"Fossies" - the Fresh Open Source Software Archive

Member "pytorch-1.8.2/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8conv/4x8-neon.c" (23 Jul 2021, 32536 Bytes) of package /linux/misc/pytorch-1.8.2.tar.gz:


As a special service "Fossies" has tried to format the requested source page into HTML format using (guessed) C and C++ source code syntax highlighting (style: standard) with prefixed line numbers and code folding option. Alternatively you can here view or download the uninterpreted source code file. For more information about "4x8-neon.c" see the Fossies "Dox" file reference documentation.

    1 /*
    2  * Copyright (c) Facebook, Inc. and its affiliates.
    3  * All rights reserved.
    4  *
    5  * This source code is licensed under the BSD-style license found in the
    6  * LICENSE file in the root directory of this source tree.
    7  */
    8 
    9 #include <arm_neon.h>
   10 
   11 #include <qnnpack/q8conv.h>
   12 #include <requantization/runtime-neon.h>
   13 
   14 void pytorch_q8conv_ukernel_4x8__neon(
   15     size_t mr,
   16     size_t nr,
   17     size_t kc,
   18     size_t ks,
   19     const uint8_t** restrict a,
   20     const void* restrict w,
   21     uint8_t* restrict c,
   22     size_t c_stride,
   23     size_t output_channel_index,
   24     const union pytorch_qnnp_conv_quantization_params
   25         quantization_params[restrict static 1]) {
   26   const uint8x8_t va_zero_point =
   27       vld1_dup_u8((const uint8_t*)&quantization_params->neon.input_zero_point);
   28   // Assumes that kernel_zero_points is an array padded with necessary elements
   29   // in order to make it multiple of 8.
   30   const uint8x8_t vb_zero_point =
   31       vld1_u8((const uint8_t*)&quantization_params->neon.kernel_zero_points
   32           [output_channel_index]);
   33 
   34   int32x4_t vacc0x0123 = vld1q_s32(w);
   35   w = (void*)((uintptr_t)w + sizeof(int32x4_t));
   36   int32x4_t vacc0x4567 = vld1q_s32(w);
   37   w = (void*)((uintptr_t)w + sizeof(int32x4_t));
   38   int32x4_t vacc1x0123 = vacc0x0123;
   39   int32x4_t vacc1x4567 = vacc0x4567;
   40   int32x4_t vacc2x0123 = vacc0x0123;
   41   int32x4_t vacc2x4567 = vacc0x4567;
   42   int32x4_t vacc3x0123 = vacc0x0123;
   43   int32x4_t vacc3x4567 = vacc0x4567;
   44 
   45   do {
   46     const uint8_t* restrict a0 = *a++;
   47     const uint8_t* restrict a1 = *a++;
   48     const uint8_t* restrict a2 = *a++;
   49     const uint8_t* restrict a3 = *a++;
   50 
   51     size_t k = kc;
   52     for (; k >= 8; k -= 8) {
   53       const uint8x8_t va0 = vld1_u8(a0);
   54       a0 += 8;
   55       const uint8x8_t va1 = vld1_u8(a1);
   56       a1 += 8;
   57       const uint8x8_t va2 = vld1_u8(a2);
   58       a2 += 8;
   59       const uint8x8_t va3 = vld1_u8(a3);
   60       a3 += 8;
   61       const int16x8_t vxa0 =
   62           vreinterpretq_s16_u16(sub_zero_point(va0, va_zero_point));
   63       const int16x8_t vxa1 =
   64           vreinterpretq_s16_u16(sub_zero_point(va1, va_zero_point));
   65       const int16x8_t vxa2 =
   66           vreinterpretq_s16_u16(sub_zero_point(va2, va_zero_point));
   67       const int16x8_t vxa3 =
   68           vreinterpretq_s16_u16(sub_zero_point(va3, va_zero_point));
   69 
   70       {
   71         const uint8x8_t vb01234567 = vld1_u8(w);
   72         w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
   73         const int16x8_t vxb01234567 =
   74             vreinterpretq_s16_u16(vsubl_u8(vb01234567, vb_zero_point));
   75 
   76         vacc0x0123 = vmlal_lane_s16(
   77             vacc0x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa0), 0);
   78         vacc0x4567 = vmlal_lane_s16(
   79             vacc0x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa0), 0);
   80         vacc1x0123 = vmlal_lane_s16(
   81             vacc1x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa1), 0);
   82         vacc1x4567 = vmlal_lane_s16(
   83             vacc1x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa1), 0);
   84         vacc2x0123 = vmlal_lane_s16(
   85             vacc2x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa2), 0);
   86         vacc2x4567 = vmlal_lane_s16(
   87             vacc2x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa2), 0);
   88         vacc3x0123 = vmlal_lane_s16(
   89             vacc3x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa3), 0);
   90         vacc3x4567 = vmlal_lane_s16(
   91             vacc3x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa3), 0);
   92       }
   93 
   94       {
   95         const uint8x8_t vb01234567 = vld1_u8(w);
   96         w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
   97         const int16x8_t vxb01234567 =
   98             vreinterpretq_s16_u16(vsubl_u8(vb01234567, vb_zero_point));
   99 
  100         vacc0x0123 = vmlal_lane_s16(
  101             vacc0x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa0), 1);
  102         vacc0x4567 = vmlal_lane_s16(
  103             vacc0x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa0), 1);
  104         vacc1x0123 = vmlal_lane_s16(
  105             vacc1x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa1), 1);
  106         vacc1x4567 = vmlal_lane_s16(
  107             vacc1x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa1), 1);
  108         vacc2x0123 = vmlal_lane_s16(
  109             vacc2x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa2), 1);
  110         vacc2x4567 = vmlal_lane_s16(
  111             vacc2x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa2), 1);
  112         vacc3x0123 = vmlal_lane_s16(
  113             vacc3x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa3), 1);
  114         vacc3x4567 = vmlal_lane_s16(
  115             vacc3x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa3), 1);
  116       }
  117 
  118       {
  119         const uint8x8_t vb01234567 = vld1_u8(w);
  120         w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
  121         const int16x8_t vxb01234567 =
  122             vreinterpretq_s16_u16(vsubl_u8(vb01234567, vb_zero_point));
  123 
  124         vacc0x0123 = vmlal_lane_s16(
  125             vacc0x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa0), 2);
  126         vacc0x4567 = vmlal_lane_s16(
  127             vacc0x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa0), 2);
  128         vacc1x0123 = vmlal_lane_s16(
  129             vacc1x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa1), 2);
  130         vacc1x4567 = vmlal_lane_s16(
  131             vacc1x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa1), 2);
  132         vacc2x0123 = vmlal_lane_s16(
  133             vacc2x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa2), 2);
  134         vacc2x4567 = vmlal_lane_s16(
  135             vacc2x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa2), 2);
  136         vacc3x0123 = vmlal_lane_s16(
  137             vacc3x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa3), 2);
  138         vacc3x4567 = vmlal_lane_s16(
  139             vacc3x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa3), 2);
  140       }
  141 
  142       {
  143         const uint8x8_t vb01234567 = vld1_u8(w);
  144         w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
  145         const int16x8_t vxb01234567 =
  146             vreinterpretq_s16_u16(vsubl_u8(vb01234567, vb_zero_point));
  147 
  148         vacc0x0123 = vmlal_lane_s16(
  149             vacc0x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa0), 3);
  150         vacc0x4567 = vmlal_lane_s16(
  151             vacc0x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa0), 3);
  152         vacc1x0123 = vmlal_lane_s16(
  153             vacc1x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa1), 3);
  154         vacc1x4567 = vmlal_lane_s16(
  155             vacc1x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa1), 3);
  156         vacc2x0123 = vmlal_lane_s16(
  157             vacc2x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa2), 3);
  158         vacc2x4567 = vmlal_lane_s16(
  159             vacc2x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa2), 3);
  160         vacc3x0123 = vmlal_lane_s16(
  161             vacc3x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa3), 3);
  162         vacc3x4567 = vmlal_lane_s16(
  163             vacc3x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa3), 3);
  164       }
  165 
  166       {
  167         const uint8x8_t vb01234567 = vld1_u8(w);
  168         w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
  169         const int16x8_t vxb01234567 =
  170             vreinterpretq_s16_u16(vsubl_u8(vb01234567, vb_zero_point));
  171 
  172         vacc0x0123 = vmlal_lane_s16(
  173             vacc0x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa0), 0);
  174         vacc0x4567 = vmlal_lane_s16(
  175             vacc0x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa0), 0);
  176         vacc1x0123 = vmlal_lane_s16(
  177             vacc1x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa1), 0);
  178         vacc1x4567 = vmlal_lane_s16(
  179             vacc1x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa1), 0);
  180         vacc2x0123 = vmlal_lane_s16(
  181             vacc2x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa2), 0);
  182         vacc2x4567 = vmlal_lane_s16(
  183             vacc2x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa2), 0);
  184         vacc3x0123 = vmlal_lane_s16(
  185             vacc3x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa3), 0);
  186         vacc3x4567 = vmlal_lane_s16(
  187             vacc3x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa3), 0);
  188       }
  189 
  190       {
  191         const uint8x8_t vb01234567 = vld1_u8(w);
  192         w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
  193         const int16x8_t vxb01234567 =
  194             vreinterpretq_s16_u16(vsubl_u8(vb01234567, vb_zero_point));
  195 
  196         vacc0x0123 = vmlal_lane_s16(
  197             vacc0x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa0), 1);
  198         vacc0x4567 = vmlal_lane_s16(
  199             vacc0x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa0), 1);
  200         vacc1x0123 = vmlal_lane_s16(
  201             vacc1x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa1), 1);
  202         vacc1x4567 = vmlal_lane_s16(
  203             vacc1x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa1), 1);
  204         vacc2x0123 = vmlal_lane_s16(
  205             vacc2x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa2), 1);
  206         vacc2x4567 = vmlal_lane_s16(
  207             vacc2x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa2), 1);
  208         vacc3x0123 = vmlal_lane_s16(
  209             vacc3x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa3), 1);
  210         vacc3x4567 = vmlal_lane_s16(
  211             vacc3x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa3), 1);
  212       }
  213 
  214       {
  215         const uint8x8_t vb01234567 = vld1_u8(w);
  216         w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
  217         const int16x8_t vxb01234567 =
  218             vreinterpretq_s16_u16(vsubl_u8(vb01234567, vb_zero_point));
  219 
  220         vacc0x0123 = vmlal_lane_s16(
  221             vacc0x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa0), 2);
  222         vacc0x4567 = vmlal_lane_s16(
  223             vacc0x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa0), 2);
  224         vacc1x0123 = vmlal_lane_s16(
  225             vacc1x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa1), 2);
  226         vacc1x4567 = vmlal_lane_s16(
  227             vacc1x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa1), 2);
  228         vacc2x0123 = vmlal_lane_s16(
  229             vacc2x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa2), 2);
  230         vacc2x4567 = vmlal_lane_s16(
  231             vacc2x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa2), 2);
  232         vacc3x0123 = vmlal_lane_s16(
  233             vacc3x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa3), 2);
  234         vacc3x4567 = vmlal_lane_s16(
  235             vacc3x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa3), 2);
  236       }
  237 
  238       {
  239         const uint8x8_t vb01234567 = vld1_u8(w);
  240         w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
  241         const int16x8_t vxb01234567 =
  242             vreinterpretq_s16_u16(vsubl_u8(vb01234567, vb_zero_point));
  243 
  244         vacc0x0123 = vmlal_lane_s16(
  245             vacc0x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa0), 3);
  246         vacc0x4567 = vmlal_lane_s16(
  247             vacc0x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa0), 3);
  248         vacc1x0123 = vmlal_lane_s16(
  249             vacc1x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa1), 3);
  250         vacc1x4567 = vmlal_lane_s16(
  251             vacc1x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa1), 3);
  252         vacc2x0123 = vmlal_lane_s16(
  253             vacc2x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa2), 3);
  254         vacc2x4567 = vmlal_lane_s16(
  255             vacc2x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa2), 3);
  256         vacc3x0123 = vmlal_lane_s16(
  257             vacc3x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa3), 3);
  258         vacc3x4567 = vmlal_lane_s16(
  259             vacc3x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa3), 3);
  260       }
  261     }
  262     if (k != 0) {
  263       const size_t a_predecrement = 8 - k;
  264       const int64x1_t va_shift = vmov_n_s64(-8 * a_predecrement);
  265       const uint8x8_t va0 = vreinterpret_u8_u64(vshl_u64(
  266           vreinterpret_u64_u8(vld1_u8(a0 - a_predecrement)), va_shift));
  267       const uint8x8_t va1 = vreinterpret_u8_u64(vshl_u64(
  268           vreinterpret_u64_u8(vld1_u8(a1 - a_predecrement)), va_shift));
  269       const uint8x8_t va2 = vreinterpret_u8_u64(vshl_u64(
  270           vreinterpret_u64_u8(vld1_u8(a2 - a_predecrement)), va_shift));
  271       const uint8x8_t va3 = vreinterpret_u8_u64(vshl_u64(
  272           vreinterpret_u64_u8(vld1_u8(a3 - a_predecrement)), va_shift));
  273       const int16x8_t vxa0 =
  274           vreinterpretq_s16_u16(sub_zero_point(va0, va_zero_point));
  275       const int16x8_t vxa1 =
  276           vreinterpretq_s16_u16(sub_zero_point(va1, va_zero_point));
  277       const int16x8_t vxa2 =
  278           vreinterpretq_s16_u16(sub_zero_point(va2, va_zero_point));
  279       const int16x8_t vxa3 =
  280           vreinterpretq_s16_u16(sub_zero_point(va3, va_zero_point));
  281 
  282       {
  283         const uint8x8_t vb01234567 = vld1_u8(w);
  284         w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
  285         const int16x8_t vxb01234567 =
  286             vreinterpretq_s16_u16(vsubl_u8(vb01234567, vb_zero_point));
  287 
  288         vacc0x0123 = vmlal_lane_s16(
  289             vacc0x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa0), 0);
  290         vacc0x4567 = vmlal_lane_s16(
  291             vacc0x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa0), 0);
  292         vacc1x0123 = vmlal_lane_s16(
  293             vacc1x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa1), 0);
  294         vacc1x4567 = vmlal_lane_s16(
  295             vacc1x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa1), 0);
  296         vacc2x0123 = vmlal_lane_s16(
  297             vacc2x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa2), 0);
  298         vacc2x4567 = vmlal_lane_s16(
  299             vacc2x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa2), 0);
  300         vacc3x0123 = vmlal_lane_s16(
  301             vacc3x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa3), 0);
  302         vacc3x4567 = vmlal_lane_s16(
  303             vacc3x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa3), 0);
  304       }
  305 
  306       if (k >= 2) {
  307         const uint8x8_t vb01234567 = vld1_u8(w);
  308         w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
  309         const int16x8_t vxb01234567 =
  310             vreinterpretq_s16_u16(vsubl_u8(vb01234567, vb_zero_point));
  311 
  312         vacc0x0123 = vmlal_lane_s16(
  313             vacc0x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa0), 1);
  314         vacc0x4567 = vmlal_lane_s16(
  315             vacc0x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa0), 1);
  316         vacc1x0123 = vmlal_lane_s16(
  317             vacc1x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa1), 1);
  318         vacc1x4567 = vmlal_lane_s16(
  319             vacc1x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa1), 1);
  320         vacc2x0123 = vmlal_lane_s16(
  321             vacc2x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa2), 1);
  322         vacc2x4567 = vmlal_lane_s16(
  323             vacc2x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa2), 1);
  324         vacc3x0123 = vmlal_lane_s16(
  325             vacc3x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa3), 1);
  326         vacc3x4567 = vmlal_lane_s16(
  327             vacc3x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa3), 1);
  328 
  329         if (k > 2) {
  330           const uint8x8_t vb01234567 = vld1_u8(w);
  331           w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
  332           const int16x8_t vxb01234567 =
  333               vreinterpretq_s16_u16(vsubl_u8(vb01234567, vb_zero_point));
  334 
  335           vacc0x0123 = vmlal_lane_s16(
  336               vacc0x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa0), 2);
  337           vacc0x4567 = vmlal_lane_s16(
  338               vacc0x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa0), 2);
  339           vacc1x0123 = vmlal_lane_s16(
  340               vacc1x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa1), 2);
  341           vacc1x4567 = vmlal_lane_s16(
  342               vacc1x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa1), 2);
  343           vacc2x0123 = vmlal_lane_s16(
  344               vacc2x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa2), 2);
  345           vacc2x4567 = vmlal_lane_s16(
  346               vacc2x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa2), 2);
  347           vacc3x0123 = vmlal_lane_s16(
  348               vacc3x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa3), 2);
  349           vacc3x4567 = vmlal_lane_s16(
  350               vacc3x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa3), 2);
  351 
  352           if (k >= 4) {
  353             const uint8x8_t vb01234567 = vld1_u8(w);
  354             w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
  355             const int16x8_t vxb01234567 =
  356                 vreinterpretq_s16_u16(vsubl_u8(vb01234567, vb_zero_point));
  357 
  358             vacc0x0123 = vmlal_lane_s16(
  359                 vacc0x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa0), 3);
  360             vacc0x4567 = vmlal_lane_s16(
  361                 vacc0x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa0), 3);
  362             vacc1x0123 = vmlal_lane_s16(
  363                 vacc1x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa1), 3);
  364             vacc1x4567 = vmlal_lane_s16(
  365                 vacc1x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa1), 3);
  366             vacc2x0123 = vmlal_lane_s16(
  367                 vacc2x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa2), 3);
  368             vacc2x4567 = vmlal_lane_s16(
  369                 vacc2x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa2), 3);
  370             vacc3x0123 = vmlal_lane_s16(
  371                 vacc3x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa3), 3);
  372             vacc3x4567 = vmlal_lane_s16(
  373                 vacc3x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa3), 3);
  374 
  375             if (k > 4) {
  376               const uint8x8_t vb01234567 = vld1_u8(w);
  377               w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
  378               const int16x8_t vxb01234567 =
  379                   vreinterpretq_s16_u16(vsubl_u8(vb01234567, vb_zero_point));
  380 
  381               vacc0x0123 = vmlal_lane_s16(
  382                   vacc0x0123,
  383                   vget_low_s16(vxb01234567),
  384                   vget_high_s16(vxa0),
  385                   0);
  386               vacc0x4567 = vmlal_lane_s16(
  387                   vacc0x4567,
  388                   vget_high_s16(vxb01234567),
  389                   vget_high_s16(vxa0),
  390                   0);
  391               vacc1x0123 = vmlal_lane_s16(
  392                   vacc1x0123,
  393                   vget_low_s16(vxb01234567),
  394                   vget_high_s16(vxa1),
  395                   0);
  396               vacc1x4567 = vmlal_lane_s16(
  397                   vacc1x4567,
  398                   vget_high_s16(vxb01234567),
  399                   vget_high_s16(vxa1),
  400                   0);
  401               vacc2x0123 = vmlal_lane_s16(
  402                   vacc2x0123,
  403                   vget_low_s16(vxb01234567),
  404                   vget_high_s16(vxa2),
  405                   0);
  406               vacc2x4567 = vmlal_lane_s16(
  407                   vacc2x4567,
  408                   vget_high_s16(vxb01234567),
  409                   vget_high_s16(vxa2),
  410                   0);
  411               vacc3x0123 = vmlal_lane_s16(
  412                   vacc3x0123,
  413                   vget_low_s16(vxb01234567),
  414                   vget_high_s16(vxa3),
  415                   0);
  416               vacc3x4567 = vmlal_lane_s16(
  417                   vacc3x4567,
  418                   vget_high_s16(vxb01234567),
  419                   vget_high_s16(vxa3),
  420                   0);
  421 
  422               if (k >= 6) {
  423                 const uint8x8_t vb01234567 = vld1_u8(w);
  424                 w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
  425                 const int16x8_t vxb01234567 =
  426                     vreinterpretq_s16_u16(vsubl_u8(vb01234567, vb_zero_point));
  427 
  428                 vacc0x0123 = vmlal_lane_s16(
  429                     vacc0x0123,
  430                     vget_low_s16(vxb01234567),
  431                     vget_high_s16(vxa0),
  432                     1);
  433                 vacc0x4567 = vmlal_lane_s16(
  434                     vacc0x4567,
  435                     vget_high_s16(vxb01234567),
  436                     vget_high_s16(vxa0),
  437                     1);
  438                 vacc1x0123 = vmlal_lane_s16(
  439                     vacc1x0123,
  440                     vget_low_s16(vxb01234567),
  441                     vget_high_s16(vxa1),
  442                     1);
  443                 vacc1x4567 = vmlal_lane_s16(
  444                     vacc1x4567,
  445                     vget_high_s16(vxb01234567),
  446                     vget_high_s16(vxa1),
  447                     1);
  448                 vacc2x0123 = vmlal_lane_s16(
  449                     vacc2x0123,
  450                     vget_low_s16(vxb01234567),
  451                     vget_high_s16(vxa2),
  452                     1);
  453                 vacc2x4567 = vmlal_lane_s16(
  454                     vacc2x4567,
  455                     vget_high_s16(vxb01234567),
  456                     vget_high_s16(vxa2),
  457                     1);
  458                 vacc3x0123 = vmlal_lane_s16(
  459                     vacc3x0123,
  460                     vget_low_s16(vxb01234567),
  461                     vget_high_s16(vxa3),
  462                     1);
  463                 vacc3x4567 = vmlal_lane_s16(
  464                     vacc3x4567,
  465                     vget_high_s16(vxb01234567),
  466                     vget_high_s16(vxa3),
  467                     1);
  468 
  469                 if (k > 6) {
  470                   const uint8x8_t vb01234567 = vld1_u8(w);
  471                   w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
  472                   const int16x8_t vxb01234567 = vreinterpretq_s16_u16(
  473                       vsubl_u8(vb01234567, vb_zero_point));
  474 
  475                   vacc0x0123 = vmlal_lane_s16(
  476                       vacc0x0123,
  477                       vget_low_s16(vxb01234567),
  478                       vget_high_s16(vxa0),
  479                       2);
  480                   vacc0x4567 = vmlal_lane_s16(
  481                       vacc0x4567,
  482                       vget_high_s16(vxb01234567),
  483                       vget_high_s16(vxa0),
  484                       2);
  485                   vacc1x0123 = vmlal_lane_s16(
  486                       vacc1x0123,
  487                       vget_low_s16(vxb01234567),
  488                       vget_high_s16(vxa1),
  489                       2);
  490                   vacc1x4567 = vmlal_lane_s16(
  491                       vacc1x4567,
  492                       vget_high_s16(vxb01234567),
  493                       vget_high_s16(vxa1),
  494                       2);
  495                   vacc2x0123 = vmlal_lane_s16(
  496                       vacc2x0123,
  497                       vget_low_s16(vxb01234567),
  498                       vget_high_s16(vxa2),
  499                       2);
  500                   vacc2x4567 = vmlal_lane_s16(
  501                       vacc2x4567,
  502                       vget_high_s16(vxb01234567),
  503                       vget_high_s16(vxa2),
  504                       2);
  505                   vacc3x0123 = vmlal_lane_s16(
  506                       vacc3x0123,
  507                       vget_low_s16(vxb01234567),
  508                       vget_high_s16(vxa3),
  509                       2);
  510                   vacc3x4567 = vmlal_lane_s16(
  511                       vacc3x4567,
  512                       vget_high_s16(vxb01234567),
  513                       vget_high_s16(vxa3),
  514                       2);
  515                 }
  516               }
  517             }
  518           }
  519         }
  520       }
  521     }
  522   } while (--ks != 0);
  523 
  524   // Doing 2 VLD1 instead of 1 VLD2 because A75 has higher latency
  525   // 8 vs. 5 for VLD2 with both VLD1 and VLD2 having throughput of
  526   // 2 per cycle. So probably this is better.
  527   const float32x4_t requantization_scale_c0123 =
  528       vld1q_f32(
  529           &quantization_params->neon.requantization_scales[output_channel_index]
  530           );
  531   const float32x4_t requantization_scale_c4567 =
  532       vld1q_f32(
  533           &quantization_params->neon.requantization_scales[
  534               output_channel_index + 4]);
  535 
  536   /*
  537    * Convert int32_t input to FP32 and multiply by FP32 scale.
  538    * Both operations involve statistically unbiased roundings:
  539    * - Large int32_t values can't be exactly represented as FP32. The
  540    * conversion instruction in ARM NEON would round it to nearest FP32 value
  541    * with ties to even.
  542    * - Product of two FP32 values is generally not exactly representation as
  543    * an FP32 value, and will be rounded to nearest FP32 value with ties to
  544    * even.
  545    */
  546   const float32x4_t vacc0x0123_f =
  547     vmulq_f32(vcvtq_f32_s32(vacc0x0123), requantization_scale_c0123);
  548   const float32x4_t vacc1x0123_f =
  549     vmulq_f32(vcvtq_f32_s32(vacc1x0123), requantization_scale_c0123);
  550   const float32x4_t vacc2x0123_f =
  551     vmulq_f32(vcvtq_f32_s32(vacc2x0123), requantization_scale_c0123);
  552   const float32x4_t vacc3x0123_f =
  553     vmulq_f32(vcvtq_f32_s32(vacc3x0123), requantization_scale_c0123);
  554   const float32x4_t vacc0x4567_f =
  555     vmulq_f32(vcvtq_f32_s32(vacc0x4567), requantization_scale_c4567);
  556   const float32x4_t vacc1x4567_f =
  557     vmulq_f32(vcvtq_f32_s32(vacc1x4567), requantization_scale_c4567);
  558   const float32x4_t vacc2x4567_f =
  559     vmulq_f32(vcvtq_f32_s32(vacc2x4567), requantization_scale_c4567);
  560   const float32x4_t vacc3x4567_f =
  561     vmulq_f32(vcvtq_f32_s32(vacc3x4567), requantization_scale_c4567);
  562 
  563 #ifdef __aarch64__
  564   const int16x8_t voutput_zero_point =
  565       vld1q_dup_s16(&quantization_params->neon.output_zero_point);
  566   /*
  567    * Leverage "Floating-point Convert to Signed integer, rounding to nearest
  568    * with ties to even" instruction. This is an ARMv8 instruction (always
  569    * available in AArch64), which saturates result on overflow. We don't need
  570    * to specifically consider saturated results, they will be clamped at the
  571    * last stage.
  572    */
  573   vacc0x0123 = vcvtnq_s32_f32(vacc0x0123_f);
  574   vacc1x0123 = vcvtnq_s32_f32(vacc1x0123_f);
  575   vacc2x0123 = vcvtnq_s32_f32(vacc2x0123_f);
  576   vacc3x0123 = vcvtnq_s32_f32(vacc3x0123_f);
  577   vacc0x4567 = vcvtnq_s32_f32(vacc0x4567_f);
  578   vacc1x4567 = vcvtnq_s32_f32(vacc1x4567_f);
  579   vacc2x4567 = vcvtnq_s32_f32(vacc2x4567_f);
  580   vacc3x4567 = vcvtnq_s32_f32(vacc3x4567_f);
  581 
  582   const int16x8_t vacc0x01234567 = vqaddq_s16(
  583       vqmovn_high_s32(vqmovn_s32(vacc0x0123), vacc0x4567), voutput_zero_point);
  584   const int16x8_t vacc1x01234567 = vqaddq_s16(
  585       vqmovn_high_s32(vqmovn_s32(vacc1x0123), vacc1x4567), voutput_zero_point);
  586   const int16x8_t vacc2x01234567 = vqaddq_s16(
  587       vqmovn_high_s32(vqmovn_s32(vacc2x0123), vacc2x4567), voutput_zero_point);
  588   const int16x8_t vacc3x01234567 = vqaddq_s16(
  589       vqmovn_high_s32(vqmovn_s32(vacc3x0123), vacc3x4567), voutput_zero_point);
  590 
  591   uint8x16_t vout0x01234567_1x01234567 =
  592       vqmovun_high_s16(vqmovun_s16(vacc0x01234567), vacc1x01234567);
  593   uint8x16_t vout2x01234567_3x01234567 =
  594       vqmovun_high_s16(vqmovun_s16(vacc2x01234567), vacc3x01234567);
  595 
  596   const uint8x16_t voutput_min =
  597       vld1q_dup_u8(&quantization_params->neon.output_min);
  598   const uint8x16_t voutput_max =
  599       vld1q_dup_u8(&quantization_params->neon.output_max);
  600 
  601   vout0x01234567_1x01234567 = vmaxq_u8(vout0x01234567_1x01234567, voutput_min);
  602   vout2x01234567_3x01234567 = vmaxq_u8(vout2x01234567_3x01234567, voutput_min);
  603   vout0x01234567_1x01234567 = vminq_u8(vout0x01234567_1x01234567, voutput_max);
  604   vout2x01234567_3x01234567 = vminq_u8(vout2x01234567_3x01234567, voutput_max);
  605 #else
  606   const float32x4_t vfmin = vdupq_n_f32(quantization_params->neon.vfmin);
  607   const float32x4_t vfmax = vdupq_n_f32(quantization_params->neon.vfmax);
  608   const float32x4_t vfmagic = vdupq_n_f32(quantization_params->neon.vfmagic);
  609   const int32x4_t vimagic = vdupq_n_s32(quantization_params->neon.vimagic);
  610   /*
  611    * ARMv7 NEON offers only a floating-point to integer conversion instruction
  612    * with rounding towards zero. In lieu of conversion instruction with
  613    * rounding-to-nearest-even, we use a magic trick of adding a large number
  614    * (1.5 * 2**23) to scaled value to cause rounding to integer, and then
  615    * substracing this magic number as integer. This trick works only in a
  616    * limited range (absolute value of input must be less than 2**22), so
  617    * generally we have to clamp input to this range before using the magic.
  618    * However, clamping to any smaller range works just as well, and thus we
  619    * clamp to [qmin - zero point, qmax - zero point] range so that after we
  620    * add zero point to the result, it gets into target [qmin, qmax] range.
  621    */
  622   const float32x4_t vacc0x0123_f_clamped =
  623       vminq_f32(vmaxq_f32(vacc0x0123_f, vfmin), vfmax);
  624   const float32x4_t vacc1x0123_f_clamped =
  625       vminq_f32(vmaxq_f32(vacc1x0123_f, vfmin), vfmax);
  626   const float32x4_t vacc2x0123_f_clamped =
  627       vminq_f32(vmaxq_f32(vacc2x0123_f, vfmin), vfmax);
  628   const float32x4_t vacc3x0123_f_clamped =
  629       vminq_f32(vmaxq_f32(vacc3x0123_f, vfmin), vfmax);
  630   const float32x4_t vacc0x4567_f_clamped =
  631       vminq_f32(vmaxq_f32(vacc0x4567_f, vfmin), vfmax);
  632   const float32x4_t vacc1x4567_f_clamped =
  633       vminq_f32(vmaxq_f32(vacc1x4567_f, vfmin), vfmax);
  634   const float32x4_t vacc2x4567_f_clamped =
  635       vminq_f32(vmaxq_f32(vacc2x4567_f, vfmin), vfmax);
  636   const float32x4_t vacc3x4567_f_clamped =
  637       vminq_f32(vmaxq_f32(vacc3x4567_f, vfmin), vfmax);
  638 
  639   /*
  640    * Conversion to integer using the "magic trick". Rounding is performed in
  641    * the output of addition operation, and result is rounded to nearest even
  642    * integer with ties to even.
  643    */
  644   vacc0x0123 = vsubq_s32(
  645       vreinterpretq_s32_f32(vaddq_f32(vacc0x0123_f_clamped, vfmagic)), vimagic);
  646   vacc1x0123 = vsubq_s32(
  647       vreinterpretq_s32_f32(vaddq_f32(vacc1x0123_f_clamped, vfmagic)), vimagic);
  648   vacc2x0123 = vsubq_s32(
  649       vreinterpretq_s32_f32(vaddq_f32(vacc2x0123_f_clamped, vfmagic)), vimagic);
  650   vacc3x0123 = vsubq_s32(
  651       vreinterpretq_s32_f32(vaddq_f32(vacc3x0123_f_clamped, vfmagic)), vimagic);
  652   vacc0x4567 = vsubq_s32(
  653       vreinterpretq_s32_f32(vaddq_f32(vacc0x4567_f_clamped, vfmagic)), vimagic);
  654   vacc1x4567 = vsubq_s32(
  655       vreinterpretq_s32_f32(vaddq_f32(vacc1x4567_f_clamped, vfmagic)), vimagic);
  656   vacc2x4567 = vsubq_s32(
  657       vreinterpretq_s32_f32(vaddq_f32(vacc2x4567_f_clamped, vfmagic)), vimagic);
  658   vacc3x4567 = vsubq_s32(
  659       vreinterpretq_s32_f32(vaddq_f32(vacc3x4567_f_clamped, vfmagic)), vimagic);
  660 
  661   const int16x8_t vacc0x01234567 =
  662       vcombine_s16(vqmovn_s32(vacc0x0123), vqmovn_s32(vacc0x4567));
  663   const int16x8_t vacc1x01234567 =
  664       vcombine_s16(vqmovn_s32(vacc1x0123), vqmovn_s32(vacc1x4567));
  665   const int16x8_t vacc2x01234567 =
  666       vcombine_s16(vqmovn_s32(vacc2x0123), vqmovn_s32(vacc2x4567));
  667   const int16x8_t vacc3x01234567 =
  668       vcombine_s16(vqmovn_s32(vacc3x0123), vqmovn_s32(vacc3x4567));
  669 
  670   uint8x16_t vout0x01234567_1x01234567 =
  671       vcombine_u8(vqmovun_s16(vacc0x01234567), vqmovun_s16(vacc1x01234567));
  672   uint8x16_t vout2x01234567_3x01234567 =
  673       vcombine_u8(vqmovun_s16(vacc2x01234567), vqmovun_s16(vacc3x01234567));
  674 #endif
  675 
  676   uint8_t* c0 = c;
  677   uint8_t* c1 = (uint8_t*)((uintptr_t)c0 + c_stride);
  678   if (mr < 2) {
  679     c1 = c0;
  680   }
  681   uint8_t* c2 = (uint8_t*)((uintptr_t)c1 + c_stride);
  682   if (mr <= 2) {
  683     c2 = c1;
  684   }
  685   uint8_t* c3 = (uint8_t*)((uintptr_t)c2 + c_stride);
  686   if (mr != 4) {
  687     c3 = c2;
  688   }
  689   if (nr == 8) {
  690     vst1_u8(c0, vget_low_u8(vout0x01234567_1x01234567));
  691     vst1_u8(c1, vget_high_u8(vout0x01234567_1x01234567));
  692     vst1_u8(c2, vget_low_u8(vout2x01234567_3x01234567));
  693     vst1_u8(c3, vget_high_u8(vout2x01234567_3x01234567));
  694   } else {
  695     if (nr >= 4) {
  696       vst1q_lane_u32(
  697           __builtin_assume_aligned(c0, 1),
  698           vreinterpretq_u32_u8(vout0x01234567_1x01234567),
  699           0);
  700       c0 += 4;
  701       vst1q_lane_u32(
  702           __builtin_assume_aligned(c1, 1),
  703           vreinterpretq_u32_u8(vout0x01234567_1x01234567),
  704           2);
  705       c1 += 4;
  706       vst1q_lane_u32(
  707           __builtin_assume_aligned(c2, 1),
  708           vreinterpretq_u32_u8(vout2x01234567_3x01234567),
  709           0);
  710       c2 += 4;
  711       vst1q_lane_u32(
  712           __builtin_assume_aligned(c3, 1),
  713           vreinterpretq_u32_u8(vout2x01234567_3x01234567),
  714           2);
  715       c3 += 4;
  716       vout0x01234567_1x01234567 =
  717           vextq_u8(vout0x01234567_1x01234567, vout0x01234567_1x01234567, 4);
  718       vout2x01234567_3x01234567 =
  719           vextq_u8(vout2x01234567_3x01234567, vout2x01234567_3x01234567, 4);
  720       nr -= 4;
  721     }
  722     if (nr >= 2) {
  723       vst1q_lane_u16(
  724           __builtin_assume_aligned(c0, 1),
  725           vreinterpretq_u16_u8(vout0x01234567_1x01234567),
  726           0);
  727       c0 += 2;
  728       vst1q_lane_u16(
  729           __builtin_assume_aligned(c1, 1),
  730           vreinterpretq_u16_u8(vout0x01234567_1x01234567),
  731           4);
  732       c1 += 2;
  733       vst1q_lane_u16(
  734           __builtin_assume_aligned(c2, 1),
  735           vreinterpretq_u16_u8(vout2x01234567_3x01234567),
  736           0);
  737       c2 += 2;
  738       vst1q_lane_u16(
  739           __builtin_assume_aligned(c3, 1),
  740           vreinterpretq_u16_u8(vout2x01234567_3x01234567),
  741           4);
  742       c3 += 2;
  743       vout0x01234567_1x01234567 =
  744           vextq_u8(vout0x01234567_1x01234567, vout0x01234567_1x01234567, 2);
  745       vout2x01234567_3x01234567 =
  746           vextq_u8(vout2x01234567_3x01234567, vout2x01234567_3x01234567, 2);
  747       nr -= 2;
  748     }
  749     if (nr != 0) {
  750       vst1q_lane_u8(c0, vout0x01234567_1x01234567, 0);
  751       vst1q_lane_u8(c1, vout0x01234567_1x01234567, 8);
  752       vst1q_lane_u8(c2, vout2x01234567_3x01234567, 0);
  753       vst1q_lane_u8(c3, vout2x01234567_3x01234567, 8);
  754     }
  755   }
  756 }