pytorch  1.8.2
About: PyTorch provides Tensor computation (like NumPy) with strong GPU acceleration and Deep Neural Networks (in Python) built on a tape-based autograd system. LTS (Long Term Support) release.
  Fossies Dox: pytorch-1.8.2.tar.gz  ("unofficial" and yet experimental doxygen-generated source code documentation)  

4x8-neon.c
Go to the documentation of this file.
1/*
2 * Copyright (c) Facebook, Inc. and its affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9#include <arm_neon.h>
10
11#include <qnnpack/q8conv.h>
13
15 size_t mr,
16 size_t nr,
17 size_t kc,
18 size_t ks,
19 const uint8_t** restrict a,
20 const void* restrict w,
21 uint8_t* restrict c,
22 size_t c_stride,
23 size_t output_channel_index,
25 quantization_params[restrict static 1]) {
26 const uint8x8_t va_zero_point =
27 vld1_dup_u8((const uint8_t*)&quantization_params->neon.input_zero_point);
28 // Assumes that kernel_zero_points is an array padded with necessary elements
29 // in order to make it multiple of 8.
30 const uint8x8_t vb_zero_point =
31 vld1_u8((const uint8_t*)&quantization_params->neon.kernel_zero_points
32 [output_channel_index]);
33
34 int32x4_t vacc0x0123 = vld1q_s32(w);
35 w = (void*)((uintptr_t)w + sizeof(int32x4_t));
36 int32x4_t vacc0x4567 = vld1q_s32(w);
37 w = (void*)((uintptr_t)w + sizeof(int32x4_t));
38 int32x4_t vacc1x0123 = vacc0x0123;
39 int32x4_t vacc1x4567 = vacc0x4567;
40 int32x4_t vacc2x0123 = vacc0x0123;
41 int32x4_t vacc2x4567 = vacc0x4567;
42 int32x4_t vacc3x0123 = vacc0x0123;
43 int32x4_t vacc3x4567 = vacc0x4567;
44
45 do {
46 const uint8_t* restrict a0 = *a++;
47 const uint8_t* restrict a1 = *a++;
48 const uint8_t* restrict a2 = *a++;
49 const uint8_t* restrict a3 = *a++;
50
51 size_t k = kc;
52 for (; k >= 8; k -= 8) {
53 const uint8x8_t va0 = vld1_u8(a0);
54 a0 += 8;
55 const uint8x8_t va1 = vld1_u8(a1);
56 a1 += 8;
57 const uint8x8_t va2 = vld1_u8(a2);
58 a2 += 8;
59 const uint8x8_t va3 = vld1_u8(a3);
60 a3 += 8;
61 const int16x8_t vxa0 =
62 vreinterpretq_s16_u16(sub_zero_point(va0, va_zero_point));
63 const int16x8_t vxa1 =
64 vreinterpretq_s16_u16(sub_zero_point(va1, va_zero_point));
65 const int16x8_t vxa2 =
66 vreinterpretq_s16_u16(sub_zero_point(va2, va_zero_point));
67 const int16x8_t vxa3 =
68 vreinterpretq_s16_u16(sub_zero_point(va3, va_zero_point));
69
70 {
71 const uint8x8_t vb01234567 = vld1_u8(w);
72 w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
73 const int16x8_t vxb01234567 =
74 vreinterpretq_s16_u16(vsubl_u8(vb01234567, vb_zero_point));
75
76 vacc0x0123 = vmlal_lane_s16(
77 vacc0x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa0), 0);
78 vacc0x4567 = vmlal_lane_s16(
79 vacc0x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa0), 0);
80 vacc1x0123 = vmlal_lane_s16(
81 vacc1x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa1), 0);
82 vacc1x4567 = vmlal_lane_s16(
83 vacc1x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa1), 0);
84 vacc2x0123 = vmlal_lane_s16(
85 vacc2x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa2), 0);
86 vacc2x4567 = vmlal_lane_s16(
87 vacc2x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa2), 0);
88 vacc3x0123 = vmlal_lane_s16(
89 vacc3x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa3), 0);
90 vacc3x4567 = vmlal_lane_s16(
91 vacc3x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa3), 0);
92 }
93
94 {
95 const uint8x8_t vb01234567 = vld1_u8(w);
96 w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
97 const int16x8_t vxb01234567 =
98 vreinterpretq_s16_u16(vsubl_u8(vb01234567, vb_zero_point));
99
100 vacc0x0123 = vmlal_lane_s16(
101 vacc0x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa0), 1);
102 vacc0x4567 = vmlal_lane_s16(
103 vacc0x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa0), 1);
104 vacc1x0123 = vmlal_lane_s16(
105 vacc1x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa1), 1);
106 vacc1x4567 = vmlal_lane_s16(
107 vacc1x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa1), 1);
108 vacc2x0123 = vmlal_lane_s16(
109 vacc2x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa2), 1);
110 vacc2x4567 = vmlal_lane_s16(
111 vacc2x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa2), 1);
112 vacc3x0123 = vmlal_lane_s16(
113 vacc3x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa3), 1);
114 vacc3x4567 = vmlal_lane_s16(
115 vacc3x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa3), 1);
116 }
117
118 {
119 const uint8x8_t vb01234567 = vld1_u8(w);
120 w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
121 const int16x8_t vxb01234567 =
122 vreinterpretq_s16_u16(vsubl_u8(vb01234567, vb_zero_point));
123
124 vacc0x0123 = vmlal_lane_s16(
125 vacc0x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa0), 2);
126 vacc0x4567 = vmlal_lane_s16(
127 vacc0x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa0), 2);
128 vacc1x0123 = vmlal_lane_s16(
129 vacc1x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa1), 2);
130 vacc1x4567 = vmlal_lane_s16(
131 vacc1x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa1), 2);
132 vacc2x0123 = vmlal_lane_s16(
133 vacc2x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa2), 2);
134 vacc2x4567 = vmlal_lane_s16(
135 vacc2x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa2), 2);
136 vacc3x0123 = vmlal_lane_s16(
137 vacc3x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa3), 2);
138 vacc3x4567 = vmlal_lane_s16(
139 vacc3x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa3), 2);
140 }
141
142 {
143 const uint8x8_t vb01234567 = vld1_u8(w);
144 w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
145 const int16x8_t vxb01234567 =
146 vreinterpretq_s16_u16(vsubl_u8(vb01234567, vb_zero_point));
147
148 vacc0x0123 = vmlal_lane_s16(
149 vacc0x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa0), 3);
150 vacc0x4567 = vmlal_lane_s16(
151 vacc0x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa0), 3);
152 vacc1x0123 = vmlal_lane_s16(
153 vacc1x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa1), 3);
154 vacc1x4567 = vmlal_lane_s16(
155 vacc1x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa1), 3);
156 vacc2x0123 = vmlal_lane_s16(
157 vacc2x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa2), 3);
158 vacc2x4567 = vmlal_lane_s16(
159 vacc2x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa2), 3);
160 vacc3x0123 = vmlal_lane_s16(
161 vacc3x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa3), 3);
162 vacc3x4567 = vmlal_lane_s16(
163 vacc3x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa3), 3);
164 }
165
166 {
167 const uint8x8_t vb01234567 = vld1_u8(w);
168 w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
169 const int16x8_t vxb01234567 =
170 vreinterpretq_s16_u16(vsubl_u8(vb01234567, vb_zero_point));
171
172 vacc0x0123 = vmlal_lane_s16(
173 vacc0x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa0), 0);
174 vacc0x4567 = vmlal_lane_s16(
175 vacc0x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa0), 0);
176 vacc1x0123 = vmlal_lane_s16(
177 vacc1x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa1), 0);
178 vacc1x4567 = vmlal_lane_s16(
179 vacc1x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa1), 0);
180 vacc2x0123 = vmlal_lane_s16(
181 vacc2x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa2), 0);
182 vacc2x4567 = vmlal_lane_s16(
183 vacc2x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa2), 0);
184 vacc3x0123 = vmlal_lane_s16(
185 vacc3x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa3), 0);
186 vacc3x4567 = vmlal_lane_s16(
187 vacc3x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa3), 0);
188 }
189
190 {
191 const uint8x8_t vb01234567 = vld1_u8(w);
192 w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
193 const int16x8_t vxb01234567 =
194 vreinterpretq_s16_u16(vsubl_u8(vb01234567, vb_zero_point));
195
196 vacc0x0123 = vmlal_lane_s16(
197 vacc0x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa0), 1);
198 vacc0x4567 = vmlal_lane_s16(
199 vacc0x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa0), 1);
200 vacc1x0123 = vmlal_lane_s16(
201 vacc1x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa1), 1);
202 vacc1x4567 = vmlal_lane_s16(
203 vacc1x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa1), 1);
204 vacc2x0123 = vmlal_lane_s16(
205 vacc2x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa2), 1);
206 vacc2x4567 = vmlal_lane_s16(
207 vacc2x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa2), 1);
208 vacc3x0123 = vmlal_lane_s16(
209 vacc3x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa3), 1);
210 vacc3x4567 = vmlal_lane_s16(
211 vacc3x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa3), 1);
212 }
213
214 {
215 const uint8x8_t vb01234567 = vld1_u8(w);
216 w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
217 const int16x8_t vxb01234567 =
218 vreinterpretq_s16_u16(vsubl_u8(vb01234567, vb_zero_point));
219
220 vacc0x0123 = vmlal_lane_s16(
221 vacc0x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa0), 2);
222 vacc0x4567 = vmlal_lane_s16(
223 vacc0x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa0), 2);
224 vacc1x0123 = vmlal_lane_s16(
225 vacc1x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa1), 2);
226 vacc1x4567 = vmlal_lane_s16(
227 vacc1x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa1), 2);
228 vacc2x0123 = vmlal_lane_s16(
229 vacc2x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa2), 2);
230 vacc2x4567 = vmlal_lane_s16(
231 vacc2x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa2), 2);
232 vacc3x0123 = vmlal_lane_s16(
233 vacc3x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa3), 2);
234 vacc3x4567 = vmlal_lane_s16(
235 vacc3x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa3), 2);
236 }
237
238 {
239 const uint8x8_t vb01234567 = vld1_u8(w);
240 w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
241 const int16x8_t vxb01234567 =
242 vreinterpretq_s16_u16(vsubl_u8(vb01234567, vb_zero_point));
243
244 vacc0x0123 = vmlal_lane_s16(
245 vacc0x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa0), 3);
246 vacc0x4567 = vmlal_lane_s16(
247 vacc0x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa0), 3);
248 vacc1x0123 = vmlal_lane_s16(
249 vacc1x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa1), 3);
250 vacc1x4567 = vmlal_lane_s16(
251 vacc1x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa1), 3);
252 vacc2x0123 = vmlal_lane_s16(
253 vacc2x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa2), 3);
254 vacc2x4567 = vmlal_lane_s16(
255 vacc2x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa2), 3);
256 vacc3x0123 = vmlal_lane_s16(
257 vacc3x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa3), 3);
258 vacc3x4567 = vmlal_lane_s16(
259 vacc3x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa3), 3);
260 }
261 }
262 if (k != 0) {
263 const size_t a_predecrement = 8 - k;
264 const int64x1_t va_shift = vmov_n_s64(-8 * a_predecrement);
265 const uint8x8_t va0 = vreinterpret_u8_u64(vshl_u64(
266 vreinterpret_u64_u8(vld1_u8(a0 - a_predecrement)), va_shift));
267 const uint8x8_t va1 = vreinterpret_u8_u64(vshl_u64(
268 vreinterpret_u64_u8(vld1_u8(a1 - a_predecrement)), va_shift));
269 const uint8x8_t va2 = vreinterpret_u8_u64(vshl_u64(
270 vreinterpret_u64_u8(vld1_u8(a2 - a_predecrement)), va_shift));
271 const uint8x8_t va3 = vreinterpret_u8_u64(vshl_u64(
272 vreinterpret_u64_u8(vld1_u8(a3 - a_predecrement)), va_shift));
273 const int16x8_t vxa0 =
274 vreinterpretq_s16_u16(sub_zero_point(va0, va_zero_point));
275 const int16x8_t vxa1 =
276 vreinterpretq_s16_u16(sub_zero_point(va1, va_zero_point));
277 const int16x8_t vxa2 =
278 vreinterpretq_s16_u16(sub_zero_point(va2, va_zero_point));
279 const int16x8_t vxa3 =
280 vreinterpretq_s16_u16(sub_zero_point(va3, va_zero_point));
281
282 {
283 const uint8x8_t vb01234567 = vld1_u8(w);
284 w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
285 const int16x8_t vxb01234567 =
286 vreinterpretq_s16_u16(vsubl_u8(vb01234567, vb_zero_point));
287
288 vacc0x0123 = vmlal_lane_s16(
289 vacc0x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa0), 0);
290 vacc0x4567 = vmlal_lane_s16(
291 vacc0x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa0), 0);
292 vacc1x0123 = vmlal_lane_s16(
293 vacc1x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa1), 0);
294 vacc1x4567 = vmlal_lane_s16(
295 vacc1x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa1), 0);
296 vacc2x0123 = vmlal_lane_s16(
297 vacc2x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa2), 0);
298 vacc2x4567 = vmlal_lane_s16(
299 vacc2x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa2), 0);
300 vacc3x0123 = vmlal_lane_s16(
301 vacc3x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa3), 0);
302 vacc3x4567 = vmlal_lane_s16(
303 vacc3x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa3), 0);
304 }
305
306 if (k >= 2) {
307 const uint8x8_t vb01234567 = vld1_u8(w);
308 w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
309 const int16x8_t vxb01234567 =
310 vreinterpretq_s16_u16(vsubl_u8(vb01234567, vb_zero_point));
311
312 vacc0x0123 = vmlal_lane_s16(
313 vacc0x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa0), 1);
314 vacc0x4567 = vmlal_lane_s16(
315 vacc0x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa0), 1);
316 vacc1x0123 = vmlal_lane_s16(
317 vacc1x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa1), 1);
318 vacc1x4567 = vmlal_lane_s16(
319 vacc1x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa1), 1);
320 vacc2x0123 = vmlal_lane_s16(
321 vacc2x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa2), 1);
322 vacc2x4567 = vmlal_lane_s16(
323 vacc2x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa2), 1);
324 vacc3x0123 = vmlal_lane_s16(
325 vacc3x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa3), 1);
326 vacc3x4567 = vmlal_lane_s16(
327 vacc3x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa3), 1);
328
329 if (k > 2) {
330 const uint8x8_t vb01234567 = vld1_u8(w);
331 w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
332 const int16x8_t vxb01234567 =
333 vreinterpretq_s16_u16(vsubl_u8(vb01234567, vb_zero_point));
334
335 vacc0x0123 = vmlal_lane_s16(
336 vacc0x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa0), 2);
337 vacc0x4567 = vmlal_lane_s16(
338 vacc0x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa0), 2);
339 vacc1x0123 = vmlal_lane_s16(
340 vacc1x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa1), 2);
341 vacc1x4567 = vmlal_lane_s16(
342 vacc1x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa1), 2);
343 vacc2x0123 = vmlal_lane_s16(
344 vacc2x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa2), 2);
345 vacc2x4567 = vmlal_lane_s16(
346 vacc2x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa2), 2);
347 vacc3x0123 = vmlal_lane_s16(
348 vacc3x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa3), 2);
349 vacc3x4567 = vmlal_lane_s16(
350 vacc3x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa3), 2);
351
352 if (k >= 4) {
353 const uint8x8_t vb01234567 = vld1_u8(w);
354 w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
355 const int16x8_t vxb01234567 =
356 vreinterpretq_s16_u16(vsubl_u8(vb01234567, vb_zero_point));
357
358 vacc0x0123 = vmlal_lane_s16(
359 vacc0x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa0), 3);
360 vacc0x4567 = vmlal_lane_s16(
361 vacc0x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa0), 3);
362 vacc1x0123 = vmlal_lane_s16(
363 vacc1x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa1), 3);
364 vacc1x4567 = vmlal_lane_s16(
365 vacc1x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa1), 3);
366 vacc2x0123 = vmlal_lane_s16(
367 vacc2x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa2), 3);
368 vacc2x4567 = vmlal_lane_s16(
369 vacc2x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa2), 3);
370 vacc3x0123 = vmlal_lane_s16(
371 vacc3x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa3), 3);
372 vacc3x4567 = vmlal_lane_s16(
373 vacc3x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa3), 3);
374
375 if (k > 4) {
376 const uint8x8_t vb01234567 = vld1_u8(w);
377 w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
378 const int16x8_t vxb01234567 =
379 vreinterpretq_s16_u16(vsubl_u8(vb01234567, vb_zero_point));
380
381 vacc0x0123 = vmlal_lane_s16(
382 vacc0x0123,
383 vget_low_s16(vxb01234567),
384 vget_high_s16(vxa0),
385 0);
386 vacc0x4567 = vmlal_lane_s16(
387 vacc0x4567,
388 vget_high_s16(vxb01234567),
389 vget_high_s16(vxa0),
390 0);
391 vacc1x0123 = vmlal_lane_s16(
392 vacc1x0123,
393 vget_low_s16(vxb01234567),
394 vget_high_s16(vxa1),
395 0);
396 vacc1x4567 = vmlal_lane_s16(
397 vacc1x4567,
398 vget_high_s16(vxb01234567),
399 vget_high_s16(vxa1),
400 0);
401 vacc2x0123 = vmlal_lane_s16(
402 vacc2x0123,
403 vget_low_s16(vxb01234567),
404 vget_high_s16(vxa2),
405 0);
406 vacc2x4567 = vmlal_lane_s16(
407 vacc2x4567,
408 vget_high_s16(vxb01234567),
409 vget_high_s16(vxa2),
410 0);
411 vacc3x0123 = vmlal_lane_s16(
412 vacc3x0123,
413 vget_low_s16(vxb01234567),
414 vget_high_s16(vxa3),
415 0);
416 vacc3x4567 = vmlal_lane_s16(
417 vacc3x4567,
418 vget_high_s16(vxb01234567),
419 vget_high_s16(vxa3),
420 0);
421
422 if (k >= 6) {
423 const uint8x8_t vb01234567 = vld1_u8(w);
424 w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
425 const int16x8_t vxb01234567 =
426 vreinterpretq_s16_u16(vsubl_u8(vb01234567, vb_zero_point));
427
428 vacc0x0123 = vmlal_lane_s16(
429 vacc0x0123,
430 vget_low_s16(vxb01234567),
431 vget_high_s16(vxa0),
432 1);
433 vacc0x4567 = vmlal_lane_s16(
434 vacc0x4567,
435 vget_high_s16(vxb01234567),
436 vget_high_s16(vxa0),
437 1);
438 vacc1x0123 = vmlal_lane_s16(
439 vacc1x0123,
440 vget_low_s16(vxb01234567),
441 vget_high_s16(vxa1),
442 1);
443 vacc1x4567 = vmlal_lane_s16(
444 vacc1x4567,
445 vget_high_s16(vxb01234567),
446 vget_high_s16(vxa1),
447 1);
448 vacc2x0123 = vmlal_lane_s16(
449 vacc2x0123,
450 vget_low_s16(vxb01234567),
451 vget_high_s16(vxa2),
452 1);
453 vacc2x4567 = vmlal_lane_s16(
454 vacc2x4567,
455 vget_high_s16(vxb01234567),
456 vget_high_s16(vxa2),
457 1);
458 vacc3x0123 = vmlal_lane_s16(
459 vacc3x0123,
460 vget_low_s16(vxb01234567),
461 vget_high_s16(vxa3),
462 1);
463 vacc3x4567 = vmlal_lane_s16(
464 vacc3x4567,
465 vget_high_s16(vxb01234567),
466 vget_high_s16(vxa3),
467 1);
468
469 if (k > 6) {
470 const uint8x8_t vb01234567 = vld1_u8(w);
471 w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
472 const int16x8_t vxb01234567 = vreinterpretq_s16_u16(
473 vsubl_u8(vb01234567, vb_zero_point));
474
475 vacc0x0123 = vmlal_lane_s16(
476 vacc0x0123,
477 vget_low_s16(vxb01234567),
478 vget_high_s16(vxa0),
479 2);
480 vacc0x4567 = vmlal_lane_s16(
481 vacc0x4567,
482 vget_high_s16(vxb01234567),
483 vget_high_s16(vxa0),
484 2);
485 vacc1x0123 = vmlal_lane_s16(
486 vacc1x0123,
487 vget_low_s16(vxb01234567),
488 vget_high_s16(vxa1),
489 2);
490 vacc1x4567 = vmlal_lane_s16(
491 vacc1x4567,
492 vget_high_s16(vxb01234567),
493 vget_high_s16(vxa1),
494 2);
495 vacc2x0123 = vmlal_lane_s16(
496 vacc2x0123,
497 vget_low_s16(vxb01234567),
498 vget_high_s16(vxa2),
499 2);
500 vacc2x4567 = vmlal_lane_s16(
501 vacc2x4567,
502 vget_high_s16(vxb01234567),
503 vget_high_s16(vxa2),
504 2);
505 vacc3x0123 = vmlal_lane_s16(
506 vacc3x0123,
507 vget_low_s16(vxb01234567),
508 vget_high_s16(vxa3),
509 2);
510 vacc3x4567 = vmlal_lane_s16(
511 vacc3x4567,
512 vget_high_s16(vxb01234567),
513 vget_high_s16(vxa3),
514 2);
515 }
516 }
517 }
518 }
519 }
520 }
521 }
522 } while (--ks != 0);
523
524 // Doing 2 VLD1 instead of 1 VLD2 because A75 has higher latency
525 // 8 vs. 5 for VLD2 with both VLD1 and VLD2 having throughput of
526 // 2 per cycle. So probably this is better.
527 const float32x4_t requantization_scale_c0123 =
528 vld1q_f32(
529 &quantization_params->neon.requantization_scales[output_channel_index]
530 );
531 const float32x4_t requantization_scale_c4567 =
532 vld1q_f32(
533 &quantization_params->neon.requantization_scales[
534 output_channel_index + 4]);
535
536 /*
537 * Convert int32_t input to FP32 and multiply by FP32 scale.
538 * Both operations involve statistically unbiased roundings:
539 * - Large int32_t values can't be exactly represented as FP32. The
540 * conversion instruction in ARM NEON would round it to nearest FP32 value
541 * with ties to even.
542 * - Product of two FP32 values is generally not exactly representation as
543 * an FP32 value, and will be rounded to nearest FP32 value with ties to
544 * even.
545 */
546 const float32x4_t vacc0x0123_f =
547 vmulq_f32(vcvtq_f32_s32(vacc0x0123), requantization_scale_c0123);
548 const float32x4_t vacc1x0123_f =
549 vmulq_f32(vcvtq_f32_s32(vacc1x0123), requantization_scale_c0123);
550 const float32x4_t vacc2x0123_f =
551 vmulq_f32(vcvtq_f32_s32(vacc2x0123), requantization_scale_c0123);
552 const float32x4_t vacc3x0123_f =
553 vmulq_f32(vcvtq_f32_s32(vacc3x0123), requantization_scale_c0123);
554 const float32x4_t vacc0x4567_f =
555 vmulq_f32(vcvtq_f32_s32(vacc0x4567), requantization_scale_c4567);
556 const float32x4_t vacc1x4567_f =
557 vmulq_f32(vcvtq_f32_s32(vacc1x4567), requantization_scale_c4567);
558 const float32x4_t vacc2x4567_f =
559 vmulq_f32(vcvtq_f32_s32(vacc2x4567), requantization_scale_c4567);
560 const float32x4_t vacc3x4567_f =
561 vmulq_f32(vcvtq_f32_s32(vacc3x4567), requantization_scale_c4567);
562
563#ifdef __aarch64__
564 const int16x8_t voutput_zero_point =
565 vld1q_dup_s16(&quantization_params->neon.output_zero_point);
566 /*
567 * Leverage "Floating-point Convert to Signed integer, rounding to nearest
568 * with ties to even" instruction. This is an ARMv8 instruction (always
569 * available in AArch64), which saturates result on overflow. We don't need
570 * to specifically consider saturated results, they will be clamped at the
571 * last stage.
572 */
573 vacc0x0123 = vcvtnq_s32_f32(vacc0x0123_f);
574 vacc1x0123 = vcvtnq_s32_f32(vacc1x0123_f);
575 vacc2x0123 = vcvtnq_s32_f32(vacc2x0123_f);
576 vacc3x0123 = vcvtnq_s32_f32(vacc3x0123_f);
577 vacc0x4567 = vcvtnq_s32_f32(vacc0x4567_f);
578 vacc1x4567 = vcvtnq_s32_f32(vacc1x4567_f);
579 vacc2x4567 = vcvtnq_s32_f32(vacc2x4567_f);
580 vacc3x4567 = vcvtnq_s32_f32(vacc3x4567_f);
581
582 const int16x8_t vacc0x01234567 = vqaddq_s16(
583 vqmovn_high_s32(vqmovn_s32(vacc0x0123), vacc0x4567), voutput_zero_point);
584 const int16x8_t vacc1x01234567 = vqaddq_s16(
585 vqmovn_high_s32(vqmovn_s32(vacc1x0123), vacc1x4567), voutput_zero_point);
586 const int16x8_t vacc2x01234567 = vqaddq_s16(
587 vqmovn_high_s32(vqmovn_s32(vacc2x0123), vacc2x4567), voutput_zero_point);
588 const int16x8_t vacc3x01234567 = vqaddq_s16(
589 vqmovn_high_s32(vqmovn_s32(vacc3x0123), vacc3x4567), voutput_zero_point);
590
591 uint8x16_t vout0x01234567_1x01234567 =
592 vqmovun_high_s16(vqmovun_s16(vacc0x01234567), vacc1x01234567);
593 uint8x16_t vout2x01234567_3x01234567 =
594 vqmovun_high_s16(vqmovun_s16(vacc2x01234567), vacc3x01234567);
595
596 const uint8x16_t voutput_min =
597 vld1q_dup_u8(&quantization_params->neon.output_min);
598 const uint8x16_t voutput_max =
599 vld1q_dup_u8(&quantization_params->neon.output_max);
600
601 vout0x01234567_1x01234567 = vmaxq_u8(vout0x01234567_1x01234567, voutput_min);
602 vout2x01234567_3x01234567 = vmaxq_u8(vout2x01234567_3x01234567, voutput_min);
603 vout0x01234567_1x01234567 = vminq_u8(vout0x01234567_1x01234567, voutput_max);
604 vout2x01234567_3x01234567 = vminq_u8(vout2x01234567_3x01234567, voutput_max);
605#else
606 const float32x4_t vfmin = vdupq_n_f32(quantization_params->neon.vfmin);
607 const float32x4_t vfmax = vdupq_n_f32(quantization_params->neon.vfmax);
608 const float32x4_t vfmagic = vdupq_n_f32(quantization_params->neon.vfmagic);
609 const int32x4_t vimagic = vdupq_n_s32(quantization_params->neon.vimagic);
610 /*
611 * ARMv7 NEON offers only a floating-point to integer conversion instruction
612 * with rounding towards zero. In lieu of conversion instruction with
613 * rounding-to-nearest-even, we use a magic trick of adding a large number
614 * (1.5 * 2**23) to scaled value to cause rounding to integer, and then
615 * substracing this magic number as integer. This trick works only in a
616 * limited range (absolute value of input must be less than 2**22), so
617 * generally we have to clamp input to this range before using the magic.
618 * However, clamping to any smaller range works just as well, and thus we
619 * clamp to [qmin - zero point, qmax - zero point] range so that after we
620 * add zero point to the result, it gets into target [qmin, qmax] range.
621 */
622 const float32x4_t vacc0x0123_f_clamped =
623 vminq_f32(vmaxq_f32(vacc0x0123_f, vfmin), vfmax);
624 const float32x4_t vacc1x0123_f_clamped =
625 vminq_f32(vmaxq_f32(vacc1x0123_f, vfmin), vfmax);
626 const float32x4_t vacc2x0123_f_clamped =
627 vminq_f32(vmaxq_f32(vacc2x0123_f, vfmin), vfmax);
628 const float32x4_t vacc3x0123_f_clamped =
629 vminq_f32(vmaxq_f32(vacc3x0123_f, vfmin), vfmax);
630 const float32x4_t vacc0x4567_f_clamped =
631 vminq_f32(vmaxq_f32(vacc0x4567_f, vfmin), vfmax);
632 const float32x4_t vacc1x4567_f_clamped =
633 vminq_f32(vmaxq_f32(vacc1x4567_f, vfmin), vfmax);
634 const float32x4_t vacc2x4567_f_clamped =
635 vminq_f32(vmaxq_f32(vacc2x4567_f, vfmin), vfmax);
636 const float32x4_t vacc3x4567_f_clamped =
637 vminq_f32(vmaxq_f32(vacc3x4567_f, vfmin), vfmax);
638
639 /*
640 * Conversion to integer using the "magic trick". Rounding is performed in
641 * the output of addition operation, and result is rounded to nearest even
642 * integer with ties to even.
643 */
644 vacc0x0123 = vsubq_s32(
645 vreinterpretq_s32_f32(vaddq_f32(vacc0x0123_f_clamped, vfmagic)), vimagic);
646 vacc1x0123 = vsubq_s32(
647 vreinterpretq_s32_f32(vaddq_f32(vacc1x0123_f_clamped, vfmagic)), vimagic);
648 vacc2x0123 = vsubq_s32(
649 vreinterpretq_s32_f32(vaddq_f32(vacc2x0123_f_clamped, vfmagic)), vimagic);
650 vacc3x0123 = vsubq_s32(
651 vreinterpretq_s32_f32(vaddq_f32(vacc3x0123_f_clamped, vfmagic)), vimagic);
652 vacc0x4567 = vsubq_s32(
653 vreinterpretq_s32_f32(vaddq_f32(vacc0x4567_f_clamped, vfmagic)), vimagic);
654 vacc1x4567 = vsubq_s32(
655 vreinterpretq_s32_f32(vaddq_f32(vacc1x4567_f_clamped, vfmagic)), vimagic);
656 vacc2x4567 = vsubq_s32(
657 vreinterpretq_s32_f32(vaddq_f32(vacc2x4567_f_clamped, vfmagic)), vimagic);
658 vacc3x4567 = vsubq_s32(
659 vreinterpretq_s32_f32(vaddq_f32(vacc3x4567_f_clamped, vfmagic)), vimagic);
660
661 const int16x8_t vacc0x01234567 =
662 vcombine_s16(vqmovn_s32(vacc0x0123), vqmovn_s32(vacc0x4567));
663 const int16x8_t vacc1x01234567 =
664 vcombine_s16(vqmovn_s32(vacc1x0123), vqmovn_s32(vacc1x4567));
665 const int16x8_t vacc2x01234567 =
666 vcombine_s16(vqmovn_s32(vacc2x0123), vqmovn_s32(vacc2x4567));
667 const int16x8_t vacc3x01234567 =
668 vcombine_s16(vqmovn_s32(vacc3x0123), vqmovn_s32(vacc3x4567));
669
670 uint8x16_t vout0x01234567_1x01234567 =
671 vcombine_u8(vqmovun_s16(vacc0x01234567), vqmovun_s16(vacc1x01234567));
672 uint8x16_t vout2x01234567_3x01234567 =
673 vcombine_u8(vqmovun_s16(vacc2x01234567), vqmovun_s16(vacc3x01234567));
674#endif
675
676 uint8_t* c0 = c;
677 uint8_t* c1 = (uint8_t*)((uintptr_t)c0 + c_stride);
678 if (mr < 2) {
679 c1 = c0;
680 }
681 uint8_t* c2 = (uint8_t*)((uintptr_t)c1 + c_stride);
682 if (mr <= 2) {
683 c2 = c1;
684 }
685 uint8_t* c3 = (uint8_t*)((uintptr_t)c2 + c_stride);
686 if (mr != 4) {
687 c3 = c2;
688 }
689 if (nr == 8) {
690 vst1_u8(c0, vget_low_u8(vout0x01234567_1x01234567));
691 vst1_u8(c1, vget_high_u8(vout0x01234567_1x01234567));
692 vst1_u8(c2, vget_low_u8(vout2x01234567_3x01234567));
693 vst1_u8(c3, vget_high_u8(vout2x01234567_3x01234567));
694 } else {
695 if (nr >= 4) {
696 vst1q_lane_u32(
697 __builtin_assume_aligned(c0, 1),
698 vreinterpretq_u32_u8(vout0x01234567_1x01234567),
699 0);
700 c0 += 4;
701 vst1q_lane_u32(
702 __builtin_assume_aligned(c1, 1),
703 vreinterpretq_u32_u8(vout0x01234567_1x01234567),
704 2);
705 c1 += 4;
706 vst1q_lane_u32(
707 __builtin_assume_aligned(c2, 1),
708 vreinterpretq_u32_u8(vout2x01234567_3x01234567),
709 0);
710 c2 += 4;
711 vst1q_lane_u32(
712 __builtin_assume_aligned(c3, 1),
713 vreinterpretq_u32_u8(vout2x01234567_3x01234567),
714 2);
715 c3 += 4;
716 vout0x01234567_1x01234567 =
717 vextq_u8(vout0x01234567_1x01234567, vout0x01234567_1x01234567, 4);
718 vout2x01234567_3x01234567 =
719 vextq_u8(vout2x01234567_3x01234567, vout2x01234567_3x01234567, 4);
720 nr -= 4;
721 }
722 if (nr >= 2) {
723 vst1q_lane_u16(
724 __builtin_assume_aligned(c0, 1),
725 vreinterpretq_u16_u8(vout0x01234567_1x01234567),
726 0);
727 c0 += 2;
728 vst1q_lane_u16(
729 __builtin_assume_aligned(c1, 1),
730 vreinterpretq_u16_u8(vout0x01234567_1x01234567),
731 4);
732 c1 += 2;
733 vst1q_lane_u16(
734 __builtin_assume_aligned(c2, 1),
735 vreinterpretq_u16_u8(vout2x01234567_3x01234567),
736 0);
737 c2 += 2;
738 vst1q_lane_u16(
739 __builtin_assume_aligned(c3, 1),
740 vreinterpretq_u16_u8(vout2x01234567_3x01234567),
741 4);
742 c3 += 2;
743 vout0x01234567_1x01234567 =
744 vextq_u8(vout0x01234567_1x01234567, vout0x01234567_1x01234567, 2);
745 vout2x01234567_3x01234567 =
746 vextq_u8(vout2x01234567_3x01234567, vout2x01234567_3x01234567, 2);
747 nr -= 2;
748 }
749 if (nr != 0) {
750 vst1q_lane_u8(c0, vout0x01234567_1x01234567, 0);
751 vst1q_lane_u8(c1, vout0x01234567_1x01234567, 8);
752 vst1q_lane_u8(c2, vout2x01234567_3x01234567, 0);
753 vst1q_lane_u8(c3, vout2x01234567_3x01234567, 8);
754 }
755 }
756}
const int * c
PYTORCH_QNNP_INLINE uint16x8_t sub_zero_point(const uint8x8_t va, const uint8x8_t vzp)
Definition: runtime-neon.h:14
void pytorch_q8conv_ukernel_4x8__neon(size_t mr, size_t nr, size_t kc, size_t ks, const uint8_t **restrict a, const void *restrict w, uint8_t *restrict c, size_t c_stride, size_t output_channel_index, const union pytorch_qnnp_conv_quantization_params quantization_params[restrict static 1])
Definition: 4x8-neon.c:14