Subversion Repositories dashGPS

Rev

Go to most recent revision | Blame | Compare with Previous | Last modification | View Log | Download | RSS feed

  1. /*
  2.  * Copyright (C) 2010-2018 Arm Limited or its affiliates. All rights reserved.
  3.  *
  4.  * SPDX-License-Identifier: Apache-2.0
  5.  *
  6.  * Licensed under the Apache License, Version 2.0 (the License); you may
  7.  * not use this file except in compliance with the License.
  8.  * You may obtain a copy of the License at
  9.  *
  10.  * www.apache.org/licenses/LICENSE-2.0
  11.  *
  12.  * Unless required by applicable law or agreed to in writing, software
  13.  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
  14.  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15.  * See the License for the specific language governing permissions and
  16.  * limitations under the License.
  17.  */
  18.  
  19. /* ----------------------------------------------------------------------
  20.  * Project:      CMSIS NN Library
  21.  * Title:        arm_fully_connected_q7_opt.c
  22.  * Description:  Q7 basic fully-connected layer function
  23.  *
  24.  * $Date:        17. January 2018
  25.  * $Revision:    V.1.0.0
  26.  *
  27.  * Target Processor:  Cortex-M cores
  28.  *
  29.  * -------------------------------------------------------------------- */
  30.  
  31. #include "arm_math.h"
  32. #include "arm_nnfunctions.h"
  33.  
  34. /**
  35.  *  @ingroup groupNN
  36.  */
  37.  
  38. /**
  39.  * @addtogroup FC
  40.  * @{
  41.  */
  42.  
  43.   /**
  44.    * @brief Q7 opt fully-connected layer function
  45.    * @param[in]       pV          pointer to input vector
  46.    * @param[in]       pM          pointer to matrix weights
  47.    * @param[in]       dim_vec     length of the vector
  48.    * @param[in]       num_of_rows number of rows in weight matrix
  49.    * @param[in]       bias_shift  amount of left-shift for bias
  50.    * @param[in]       out_shift   amount of right-shift for output
  51.    * @param[in]       bias        pointer to bias
  52.    * @param[in,out]   pOut        pointer to output vector
  53.    * @param[in,out]   vec_buffer  pointer to buffer space for input
  54.    * @return     The function returns <code>ARM_MATH_SUCCESS</code>
  55.    *
  56.    * @details
  57.    *
  58.    * <b>Buffer size:</b>
  59.    *
  60.    * vec_buffer size: dim_vec
  61.    *
  62.    * This opt function is designed to work with interleaved weight
  63.    * matrix. The vector input is assumed in q7_t format, we call
  64.    *  arm_q7_to_q15_no_shift_shuffle function to expand into
  65.    *  q15_t format with certain weight re-ordering, refer to the function
  66.    *  comments for more details.
  67.    *  Here we use only one pointer to read 4 rows in the weight
  68.    *  matrix. So if the original q7_t matrix looks like this:
  69.    *
  70.    *  | a11 | a12 | a13 | a14 | a15 | a16 | a17 |
  71.    *
  72.    *  | a21 | a22 | a23 | a24 | a25 | a26 | a27 |
  73.    *
  74.    *  | a31 | a32 | a33 | a34 | a35 | a36 | a37 |
  75.    *
  76.    *  | a41 | a42 | a43 | a44 | a45 | a46 | a47 |
  77.    *
  78.    *  | a51 | a52 | a53 | a54 | a55 | a56 | a57 |
  79.    *
  80.    *  | a61 | a62 | a63 | a64 | a65 | a66 | a67 |
  81.    *
  82.    *
  83.    *  We operates on multiple-of-4 rows, so the first four rows becomes
  84.    *
  85.    *  | a11 | a21 | a13 | a23 | a31 | a41 | a33 | a43 |
  86.    *
  87.    *  | a12 | a22 | a14 | a24 | a32 | a42 | a34 | a44 |
  88.    *
  89.    *  | a15 | a25 | a35 | a45 | a16 | a26 | a36 | a46 |
  90.    *
  91.    *  So within the kernel, we first read the re-ordered vector in as:
  92.    *
  93.    *  | b1  | b3  | and | b2  | b4  |
  94.    *
  95.    *  the four q31_t weights will look like
  96.    *
  97.    *  | a11 | a13 |, | a21 | a23 |, | a31 | a33 |, | a41 | a43 |
  98.    *
  99.    *  | a12 | a14 |, | a22 | a24 |, | a32 | a34 |, | a42 | a44 |
  100.    *
  101.    *  The column left over will be in-order.
  102.    *  which is:
  103.    *
  104.    *  | a17 | a27 | a37 | a47 |
  105.    *
  106.    *  For the left-over rows, we do 1x1 computation, so the data remains
  107.    *  as its original order.
  108.    *
  109.    *  So the stored weight matrix looks like this:
  110.    *
  111.    *  | a11 | a21 | a13 | a23 | a31 | a41 |
  112.    *
  113.    *  | a33 | a43 | a12 | a22 | a14 | a24 |
  114.    *
  115.    *  | a32 | a42 | a34 | a44 | a15 | a25 |
  116.    *
  117.    *  | a35 | a45 | a16 | a26 | a36 | a46 |
  118.    *
  119.    *  | a17 | a27 | a37 | a47 | a51 | a52 |
  120.    *
  121.    *  | a53 | a54 | a55 | a56 | a57 | a61 |
  122.    *
  123.    *  | a62 | a63 | a64 | a65 | a66 | a67 |
  124.    *
  125.    *
  126.    */
  127.  
  128. arm_status
  129. arm_fully_connected_q7_opt(const q7_t * pV,
  130.                            const q7_t * pM,
  131.                            const uint16_t dim_vec,
  132.                            const uint16_t num_of_rows,
  133.                            const uint16_t bias_shift,
  134.                            const uint16_t out_shift,
  135.                            const q7_t * bias,
  136.                            q7_t * pOut,
  137.                            q15_t * vec_buffer)
  138. {
  139.  
  140. #if defined (ARM_MATH_DSP)
  141.     /* Run the following code for Cortex-M4 and Cortex-M7 */
  142.  
  143.     const q7_t *pB = pM;
  144.     q7_t     *pO = pOut;
  145.     const q7_t *pBias = bias;
  146.     q15_t    *pA;
  147.     uint16_t  rowCnt = num_of_rows >> 2;
  148.  
  149.     arm_q7_to_q15_reordered_no_shift(pV, vec_buffer, dim_vec);
  150.  
  151.     while (rowCnt)
  152.     {
  153.  
  154.         q31_t     sum =  ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
  155.         q31_t     sum2 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
  156.         q31_t     sum3 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
  157.         q31_t     sum4 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
  158.  
  159.         uint16_t  colCnt = dim_vec >> 2;
  160.  
  161.         pA = vec_buffer;
  162.  
  163. #ifdef USE_INTRINSIC
  164.  
  165. #ifndef ARM_MATH_BIG_ENDIAN
  166.         while (colCnt)
  167.         {
  168.             q31_t     inM11, inM12, inM13, inM14;
  169.             q31_t     inV;
  170.  
  171.             inV = *__SIMD32(pA)++;
  172.             inM11 = *__SIMD32(pB)++;
  173.             inM12 = __SXTB16(__ROR(inM11, 8));
  174.             inM11 = __SXTB16(inM11);
  175.             sum = __SMLAD(inM11, inV, sum);
  176.             sum2 = __SMLAD(inM12, inV, sum2);
  177.             inM13 = *__SIMD32(pB)++;
  178.             inM14 = __SXTB16(__ROR(inM13, 8));
  179.             inM13 = __SXTB16(inM13);
  180.             sum3 = __SMLAD(inM13, inV, sum3);
  181.             sum4 = __SMLAD(inM14, inV, sum4);
  182.  
  183.             inV = *__SIMD32(pA)++;
  184.             inM11 = *__SIMD32(pB)++;
  185.             inM12 = __SXTB16(__ROR(inM11, 8));
  186.             inM11 = __SXTB16(inM11);
  187.             sum = __SMLAD(inM11, inV, sum);
  188.             sum2 = __SMLAD(inM12, inV, sum2);
  189.             inM13 = *__SIMD32(pB)++;
  190.             inM14 = __SXTB16(__ROR(inM13, 8));
  191.             inM13 = __SXTB16(inM13);
  192.             sum3 = __SMLAD(inM13, inV, sum3);
  193.             sum4 = __SMLAD(inM14, inV, sum4);
  194.             colCnt--;
  195.         }
  196. #else
  197.         while (colCnt)
  198.         {
  199.             q31_t     inM11, inM12, inM13, inM14;
  200.             q31_t     inV;
  201.  
  202.             inV = *__SIMD32(pA)++;
  203.             inM11 = *__SIMD32(pB)++;
  204.             inM12 = __SXTB16(__ROR(inM11, 8));
  205.             inM11 = __SXTB16(inM11);
  206.             sum = __SMLAD(inM12, inV, sum);
  207.             sum2 = __SMLAD(inM11, inV, sum2);
  208.             inM13 = *__SIMD32(pB)++;
  209.             inM14 = __SXTB16(__ROR(inM13, 8));
  210.             inM13 = __SXTB16(inM13);
  211.             sum3 = __SMLAD(inM14, inV, sum3);
  212.             sum4 = __SMLAD(inM13, inV, sum4);
  213.  
  214.             inV = *__SIMD32(pA)++;
  215.             inM11 = *__SIMD32(pB)++;
  216.             inM12 = __SXTB16(__ROR(inM11, 8));
  217.             inM11 = __SXTB16(inM11);
  218.             sum = __SMLAD(inM12, inV, sum);
  219.             sum2 = __SMLAD(inM11, inV, sum2);
  220.             inM13 = *__SIMD32(pB)++;
  221.             inM14 = __SXTB16(__ROR(inM13, 8));
  222.             inM13 = __SXTB16(inM13);
  223.             sum3 = __SMLAD(inM14, inV, sum3);
  224.             sum4 = __SMLAD(inM13, inV, sum4);
  225.             colCnt--;
  226.         }
  227. #endif                          /* ARM_MATH_BIG_ENDIAN */
  228.  
  229. #else
  230.  
  231.         /*
  232.          * register needed:
  233.          * loop counter: colCnt
  234.          * accumulators: sum, sum2, sum3, sum4
  235.          * pointers: pB, pA
  236.          * weight data: inM11, inM12, inM13, inM14
  237.          * activation data: inV
  238.          */
  239.  
  240. #ifndef ARM_MATH_BIG_ENDIAN
  241.         asm volatile ("COL_LOOP_%=:\n"
  242.                       "ldr.w r4, [%[pA]], #8\n"
  243.                       "ldr.w r1, [%[pB]], #16\n"
  244.                       "mov.w r0, r1, ror #8\n"
  245.                       "sxtb16 r0, r0\n"
  246.                       "sxtb16 r1, r1\n"
  247.                       "smlad %[sum], r4, r1, %[sum]\n"
  248.                       "smlad %[sum2], r4, r0, %[sum2]\n"
  249.                       "ldr.w r3, [%[pB], #-12]\n"
  250.                       "mov.w r2, r3, ror #8\n"
  251.                       "sxtb16 r2, r2\n"
  252.                       "sxtb16 r3, r3\n"
  253.                       "smlad %[sum3], r4, r3, %[sum3]\n"
  254.                       "smlad %[sum4], r4, r2, %[sum4]\n"
  255.                       "ldr.w r4, [%[pA], #-4]\n"
  256.                       "ldr.w r1, [%[pB], #-8]\n"
  257.                       "mov.w r0, r1, ror #8\n"
  258.                       "sxtb16 r0, r0\n"
  259.                       "sxtb16 r1, r1\n"
  260.                       "smlad %[sum], r4, r1, %[sum]\n"
  261.                       "smlad %[sum2], r4, r0, %[sum2]\n"
  262.                       "ldr.w r3, [%[pB], #-4]\n"
  263.                       "mov.w r2, r3, ror #8\n"
  264.                       "sxtb16 r2, r2\n"
  265.                       "sxtb16 r3, r3\n"
  266.                       "smlad %[sum3], r4, r3, %[sum3]\n"
  267.                       "smlad %[sum4], r4, r2, %[sum4]\n"
  268.                       "subs %[colCnt], #1\n"
  269.                       "bne COL_LOOP_%=\n":[sum] "+r"(sum),
  270.                       [sum2] "+r"(sum2),[sum3] "+r"(sum3),
  271.                       [sum4] "+r"(sum4),[pB] "+r"(pB),[pA] "+r"(pA):[colCnt] "r"(colCnt):"r0", "r1", "r2", "r3", "r4");
  272. #else
  273.         asm volatile ("COL_LOOP_%=:\n"
  274.                       "ldr.w r4, [%[pA]], #8\n"
  275.                       "ldr.w r1, [%[pB]], #16\n"
  276.                       "mov.w r0, r1, ror #8\n"
  277.                       "sxtb16 r0, r0\n"
  278.                       "sxtb16 r1, r1\n"
  279.                       "smlad %[sum], r4, r0, %[sum]\n"
  280.                       "smlad %[sum2], r4, r1, %[sum2]\n"
  281.                       "ldr.w r3, [%[pB], #-12]\n"
  282.                       "mov.w r2, r3, ror #8\n"
  283.                       "sxtb16 r2, r2\n"
  284.                       "sxtb16 r3, r3\n"
  285.                       "smlad %[sum3], r4, r2, %[sum3]\n"
  286.                       "smlad %[sum4], r4, r3, %[sum4]\n"
  287.                       "ldr.w r4, [%[pA], #-4]\n"
  288.                       "ldr.w r1, [%[pB], #-8]\n"
  289.                       "mov.w r0, r1, ror #8\n"
  290.                       "sxtb16 r0, r0\n"
  291.                       "sxtb16 r1, r1\n"
  292.                       "smlad %[sum], r4, r0, %[sum]\n"
  293.                       "smlad %[sum2], r4, r1, %[sum2]\n"
  294.                       "ldr.w r3, [%[pB], #-4]\n"
  295.                       "mov.w r2, r3, ror #8\n"
  296.                       "sxtb16 r2, r2\n"
  297.                       "sxtb16 r3, r3\n"
  298.                       "smlad %[sum3], r4, r2, %[sum3]\n"
  299.                       "smlad %[sum4], r4, r3, %[sum4]\n"
  300.                       "subs %[colCnt], #1\n"
  301.                       "bne COL_LOOP_%=\n":[sum] "+r"(sum),
  302.                       [sum2] "+r"(sum2),[sum3] "+r"(sum3),
  303.                       [sum4] "+r"(sum4),[pB] "+r"(pB),[pA] "+r"(pA):[colCnt] "r"(colCnt):"r0", "r1", "r2", "r3", "r4");
  304. #endif                          /* ARM_MATH_BIG_ENDIAN */
  305.  
  306. #endif                          /* USE_INTRINSIC */
  307.  
  308.         colCnt = dim_vec & 0x3;
  309.         while (colCnt)
  310.         {
  311.             q15_t     inV = *pA++;
  312.             q7_t      inM = *pB++;
  313.             q7_t      inM2 = *pB++;
  314.             q7_t      inM3 = *pB++;
  315.             q7_t      inM4 = *pB++;
  316.  
  317.             sum += inV * inM;
  318.             sum2 += inV * inM2;
  319.             sum3 += inV * inM3;
  320.             sum4 += inV * inM4;
  321.             colCnt--;
  322.         }                       /* while over colCnt */
  323.         *pO++ = (q7_t) (__SSAT((sum >> out_shift), 8));
  324.         *pO++ = (q7_t) (__SSAT((sum2 >> out_shift), 8));
  325.         *pO++ = (q7_t) (__SSAT((sum3 >> out_shift), 8));
  326.         *pO++ = (q7_t) (__SSAT((sum4 >> out_shift), 8));
  327.  
  328.         /* adjust the pointers and counters */
  329.         rowCnt--;
  330.     }
  331.  
  332.     /* left-over part of the rows */
  333.     rowCnt = num_of_rows & 0x3;
  334.  
  335.     while (rowCnt)
  336.     {
  337.         q31_t     sum = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
  338.         uint16_t  colCnt = dim_vec >> 2;
  339.  
  340.         pA = vec_buffer;
  341.  
  342.         while (colCnt)
  343.         {
  344.             q31_t     inV1, inV2, inM11, inM12;
  345.  
  346.             pB = (q7_t *) read_and_pad_reordered((void *)pB, &inM11, &inM12);
  347.  
  348.             inV1 = *__SIMD32(pA)++;
  349.             sum = __SMLAD(inV1, inM11, sum);
  350.  
  351.             inV2 = *__SIMD32(pA)++;
  352.             sum = __SMLAD(inV2, inM12, sum);
  353.  
  354.             colCnt--;
  355.         }
  356.  
  357.         /* left-over of the vector */
  358.         colCnt = dim_vec & 0x3;
  359.         while (colCnt)
  360.         {
  361.             q15_t     inV = *pA++;
  362.             q7_t      inM = *pB++;
  363.             sum += inV * inM;
  364.             colCnt--;
  365.         }
  366.  
  367.         *pO++ = (q7_t) (__SSAT((sum >> out_shift), 8));
  368.  
  369.         rowCnt--;
  370.     }
  371.  
  372. #else
  373.     /* Run the following code as reference implementation for Cortex-M0 and Cortex-M3 */
  374.     uint16_t  rowCnt = num_of_rows >> 2;
  375.     const q7_t *pB = pM;
  376.     const q7_t *pA;
  377.     q7_t     *pO = pOut;
  378.     const q7_t *pBias = bias;
  379.  
  380.     while (rowCnt)
  381.     {
  382.         q31_t     sum =  ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
  383.         q31_t     sum2 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
  384.         q31_t     sum3 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
  385.         q31_t     sum4 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
  386.  
  387.         uint16_t  colCnt = dim_vec >> 2;
  388.  
  389.         pA = pV;
  390.  
  391.         while (colCnt)
  392.         {
  393.             q7_t      inA1 = *pA++;
  394.             q7_t      inA3 = *pA++;
  395.             q7_t      inA2 = *pA++;
  396.             q7_t      inA4 = *pA++;
  397.  
  398.             q7_t      inB1 = *pB++;
  399.             q7_t      inB3 = *pB++;
  400.             q7_t      inB2 = *pB++;
  401.             q7_t      inB4 = *pB++;
  402.  
  403.             sum += inA1 * inB1 + inA2 * inB2;
  404.             sum2 += inA1 * inB3 + inA2 * inB4;
  405.  
  406.             inB1 = *pB++;
  407.             inB3 = *pB++;
  408.             inB2 = *pB++;
  409.             inB4 = *pB++;
  410.  
  411.             sum3 += inA1 * inB1 + inA2 * inB2;
  412.             sum4 += inA1 * inB3 + inA2 * inB4;
  413.  
  414.             inB1 = *pB++;
  415.             inB3 = *pB++;
  416.             inB2 = *pB++;
  417.             inB4 = *pB++;
  418.  
  419.             sum += inA3 * inB1 + inA4 * inB2;
  420.             sum2 += inA3 * inB3 + inA4 * inB4;
  421.  
  422.             inB1 = *pB++;
  423.             inB3 = *pB++;
  424.             inB2 = *pB++;
  425.             inB4 = *pB++;
  426.  
  427.             sum3 += inA3 * inB1 + inA4 * inB2;
  428.             sum4 += inA3 * inB3 + inA4 * inB4;
  429.  
  430.             colCnt--;
  431.         }
  432.         colCnt = dim_vec & 0x3;
  433.         while (colCnt)
  434.         {
  435.             q7_t      inA = *pA++;
  436.             q7_t      inB = *pB++;
  437.             sum += inA * inB;
  438.             inB = *pB++;
  439.             sum2 += inA * inB;
  440.             inB = *pB++;
  441.             sum3 += inA * inB;
  442.             inB = *pB++;
  443.             sum4 += inA * inB;
  444.  
  445.             colCnt--;
  446.         }
  447.         *pO++ = (q7_t) __SSAT((sum >> out_shift), 8);
  448.         *pO++ = (q7_t) __SSAT((sum2 >> out_shift), 8);
  449.         *pO++ = (q7_t) __SSAT((sum3 >> out_shift), 8);
  450.         *pO++ = (q7_t) __SSAT((sum4 >> out_shift), 8);
  451.  
  452.         rowCnt--;
  453.     }
  454.  
  455.     rowCnt = num_of_rows & 0x3;
  456.  
  457.     while (rowCnt)
  458.     {
  459.         int       ip_out = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
  460.  
  461.         int       j;
  462.  
  463.         pA = pV;
  464.         for (j = 0; j < dim_vec; j++)
  465.         {
  466.             q7_t      inA = *pA++;
  467.             q7_t      inB = *pB++;
  468.             ip_out += inA * inB;
  469.         }
  470.         *pO++ = (q7_t) __SSAT((ip_out >> out_shift), 8);
  471.  
  472.         rowCnt--;
  473.     }
  474.  
  475. #endif                          /* ARM_MATH_DSP */
  476.  
  477.     /* Return to ARM_MATH_SUCCESS */
  478.     return (ARM_MATH_SUCCESS);
  479.  
  480. }
  481.  
  482. /**
  483.  * @} end of FC group
  484.  */
  485.