"Fossies" - the Fresh Open Source Software Archive

Member "pytorch-1.8.2/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm/4x8-neon.c" (23 Jul 2021, 31475 Bytes) of package /linux/misc/pytorch-1.8.2.tar.gz:


As a special service "Fossies" has tried to format the requested source page into HTML format using (guessed) C and C++ source code syntax highlighting (style: standard) with prefixed line numbers and code folding option. Alternatively you can here view or download the uninterpreted source code file. For more information about "4x8-neon.c" see the Fossies "Dox" file reference documentation.

    1 /*
    2  * Copyright (c) Facebook, Inc. and its affiliates.
    3  * All rights reserved.
    4  *
    5  * This source code is licensed under the BSD-style license found in the
    6  * LICENSE file in the root directory of this source tree.
    7  */
    8 
    9 #include <arm_neon.h>
   10 
   11 #include <qnnpack/q8gemm.h>
   12 #include <requantization/runtime-neon.h>
   13 
   14 void pytorch_q8gemm_ukernel_4x8__neon(
   15     size_t mr,
   16     size_t nr,
   17     size_t k,
   18     const uint8_t* restrict a,
   19     size_t a_stride,
   20     const void* restrict w,
   21     uint8_t* restrict c,
   22     size_t c_stride,
   23     size_t output_channel_index,
   24     const union pytorch_qnnp_conv_quantization_params
   25         quantization_params[restrict static 1]) {
   26   int32x4_t vacc0x0123 = vld1q_s32(w);
   27   w = (const void*)((uintptr_t)w + 16);
   28   int32x4_t vacc0x4567 = vld1q_s32(w);
   29   w = (const void*)((uintptr_t)w + 16);
   30   int32x4_t vacc1x0123 = vacc0x0123;
   31   int32x4_t vacc1x4567 = vacc0x4567;
   32   int32x4_t vacc2x0123 = vacc0x0123;
   33   int32x4_t vacc2x4567 = vacc0x4567;
   34   int32x4_t vacc3x0123 = vacc0x0123;
   35   int32x4_t vacc3x4567 = vacc0x4567;
   36 
   37   const uint8_t* a0 = a;
   38   const uint8_t* a1 = (const uint8_t*)((uintptr_t)a0 + a_stride);
   39   if (mr < 2) {
   40     a1 = a0;
   41   }
   42   const uint8_t* a2 = (const uint8_t*)((uintptr_t)a1 + a_stride);
   43   if (mr <= 2) {
   44     a2 = a1;
   45   }
   46   const uint8_t* a3 = (const uint8_t*)((uintptr_t)a2 + a_stride);
   47   if (mr != 4) {
   48     a3 = a2;
   49   }
   50 
   51   const uint8x8_t va_zero_point =
   52       vld1_dup_u8((const uint8_t*)&quantization_params->neon.input_zero_point);
   53   // Assumes that kernel_zero_points is an array padded with necessary elements
   54   // in order to make it multiple of 8.
   55   const uint8x8_t vb_zero_point =
   56       vld1_u8((const uint8_t*)&quantization_params->neon.kernel_zero_points
   57           [output_channel_index]);
   58   for (; k >= 8; k -= 8) {
   59     const uint8x8_t va0 = vld1_u8(a0);
   60     a0 += 8;
   61     const int16x8_t vxa0 =
   62         vreinterpretq_s16_u16(sub_zero_point(va0, va_zero_point));
   63     const uint8x8_t va1 = vld1_u8(a1);
   64     a1 += 8;
   65     const int16x8_t vxa1 =
   66         vreinterpretq_s16_u16(sub_zero_point(va1, va_zero_point));
   67     const uint8x8_t va2 = vld1_u8(a2);
   68     a2 += 8;
   69     const int16x8_t vxa2 =
   70         vreinterpretq_s16_u16(sub_zero_point(va2, va_zero_point));
   71     const uint8x8_t va3 = vld1_u8(a3);
   72     a3 += 8;
   73     const int16x8_t vxa3 =
   74         vreinterpretq_s16_u16(sub_zero_point(va3, va_zero_point));
   75 
   76     const uint8x8_t vb01234567c0 = vld1_u8(w);
   77     w = (const void*)((uintptr_t)w + 8);
   78     const int16x8_t vxb01234567c0 =
   79         vreinterpretq_s16_u16(vsubl_u8(vb01234567c0, vb_zero_point));
   80 
   81     vacc0x0123 = vmlal_lane_s16(
   82         vacc0x0123, vget_low_s16(vxb01234567c0), vget_low_s16(vxa0), 0);
   83     vacc0x4567 = vmlal_lane_s16(
   84         vacc0x4567, vget_high_s16(vxb01234567c0), vget_low_s16(vxa0), 0);
   85     vacc1x0123 = vmlal_lane_s16(
   86         vacc1x0123, vget_low_s16(vxb01234567c0), vget_low_s16(vxa1), 0);
   87     vacc1x4567 = vmlal_lane_s16(
   88         vacc1x4567, vget_high_s16(vxb01234567c0), vget_low_s16(vxa1), 0);
   89     vacc2x0123 = vmlal_lane_s16(
   90         vacc2x0123, vget_low_s16(vxb01234567c0), vget_low_s16(vxa2), 0);
   91     vacc2x4567 = vmlal_lane_s16(
   92         vacc2x4567, vget_high_s16(vxb01234567c0), vget_low_s16(vxa2), 0);
   93     vacc3x0123 = vmlal_lane_s16(
   94         vacc3x0123, vget_low_s16(vxb01234567c0), vget_low_s16(vxa3), 0);
   95     vacc3x4567 = vmlal_lane_s16(
   96         vacc3x4567, vget_high_s16(vxb01234567c0), vget_low_s16(vxa3), 0);
   97 
   98     const uint8x8_t vb01234567c1 = vld1_u8(w);
   99     w = (const void*)((uintptr_t)w + 8);
  100     const int16x8_t vxb01234567c1 =
  101         vreinterpretq_s16_u16(vsubl_u8(vb01234567c1, vb_zero_point));
  102 
  103     vacc0x0123 = vmlal_lane_s16(
  104         vacc0x0123, vget_low_s16(vxb01234567c1), vget_low_s16(vxa0), 1);
  105     vacc0x4567 = vmlal_lane_s16(
  106         vacc0x4567, vget_high_s16(vxb01234567c1), vget_low_s16(vxa0), 1);
  107     vacc1x0123 = vmlal_lane_s16(
  108         vacc1x0123, vget_low_s16(vxb01234567c1), vget_low_s16(vxa1), 1);
  109     vacc1x4567 = vmlal_lane_s16(
  110         vacc1x4567, vget_high_s16(vxb01234567c1), vget_low_s16(vxa1), 1);
  111     vacc2x0123 = vmlal_lane_s16(
  112         vacc2x0123, vget_low_s16(vxb01234567c1), vget_low_s16(vxa2), 1);
  113     vacc2x4567 = vmlal_lane_s16(
  114         vacc2x4567, vget_high_s16(vxb01234567c1), vget_low_s16(vxa2), 1);
  115     vacc3x0123 = vmlal_lane_s16(
  116         vacc3x0123, vget_low_s16(vxb01234567c1), vget_low_s16(vxa3), 1);
  117     vacc3x4567 = vmlal_lane_s16(
  118         vacc3x4567, vget_high_s16(vxb01234567c1), vget_low_s16(vxa3), 1);
  119 
  120     const uint8x8_t vb01234567c2 = vld1_u8(w);
  121     w = (const void*)((uintptr_t)w + 8);
  122     const int16x8_t vxb01234567c2 =
  123         vreinterpretq_s16_u16(vsubl_u8(vb01234567c2, vb_zero_point));
  124 
  125     vacc0x0123 = vmlal_lane_s16(
  126         vacc0x0123, vget_low_s16(vxb01234567c2), vget_low_s16(vxa0), 2);
  127     vacc0x4567 = vmlal_lane_s16(
  128         vacc0x4567, vget_high_s16(vxb01234567c2), vget_low_s16(vxa0), 2);
  129     vacc1x0123 = vmlal_lane_s16(
  130         vacc1x0123, vget_low_s16(vxb01234567c2), vget_low_s16(vxa1), 2);
  131     vacc1x4567 = vmlal_lane_s16(
  132         vacc1x4567, vget_high_s16(vxb01234567c2), vget_low_s16(vxa1), 2);
  133     vacc2x0123 = vmlal_lane_s16(
  134         vacc2x0123, vget_low_s16(vxb01234567c2), vget_low_s16(vxa2), 2);
  135     vacc2x4567 = vmlal_lane_s16(
  136         vacc2x4567, vget_high_s16(vxb01234567c2), vget_low_s16(vxa2), 2);
  137     vacc3x0123 = vmlal_lane_s16(
  138         vacc3x0123, vget_low_s16(vxb01234567c2), vget_low_s16(vxa3), 2);
  139     vacc3x4567 = vmlal_lane_s16(
  140         vacc3x4567, vget_high_s16(vxb01234567c2), vget_low_s16(vxa3), 2);
  141 
  142     const uint8x8_t vb01234567c3 = vld1_u8(w);
  143     w = (const void*)((uintptr_t)w + 8);
  144     const int16x8_t vxb01234567c3 =
  145         vreinterpretq_s16_u16(vsubl_u8(vb01234567c3, vb_zero_point));
  146 
  147     vacc0x0123 = vmlal_lane_s16(
  148         vacc0x0123, vget_low_s16(vxb01234567c3), vget_low_s16(vxa0), 3);
  149     vacc0x4567 = vmlal_lane_s16(
  150         vacc0x4567, vget_high_s16(vxb01234567c3), vget_low_s16(vxa0), 3);
  151     vacc1x0123 = vmlal_lane_s16(
  152         vacc1x0123, vget_low_s16(vxb01234567c3), vget_low_s16(vxa1), 3);
  153     vacc1x4567 = vmlal_lane_s16(
  154         vacc1x4567, vget_high_s16(vxb01234567c3), vget_low_s16(vxa1), 3);
  155     vacc2x0123 = vmlal_lane_s16(
  156         vacc2x0123, vget_low_s16(vxb01234567c3), vget_low_s16(vxa2), 3);
  157     vacc2x4567 = vmlal_lane_s16(
  158         vacc2x4567, vget_high_s16(vxb01234567c3), vget_low_s16(vxa2), 3);
  159     vacc3x0123 = vmlal_lane_s16(
  160         vacc3x0123, vget_low_s16(vxb01234567c3), vget_low_s16(vxa3), 3);
  161     vacc3x4567 = vmlal_lane_s16(
  162         vacc3x4567, vget_high_s16(vxb01234567c3), vget_low_s16(vxa3), 3);
  163 
  164     const uint8x8_t vb01234567c4 = vld1_u8(w);
  165     w = (const void*)((uintptr_t)w + 8);
  166     const int16x8_t vxb01234567c4 =
  167         vreinterpretq_s16_u16(vsubl_u8(vb01234567c4, vb_zero_point));
  168 
  169     vacc0x0123 = vmlal_lane_s16(
  170         vacc0x0123, vget_low_s16(vxb01234567c4), vget_high_s16(vxa0), 0);
  171     vacc0x4567 = vmlal_lane_s16(
  172         vacc0x4567, vget_high_s16(vxb01234567c4), vget_high_s16(vxa0), 0);
  173     vacc1x0123 = vmlal_lane_s16(
  174         vacc1x0123, vget_low_s16(vxb01234567c4), vget_high_s16(vxa1), 0);
  175     vacc1x4567 = vmlal_lane_s16(
  176         vacc1x4567, vget_high_s16(vxb01234567c4), vget_high_s16(vxa1), 0);
  177     vacc2x0123 = vmlal_lane_s16(
  178         vacc2x0123, vget_low_s16(vxb01234567c4), vget_high_s16(vxa2), 0);
  179     vacc2x4567 = vmlal_lane_s16(
  180         vacc2x4567, vget_high_s16(vxb01234567c4), vget_high_s16(vxa2), 0);
  181     vacc3x0123 = vmlal_lane_s16(
  182         vacc3x0123, vget_low_s16(vxb01234567c4), vget_high_s16(vxa3), 0);
  183     vacc3x4567 = vmlal_lane_s16(
  184         vacc3x4567, vget_high_s16(vxb01234567c4), vget_high_s16(vxa3), 0);
  185 
  186     const uint8x8_t vb01234567c5 = vld1_u8(w);
  187     w = (const void*)((uintptr_t)w + 8);
  188     const int16x8_t vxb01234567c5 =
  189         vreinterpretq_s16_u16(vsubl_u8(vb01234567c5, vb_zero_point));
  190 
  191     vacc0x0123 = vmlal_lane_s16(
  192         vacc0x0123, vget_low_s16(vxb01234567c5), vget_high_s16(vxa0), 1);
  193     vacc0x4567 = vmlal_lane_s16(
  194         vacc0x4567, vget_high_s16(vxb01234567c5), vget_high_s16(vxa0), 1);
  195     vacc1x0123 = vmlal_lane_s16(
  196         vacc1x0123, vget_low_s16(vxb01234567c5), vget_high_s16(vxa1), 1);
  197     vacc1x4567 = vmlal_lane_s16(
  198         vacc1x4567, vget_high_s16(vxb01234567c5), vget_high_s16(vxa1), 1);
  199     vacc2x0123 = vmlal_lane_s16(
  200         vacc2x0123, vget_low_s16(vxb01234567c5), vget_high_s16(vxa2), 1);
  201     vacc2x4567 = vmlal_lane_s16(
  202         vacc2x4567, vget_high_s16(vxb01234567c5), vget_high_s16(vxa2), 1);
  203     vacc3x0123 = vmlal_lane_s16(
  204         vacc3x0123, vget_low_s16(vxb01234567c5), vget_high_s16(vxa3), 1);
  205     vacc3x4567 = vmlal_lane_s16(
  206         vacc3x4567, vget_high_s16(vxb01234567c5), vget_high_s16(vxa3), 1);
  207 
  208     const uint8x8_t vb01234567c6 = vld1_u8(w);
  209     w = (const void*)((uintptr_t)w + 8);
  210     const int16x8_t vxb01234567c6 =
  211         vreinterpretq_s16_u16(vsubl_u8(vb01234567c6, vb_zero_point));
  212 
  213     vacc0x0123 = vmlal_lane_s16(
  214         vacc0x0123, vget_low_s16(vxb01234567c6), vget_high_s16(vxa0), 2);
  215     vacc0x4567 = vmlal_lane_s16(
  216         vacc0x4567, vget_high_s16(vxb01234567c6), vget_high_s16(vxa0), 2);
  217     vacc1x0123 = vmlal_lane_s16(
  218         vacc1x0123, vget_low_s16(vxb01234567c6), vget_high_s16(vxa1), 2);
  219     vacc1x4567 = vmlal_lane_s16(
  220         vacc1x4567, vget_high_s16(vxb01234567c6), vget_high_s16(vxa1), 2);
  221     vacc2x0123 = vmlal_lane_s16(
  222         vacc2x0123, vget_low_s16(vxb01234567c6), vget_high_s16(vxa2), 2);
  223     vacc2x4567 = vmlal_lane_s16(
  224         vacc2x4567, vget_high_s16(vxb01234567c6), vget_high_s16(vxa2), 2);
  225     vacc3x0123 = vmlal_lane_s16(
  226         vacc3x0123, vget_low_s16(vxb01234567c6), vget_high_s16(vxa3), 2);
  227     vacc3x4567 = vmlal_lane_s16(
  228         vacc3x4567, vget_high_s16(vxb01234567c6), vget_high_s16(vxa3), 2);
  229 
  230     const uint8x8_t vb01234567c7 = vld1_u8(w);
  231     w = (const void*)((uintptr_t)w + 8);
  232     const int16x8_t vxb01234567c7 =
  233         vreinterpretq_s16_u16(vsubl_u8(vb01234567c7, vb_zero_point));
  234 
  235     vacc0x0123 = vmlal_lane_s16(
  236         vacc0x0123, vget_low_s16(vxb01234567c7), vget_high_s16(vxa0), 3);
  237     vacc0x4567 = vmlal_lane_s16(
  238         vacc0x4567, vget_high_s16(vxb01234567c7), vget_high_s16(vxa0), 3);
  239     vacc1x0123 = vmlal_lane_s16(
  240         vacc1x0123, vget_low_s16(vxb01234567c7), vget_high_s16(vxa1), 3);
  241     vacc1x4567 = vmlal_lane_s16(
  242         vacc1x4567, vget_high_s16(vxb01234567c7), vget_high_s16(vxa1), 3);
  243     vacc2x0123 = vmlal_lane_s16(
  244         vacc2x0123, vget_low_s16(vxb01234567c7), vget_high_s16(vxa2), 3);
  245     vacc2x4567 = vmlal_lane_s16(
  246         vacc2x4567, vget_high_s16(vxb01234567c7), vget_high_s16(vxa2), 3);
  247     vacc3x0123 = vmlal_lane_s16(
  248         vacc3x0123, vget_low_s16(vxb01234567c7), vget_high_s16(vxa3), 3);
  249     vacc3x4567 = vmlal_lane_s16(
  250         vacc3x4567, vget_high_s16(vxb01234567c7), vget_high_s16(vxa3), 3);
  251   }
  252   if (k != 0) {
  253     const size_t a_predecrement = 8 - k;
  254     const int64x1_t va_shift = vmov_n_s64(-8 * a_predecrement);
  255     const uint8x8_t va0 = vreinterpret_u8_u64(
  256         vshl_u64(vreinterpret_u64_u8(vld1_u8(a0 - a_predecrement)), va_shift));
  257     const int16x8_t vxa0 =
  258         vreinterpretq_s16_u16(sub_zero_point(va0, va_zero_point));
  259     const uint8x8_t va1 = vreinterpret_u8_u64(
  260         vshl_u64(vreinterpret_u64_u8(vld1_u8(a1 - a_predecrement)), va_shift));
  261     const int16x8_t vxa1 =
  262         vreinterpretq_s16_u16(sub_zero_point(va1, va_zero_point));
  263     const uint8x8_t va2 = vreinterpret_u8_u64(
  264         vshl_u64(vreinterpret_u64_u8(vld1_u8(a2 - a_predecrement)), va_shift));
  265     const int16x8_t vxa2 =
  266         vreinterpretq_s16_u16(sub_zero_point(va2, va_zero_point));
  267     const uint8x8_t va3 = vreinterpret_u8_u64(
  268         vshl_u64(vreinterpret_u64_u8(vld1_u8(a3 - a_predecrement)), va_shift));
  269     const int16x8_t vxa3 =
  270         vreinterpretq_s16_u16(sub_zero_point(va3, va_zero_point));
  271 
  272     const uint8x8_t vb01234567c0 = vld1_u8(w);
  273     w = (const void*)((uintptr_t)w + 8);
  274     const int16x8_t vxb01234567c0 =
  275         vreinterpretq_s16_u16(vsubl_u8(vb01234567c0, vb_zero_point));
  276 
  277     vacc0x0123 = vmlal_lane_s16(
  278         vacc0x0123, vget_low_s16(vxb01234567c0), vget_low_s16(vxa0), 0);
  279     vacc0x4567 = vmlal_lane_s16(
  280         vacc0x4567, vget_high_s16(vxb01234567c0), vget_low_s16(vxa0), 0);
  281     vacc1x0123 = vmlal_lane_s16(
  282         vacc1x0123, vget_low_s16(vxb01234567c0), vget_low_s16(vxa1), 0);
  283     vacc1x4567 = vmlal_lane_s16(
  284         vacc1x4567, vget_high_s16(vxb01234567c0), vget_low_s16(vxa1), 0);
  285     vacc2x0123 = vmlal_lane_s16(
  286         vacc2x0123, vget_low_s16(vxb01234567c0), vget_low_s16(vxa2), 0);
  287     vacc2x4567 = vmlal_lane_s16(
  288         vacc2x4567, vget_high_s16(vxb01234567c0), vget_low_s16(vxa2), 0);
  289     vacc3x0123 = vmlal_lane_s16(
  290         vacc3x0123, vget_low_s16(vxb01234567c0), vget_low_s16(vxa3), 0);
  291     vacc3x4567 = vmlal_lane_s16(
  292         vacc3x4567, vget_high_s16(vxb01234567c0), vget_low_s16(vxa3), 0);
  293 
  294     if (k >= 2) {
  295       const uint8x8_t vb01234567c1 = vld1_u8(w);
  296       w = (const void*)((uintptr_t)w + 8);
  297       const int16x8_t vxb01234567c1 =
  298           vreinterpretq_s16_u16(vsubl_u8(vb01234567c1, vb_zero_point));
  299 
  300       vacc0x0123 = vmlal_lane_s16(
  301           vacc0x0123, vget_low_s16(vxb01234567c1), vget_low_s16(vxa0), 1);
  302       vacc0x4567 = vmlal_lane_s16(
  303           vacc0x4567, vget_high_s16(vxb01234567c1), vget_low_s16(vxa0), 1);
  304       vacc1x0123 = vmlal_lane_s16(
  305           vacc1x0123, vget_low_s16(vxb01234567c1), vget_low_s16(vxa1), 1);
  306       vacc1x4567 = vmlal_lane_s16(
  307           vacc1x4567, vget_high_s16(vxb01234567c1), vget_low_s16(vxa1), 1);
  308       vacc2x0123 = vmlal_lane_s16(
  309           vacc2x0123, vget_low_s16(vxb01234567c1), vget_low_s16(vxa2), 1);
  310       vacc2x4567 = vmlal_lane_s16(
  311           vacc2x4567, vget_high_s16(vxb01234567c1), vget_low_s16(vxa2), 1);
  312       vacc3x0123 = vmlal_lane_s16(
  313           vacc3x0123, vget_low_s16(vxb01234567c1), vget_low_s16(vxa3), 1);
  314       vacc3x4567 = vmlal_lane_s16(
  315           vacc3x4567, vget_high_s16(vxb01234567c1), vget_low_s16(vxa3), 1);
  316 
  317       if (k >= 3) {
  318         const uint8x8_t vb01234567c2 = vld1_u8(w);
  319         w = (const void*)((uintptr_t)w + 8);
  320         const int16x8_t vxb01234567c2 =
  321             vreinterpretq_s16_u16(vsubl_u8(vb01234567c2, vb_zero_point));
  322 
  323         vacc0x0123 = vmlal_lane_s16(
  324             vacc0x0123, vget_low_s16(vxb01234567c2), vget_low_s16(vxa0), 2);
  325         vacc0x4567 = vmlal_lane_s16(
  326             vacc0x4567, vget_high_s16(vxb01234567c2), vget_low_s16(vxa0), 2);
  327         vacc1x0123 = vmlal_lane_s16(
  328             vacc1x0123, vget_low_s16(vxb01234567c2), vget_low_s16(vxa1), 2);
  329         vacc1x4567 = vmlal_lane_s16(
  330             vacc1x4567, vget_high_s16(vxb01234567c2), vget_low_s16(vxa1), 2);
  331         vacc2x0123 = vmlal_lane_s16(
  332             vacc2x0123, vget_low_s16(vxb01234567c2), vget_low_s16(vxa2), 2);
  333         vacc2x4567 = vmlal_lane_s16(
  334             vacc2x4567, vget_high_s16(vxb01234567c2), vget_low_s16(vxa2), 2);
  335         vacc3x0123 = vmlal_lane_s16(
  336             vacc3x0123, vget_low_s16(vxb01234567c2), vget_low_s16(vxa3), 2);
  337         vacc3x4567 = vmlal_lane_s16(
  338             vacc3x4567, vget_high_s16(vxb01234567c2), vget_low_s16(vxa3), 2);
  339 
  340         if (k >= 4) {
  341           const uint8x8_t vb01234567c3 = vld1_u8(w);
  342           w = (const void*)((uintptr_t)w + 8);
  343           const int16x8_t vxb01234567c3 =
  344               vreinterpretq_s16_u16(vsubl_u8(vb01234567c3, vb_zero_point));
  345 
  346           vacc0x0123 = vmlal_lane_s16(
  347               vacc0x0123, vget_low_s16(vxb01234567c3), vget_low_s16(vxa0), 3);
  348           vacc0x4567 = vmlal_lane_s16(
  349               vacc0x4567, vget_high_s16(vxb01234567c3), vget_low_s16(vxa0), 3);
  350           vacc1x0123 = vmlal_lane_s16(
  351               vacc1x0123, vget_low_s16(vxb01234567c3), vget_low_s16(vxa1), 3);
  352           vacc1x4567 = vmlal_lane_s16(
  353               vacc1x4567, vget_high_s16(vxb01234567c3), vget_low_s16(vxa1), 3);
  354           vacc2x0123 = vmlal_lane_s16(
  355               vacc2x0123, vget_low_s16(vxb01234567c3), vget_low_s16(vxa2), 3);
  356           vacc2x4567 = vmlal_lane_s16(
  357               vacc2x4567, vget_high_s16(vxb01234567c3), vget_low_s16(vxa2), 3);
  358           vacc3x0123 = vmlal_lane_s16(
  359               vacc3x0123, vget_low_s16(vxb01234567c3), vget_low_s16(vxa3), 3);
  360           vacc3x4567 = vmlal_lane_s16(
  361               vacc3x4567, vget_high_s16(vxb01234567c3), vget_low_s16(vxa3), 3);
  362 
  363           if (k >= 5) {
  364             const uint8x8_t vb01234567c4 = vld1_u8(w);
  365             w = (const void*)((uintptr_t)w + 8);
  366             const int16x8_t vxb01234567c4 =
  367                 vreinterpretq_s16_u16(vsubl_u8(vb01234567c4, vb_zero_point));
  368 
  369             vacc0x0123 = vmlal_lane_s16(
  370                 vacc0x0123,
  371                 vget_low_s16(vxb01234567c4),
  372                 vget_high_s16(vxa0),
  373                 0);
  374             vacc0x4567 = vmlal_lane_s16(
  375                 vacc0x4567,
  376                 vget_high_s16(vxb01234567c4),
  377                 vget_high_s16(vxa0),
  378                 0);
  379             vacc1x0123 = vmlal_lane_s16(
  380                 vacc1x0123,
  381                 vget_low_s16(vxb01234567c4),
  382                 vget_high_s16(vxa1),
  383                 0);
  384             vacc1x4567 = vmlal_lane_s16(
  385                 vacc1x4567,
  386                 vget_high_s16(vxb01234567c4),
  387                 vget_high_s16(vxa1),
  388                 0);
  389             vacc2x0123 = vmlal_lane_s16(
  390                 vacc2x0123,
  391                 vget_low_s16(vxb01234567c4),
  392                 vget_high_s16(vxa2),
  393                 0);
  394             vacc2x4567 = vmlal_lane_s16(
  395                 vacc2x4567,
  396                 vget_high_s16(vxb01234567c4),
  397                 vget_high_s16(vxa2),
  398                 0);
  399             vacc3x0123 = vmlal_lane_s16(
  400                 vacc3x0123,
  401                 vget_low_s16(vxb01234567c4),
  402                 vget_high_s16(vxa3),
  403                 0);
  404             vacc3x4567 = vmlal_lane_s16(
  405                 vacc3x4567,
  406                 vget_high_s16(vxb01234567c4),
  407                 vget_high_s16(vxa3),
  408                 0);
  409 
  410             if (k >= 6) {
  411               const uint8x8_t vb01234567c5 = vld1_u8(w);
  412               w = (const void*)((uintptr_t)w + 8);
  413               const int16x8_t vxb01234567c5 =
  414                   vreinterpretq_s16_u16(vsubl_u8(vb01234567c5, vb_zero_point));
  415 
  416               vacc0x0123 = vmlal_lane_s16(
  417                   vacc0x0123,
  418                   vget_low_s16(vxb01234567c5),
  419                   vget_high_s16(vxa0),
  420                   1);
  421               vacc0x4567 = vmlal_lane_s16(
  422                   vacc0x4567,
  423                   vget_high_s16(vxb01234567c5),
  424                   vget_high_s16(vxa0),
  425                   1);
  426               vacc1x0123 = vmlal_lane_s16(
  427                   vacc1x0123,
  428                   vget_low_s16(vxb01234567c5),
  429                   vget_high_s16(vxa1),
  430                   1);
  431               vacc1x4567 = vmlal_lane_s16(
  432                   vacc1x4567,
  433                   vget_high_s16(vxb01234567c5),
  434                   vget_high_s16(vxa1),
  435                   1);
  436               vacc2x0123 = vmlal_lane_s16(
  437                   vacc2x0123,
  438                   vget_low_s16(vxb01234567c5),
  439                   vget_high_s16(vxa2),
  440                   1);
  441               vacc2x4567 = vmlal_lane_s16(
  442                   vacc2x4567,
  443                   vget_high_s16(vxb01234567c5),
  444                   vget_high_s16(vxa2),
  445                   1);
  446               vacc3x0123 = vmlal_lane_s16(
  447                   vacc3x0123,
  448                   vget_low_s16(vxb01234567c5),
  449                   vget_high_s16(vxa3),
  450                   1);
  451               vacc3x4567 = vmlal_lane_s16(
  452                   vacc3x4567,
  453                   vget_high_s16(vxb01234567c5),
  454                   vget_high_s16(vxa3),
  455                   1);
  456 
  457               if (k >= 7) {
  458                 const uint8x8_t vb01234567c6 = vld1_u8(w);
  459                 w = (const void*)((uintptr_t)w + 8);
  460                 const int16x8_t vxb01234567c6 = vreinterpretq_s16_u16(
  461                     vsubl_u8(vb01234567c6, vb_zero_point));
  462 
  463                 vacc0x0123 = vmlal_lane_s16(
  464                     vacc0x0123,
  465                     vget_low_s16(vxb01234567c6),
  466                     vget_high_s16(vxa0),
  467                     2);
  468                 vacc0x4567 = vmlal_lane_s16(
  469                     vacc0x4567,
  470                     vget_high_s16(vxb01234567c6),
  471                     vget_high_s16(vxa0),
  472                     2);
  473                 vacc1x0123 = vmlal_lane_s16(
  474                     vacc1x0123,
  475                     vget_low_s16(vxb01234567c6),
  476                     vget_high_s16(vxa1),
  477                     2);
  478                 vacc1x4567 = vmlal_lane_s16(
  479                     vacc1x4567,
  480                     vget_high_s16(vxb01234567c6),
  481                     vget_high_s16(vxa1),
  482                     2);
  483                 vacc2x0123 = vmlal_lane_s16(
  484                     vacc2x0123,
  485                     vget_low_s16(vxb01234567c6),
  486                     vget_high_s16(vxa2),
  487                     2);
  488                 vacc2x4567 = vmlal_lane_s16(
  489                     vacc2x4567,
  490                     vget_high_s16(vxb01234567c6),
  491                     vget_high_s16(vxa2),
  492                     2);
  493                 vacc3x0123 = vmlal_lane_s16(
  494                     vacc3x0123,
  495                     vget_low_s16(vxb01234567c6),
  496                     vget_high_s16(vxa3),
  497                     2);
  498                 vacc3x4567 = vmlal_lane_s16(
  499                     vacc3x4567,
  500                     vget_high_s16(vxb01234567c6),
  501                     vget_high_s16(vxa3),
  502                     2);
  503               }
  504             }
  505           }
  506         }
  507       }
  508     }
  509   }
  510 
  511   // Doing 2 VLD1 instead of 1 VLD2 because A75 has higher latency
  512   // 8 vs. 5 for VLD2 with both VLD1 and VLD2 having throughput of
  513   // 2 per cycle. So probably this is better.
  514   const float32x4_t requantization_scale_c0123 =
  515       vld1q_f32(
  516           &quantization_params->neon.requantization_scales[output_channel_index]
  517           );
  518   const float32x4_t requantization_scale_c4567 =
  519       vld1q_f32(
  520           &quantization_params->neon.requantization_scales[
  521               output_channel_index + 4]);
  522 
  523   /*
  524    * Convert int32_t input to FP32 and multiply by FP32 scale.
  525    * Both operations involve statistically unbiased roundings:
  526    * - Large int32_t values can't be exactly represented as FP32. The
  527    * conversion instruction in ARM NEON would round it to nearest FP32 value
  528    * with ties to even.
  529    * - Product of two FP32 values is generally not exactly representation as
  530    * an FP32 value, and will be rounded to nearest FP32 value with ties to
  531    * even.
  532    */
  533   const float32x4_t vacc0x0123_f =
  534     vmulq_f32(vcvtq_f32_s32(vacc0x0123), requantization_scale_c0123);
  535   const float32x4_t vacc1x0123_f =
  536     vmulq_f32(vcvtq_f32_s32(vacc1x0123), requantization_scale_c0123);
  537   const float32x4_t vacc2x0123_f =
  538     vmulq_f32(vcvtq_f32_s32(vacc2x0123), requantization_scale_c0123);
  539   const float32x4_t vacc3x0123_f =
  540     vmulq_f32(vcvtq_f32_s32(vacc3x0123), requantization_scale_c0123);
  541   const float32x4_t vacc0x4567_f =
  542     vmulq_f32(vcvtq_f32_s32(vacc0x4567), requantization_scale_c4567);
  543   const float32x4_t vacc1x4567_f =
  544     vmulq_f32(vcvtq_f32_s32(vacc1x4567), requantization_scale_c4567);
  545   const float32x4_t vacc2x4567_f =
  546     vmulq_f32(vcvtq_f32_s32(vacc2x4567), requantization_scale_c4567);
  547   const float32x4_t vacc3x4567_f =
  548     vmulq_f32(vcvtq_f32_s32(vacc3x4567), requantization_scale_c4567);
  549 
  550 #ifdef __aarch64__
  551   const int16x8_t voutput_zero_point =
  552       vld1q_dup_s16(&quantization_params->neon.output_zero_point);
  553   /*
  554    * Leverage "Floating-point Convert to Signed integer, rounding to nearest
  555    * with ties to even" instruction. This is an ARMv8 instruction (always
  556    * available in AArch64), which saturates result on overflow. We don't need
  557    * to specifically consider saturated results, they will be clamped at the
  558    * last stage.
  559    */
  560   vacc0x0123 = vcvtnq_s32_f32(vacc0x0123_f);
  561   vacc1x0123 = vcvtnq_s32_f32(vacc1x0123_f);
  562   vacc2x0123 = vcvtnq_s32_f32(vacc2x0123_f);
  563   vacc3x0123 = vcvtnq_s32_f32(vacc3x0123_f);
  564   vacc0x4567 = vcvtnq_s32_f32(vacc0x4567_f);
  565   vacc1x4567 = vcvtnq_s32_f32(vacc1x4567_f);
  566   vacc2x4567 = vcvtnq_s32_f32(vacc2x4567_f);
  567   vacc3x4567 = vcvtnq_s32_f32(vacc3x4567_f);
  568 
  569   const int16x8_t vacc0x01234567 = vqaddq_s16(
  570       vqmovn_high_s32(vqmovn_s32(vacc0x0123), vacc0x4567), voutput_zero_point);
  571   const int16x8_t vacc1x01234567 = vqaddq_s16(
  572       vqmovn_high_s32(vqmovn_s32(vacc1x0123), vacc1x4567), voutput_zero_point);
  573   const int16x8_t vacc2x01234567 = vqaddq_s16(
  574       vqmovn_high_s32(vqmovn_s32(vacc2x0123), vacc2x4567), voutput_zero_point);
  575   const int16x8_t vacc3x01234567 = vqaddq_s16(
  576       vqmovn_high_s32(vqmovn_s32(vacc3x0123), vacc3x4567), voutput_zero_point);
  577 
  578   uint8x16_t vout0x01234567_1x01234567 =
  579       vqmovun_high_s16(vqmovun_s16(vacc0x01234567), vacc1x01234567);
  580   uint8x16_t vout2x01234567_3x01234567 =
  581       vqmovun_high_s16(vqmovun_s16(vacc2x01234567), vacc3x01234567);
  582 
  583   const uint8x16_t voutput_min =
  584       vld1q_dup_u8(&quantization_params->neon.output_min);
  585   const uint8x16_t voutput_max =
  586       vld1q_dup_u8(&quantization_params->neon.output_max);
  587 
  588   vout0x01234567_1x01234567 = vmaxq_u8(vout0x01234567_1x01234567, voutput_min);
  589   vout2x01234567_3x01234567 = vmaxq_u8(vout2x01234567_3x01234567, voutput_min);
  590   vout0x01234567_1x01234567 = vminq_u8(vout0x01234567_1x01234567, voutput_max);
  591   vout2x01234567_3x01234567 = vminq_u8(vout2x01234567_3x01234567, voutput_max);
  592 #else
  593   const float32x4_t vfmin = vdupq_n_f32(quantization_params->neon.vfmin);
  594   const float32x4_t vfmax = vdupq_n_f32(quantization_params->neon.vfmax);
  595   const float32x4_t vfmagic = vdupq_n_f32(quantization_params->neon.vfmagic);
  596   const int32x4_t vimagic = vdupq_n_s32(quantization_params->neon.vimagic);
  597   /*
  598    * ARMv7 NEON offers only a floating-point to integer conversion instruction
  599    * with rounding towards zero. In lieu of conversion instruction with
  600    * rounding-to-nearest-even, we use a magic trick of adding a large number
  601    * (1.5 * 2**23) to scaled value to cause rounding to integer, and then
  602    * substracing this magic number as integer. This trick works only in a
  603    * limited range (absolute value of input must be less than 2**22), so
  604    * generally we have to clamp input to this range before using the magic.
  605    * However, clamping to any smaller range works just as well, and thus we
  606    * clamp to [qmin - zero point, qmax - zero point] range so that after we
  607    * add zero point to the result, it gets into target [qmin, qmax] range.
  608    */
  609   const float32x4_t vacc0x0123_f_clamped =
  610       vminq_f32(vmaxq_f32(vacc0x0123_f, vfmin), vfmax);
  611   const float32x4_t vacc1x0123_f_clamped =
  612       vminq_f32(vmaxq_f32(vacc1x0123_f, vfmin), vfmax);
  613   const float32x4_t vacc2x0123_f_clamped =
  614       vminq_f32(vmaxq_f32(vacc2x0123_f, vfmin), vfmax);
  615   const float32x4_t vacc3x0123_f_clamped =
  616       vminq_f32(vmaxq_f32(vacc3x0123_f, vfmin), vfmax);
  617   const float32x4_t vacc0x4567_f_clamped =
  618       vminq_f32(vmaxq_f32(vacc0x4567_f, vfmin), vfmax);
  619   const float32x4_t vacc1x4567_f_clamped =
  620       vminq_f32(vmaxq_f32(vacc1x4567_f, vfmin), vfmax);
  621   const float32x4_t vacc2x4567_f_clamped =
  622       vminq_f32(vmaxq_f32(vacc2x4567_f, vfmin), vfmax);
  623   const float32x4_t vacc3x4567_f_clamped =
  624       vminq_f32(vmaxq_f32(vacc3x4567_f, vfmin), vfmax);
  625 
  626   /*
  627    * Conversion to integer using the "magic trick". Rounding is performed in
  628    * the output of addition operation, and result is rounded to nearest even
  629    * integer with ties to even.
  630    */
  631   vacc0x0123 = vsubq_s32(
  632       vreinterpretq_s32_f32(vaddq_f32(vacc0x0123_f_clamped, vfmagic)), vimagic);
  633   vacc1x0123 = vsubq_s32(
  634       vreinterpretq_s32_f32(vaddq_f32(vacc1x0123_f_clamped, vfmagic)), vimagic);
  635   vacc2x0123 = vsubq_s32(
  636       vreinterpretq_s32_f32(vaddq_f32(vacc2x0123_f_clamped, vfmagic)), vimagic);
  637   vacc3x0123 = vsubq_s32(
  638       vreinterpretq_s32_f32(vaddq_f32(vacc3x0123_f_clamped, vfmagic)), vimagic);
  639   vacc0x4567 = vsubq_s32(
  640       vreinterpretq_s32_f32(vaddq_f32(vacc0x4567_f_clamped, vfmagic)), vimagic);
  641   vacc1x4567 = vsubq_s32(
  642       vreinterpretq_s32_f32(vaddq_f32(vacc1x4567_f_clamped, vfmagic)), vimagic);
  643   vacc2x4567 = vsubq_s32(
  644       vreinterpretq_s32_f32(vaddq_f32(vacc2x4567_f_clamped, vfmagic)), vimagic);
  645   vacc3x4567 = vsubq_s32(
  646       vreinterpretq_s32_f32(vaddq_f32(vacc3x4567_f_clamped, vfmagic)), vimagic);
  647 
  648   const int16x8_t vacc0x01234567 =
  649       vcombine_s16(vqmovn_s32(vacc0x0123), vqmovn_s32(vacc0x4567));
  650   const int16x8_t vacc1x01234567 =
  651       vcombine_s16(vqmovn_s32(vacc1x0123), vqmovn_s32(vacc1x4567));
  652   const int16x8_t vacc2x01234567 =
  653       vcombine_s16(vqmovn_s32(vacc2x0123), vqmovn_s32(vacc2x4567));
  654   const int16x8_t vacc3x01234567 =
  655       vcombine_s16(vqmovn_s32(vacc3x0123), vqmovn_s32(vacc3x4567));
  656 
  657   uint8x16_t vout0x01234567_1x01234567 =
  658       vcombine_u8(vqmovun_s16(vacc0x01234567), vqmovun_s16(vacc1x01234567));
  659   uint8x16_t vout2x01234567_3x01234567 =
  660       vcombine_u8(vqmovun_s16(vacc2x01234567), vqmovun_s16(vacc3x01234567));
  661 #endif
  662 
  663   uint8_t* c0 = c;
  664   uint8_t* c1 = (uint8_t*)((uintptr_t)c0 + c_stride);
  665   if (mr < 2) {
  666     c1 = c0;
  667   }
  668   uint8_t* c2 = (uint8_t*)((uintptr_t)c1 + c_stride);
  669   if (mr <= 2) {
  670     c2 = c1;
  671   }
  672   uint8_t* c3 = (uint8_t*)((uintptr_t)c2 + c_stride);
  673   if (mr != 4) {
  674     c3 = c2;
  675   }
  676   if (nr == 8) {
  677     vst1_u8(c0, vget_low_u8(vout0x01234567_1x01234567));
  678     vst1_u8(c1, vget_high_u8(vout0x01234567_1x01234567));
  679     vst1_u8(c2, vget_low_u8(vout2x01234567_3x01234567));
  680     vst1_u8(c3, vget_high_u8(vout2x01234567_3x01234567));
  681   } else {
  682     if (nr >= 4) {
  683       vst1q_lane_u32(
  684           __builtin_assume_aligned(c0, 1),
  685           vreinterpretq_u32_u8(vout0x01234567_1x01234567),
  686           0);
  687       c0 += 4;
  688       vst1q_lane_u32(
  689           __builtin_assume_aligned(c1, 1),
  690           vreinterpretq_u32_u8(vout0x01234567_1x01234567),
  691           2);
  692       c1 += 4;
  693       vst1q_lane_u32(
  694           __builtin_assume_aligned(c2, 1),
  695           vreinterpretq_u32_u8(vout2x01234567_3x01234567),
  696           0);
  697       c2 += 4;
  698       vst1q_lane_u32(
  699           __builtin_assume_aligned(c3, 1),
  700           vreinterpretq_u32_u8(vout2x01234567_3x01234567),
  701           2);
  702       c3 += 4;
  703       vout0x01234567_1x01234567 =
  704           vextq_u8(vout0x01234567_1x01234567, vout0x01234567_1x01234567, 4);
  705       vout2x01234567_3x01234567 =
  706           vextq_u8(vout2x01234567_3x01234567, vout2x01234567_3x01234567, 4);
  707       nr -= 4;
  708     }
  709     if (nr >= 2) {
  710       vst1q_lane_u16(
  711           __builtin_assume_aligned(c0, 1),
  712           vreinterpretq_u16_u8(vout0x01234567_1x01234567),
  713           0);
  714       c0 += 2;
  715       vst1q_lane_u16(
  716           __builtin_assume_aligned(c1, 1),
  717           vreinterpretq_u16_u8(vout0x01234567_1x01234567),
  718           4);
  719       c1 += 2;
  720       vst1q_lane_u16(
  721           __builtin_assume_aligned(c2, 1),
  722           vreinterpretq_u16_u8(vout2x01234567_3x01234567),
  723           0);
  724       c2 += 2;
  725       vst1q_lane_u16(
  726           __builtin_assume_aligned(c3, 1),
  727           vreinterpretq_u16_u8(vout2x01234567_3x01234567),
  728           4);
  729       c3 += 2;
  730       vout0x01234567_1x01234567 =
  731           vextq_u8(vout0x01234567_1x01234567, vout0x01234567_1x01234567, 2);
  732       vout2x01234567_3x01234567 =
  733           vextq_u8(vout2x01234567_3x01234567, vout2x01234567_3x01234567, 2);
  734       nr -= 2;
  735     }
  736     if (nr != 0) {
  737       vst1q_lane_u8(c0, vout0x01234567_1x01234567, 0);
  738       vst1q_lane_u8(c1, vout0x01234567_1x01234567, 8);
  739       vst1q_lane_u8(c2, vout2x01234567_3x01234567, 0);
  740       vst1q_lane_u8(c3, vout2x01234567_3x01234567, 8);
  741     }
  742   }
  743 }