"Fossies" - the Fresh Open Source Software Archive

Member "pytorch-1.8.2/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm/4x8c2-xzp-neon.c" (23 Jul 2021, 22120 Bytes) of package /linux/misc/pytorch-1.8.2.tar.gz:


As a special service "Fossies" has tried to format the requested source page into HTML format using (guessed) C and C++ source code syntax highlighting (style: standard) with prefixed line numbers and code folding option. Alternatively you can here view or download the uninterpreted source code file. For more information about "4x8c2-xzp-neon.c" see the Fossies "Dox" file reference documentation.

    1 /*
    2  * Copyright (c) Facebook, Inc. and its affiliates.
    3  * All rights reserved.
    4  *
    5  * This source code is licensed under the BSD-style license found in the
    6  * LICENSE file in the root directory of this source tree.
    7  */
    8 
    9 #include <arm_neon.h>
   10 
   11 #include <qnnpack/q8gemm.h>
   12 
   13 void pytorch_q8gemm_xzp_ukernel_4x8c2__neon(
   14     size_t mr,
   15     size_t nr,
   16     size_t k,
   17     const uint8_t* restrict a,
   18     size_t a_stride,
   19     const int32_t* restrict a_sum,
   20     const void* restrict w,
   21     uint8_t* restrict c,
   22     size_t c_stride,
   23     const union pytorch_qnnp_q31_requantization_params
   24         requantization_params[restrict static 1]) {
   25   int32x4_t vacc0x0123 = vld1q_s32(w);
   26   w = (const void*)((uintptr_t)w + 16);
   27   int32x4_t vacc0x4567 = vld1q_s32(w);
   28   w = (const void*)((uintptr_t)w + 16);
   29   int32x4_t vacc1x0123 = vacc0x0123;
   30   int32x4_t vacc1x4567 = vacc0x4567;
   31   int32x4_t vacc2x0123 = vacc0x0123;
   32   int32x4_t vacc2x4567 = vacc0x4567;
   33   int32x4_t vacc3x0123 = vacc0x0123;
   34   int32x4_t vacc3x4567 = vacc0x4567;
   35 
   36   const uint8_t* a0 = a;
   37   const uint8_t* a1 = a0;
   38   const int32_t* a_sum0 = a_sum;
   39   const int32_t* a_sum1 = a_sum0;
   40   if (mr >= 2) {
   41     a1 += a_stride;
   42     a_sum1 += 1;
   43   }
   44   const uint8_t* a2 = a1;
   45   const int32_t* a_sum2 = a_sum1;
   46   if (mr > 2) {
   47     a2 += a_stride;
   48     a_sum2 += 1;
   49   }
   50   const uint8_t* a3 = a2;
   51   const int32_t* a_sum3 = a_sum2;
   52   if (mr == 4) {
   53     a3 += a_stride;
   54     a_sum3 += 1;
   55   }
   56 
   57   const int32x4_t va_sum0 = vld1q_dup_s32(a_sum0);
   58   const int32x4_t va_sum1 = vld1q_dup_s32(a_sum1);
   59   const int32x4_t va_sum2 = vld1q_dup_s32(a_sum2);
   60   const int32x4_t va_sum3 = vld1q_dup_s32(a_sum3);
   61   vacc0x0123 = vaddq_s32(vacc0x0123, va_sum0);
   62   vacc0x4567 = vaddq_s32(vacc0x4567, va_sum0);
   63   vacc1x0123 = vaddq_s32(vacc1x0123, va_sum1);
   64   vacc1x4567 = vaddq_s32(vacc1x4567, va_sum1);
   65   vacc2x0123 = vaddq_s32(vacc2x0123, va_sum2);
   66   vacc2x4567 = vaddq_s32(vacc2x4567, va_sum2);
   67   vacc3x0123 = vaddq_s32(vacc3x0123, va_sum3);
   68   vacc3x4567 = vaddq_s32(vacc3x4567, va_sum3);
   69 
   70   for (; k >= 8; k -= 8) {
   71     uint8x8_t va0x01234567 = vld1_u8(a0);
   72     a0 += 8;
   73     uint8x8_t va1x01234567 = vld1_u8(a1);
   74     a1 += 8;
   75     uint8x8_t va2x01234567 = vld1_u8(a2);
   76     a2 += 8;
   77     uint8x8_t va3x01234567 = vld1_u8(a3);
   78     a3 += 8;
   79 
   80     /* k = 0, 1 */
   81     const uint8x16_t vb01234567x01 = vld1q_u8(w);
   82     w += 16;
   83 
   84     vacc0x0123 = vreinterpretq_s32_u32(vpadalq_u16(
   85         vreinterpretq_u32_s32(vacc0x0123),
   86         vmull_u8(va0x01234567, vget_low_u8(vb01234567x01))));
   87     vacc0x4567 = vreinterpretq_s32_u32(vpadalq_u16(
   88         vreinterpretq_u32_s32(vacc0x4567),
   89         vmull_u8(va0x01234567, vget_high_u8(vb01234567x01))));
   90 
   91     vacc1x0123 = vreinterpretq_s32_u32(vpadalq_u16(
   92         vreinterpretq_u32_s32(vacc1x0123),
   93         vmull_u8(va1x01234567, vget_low_u8(vb01234567x01))));
   94     vacc1x4567 = vreinterpretq_s32_u32(vpadalq_u16(
   95         vreinterpretq_u32_s32(vacc1x4567),
   96         vmull_u8(va1x01234567, vget_high_u8(vb01234567x01))));
   97 
   98     vacc2x0123 = vreinterpretq_s32_u32(vpadalq_u16(
   99         vreinterpretq_u32_s32(vacc2x0123),
  100         vmull_u8(va2x01234567, vget_low_u8(vb01234567x01))));
  101     vacc2x4567 = vreinterpretq_s32_u32(vpadalq_u16(
  102         vreinterpretq_u32_s32(vacc2x4567),
  103         vmull_u8(va2x01234567, vget_high_u8(vb01234567x01))));
  104 
  105     vacc3x0123 = vreinterpretq_s32_u32(vpadalq_u16(
  106         vreinterpretq_u32_s32(vacc3x0123),
  107         vmull_u8(va3x01234567, vget_low_u8(vb01234567x01))));
  108     vacc3x4567 = vreinterpretq_s32_u32(vpadalq_u16(
  109         vreinterpretq_u32_s32(vacc3x4567),
  110         vmull_u8(va3x01234567, vget_high_u8(vb01234567x01))));
  111 
  112     /* k = 2, 3 */
  113     va0x01234567 = vext_u8(va0x01234567, va0x01234567, 2);
  114     va1x01234567 = vext_u8(va1x01234567, va1x01234567, 2);
  115     va2x01234567 = vext_u8(va2x01234567, va2x01234567, 2);
  116     va3x01234567 = vext_u8(va3x01234567, va3x01234567, 2);
  117 
  118     const uint8x16_t vb01234567x23 = vld1q_u8(w);
  119     w += 16;
  120 
  121     vacc0x0123 = vreinterpretq_s32_u32(vpadalq_u16(
  122         vreinterpretq_u32_s32(vacc0x0123),
  123         vmull_u8(va0x01234567, vget_low_u8(vb01234567x23))));
  124     vacc0x4567 = vreinterpretq_s32_u32(vpadalq_u16(
  125         vreinterpretq_u32_s32(vacc0x4567),
  126         vmull_u8(va0x01234567, vget_high_u8(vb01234567x23))));
  127 
  128     vacc1x0123 = vreinterpretq_s32_u32(vpadalq_u16(
  129         vreinterpretq_u32_s32(vacc1x0123),
  130         vmull_u8(va1x01234567, vget_low_u8(vb01234567x23))));
  131     vacc1x4567 = vreinterpretq_s32_u32(vpadalq_u16(
  132         vreinterpretq_u32_s32(vacc1x4567),
  133         vmull_u8(va1x01234567, vget_high_u8(vb01234567x23))));
  134 
  135     vacc2x0123 = vreinterpretq_s32_u32(vpadalq_u16(
  136         vreinterpretq_u32_s32(vacc2x0123),
  137         vmull_u8(va2x01234567, vget_low_u8(vb01234567x23))));
  138     vacc2x4567 = vreinterpretq_s32_u32(vpadalq_u16(
  139         vreinterpretq_u32_s32(vacc2x4567),
  140         vmull_u8(va2x01234567, vget_high_u8(vb01234567x23))));
  141 
  142     vacc3x0123 = vreinterpretq_s32_u32(vpadalq_u16(
  143         vreinterpretq_u32_s32(vacc3x0123),
  144         vmull_u8(va3x01234567, vget_low_u8(vb01234567x23))));
  145     vacc3x4567 = vreinterpretq_s32_u32(vpadalq_u16(
  146         vreinterpretq_u32_s32(vacc3x4567),
  147         vmull_u8(va3x01234567, vget_high_u8(vb01234567x23))));
  148 
  149     /* k = 4, 5 */
  150     va0x01234567 = vext_u8(va0x01234567, va0x01234567, 2);
  151     va1x01234567 = vext_u8(va1x01234567, va1x01234567, 2);
  152     va2x01234567 = vext_u8(va2x01234567, va2x01234567, 2);
  153     va3x01234567 = vext_u8(va3x01234567, va3x01234567, 2);
  154 
  155     const uint8x16_t vb01234567x45 = vld1q_u8(w);
  156     w += 16;
  157 
  158     vacc0x0123 = vreinterpretq_s32_u32(vpadalq_u16(
  159         vreinterpretq_u32_s32(vacc0x0123),
  160         vmull_u8(va0x01234567, vget_low_u8(vb01234567x45))));
  161     vacc0x4567 = vreinterpretq_s32_u32(vpadalq_u16(
  162         vreinterpretq_u32_s32(vacc0x4567),
  163         vmull_u8(va0x01234567, vget_high_u8(vb01234567x45))));
  164 
  165     vacc1x0123 = vreinterpretq_s32_u32(vpadalq_u16(
  166         vreinterpretq_u32_s32(vacc1x0123),
  167         vmull_u8(va1x01234567, vget_low_u8(vb01234567x45))));
  168     vacc1x4567 = vreinterpretq_s32_u32(vpadalq_u16(
  169         vreinterpretq_u32_s32(vacc1x4567),
  170         vmull_u8(va1x01234567, vget_high_u8(vb01234567x45))));
  171 
  172     vacc2x0123 = vreinterpretq_s32_u32(vpadalq_u16(
  173         vreinterpretq_u32_s32(vacc2x0123),
  174         vmull_u8(va2x01234567, vget_low_u8(vb01234567x45))));
  175     vacc2x4567 = vreinterpretq_s32_u32(vpadalq_u16(
  176         vreinterpretq_u32_s32(vacc2x4567),
  177         vmull_u8(va2x01234567, vget_high_u8(vb01234567x45))));
  178 
  179     vacc3x0123 = vreinterpretq_s32_u32(vpadalq_u16(
  180         vreinterpretq_u32_s32(vacc3x0123),
  181         vmull_u8(va3x01234567, vget_low_u8(vb01234567x45))));
  182     vacc3x4567 = vreinterpretq_s32_u32(vpadalq_u16(
  183         vreinterpretq_u32_s32(vacc3x4567),
  184         vmull_u8(va3x01234567, vget_high_u8(vb01234567x45))));
  185 
  186     /* k = 6, 7 */
  187     va0x01234567 = vext_u8(va0x01234567, va0x01234567, 2);
  188     va1x01234567 = vext_u8(va1x01234567, va1x01234567, 2);
  189     va2x01234567 = vext_u8(va2x01234567, va2x01234567, 2);
  190     va3x01234567 = vext_u8(va3x01234567, va3x01234567, 2);
  191 
  192     const uint8x16_t vb01234567x67 = vld1q_u8(w);
  193     w += 16;
  194 
  195     vacc0x0123 = vreinterpretq_s32_u32(vpadalq_u16(
  196         vreinterpretq_u32_s32(vacc0x0123),
  197         vmull_u8(va0x01234567, vget_low_u8(vb01234567x67))));
  198     vacc0x4567 = vreinterpretq_s32_u32(vpadalq_u16(
  199         vreinterpretq_u32_s32(vacc0x4567),
  200         vmull_u8(va0x01234567, vget_high_u8(vb01234567x67))));
  201 
  202     vacc1x0123 = vreinterpretq_s32_u32(vpadalq_u16(
  203         vreinterpretq_u32_s32(vacc1x0123),
  204         vmull_u8(va1x01234567, vget_low_u8(vb01234567x67))));
  205     vacc1x4567 = vreinterpretq_s32_u32(vpadalq_u16(
  206         vreinterpretq_u32_s32(vacc1x4567),
  207         vmull_u8(va1x01234567, vget_high_u8(vb01234567x67))));
  208 
  209     vacc2x0123 = vreinterpretq_s32_u32(vpadalq_u16(
  210         vreinterpretq_u32_s32(vacc2x0123),
  211         vmull_u8(va2x01234567, vget_low_u8(vb01234567x67))));
  212     vacc2x4567 = vreinterpretq_s32_u32(vpadalq_u16(
  213         vreinterpretq_u32_s32(vacc2x4567),
  214         vmull_u8(va2x01234567, vget_high_u8(vb01234567x67))));
  215 
  216     vacc3x0123 = vreinterpretq_s32_u32(vpadalq_u16(
  217         vreinterpretq_u32_s32(vacc3x0123),
  218         vmull_u8(va3x01234567, vget_low_u8(vb01234567x67))));
  219     vacc3x4567 = vreinterpretq_s32_u32(vpadalq_u16(
  220         vreinterpretq_u32_s32(vacc3x4567),
  221         vmull_u8(va3x01234567, vget_high_u8(vb01234567x67))));
  222   }
  223 
  224   /* for k < 8, reuse the packing scheme for the original xzp ukernel */
  225   if (k & 4) {
  226     /* k = 0, 1 */
  227     const uint8x8_t va0x01010101 = vreinterpret_u8_u16(
  228         vld1_dup_u16(__builtin_assume_aligned((const uint16_t*)a0, 1)));
  229     a0 += 2;
  230     const uint8x8_t va1x01010101 = vreinterpret_u8_u16(
  231         vld1_dup_u16(__builtin_assume_aligned((const uint16_t*)a1, 1)));
  232     a1 += 2;
  233     const uint8x8_t va2x01010101 = vreinterpret_u8_u16(
  234         vld1_dup_u16(__builtin_assume_aligned((const uint16_t*)a2, 1)));
  235     a2 += 2;
  236     const uint8x8_t va3x01010101 = vreinterpret_u8_u16(
  237         vld1_dup_u16(__builtin_assume_aligned((const uint16_t*)a3, 1)));
  238     a3 += 2;
  239     const uint8x16_t vb01234567x01 = vld1q_u8(w);
  240     w += 16;
  241     vacc0x0123 = vreinterpretq_s32_u32(vpadalq_u16(
  242         vreinterpretq_u32_s32(vacc0x0123),
  243         vmull_u8(va0x01010101, vget_low_u8(vb01234567x01))));
  244     vacc0x4567 = vreinterpretq_s32_u32(vpadalq_u16(
  245         vreinterpretq_u32_s32(vacc0x4567),
  246         vmull_u8(va0x01010101, vget_high_u8(vb01234567x01))));
  247     vacc1x0123 = vreinterpretq_s32_u32(vpadalq_u16(
  248         vreinterpretq_u32_s32(vacc1x0123),
  249         vmull_u8(va1x01010101, vget_low_u8(vb01234567x01))));
  250     vacc1x4567 = vreinterpretq_s32_u32(vpadalq_u16(
  251         vreinterpretq_u32_s32(vacc1x4567),
  252         vmull_u8(va1x01010101, vget_high_u8(vb01234567x01))));
  253     vacc2x0123 = vreinterpretq_s32_u32(vpadalq_u16(
  254         vreinterpretq_u32_s32(vacc2x0123),
  255         vmull_u8(va2x01010101, vget_low_u8(vb01234567x01))));
  256     vacc2x4567 = vreinterpretq_s32_u32(vpadalq_u16(
  257         vreinterpretq_u32_s32(vacc2x4567),
  258         vmull_u8(va2x01010101, vget_high_u8(vb01234567x01))));
  259     vacc3x0123 = vreinterpretq_s32_u32(vpadalq_u16(
  260         vreinterpretq_u32_s32(vacc3x0123),
  261         vmull_u8(va3x01010101, vget_low_u8(vb01234567x01))));
  262     vacc3x4567 = vreinterpretq_s32_u32(vpadalq_u16(
  263         vreinterpretq_u32_s32(vacc3x4567),
  264         vmull_u8(va3x01010101, vget_high_u8(vb01234567x01))));
  265 
  266     /* k = 2, 3 */
  267     const uint8x8_t va0x23232323 = vreinterpret_u8_u16(
  268         vld1_dup_u16(__builtin_assume_aligned((const uint16_t*)a0, 1)));
  269     a0 += 2;
  270     const uint8x8_t va1x23232323 = vreinterpret_u8_u16(
  271         vld1_dup_u16(__builtin_assume_aligned((const uint16_t*)a1, 1)));
  272     a1 += 2;
  273     const uint8x8_t va2x23232323 = vreinterpret_u8_u16(
  274         vld1_dup_u16(__builtin_assume_aligned((const uint16_t*)a2, 1)));
  275     a2 += 2;
  276     const uint8x8_t va3x23232323 = vreinterpret_u8_u16(
  277         vld1_dup_u16(__builtin_assume_aligned((const uint16_t*)a3, 1)));
  278     a3 += 2;
  279     const uint8x16_t vb01234567x23 = vld1q_u8(w);
  280     w += 16;
  281     vacc0x0123 = vreinterpretq_s32_u32(vpadalq_u16(
  282         vreinterpretq_u32_s32(vacc0x0123),
  283         vmull_u8(va0x23232323, vget_low_u8(vb01234567x23))));
  284     vacc0x4567 = vreinterpretq_s32_u32(vpadalq_u16(
  285         vreinterpretq_u32_s32(vacc0x4567),
  286         vmull_u8(va0x23232323, vget_high_u8(vb01234567x23))));
  287     vacc1x0123 = vreinterpretq_s32_u32(vpadalq_u16(
  288         vreinterpretq_u32_s32(vacc1x0123),
  289         vmull_u8(va1x23232323, vget_low_u8(vb01234567x23))));
  290     vacc1x4567 = vreinterpretq_s32_u32(vpadalq_u16(
  291         vreinterpretq_u32_s32(vacc1x4567),
  292         vmull_u8(va1x23232323, vget_high_u8(vb01234567x23))));
  293     vacc2x0123 = vreinterpretq_s32_u32(vpadalq_u16(
  294         vreinterpretq_u32_s32(vacc2x0123),
  295         vmull_u8(va2x23232323, vget_low_u8(vb01234567x23))));
  296     vacc2x4567 = vreinterpretq_s32_u32(vpadalq_u16(
  297         vreinterpretq_u32_s32(vacc2x4567),
  298         vmull_u8(va2x23232323, vget_high_u8(vb01234567x23))));
  299     vacc3x0123 = vreinterpretq_s32_u32(vpadalq_u16(
  300         vreinterpretq_u32_s32(vacc3x0123),
  301         vmull_u8(va3x23232323, vget_low_u8(vb01234567x23))));
  302     vacc3x4567 = vreinterpretq_s32_u32(vpadalq_u16(
  303         vreinterpretq_u32_s32(vacc3x4567),
  304         vmull_u8(va3x23232323, vget_high_u8(vb01234567x23))));
  305   }
  306   if (k & 2) {
  307     /* k = 0, 1 */
  308     const uint8x8_t va0x01010101 = vreinterpret_u8_u16(
  309         vld1_dup_u16(__builtin_assume_aligned((const uint16_t*)a0, 1)));
  310     a0 += 2;
  311     const uint8x8_t va1x01010101 = vreinterpret_u8_u16(
  312         vld1_dup_u16(__builtin_assume_aligned((const uint16_t*)a1, 1)));
  313     a1 += 2;
  314     const uint8x8_t va2x01010101 = vreinterpret_u8_u16(
  315         vld1_dup_u16(__builtin_assume_aligned((const uint16_t*)a2, 1)));
  316     a2 += 2;
  317     const uint8x8_t va3x01010101 = vreinterpret_u8_u16(
  318         vld1_dup_u16(__builtin_assume_aligned((const uint16_t*)a3, 1)));
  319     a3 += 2;
  320     const uint8x16_t vb01234567x01 = vld1q_u8(w);
  321     w += 16;
  322     vacc0x0123 = vreinterpretq_s32_u32(vpadalq_u16(
  323         vreinterpretq_u32_s32(vacc0x0123),
  324         vmull_u8(va0x01010101, vget_low_u8(vb01234567x01))));
  325     vacc0x4567 = vreinterpretq_s32_u32(vpadalq_u16(
  326         vreinterpretq_u32_s32(vacc0x4567),
  327         vmull_u8(va0x01010101, vget_high_u8(vb01234567x01))));
  328     vacc1x0123 = vreinterpretq_s32_u32(vpadalq_u16(
  329         vreinterpretq_u32_s32(vacc1x0123),
  330         vmull_u8(va1x01010101, vget_low_u8(vb01234567x01))));
  331     vacc1x4567 = vreinterpretq_s32_u32(vpadalq_u16(
  332         vreinterpretq_u32_s32(vacc1x4567),
  333         vmull_u8(va1x01010101, vget_high_u8(vb01234567x01))));
  334     vacc2x0123 = vreinterpretq_s32_u32(vpadalq_u16(
  335         vreinterpretq_u32_s32(vacc2x0123),
  336         vmull_u8(va2x01010101, vget_low_u8(vb01234567x01))));
  337     vacc2x4567 = vreinterpretq_s32_u32(vpadalq_u16(
  338         vreinterpretq_u32_s32(vacc2x4567),
  339         vmull_u8(va2x01010101, vget_high_u8(vb01234567x01))));
  340     vacc3x0123 = vreinterpretq_s32_u32(vpadalq_u16(
  341         vreinterpretq_u32_s32(vacc3x0123),
  342         vmull_u8(va3x01010101, vget_low_u8(vb01234567x01))));
  343     vacc3x4567 = vreinterpretq_s32_u32(vpadalq_u16(
  344         vreinterpretq_u32_s32(vacc3x4567),
  345         vmull_u8(va3x01010101, vget_high_u8(vb01234567x01))));
  346   }
  347   if (k & 1) {
  348     const uint8x8_t va0x00000000 = vld1_dup_u8(a0);
  349     const uint8x8_t va1x00000000 = vld1_dup_u8(a1);
  350     const uint8x8_t va2x00000000 = vld1_dup_u8(a2);
  351     const uint8x8_t va3x00000000 = vld1_dup_u8(a3);
  352     const uint8x16_t vb01234567x0 = vld1q_u8(w);
  353     vacc0x0123 = vreinterpretq_s32_u32(vpadalq_u16(
  354         vreinterpretq_u32_s32(vacc0x0123),
  355         vmull_u8(va0x00000000, vget_low_u8(vb01234567x0))));
  356     vacc0x4567 = vreinterpretq_s32_u32(vpadalq_u16(
  357         vreinterpretq_u32_s32(vacc0x4567),
  358         vmull_u8(va0x00000000, vget_high_u8(vb01234567x0))));
  359     vacc1x0123 = vreinterpretq_s32_u32(vpadalq_u16(
  360         vreinterpretq_u32_s32(vacc1x0123),
  361         vmull_u8(va1x00000000, vget_low_u8(vb01234567x0))));
  362     vacc1x4567 = vreinterpretq_s32_u32(vpadalq_u16(
  363         vreinterpretq_u32_s32(vacc1x4567),
  364         vmull_u8(va1x00000000, vget_high_u8(vb01234567x0))));
  365     vacc2x0123 = vreinterpretq_s32_u32(vpadalq_u16(
  366         vreinterpretq_u32_s32(vacc2x0123),
  367         vmull_u8(va2x00000000, vget_low_u8(vb01234567x0))));
  368     vacc2x4567 = vreinterpretq_s32_u32(vpadalq_u16(
  369         vreinterpretq_u32_s32(vacc2x4567),
  370         vmull_u8(va2x00000000, vget_high_u8(vb01234567x0))));
  371     vacc3x0123 = vreinterpretq_s32_u32(vpadalq_u16(
  372         vreinterpretq_u32_s32(vacc3x0123),
  373         vmull_u8(va3x00000000, vget_low_u8(vb01234567x0))));
  374     vacc3x4567 = vreinterpretq_s32_u32(vpadalq_u16(
  375         vreinterpretq_u32_s32(vacc3x4567),
  376         vmull_u8(va3x00000000, vget_high_u8(vb01234567x0))));
  377   }
  378 
  379   const int32x4_t vmultiplier =
  380       vld1q_dup_s32(&requantization_params->neon.multiplier);
  381   vacc0x0123 = vqrdmulhq_s32(vacc0x0123, vmultiplier);
  382   vacc0x4567 = vqrdmulhq_s32(vacc0x4567, vmultiplier);
  383   vacc1x0123 = vqrdmulhq_s32(vacc1x0123, vmultiplier);
  384   vacc1x4567 = vqrdmulhq_s32(vacc1x4567, vmultiplier);
  385   vacc2x0123 = vqrdmulhq_s32(vacc2x0123, vmultiplier);
  386   vacc2x4567 = vqrdmulhq_s32(vacc2x4567, vmultiplier);
  387   vacc3x0123 = vqrdmulhq_s32(vacc3x0123, vmultiplier);
  388   vacc3x4567 = vqrdmulhq_s32(vacc3x4567, vmultiplier);
  389 
  390   const int32x4_t vright_shift =
  391       vld1q_dup_s32(&requantization_params->neon.right_shift);
  392   const int32x4_t vzero_shift_mask =
  393       vreinterpretq_s32_u32(vceqq_s32(vright_shift, vmovq_n_s32(0)));
  394   vacc0x0123 =
  395       vsraq_n_s32(vacc0x0123, vbicq_s32(vacc0x0123, vzero_shift_mask), 31);
  396   vacc0x4567 =
  397       vsraq_n_s32(vacc0x4567, vbicq_s32(vacc0x4567, vzero_shift_mask), 31);
  398   vacc1x0123 =
  399       vsraq_n_s32(vacc1x0123, vbicq_s32(vacc1x0123, vzero_shift_mask), 31);
  400   vacc1x4567 =
  401       vsraq_n_s32(vacc1x4567, vbicq_s32(vacc1x4567, vzero_shift_mask), 31);
  402   vacc2x0123 =
  403       vsraq_n_s32(vacc2x0123, vbicq_s32(vacc2x0123, vzero_shift_mask), 31);
  404   vacc2x4567 =
  405       vsraq_n_s32(vacc2x4567, vbicq_s32(vacc2x4567, vzero_shift_mask), 31);
  406   vacc3x0123 =
  407       vsraq_n_s32(vacc3x0123, vbicq_s32(vacc3x0123, vzero_shift_mask), 31);
  408   vacc3x4567 =
  409       vsraq_n_s32(vacc3x4567, vbicq_s32(vacc3x4567, vzero_shift_mask), 31);
  410 
  411   vacc0x0123 = vrshlq_s32(vacc0x0123, vright_shift);
  412   vacc0x4567 = vrshlq_s32(vacc0x4567, vright_shift);
  413   vacc1x0123 = vrshlq_s32(vacc1x0123, vright_shift);
  414   vacc1x4567 = vrshlq_s32(vacc1x4567, vright_shift);
  415   vacc2x0123 = vrshlq_s32(vacc2x0123, vright_shift);
  416   vacc2x4567 = vrshlq_s32(vacc2x4567, vright_shift);
  417   vacc3x0123 = vrshlq_s32(vacc3x0123, vright_shift);
  418   vacc3x4567 = vrshlq_s32(vacc3x4567, vright_shift);
  419 
  420   const int16x8_t vzero_point =
  421       vld1q_dup_s16(&requantization_params->neon.zero_point);
  422 #ifdef __aarch64__
  423   const int16x8_t vacc0x01234567 = vqaddq_s16(
  424       vqmovn_high_s32(vqmovn_s32(vacc0x0123), vacc0x4567), vzero_point);
  425   const int16x8_t vacc1x01234567 = vqaddq_s16(
  426       vqmovn_high_s32(vqmovn_s32(vacc1x0123), vacc1x4567), vzero_point);
  427   const int16x8_t vacc2x01234567 = vqaddq_s16(
  428       vqmovn_high_s32(vqmovn_s32(vacc2x0123), vacc2x4567), vzero_point);
  429   const int16x8_t vacc3x01234567 = vqaddq_s16(
  430       vqmovn_high_s32(vqmovn_s32(vacc3x0123), vacc3x4567), vzero_point);
  431 
  432   uint8x16_t vout0x01234567_1x01234567 =
  433       vqmovun_high_s16(vqmovun_s16(vacc0x01234567), vacc1x01234567);
  434   uint8x16_t vout2x01234567_3x01234567 =
  435       vqmovun_high_s16(vqmovun_s16(vacc2x01234567), vacc3x01234567);
  436 #else
  437   const int16x8_t vacc0x01234567 = vqaddq_s16(
  438       vcombine_s16(vqmovn_s32(vacc0x0123), vqmovn_s32(vacc0x4567)),
  439       vzero_point);
  440   const int16x8_t vacc1x01234567 = vqaddq_s16(
  441       vcombine_s16(vqmovn_s32(vacc1x0123), vqmovn_s32(vacc1x4567)),
  442       vzero_point);
  443   const int16x8_t vacc2x01234567 = vqaddq_s16(
  444       vcombine_s16(vqmovn_s32(vacc2x0123), vqmovn_s32(vacc2x4567)),
  445       vzero_point);
  446   const int16x8_t vacc3x01234567 = vqaddq_s16(
  447       vcombine_s16(vqmovn_s32(vacc3x0123), vqmovn_s32(vacc3x4567)),
  448       vzero_point);
  449 
  450   uint8x16_t vout0x01234567_1x01234567 =
  451       vcombine_u8(vqmovun_s16(vacc0x01234567), vqmovun_s16(vacc1x01234567));
  452   uint8x16_t vout2x01234567_3x01234567 =
  453       vcombine_u8(vqmovun_s16(vacc2x01234567), vqmovun_s16(vacc3x01234567));
  454 #endif
  455   const uint8x16_t vmin = vld1q_dup_u8(&requantization_params->neon.min);
  456   const uint8x16_t vmax = vld1q_dup_u8(&requantization_params->neon.max);
  457 
  458   vout0x01234567_1x01234567 = vmaxq_u8(vout0x01234567_1x01234567, vmin);
  459   vout2x01234567_3x01234567 = vmaxq_u8(vout2x01234567_3x01234567, vmin);
  460   vout0x01234567_1x01234567 = vminq_u8(vout0x01234567_1x01234567, vmax);
  461   vout2x01234567_3x01234567 = vminq_u8(vout2x01234567_3x01234567, vmax);
  462 
  463   uint8_t* c0 = c;
  464   uint8_t* c1 = c0;
  465   if (mr >= 2) {
  466     c1 += c_stride;
  467   }
  468   uint8_t* c2 = c1;
  469   if (mr > 2) {
  470     c2 += c_stride;
  471   }
  472   uint8_t* c3 = c2;
  473   if (mr == 4) {
  474     c3 += c_stride;
  475   }
  476   if (nr == 8) {
  477     vst1_u8(c0, vget_low_u8(vout0x01234567_1x01234567));
  478     vst1_u8(c1, vget_high_u8(vout0x01234567_1x01234567));
  479     vst1_u8(c2, vget_low_u8(vout2x01234567_3x01234567));
  480     vst1_u8(c3, vget_high_u8(vout2x01234567_3x01234567));
  481   } else {
  482     if (nr >= 4) {
  483       vst1q_lane_u32(
  484           __builtin_assume_aligned(c0, 1),
  485           vreinterpretq_u32_u8(vout0x01234567_1x01234567),
  486           0);
  487       c0 += 4;
  488       vst1q_lane_u32(
  489           __builtin_assume_aligned(c1, 1),
  490           vreinterpretq_u32_u8(vout0x01234567_1x01234567),
  491           2);
  492       c1 += 4;
  493       vst1q_lane_u32(
  494           __builtin_assume_aligned(c2, 1),
  495           vreinterpretq_u32_u8(vout2x01234567_3x01234567),
  496           0);
  497       c2 += 4;
  498       vst1q_lane_u32(
  499           __builtin_assume_aligned(c3, 1),
  500           vreinterpretq_u32_u8(vout2x01234567_3x01234567),
  501           2);
  502       c3 += 4;
  503       vout0x01234567_1x01234567 =
  504           vextq_u8(vout0x01234567_1x01234567, vout0x01234567_1x01234567, 4);
  505       vout2x01234567_3x01234567 =
  506           vextq_u8(vout2x01234567_3x01234567, vout2x01234567_3x01234567, 4);
  507       nr -= 4;
  508     }
  509     if (nr >= 2) {
  510       vst1q_lane_u16(
  511           __builtin_assume_aligned(c0, 1),
  512           vreinterpretq_u16_u8(vout0x01234567_1x01234567),
  513           0);
  514       c0 += 2;
  515       vst1q_lane_u16(
  516           __builtin_assume_aligned(c1, 1),
  517           vreinterpretq_u16_u8(vout0x01234567_1x01234567),
  518           4);
  519       c1 += 2;
  520       vst1q_lane_u16(
  521           __builtin_assume_aligned(c2, 1),
  522           vreinterpretq_u16_u8(vout2x01234567_3x01234567),
  523           0);
  524       c2 += 2;
  525       vst1q_lane_u16(
  526           __builtin_assume_aligned(c3, 1),
  527           vreinterpretq_u16_u8(vout2x01234567_3x01234567),
  528           4);
  529       c3 += 2;
  530       vout0x01234567_1x01234567 =
  531           vextq_u8(vout0x01234567_1x01234567, vout0x01234567_1x01234567, 2);
  532       vout2x01234567_3x01234567 =
  533           vextq_u8(vout2x01234567_3x01234567, vout2x01234567_3x01234567, 2);
  534       nr -= 2;
  535     }
  536     if (nr != 0) {
  537       vst1q_lane_u8(c0, vout0x01234567_1x01234567, 0);
  538       vst1q_lane_u8(c1, vout0x01234567_1x01234567, 8);
  539       vst1q_lane_u8(c2, vout2x01234567_3x01234567, 0);
  540       vst1q_lane_u8(c3, vout2x01234567_3x01234567, 8);
  541     }
  542   }
  543 }