Subversion Repositories ScreenTimer

Rev

Details | Last modification | View Log | RSS feed

Rev Author Line No. Line
2 mjames 1
/*
2
 * Copyright (C) 2010-2018 Arm Limited or its affiliates. All rights reserved.
3
 *
4
 * SPDX-License-Identifier: Apache-2.0
5
 *
6
 * Licensed under the Apache License, Version 2.0 (the License); you may
7
 * not use this file except in compliance with the License.
8
 * You may obtain a copy of the License at
9
 *
10
 * www.apache.org/licenses/LICENSE-2.0
11
 *
12
 * Unless required by applicable law or agreed to in writing, software
13
 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
 * See the License for the specific language governing permissions and
16
 * limitations under the License.
17
 */
18
 
19
/* ----------------------------------------------------------------------
20
 * Project:      CMSIS NN Library
21
 * Title:        arm_nn_mat_mult_kernel_q7_q15.c
22
 * Description:  Matrix-multiplication function for convolution
23
 *
24
 * $Date:        17. January 2018
25
 * $Revision:    V.1.0.0
26
 *
27
 * Target Processor:  Cortex-M cores
28
 * -------------------------------------------------------------------- */
29
 
30
#include "arm_math.h"
31
#include "arm_nnfunctions.h"
32
 
33
  /**
34
   * @brief Matrix-multiplication function for convolution
35
   * @param[in]       pA          pointer to operand A
36
   * @param[in]       pInBuffer   pointer to operand B, always conssists of 2 vectors
37
   * @param[in]       ch_im_out   numRow of A
38
   * @param[in]       numCol_A    numCol of A
39
   * @param[in]       bias_shift  amount of left-shift for bias
40
   * @param[in]       out_shift   amount of right-shift for output
41
   * @param[in]       bias        the bias
42
   * @param[in,out]   pOut        pointer to output
43
   * @return     The function returns the incremented output pointer
44
   *
45
   * @details
46
   *
47
   * This function does the matrix multiplication with weight matrix
48
   * and 2 columns from im2col.
49
   */
50
 
51
q7_t     *arm_nn_mat_mult_kernel_q7_q15(const q7_t * pA,
52
                                        const q15_t * pInBuffer,
53
                                        const uint16_t ch_im_out,
54
                                        const uint16_t numCol_A,
55
                                        const uint16_t bias_shift,
56
                                        const uint16_t out_shift,
57
                                        const q7_t * bias,
58
                                        q7_t * pOut)
59
{
60
#if defined (ARM_MATH_DSP)
61
    /* set up the second output pointers */
62
    q7_t     *pOut2 = pOut + ch_im_out;
63
    const q7_t *pBias = bias;
64
 
65
    uint16_t  rowCnt = ch_im_out >> 1;
66
    /* this loop over rows in A */
67
    while (rowCnt)
68
    {
69
        /* setup pointers for B */
70
        const q15_t *pB = pInBuffer;
71
        const q15_t *pB2 = pB + numCol_A;
72
 
73
        /* align the second pointer for A */
74
        const q7_t *pA2 = pA + numCol_A;
75
 
76
        /* init the sum with bias */
77
        q31_t     sum =  ((q31_t)(*pBias) << bias_shift) + NN_ROUND(out_shift);
78
        q31_t     sum2 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
79
        q31_t     sum3 = ((q31_t)(*pBias) << bias_shift) + NN_ROUND(out_shift);
80
        q31_t     sum4 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
81
 
82
        uint16_t  colCnt = numCol_A >> 2;
83
        /* accumulate over the vector */
84
        while (colCnt)
85
        {
86
            q31_t     inA11, inA12, inA21, inA22;
87
            q31_t     inB1 = *__SIMD32(pB)++;
88
            q31_t     inB2 = *__SIMD32(pB2)++;
89
 
90
            pA = (q7_t *) read_and_pad((void *)pA, &inA11, &inA12);
91
            pA2 = (q7_t *) read_and_pad((void *)pA2, &inA21, &inA22);
92
 
93
            sum = __SMLAD(inA11, inB1, sum);
94
            sum2 = __SMLAD(inA11, inB2, sum2);
95
            sum3 = __SMLAD(inA21, inB1, sum3);
96
            sum4 = __SMLAD(inA21, inB2, sum4);
97
 
98
            inB1 = *__SIMD32(pB)++;
99
            inB2 = *__SIMD32(pB2)++;
100
 
101
            sum = __SMLAD(inA12, inB1, sum);
102
            sum2 = __SMLAD(inA12, inB2, sum2);
103
            sum3 = __SMLAD(inA22, inB1, sum3);
104
            sum4 = __SMLAD(inA22, inB2, sum4);
105
 
106
            colCnt--;
107
        }                       /* while over colCnt */
108
        colCnt = numCol_A & 0x3;
109
        while (colCnt)
110
        {
111
            q7_t      inA1 = *pA++;
112
            q15_t     inB1 = *pB++;
113
            q7_t      inA2 = *pA2++;
114
            q15_t     inB2 = *pB2++;
115
 
116
            sum += inA1 * inB1;
117
            sum2 += inA1 * inB2;
118
            sum3 += inA2 * inB1;
119
            sum4 += inA2 * inB2;
120
            colCnt--;
121
        }                       /* while over colCnt */
122
        *pOut++ = (q7_t) __SSAT((sum >> out_shift), 8);
123
        *pOut++ = (q7_t) __SSAT((sum3 >> out_shift), 8);
124
        *pOut2++ = (q7_t) __SSAT((sum2 >> out_shift), 8);
125
        *pOut2++ = (q7_t) __SSAT((sum4 >> out_shift), 8);
126
 
127
        /* skip the row computed with A2 */
128
        pA += numCol_A;
129
        rowCnt--;
130
    }                           /* for over ch_im_out */
131
 
132
    /* compute left-over row if any */
133
    if (ch_im_out & 0x1)
134
    {
135
        /* setup pointers for B */
136
        const q15_t *pB = pInBuffer;
137
        const q15_t *pB2 = pB + numCol_A;
138
 
139
        /* load the bias */
140
        q31_t     sum = ((q31_t)(*pBias) << bias_shift) + NN_ROUND(out_shift);
141
        q31_t     sum2 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
142
 
143
        uint16_t  colCnt = numCol_A >> 2;
144
        while (colCnt)
145
        {
146
            q31_t     inA11, inA12;
147
            q31_t     inB1 = *__SIMD32(pB)++;
148
            q31_t     inB2 = *__SIMD32(pB2)++;
149
 
150
            pA = (q7_t *) read_and_pad((void *)pA, &inA11, &inA12);
151
 
152
            sum = __SMLAD(inA11, inB1, sum);
153
            sum2 = __SMLAD(inA11, inB2, sum2);
154
 
155
            inB1 = *__SIMD32(pB)++;
156
            inB2 = *__SIMD32(pB2)++;
157
            sum = __SMLAD(inA12, inB1, sum);
158
            sum2 = __SMLAD(inA12, inB2, sum2);
159
 
160
            colCnt--;
161
        }
162
        colCnt = numCol_A & 0x3;
163
        while (colCnt)
164
        {
165
            q7_t      inA1 = *pA++;
166
            q15_t     inB1 = *pB++;
167
            q15_t     inB2 = *pB2++;
168
 
169
            sum += inA1 * inB1;
170
            sum2 += inA1 * inB2;
171
            colCnt--;
172
        }
173
 
174
        *pOut++ = (q7_t) __SSAT((sum >> out_shift), 8);
175
        *pOut2++ = (q7_t) __SSAT((sum2 >> out_shift), 8);
176
    }
177
 
178
    pOut += ch_im_out;
179
 
180
    /* return the new output pointer with offset */
181
    return pOut;
182
#else
183
    /* To be completed */
184
    return NULL;
185
#endif                          /* ARM_MATH_DSP */
186
 
187
}