"Fossies" - the Fresh Open Source Software Archive

Member "pytorch-1.8.2/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm/6x4-neon.c" (23 Jul 2021, 23968 Bytes) of package /linux/misc/pytorch-1.8.2.tar.gz:


As a special service "Fossies" has tried to format the requested source page into HTML format using (guessed) C and C++ source code syntax highlighting (style: standard) with prefixed line numbers and code folding option. Alternatively you can here view or download the uninterpreted source code file. For more information about "6x4-neon.c" see the Fossies "Dox" file reference documentation.

    1 /*
    2  * Copyright (c) Facebook, Inc. and its affiliates.
    3  * All rights reserved.
    4  *
    5  * This source code is licensed under the BSD-style license found in the
    6  * LICENSE file in the root directory of this source tree.
    7  */
    8 
    9 #include <arm_neon.h>
   10 
   11 #include <qnnpack/q8gemm.h>
   12 #include <requantization/runtime-neon.h>
   13 
   14 void pytorch_q8gemm_ukernel_6x4__neon(
   15     size_t mr,
   16     size_t nr,
   17     size_t k,
   18     const uint8_t* restrict a,
   19     size_t a_stride,
   20     const void* restrict w,
   21     uint8_t* restrict c,
   22     size_t c_stride,
   23     size_t output_channel_index,
   24     const union pytorch_qnnp_conv_quantization_params
   25         quantization_params[restrict static 1]) {
   26   int32x4_t vacc0x0123 = vld1q_s32(w);
   27   w = (const void*)((uintptr_t)w + 16);
   28   int32x4_t vacc1x0123 = vacc0x0123;
   29   int32x4_t vacc2x0123 = vacc0x0123;
   30   int32x4_t vacc3x0123 = vacc0x0123;
   31   int32x4_t vacc4x0123 = vacc0x0123;
   32   int32x4_t vacc5x0123 = vacc0x0123;
   33 
   34   const uint8_t* a0 = a;
   35   const uint8_t* a1 = (const uint8_t*)((uintptr_t)a0 + a_stride);
   36   if (mr < 2) {
   37     a1 = a0;
   38   }
   39   const uint8_t* a2 = (const uint8_t*)((uintptr_t)a1 + a_stride);
   40   if (mr <= 2) {
   41     a2 = a1;
   42   }
   43   const uint8_t* a3 = (const uint8_t*)((uintptr_t)a2 + a_stride);
   44   if (mr < 4) {
   45     a3 = a2;
   46   }
   47   const uint8_t* a4 = (const uint8_t*)((uintptr_t)a3 + a_stride);
   48   if (mr <= 4) {
   49     a4 = a3;
   50   };
   51   const uint8_t* a5 = (const uint8_t*)((uintptr_t)a4 + a_stride);
   52   if (mr != 6) {
   53     a5 = a4;
   54   }
   55 
   56   const uint8x8_t va_zero_point =
   57       vld1_dup_u8((const uint8_t*)&quantization_params->neon.input_zero_point);
   58   uint8x8_t vb_zero_point =
   59       vld1_u8((const uint8_t*)&quantization_params->neon.kernel_zero_points
   60           [output_channel_index]);
   61   // Since only lower 4 values are used in this kernel. We replicate lower 4
   62   // values in upper 4 values. Still we end up loading 8 values assuming
   63   // zero point array is always multiple of 8.
   64   vb_zero_point = vset_lane_u8(vget_lane_u8(vb_zero_point, 0), vb_zero_point, 4);
   65   vb_zero_point = vset_lane_u8(vget_lane_u8(vb_zero_point, 1), vb_zero_point, 5);
   66   vb_zero_point = vset_lane_u8(vget_lane_u8(vb_zero_point, 2), vb_zero_point, 6);
   67   vb_zero_point = vset_lane_u8(vget_lane_u8(vb_zero_point, 3), vb_zero_point, 7);
   68   for (; k >= 8; k -= 8) {
   69     const uint8x8_t va0 = vld1_u8(a0);
   70     a0 += 8;
   71     const int16x8_t vxa0 =
   72         vreinterpretq_s16_u16(sub_zero_point(va0, va_zero_point));
   73     const uint8x8_t va1 = vld1_u8(a1);
   74     a1 += 8;
   75     const int16x8_t vxa1 =
   76         vreinterpretq_s16_u16(sub_zero_point(va1, va_zero_point));
   77     const uint8x8_t va2 = vld1_u8(a2);
   78     a2 += 8;
   79     const int16x8_t vxa2 =
   80         vreinterpretq_s16_u16(sub_zero_point(va2, va_zero_point));
   81     const uint8x8_t va3 = vld1_u8(a3);
   82     a3 += 8;
   83     const int16x8_t vxa3 =
   84         vreinterpretq_s16_u16(sub_zero_point(va3, va_zero_point));
   85     const uint8x8_t va4 = vld1_u8(a4);
   86     a4 += 8;
   87     const int16x8_t vxa4 =
   88         vreinterpretq_s16_u16(sub_zero_point(va4, va_zero_point));
   89     const uint8x8_t va5 = vld1_u8(a5);
   90     a5 += 8;
   91     const int16x8_t vxa5 =
   92         vreinterpretq_s16_u16(sub_zero_point(va5, va_zero_point));
   93 
   94     const uint8x8_t vb0123c01 = vld1_u8(w);
   95     w = (const void*)((uintptr_t)w + 8);
   96     const int16x8_t vxb0123c01 =
   97         vreinterpretq_s16_u16(vsubl_u8(vb0123c01, vb_zero_point));
   98 
   99     vacc0x0123 = vmlal_lane_s16(
  100         vacc0x0123, vget_low_s16(vxb0123c01), vget_low_s16(vxa0), 0);
  101     vacc1x0123 = vmlal_lane_s16(
  102         vacc1x0123, vget_low_s16(vxb0123c01), vget_low_s16(vxa1), 0);
  103     vacc2x0123 = vmlal_lane_s16(
  104         vacc2x0123, vget_low_s16(vxb0123c01), vget_low_s16(vxa2), 0);
  105     vacc3x0123 = vmlal_lane_s16(
  106         vacc3x0123, vget_low_s16(vxb0123c01), vget_low_s16(vxa3), 0);
  107     vacc4x0123 = vmlal_lane_s16(
  108         vacc4x0123, vget_low_s16(vxb0123c01), vget_low_s16(vxa4), 0);
  109     vacc5x0123 = vmlal_lane_s16(
  110         vacc5x0123, vget_low_s16(vxb0123c01), vget_low_s16(vxa5), 0);
  111 
  112     vacc0x0123 = vmlal_lane_s16(
  113         vacc0x0123, vget_high_s16(vxb0123c01), vget_low_s16(vxa0), 1);
  114     vacc1x0123 = vmlal_lane_s16(
  115         vacc1x0123, vget_high_s16(vxb0123c01), vget_low_s16(vxa1), 1);
  116     vacc2x0123 = vmlal_lane_s16(
  117         vacc2x0123, vget_high_s16(vxb0123c01), vget_low_s16(vxa2), 1);
  118     vacc3x0123 = vmlal_lane_s16(
  119         vacc3x0123, vget_high_s16(vxb0123c01), vget_low_s16(vxa3), 1);
  120     vacc4x0123 = vmlal_lane_s16(
  121         vacc4x0123, vget_high_s16(vxb0123c01), vget_low_s16(vxa4), 1);
  122     vacc5x0123 = vmlal_lane_s16(
  123         vacc5x0123, vget_high_s16(vxb0123c01), vget_low_s16(vxa5), 1);
  124 
  125     const uint8x8_t vb0123c23 = vld1_u8(w);
  126     w = (const void*)((uintptr_t)w + 8);
  127     const int16x8_t vxb0123c23 =
  128         vreinterpretq_s16_u16(vsubl_u8(vb0123c23, vb_zero_point));
  129 
  130     vacc0x0123 = vmlal_lane_s16(
  131         vacc0x0123, vget_low_s16(vxb0123c23), vget_low_s16(vxa0), 2);
  132     vacc1x0123 = vmlal_lane_s16(
  133         vacc1x0123, vget_low_s16(vxb0123c23), vget_low_s16(vxa1), 2);
  134     vacc2x0123 = vmlal_lane_s16(
  135         vacc2x0123, vget_low_s16(vxb0123c23), vget_low_s16(vxa2), 2);
  136     vacc3x0123 = vmlal_lane_s16(
  137         vacc3x0123, vget_low_s16(vxb0123c23), vget_low_s16(vxa3), 2);
  138     vacc4x0123 = vmlal_lane_s16(
  139         vacc4x0123, vget_low_s16(vxb0123c23), vget_low_s16(vxa4), 2);
  140     vacc5x0123 = vmlal_lane_s16(
  141         vacc5x0123, vget_low_s16(vxb0123c23), vget_low_s16(vxa5), 2);
  142 
  143     vacc0x0123 = vmlal_lane_s16(
  144         vacc0x0123, vget_high_s16(vxb0123c23), vget_low_s16(vxa0), 3);
  145     vacc1x0123 = vmlal_lane_s16(
  146         vacc1x0123, vget_high_s16(vxb0123c23), vget_low_s16(vxa1), 3);
  147     vacc2x0123 = vmlal_lane_s16(
  148         vacc2x0123, vget_high_s16(vxb0123c23), vget_low_s16(vxa2), 3);
  149     vacc3x0123 = vmlal_lane_s16(
  150         vacc3x0123, vget_high_s16(vxb0123c23), vget_low_s16(vxa3), 3);
  151     vacc4x0123 = vmlal_lane_s16(
  152         vacc4x0123, vget_high_s16(vxb0123c23), vget_low_s16(vxa4), 3);
  153     vacc5x0123 = vmlal_lane_s16(
  154         vacc5x0123, vget_high_s16(vxb0123c23), vget_low_s16(vxa5), 3);
  155 
  156     const uint8x8_t vb0123c45 = vld1_u8(w);
  157     w = (const void*)((uintptr_t)w + 8);
  158     const int16x8_t vxb0123c45 =
  159         vreinterpretq_s16_u16(vsubl_u8(vb0123c45, vb_zero_point));
  160 
  161     vacc0x0123 = vmlal_lane_s16(
  162         vacc0x0123, vget_low_s16(vxb0123c45), vget_high_s16(vxa0), 0);
  163     vacc1x0123 = vmlal_lane_s16(
  164         vacc1x0123, vget_low_s16(vxb0123c45), vget_high_s16(vxa1), 0);
  165     vacc2x0123 = vmlal_lane_s16(
  166         vacc2x0123, vget_low_s16(vxb0123c45), vget_high_s16(vxa2), 0);
  167     vacc3x0123 = vmlal_lane_s16(
  168         vacc3x0123, vget_low_s16(vxb0123c45), vget_high_s16(vxa3), 0);
  169     vacc4x0123 = vmlal_lane_s16(
  170         vacc4x0123, vget_low_s16(vxb0123c45), vget_high_s16(vxa4), 0);
  171     vacc5x0123 = vmlal_lane_s16(
  172         vacc5x0123, vget_low_s16(vxb0123c45), vget_high_s16(vxa5), 0);
  173 
  174     vacc0x0123 = vmlal_lane_s16(
  175         vacc0x0123, vget_high_s16(vxb0123c45), vget_high_s16(vxa0), 1);
  176     vacc1x0123 = vmlal_lane_s16(
  177         vacc1x0123, vget_high_s16(vxb0123c45), vget_high_s16(vxa1), 1);
  178     vacc2x0123 = vmlal_lane_s16(
  179         vacc2x0123, vget_high_s16(vxb0123c45), vget_high_s16(vxa2), 1);
  180     vacc3x0123 = vmlal_lane_s16(
  181         vacc3x0123, vget_high_s16(vxb0123c45), vget_high_s16(vxa3), 1);
  182     vacc4x0123 = vmlal_lane_s16(
  183         vacc4x0123, vget_high_s16(vxb0123c45), vget_high_s16(vxa4), 1);
  184     vacc5x0123 = vmlal_lane_s16(
  185         vacc5x0123, vget_high_s16(vxb0123c45), vget_high_s16(vxa5), 1);
  186 
  187     const uint8x8_t vb0123c67 = vld1_u8(w);
  188     w = (const void*)((uintptr_t)w + 8);
  189     const int16x8_t vxb0123c67 =
  190         vreinterpretq_s16_u16(vsubl_u8(vb0123c67, vb_zero_point));
  191 
  192     vacc0x0123 = vmlal_lane_s16(
  193         vacc0x0123, vget_low_s16(vxb0123c67), vget_high_s16(vxa0), 2);
  194     vacc1x0123 = vmlal_lane_s16(
  195         vacc1x0123, vget_low_s16(vxb0123c67), vget_high_s16(vxa1), 2);
  196     vacc2x0123 = vmlal_lane_s16(
  197         vacc2x0123, vget_low_s16(vxb0123c67), vget_high_s16(vxa2), 2);
  198     vacc3x0123 = vmlal_lane_s16(
  199         vacc3x0123, vget_low_s16(vxb0123c67), vget_high_s16(vxa3), 2);
  200     vacc4x0123 = vmlal_lane_s16(
  201         vacc4x0123, vget_low_s16(vxb0123c67), vget_high_s16(vxa4), 2);
  202     vacc5x0123 = vmlal_lane_s16(
  203         vacc5x0123, vget_low_s16(vxb0123c67), vget_high_s16(vxa5), 2);
  204 
  205     vacc0x0123 = vmlal_lane_s16(
  206         vacc0x0123, vget_high_s16(vxb0123c67), vget_high_s16(vxa0), 3);
  207     vacc1x0123 = vmlal_lane_s16(
  208         vacc1x0123, vget_high_s16(vxb0123c67), vget_high_s16(vxa1), 3);
  209     vacc2x0123 = vmlal_lane_s16(
  210         vacc2x0123, vget_high_s16(vxb0123c67), vget_high_s16(vxa2), 3);
  211     vacc3x0123 = vmlal_lane_s16(
  212         vacc3x0123, vget_high_s16(vxb0123c67), vget_high_s16(vxa3), 3);
  213     vacc4x0123 = vmlal_lane_s16(
  214         vacc4x0123, vget_high_s16(vxb0123c67), vget_high_s16(vxa4), 3);
  215     vacc5x0123 = vmlal_lane_s16(
  216         vacc5x0123, vget_high_s16(vxb0123c67), vget_high_s16(vxa5), 3);
  217   }
  218   if (k != 0) {
  219     const size_t a_predecrement = 8 - k;
  220     const int64x1_t va_shift = vmov_n_s64(-8 * a_predecrement);
  221     const uint8x8_t va0 = vreinterpret_u8_u64(
  222         vshl_u64(vreinterpret_u64_u8(vld1_u8(a0 - a_predecrement)), va_shift));
  223     const int16x8_t vxa0 =
  224         vreinterpretq_s16_u16(sub_zero_point(va0, va_zero_point));
  225     const uint8x8_t va1 = vreinterpret_u8_u64(
  226         vshl_u64(vreinterpret_u64_u8(vld1_u8(a1 - a_predecrement)), va_shift));
  227     const int16x8_t vxa1 =
  228         vreinterpretq_s16_u16(sub_zero_point(va1, va_zero_point));
  229     const uint8x8_t va2 = vreinterpret_u8_u64(
  230         vshl_u64(vreinterpret_u64_u8(vld1_u8(a2 - a_predecrement)), va_shift));
  231     const int16x8_t vxa2 =
  232         vreinterpretq_s16_u16(sub_zero_point(va2, va_zero_point));
  233     const uint8x8_t va3 = vreinterpret_u8_u64(
  234         vshl_u64(vreinterpret_u64_u8(vld1_u8(a3 - a_predecrement)), va_shift));
  235     const int16x8_t vxa3 =
  236         vreinterpretq_s16_u16(sub_zero_point(va3, va_zero_point));
  237     const uint8x8_t va4 = vreinterpret_u8_u64(
  238         vshl_u64(vreinterpret_u64_u8(vld1_u8(a4 - a_predecrement)), va_shift));
  239     const int16x8_t vxa4 =
  240         vreinterpretq_s16_u16(sub_zero_point(va4, va_zero_point));
  241     const uint8x8_t va5 = vreinterpret_u8_u64(
  242         vshl_u64(vreinterpret_u64_u8(vld1_u8(a5 - a_predecrement)), va_shift));
  243     const int16x8_t vxa5 =
  244         vreinterpretq_s16_u16(sub_zero_point(va5, va_zero_point));
  245 
  246     const uint8x8_t vb0123c0 = vreinterpret_u8_u32(vld1_dup_u32(w));
  247     w = (const void*)((uintptr_t)w + 4);
  248     const int16x8_t vxb0123c0 =
  249         vreinterpretq_s16_u16(vsubl_u8(vb0123c0, vb_zero_point));
  250 
  251     vacc0x0123 = vmlal_lane_s16(
  252         vacc0x0123, vget_low_s16(vxb0123c0), vget_low_s16(vxa0), 0);
  253     vacc1x0123 = vmlal_lane_s16(
  254         vacc1x0123, vget_low_s16(vxb0123c0), vget_low_s16(vxa1), 0);
  255     vacc2x0123 = vmlal_lane_s16(
  256         vacc2x0123, vget_low_s16(vxb0123c0), vget_low_s16(vxa2), 0);
  257     vacc3x0123 = vmlal_lane_s16(
  258         vacc3x0123, vget_low_s16(vxb0123c0), vget_low_s16(vxa3), 0);
  259     vacc4x0123 = vmlal_lane_s16(
  260         vacc4x0123, vget_low_s16(vxb0123c0), vget_low_s16(vxa4), 0);
  261     vacc5x0123 = vmlal_lane_s16(
  262         vacc5x0123, vget_low_s16(vxb0123c0), vget_low_s16(vxa5), 0);
  263 
  264     if (k >= 2) {
  265       const uint8x8_t vb0123c1 = vreinterpret_u8_u32(vld1_dup_u32(w));
  266       w = (const void*)((uintptr_t)w + 4);
  267       const int16x8_t vxb0123c1 =
  268           vreinterpretq_s16_u16(vsubl_u8(vb0123c1, vb_zero_point));
  269 
  270       vacc0x0123 = vmlal_lane_s16(
  271           vacc0x0123, vget_low_s16(vxb0123c1), vget_low_s16(vxa0), 1);
  272       vacc1x0123 = vmlal_lane_s16(
  273           vacc1x0123, vget_low_s16(vxb0123c1), vget_low_s16(vxa1), 1);
  274       vacc2x0123 = vmlal_lane_s16(
  275           vacc2x0123, vget_low_s16(vxb0123c1), vget_low_s16(vxa2), 1);
  276       vacc3x0123 = vmlal_lane_s16(
  277           vacc3x0123, vget_low_s16(vxb0123c1), vget_low_s16(vxa3), 1);
  278       vacc4x0123 = vmlal_lane_s16(
  279           vacc4x0123, vget_low_s16(vxb0123c1), vget_low_s16(vxa4), 1);
  280       vacc5x0123 = vmlal_lane_s16(
  281           vacc5x0123, vget_low_s16(vxb0123c1), vget_low_s16(vxa5), 1);
  282 
  283       if (k > 2) {
  284         const uint8x8_t vb0123c2 = vreinterpret_u8_u32(vld1_dup_u32(w));
  285         w = (const void*)((uintptr_t)w + 4);
  286         const int16x8_t vxb0123c2 =
  287             vreinterpretq_s16_u16(vsubl_u8(vb0123c2, vb_zero_point));
  288 
  289         vacc0x0123 = vmlal_lane_s16(
  290             vacc0x0123, vget_low_s16(vxb0123c2), vget_low_s16(vxa0), 2);
  291         vacc1x0123 = vmlal_lane_s16(
  292             vacc1x0123, vget_low_s16(vxb0123c2), vget_low_s16(vxa1), 2);
  293         vacc2x0123 = vmlal_lane_s16(
  294             vacc2x0123, vget_low_s16(vxb0123c2), vget_low_s16(vxa2), 2);
  295         vacc3x0123 = vmlal_lane_s16(
  296             vacc3x0123, vget_low_s16(vxb0123c2), vget_low_s16(vxa3), 2);
  297         vacc4x0123 = vmlal_lane_s16(
  298             vacc4x0123, vget_low_s16(vxb0123c2), vget_low_s16(vxa4), 2);
  299         vacc5x0123 = vmlal_lane_s16(
  300             vacc5x0123, vget_low_s16(vxb0123c2), vget_low_s16(vxa5), 2);
  301 
  302         if (k >= 4) {
  303           const uint8x8_t vb0123c3 = vreinterpret_u8_u32(vld1_dup_u32(w));
  304           w = (const void*)((uintptr_t)w + 4);
  305           const int16x8_t vxb0123c3 =
  306               vreinterpretq_s16_u16(vsubl_u8(vb0123c3, vb_zero_point));
  307 
  308           vacc0x0123 = vmlal_lane_s16(
  309               vacc0x0123, vget_low_s16(vxb0123c3), vget_low_s16(vxa0), 3);
  310           vacc1x0123 = vmlal_lane_s16(
  311               vacc1x0123, vget_low_s16(vxb0123c3), vget_low_s16(vxa1), 3);
  312           vacc2x0123 = vmlal_lane_s16(
  313               vacc2x0123, vget_low_s16(vxb0123c3), vget_low_s16(vxa2), 3);
  314           vacc3x0123 = vmlal_lane_s16(
  315               vacc3x0123, vget_low_s16(vxb0123c3), vget_low_s16(vxa3), 3);
  316           vacc4x0123 = vmlal_lane_s16(
  317               vacc4x0123, vget_low_s16(vxb0123c3), vget_low_s16(vxa4), 3);
  318           vacc5x0123 = vmlal_lane_s16(
  319               vacc5x0123, vget_low_s16(vxb0123c3), vget_low_s16(vxa5), 3);
  320 
  321           if (k > 4) {
  322             const uint8x8_t vb0123c4 = vreinterpret_u8_u32(vld1_dup_u32(w));
  323             w = (const void*)((uintptr_t)w + 4);
  324             const int16x8_t vxb0123c4 =
  325                 vreinterpretq_s16_u16(vsubl_u8(vb0123c4, vb_zero_point));
  326 
  327             vacc0x0123 = vmlal_lane_s16(
  328                 vacc0x0123, vget_low_s16(vxb0123c4), vget_high_s16(vxa0), 0);
  329             vacc1x0123 = vmlal_lane_s16(
  330                 vacc1x0123, vget_low_s16(vxb0123c4), vget_high_s16(vxa1), 0);
  331             vacc2x0123 = vmlal_lane_s16(
  332                 vacc2x0123, vget_low_s16(vxb0123c4), vget_high_s16(vxa2), 0);
  333             vacc3x0123 = vmlal_lane_s16(
  334                 vacc3x0123, vget_low_s16(vxb0123c4), vget_high_s16(vxa3), 0);
  335             vacc4x0123 = vmlal_lane_s16(
  336                 vacc4x0123, vget_low_s16(vxb0123c4), vget_high_s16(vxa4), 0);
  337             vacc5x0123 = vmlal_lane_s16(
  338                 vacc5x0123, vget_low_s16(vxb0123c4), vget_high_s16(vxa5), 0);
  339 
  340             if (k >= 6) {
  341               const uint8x8_t vb0123c5 = vreinterpret_u8_u32(vld1_dup_u32(w));
  342               w = (const void*)((uintptr_t)w + 4);
  343               const int16x8_t vxb0123c5 =
  344                   vreinterpretq_s16_u16(vsubl_u8(vb0123c5, vb_zero_point));
  345 
  346               vacc0x0123 = vmlal_lane_s16(
  347                   vacc0x0123, vget_low_s16(vxb0123c5), vget_high_s16(vxa0), 1);
  348               vacc1x0123 = vmlal_lane_s16(
  349                   vacc1x0123, vget_low_s16(vxb0123c5), vget_high_s16(vxa1), 1);
  350               vacc2x0123 = vmlal_lane_s16(
  351                   vacc2x0123, vget_low_s16(vxb0123c5), vget_high_s16(vxa2), 1);
  352               vacc3x0123 = vmlal_lane_s16(
  353                   vacc3x0123, vget_low_s16(vxb0123c5), vget_high_s16(vxa3), 1);
  354               vacc4x0123 = vmlal_lane_s16(
  355                   vacc4x0123, vget_low_s16(vxb0123c5), vget_high_s16(vxa4), 1);
  356               vacc5x0123 = vmlal_lane_s16(
  357                   vacc5x0123, vget_low_s16(vxb0123c5), vget_high_s16(vxa5), 1);
  358 
  359               if (k > 6) {
  360                 const uint8x8_t vb0123c6 = vreinterpret_u8_u32(vld1_dup_u32(w));
  361                 const int16x8_t vxb0123c6 =
  362                     vreinterpretq_s16_u16(vsubl_u8(vb0123c6, vb_zero_point));
  363 
  364                 vacc0x0123 = vmlal_lane_s16(
  365                     vacc0x0123,
  366                     vget_low_s16(vxb0123c6),
  367                     vget_high_s16(vxa0),
  368                     2);
  369                 vacc1x0123 = vmlal_lane_s16(
  370                     vacc1x0123,
  371                     vget_low_s16(vxb0123c6),
  372                     vget_high_s16(vxa1),
  373                     2);
  374                 vacc2x0123 = vmlal_lane_s16(
  375                     vacc2x0123,
  376                     vget_low_s16(vxb0123c6),
  377                     vget_high_s16(vxa2),
  378                     2);
  379                 vacc3x0123 = vmlal_lane_s16(
  380                     vacc3x0123,
  381                     vget_low_s16(vxb0123c6),
  382                     vget_high_s16(vxa3),
  383                     2);
  384                 vacc4x0123 = vmlal_lane_s16(
  385                     vacc4x0123,
  386                     vget_low_s16(vxb0123c6),
  387                     vget_high_s16(vxa4),
  388                     2);
  389                 vacc5x0123 = vmlal_lane_s16(
  390                     vacc5x0123,
  391                     vget_low_s16(vxb0123c6),
  392                     vget_high_s16(vxa5),
  393                     2);
  394               }
  395             }
  396           }
  397         }
  398       }
  399     }
  400   }
  401 
  402   const float32x4_t requantization_scale_v =
  403       vld1q_f32(
  404           &quantization_params->neon.requantization_scales[
  405               output_channel_index]);
  406 
  407   const float32x4_t vacc0x0123_f =
  408     vmulq_f32(vcvtq_f32_s32(vacc0x0123), requantization_scale_v);
  409   const float32x4_t vacc1x0123_f =
  410     vmulq_f32(vcvtq_f32_s32(vacc1x0123), requantization_scale_v);
  411   const float32x4_t vacc2x0123_f =
  412     vmulq_f32(vcvtq_f32_s32(vacc2x0123), requantization_scale_v);
  413   const float32x4_t vacc3x0123_f =
  414     vmulq_f32(vcvtq_f32_s32(vacc3x0123), requantization_scale_v);
  415   const float32x4_t vacc4x0123_f =
  416     vmulq_f32(vcvtq_f32_s32(vacc4x0123), requantization_scale_v);
  417   const float32x4_t vacc5x0123_f =
  418     vmulq_f32(vcvtq_f32_s32(vacc5x0123), requantization_scale_v);
  419 
  420 #ifdef __aarch64__
  421   const int16x8_t voutput_zero_point =
  422       vld1q_dup_s16(&quantization_params->neon.output_zero_point);
  423   vacc0x0123 = vcvtnq_s32_f32(vacc0x0123_f);
  424   vacc1x0123 = vcvtnq_s32_f32(vacc1x0123_f);
  425   vacc2x0123 = vcvtnq_s32_f32(vacc2x0123_f);
  426   vacc3x0123 = vcvtnq_s32_f32(vacc3x0123_f);
  427   vacc4x0123 = vcvtnq_s32_f32(vacc4x0123_f);
  428   vacc5x0123 = vcvtnq_s32_f32(vacc5x0123_f);
  429 
  430   const int16x8_t vacc01x0123 = vqaddq_s16(
  431       vqmovn_high_s32(vqmovn_s32(vacc0x0123), vacc1x0123), voutput_zero_point);
  432   const int16x8_t vacc23x0123 = vqaddq_s16(
  433       vqmovn_high_s32(vqmovn_s32(vacc2x0123), vacc3x0123), voutput_zero_point);
  434   const int16x8_t vacc45x0123 = vqaddq_s16(
  435       vqmovn_high_s32(vqmovn_s32(vacc4x0123), vacc5x0123), voutput_zero_point);
  436 
  437   uint8x16_t vout0123x0123 =
  438       vqmovun_high_s16(vqmovun_s16(vacc01x0123), vacc23x0123);
  439   uint8x8_t vout45x0123 = vqmovun_s16(vacc45x0123);
  440 
  441   const uint8x16_t voutput_min =
  442       vld1q_dup_u8(&quantization_params->neon.output_min);
  443   const uint8x16_t voutput_max =
  444       vld1q_dup_u8(&quantization_params->neon.output_max);
  445 
  446   vout0123x0123 = vmaxq_u8(vout0123x0123, voutput_min);
  447   vout45x0123 = vmax_u8(vout45x0123, vget_low_u8(voutput_min));
  448   vout0123x0123 = vminq_u8(vout0123x0123, voutput_max);
  449   vout45x0123 = vmin_u8(vout45x0123, vget_low_u8(voutput_max));
  450 #else
  451   const float32x4_t vfmin = vdupq_n_f32(quantization_params->neon.vfmin);
  452   const float32x4_t vfmax = vdupq_n_f32(quantization_params->neon.vfmax);
  453   const float32x4_t vfmagic = vdupq_n_f32(quantization_params->neon.vfmagic);
  454   const int32x4_t vimagic = vdupq_n_s32(quantization_params->neon.vimagic);
  455 
  456   const float32x4_t vacc0x0123_f_clamped =
  457       vminq_f32(vmaxq_f32(vacc0x0123_f, vfmin), vfmax);
  458   const float32x4_t vacc1x0123_f_clamped =
  459       vminq_f32(vmaxq_f32(vacc1x0123_f, vfmin), vfmax);
  460   const float32x4_t vacc2x0123_f_clamped =
  461       vminq_f32(vmaxq_f32(vacc2x0123_f, vfmin), vfmax);
  462   const float32x4_t vacc3x0123_f_clamped =
  463       vminq_f32(vmaxq_f32(vacc3x0123_f, vfmin), vfmax);
  464   const float32x4_t vacc4x0123_f_clamped =
  465       vminq_f32(vmaxq_f32(vacc4x0123_f, vfmin), vfmax);
  466   const float32x4_t vacc5x0123_f_clamped =
  467       vminq_f32(vmaxq_f32(vacc5x0123_f, vfmin), vfmax);
  468 
  469   vacc0x0123 = vsubq_s32(
  470       vreinterpretq_s32_f32(vaddq_f32(vacc0x0123_f_clamped, vfmagic)), vimagic);
  471   vacc1x0123 = vsubq_s32(
  472       vreinterpretq_s32_f32(vaddq_f32(vacc1x0123_f_clamped, vfmagic)), vimagic);
  473   vacc2x0123 = vsubq_s32(
  474       vreinterpretq_s32_f32(vaddq_f32(vacc2x0123_f_clamped, vfmagic)), vimagic);
  475   vacc3x0123 = vsubq_s32(
  476       vreinterpretq_s32_f32(vaddq_f32(vacc3x0123_f_clamped, vfmagic)), vimagic);
  477   vacc4x0123 = vsubq_s32(
  478       vreinterpretq_s32_f32(vaddq_f32(vacc4x0123_f_clamped, vfmagic)), vimagic);
  479   vacc5x0123 = vsubq_s32(
  480       vreinterpretq_s32_f32(vaddq_f32(vacc5x0123_f_clamped, vfmagic)), vimagic);
  481 
  482   const int16x8_t vacc01x0123 =
  483       vcombine_s16(vqmovn_s32(vacc0x0123), vqmovn_s32(vacc1x0123));
  484   const int16x8_t vacc23x0123 =
  485       vcombine_s16(vqmovn_s32(vacc2x0123), vqmovn_s32(vacc3x0123));
  486   const int16x8_t vacc45x0123 =
  487       vcombine_s16(vqmovn_s32(vacc4x0123), vqmovn_s32(vacc5x0123));
  488 
  489   uint8x16_t vout0123x0123 =
  490       vcombine_u8(vqmovun_s16(vacc01x0123), vqmovun_s16(vacc23x0123));
  491   uint8x8_t vout45x0123 = vqmovun_s16(vacc45x0123);
  492 #endif
  493 
  494   uint8_t* c0 = c;
  495   uint8_t* c1 = (uint8_t*)((uintptr_t)c0 + c_stride);
  496   if (mr < 2) {
  497     c1 = c0;
  498   }
  499   uint8_t* c2 = (uint8_t*)((uintptr_t)c1 + c_stride);
  500   if (mr <= 2) {
  501     c2 = c1;
  502   }
  503   uint8_t* c3 = (uint8_t*)((uintptr_t)c2 + c_stride);
  504   if (mr < 4) {
  505     c3 = c2;
  506   }
  507   uint8_t* c4 = (uint8_t*)((uintptr_t)c3 + c_stride);
  508   if (mr <= 4) {
  509     c4 = c3;
  510   }
  511   uint8_t* c5 = (uint8_t*)((uintptr_t)c4 + c_stride);
  512   if (mr != 6) {
  513     c5 = c4;
  514   }
  515   if (nr == 4) {
  516     vst1q_lane_u32(
  517         __builtin_assume_aligned(c0, 1),
  518         vreinterpretq_u32_u8(vout0123x0123),
  519         0);
  520     vst1q_lane_u32(
  521         __builtin_assume_aligned(c1, 1),
  522         vreinterpretq_u32_u8(vout0123x0123),
  523         1);
  524     vst1q_lane_u32(
  525         __builtin_assume_aligned(c2, 1),
  526         vreinterpretq_u32_u8(vout0123x0123),
  527         2);
  528     vst1q_lane_u32(
  529         __builtin_assume_aligned(c3, 1),
  530         vreinterpretq_u32_u8(vout0123x0123),
  531         3);
  532     vst1_lane_u32(
  533         __builtin_assume_aligned(c4, 1), vreinterpret_u32_u8(vout45x0123), 0);
  534     vst1_lane_u32(
  535         __builtin_assume_aligned(c5, 1), vreinterpret_u32_u8(vout45x0123), 1);
  536   } else {
  537     if (nr >= 2) {
  538       vst1q_lane_u16(
  539           __builtin_assume_aligned(c0, 1),
  540           vreinterpretq_u16_u8(vout0123x0123),
  541           0);
  542       c0 += 2;
  543       vst1q_lane_u16(
  544           __builtin_assume_aligned(c1, 1),
  545           vreinterpretq_u16_u8(vout0123x0123),
  546           2);
  547       c1 += 2;
  548       vst1q_lane_u16(
  549           __builtin_assume_aligned(c2, 1),
  550           vreinterpretq_u16_u8(vout0123x0123),
  551           4);
  552       c2 += 2;
  553       vst1q_lane_u16(
  554           __builtin_assume_aligned(c3, 1),
  555           vreinterpretq_u16_u8(vout0123x0123),
  556           6);
  557       c3 += 2;
  558       vst1_lane_u16(
  559           __builtin_assume_aligned(c4, 1), vreinterpret_u16_u8(vout45x0123), 0);
  560       c4 += 2;
  561       vst1_lane_u16(
  562           __builtin_assume_aligned(c5, 1), vreinterpret_u16_u8(vout45x0123), 2);
  563       c5 += 2;
  564       vout0123x0123 = vextq_u8(vout0123x0123, vout0123x0123, 2);
  565       vout45x0123 = vext_u8(vout45x0123, vout45x0123, 2);
  566       nr -= 2;
  567     }
  568     if (nr != 0) {
  569       vst1q_lane_u8(__builtin_assume_aligned(c0, 1), vout0123x0123, 0);
  570       vst1q_lane_u8(__builtin_assume_aligned(c1, 1), vout0123x0123, 4);
  571       vst1q_lane_u8(__builtin_assume_aligned(c2, 1), vout0123x0123, 8);
  572       vst1q_lane_u8(__builtin_assume_aligned(c3, 1), vout0123x0123, 12);
  573       vst1_lane_u8(__builtin_assume_aligned(c4, 1), vout45x0123, 0);
  574       vst1_lane_u8(__builtin_assume_aligned(c5, 1), vout45x0123, 4);
  575     }
  576   }
  577 }