Subversion Repositories canSerial

Rev

Details | Last modification | View Log | RSS feed

Rev Author Line No. Line
2 mjames 1
/*
2
 * Copyright (C) 2010-2018 Arm Limited or its affiliates. All rights reserved.
3
 *
4
 * SPDX-License-Identifier: Apache-2.0
5
 *
6
 * Licensed under the Apache License, Version 2.0 (the License); you may
7
 * not use this file except in compliance with the License.
8
 * You may obtain a copy of the License at
9
 *
10
 * www.apache.org/licenses/LICENSE-2.0
11
 *
12
 * Unless required by applicable law or agreed to in writing, software
13
 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
 * See the License for the specific language governing permissions and
16
 * limitations under the License.
17
 */
18
 
19
/* ----------------------------------------------------------------------
20
 * Project:      CMSIS NN Library
21
 * Title:        arm_fully_connected_q15.c
22
 * Description:  Q15 basic fully-connected layer function
23
 *
24
 * $Date:        17. January 2018
25
 * $Revision:    V.1.0.0
26
 *
27
 * Target Processor:  Cortex-M cores
28
 *
29
 * -------------------------------------------------------------------- */
30
 
31
#include "arm_math.h"
32
#include "arm_nnfunctions.h"
33
 
34
/**
35
 *  @ingroup groupNN
36
 */
37
 
38
/**
39
 * @addtogroup FC
40
 * @{
41
 */
42
 
43
  /**
44
   * @brief Q15 opt fully-connected layer function
45
   * @param[in]       pV          pointer to input vector
46
   * @param[in]       pM          pointer to matrix weights
47
   * @param[in]       dim_vec     length of the vector
48
   * @param[in]       num_of_rows number of rows in weight matrix
49
   * @param[in]       bias_shift  amount of left-shift for bias
50
   * @param[in]       out_shift   amount of right-shift for output
51
   * @param[in]       bias        pointer to bias
52
   * @param[in,out]   pOut        pointer to output vector
53
   * @param[in,out]   vec_buffer  pointer to buffer space for input
54
   * @return     The function returns <code>ARM_MATH_SUCCESS</code>
55
   *
56
   *
57
   * @details
58
   *
59
   * <b>Buffer size:</b>
60
   *
61
   * vec_buffer size: 0
62
   *
63
   */
64
 
65
arm_status
66
arm_fully_connected_q15(const q15_t * pV,
67
                        const q15_t * pM,
68
                        const uint16_t dim_vec,
69
                        const uint16_t num_of_rows,
70
                        const uint16_t bias_shift,
71
                        const uint16_t out_shift,
72
                        const q15_t * bias,
73
                        q15_t * pOut,
74
                        q15_t * vec_buffer)
75
{
76
 
77
#if defined (ARM_MATH_DSP)
78
    /* Run the following code for Cortex-M4 and Cortex-M7 */
79
 
80
    const q15_t *pB = pM;
81
    const q15_t *pB2 = pB + dim_vec;
82
    q15_t    *pO = pOut;
83
    const q15_t    *pA;
84
    const q15_t    *pBias = bias;
85
    uint16_t rowCnt = num_of_rows >> 1;
86
 
87
    /* this loop loops over different output */
88
    while (rowCnt) {
89
        q31_t     sum =  ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
90
        q31_t     sum2 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
91
 
92
        uint16_t  colCnt = dim_vec >> 2;
93
 
94
        pA = pV;
95
        pB2 = pB + dim_vec;
96
 
97
        while (colCnt)
98
        {
99
            q31_t     inV1, inM1, inM2;
100
            inV1 = *__SIMD32(pA)++;
101
            inM1 = *__SIMD32(pB)++;
102
            sum = __SMLAD(inV1, inM1, sum);
103
            inM2 = *__SIMD32(pB2)++;
104
            sum2 = __SMLAD(inV1, inM2, sum2);
105
 
106
            inV1 = *__SIMD32(pA)++;
107
            inM1 = *__SIMD32(pB)++;
108
            sum = __SMLAD(inV1, inM1, sum);
109
            inM2 = *__SIMD32(pB2)++;
110
            sum2 = __SMLAD(inV1, inM2, sum2);
111
 
112
            colCnt--;
113
        }
114
        colCnt = dim_vec & 0x3;
115
        while (colCnt)
116
        {
117
            q15_t     inV = *pA++;
118
            q15_t     inM = *pB++;
119
            q15_t     inM2 = *pB2++;
120
 
121
            sum += inV * inM;
122
            sum2 += inV * inM2;
123
            colCnt--;
124
        }                       /* while over colCnt */
125
        *pO++ =  (q15_t) (__SSAT((sum >> out_shift), 16));
126
        *pO++ = (q15_t) (__SSAT((sum2>> out_shift), 16));
127
 
128
        /* adjust the pointers and counters */
129
        pB = pB + dim_vec;
130
        rowCnt --;
131
    }
132
 
133
    rowCnt = num_of_rows & 0x1;
134
 
135
    while (rowCnt) {
136
        q31_t     sum = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
137
 
138
        uint16_t  colCnt = dim_vec >> 2;
139
 
140
        pA = pV;
141
 
142
        while (colCnt) {
143
            q31_t     inV1, inM1;
144
            inV1 = *__SIMD32(pA)++;
145
            inM1 = *__SIMD32(pB)++;
146
            sum = __SMLAD(inV1, inM1, sum);
147
 
148
            inV1 = *__SIMD32(pA)++;
149
            inM1 = *__SIMD32(pB)++;
150
            sum = __SMLAD(inV1, inM1, sum);
151
 
152
            colCnt--;
153
        }
154
 
155
        /* left-over of the vector */
156
        colCnt = dim_vec & 0x3;
157
        while(colCnt) {
158
            q15_t     inV = *pA++;
159
            q15_t     inM = *pB++;
160
 
161
            sum += inV * inM;
162
 
163
            colCnt--;
164
        }
165
 
166
        *pO++ =  (q15_t) (__SSAT((sum >> out_shift), 16));
167
 
168
        rowCnt --;
169
    }
170
 
171
#else
172
    int       i, j;
173
    /* Run the following code as reference implementation for Cortex-M0 and Cortex-M3 */
174
    for (i = 0; i < num_of_rows; i++)
175
    {
176
        int       ip_out = ((q31_t)(bias[i]) << bias_shift) + NN_ROUND(out_shift);
177
        for (j = 0; j < dim_vec; j++)
178
        {
179
            ip_out += pV[j] * pM[i * dim_vec + j];
180
        }
181
        pOut[i] = (q15_t) __SSAT((ip_out >> out_shift), 16);
182
    }
183
 
184
#endif                          /* ARM_MATH_DSP */
185
 
186
    /* Return to application */
187
    return (ARM_MATH_SUCCESS);
188
 
189
}
190
 
191
/**
192
 * @} end of FC group
193
 */