Subversion Repositories dashGPS

Rev

Details | Last modification | View Log | RSS feed

Rev Author Line No. Line
2 mjames 1
/*
2
 * Copyright (C) 2010-2018 Arm Limited or its affiliates. All rights reserved.
3
 *
4
 * SPDX-License-Identifier: Apache-2.0
5
 *
6
 * Licensed under the Apache License, Version 2.0 (the License); you may
7
 * not use this file except in compliance with the License.
8
 * You may obtain a copy of the License at
9
 *
10
 * www.apache.org/licenses/LICENSE-2.0
11
 *
12
 * Unless required by applicable law or agreed to in writing, software
13
 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
 * See the License for the specific language governing permissions and
16
 * limitations under the License.
17
 */
18
 
19
/* ----------------------------------------------------------------------
20
 * Project:      CMSIS NN Library
21
 * Title:        arm_fully_connected_q15_opt.c
22
 * Description:  Q15 opt fully-connected layer function
23
 *
24
 * $Date:        17. January 2018
25
 * $Revision:    V.1.0.0
26
 *
27
 * Target Processor:  Cortex-M cores
28
 *
29
 * -------------------------------------------------------------------- */
30
 
31
#include "arm_math.h"
32
#include "arm_nnfunctions.h"
33
 
34
/**
35
 *  @ingroup groupNN
36
 */
37
 
38
/**
39
 * @addtogroup FC
40
 * @{
41
 */
42
 
43
  /**
44
   * @brief Q15 opt fully-connected layer function
45
   * @param[in]       pV          pointer to input vector
46
   * @param[in]       pM          pointer to matrix weights
47
   * @param[in]       dim_vec     length of the vector
48
   * @param[in]       num_of_rows number of rows in weight matrix
49
   * @param[in]       bias_shift  amount of left-shift for bias
50
   * @param[in]       out_shift   amount of right-shift for output
51
   * @param[in]       bias        pointer to bias
52
   * @param[in,out]   pOut        pointer to output vector
53
   * @param[in,out]   vec_buffer  pointer to buffer space for input
54
   * @return     The function returns <code>ARM_MATH_SUCCESS</code>
55
   *
56
   *
57
   * @details
58
   *
59
   * <b>Buffer size:</b>
60
   *
61
   * vec_buffer size: 0
62
   *
63
   *  Here we use only one pointer to read 4 rows in the weight
64
   *  matrix. So if the original matrix looks like this:
65
   *
66
   *  | a11 | a12 | a13 |
67
   *
68
   *  | a21 | a22 | a23 |
69
   *
70
   *  | a31 | a32 | a33 |
71
   *
72
   *  | a41 | a42 | a43 |
73
   *
74
   *  | a51 | a52 | a53 |
75
   *
76
   *  | a61 | a62 | a63 |
77
   *
78
   *  We operates on multiple-of-4 rows, so the first four rows becomes
79
   *
80
   *  | a11 | a12 | a21 | a22 | a31 | a32 | a41 | a42 |
81
   *
82
   *  | a13 | a23 | a33 | a43 |
83
   *
84
   *  Remaining rows are kept the same original order.
85
   *
86
   *  So the stored weight matrix looks like this:
87
   *
88
   *
89
   *  | a11 | a12 | a21 | a22 | a31 | a32 | a41 | a42 |
90
   *
91
   *  | a13 | a23 | a33 | a43 | a51 | a52 | a53 | a61 |
92
   *
93
   *  | a62 | a63 |
94
   */
95
 
96
arm_status
97
arm_fully_connected_q15_opt(const q15_t * pV,
98
                            const q15_t * pM,
99
                            const uint16_t dim_vec,
100
                            const uint16_t num_of_rows,
101
                            const uint16_t bias_shift,
102
                            const uint16_t out_shift,
103
                            const q15_t * bias,
104
                            q15_t * pOut,
105
                            q15_t * vec_buffer)
106
{
107
 
108
#if defined (ARM_MATH_DSP)
109
    /* Run the following code for Cortex-M4 and Cortex-M7 */
110
 
111
    const q15_t *pB = pM;
112
    q15_t    *pO = pOut;
113
    const q15_t *pBias = bias;
114
    const q15_t *pA = pV;
115
 
116
    uint16_t  rowCnt = num_of_rows >> 2;
117
 
118
    while (rowCnt)
119
    {
120
        q31_t     sum =  ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
121
        q31_t     sum2 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
122
        q31_t     sum3 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
123
        q31_t     sum4 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
124
 
125
        uint16_t  colCnt = dim_vec >> 1;
126
 
127
        pA = pV;
128
 
129
#ifdef USE_INTRINSIC
130
 
131
        while (colCnt)
132
        {
133
            q31_t     inM11, inM12, inM13, inM14;
134
            q31_t     inV;
135
 
136
            inV = *__SIMD32(pA)++;
137
            inM11 = *__SIMD32(pB)++;
138
            sum = __SMLAD(inV, inM11, sum);
139
            inM12 = *__SIMD32(pB)++;
140
            sum2 = __SMLAD(inV, inM12, sum2);
141
            inM13 = *__SIMD32(pB)++;
142
            sum3 = __SMLAD(inV, inM13, sum3);
143
            inM14 = *__SIMD32(pB)++;
144
            sum4 = __SMLAD(inV, inM14, sum4);
145
            colCnt--;
146
        }
147
 
148
#else
149
 
150
        /*
151
         * register needed:
152
         * loop counter: colCnt
153
         * accumulators: sum, sum2, sum3, sum4
154
         * pointers: pB, pA
155
         * weight data: inM11, inM12, inM13, inM14
156
         * activation data: inV
157
         */
158
 
159
        asm volatile ("COL_LOOP_%=:\n"
160
                      "ldr.w r4, [%[pA]], #4\n"
161
                      "ldr.w r0, [%[pB]], #16\n"
162
                      "smlad %[sum], r4, r0, %[sum]\n"
163
                      "ldr.w r1, [%[pB] , #-12]\n"
164
                      "smlad %[sum2], r4, r1, %[sum2]\n"
165
                      "ldr.w r2, [%[pB] , #-8]\n"
166
                      "smlad %[sum3], r4, r2, %[sum3]\n"
167
                      "ldr.w r3, [%[pB] , #-4]\n"
168
                      "smlad %[sum4], r4, r3, %[sum4]\n"
169
                      "subs %[colCnt], #1\n"
170
                      "bne COL_LOOP_%=\n":[sum] "+r"(sum),
171
                      [sum2] "+r"(sum2),[sum3] "+r"(sum3),
172
                      [sum4] "+r"(sum4),[pB] "+r"(pB),[pA] "+r"(pA):[colCnt] "r"(colCnt):"r0", "r1", "r2", "r3", "r4");
173
 
174
#endif                          /* USE_INTRINSIC */
175
 
176
        colCnt = dim_vec & 0x1;
177
        while (colCnt)
178
        {
179
 
180
            q15_t     inV = *pA++;
181
            q15_t     inM = *pB++;
182
            q15_t     inM2 = *pB++;
183
            q15_t     inM3 = *pB++;
184
            q15_t     inM4 = *pB++;
185
 
186
            sum += inV * inM;
187
            sum2 += inV * inM2;
188
            sum3 += inV * inM3;
189
            sum4 += inV * inM4;
190
            colCnt--;
191
        }                       /* while over colCnt */
192
        *pO++ = (q15_t) (__SSAT((sum >> out_shift), 16));
193
        *pO++ = (q15_t) (__SSAT((sum2 >> out_shift), 16));
194
        *pO++ = (q15_t) (__SSAT((sum3 >> out_shift), 16));
195
        *pO++ = (q15_t) (__SSAT((sum4 >> out_shift), 16));
196
 
197
        /* adjust the pointers and counters */
198
        rowCnt--;
199
    }
200
 
201
    /* left-over part of the rows */
202
    rowCnt = num_of_rows & 0x3;
203
 
204
    while (rowCnt)
205
    {
206
        q31_t     sum = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
207
 
208
        uint16_t  colCnt = dim_vec >> 2;
209
 
210
        pA = pV;
211
 
212
        while (colCnt)
213
        {
214
            q31_t     inV1, inV2, inM1, inM2;
215
 
216
            inM1 = *__SIMD32(pB)++;
217
            inV1 = *__SIMD32(pA)++;
218
            sum = __SMLAD(inV1, inM1, sum);
219
 
220
            inM2 = *__SIMD32(pB)++;
221
            inV2 = *__SIMD32(pA)++;
222
            sum = __SMLAD(inV2, inM2, sum);
223
 
224
            colCnt--;
225
        }
226
 
227
        /* left-over of the vector */
228
        colCnt = dim_vec & 0x3;
229
        while (colCnt)
230
        {
231
            q15_t     inV = *pA++;
232
            q15_t     inM = *pB++;
233
            sum += inV * inM;
234
            colCnt--;
235
        }
236
 
237
        *pO++ = (q15_t) (__SSAT((sum >> out_shift), 16));
238
 
239
        rowCnt--;
240
    }
241
 
242
#else
243
    /* Run the following code as reference implementation for Cortex-M0 and Cortex-M3 */
244
    uint16_t  rowCnt = num_of_rows >> 2;
245
    const q15_t *pB = pM;
246
    const q15_t *pA;
247
    q15_t    *pO = pOut;
248
    const q15_t *pBias = bias;
249
 
250
    while (rowCnt)
251
    {
252
        q31_t     sum =  ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
253
        q31_t     sum2 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
254
        q31_t     sum3 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
255
        q31_t     sum4 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
256
 
257
        uint16_t  colCnt = dim_vec >> 1;
258
 
259
        pA = pV;
260
        while (colCnt)
261
        {
262
            q15_t     inA1 = *pA++;
263
            q15_t     inA2 = *pA++;
264
 
265
            q15_t     inB1 = *pB++;
266
            q15_t     inB2 = *pB++;
267
            sum += inA1 * inB1 + inA2 * inB2;
268
 
269
            inB1 = *pB++;
270
            inB2 = *pB++;
271
            sum2 += inA1 * inB1 + inA2 * inB2;
272
 
273
            inB1 = *pB++;
274
            inB2 = *pB++;
275
            sum3 += inA1 * inB1 + inA2 * inB2;
276
 
277
            inB1 = *pB++;
278
            inB2 = *pB++;
279
            sum4 += inA1 * inB1 + inA2 * inB2;
280
 
281
            colCnt--;
282
        }
283
        colCnt = dim_vec & 0x1;
284
        while (colCnt)
285
        {
286
            q15_t     inA = *pA++;
287
            q15_t     inB = *pB++;
288
            sum += inA * inB;
289
            inB = *pB++;
290
            sum2 += inA * inB;
291
            inB = *pB++;
292
            sum3 += inA * inB;
293
            inB = *pB++;
294
            sum4 += inA * inB;
295
            colCnt--;
296
        }
297
        *pO++ = (q15_t) __SSAT((sum >> out_shift), 16);
298
        *pO++ = (q15_t) __SSAT((sum2 >> out_shift), 16);
299
        *pO++ = (q15_t) __SSAT((sum3 >> out_shift), 16);
300
        *pO++ = (q15_t) __SSAT((sum4 >> out_shift), 16);
301
 
302
        rowCnt--;
303
    }
304
    rowCnt = num_of_rows & 0x3;
305
 
306
    while (rowCnt)
307
    {
308
        int       ip_out = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
309
        int       j;
310
 
311
        pA = pV;
312
        for (j = 0; j < dim_vec; j++)
313
        {
314
            q15_t     inA = *pA++;
315
            q15_t     inB = *pB++;
316
            ip_out += inA * inB;
317
        }
318
        *pO++ = (q15_t) __SSAT((ip_out >> out_shift), 16);
319
 
320
        rowCnt--;
321
    }
322
 
323
#endif                          /* ARM_MATH_DSP */
324
 
325
    /* Return to ARM_MATH_SUCCESS */
326
    return (ARM_MATH_SUCCESS);
327
 
328
}
329
 
330
/**
331
 * @} end of FC group
332
 */