Subversion Repositories dashGPS

Rev

Blame | Last modification | View Log | Download | RSS feed

  1. /*
  2.  * Copyright (C) 2010-2018 Arm Limited or its affiliates. All rights reserved.
  3.  *
  4.  * SPDX-License-Identifier: Apache-2.0
  5.  *
  6.  * Licensed under the Apache License, Version 2.0 (the License); you may
  7.  * not use this file except in compliance with the License.
  8.  * You may obtain a copy of the License at
  9.  *
  10.  * www.apache.org/licenses/LICENSE-2.0
  11.  *
  12.  * Unless required by applicable law or agreed to in writing, software
  13.  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
  14.  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15.  * See the License for the specific language governing permissions and
  16.  * limitations under the License.
  17.  */
  18.  
  19. /* ----------------------------------------------------------------------
  20.  * Project:      CMSIS NN Library
  21.  * Title:        arm_fully_connected_q15_opt.c
  22.  * Description:  Q15 opt fully-connected layer function
  23.  *
  24.  * $Date:        17. January 2018
  25.  * $Revision:    V.1.0.0
  26.  *
  27.  * Target Processor:  Cortex-M cores
  28.  *
  29.  * -------------------------------------------------------------------- */
  30.  
  31. #include "arm_math.h"
  32. #include "arm_nnfunctions.h"
  33.  
  34. /**
  35.  *  @ingroup groupNN
  36.  */
  37.  
  38. /**
  39.  * @addtogroup FC
  40.  * @{
  41.  */
  42.  
  43.   /**
  44.    * @brief Q15 opt fully-connected layer function
  45.    * @param[in]       pV          pointer to input vector
  46.    * @param[in]       pM          pointer to matrix weights
  47.    * @param[in]       dim_vec     length of the vector
  48.    * @param[in]       num_of_rows number of rows in weight matrix
  49.    * @param[in]       bias_shift  amount of left-shift for bias
  50.    * @param[in]       out_shift   amount of right-shift for output
  51.    * @param[in]       bias        pointer to bias
  52.    * @param[in,out]   pOut        pointer to output vector
  53.    * @param[in,out]   vec_buffer  pointer to buffer space for input
  54.    * @return     The function returns <code>ARM_MATH_SUCCESS</code>
  55.    *
  56.    *
  57.    * @details
  58.    *
  59.    * <b>Buffer size:</b>
  60.    *
  61.    * vec_buffer size: 0
  62.    *
  63.    *  Here we use only one pointer to read 4 rows in the weight
  64.    *  matrix. So if the original matrix looks like this:
  65.    *
  66.    *  | a11 | a12 | a13 |
  67.    *
  68.    *  | a21 | a22 | a23 |
  69.    *
  70.    *  | a31 | a32 | a33 |
  71.    *
  72.    *  | a41 | a42 | a43 |
  73.    *
  74.    *  | a51 | a52 | a53 |
  75.    *
  76.    *  | a61 | a62 | a63 |
  77.    *
  78.    *  We operates on multiple-of-4 rows, so the first four rows becomes
  79.    *
  80.    *  | a11 | a12 | a21 | a22 | a31 | a32 | a41 | a42 |
  81.    *
  82.    *  | a13 | a23 | a33 | a43 |
  83.    *
  84.    *  Remaining rows are kept the same original order.
  85.    *
  86.    *  So the stored weight matrix looks like this:
  87.    *
  88.    *
  89.    *  | a11 | a12 | a21 | a22 | a31 | a32 | a41 | a42 |
  90.    *
  91.    *  | a13 | a23 | a33 | a43 | a51 | a52 | a53 | a61 |
  92.    *
  93.    *  | a62 | a63 |
  94.    */
  95.  
  96. arm_status
  97. arm_fully_connected_q15_opt(const q15_t * pV,
  98.                             const q15_t * pM,
  99.                             const uint16_t dim_vec,
  100.                             const uint16_t num_of_rows,
  101.                             const uint16_t bias_shift,
  102.                             const uint16_t out_shift,
  103.                             const q15_t * bias,
  104.                             q15_t * pOut,
  105.                             q15_t * vec_buffer)
  106. {
  107.  
  108. #if defined (ARM_MATH_DSP)
  109.     /* Run the following code for Cortex-M4 and Cortex-M7 */
  110.  
  111.     const q15_t *pB = pM;
  112.     q15_t    *pO = pOut;
  113.     const q15_t *pBias = bias;
  114.     const q15_t *pA = pV;
  115.  
  116.     uint16_t  rowCnt = num_of_rows >> 2;
  117.  
  118.     while (rowCnt)
  119.     {
  120.         q31_t     sum =  ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
  121.         q31_t     sum2 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
  122.         q31_t     sum3 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
  123.         q31_t     sum4 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
  124.  
  125.         uint16_t  colCnt = dim_vec >> 1;
  126.  
  127.         pA = pV;
  128.  
  129. #ifdef USE_INTRINSIC
  130.  
  131.         while (colCnt)
  132.         {
  133.             q31_t     inM11, inM12, inM13, inM14;
  134.             q31_t     inV;
  135.  
  136.             inV = *__SIMD32(pA)++;
  137.             inM11 = *__SIMD32(pB)++;
  138.             sum = __SMLAD(inV, inM11, sum);
  139.             inM12 = *__SIMD32(pB)++;
  140.             sum2 = __SMLAD(inV, inM12, sum2);
  141.             inM13 = *__SIMD32(pB)++;
  142.             sum3 = __SMLAD(inV, inM13, sum3);
  143.             inM14 = *__SIMD32(pB)++;
  144.             sum4 = __SMLAD(inV, inM14, sum4);
  145.             colCnt--;
  146.         }
  147.  
  148. #else
  149.  
  150.         /*
  151.          * register needed:
  152.          * loop counter: colCnt
  153.          * accumulators: sum, sum2, sum3, sum4
  154.          * pointers: pB, pA
  155.          * weight data: inM11, inM12, inM13, inM14
  156.          * activation data: inV
  157.          */
  158.  
  159.         asm volatile ("COL_LOOP_%=:\n"
  160.                       "ldr.w r4, [%[pA]], #4\n"
  161.                       "ldr.w r0, [%[pB]], #16\n"
  162.                       "smlad %[sum], r4, r0, %[sum]\n"
  163.                       "ldr.w r1, [%[pB] , #-12]\n"
  164.                       "smlad %[sum2], r4, r1, %[sum2]\n"
  165.                       "ldr.w r2, [%[pB] , #-8]\n"
  166.                       "smlad %[sum3], r4, r2, %[sum3]\n"
  167.                       "ldr.w r3, [%[pB] , #-4]\n"
  168.                       "smlad %[sum4], r4, r3, %[sum4]\n"
  169.                       "subs %[colCnt], #1\n"
  170.                       "bne COL_LOOP_%=\n":[sum] "+r"(sum),
  171.                       [sum2] "+r"(sum2),[sum3] "+r"(sum3),
  172.                       [sum4] "+r"(sum4),[pB] "+r"(pB),[pA] "+r"(pA):[colCnt] "r"(colCnt):"r0", "r1", "r2", "r3", "r4");
  173.  
  174. #endif                          /* USE_INTRINSIC */
  175.  
  176.         colCnt = dim_vec & 0x1;
  177.         while (colCnt)
  178.         {
  179.  
  180.             q15_t     inV = *pA++;
  181.             q15_t     inM = *pB++;
  182.             q15_t     inM2 = *pB++;
  183.             q15_t     inM3 = *pB++;
  184.             q15_t     inM4 = *pB++;
  185.  
  186.             sum += inV * inM;
  187.             sum2 += inV * inM2;
  188.             sum3 += inV * inM3;
  189.             sum4 += inV * inM4;
  190.             colCnt--;
  191.         }                       /* while over colCnt */
  192.         *pO++ = (q15_t) (__SSAT((sum >> out_shift), 16));
  193.         *pO++ = (q15_t) (__SSAT((sum2 >> out_shift), 16));
  194.         *pO++ = (q15_t) (__SSAT((sum3 >> out_shift), 16));
  195.         *pO++ = (q15_t) (__SSAT((sum4 >> out_shift), 16));
  196.  
  197.         /* adjust the pointers and counters */
  198.         rowCnt--;
  199.     }
  200.  
  201.     /* left-over part of the rows */
  202.     rowCnt = num_of_rows & 0x3;
  203.  
  204.     while (rowCnt)
  205.     {
  206.         q31_t     sum = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
  207.  
  208.         uint16_t  colCnt = dim_vec >> 2;
  209.  
  210.         pA = pV;
  211.  
  212.         while (colCnt)
  213.         {
  214.             q31_t     inV1, inV2, inM1, inM2;
  215.  
  216.             inM1 = *__SIMD32(pB)++;
  217.             inV1 = *__SIMD32(pA)++;
  218.             sum = __SMLAD(inV1, inM1, sum);
  219.  
  220.             inM2 = *__SIMD32(pB)++;
  221.             inV2 = *__SIMD32(pA)++;
  222.             sum = __SMLAD(inV2, inM2, sum);
  223.  
  224.             colCnt--;
  225.         }
  226.  
  227.         /* left-over of the vector */
  228.         colCnt = dim_vec & 0x3;
  229.         while (colCnt)
  230.         {
  231.             q15_t     inV = *pA++;
  232.             q15_t     inM = *pB++;
  233.             sum += inV * inM;
  234.             colCnt--;
  235.         }
  236.  
  237.         *pO++ = (q15_t) (__SSAT((sum >> out_shift), 16));
  238.  
  239.         rowCnt--;
  240.     }
  241.  
  242. #else
  243.     /* Run the following code as reference implementation for Cortex-M0 and Cortex-M3 */
  244.     uint16_t  rowCnt = num_of_rows >> 2;
  245.     const q15_t *pB = pM;
  246.     const q15_t *pA;
  247.     q15_t    *pO = pOut;
  248.     const q15_t *pBias = bias;
  249.  
  250.     while (rowCnt)
  251.     {
  252.         q31_t     sum =  ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
  253.         q31_t     sum2 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
  254.         q31_t     sum3 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
  255.         q31_t     sum4 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
  256.  
  257.         uint16_t  colCnt = dim_vec >> 1;
  258.  
  259.         pA = pV;
  260.         while (colCnt)
  261.         {
  262.             q15_t     inA1 = *pA++;
  263.             q15_t     inA2 = *pA++;
  264.  
  265.             q15_t     inB1 = *pB++;
  266.             q15_t     inB2 = *pB++;
  267.             sum += inA1 * inB1 + inA2 * inB2;
  268.  
  269.             inB1 = *pB++;
  270.             inB2 = *pB++;
  271.             sum2 += inA1 * inB1 + inA2 * inB2;
  272.  
  273.             inB1 = *pB++;
  274.             inB2 = *pB++;
  275.             sum3 += inA1 * inB1 + inA2 * inB2;
  276.  
  277.             inB1 = *pB++;
  278.             inB2 = *pB++;
  279.             sum4 += inA1 * inB1 + inA2 * inB2;
  280.  
  281.             colCnt--;
  282.         }
  283.         colCnt = dim_vec & 0x1;
  284.         while (colCnt)
  285.         {
  286.             q15_t     inA = *pA++;
  287.             q15_t     inB = *pB++;
  288.             sum += inA * inB;
  289.             inB = *pB++;
  290.             sum2 += inA * inB;
  291.             inB = *pB++;
  292.             sum3 += inA * inB;
  293.             inB = *pB++;
  294.             sum4 += inA * inB;
  295.             colCnt--;
  296.         }
  297.         *pO++ = (q15_t) __SSAT((sum >> out_shift), 16);
  298.         *pO++ = (q15_t) __SSAT((sum2 >> out_shift), 16);
  299.         *pO++ = (q15_t) __SSAT((sum3 >> out_shift), 16);
  300.         *pO++ = (q15_t) __SSAT((sum4 >> out_shift), 16);
  301.  
  302.         rowCnt--;
  303.     }
  304.     rowCnt = num_of_rows & 0x3;
  305.  
  306.     while (rowCnt)
  307.     {
  308.         int       ip_out = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
  309.         int       j;
  310.  
  311.         pA = pV;
  312.         for (j = 0; j < dim_vec; j++)
  313.         {
  314.             q15_t     inA = *pA++;
  315.             q15_t     inB = *pB++;
  316.             ip_out += inA * inB;
  317.         }
  318.         *pO++ = (q15_t) __SSAT((ip_out >> out_shift), 16);
  319.  
  320.         rowCnt--;
  321.     }
  322.  
  323. #endif                          /* ARM_MATH_DSP */
  324.  
  325.     /* Return to ARM_MATH_SUCCESS */
  326.     return (ARM_MATH_SUCCESS);
  327.  
  328. }
  329.  
  330. /**
  331.  * @} end of FC group
  332.  */
  333.