"Fossies" - the Fresh Open Source Software Archive

Member "pytorch-1.8.2/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm/4x8-dq-neon.c" (23 Jul 2021, 23664 Bytes) of package /linux/misc/pytorch-1.8.2.tar.gz:


As a special service "Fossies" has tried to format the requested source page into HTML format using (guessed) C and C++ source code syntax highlighting (style: standard) with prefixed line numbers and code folding option. Alternatively you can here view or download the uninterpreted source code file. For more information about "4x8-dq-neon.c" see the Fossies "Dox" file reference documentation.

    1 /*
    2  * Copyright (c) Facebook, Inc. and its affiliates.
    3  * All rights reserved.
    4  *
    5  * This source code is licensed under the BSD-style license found in the
    6  * LICENSE file in the root directory of this source tree.
    7  */
    8 
    9 #include <arm_neon.h>
   10 
   11 #include <qnnpack/q8gemm.h>
   12 #include <requantization/runtime-neon.h>
   13 
   14 void pytorch_q8gemm_dq_ukernel_4x8__neon(
   15     size_t mr,
   16     size_t nr,
   17     size_t k,
   18     const uint8_t* restrict a,
   19     size_t a_stride,
   20     const void* restrict w,
   21     const float* restrict b,
   22     float* restrict c,
   23     size_t c_stride,
   24     size_t output_channel_index,
   25     const struct pytorch_qnnp_conv_dynamic_quantization_params
   26         quantization_params[RESTRICT_STATIC 1]) {
   27   int32x4_t vacc0x0123 = {};
   28   int32x4_t vacc0x4567 = {};
   29   int32x4_t vacc1x0123 = {};
   30   int32x4_t vacc1x4567 = {};
   31   int32x4_t vacc2x0123 = {};
   32   int32x4_t vacc2x4567 = {};
   33   int32x4_t vacc3x0123 = {};
   34   int32x4_t vacc3x4567 = {};
   35   w = (const void*)((uintptr_t)w + 32);
   36 
   37   const uint8_t* a0 = a;
   38   const uint8_t* a1 = (const uint8_t*)((uintptr_t)a0 + a_stride);
   39   if (mr < 2) {
   40     a1 = a0;
   41   }
   42   const uint8_t* a2 = (const uint8_t*)((uintptr_t)a1 + a_stride);
   43   if (mr <= 2) {
   44     a2 = a1;
   45   }
   46   const uint8_t* a3 = (const uint8_t*)((uintptr_t)a2 + a_stride);
   47   if (mr != 4) {
   48     a3 = a2;
   49   }
   50 
   51   const uint8x8_t va_zero_point =
   52       vld1_dup_u8((const uint8_t*)&quantization_params->input_zero_point);
   53   // Assumes that kernel_zero_points is an array padded with necessary elements
   54   // in order to make it multiple of 8.
   55   const uint8x8_t vb_zero_point =
   56       vld1_u8((const uint8_t*)&quantization_params->kernel_zero_points
   57           [output_channel_index]);
   58 
   59   const float32x4_t vmultiplier_c0123 =
   60       vld1q_f32(&quantization_params->multipliers[output_channel_index]);
   61   const float32x4_t vmultiplier_c4567 =
   62       vld1q_f32(&quantization_params->multipliers[output_channel_index + 4]);
   63   const float32x4_t vbias[] = {
   64     vld1q_f32(b),
   65     vld1q_f32(b + 4),
   66   };
   67 
   68   for (; k >= 8; k -= 8) {
   69     const uint8x8_t va0 = vld1_u8(a0);
   70     a0 += 8;
   71     const int16x8_t vxa0 =
   72         vreinterpretq_s16_u16(sub_zero_point(va0, va_zero_point));
   73     const uint8x8_t va1 = vld1_u8(a1);
   74     a1 += 8;
   75     const int16x8_t vxa1 =
   76         vreinterpretq_s16_u16(sub_zero_point(va1, va_zero_point));
   77     const uint8x8_t va2 = vld1_u8(a2);
   78     a2 += 8;
   79     const int16x8_t vxa2 =
   80         vreinterpretq_s16_u16(sub_zero_point(va2, va_zero_point));
   81     const uint8x8_t va3 = vld1_u8(a3);
   82     a3 += 8;
   83     const int16x8_t vxa3 =
   84         vreinterpretq_s16_u16(sub_zero_point(va3, va_zero_point));
   85 
   86     const uint8x8_t vb01234567c0 = vld1_u8(w);
   87     w = (const void*)((uintptr_t)w + 8);
   88     const int16x8_t vxb01234567c0 =
   89         vreinterpretq_s16_u16(vsubl_u8(vb01234567c0, vb_zero_point));
   90 
   91     vacc0x0123 = vmlal_lane_s16(
   92         vacc0x0123, vget_low_s16(vxb01234567c0), vget_low_s16(vxa0), 0);
   93     vacc0x4567 = vmlal_lane_s16(
   94         vacc0x4567, vget_high_s16(vxb01234567c0), vget_low_s16(vxa0), 0);
   95     vacc1x0123 = vmlal_lane_s16(
   96         vacc1x0123, vget_low_s16(vxb01234567c0), vget_low_s16(vxa1), 0);
   97     vacc1x4567 = vmlal_lane_s16(
   98         vacc1x4567, vget_high_s16(vxb01234567c0), vget_low_s16(vxa1), 0);
   99     vacc2x0123 = vmlal_lane_s16(
  100         vacc2x0123, vget_low_s16(vxb01234567c0), vget_low_s16(vxa2), 0);
  101     vacc2x4567 = vmlal_lane_s16(
  102         vacc2x4567, vget_high_s16(vxb01234567c0), vget_low_s16(vxa2), 0);
  103     vacc3x0123 = vmlal_lane_s16(
  104         vacc3x0123, vget_low_s16(vxb01234567c0), vget_low_s16(vxa3), 0);
  105     vacc3x4567 = vmlal_lane_s16(
  106         vacc3x4567, vget_high_s16(vxb01234567c0), vget_low_s16(vxa3), 0);
  107 
  108     const uint8x8_t vb01234567c1 = vld1_u8(w);
  109     w = (const void*)((uintptr_t)w + 8);
  110     const int16x8_t vxb01234567c1 =
  111         vreinterpretq_s16_u16(vsubl_u8(vb01234567c1, vb_zero_point));
  112 
  113     vacc0x0123 = vmlal_lane_s16(
  114         vacc0x0123, vget_low_s16(vxb01234567c1), vget_low_s16(vxa0), 1);
  115     vacc0x4567 = vmlal_lane_s16(
  116         vacc0x4567, vget_high_s16(vxb01234567c1), vget_low_s16(vxa0), 1);
  117     vacc1x0123 = vmlal_lane_s16(
  118         vacc1x0123, vget_low_s16(vxb01234567c1), vget_low_s16(vxa1), 1);
  119     vacc1x4567 = vmlal_lane_s16(
  120         vacc1x4567, vget_high_s16(vxb01234567c1), vget_low_s16(vxa1), 1);
  121     vacc2x0123 = vmlal_lane_s16(
  122         vacc2x0123, vget_low_s16(vxb01234567c1), vget_low_s16(vxa2), 1);
  123     vacc2x4567 = vmlal_lane_s16(
  124         vacc2x4567, vget_high_s16(vxb01234567c1), vget_low_s16(vxa2), 1);
  125     vacc3x0123 = vmlal_lane_s16(
  126         vacc3x0123, vget_low_s16(vxb01234567c1), vget_low_s16(vxa3), 1);
  127     vacc3x4567 = vmlal_lane_s16(
  128         vacc3x4567, vget_high_s16(vxb01234567c1), vget_low_s16(vxa3), 1);
  129 
  130     const uint8x8_t vb01234567c2 = vld1_u8(w);
  131     w = (const void*)((uintptr_t)w + 8);
  132     const int16x8_t vxb01234567c2 =
  133         vreinterpretq_s16_u16(vsubl_u8(vb01234567c2, vb_zero_point));
  134 
  135     vacc0x0123 = vmlal_lane_s16(
  136         vacc0x0123, vget_low_s16(vxb01234567c2), vget_low_s16(vxa0), 2);
  137     vacc0x4567 = vmlal_lane_s16(
  138         vacc0x4567, vget_high_s16(vxb01234567c2), vget_low_s16(vxa0), 2);
  139     vacc1x0123 = vmlal_lane_s16(
  140         vacc1x0123, vget_low_s16(vxb01234567c2), vget_low_s16(vxa1), 2);
  141     vacc1x4567 = vmlal_lane_s16(
  142         vacc1x4567, vget_high_s16(vxb01234567c2), vget_low_s16(vxa1), 2);
  143     vacc2x0123 = vmlal_lane_s16(
  144         vacc2x0123, vget_low_s16(vxb01234567c2), vget_low_s16(vxa2), 2);
  145     vacc2x4567 = vmlal_lane_s16(
  146         vacc2x4567, vget_high_s16(vxb01234567c2), vget_low_s16(vxa2), 2);
  147     vacc3x0123 = vmlal_lane_s16(
  148         vacc3x0123, vget_low_s16(vxb01234567c2), vget_low_s16(vxa3), 2);
  149     vacc3x4567 = vmlal_lane_s16(
  150         vacc3x4567, vget_high_s16(vxb01234567c2), vget_low_s16(vxa3), 2);
  151 
  152     const uint8x8_t vb01234567c3 = vld1_u8(w);
  153     w = (const void*)((uintptr_t)w + 8);
  154     const int16x8_t vxb01234567c3 =
  155         vreinterpretq_s16_u16(vsubl_u8(vb01234567c3, vb_zero_point));
  156 
  157     vacc0x0123 = vmlal_lane_s16(
  158         vacc0x0123, vget_low_s16(vxb01234567c3), vget_low_s16(vxa0), 3);
  159     vacc0x4567 = vmlal_lane_s16(
  160         vacc0x4567, vget_high_s16(vxb01234567c3), vget_low_s16(vxa0), 3);
  161     vacc1x0123 = vmlal_lane_s16(
  162         vacc1x0123, vget_low_s16(vxb01234567c3), vget_low_s16(vxa1), 3);
  163     vacc1x4567 = vmlal_lane_s16(
  164         vacc1x4567, vget_high_s16(vxb01234567c3), vget_low_s16(vxa1), 3);
  165     vacc2x0123 = vmlal_lane_s16(
  166         vacc2x0123, vget_low_s16(vxb01234567c3), vget_low_s16(vxa2), 3);
  167     vacc2x4567 = vmlal_lane_s16(
  168         vacc2x4567, vget_high_s16(vxb01234567c3), vget_low_s16(vxa2), 3);
  169     vacc3x0123 = vmlal_lane_s16(
  170         vacc3x0123, vget_low_s16(vxb01234567c3), vget_low_s16(vxa3), 3);
  171     vacc3x4567 = vmlal_lane_s16(
  172         vacc3x4567, vget_high_s16(vxb01234567c3), vget_low_s16(vxa3), 3);
  173 
  174     const uint8x8_t vb01234567c4 = vld1_u8(w);
  175     w = (const void*)((uintptr_t)w + 8);
  176     const int16x8_t vxb01234567c4 =
  177         vreinterpretq_s16_u16(vsubl_u8(vb01234567c4, vb_zero_point));
  178 
  179     vacc0x0123 = vmlal_lane_s16(
  180         vacc0x0123, vget_low_s16(vxb01234567c4), vget_high_s16(vxa0), 0);
  181     vacc0x4567 = vmlal_lane_s16(
  182         vacc0x4567, vget_high_s16(vxb01234567c4), vget_high_s16(vxa0), 0);
  183     vacc1x0123 = vmlal_lane_s16(
  184         vacc1x0123, vget_low_s16(vxb01234567c4), vget_high_s16(vxa1), 0);
  185     vacc1x4567 = vmlal_lane_s16(
  186         vacc1x4567, vget_high_s16(vxb01234567c4), vget_high_s16(vxa1), 0);
  187     vacc2x0123 = vmlal_lane_s16(
  188         vacc2x0123, vget_low_s16(vxb01234567c4), vget_high_s16(vxa2), 0);
  189     vacc2x4567 = vmlal_lane_s16(
  190         vacc2x4567, vget_high_s16(vxb01234567c4), vget_high_s16(vxa2), 0);
  191     vacc3x0123 = vmlal_lane_s16(
  192         vacc3x0123, vget_low_s16(vxb01234567c4), vget_high_s16(vxa3), 0);
  193     vacc3x4567 = vmlal_lane_s16(
  194         vacc3x4567, vget_high_s16(vxb01234567c4), vget_high_s16(vxa3), 0);
  195 
  196     const uint8x8_t vb01234567c5 = vld1_u8(w);
  197     w = (const void*)((uintptr_t)w + 8);
  198     const int16x8_t vxb01234567c5 =
  199         vreinterpretq_s16_u16(vsubl_u8(vb01234567c5, vb_zero_point));
  200 
  201     vacc0x0123 = vmlal_lane_s16(
  202         vacc0x0123, vget_low_s16(vxb01234567c5), vget_high_s16(vxa0), 1);
  203     vacc0x4567 = vmlal_lane_s16(
  204         vacc0x4567, vget_high_s16(vxb01234567c5), vget_high_s16(vxa0), 1);
  205     vacc1x0123 = vmlal_lane_s16(
  206         vacc1x0123, vget_low_s16(vxb01234567c5), vget_high_s16(vxa1), 1);
  207     vacc1x4567 = vmlal_lane_s16(
  208         vacc1x4567, vget_high_s16(vxb01234567c5), vget_high_s16(vxa1), 1);
  209     vacc2x0123 = vmlal_lane_s16(
  210         vacc2x0123, vget_low_s16(vxb01234567c5), vget_high_s16(vxa2), 1);
  211     vacc2x4567 = vmlal_lane_s16(
  212         vacc2x4567, vget_high_s16(vxb01234567c5), vget_high_s16(vxa2), 1);
  213     vacc3x0123 = vmlal_lane_s16(
  214         vacc3x0123, vget_low_s16(vxb01234567c5), vget_high_s16(vxa3), 1);
  215     vacc3x4567 = vmlal_lane_s16(
  216         vacc3x4567, vget_high_s16(vxb01234567c5), vget_high_s16(vxa3), 1);
  217 
  218     const uint8x8_t vb01234567c6 = vld1_u8(w);
  219     w = (const void*)((uintptr_t)w + 8);
  220     const int16x8_t vxb01234567c6 =
  221         vreinterpretq_s16_u16(vsubl_u8(vb01234567c6, vb_zero_point));
  222 
  223     vacc0x0123 = vmlal_lane_s16(
  224         vacc0x0123, vget_low_s16(vxb01234567c6), vget_high_s16(vxa0), 2);
  225     vacc0x4567 = vmlal_lane_s16(
  226         vacc0x4567, vget_high_s16(vxb01234567c6), vget_high_s16(vxa0), 2);
  227     vacc1x0123 = vmlal_lane_s16(
  228         vacc1x0123, vget_low_s16(vxb01234567c6), vget_high_s16(vxa1), 2);
  229     vacc1x4567 = vmlal_lane_s16(
  230         vacc1x4567, vget_high_s16(vxb01234567c6), vget_high_s16(vxa1), 2);
  231     vacc2x0123 = vmlal_lane_s16(
  232         vacc2x0123, vget_low_s16(vxb01234567c6), vget_high_s16(vxa2), 2);
  233     vacc2x4567 = vmlal_lane_s16(
  234         vacc2x4567, vget_high_s16(vxb01234567c6), vget_high_s16(vxa2), 2);
  235     vacc3x0123 = vmlal_lane_s16(
  236         vacc3x0123, vget_low_s16(vxb01234567c6), vget_high_s16(vxa3), 2);
  237     vacc3x4567 = vmlal_lane_s16(
  238         vacc3x4567, vget_high_s16(vxb01234567c6), vget_high_s16(vxa3), 2);
  239 
  240     const uint8x8_t vb01234567c7 = vld1_u8(w);
  241     w = (const void*)((uintptr_t)w + 8);
  242     const int16x8_t vxb01234567c7 =
  243         vreinterpretq_s16_u16(vsubl_u8(vb01234567c7, vb_zero_point));
  244 
  245     vacc0x0123 = vmlal_lane_s16(
  246         vacc0x0123, vget_low_s16(vxb01234567c7), vget_high_s16(vxa0), 3);
  247     vacc0x4567 = vmlal_lane_s16(
  248         vacc0x4567, vget_high_s16(vxb01234567c7), vget_high_s16(vxa0), 3);
  249     vacc1x0123 = vmlal_lane_s16(
  250         vacc1x0123, vget_low_s16(vxb01234567c7), vget_high_s16(vxa1), 3);
  251     vacc1x4567 = vmlal_lane_s16(
  252         vacc1x4567, vget_high_s16(vxb01234567c7), vget_high_s16(vxa1), 3);
  253     vacc2x0123 = vmlal_lane_s16(
  254         vacc2x0123, vget_low_s16(vxb01234567c7), vget_high_s16(vxa2), 3);
  255     vacc2x4567 = vmlal_lane_s16(
  256         vacc2x4567, vget_high_s16(vxb01234567c7), vget_high_s16(vxa2), 3);
  257     vacc3x0123 = vmlal_lane_s16(
  258         vacc3x0123, vget_low_s16(vxb01234567c7), vget_high_s16(vxa3), 3);
  259     vacc3x4567 = vmlal_lane_s16(
  260         vacc3x4567, vget_high_s16(vxb01234567c7), vget_high_s16(vxa3), 3);
  261   }
  262   if (k != 0) {
  263     const size_t a_predecrement = 8 - k;
  264     const int64x1_t va_shift = vmov_n_s64(-8 * a_predecrement);
  265     const uint8x8_t va0 = vreinterpret_u8_u64(
  266         vshl_u64(vreinterpret_u64_u8(vld1_u8(a0 - a_predecrement)), va_shift));
  267     const int16x8_t vxa0 =
  268         vreinterpretq_s16_u16(sub_zero_point(va0, va_zero_point));
  269     const uint8x8_t va1 = vreinterpret_u8_u64(
  270         vshl_u64(vreinterpret_u64_u8(vld1_u8(a1 - a_predecrement)), va_shift));
  271     const int16x8_t vxa1 =
  272         vreinterpretq_s16_u16(sub_zero_point(va1, va_zero_point));
  273     const uint8x8_t va2 = vreinterpret_u8_u64(
  274         vshl_u64(vreinterpret_u64_u8(vld1_u8(a2 - a_predecrement)), va_shift));
  275     const int16x8_t vxa2 =
  276         vreinterpretq_s16_u16(sub_zero_point(va2, va_zero_point));
  277     const uint8x8_t va3 = vreinterpret_u8_u64(
  278         vshl_u64(vreinterpret_u64_u8(vld1_u8(a3 - a_predecrement)), va_shift));
  279     const int16x8_t vxa3 =
  280         vreinterpretq_s16_u16(sub_zero_point(va3, va_zero_point));
  281 
  282     const uint8x8_t vb01234567c0 = vld1_u8(w);
  283     w = (const void*)((uintptr_t)w + 8);
  284     const int16x8_t vxb01234567c0 =
  285         vreinterpretq_s16_u16(vsubl_u8(vb01234567c0, vb_zero_point));
  286 
  287     vacc0x0123 = vmlal_lane_s16(
  288         vacc0x0123, vget_low_s16(vxb01234567c0), vget_low_s16(vxa0), 0);
  289     vacc0x4567 = vmlal_lane_s16(
  290         vacc0x4567, vget_high_s16(vxb01234567c0), vget_low_s16(vxa0), 0);
  291     vacc1x0123 = vmlal_lane_s16(
  292         vacc1x0123, vget_low_s16(vxb01234567c0), vget_low_s16(vxa1), 0);
  293     vacc1x4567 = vmlal_lane_s16(
  294         vacc1x4567, vget_high_s16(vxb01234567c0), vget_low_s16(vxa1), 0);
  295     vacc2x0123 = vmlal_lane_s16(
  296         vacc2x0123, vget_low_s16(vxb01234567c0), vget_low_s16(vxa2), 0);
  297     vacc2x4567 = vmlal_lane_s16(
  298         vacc2x4567, vget_high_s16(vxb01234567c0), vget_low_s16(vxa2), 0);
  299     vacc3x0123 = vmlal_lane_s16(
  300         vacc3x0123, vget_low_s16(vxb01234567c0), vget_low_s16(vxa3), 0);
  301     vacc3x4567 = vmlal_lane_s16(
  302         vacc3x4567, vget_high_s16(vxb01234567c0), vget_low_s16(vxa3), 0);
  303 
  304     if (k >= 2) {
  305       const uint8x8_t vb01234567c1 = vld1_u8(w);
  306       w = (const void*)((uintptr_t)w + 8);
  307       const int16x8_t vxb01234567c1 =
  308           vreinterpretq_s16_u16(vsubl_u8(vb01234567c1, vb_zero_point));
  309 
  310       vacc0x0123 = vmlal_lane_s16(
  311           vacc0x0123, vget_low_s16(vxb01234567c1), vget_low_s16(vxa0), 1);
  312       vacc0x4567 = vmlal_lane_s16(
  313           vacc0x4567, vget_high_s16(vxb01234567c1), vget_low_s16(vxa0), 1);
  314       vacc1x0123 = vmlal_lane_s16(
  315           vacc1x0123, vget_low_s16(vxb01234567c1), vget_low_s16(vxa1), 1);
  316       vacc1x4567 = vmlal_lane_s16(
  317           vacc1x4567, vget_high_s16(vxb01234567c1), vget_low_s16(vxa1), 1);
  318       vacc2x0123 = vmlal_lane_s16(
  319           vacc2x0123, vget_low_s16(vxb01234567c1), vget_low_s16(vxa2), 1);
  320       vacc2x4567 = vmlal_lane_s16(
  321           vacc2x4567, vget_high_s16(vxb01234567c1), vget_low_s16(vxa2), 1);
  322       vacc3x0123 = vmlal_lane_s16(
  323           vacc3x0123, vget_low_s16(vxb01234567c1), vget_low_s16(vxa3), 1);
  324       vacc3x4567 = vmlal_lane_s16(
  325           vacc3x4567, vget_high_s16(vxb01234567c1), vget_low_s16(vxa3), 1);
  326 
  327       if (k >= 3) {
  328         const uint8x8_t vb01234567c2 = vld1_u8(w);
  329         w = (const void*)((uintptr_t)w + 8);
  330         const int16x8_t vxb01234567c2 =
  331             vreinterpretq_s16_u16(vsubl_u8(vb01234567c2, vb_zero_point));
  332 
  333         vacc0x0123 = vmlal_lane_s16(
  334             vacc0x0123, vget_low_s16(vxb01234567c2), vget_low_s16(vxa0), 2);
  335         vacc0x4567 = vmlal_lane_s16(
  336             vacc0x4567, vget_high_s16(vxb01234567c2), vget_low_s16(vxa0), 2);
  337         vacc1x0123 = vmlal_lane_s16(
  338             vacc1x0123, vget_low_s16(vxb01234567c2), vget_low_s16(vxa1), 2);
  339         vacc1x4567 = vmlal_lane_s16(
  340             vacc1x4567, vget_high_s16(vxb01234567c2), vget_low_s16(vxa1), 2);
  341         vacc2x0123 = vmlal_lane_s16(
  342             vacc2x0123, vget_low_s16(vxb01234567c2), vget_low_s16(vxa2), 2);
  343         vacc2x4567 = vmlal_lane_s16(
  344             vacc2x4567, vget_high_s16(vxb01234567c2), vget_low_s16(vxa2), 2);
  345         vacc3x0123 = vmlal_lane_s16(
  346             vacc3x0123, vget_low_s16(vxb01234567c2), vget_low_s16(vxa3), 2);
  347         vacc3x4567 = vmlal_lane_s16(
  348             vacc3x4567, vget_high_s16(vxb01234567c2), vget_low_s16(vxa3), 2);
  349 
  350         if (k >= 4) {
  351           const uint8x8_t vb01234567c3 = vld1_u8(w);
  352           w = (const void*)((uintptr_t)w + 8);
  353           const int16x8_t vxb01234567c3 =
  354               vreinterpretq_s16_u16(vsubl_u8(vb01234567c3, vb_zero_point));
  355 
  356           vacc0x0123 = vmlal_lane_s16(
  357               vacc0x0123, vget_low_s16(vxb01234567c3), vget_low_s16(vxa0), 3);
  358           vacc0x4567 = vmlal_lane_s16(
  359               vacc0x4567, vget_high_s16(vxb01234567c3), vget_low_s16(vxa0), 3);
  360           vacc1x0123 = vmlal_lane_s16(
  361               vacc1x0123, vget_low_s16(vxb01234567c3), vget_low_s16(vxa1), 3);
  362           vacc1x4567 = vmlal_lane_s16(
  363               vacc1x4567, vget_high_s16(vxb01234567c3), vget_low_s16(vxa1), 3);
  364           vacc2x0123 = vmlal_lane_s16(
  365               vacc2x0123, vget_low_s16(vxb01234567c3), vget_low_s16(vxa2), 3);
  366           vacc2x4567 = vmlal_lane_s16(
  367               vacc2x4567, vget_high_s16(vxb01234567c3), vget_low_s16(vxa2), 3);
  368           vacc3x0123 = vmlal_lane_s16(
  369               vacc3x0123, vget_low_s16(vxb01234567c3), vget_low_s16(vxa3), 3);
  370           vacc3x4567 = vmlal_lane_s16(
  371               vacc3x4567, vget_high_s16(vxb01234567c3), vget_low_s16(vxa3), 3);
  372 
  373           if (k >= 5) {
  374             const uint8x8_t vb01234567c4 = vld1_u8(w);
  375             w = (const void*)((uintptr_t)w + 8);
  376             const int16x8_t vxb01234567c4 =
  377                 vreinterpretq_s16_u16(vsubl_u8(vb01234567c4, vb_zero_point));
  378 
  379             vacc0x0123 = vmlal_lane_s16(
  380                 vacc0x0123,
  381                 vget_low_s16(vxb01234567c4),
  382                 vget_high_s16(vxa0),
  383                 0);
  384             vacc0x4567 = vmlal_lane_s16(
  385                 vacc0x4567,
  386                 vget_high_s16(vxb01234567c4),
  387                 vget_high_s16(vxa0),
  388                 0);
  389             vacc1x0123 = vmlal_lane_s16(
  390                 vacc1x0123,
  391                 vget_low_s16(vxb01234567c4),
  392                 vget_high_s16(vxa1),
  393                 0);
  394             vacc1x4567 = vmlal_lane_s16(
  395                 vacc1x4567,
  396                 vget_high_s16(vxb01234567c4),
  397                 vget_high_s16(vxa1),
  398                 0);
  399             vacc2x0123 = vmlal_lane_s16(
  400                 vacc2x0123,
  401                 vget_low_s16(vxb01234567c4),
  402                 vget_high_s16(vxa2),
  403                 0);
  404             vacc2x4567 = vmlal_lane_s16(
  405                 vacc2x4567,
  406                 vget_high_s16(vxb01234567c4),
  407                 vget_high_s16(vxa2),
  408                 0);
  409             vacc3x0123 = vmlal_lane_s16(
  410                 vacc3x0123,
  411                 vget_low_s16(vxb01234567c4),
  412                 vget_high_s16(vxa3),
  413                 0);
  414             vacc3x4567 = vmlal_lane_s16(
  415                 vacc3x4567,
  416                 vget_high_s16(vxb01234567c4),
  417                 vget_high_s16(vxa3),
  418                 0);
  419 
  420             if (k >= 6) {
  421               const uint8x8_t vb01234567c5 = vld1_u8(w);
  422               w = (const void*)((uintptr_t)w + 8);
  423               const int16x8_t vxb01234567c5 =
  424                   vreinterpretq_s16_u16(vsubl_u8(vb01234567c5, vb_zero_point));
  425 
  426               vacc0x0123 = vmlal_lane_s16(
  427                   vacc0x0123,
  428                   vget_low_s16(vxb01234567c5),
  429                   vget_high_s16(vxa0),
  430                   1);
  431               vacc0x4567 = vmlal_lane_s16(
  432                   vacc0x4567,
  433                   vget_high_s16(vxb01234567c5),
  434                   vget_high_s16(vxa0),
  435                   1);
  436               vacc1x0123 = vmlal_lane_s16(
  437                   vacc1x0123,
  438                   vget_low_s16(vxb01234567c5),
  439                   vget_high_s16(vxa1),
  440                   1);
  441               vacc1x4567 = vmlal_lane_s16(
  442                   vacc1x4567,
  443                   vget_high_s16(vxb01234567c5),
  444                   vget_high_s16(vxa1),
  445                   1);
  446               vacc2x0123 = vmlal_lane_s16(
  447                   vacc2x0123,
  448                   vget_low_s16(vxb01234567c5),
  449                   vget_high_s16(vxa2),
  450                   1);
  451               vacc2x4567 = vmlal_lane_s16(
  452                   vacc2x4567,
  453                   vget_high_s16(vxb01234567c5),
  454                   vget_high_s16(vxa2),
  455                   1);
  456               vacc3x0123 = vmlal_lane_s16(
  457                   vacc3x0123,
  458                   vget_low_s16(vxb01234567c5),
  459                   vget_high_s16(vxa3),
  460                   1);
  461               vacc3x4567 = vmlal_lane_s16(
  462                   vacc3x4567,
  463                   vget_high_s16(vxb01234567c5),
  464                   vget_high_s16(vxa3),
  465                   1);
  466 
  467               if (k >= 7) {
  468                 const uint8x8_t vb01234567c6 = vld1_u8(w);
  469                 w = (const void*)((uintptr_t)w + 8);
  470                 const int16x8_t vxb01234567c6 = vreinterpretq_s16_u16(
  471                     vsubl_u8(vb01234567c6, vb_zero_point));
  472 
  473                 vacc0x0123 = vmlal_lane_s16(
  474                     vacc0x0123,
  475                     vget_low_s16(vxb01234567c6),
  476                     vget_high_s16(vxa0),
  477                     2);
  478                 vacc0x4567 = vmlal_lane_s16(
  479                     vacc0x4567,
  480                     vget_high_s16(vxb01234567c6),
  481                     vget_high_s16(vxa0),
  482                     2);
  483                 vacc1x0123 = vmlal_lane_s16(
  484                     vacc1x0123,
  485                     vget_low_s16(vxb01234567c6),
  486                     vget_high_s16(vxa1),
  487                     2);
  488                 vacc1x4567 = vmlal_lane_s16(
  489                     vacc1x4567,
  490                     vget_high_s16(vxb01234567c6),
  491                     vget_high_s16(vxa1),
  492                     2);
  493                 vacc2x0123 = vmlal_lane_s16(
  494                     vacc2x0123,
  495                     vget_low_s16(vxb01234567c6),
  496                     vget_high_s16(vxa2),
  497                     2);
  498                 vacc2x4567 = vmlal_lane_s16(
  499                     vacc2x4567,
  500                     vget_high_s16(vxb01234567c6),
  501                     vget_high_s16(vxa2),
  502                     2);
  503                 vacc3x0123 = vmlal_lane_s16(
  504                     vacc3x0123,
  505                     vget_low_s16(vxb01234567c6),
  506                     vget_high_s16(vxa3),
  507                     2);
  508                 vacc3x4567 = vmlal_lane_s16(
  509                     vacc3x4567,
  510                     vget_high_s16(vxb01234567c6),
  511                     vget_high_s16(vxa3),
  512                     2);
  513               }
  514             }
  515           }
  516         }
  517       }
  518     }
  519   }
  520 
  521   float32x4_t vout0[] = {
  522     vaddq_f32(vmulq_f32(vmultiplier_c0123, vcvtq_f32_s32(vacc0x0123)), vbias[0]),
  523     vaddq_f32(vmulq_f32(vmultiplier_c4567, vcvtq_f32_s32(vacc0x4567)), vbias[1]),
  524   };
  525   float32x4_t vout1[] = {
  526     vaddq_f32(vmulq_f32(vmultiplier_c0123, vcvtq_f32_s32(vacc1x0123)), vbias[0]),
  527     vaddq_f32(vmulq_f32(vmultiplier_c4567, vcvtq_f32_s32(vacc1x4567)), vbias[1]),
  528   };
  529   float32x4_t vout2[] = {
  530     vaddq_f32(vmulq_f32(vmultiplier_c0123, vcvtq_f32_s32(vacc2x0123)), vbias[0]),
  531     vaddq_f32(vmulq_f32(vmultiplier_c4567, vcvtq_f32_s32(vacc2x4567)), vbias[1]),
  532   };
  533   float32x4_t vout3[] = {
  534     vaddq_f32(vmulq_f32(vmultiplier_c0123, vcvtq_f32_s32(vacc3x0123)), vbias[0]),
  535     vaddq_f32(vmulq_f32(vmultiplier_c4567, vcvtq_f32_s32(vacc3x4567)), vbias[1]),
  536   };
  537 
  538   float32x4_t * vout0_ptr = vout0;
  539   float32x4_t * vout1_ptr = vout1;
  540   float32x4_t * vout2_ptr = vout2;
  541   float32x4_t * vout3_ptr = vout3;
  542 
  543   float* c0 = c;
  544   float* c1 = c0 + c_stride;
  545   if (mr < 2) {
  546     c1 = c0;
  547   }
  548   float* c2 = c1 + c_stride;
  549   if (mr <= 2) {
  550     c2 = c1;
  551   }
  552   float* c3 = c2 + c_stride;
  553   if (mr != 4) {
  554     c3 = c2;
  555   }
  556 
  557   for (; nr >= 4; nr -= 4) {
  558     vst1q_f32(c0, *vout0_ptr++);
  559     vst1q_f32(c1, *vout1_ptr++);
  560     vst1q_f32(c2, *vout2_ptr++);
  561     vst1q_f32(c3, *vout3_ptr++);
  562 
  563     c0 += 4;
  564     c1 += 4;
  565     c2 += 4;
  566     c3 += 4;
  567   }
  568 
  569   if (nr >= 2) {
  570     vst1_f32(c0, vget_low_f32(*vout0_ptr));
  571     vst1_f32(c1, vget_low_f32(*vout1_ptr));
  572     vst1_f32(c2, vget_low_f32(*vout2_ptr));
  573     vst1_f32(c3, vget_low_f32(*vout3_ptr));
  574 
  575     c0 += 2;
  576     (*vout0_ptr)[0] = (*vout0_ptr)[2];
  577     c1 += 2;
  578     (*vout1_ptr)[0] = (*vout1_ptr)[2];
  579     c2 += 2;
  580     (*vout2_ptr)[0] = (*vout2_ptr)[2];
  581     c3 += 2;
  582     (*vout3_ptr)[0] = (*vout3_ptr)[2];
  583 
  584     nr -= 2;
  585   }
  586 
  587   if (nr != 0) {
  588     vst1q_lane_f32(c0, *vout0_ptr, 0);
  589     vst1q_lane_f32(c1, *vout1_ptr, 0);
  590     vst1q_lane_f32(c2, *vout2_ptr, 0);
  591     vst1q_lane_f32(c3, *vout3_ptr, 0);
  592   }
  593 }