Subversion Repositories dashGPS

Rev

Rev 2 | Blame | Compare with Previous | Last modification | View Log | Download | RSS feed

  1. /*
  2.  * Copyright (C) 2010-2018 Arm Limited or its affiliates. All rights reserved.
  3.  *
  4.  * SPDX-License-Identifier: Apache-2.0
  5.  *
  6.  * Licensed under the Apache License, Version 2.0 (the License); you may
  7.  * not use this file except in compliance with the License.
  8.  * You may obtain a copy of the License at
  9.  *
  10.  * www.apache.org/licenses/LICENSE-2.0
  11.  *
  12.  * Unless required by applicable law or agreed to in writing, software
  13.  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
  14.  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15.  * See the License for the specific language governing permissions and
  16.  * limitations under the License.
  17.  */
  18.  
  19. /* ----------------------------------------------------------------------
  20.  * Project:      CMSIS NN Library
  21.  * Title:        arm_nn_mat_mult_kernel_q7_q15.c
  22.  * Description:  Matrix-multiplication function for convolution
  23.  *
  24.  * $Date:        17. January 2018
  25.  * $Revision:    V.1.0.0
  26.  *
  27.  * Target Processor:  Cortex-M cores
  28.  * -------------------------------------------------------------------- */
  29.  
  30. #include "arm_math.h"
  31. #include "arm_nnfunctions.h"
  32.  
  33.   /**
  34.    * @brief Matrix-multiplication function for convolution
  35.    * @param[in]       pA          pointer to operand A
  36.    * @param[in]       pInBuffer   pointer to operand B, always conssists of 2 vectors
  37.    * @param[in]       ch_im_out   numRow of A
  38.    * @param[in]       numCol_A    numCol of A
  39.    * @param[in]       bias_shift  amount of left-shift for bias
  40.    * @param[in]       out_shift   amount of right-shift for output
  41.    * @param[in]       bias        the bias
  42.    * @param[in,out]   pOut        pointer to output
  43.    * @return     The function returns the incremented output pointer
  44.    *
  45.    * @details
  46.    *
  47.    * This function does the matrix multiplication with weight matrix
  48.    * and 2 columns from im2col.
  49.    */
  50.  
  51. q7_t     *arm_nn_mat_mult_kernel_q7_q15(const q7_t * pA,
  52.                                         const q15_t * pInBuffer,
  53.                                         const uint16_t ch_im_out,
  54.                                         const uint16_t numCol_A,
  55.                                         const uint16_t bias_shift,
  56.                                         const uint16_t out_shift,
  57.                                         const q7_t * bias,
  58.                                         q7_t * pOut)
  59. {
  60. #if defined (ARM_MATH_DSP)
  61.     /* set up the second output pointers */
  62.     q7_t     *pOut2 = pOut + ch_im_out;
  63.     const q7_t *pBias = bias;
  64.  
  65.     uint16_t  rowCnt = ch_im_out >> 1;
  66.     /* this loop over rows in A */
  67.     while (rowCnt)
  68.     {
  69.         /* setup pointers for B */
  70.         const q15_t *pB = pInBuffer;
  71.         const q15_t *pB2 = pB + numCol_A;
  72.  
  73.         /* align the second pointer for A */
  74.         const q7_t *pA2 = pA + numCol_A;
  75.  
  76.         /* init the sum with bias */
  77.         q31_t     sum =  ((q31_t)(*pBias) << bias_shift) + NN_ROUND(out_shift);
  78.         q31_t     sum2 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
  79.         q31_t     sum3 = ((q31_t)(*pBias) << bias_shift) + NN_ROUND(out_shift);
  80.         q31_t     sum4 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
  81.  
  82.         uint16_t  colCnt = numCol_A >> 2;
  83.         /* accumulate over the vector */
  84.         while (colCnt)
  85.         {
  86.             q31_t     inA11, inA12, inA21, inA22;
  87.             q31_t     inB1 = *__SIMD32(pB)++;
  88.             q31_t     inB2 = *__SIMD32(pB2)++;
  89.  
  90.             pA = (q7_t *) read_and_pad((void *)pA, &inA11, &inA12);
  91.             pA2 = (q7_t *) read_and_pad((void *)pA2, &inA21, &inA22);
  92.  
  93.             sum = __SMLAD(inA11, inB1, sum);
  94.             sum2 = __SMLAD(inA11, inB2, sum2);
  95.             sum3 = __SMLAD(inA21, inB1, sum3);
  96.             sum4 = __SMLAD(inA21, inB2, sum4);
  97.  
  98.             inB1 = *__SIMD32(pB)++;
  99.             inB2 = *__SIMD32(pB2)++;
  100.  
  101.             sum = __SMLAD(inA12, inB1, sum);
  102.             sum2 = __SMLAD(inA12, inB2, sum2);
  103.             sum3 = __SMLAD(inA22, inB1, sum3);
  104.             sum4 = __SMLAD(inA22, inB2, sum4);
  105.  
  106.             colCnt--;
  107.         }                       /* while over colCnt */
  108.         colCnt = numCol_A & 0x3;
  109.         while (colCnt)
  110.         {
  111.             q7_t      inA1 = *pA++;
  112.             q15_t     inB1 = *pB++;
  113.             q7_t      inA2 = *pA2++;
  114.             q15_t     inB2 = *pB2++;
  115.  
  116.             sum += inA1 * inB1;
  117.             sum2 += inA1 * inB2;
  118.             sum3 += inA2 * inB1;
  119.             sum4 += inA2 * inB2;
  120.             colCnt--;
  121.         }                       /* while over colCnt */
  122.         *pOut++ = (q7_t) __SSAT((sum >> out_shift), 8);
  123.         *pOut++ = (q7_t) __SSAT((sum3 >> out_shift), 8);
  124.         *pOut2++ = (q7_t) __SSAT((sum2 >> out_shift), 8);
  125.         *pOut2++ = (q7_t) __SSAT((sum4 >> out_shift), 8);
  126.  
  127.         /* skip the row computed with A2 */
  128.         pA += numCol_A;
  129.         rowCnt--;
  130.     }                           /* for over ch_im_out */
  131.  
  132.     /* compute left-over row if any */
  133.     if (ch_im_out & 0x1)
  134.     {
  135.         /* setup pointers for B */
  136.         const q15_t *pB = pInBuffer;
  137.         const q15_t *pB2 = pB + numCol_A;
  138.  
  139.         /* load the bias */
  140.         q31_t     sum = ((q31_t)(*pBias) << bias_shift) + NN_ROUND(out_shift);
  141.         q31_t     sum2 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
  142.  
  143.         uint16_t  colCnt = numCol_A >> 2;
  144.         while (colCnt)
  145.         {
  146.             q31_t     inA11, inA12;
  147.             q31_t     inB1 = *__SIMD32(pB)++;
  148.             q31_t     inB2 = *__SIMD32(pB2)++;
  149.  
  150.             pA = (q7_t *) read_and_pad((void *)pA, &inA11, &inA12);
  151.  
  152.             sum = __SMLAD(inA11, inB1, sum);
  153.             sum2 = __SMLAD(inA11, inB2, sum2);
  154.  
  155.             inB1 = *__SIMD32(pB)++;
  156.             inB2 = *__SIMD32(pB2)++;
  157.             sum = __SMLAD(inA12, inB1, sum);
  158.             sum2 = __SMLAD(inA12, inB2, sum2);
  159.  
  160.             colCnt--;
  161.         }
  162.         colCnt = numCol_A & 0x3;
  163.         while (colCnt)
  164.         {
  165.             q7_t      inA1 = *pA++;
  166.             q15_t     inB1 = *pB++;
  167.             q15_t     inB2 = *pB2++;
  168.  
  169.             sum += inA1 * inB1;
  170.             sum2 += inA1 * inB2;
  171.             colCnt--;
  172.         }
  173.  
  174.         *pOut++ = (q7_t) __SSAT((sum >> out_shift), 8);
  175.         *pOut2++ = (q7_t) __SSAT((sum2 >> out_shift), 8);
  176.     }
  177.  
  178.     pOut += ch_im_out;
  179.  
  180.     /* return the new output pointer with offset */
  181.     return pOut;
  182. #else
  183.     /* To be completed */
  184.     return NULL;
  185. #endif                          /* ARM_MATH_DSP */
  186.  
  187. }
  188.