Subversion Repositories dashGPS

Rev

Blame | Last modification | View Log | Download | RSS feed

  1. /*
  2.  * Copyright (C) 2010-2018 Arm Limited or its affiliates. All rights reserved.
  3.  *
  4.  * SPDX-License-Identifier: Apache-2.0
  5.  *
  6.  * Licensed under the Apache License, Version 2.0 (the License); you may
  7.  * not use this file except in compliance with the License.
  8.  * You may obtain a copy of the License at
  9.  *
  10.  * www.apache.org/licenses/LICENSE-2.0
  11.  *
  12.  * Unless required by applicable law or agreed to in writing, software
  13.  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
  14.  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15.  * See the License for the specific language governing permissions and
  16.  * limitations under the License.
  17.  */
  18.  
  19. /* ----------------------------------------------------------------------
  20.  * Project:      CMSIS NN Library
  21.  * Title:        arm_depthwise_separable_conv_HWC_q7_nonsquare.c
  22.  * Description:  Q7 depthwise separable convolution function (non-square shape)
  23.  *
  24.  * $Date:        17. January 2018
  25.  * $Revision:    V.1.0.0
  26.  *
  27.  * Target Processor:  Cortex-M cores
  28.  *
  29.  * -------------------------------------------------------------------- */
  30.  
  31. #include "arm_math.h"
  32. #include "arm_nnfunctions.h"
  33.  
  34. /**
  35.  *  @ingroup groupNN
  36.  */
  37.  
  38. /**
  39.  * @addtogroup NNConv
  40.  * @{
  41.  */
  42.  
  43. /**
  44.  * @brief Q7 depthwise separable convolution function (non-square shape)
  45.  * @param[in]       Im_in         pointer to input tensor
  46.  * @param[in]       dim_im_in_x   input tensor dimention x
  47.  * @param[in]       dim_im_in_y   input tensor dimention y
  48.  * @param[in]       ch_im_in      number of input tensor channels
  49.  * @param[in]       wt            pointer to kernel weights
  50.  * @param[in]       ch_im_out     number of filters, i.e., output tensor channels
  51.  * @param[in]       dim_kernel_x  filter kernel size x
  52.  * @param[in]       dim_kernel_y  filter kernel size y
  53.  * @param[in]       padding_x     padding sizes x
  54.  * @param[in]       padding_y     padding sizes y
  55.  * @param[in]       stride_x      convolution stride x
  56.  * @param[in]       stride_y      convolution stride y
  57.  * @param[in]       bias          pointer to bias
  58.  * @param[in]       bias_shift    amount of left-shift for bias
  59.  * @param[in]       out_shift     amount of right-shift for output
  60.  * @param[in,out]   Im_out        pointer to output tensor
  61.  * @param[in]       dim_im_out_x  output tensor dimension x
  62.  * @param[in]       dim_im_out_y  output tensor dimension y
  63.  * @param[in,out]   bufferA       pointer to buffer space for input
  64.  * @param[in,out]   bufferB       pointer to buffer space for output
  65.  * @return     The function returns either
  66.  * <code>ARM_MATH_SIZE_MISMATCH</code> or <code>ARM_MATH_SUCCESS</code> based on the outcome of size checking.
  67.  *
  68.  * This function is the version with full list of optimization tricks, but with
  69.  * some contraints:
  70.  *   ch_im_in is multiple of 2
  71.  *   ch_im_out is multiple of 2
  72.  */
  73.  
  74. arm_status arm_depthwise_separable_conv_HWC_q7_nonsquare(const q7_t * Im_in,
  75.                                                          const uint16_t dim_im_in_x,
  76.                                                          const uint16_t dim_im_in_y,
  77.                                                          const uint16_t ch_im_in,
  78.                                                          const q7_t * wt,
  79.                                                          const uint16_t ch_im_out,
  80.                                                          const uint16_t dim_kernel_x,
  81.                                                          const uint16_t dim_kernel_y,
  82.                                                          const uint16_t padding_x,
  83.                                                          const uint16_t padding_y,
  84.                                                          const uint16_t stride_x,
  85.                                                          const uint16_t stride_y,
  86.                                                          const q7_t * bias,
  87.                                                          const uint16_t bias_shift,
  88.                                                          const uint16_t out_shift,
  89.                                                          q7_t * Im_out,
  90.                                                          const uint16_t dim_im_out_x,
  91.                                                          const uint16_t dim_im_out_y,
  92.                                                          q15_t * bufferA,
  93.                                                          q7_t * bufferB)
  94. {
  95.  
  96. #if defined (ARM_MATH_DSP)
  97.     /* Run the following code for Cortex-M4 and Cortex-M7 */
  98.  
  99. /*
  100.  * Implementation:
  101.  * There are 3 nested loop here:
  102.  * Inner loop: calculate each output value with MAC instruction over an accumulator
  103.  * Mid   loop: loop over different output channel
  104.  * Outer loop: loop over different output (x, y)
  105.  *
  106.  */
  107.  
  108.     int16_t   i_out_y, i_out_x;
  109.     int16_t   i_ker_y, i_ker_x;
  110.     q7_t     *colBuffer = (q7_t *) bufferA;
  111.     q7_t     *pBuffer = colBuffer;
  112.     const q7_t *pBias = bias;
  113.     q7_t     *pOut = Im_out;
  114.     uint16_t  rowCnt;
  115.     uint16_t  row_shift;
  116.  
  117.     /* do some checking here, basically ch_im_in == ch_im_out */
  118.     if (ch_im_in != ch_im_out)
  119.     {
  120.         return ARM_MATH_SIZE_MISMATCH;
  121.     }
  122.  
  123.     for (i_out_y = 0; i_out_y < dim_im_out_y; i_out_y++)
  124.     {
  125.         for (i_out_x = 0; i_out_x < dim_im_out_x; i_out_x++)
  126.         {
  127.             /* we first do im2col here */
  128.             for (i_ker_y = i_out_y * stride_y - padding_y; i_ker_y < i_out_y * stride_y - padding_y + dim_kernel_y;
  129.                  i_ker_y++)
  130.             {
  131.                 for (i_ker_x = i_out_x * stride_x - padding_x; i_ker_x < i_out_x * stride_x - padding_x + dim_kernel_x;
  132.                      i_ker_x++)
  133.                 {
  134.                     if (i_ker_y < 0 || i_ker_y >= dim_im_in_y || i_ker_x < 0 || i_ker_x >= dim_im_in_x)
  135.                     {
  136.                         /* arm_fill_q7(0, pBuffer, ch_im_in); */
  137.                         memset(pBuffer, 0, ch_im_in);
  138.                     } else
  139.                     {
  140.                         /* arm_copy_q7((q7_t *) Im_in + (i_ker_y * dim_im_in_x + i_ker_x) * ch_im_in, pBuffer, ch_im_in); */
  141.                         memcpy(pBuffer, (q7_t *) Im_in + (i_ker_y * dim_im_in_x + i_ker_x) * ch_im_in, ch_im_in);
  142.                     }
  143.                     pBuffer += ch_im_in;
  144.                 }
  145.             }
  146.  
  147.             /* we will do the computation here for each channel */
  148.             rowCnt = ch_im_out >> 2;
  149.             row_shift = 0;
  150.             pBias = bias;
  151.  
  152.             while (rowCnt)
  153.             {
  154.                 q31_t     sum =  ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
  155.                 q31_t     sum2 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
  156.                 q31_t     sum3 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
  157.                 q31_t     sum4 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
  158.  
  159.                 uint16_t  colCnt = (dim_kernel_x * dim_kernel_y) >> 1;
  160.                 q7_t     *pB = colBuffer + row_shift;
  161.                 const q7_t *pA = wt + row_shift;
  162.                 row_shift += 4;
  163.  
  164. #ifdef USE_INTRINSIC
  165.  
  166. #ifndef ARM_MATH_BIG_ENDIAN
  167.  
  168.                 while (colCnt)
  169.                 {
  170.                     q31_t     inA1, inA2, inB1, inB2, opA, opB;
  171.  
  172.                     inB1 = *__SIMD32(pB);
  173.                     pB += ch_im_in;
  174.                     opB = *__SIMD32(pB);
  175.                     pB += ch_im_in;
  176.                     inB2 = __PKHTB(opB, inB1, 16);
  177.                     inB1 = __PKHBT(inB1, opB, 16);
  178.                     inA1 = *__SIMD32(pA);
  179.                     pA += ch_im_in;
  180.                     opB = *__SIMD32(pA);
  181.                     pA += ch_im_in;
  182.                     inA2 = __PKHTB(opB, inA1, 16);
  183.                     inA1 = __PKHBT(inA1, opB, 16);
  184.                     opA = __SXTB16(inA1);
  185.                     opB = __SXTB16(inB1);
  186.                     sum = __SMLAD(opA, opB, sum);
  187.                     opA = __SXTB16(__ROR(inA1, 8));
  188.                     opB = __SXTB16(__ROR(inB1, 8));
  189.                     sum2 = __SMLAD(opA, opB, sum2);
  190.                     opA = __SXTB16(inA2);
  191.                     opB = __SXTB16(inB2);
  192.                     sum3 = __SMLAD(opA, opB, sum3);
  193.                     opA = __SXTB16(__ROR(inA2, 8));
  194.                     opB = __SXTB16(__ROR(inB2, 8));
  195.                     sum4 = __SMLAD(opA, opB, sum4);
  196.                     colCnt--;
  197.                 }
  198. #else
  199.  
  200.                 while (colCnt)
  201.                 {
  202.                     q31_t     inA1, inA2, inB1, inB2, opA, opB;
  203.  
  204.                     inB1 = *__SIMD32(pB);
  205.                     pB += ch_im_in;
  206.                     opB = *__SIMD32(pB);
  207.                     pB += ch_im_in;
  208.                     inB2 = __PKHBT(opB, inB1, 16);
  209.                     inB1 = __PKHTB(inB1, opB, 16);
  210.                     inA1 = *__SIMD32(pA);
  211.                     pA += ch_im_in;
  212.                     opB = *__SIMD32(pA);
  213.                     pA += ch_im_in;
  214.                     inA2 = __PKHBT(opB, inA1, 16);
  215.                     inA1 = __PKHTB(inA1, opB, 16);
  216.                     opA = __SXTB16(inA1);
  217.                     opB = __SXTB16(inB1);
  218.                     sum2 = __SMLAD(opA, opB, sum2);
  219.                     opA = __SXTB16(__ROR(inA1, 8));
  220.                     opB = __SXTB16(__ROR(inB1, 8));
  221.                     sum = __SMLAD(opA, opB, sum);
  222.                     opA = __SXTB16(inA2);
  223.                     opB = __SXTB16(inB2);
  224.                     sum4 = __SMLAD(opA, opB, sum4);
  225.                     opA = __SXTB16(__ROR(inA2, 8));
  226.                     opB = __SXTB16(__ROR(inB2, 8));
  227.                     sum3 = __SMLAD(opA, opB, sum3);
  228.                     colCnt--;
  229.                 }
  230.  
  231. #endif                          /* ARM_MATH_BIG_ENDIAN */
  232.  
  233. #else
  234.  
  235. #ifndef ARM_MATH_BIG_ENDIAN
  236.                 //  r0    r1    r2    r3    r4   r5
  237.                 // inA1, inA2, inB1, inB2, opA, opB
  238.                 asm volatile ("COL_LOOP:\n"
  239.                               "ldr.w r2, [%[pB], #0]\n"
  240.                               "add.w %[pB], %[pB], %[ch_im_in]\n"
  241.                               "ldr.w r5, [%[pB], #0]\n"
  242.                               "add.w %[pB], %[pB], %[ch_im_in]\n"
  243.                               "pkhtb r3, r5, r2, ASR #16\n"
  244.                               "pkhbt r2, r2, r5, LSL #16\n"
  245.                               "ldr.w r0, [%[pA], #0]\n"
  246.                               "add.w %[pA], %[pA], %[ch_im_in]\n"
  247.                               "ldr.w r5, [%[pA], #0]\n"
  248.                               "add.w %[pA], %[pA], %[ch_im_in]\n"
  249.                               "pkhtb r1, r5, r0, ASR #16\n"
  250.                               "pkhbt r0, r0, r5, LSL #16\n"
  251.                               "sxtb16 r4, r0\n"
  252.                               "sxtb16 r5, r2\n"
  253.                               "smlad %[sum], r4, r5, %[sum]\n"
  254.                               "mov.w r4, r0, ror #8\n"
  255.                               "mov.w r5, r2, ror #8\n"
  256.                               "sxtb16 r4, r4\n"
  257.                               "sxtb16 r5, r5\n"
  258.                               "smlad %[sum2], r4, r5, %[sum2]\n"
  259.                               "sxtb16 r4, r1\n"
  260.                               "sxtb16 r5, r3\n"
  261.                               "smlad %[sum3], r4, r5, %[sum3]\n"
  262.                               "mov.w r4, r1, ror #8\n"
  263.                               "mov.w r5, r3, ror #8\n"
  264.                               "sxtb16 r4, r4\n"
  265.                               "sxtb16 r5, r5\n"
  266.                               "smlad %[sum4], r4, r5, %[sum4]\n"
  267.                               "subs %[colCnt], #1\n"
  268.                               "bne COL_LOOP\n":[sum] "+r"(sum),[sum2] "+r"(sum2),[sum3] "+r"(sum3),
  269.                               [sum4] "+r"(sum4),[pB] "+r"(pB),[pA] "+r"(pA):[colCnt] "r"(colCnt),
  270.                               [ch_im_in] "r"(ch_im_in):"r0", "r1", "r2", "r3", "r4", "r5");
  271. #else
  272.                 //  r0    r1    r2    r3    r4   r5
  273.                 // inA1, inA2, inB1, inB2, opA, opB
  274.                 asm volatile ("COL_LOOP:\n"
  275.                               "ldr.w r2, [%[pB], #0]\n"
  276.                               "add.w %[pB], %[pB], %[ch_im_in]\n"
  277.                               "ldr.w r5, [%[pB], #0]\n"
  278.                               "add.w %[pB], %[pB], %[ch_im_in]\n"
  279.                               "pkhbt r3, r5, r2, LSL #16\n"
  280.                               "pkhtb r2, r2, r5, ASR #16\n"
  281.                               "ldr.w r0, [%[pA], #0]\n"
  282.                               "add.w %[pA], %[pA], %[ch_im_in]\n"
  283.                               "ldr.w r5, [%[pA], #0]\n"
  284.                               "add.w %[pA], %[pA], %[ch_im_in]\n"
  285.                               "pkhbt r1, r5, r0, LSL #16\n"
  286.                               "pkhtb r0, r0, r5, ASR #16\n"
  287.                               "sxtb16 r4, r0\n"
  288.                               "sxtb16 r5, r2\n"
  289.                               "smlad %[sum2], r4, r5, %[sum2]\n"
  290.                               "mov.w r4, r0, ror #8\n"
  291.                               "mov.w r5, r2, ror #8\n"
  292.                               "sxtb16 r4, r4\n"
  293.                               "sxtb16 r5, r5\n"
  294.                               "smlad %[sum], r4, r5, %[sum]\n"
  295.                               "sxtb16 r4, r1\n"
  296.                               "sxtb16 r5, r3\n"
  297.                               "smlad %[sum4], r4, r5, %[sum4]\n"
  298.                               "mov.w r4, r1, ror #8\n"
  299.                               "mov.w r5, r3, ror #8\n"
  300.                               "sxtb16 r4, r4\n"
  301.                               "sxtb16 r5, r5\n"
  302.                               "smlad %[sum3], r4, r5, %[sum3]\n"
  303.                               "subs %[colCnt], #1\n"
  304.                               "bne COL_LOOP\n":[sum] "+r"(sum),[sum2] "+r"(sum2),[sum3] "+r"(sum3),
  305.                               [sum4] "+r"(sum4),[pB] "+r"(pB),[pA] "+r"(pA):[colCnt] "r"(colCnt),
  306.                               [ch_im_in] "r"(ch_im_in):"r0", "r1", "r2", "r3", "r4", "r5");
  307. #endif                          /*ARM_MATH_BIG_ENDIAN */
  308.  
  309. #endif                          /* USE_INTRINSIC */
  310.  
  311.                 colCnt = (dim_kernel_x * dim_kernel_y) & 0x1;
  312.                 while (colCnt)
  313.                 {
  314.                     union arm_nnword inA, inB;
  315.                     inA.word = *__SIMD32(pA);
  316.                     pA += ch_im_in;
  317.                     inB.word = *__SIMD32(pB);
  318.                     pB += ch_im_in;
  319.                     sum += inA.bytes[0] * inB.bytes[0];
  320.                     sum2 += inA.bytes[1] * inB.bytes[1];
  321.                     sum3 += inA.bytes[2] * inB.bytes[2];
  322.                     sum4 += inA.bytes[3] * inB.bytes[3];
  323.                     colCnt--;
  324.                 }
  325.  
  326.                 *pOut++ = (q7_t) __SSAT((sum >> out_shift), 8);
  327.                 *pOut++ = (q7_t) __SSAT((sum2 >> out_shift), 8);
  328.                 *pOut++ = (q7_t) __SSAT((sum3 >> out_shift), 8);
  329.                 *pOut++ = (q7_t) __SSAT((sum4 >> out_shift), 8);
  330.  
  331.                 rowCnt--;
  332.             }
  333.  
  334.             rowCnt = ch_im_out & 0x3;
  335.             while (rowCnt)
  336.             {
  337.                 q7_t     *pB = colBuffer + row_shift;
  338.                 const q7_t *pA = wt + row_shift;
  339.                 q31_t     sum = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
  340.                 uint16_t  colCnt = (dim_kernel_x * dim_kernel_y);
  341.  
  342.                 row_shift += 1;
  343.  
  344.                 while (colCnt)
  345.                 {
  346.                     q7_t      A1 = *pA;
  347.                     q7_t      B1 = *pB;
  348.                     pA += ch_im_in;
  349.                     pB += ch_im_in;
  350.                     sum += A1 * B1;
  351.  
  352.                     colCnt--;
  353.                 }
  354.                 *pOut++ = (q7_t) __SSAT((sum >> out_shift), 8);
  355.                 rowCnt--;
  356.             }
  357.  
  358.             // clear counter and pointers
  359.             pBuffer = colBuffer;
  360.         }
  361.     }
  362.  
  363. #else
  364.     /* Run the following code as reference implementation for Cortex-M0 and Cortex-M3 */
  365.     int       i_out_y, i_out_x, i_ch_out;
  366.     int       i_ker_y, i_ker_x;
  367.  
  368.     /* do some checking here, basically ch_im_in == ch_im_out */
  369.     if (ch_im_in != ch_im_out)
  370.     {
  371.         return ARM_MATH_SIZE_MISMATCH;
  372.     }
  373.  
  374.     for (i_out_y = 0; i_out_y < dim_im_out_y; i_out_y++)
  375.     {
  376.         for (i_out_x = 0; i_out_x < dim_im_out_x; i_out_x++)
  377.         {
  378.             for (i_ch_out = 0; i_ch_out < ch_im_out; i_ch_out++)
  379.             {
  380.                 // for each output
  381.                 int       conv_out = ((q31_t)(bias[i_ch_out]) << bias_shift) + NN_ROUND(out_shift);
  382.                 for (i_ker_y = 0; i_ker_y < dim_kernel_y; i_ker_y++)
  383.                 {
  384.                     for (i_ker_x = 0; i_ker_x < dim_kernel_x; i_ker_x++)
  385.                     {
  386.                         int       in_row = stride_y * i_out_y + i_ker_y - padding_y;
  387.                         int       in_col = stride_x * i_out_x + i_ker_x - padding_x;
  388.                         if (in_row >= 0 && in_col >= 0 && in_row < dim_im_in_y && in_col < dim_im_in_x)
  389.                         {
  390.                             conv_out += Im_in[(in_row * dim_im_in_x + in_col) * ch_im_in + i_ch_out] *                        
  391.                                 wt[(i_ker_y * dim_kernel_x + i_ker_x) * ch_im_out + i_ch_out];
  392.                         }
  393.                     }
  394.                 }
  395.                 Im_out[(i_out_y * dim_im_out_x + i_out_x) * ch_im_out + i_ch_out] =
  396.                     (q7_t) __SSAT((conv_out >> out_shift), 8);
  397.             }
  398.         }
  399.     }
  400.  
  401. #endif                          /* ARM_MATH_DSP */
  402.  
  403.  
  404.     /* Return to application */
  405.     return ARM_MATH_SUCCESS;
  406.  
  407. }
  408.  
  409. /**
  410.  * @} end of NNConv group
  411.  */
  412.