"Fossies" - the Fresh Open Source Software Archive

Member "pytorch-1.8.2/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm/4x-sumrows-neon.c" (23 Jul 2021, 4862 Bytes) of package /linux/misc/pytorch-1.8.2.tar.gz:


As a special service "Fossies" has tried to format the requested source page into HTML format using (guessed) C and C++ source code syntax highlighting (style: standard) with prefixed line numbers and code folding option. Alternatively you can here view or download the uninterpreted source code file. For more information about "4x-sumrows-neon.c" see the Fossies "Dox" file reference documentation.

    1 /*
    2  * Copyright (c) Facebook, Inc. and its affiliates.
    3  * All rights reserved.
    4  *
    5  * This source code is licensed under the BSD-style license found in the
    6  * LICENSE file in the root directory of this source tree.
    7  */
    8 
    9 #include <arm_neon.h>
   10 
   11 #include <qnnpack/q8gemm.h>
   12 
   13 void pytorch_q8sumrows_ukernel_4x__neon(
   14     const uint8_t* restrict a,
   15     size_t m,
   16     size_t k,
   17     size_t stride,
   18     const int32_t multiplier,
   19     int32_t* restrict a_sum) {
   20   const uint8_t* a0 = a;
   21   const uint8_t* a1 = a0;
   22   if (m >= 2) {
   23     a1 += stride;
   24   }
   25   const uint8_t* a2 = a1;
   26   if (m > 2) {
   27     a2 += stride;
   28   }
   29   const uint8_t* a3 = a2;
   30   if (m == 4) {
   31     a3 += stride;
   32   }
   33 
   34   uint32x4_t vacc0x0123 = vmovq_n_u32(0); // row 0
   35   uint32x4_t vacc1x0123 = vmovq_n_u32(0); // row 1
   36   uint32x4_t vacc2x0123 = vmovq_n_u32(0); // row 2
   37   uint32x4_t vacc3x0123 = vmovq_n_u32(0); // row 3
   38   for (; k >= 16; k -= 16) {
   39     // row 0
   40     const uint8x16_t va0x0_15 = vld1q_u8(a0);
   41     a0 += 16;
   42     vacc0x0123 = vpadalq_u16(
   43         vacc0x0123, vaddl_u8(vget_low_u8(va0x0_15), vget_high_u8(va0x0_15)));
   44 
   45     // row 1
   46     const uint8x16_t va1x0_15 = vld1q_u8(a1);
   47     a1 += 16;
   48     vacc1x0123 = vpadalq_u16(
   49         vacc1x0123, vaddl_u8(vget_low_u8(va1x0_15), vget_high_u8(va1x0_15)));
   50 
   51     // row 2
   52     const uint8x16_t va2x0_15 = vld1q_u8(a2);
   53     a2 += 16;
   54     vacc2x0123 = vpadalq_u16(
   55         vacc2x0123, vaddl_u8(vget_low_u8(va2x0_15), vget_high_u8(va2x0_15)));
   56 
   57     // row 3
   58     const uint8x16_t va3x0_15 = vld1q_u8(a3);
   59     a3 += 16;
   60     vacc3x0123 = vpadalq_u16(
   61         vacc3x0123, vaddl_u8(vget_low_u8(va3x0_15), vget_high_u8(va3x0_15)));
   62   }
   63 
   64   if (k >= 8) {
   65     vacc0x0123 = vaddw_u16(vacc0x0123, vpaddl_u8(vld1_u8(a0)));
   66     a0 += 8;
   67     vacc1x0123 = vaddw_u16(vacc1x0123, vpaddl_u8(vld1_u8(a1)));
   68     a1 += 8;
   69     vacc2x0123 = vaddw_u16(vacc2x0123, vpaddl_u8(vld1_u8(a2)));
   70     a2 += 8;
   71     vacc3x0123 = vaddw_u16(vacc3x0123, vpaddl_u8(vld1_u8(a3)));
   72     a3 += 8;
   73     k -= 8;
   74   }
   75 
   76   if (k >= 4) {
   77     vacc0x0123 = vaddw_u16(
   78         vacc0x0123,
   79         vget_low_u16(vmovl_u8(vreinterpret_u8_u32(
   80             vld1_dup_u32(__builtin_assume_aligned((const uint32_t*)a0, 1))))));
   81     a0 += 4;
   82     vacc1x0123 = vaddw_u16(
   83         vacc1x0123,
   84         vget_low_u16(vmovl_u8(vreinterpret_u8_u32(
   85             vld1_dup_u32(__builtin_assume_aligned((const uint32_t*)a1, 1))))));
   86     a1 += 4;
   87     vacc2x0123 = vaddw_u16(
   88         vacc2x0123,
   89         vget_low_u16(vmovl_u8(vreinterpret_u8_u32(
   90             vld1_dup_u32(__builtin_assume_aligned((const uint32_t*)a2, 1))))));
   91     a2 += 4;
   92     vacc3x0123 = vaddw_u16(
   93         vacc3x0123,
   94         vget_low_u16(vmovl_u8(vreinterpret_u8_u32(
   95             vld1_dup_u32(__builtin_assume_aligned((const uint32_t*)a3, 1))))));
   96     a3 += 4;
   97     k -= 4;
   98   }
   99 
  100   const uint32x2_t vsum0x01 =
  101       vpadd_u32(vget_low_u32(vacc0x0123), vget_high_u32(vacc0x0123));
  102   const uint32x2_t vsum1x01 =
  103       vpadd_u32(vget_low_u32(vacc1x0123), vget_high_u32(vacc1x0123));
  104   const uint32x2_t vsum2x01 =
  105       vpadd_u32(vget_low_u32(vacc2x0123), vget_high_u32(vacc2x0123));
  106   const uint32x2_t vsum3x01 =
  107       vpadd_u32(vget_low_u32(vacc3x0123), vget_high_u32(vacc3x0123));
  108   uint32x4_t vacc0123 = vcombine_u32(
  109       vpadd_u32(vsum0x01, vsum1x01), vpadd_u32(vsum2x01, vsum3x01));
  110 
  111   if (k >= 2) {
  112     const uint8x8_t va0x01010101 = vreinterpret_u8_u16(
  113         vld1_dup_u16(__builtin_assume_aligned((const uint16_t*)a0, 1)));
  114     a0 += 2;
  115     const uint8x8_t va1x01010101 = vreinterpret_u8_u16(
  116         vld1_dup_u16(__builtin_assume_aligned((const uint16_t*)a1, 1)));
  117     a1 += 2;
  118     const uint8x8_t va2x01010101 = vreinterpret_u8_u16(
  119         vld1_dup_u16(__builtin_assume_aligned((const uint16_t*)a2, 1)));
  120     a2 += 2;
  121     const uint8x8_t va3x01010101 = vreinterpret_u8_u16(
  122         vld1_dup_u16(__builtin_assume_aligned((const uint16_t*)a3, 1)));
  123     a3 += 2;
  124     const uint8x8_t va0x01_1x010101 = vext_u8(va0x01010101, va1x01010101, 2);
  125     const uint8x8_t va2x01_3x010101 = vext_u8(va2x01010101, va3x01010101, 6);
  126     const uint8x8_t va0123x01 = vext_u8(va0x01_1x010101, va2x01_3x010101, 4);
  127     vacc0123 = vaddw_u16(vacc0123, vpaddl_u8(va0123x01));
  128     k -= 2;
  129   }
  130 
  131   if (k > 0) {
  132     uint8x8_t vax0x1x2x3 = vmov_n_u8(0);
  133     vax0x1x2x3 = vld1_lane_u8(a0, vax0x1x2x3, 0);
  134     vax0x1x2x3 = vld1_lane_u8(a1, vax0x1x2x3, 2);
  135     vax0x1x2x3 = vld1_lane_u8(a2, vax0x1x2x3, 4);
  136     vax0x1x2x3 = vld1_lane_u8(a3, vax0x1x2x3, 6);
  137     vacc0123 = vaddw_u16(vacc0123, vpaddl_u8(vax0x1x2x3));
  138   }
  139 
  140   int32x4_t vsum0123 = vmulq_n_s32(vreinterpretq_s32_u32(vacc0123), multiplier);
  141   if (m == 4) {
  142     vst1q_s32(a_sum, vsum0123);
  143   } else {
  144     if (m >= 2) {
  145       vst1_s32(a_sum, vget_low_s32(vsum0123));
  146       a_sum += 2;
  147       vsum0123 = vextq_s32(vsum0123, vsum0123, 2);
  148       m -= 2;
  149     }
  150     if (m != 0) {
  151       vst1q_lane_s32(a_sum, vsum0123, 0);
  152     }
  153   }
  154 }