Updated and Validated
[betaflight.git] / lib / main / CMSIS / DSP / Source / MatrixFunctions / arm_mat_mult_fast_q31.c
blob78b33ef5962ec2d661008bb1a049829a79af544b
1 /* ----------------------------------------------------------------------
2 * Project: CMSIS DSP Library
3 * Title: arm_mat_mult_fast_q31.c
4 * Description: Q31 matrix multiplication (fast variant)
6 * $Date: 27. January 2017
7 * $Revision: V.1.5.1
9 * Target Processor: Cortex-M cores
10 * -------------------------------------------------------------------- */
12 * Copyright (C) 2010-2017 ARM Limited or its affiliates. All rights reserved.
14 * SPDX-License-Identifier: Apache-2.0
16 * Licensed under the Apache License, Version 2.0 (the License); you may
17 * not use this file except in compliance with the License.
18 * You may obtain a copy of the License at
20 * www.apache.org/licenses/LICENSE-2.0
22 * Unless required by applicable law or agreed to in writing, software
23 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
24 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
25 * See the License for the specific language governing permissions and
26 * limitations under the License.
29 #include "arm_math.h"
31 /**
32 * @ingroup groupMatrix
35 /**
36 * @addtogroup MatrixMult
37 * @{
40 /**
41 * @brief Q31 matrix multiplication (fast variant) for Cortex-M3 and Cortex-M4
42 * @param[in] *pSrcA points to the first input matrix structure
43 * @param[in] *pSrcB points to the second input matrix structure
44 * @param[out] *pDst points to output matrix structure
45 * @return The function returns either
46 * <code>ARM_MATH_SIZE_MISMATCH</code> or <code>ARM_MATH_SUCCESS</code> based on the outcome of size checking.
48 * @details
49 * <b>Scaling and Overflow Behavior:</b>
51 * \par
52 * The difference between the function arm_mat_mult_q31() and this fast variant is that
53 * the fast variant use a 32-bit rather than a 64-bit accumulator.
54 * The result of each 1.31 x 1.31 multiplication is truncated to
55 * 2.30 format. These intermediate results are accumulated in a 32-bit register in 2.30
56 * format. Finally, the accumulator is saturated and converted to a 1.31 result.
58 * \par
59 * The fast version has the same overflow behavior as the standard version but provides
60 * less precision since it discards the low 32 bits of each multiplication result.
61 * In order to avoid overflows completely the input signals must be scaled down.
62 * Scale down one of the input matrices by log2(numColsA) bits to
63 * avoid overflows, as a total of numColsA additions are computed internally for each
64 * output element.
66 * \par
67 * See <code>arm_mat_mult_q31()</code> for a slower implementation of this function
68 * which uses 64-bit accumulation to provide higher precision.
71 arm_status arm_mat_mult_fast_q31(
72 const arm_matrix_instance_q31 * pSrcA,
73 const arm_matrix_instance_q31 * pSrcB,
74 arm_matrix_instance_q31 * pDst)
76 q31_t *pInA = pSrcA->pData; /* input data matrix pointer A */
77 q31_t *pInB = pSrcB->pData; /* input data matrix pointer B */
78 q31_t *px; /* Temporary output data matrix pointer */
79 q31_t sum; /* Accumulator */
80 uint16_t numRowsA = pSrcA->numRows; /* number of rows of input matrix A */
81 uint16_t numColsB = pSrcB->numCols; /* number of columns of input matrix B */
82 uint16_t numColsA = pSrcA->numCols; /* number of columns of input matrix A */
83 uint32_t col, i = 0U, j, row = numRowsA, colCnt; /* loop counters */
84 arm_status status; /* status of matrix multiplication */
85 q31_t inA1, inB1;
87 #if defined (ARM_MATH_DSP)
89 q31_t sum2, sum3, sum4;
90 q31_t inA2, inB2;
91 q31_t *pInA2;
92 q31_t *px2;
94 #endif
96 #ifdef ARM_MATH_MATRIX_CHECK
98 /* Check for matrix mismatch condition */
99 if ((pSrcA->numCols != pSrcB->numRows) ||
100 (pSrcA->numRows != pDst->numRows) || (pSrcB->numCols != pDst->numCols))
102 /* Set status as ARM_MATH_SIZE_MISMATCH */
103 status = ARM_MATH_SIZE_MISMATCH;
105 else
106 #endif /* #ifdef ARM_MATH_MATRIX_CHECK */
110 px = pDst->pData;
112 #if defined (ARM_MATH_DSP)
113 row = row >> 1;
114 px2 = px + numColsB;
115 #endif
117 /* The following loop performs the dot-product of each row in pSrcA with each column in pSrcB */
118 /* row loop */
119 while (row > 0U)
122 /* For every row wise process, the column loop counter is to be initiated */
123 col = numColsB;
125 /* For every row wise process, the pIn2 pointer is set
126 ** to the starting address of the pSrcB data */
127 pInB = pSrcB->pData;
129 j = 0U;
131 #if defined (ARM_MATH_DSP)
132 col = col >> 1;
133 #endif
135 /* column loop */
136 while (col > 0U)
138 /* Set the variable sum, that acts as accumulator, to zero */
139 sum = 0;
141 /* Initiate data pointers */
142 pInA = pSrcA->pData + i;
143 pInB = pSrcB->pData + j;
145 #if defined (ARM_MATH_DSP)
146 sum2 = 0;
147 sum3 = 0;
148 sum4 = 0;
149 pInA2 = pInA + numColsA;
150 colCnt = numColsA;
151 #else
152 colCnt = numColsA >> 2;
153 #endif
155 /* matrix multiplication */
156 while (colCnt > 0U)
159 #if defined (ARM_MATH_DSP)
160 inA1 = *pInA++;
161 inB1 = pInB[0];
162 inA2 = *pInA2++;
163 inB2 = pInB[1];
164 pInB += numColsB;
166 sum = __SMMLA(inA1, inB1, sum);
167 sum2 = __SMMLA(inA1, inB2, sum2);
168 sum3 = __SMMLA(inA2, inB1, sum3);
169 sum4 = __SMMLA(inA2, inB2, sum4);
170 #else
171 /* c(m,n) = a(1,1)*b(1,1) + a(1,2) * b(2,1) + .... + a(m,p)*b(p,n) */
172 /* Perform the multiply-accumulates */
173 inB1 = *pInB;
174 pInB += numColsB;
175 inA1 = pInA[0];
176 sum = __SMMLA(inA1, inB1, sum);
178 inB1 = *pInB;
179 pInB += numColsB;
180 inA1 = pInA[1];
181 sum = __SMMLA(inA1, inB1, sum);
183 inB1 = *pInB;
184 pInB += numColsB;
185 inA1 = pInA[2];
186 sum = __SMMLA(inA1, inB1, sum);
188 inB1 = *pInB;
189 pInB += numColsB;
190 inA1 = pInA[3];
191 sum = __SMMLA(inA1, inB1, sum);
193 pInA += 4U;
194 #endif
196 /* Decrement the loop counter */
197 colCnt--;
200 #ifdef ARM_MATH_CM0_FAMILY
201 /* If the columns of pSrcA is not a multiple of 4, compute any remaining output samples here. */
202 colCnt = numColsA % 0x4U;
203 while (colCnt > 0U)
205 sum = __SMMLA(*pInA++, *pInB, sum);
206 pInB += numColsB;
207 colCnt--;
209 j++;
210 #endif
212 /* Convert the result from 2.30 to 1.31 format and store in destination buffer */
213 *px++ = sum << 1;
215 #if defined (ARM_MATH_DSP)
216 *px++ = sum2 << 1;
217 *px2++ = sum3 << 1;
218 *px2++ = sum4 << 1;
219 j += 2;
220 #endif
222 /* Decrement the column loop counter */
223 col--;
227 i = i + numColsA;
229 #if defined (ARM_MATH_DSP)
230 i = i + numColsA;
231 px = px2 + (numColsB & 1U);
232 px2 = px + numColsB;
233 #endif
235 /* Decrement the row loop counter */
236 row--;
240 /* Compute any remaining odd row/column below */
242 #if defined (ARM_MATH_DSP)
244 /* Compute remaining output column */
245 if (numColsB & 1U) {
247 /* Avoid redundant computation of last element */
248 row = numRowsA & (~0x1);
250 /* Point to remaining unfilled column in output matrix */
251 px = pDst->pData+numColsB-1;
252 pInA = pSrcA->pData;
254 /* row loop */
255 while (row > 0)
258 /* point to last column in matrix B */
259 pInB = pSrcB->pData + numColsB-1;
261 /* Set the variable sum, that acts as accumulator, to zero */
262 sum = 0;
264 /* Compute 4 columns at once */
265 colCnt = numColsA >> 2;
267 /* matrix multiplication */
268 while (colCnt > 0U)
270 inA1 = *pInA++;
271 inA2 = *pInA++;
272 inB1 = *pInB;
273 pInB += numColsB;
274 inB2 = *pInB;
275 pInB += numColsB;
276 sum = __SMMLA(inA1, inB1, sum);
277 sum = __SMMLA(inA2, inB2, sum);
279 inA1 = *pInA++;
280 inA2 = *pInA++;
281 inB1 = *pInB;
282 pInB += numColsB;
283 inB2 = *pInB;
284 pInB += numColsB;
285 sum = __SMMLA(inA1, inB1, sum);
286 sum = __SMMLA(inA2, inB2, sum);
288 /* Decrement the loop counter */
289 colCnt--;
292 colCnt = numColsA & 3U;
293 while (colCnt > 0U) {
294 sum = __SMMLA(*pInA++, *pInB, sum);
295 pInB += numColsB;
296 colCnt--;
299 /* Convert the result from 2.30 to 1.31 format and store in destination buffer */
300 *px = sum << 1;
301 px += numColsB;
303 /* Decrement the row loop counter */
304 row--;
308 /* Compute remaining output row */
309 if (numRowsA & 1U) {
311 /* point to last row in output matrix */
312 px = pDst->pData+(numColsB)*(numRowsA-1);
314 col = numColsB;
315 i = 0U;
317 /* col loop */
318 while (col > 0)
321 /* point to last row in matrix A */
322 pInA = pSrcA->pData + (numRowsA-1)*numColsA;
323 pInB = pSrcB->pData + i;
325 /* Set the variable sum, that acts as accumulator, to zero */
326 sum = 0;
328 /* Compute 4 columns at once */
329 colCnt = numColsA >> 2;
331 /* matrix multiplication */
332 while (colCnt > 0U)
334 inA1 = *pInA++;
335 inA2 = *pInA++;
336 inB1 = *pInB;
337 pInB += numColsB;
338 inB2 = *pInB;
339 pInB += numColsB;
340 sum = __SMMLA(inA1, inB1, sum);
341 sum = __SMMLA(inA2, inB2, sum);
343 inA1 = *pInA++;
344 inA2 = *pInA++;
345 inB1 = *pInB;
346 pInB += numColsB;
347 inB2 = *pInB;
348 pInB += numColsB;
349 sum = __SMMLA(inA1, inB1, sum);
350 sum = __SMMLA(inA2, inB2, sum);
352 /* Decrement the loop counter */
353 colCnt--;
356 colCnt = numColsA & 3U;
357 while (colCnt > 0U) {
358 sum = __SMMLA(*pInA++, *pInB, sum);
359 pInB += numColsB;
360 colCnt--;
363 /* Saturate and store the result in the destination buffer */
364 *px++ = sum << 1;
365 i++;
367 /* Decrement the col loop counter */
368 col--;
372 #endif /* #if defined (ARM_MATH_DSP) */
374 /* set status as ARM_MATH_SUCCESS */
375 status = ARM_MATH_SUCCESS;
378 /* Return to application */
379 return (status);
383 * @} end of MatrixMult group