Subversion Repositories FuelGauge

Rev

Details | Last modification | View Log | RSS feed

Rev Author Line No. Line
2 mjames 1
/*
2
 * Copyright (C) 2010-2018 Arm Limited or its affiliates. All rights reserved.
3
 *
4
 * SPDX-License-Identifier: Apache-2.0
5
 *
6
 * Licensed under the Apache License, Version 2.0 (the License); you may
7
 * not use this file except in compliance with the License.
8
 * You may obtain a copy of the License at
9
 *
10
 * www.apache.org/licenses/LICENSE-2.0
11
 *
12
 * Unless required by applicable law or agreed to in writing, software
13
 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
 * See the License for the specific language governing permissions and
16
 * limitations under the License.
17
 */
18
 
19
/* ----------------------------------------------------------------------
20
 * Project:      CMSIS NN Library
21
 * Title:        arm_fully_connected_q7_opt.c
22
 * Description:  Q7 basic fully-connected layer function
23
 *
24
 * $Date:        17. January 2018
25
 * $Revision:    V.1.0.0
26
 *
27
 * Target Processor:  Cortex-M cores
28
 *
29
 * -------------------------------------------------------------------- */
30
 
31
#include "arm_math.h"
32
#include "arm_nnfunctions.h"
33
 
34
/**
35
 *  @ingroup groupNN
36
 */
37
 
38
/**
39
 * @addtogroup FC
40
 * @{
41
 */
42
 
43
  /**
44
   * @brief Q7 opt fully-connected layer function
45
   * @param[in]       pV          pointer to input vector
46
   * @param[in]       pM          pointer to matrix weights
47
   * @param[in]       dim_vec     length of the vector
48
   * @param[in]       num_of_rows number of rows in weight matrix
49
   * @param[in]       bias_shift  amount of left-shift for bias
50
   * @param[in]       out_shift   amount of right-shift for output
51
   * @param[in]       bias        pointer to bias
52
   * @param[in,out]   pOut        pointer to output vector
53
   * @param[in,out]   vec_buffer  pointer to buffer space for input
54
   * @return     The function returns <code>ARM_MATH_SUCCESS</code>
55
   *
56
   * @details
57
   *
58
   * <b>Buffer size:</b>
59
   *
60
   * vec_buffer size: dim_vec
61
   *
62
   * This opt function is designed to work with interleaved weight
63
   * matrix. The vector input is assumed in q7_t format, we call
64
   *  arm_q7_to_q15_no_shift_shuffle function to expand into
65
   *  q15_t format with certain weight re-ordering, refer to the function
66
   *  comments for more details.
67
   *  Here we use only one pointer to read 4 rows in the weight
68
   *  matrix. So if the original q7_t matrix looks like this:
69
   *
70
   *  | a11 | a12 | a13 | a14 | a15 | a16 | a17 |
71
   *
72
   *  | a21 | a22 | a23 | a24 | a25 | a26 | a27 |
73
   *
74
   *  | a31 | a32 | a33 | a34 | a35 | a36 | a37 |
75
   *
76
   *  | a41 | a42 | a43 | a44 | a45 | a46 | a47 |
77
   *
78
   *  | a51 | a52 | a53 | a54 | a55 | a56 | a57 |
79
   *
80
   *  | a61 | a62 | a63 | a64 | a65 | a66 | a67 |
81
   *
82
   *
83
   *  We operates on multiple-of-4 rows, so the first four rows becomes
84
   *
85
   *  | a11 | a21 | a13 | a23 | a31 | a41 | a33 | a43 |
86
   *
87
   *  | a12 | a22 | a14 | a24 | a32 | a42 | a34 | a44 |
88
   *
89
   *  | a15 | a25 | a35 | a45 | a16 | a26 | a36 | a46 |
90
   *
91
   *  So within the kernel, we first read the re-ordered vector in as:
92
   *
93
   *  | b1  | b3  | and | b2  | b4  |
94
   *
95
   *  the four q31_t weights will look like
96
   *
97
   *  | a11 | a13 |, | a21 | a23 |, | a31 | a33 |, | a41 | a43 |
98
   *
99
   *  | a12 | a14 |, | a22 | a24 |, | a32 | a34 |, | a42 | a44 |
100
   *
101
   *  The column left over will be in-order.
102
   *  which is:
103
   *
104
   *  | a17 | a27 | a37 | a47 |
105
   *
106
   *  For the left-over rows, we do 1x1 computation, so the data remains
107
   *  as its original order.
108
   *
109
   *  So the stored weight matrix looks like this:
110
   *
111
   *  | a11 | a21 | a13 | a23 | a31 | a41 |
112
   *
113
   *  | a33 | a43 | a12 | a22 | a14 | a24 |
114
   *
115
   *  | a32 | a42 | a34 | a44 | a15 | a25 |
116
   *
117
   *  | a35 | a45 | a16 | a26 | a36 | a46 |
118
   *
119
   *  | a17 | a27 | a37 | a47 | a51 | a52 |
120
   *
121
   *  | a53 | a54 | a55 | a56 | a57 | a61 |
122
   *
123
   *  | a62 | a63 | a64 | a65 | a66 | a67 |
124
   *
125
   *
126
   */
127
 
128
arm_status
129
arm_fully_connected_q7_opt(const q7_t * pV,
130
                           const q7_t * pM,
131
                           const uint16_t dim_vec,
132
                           const uint16_t num_of_rows,
133
                           const uint16_t bias_shift,
134
                           const uint16_t out_shift,
135
                           const q7_t * bias,
136
                           q7_t * pOut,
137
                           q15_t * vec_buffer)
138
{
139
 
140
#if defined (ARM_MATH_DSP)
141
    /* Run the following code for Cortex-M4 and Cortex-M7 */
142
 
143
    const q7_t *pB = pM;
144
    q7_t     *pO = pOut;
145
    const q7_t *pBias = bias;
146
    q15_t    *pA;
147
    uint16_t  rowCnt = num_of_rows >> 2;
148
 
149
    arm_q7_to_q15_reordered_no_shift(pV, vec_buffer, dim_vec);
150
 
151
    while (rowCnt)
152
    {
153
 
154
        q31_t     sum =  ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
155
        q31_t     sum2 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
156
        q31_t     sum3 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
157
        q31_t     sum4 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
158
 
159
        uint16_t  colCnt = dim_vec >> 2;
160
 
161
        pA = vec_buffer;
162
 
163
#ifdef USE_INTRINSIC
164
 
165
#ifndef ARM_MATH_BIG_ENDIAN
166
        while (colCnt)
167
        {
168
            q31_t     inM11, inM12, inM13, inM14;
169
            q31_t     inV;
170
 
171
            inV = *__SIMD32(pA)++;
172
            inM11 = *__SIMD32(pB)++;
173
            inM12 = __SXTB16(__ROR(inM11, 8));
174
            inM11 = __SXTB16(inM11);
175
            sum = __SMLAD(inM11, inV, sum);
176
            sum2 = __SMLAD(inM12, inV, sum2);
177
            inM13 = *__SIMD32(pB)++;
178
            inM14 = __SXTB16(__ROR(inM13, 8));
179
            inM13 = __SXTB16(inM13);
180
            sum3 = __SMLAD(inM13, inV, sum3);
181
            sum4 = __SMLAD(inM14, inV, sum4);
182
 
183
            inV = *__SIMD32(pA)++;
184
            inM11 = *__SIMD32(pB)++;
185
            inM12 = __SXTB16(__ROR(inM11, 8));
186
            inM11 = __SXTB16(inM11);
187
            sum = __SMLAD(inM11, inV, sum);
188
            sum2 = __SMLAD(inM12, inV, sum2);
189
            inM13 = *__SIMD32(pB)++;
190
            inM14 = __SXTB16(__ROR(inM13, 8));
191
            inM13 = __SXTB16(inM13);
192
            sum3 = __SMLAD(inM13, inV, sum3);
193
            sum4 = __SMLAD(inM14, inV, sum4);
194
            colCnt--;
195
        }
196
#else
197
        while (colCnt)
198
        {
199
            q31_t     inM11, inM12, inM13, inM14;
200
            q31_t     inV;
201
 
202
            inV = *__SIMD32(pA)++;
203
            inM11 = *__SIMD32(pB)++;
204
            inM12 = __SXTB16(__ROR(inM11, 8));
205
            inM11 = __SXTB16(inM11);
206
            sum = __SMLAD(inM12, inV, sum);
207
            sum2 = __SMLAD(inM11, inV, sum2);
208
            inM13 = *__SIMD32(pB)++;
209
            inM14 = __SXTB16(__ROR(inM13, 8));
210
            inM13 = __SXTB16(inM13);
211
            sum3 = __SMLAD(inM14, inV, sum3);
212
            sum4 = __SMLAD(inM13, inV, sum4);
213
 
214
            inV = *__SIMD32(pA)++;
215
            inM11 = *__SIMD32(pB)++;
216
            inM12 = __SXTB16(__ROR(inM11, 8));
217
            inM11 = __SXTB16(inM11);
218
            sum = __SMLAD(inM12, inV, sum);
219
            sum2 = __SMLAD(inM11, inV, sum2);
220
            inM13 = *__SIMD32(pB)++;
221
            inM14 = __SXTB16(__ROR(inM13, 8));
222
            inM13 = __SXTB16(inM13);
223
            sum3 = __SMLAD(inM14, inV, sum3);
224
            sum4 = __SMLAD(inM13, inV, sum4);
225
            colCnt--;
226
        }
227
#endif                          /* ARM_MATH_BIG_ENDIAN */
228
 
229
#else
230
 
231
        /*
232
         * register needed:
233
         * loop counter: colCnt
234
         * accumulators: sum, sum2, sum3, sum4
235
         * pointers: pB, pA
236
         * weight data: inM11, inM12, inM13, inM14
237
         * activation data: inV
238
         */
239
 
240
#ifndef ARM_MATH_BIG_ENDIAN
241
        asm volatile ("COL_LOOP_%=:\n"
242
                      "ldr.w r4, [%[pA]], #8\n"
243
                      "ldr.w r1, [%[pB]], #16\n"
244
                      "mov.w r0, r1, ror #8\n"
245
                      "sxtb16 r0, r0\n"
246
                      "sxtb16 r1, r1\n"
247
                      "smlad %[sum], r4, r1, %[sum]\n"
248
                      "smlad %[sum2], r4, r0, %[sum2]\n"
249
                      "ldr.w r3, [%[pB], #-12]\n"
250
                      "mov.w r2, r3, ror #8\n"
251
                      "sxtb16 r2, r2\n"
252
                      "sxtb16 r3, r3\n"
253
                      "smlad %[sum3], r4, r3, %[sum3]\n"
254
                      "smlad %[sum4], r4, r2, %[sum4]\n"
255
                      "ldr.w r4, [%[pA], #-4]\n"
256
                      "ldr.w r1, [%[pB], #-8]\n"
257
                      "mov.w r0, r1, ror #8\n"
258
                      "sxtb16 r0, r0\n"
259
                      "sxtb16 r1, r1\n"
260
                      "smlad %[sum], r4, r1, %[sum]\n"
261
                      "smlad %[sum2], r4, r0, %[sum2]\n"
262
                      "ldr.w r3, [%[pB], #-4]\n"
263
                      "mov.w r2, r3, ror #8\n"
264
                      "sxtb16 r2, r2\n"
265
                      "sxtb16 r3, r3\n"
266
                      "smlad %[sum3], r4, r3, %[sum3]\n"
267
                      "smlad %[sum4], r4, r2, %[sum4]\n"
268
                      "subs %[colCnt], #1\n"
269
                      "bne COL_LOOP_%=\n":[sum] "+r"(sum),
270
                      [sum2] "+r"(sum2),[sum3] "+r"(sum3),
271
                      [sum4] "+r"(sum4),[pB] "+r"(pB),[pA] "+r"(pA):[colCnt] "r"(colCnt):"r0", "r1", "r2", "r3", "r4");
272
#else
273
        asm volatile ("COL_LOOP_%=:\n"
274
                      "ldr.w r4, [%[pA]], #8\n"
275
                      "ldr.w r1, [%[pB]], #16\n"
276
                      "mov.w r0, r1, ror #8\n"
277
                      "sxtb16 r0, r0\n"
278
                      "sxtb16 r1, r1\n"
279
                      "smlad %[sum], r4, r0, %[sum]\n"
280
                      "smlad %[sum2], r4, r1, %[sum2]\n"
281
                      "ldr.w r3, [%[pB], #-12]\n"
282
                      "mov.w r2, r3, ror #8\n"
283
                      "sxtb16 r2, r2\n"
284
                      "sxtb16 r3, r3\n"
285
                      "smlad %[sum3], r4, r2, %[sum3]\n"
286
                      "smlad %[sum4], r4, r3, %[sum4]\n"
287
                      "ldr.w r4, [%[pA], #-4]\n"
288
                      "ldr.w r1, [%[pB], #-8]\n"
289
                      "mov.w r0, r1, ror #8\n"
290
                      "sxtb16 r0, r0\n"
291
                      "sxtb16 r1, r1\n"
292
                      "smlad %[sum], r4, r0, %[sum]\n"
293
                      "smlad %[sum2], r4, r1, %[sum2]\n"
294
                      "ldr.w r3, [%[pB], #-4]\n"
295
                      "mov.w r2, r3, ror #8\n"
296
                      "sxtb16 r2, r2\n"
297
                      "sxtb16 r3, r3\n"
298
                      "smlad %[sum3], r4, r2, %[sum3]\n"
299
                      "smlad %[sum4], r4, r3, %[sum4]\n"
300
                      "subs %[colCnt], #1\n"
301
                      "bne COL_LOOP_%=\n":[sum] "+r"(sum),
302
                      [sum2] "+r"(sum2),[sum3] "+r"(sum3),
303
                      [sum4] "+r"(sum4),[pB] "+r"(pB),[pA] "+r"(pA):[colCnt] "r"(colCnt):"r0", "r1", "r2", "r3", "r4");
304
#endif                          /* ARM_MATH_BIG_ENDIAN */
305
 
306
#endif                          /* USE_INTRINSIC */
307
 
308
        colCnt = dim_vec & 0x3;
309
        while (colCnt)
310
        {
311
            q15_t     inV = *pA++;
312
            q7_t      inM = *pB++;
313
            q7_t      inM2 = *pB++;
314
            q7_t      inM3 = *pB++;
315
            q7_t      inM4 = *pB++;
316
 
317
            sum += inV * inM;
318
            sum2 += inV * inM2;
319
            sum3 += inV * inM3;
320
            sum4 += inV * inM4;
321
            colCnt--;
322
        }                       /* while over colCnt */
323
        *pO++ = (q7_t) (__SSAT((sum >> out_shift), 8));
324
        *pO++ = (q7_t) (__SSAT((sum2 >> out_shift), 8));
325
        *pO++ = (q7_t) (__SSAT((sum3 >> out_shift), 8));
326
        *pO++ = (q7_t) (__SSAT((sum4 >> out_shift), 8));
327
 
328
        /* adjust the pointers and counters */
329
        rowCnt--;
330
    }
331
 
332
    /* left-over part of the rows */
333
    rowCnt = num_of_rows & 0x3;
334
 
335
    while (rowCnt)
336
    {
337
        q31_t     sum = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
338
        uint16_t  colCnt = dim_vec >> 2;
339
 
340
        pA = vec_buffer;
341
 
342
        while (colCnt)
343
        {
344
            q31_t     inV1, inV2, inM11, inM12;
345
 
346
            pB = (q7_t *) read_and_pad_reordered((void *)pB, &inM11, &inM12);
347
 
348
            inV1 = *__SIMD32(pA)++;
349
            sum = __SMLAD(inV1, inM11, sum);
350
 
351
            inV2 = *__SIMD32(pA)++;
352
            sum = __SMLAD(inV2, inM12, sum);
353
 
354
            colCnt--;
355
        }
356
 
357
        /* left-over of the vector */
358
        colCnt = dim_vec & 0x3;
359
        while (colCnt)
360
        {
361
            q15_t     inV = *pA++;
362
            q7_t      inM = *pB++;
363
            sum += inV * inM;
364
            colCnt--;
365
        }
366
 
367
        *pO++ = (q7_t) (__SSAT((sum >> out_shift), 8));
368
 
369
        rowCnt--;
370
    }
371
 
372
#else
373
    /* Run the following code as reference implementation for Cortex-M0 and Cortex-M3 */
374
    uint16_t  rowCnt = num_of_rows >> 2;
375
    const q7_t *pB = pM;
376
    const q7_t *pA;
377
    q7_t     *pO = pOut;
378
    const q7_t *pBias = bias;
379
 
380
    while (rowCnt)
381
    {
382
        q31_t     sum =  ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
383
        q31_t     sum2 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
384
        q31_t     sum3 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
385
        q31_t     sum4 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
386
 
387
        uint16_t  colCnt = dim_vec >> 2;
388
 
389
        pA = pV;
390
 
391
        while (colCnt)
392
        {
393
            q7_t      inA1 = *pA++;
394
            q7_t      inA3 = *pA++;
395
            q7_t      inA2 = *pA++;
396
            q7_t      inA4 = *pA++;
397
 
398
            q7_t      inB1 = *pB++;
399
            q7_t      inB3 = *pB++;
400
            q7_t      inB2 = *pB++;
401
            q7_t      inB4 = *pB++;
402
 
403
            sum += inA1 * inB1 + inA2 * inB2;
404
            sum2 += inA1 * inB3 + inA2 * inB4;
405
 
406
            inB1 = *pB++;
407
            inB3 = *pB++;
408
            inB2 = *pB++;
409
            inB4 = *pB++;
410
 
411
            sum3 += inA1 * inB1 + inA2 * inB2;
412
            sum4 += inA1 * inB3 + inA2 * inB4;
413
 
414
            inB1 = *pB++;
415
            inB3 = *pB++;
416
            inB2 = *pB++;
417
            inB4 = *pB++;
418
 
419
            sum += inA3 * inB1 + inA4 * inB2;
420
            sum2 += inA3 * inB3 + inA4 * inB4;
421
 
422
            inB1 = *pB++;
423
            inB3 = *pB++;
424
            inB2 = *pB++;
425
            inB4 = *pB++;
426
 
427
            sum3 += inA3 * inB1 + inA4 * inB2;
428
            sum4 += inA3 * inB3 + inA4 * inB4;
429
 
430
            colCnt--;
431
        }
432
        colCnt = dim_vec & 0x3;
433
        while (colCnt)
434
        {
435
            q7_t      inA = *pA++;
436
            q7_t      inB = *pB++;
437
            sum += inA * inB;
438
            inB = *pB++;
439
            sum2 += inA * inB;
440
            inB = *pB++;
441
            sum3 += inA * inB;
442
            inB = *pB++;
443
            sum4 += inA * inB;
444
 
445
            colCnt--;
446
        }
447
        *pO++ = (q7_t) __SSAT((sum >> out_shift), 8);
448
        *pO++ = (q7_t) __SSAT((sum2 >> out_shift), 8);
449
        *pO++ = (q7_t) __SSAT((sum3 >> out_shift), 8);
450
        *pO++ = (q7_t) __SSAT((sum4 >> out_shift), 8);
451
 
452
        rowCnt--;
453
    }
454
 
455
    rowCnt = num_of_rows & 0x3;
456
 
457
    while (rowCnt)
458
    {
459
        int       ip_out = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
460
 
461
        int       j;
462
 
463
        pA = pV;
464
        for (j = 0; j < dim_vec; j++)
465
        {
466
            q7_t      inA = *pA++;
467
            q7_t      inB = *pB++;
468
            ip_out += inA * inB;
469
        }
470
        *pO++ = (q7_t) __SSAT((ip_out >> out_shift), 8);
471
 
472
        rowCnt--;
473
    }
474
 
475
#endif                          /* ARM_MATH_DSP */
476
 
477
    /* Return to ARM_MATH_SUCCESS */
478
    return (ARM_MATH_SUCCESS);
479
 
480
}
481
 
482
/**
483
 * @} end of FC group
484
 */