Subversion Repositories dashGPS

Rev

Blame | Last modification | View Log | Download | RSS feed

  1. /*
  2.  * Copyright (C) 2010-2018 Arm Limited or its affiliates. All rights reserved.
  3.  *
  4.  * SPDX-License-Identifier: Apache-2.0
  5.  *
  6.  * Licensed under the Apache License, Version 2.0 (the License); you may
  7.  * not use this file except in compliance with the License.
  8.  * You may obtain a copy of the License at
  9.  *
  10.  * www.apache.org/licenses/LICENSE-2.0
  11.  *
  12.  * Unless required by applicable law or agreed to in writing, software
  13.  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
  14.  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15.  * See the License for the specific language governing permissions and
  16.  * limitations under the License.
  17.  */
  18.  
  19. /* ----------------------------------------------------------------------
  20.  * Project:      CMSIS NN Library
  21.  * Title:        arm_fully_connected_mat_q7_vec_q15_opt.c
  22.  * Description:  Mixed Q15-Q7 opt fully-connected layer function
  23.  *
  24.  * $Date:        17. January 2018
  25.  * $Revision:    V.1.0.0
  26.  *
  27.  * Target Processor:  Cortex-M cores
  28.  *
  29.  * -------------------------------------------------------------------- */
  30.  
  31. #include "arm_math.h"
  32. #include "arm_nnfunctions.h"
  33.  
  34. /**
  35.  *  @ingroup groupNN
  36.  */
  37.  
  38. /**
  39.  * @addtogroup FC
  40.  * @{
  41.  */
  42.  
  43.   /**
  44.    * @brief Mixed Q15-Q7 opt fully-connected layer function
  45.    * @param[in]       pV          pointer to input vector
  46.    * @param[in]       pM          pointer to matrix weights
  47.    * @param[in]       dim_vec     length of the vector
  48.    * @param[in]       num_of_rows number of rows in weight matrix
  49.    * @param[in]       bias_shift  amount of left-shift for bias
  50.    * @param[in]       out_shift   amount of right-shift for output
  51.    * @param[in]       bias        pointer to bias
  52.    * @param[in,out]   pOut        pointer to output vector
  53.    * @param[in,out]   vec_buffer  pointer to buffer space for input
  54.    * @return     The function returns <code>ARM_MATH_SUCCESS</code>
  55.    *
  56.    * @details
  57.    *
  58.    * <b>Buffer size:</b>
  59.    *
  60.    * vec_buffer size: 0
  61.    *
  62.    *  Q7_Q15 version of the fully connected layer
  63.    *
  64.    *  Weights are in q7_t and Activations are in q15_t
  65.    *
  66.    *  Limitation: x4 version requires weight reordering to work
  67.    *
  68.    *  Here we use only one pointer to read 4 rows in the weight
  69.    *  matrix. So if the original q7_t matrix looks like this:
  70.    *
  71.    *  | a11 | a12 | a13 | a14 | a15 | a16 | a17 |
  72.    *
  73.    *  | a21 | a22 | a23 | a24 | a25 | a26 | a27 |
  74.    *
  75.    *  | a31 | a32 | a33 | a34 | a35 | a36 | a37 |
  76.    *
  77.    *  | a41 | a42 | a43 | a44 | a45 | a46 | a47 |
  78.    *
  79.    *  | a51 | a52 | a53 | a54 | a55 | a56 | a57 |
  80.    *
  81.    *  | a61 | a62 | a63 | a64 | a65 | a66 | a67 |
  82.    *
  83.    *  We operates on multiple-of-4 rows, so the first four rows becomes
  84.    *
  85.    *  | a11 | a21 | a12 | a22 | a31 | a41 | a32 | a42 |
  86.    *
  87.    *  | a13 | a23 | a14 | a24 | a33 | a43 | a34 | a44 |
  88.    *
  89.    *  | a15 | a25 | a16 | a26 | a35 | a45 | a36 | a46 |
  90.    *
  91.    *  The column left over will be in-order.
  92.    *  which is:
  93.    *  | a17 | a27 | a37 | a47 |
  94.    *
  95.    *  For the left-over rows, we do 1x1 computation, so the data remains
  96.    *  as its original order.
  97.    *
  98.    *  So the stored weight matrix looks like this:
  99.    *
  100.    *  | a11 | a21 | a12 | a22 | a31 | a41 |
  101.    *
  102.    *  | a32 | a42 | a13 | a23 | a14 | a24 |
  103.    *
  104.    *  | a33 | a43 | a34 | a44 | a15 | a25 |
  105.    *
  106.    *  | a16 | a26 | a35 | a45 | a36 | a46 |
  107.    *
  108.    *  | a17 | a27 | a37 | a47 | a51 | a52 |
  109.    *
  110.    *  | a53 | a54 | a55 | a56 | a57 | a61 |
  111.    *
  112.    *  | a62 | a63 | a64 | a65 | a66 | a67 |
  113.    *
  114.    */
  115.  
  116. arm_status
  117. arm_fully_connected_mat_q7_vec_q15_opt(const q15_t * pV,
  118.                                        const q7_t * pM,
  119.                                        const uint16_t dim_vec,
  120.                                        const uint16_t num_of_rows,
  121.                                        const uint16_t bias_shift,
  122.                                        const uint16_t out_shift, const q7_t * bias, q15_t * pOut, q15_t * vec_buffer)
  123. {
  124.  
  125. #if defined (ARM_MATH_DSP)
  126.     /* Run the following code for Cortex-M4 and Cortex-M7 */
  127.  
  128.     const q7_t *pB = pM;
  129.     q15_t    *pO = pOut;
  130.     const q7_t *pBias = bias;
  131.     const q15_t *pA = pV;
  132.  
  133.     uint16_t  rowCnt = num_of_rows >> 2;
  134.  
  135.     while (rowCnt)
  136.     {
  137.         q31_t     sum =  ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
  138.         q31_t     sum2 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
  139.         q31_t     sum3 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
  140.         q31_t     sum4 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
  141.  
  142.         uint16_t  colCnt = dim_vec >> 1;
  143.  
  144.         pA = pV;
  145.  
  146. #ifdef USE_INTRINSIC
  147.  
  148. #ifndef ARM_MATH_BIG_ENDIAN
  149.  
  150.         while (colCnt)
  151.         {
  152.             q31_t     inM11, inM12, inM13, inM14;
  153.             q31_t     inV;
  154.  
  155.             inV = *__SIMD32(pA)++;
  156.             inM11 = *__SIMD32(pB)++;
  157.             inM12 = __SXTB16(__ROR(inM11, 8));
  158.             inM11 = __SXTB16(inM11);
  159.             sum = __SMLAD(inM11, inV, sum);
  160.             sum2 = __SMLAD(inM12, inV, sum2);
  161.             inM13 = *__SIMD32(pB)++;
  162.             inM14 = __SXTB16(__ROR(inM13, 8));
  163.             inM13 = __SXTB16(inM13);
  164.             sum3 = __SMLAD(inM13, inV, sum3);
  165.             sum4 = __SMLAD(inM14, inV, sum4);
  166.             colCnt--;
  167.         }
  168.  
  169. #else
  170.  
  171.         while (colCnt)
  172.         {
  173.             q31_t     inM11, inM12, inM13, inM14;
  174.             q31_t     inV;
  175.  
  176.             inV = *__SIMD32(pA)++;
  177.             inM11 = *__SIMD32(pB)++;
  178.             inM12 = __SXTB16(__ROR(inM11, 8));
  179.             inM11 = __SXTB16(inM11);
  180.             sum = __SMLAD(inM12, inV, sum);
  181.             sum2 = __SMLAD(inM11, inV, sum2);
  182.             inM13 = *__SIMD32(pB)++;
  183.             inM14 = __SXTB16(__ROR(inM13, 8));
  184.             inM13 = __SXTB16(inM13);
  185.             sum3 = __SMLAD(inM14, inV, sum3);
  186.             sum4 = __SMLAD(inM13, inV, sum4);
  187.             colCnt--;
  188.         }
  189.  
  190. #endif                          /* ARM_MATH_BIG_ENDIAN */
  191.  
  192. #else
  193.  
  194.         /*
  195.          * register needed:
  196.          * loop counter: colCnt
  197.          * accumulators: sum, sum2, sum3, sum4
  198.          * pointers: pB, pA
  199.          * weight data: inM11, inM12, inM13, inM14
  200.          * activation data: inV
  201.          */
  202.  
  203. #ifndef ARM_MATH_BIG_ENDIAN
  204.         asm volatile ("COL_LOOP_%=:\n"
  205.                       "ldr.w r4, [%[pA]], #4\n"
  206.                       "ldr.w r1, [%[pB]], #8\n"
  207.                       "mov.w r0, r1, ror #8\n"
  208.                       "sxtb16 r0, r0\n"
  209.                       "sxtb16 r1, r1\n"
  210.                       "smlad %[sum], r4, r1, %[sum]\n"
  211.                       "smlad %[sum2], r4, r0, %[sum2]\n"
  212.                       "ldr.w r3, [%[pB], #-4]\n"
  213.                       "mov.w r2, r3, ror #8\n"
  214.                       "sxtb16 r2, r2\n"
  215.                       "sxtb16 r3, r3\n"
  216.                       "smlad %[sum3], r4, r3, %[sum3]\n"
  217.                       "smlad %[sum4], r4, r2, %[sum4]\n"
  218.                       "subs %[colCnt], #1\n"
  219.                       "bne COL_LOOP_%=\n":[sum] "+r"(sum),
  220.                       [sum2] "+r"(sum2),[sum3] "+r"(sum3),
  221.                       [sum4] "+r"(sum4),[pB] "+r"(pB),[pA] "+r"(pA):[colCnt] "r"(colCnt):"r0", "r1", "r2", "r3", "r4");
  222. #else
  223.         asm volatile ("COL_LOOP_%=:\n"
  224.                       "ldr.w r4, [%[pA]], #4\n"
  225.                       "ldr.w r1, [%[pB]], #8\n"
  226.                       "mov.w r0, r1, ror #8\n"
  227.                       "sxtb16 r0, r0\n"
  228.                       "sxtb16 r1, r1\n"
  229.                       "smlad %[sum], r4, r0, %[sum]\n"
  230.                       "smlad %[sum2], r4, r1, %[sum2]\n"
  231.                       "ldr.w r3, [%[pB], #-4]\n"
  232.                       "mov.w r2, r3, ror #8\n"
  233.                       "sxtb16 r2, r2\n"
  234.                       "sxtb16 r3, r3\n"
  235.                       "smlad %[sum3], r4, r2, %[sum3]\n"
  236.                       "smlad %[sum4], r4, r3, %[sum4]\n"
  237.                       "subs %[colCnt], #1\n"
  238.                       "bne COL_LOOP_%=\n":[sum] "+r"(sum),
  239.                       [sum2] "+r"(sum2),[sum3] "+r"(sum3),
  240.                       [sum4] "+r"(sum4),[pB] "+r"(pB),[pA] "+r"(pA):[colCnt] "r"(colCnt):"r0", "r1", "r2", "r3", "r4");
  241. #endif                          /* ARM_MATH_BIG_ENDIAN */
  242.  
  243. #endif                          /* USE_INTRINSIC */
  244.  
  245.         colCnt = dim_vec & 0x1;
  246.         while (colCnt)
  247.         {
  248.             q15_t     inV = *pA++;
  249.             q7_t      inM = *pB++;
  250.             q7_t      inM2 = *pB++;
  251.             q7_t      inM3 = *pB++;
  252.             q7_t      inM4 = *pB++;
  253.  
  254.             sum += inV * inM;
  255.             sum2 += inV * inM2;
  256.             sum3 += inV * inM3;
  257.             sum4 += inV * inM4;
  258.             colCnt--;
  259.         }                       /* while over colCnt */
  260.         *pO++ = (q15_t) (__SSAT((sum >> out_shift), 16));
  261.         *pO++ = (q15_t) (__SSAT((sum2 >> out_shift), 16));
  262.         *pO++ = (q15_t) (__SSAT((sum3 >> out_shift), 16));
  263.         *pO++ = (q15_t) (__SSAT((sum4 >> out_shift), 16));
  264.  
  265.         /* adjust the pointers and counters */
  266.         rowCnt--;
  267.     }
  268.  
  269.     /* left-over part of the rows */
  270.     rowCnt = num_of_rows & 0x3;
  271.  
  272.     while (rowCnt)
  273.     {
  274.         q31_t     sum = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
  275.  
  276.         uint16_t  colCnt = dim_vec >> 2;
  277.  
  278.         pA = pV;
  279.  
  280.         while (colCnt)
  281.         {
  282.             q31_t     inV1, inV2, inM11, inM12;
  283.  
  284.             pB = (q7_t *) read_and_pad((void *)pB, &inM11, &inM12);
  285.  
  286.             inV1 = *__SIMD32(pA)++;
  287.             sum = __SMLAD(inV1, inM11, sum);
  288.  
  289.             inV2 = *__SIMD32(pA)++;
  290.             sum = __SMLAD(inV2, inM12, sum);
  291.  
  292.             colCnt--;
  293.         }
  294.  
  295.         /* left-over of the vector */
  296.         colCnt = dim_vec & 0x3;
  297.         while (colCnt)
  298.         {
  299.             q15_t     inV = *pA++;
  300.             q7_t      inM = *pB++;
  301.             sum += inV * inM;
  302.             colCnt--;
  303.         }
  304.  
  305.         *pO++ = (q15_t) (__SSAT((sum >> out_shift), 16));
  306.  
  307.         rowCnt--;
  308.     }
  309.  
  310. #else
  311.     /* Run the following code as reference implementation for Cortex-M0 and Cortex-M3 */
  312.     uint16_t  rowCnt = num_of_rows >> 2;
  313.     const q7_t *pB = pM;
  314.     const q15_t *pA;
  315.     q15_t    *pO = pOut;
  316.     const q7_t *pBias = bias;
  317.  
  318.     while (rowCnt)
  319.     {
  320.         q31_t     sum =  ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
  321.         q31_t     sum2 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
  322.         q31_t     sum3 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
  323.         q31_t     sum4 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
  324.         uint16_t  colCnt = dim_vec >> 1;
  325.  
  326.         pA = pV;
  327.  
  328.         while (colCnt)
  329.         {
  330.             q15_t     inA1 = *pA++;
  331.             q15_t     inA2 = *pA++;
  332.  
  333.             q7_t      inB1 = *pB++;
  334.             q7_t      inB3 = *pB++;
  335.             q7_t      inB2 = *pB++;
  336.             q7_t      inB4 = *pB++;
  337.  
  338.             sum += inA1 * inB1 + inA2 * inB2;
  339.             sum2 += inA1 * inB3 + inA2 * inB4;
  340.  
  341.             inB1 = *pB++;
  342.             inB3 = *pB++;
  343.             inB2 = *pB++;
  344.             inB4 = *pB++;
  345.  
  346.             sum3 += inA1 * inB1 + inA2 * inB2;
  347.             sum4 += inA1 * inB3 + inA2 * inB4;
  348.  
  349.             colCnt--;
  350.         }
  351.  
  352.         colCnt = dim_vec & 0x1;
  353.         while (colCnt)
  354.         {
  355.             q15_t     inA = *pA++;
  356.             q7_t      inB = *pB++;
  357.             sum += inA * inB;
  358.             inB = *pB++;
  359.             sum2 += inA * inB;
  360.             inB = *pB++;
  361.             sum3 += inA * inB;
  362.             inB = *pB++;
  363.             sum4 += inA * inB;
  364.  
  365.             colCnt--;
  366.         }
  367.         *pO++ = (q15_t) __SSAT((sum >> out_shift), 16);
  368.         *pO++ = (q15_t) __SSAT((sum2 >> out_shift), 16);
  369.         *pO++ = (q15_t) __SSAT((sum3 >> out_shift), 16);
  370.         *pO++ = (q15_t) __SSAT((sum4 >> out_shift), 16);
  371.  
  372.         rowCnt--;
  373.     }
  374.  
  375.     rowCnt = num_of_rows & 0x3;
  376.  
  377.     while (rowCnt)
  378.     {
  379.         int       ip_out = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
  380.         int       j;
  381.  
  382.         pA = pV;
  383.         for (j = 0; j < dim_vec; j++)
  384.         {
  385.             q15_t     inA = *pA++;
  386.             q7_t      inB = *pB++;
  387.             ip_out += inA * inB;
  388.         }
  389.         *pO++ = (q15_t) __SSAT((ip_out >> out_shift), 16);
  390.  
  391.         rowCnt--;
  392.     }
  393.  
  394. #endif                          /* ARM_MATH_DSP */
  395.  
  396.     /* Return to ARM_MATH_SUCCESS */
  397.     return (ARM_MATH_SUCCESS);
  398.  
  399. }
  400.  
  401. /**
  402.  * @} end of FC group
  403.  */
  404.