"Fossies" - the Fresh Open Source Software Archive

Member "pytorch-1.8.2/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm/4x4c2-sse2.c" (23 Jul 2021, 12958 Bytes) of package /linux/misc/pytorch-1.8.2.tar.gz:


As a special service "Fossies" has tried to format the requested source page into HTML format using (guessed) C and C++ source code syntax highlighting (style: standard) with prefixed line numbers and code folding option. Alternatively you can here view or download the uninterpreted source code file. For more information about "4x4c2-sse2.c" see the Fossies "Dox" file reference documentation.

    1 /*
    2  * Copyright (c) Facebook, Inc. and its affiliates.
    3  * All rights reserved.
    4  *
    5  * This source code is licensed under the BSD-style license found in the
    6  * LICENSE file in the root directory of this source tree.
    7  */
    8 
    9 #include <immintrin.h>
   10 
   11 #include <qnnpack/q8gemm.h>
   12 #include <requantization/runtime-sse2.h>
   13 
   14 void pytorch_q8gemm_ukernel_4x4c2__sse2(
   15     size_t mr,
   16     size_t nr,
   17     size_t k,
   18     const uint8_t* restrict a,
   19     size_t a_stride,
   20     const void* restrict w,
   21     uint8_t* restrict c,
   22     size_t c_stride,
   23     size_t output_channel_index,
   24     const union pytorch_qnnp_conv_quantization_params
   25         quantization_params[RESTRICT_STATIC 1]) {
   26   __m128i vacc0x0123 = _mm_loadu_si128((const __m128i*)w);
   27   __m128i vacc1x0123 = vacc0x0123;
   28   __m128i vacc2x0123 = vacc0x0123;
   29   __m128i vacc3x0123 = vacc0x0123;
   30   w = (const void*)((uintptr_t)w + 16);
   31 
   32   const uint8_t* a0 = a;
   33   const uint8_t* a1 = (const uint8_t*)((uintptr_t)a0 + a_stride);
   34   if (mr < 2) {
   35     a1 = a0;
   36   }
   37   const uint8_t* a2 = (const uint8_t*)((uintptr_t)a1 + a_stride);
   38   if (mr <= 2) {
   39     a2 = a1;
   40   }
   41   const uint8_t* a3 = (const uint8_t*)((uintptr_t)a2 + a_stride);
   42   if (mr != 4) {
   43     a3 = a2;
   44   }
   45 
   46   const __m128i va_zero_point = _mm_load_si128(
   47       (const __m128i*)quantization_params->sse2.input_zero_point);
   48   const int16_t vb_zero_point_0 =
   49     (int16_t)(uint16_t)quantization_params->sse2.kernel_zero_points[
   50     output_channel_index];
   51   const int16_t vb_zero_point_1 =
   52       (int16_t)(uint16_t)quantization_params->sse2.kernel_zero_points[
   53         output_channel_index + 1];
   54   const int16_t vb_zero_point_2 =
   55       (int16_t)(uint16_t)quantization_params->sse2.kernel_zero_points[
   56         output_channel_index + 2];
   57   const int16_t vb_zero_point_3 =
   58       (int16_t)(uint16_t)quantization_params->sse2.kernel_zero_points[
   59         output_channel_index + 3];
   60 
   61   __m128i vb_zero_point = _mm_set_epi16(vb_zero_point_3,
   62                                         vb_zero_point_3,
   63                                         vb_zero_point_2,
   64                                         vb_zero_point_2,
   65                                         vb_zero_point_1,
   66                                         vb_zero_point_1,
   67                                         vb_zero_point_0,
   68                                         vb_zero_point_0
   69                                         );
   70   const __m128i vzero = _mm_setzero_si128();
   71   for (; k >= 8; k -= 8) {
   72     const __m128i va0 = _mm_loadl_epi64((const __m128i*)a0);
   73     const __m128i vxa0 =
   74         sub_zero_point(_mm_unpacklo_epi8(va0, vzero), va_zero_point);
   75     a0 += 8;
   76     const __m128i va1 = _mm_loadl_epi64((const __m128i*)a1);
   77     const __m128i vxa1 =
   78         sub_zero_point(_mm_unpacklo_epi8(va1, vzero), va_zero_point);
   79     a1 += 8;
   80     const __m128i va2 = _mm_loadl_epi64((const __m128i*)a2);
   81     const __m128i vxa2 =
   82         sub_zero_point(_mm_unpacklo_epi8(va2, vzero), va_zero_point);
   83     a2 += 8;
   84     const __m128i va3 = _mm_loadl_epi64((const __m128i*)a3);
   85     const __m128i vxa3 =
   86         sub_zero_point(_mm_unpacklo_epi8(va3, vzero), va_zero_point);
   87     a3 += 8;
   88 
   89     const __m128i vb0 = _mm_loadl_epi64((const __m128i*)w);
   90     const __m128i vxb0 =
   91         _mm_sub_epi16(_mm_unpacklo_epi8(vb0, vzero), vb_zero_point);
   92 
   93     vacc0x0123 = _mm_add_epi32(
   94         vacc0x0123,
   95         _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(0, 0, 0, 0)), vxb0));
   96     vacc1x0123 = _mm_add_epi32(
   97         vacc1x0123,
   98         _mm_madd_epi16(_mm_shuffle_epi32(vxa1, _MM_SHUFFLE(0, 0, 0, 0)), vxb0));
   99     vacc2x0123 = _mm_add_epi32(
  100         vacc2x0123,
  101         _mm_madd_epi16(_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(0, 0, 0, 0)), vxb0));
  102     vacc3x0123 = _mm_add_epi32(
  103         vacc3x0123,
  104         _mm_madd_epi16(_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(0, 0, 0, 0)), vxb0));
  105 
  106     const __m128i vb1 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 8));
  107     const __m128i vxb1 =
  108         _mm_sub_epi16(_mm_unpacklo_epi8(vb1, vzero), vb_zero_point);
  109 
  110     vacc0x0123 = _mm_add_epi32(
  111         vacc0x0123,
  112         _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(1, 1, 1, 1)), vxb1));
  113     vacc1x0123 = _mm_add_epi32(
  114         vacc1x0123,
  115         _mm_madd_epi16(_mm_shuffle_epi32(vxa1, _MM_SHUFFLE(1, 1, 1, 1)), vxb1));
  116     vacc2x0123 = _mm_add_epi32(
  117         vacc2x0123,
  118         _mm_madd_epi16(_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(1, 1, 1, 1)), vxb1));
  119     vacc3x0123 = _mm_add_epi32(
  120         vacc3x0123,
  121         _mm_madd_epi16(_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(1, 1, 1, 1)), vxb1));
  122 
  123     const __m128i vb2 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 16));
  124     const __m128i vxb2 =
  125         _mm_sub_epi16(_mm_unpacklo_epi8(vb2, vzero), vb_zero_point);
  126 
  127     vacc0x0123 = _mm_add_epi32(
  128         vacc0x0123,
  129         _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
  130     vacc1x0123 = _mm_add_epi32(
  131         vacc1x0123,
  132         _mm_madd_epi16(_mm_shuffle_epi32(vxa1, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
  133     vacc2x0123 = _mm_add_epi32(
  134         vacc2x0123,
  135         _mm_madd_epi16(_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
  136     vacc3x0123 = _mm_add_epi32(
  137         vacc3x0123,
  138         _mm_madd_epi16(_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
  139 
  140     const __m128i vb3 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 24));
  141     const __m128i vxb3 =
  142         _mm_sub_epi16(_mm_unpacklo_epi8(vb3, vzero), vb_zero_point);
  143     w = (const void*)((uintptr_t)w + 32);
  144 
  145     vacc0x0123 = _mm_add_epi32(
  146         vacc0x0123,
  147         _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
  148     vacc1x0123 = _mm_add_epi32(
  149         vacc1x0123,
  150         _mm_madd_epi16(_mm_shuffle_epi32(vxa1, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
  151     vacc2x0123 = _mm_add_epi32(
  152         vacc2x0123,
  153         _mm_madd_epi16(_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
  154     vacc3x0123 = _mm_add_epi32(
  155         vacc3x0123,
  156         _mm_madd_epi16(_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
  157   }
  158   if (k != 0) {
  159     const size_t a_predecrement = 8 - k;
  160     const __m128i va_shift = _mm_cvtsi32_si128(8 * a_predecrement);
  161 
  162     const __m128i va0 = _mm_srl_epi64(
  163         _mm_loadl_epi64((const __m128i*)(a0 - a_predecrement)), va_shift);
  164     const __m128i vxa0 =
  165         sub_zero_point(_mm_unpacklo_epi8(va0, vzero), va_zero_point);
  166     const __m128i va1 = _mm_srl_epi64(
  167         _mm_loadl_epi64((const __m128i*)(a1 - a_predecrement)), va_shift);
  168     const __m128i vxa1 =
  169         sub_zero_point(_mm_unpacklo_epi8(va1, vzero), va_zero_point);
  170     const __m128i va2 = _mm_srl_epi64(
  171         _mm_loadl_epi64((const __m128i*)(a2 - a_predecrement)), va_shift);
  172     const __m128i vxa2 =
  173         sub_zero_point(_mm_unpacklo_epi8(va2, vzero), va_zero_point);
  174     const __m128i va3 = _mm_srl_epi64(
  175         _mm_loadl_epi64((const __m128i*)(a3 - a_predecrement)), va_shift);
  176     const __m128i vxa3 =
  177         sub_zero_point(_mm_unpacklo_epi8(va3, vzero), va_zero_point);
  178 
  179     const __m128i vb0 = _mm_loadl_epi64((const __m128i*)w);
  180     const __m128i vxb0 =
  181         _mm_sub_epi16(_mm_unpacklo_epi8(vb0, vzero), vb_zero_point);
  182 
  183     vacc0x0123 = _mm_add_epi32(
  184         vacc0x0123,
  185         _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(0, 0, 0, 0)), vxb0));
  186     vacc1x0123 = _mm_add_epi32(
  187         vacc1x0123,
  188         _mm_madd_epi16(_mm_shuffle_epi32(vxa1, _MM_SHUFFLE(0, 0, 0, 0)), vxb0));
  189     vacc2x0123 = _mm_add_epi32(
  190         vacc2x0123,
  191         _mm_madd_epi16(_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(0, 0, 0, 0)), vxb0));
  192     vacc3x0123 = _mm_add_epi32(
  193         vacc3x0123,
  194         _mm_madd_epi16(_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(0, 0, 0, 0)), vxb0));
  195 
  196     if (k > 2) {
  197       const __m128i vb1 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 8));
  198       const __m128i vxb1 =
  199           _mm_sub_epi16(_mm_unpacklo_epi8(vb1, vzero), vb_zero_point);
  200 
  201       vacc0x0123 = _mm_add_epi32(
  202           vacc0x0123,
  203           _mm_madd_epi16(
  204               _mm_shuffle_epi32(vxa0, _MM_SHUFFLE(1, 1, 1, 1)), vxb1));
  205       vacc1x0123 = _mm_add_epi32(
  206           vacc1x0123,
  207           _mm_madd_epi16(
  208               _mm_shuffle_epi32(vxa1, _MM_SHUFFLE(1, 1, 1, 1)), vxb1));
  209       vacc2x0123 = _mm_add_epi32(
  210           vacc2x0123,
  211           _mm_madd_epi16(
  212               _mm_shuffle_epi32(vxa2, _MM_SHUFFLE(1, 1, 1, 1)), vxb1));
  213       vacc3x0123 = _mm_add_epi32(
  214           vacc3x0123,
  215           _mm_madd_epi16(
  216               _mm_shuffle_epi32(vxa3, _MM_SHUFFLE(1, 1, 1, 1)), vxb1));
  217 
  218       if (k > 4) {
  219         const __m128i vb2 =
  220             _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 16));
  221         const __m128i vxb2 =
  222             _mm_sub_epi16(_mm_unpacklo_epi8(vb2, vzero), vb_zero_point);
  223 
  224         vacc0x0123 = _mm_add_epi32(
  225             vacc0x0123,
  226             _mm_madd_epi16(
  227                 _mm_shuffle_epi32(vxa0, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
  228         vacc1x0123 = _mm_add_epi32(
  229             vacc1x0123,
  230             _mm_madd_epi16(
  231                 _mm_shuffle_epi32(vxa1, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
  232         vacc2x0123 = _mm_add_epi32(
  233             vacc2x0123,
  234             _mm_madd_epi16(
  235                 _mm_shuffle_epi32(vxa2, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
  236         vacc3x0123 = _mm_add_epi32(
  237             vacc3x0123,
  238             _mm_madd_epi16(
  239                 _mm_shuffle_epi32(vxa3, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
  240 
  241         if (k > 6) {
  242           const __m128i vb3 =
  243               _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 24));
  244           const __m128i vxb3 =
  245               _mm_sub_epi16(_mm_unpacklo_epi8(vb3, vzero), vb_zero_point);
  246 
  247           vacc0x0123 = _mm_add_epi32(
  248               vacc0x0123,
  249               _mm_madd_epi16(
  250                   _mm_shuffle_epi32(vxa0, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
  251           vacc1x0123 = _mm_add_epi32(
  252               vacc1x0123,
  253               _mm_madd_epi16(
  254                   _mm_shuffle_epi32(vxa1, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
  255           vacc2x0123 = _mm_add_epi32(
  256               vacc2x0123,
  257               _mm_madd_epi16(
  258                   _mm_shuffle_epi32(vxa2, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
  259           vacc3x0123 = _mm_add_epi32(
  260               vacc3x0123,
  261               _mm_madd_epi16(
  262                   _mm_shuffle_epi32(vxa3, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
  263         }
  264       }
  265     }
  266   }
  267 
  268   const __m128 vmultiplier =
  269       _mm_loadu_ps(&quantization_params->sse2.requantization_scales[output_channel_index]);
  270 
  271   vacc0x0123 = _mm_cvtps_epi32(
  272                 _mm_mul_ps(
  273                   _mm_cvtepi32_ps(vacc0x0123),
  274                   vmultiplier
  275                   )
  276                 );
  277   vacc1x0123 = _mm_cvtps_epi32(
  278                 _mm_mul_ps(
  279                   _mm_cvtepi32_ps(vacc1x0123),
  280                   vmultiplier
  281                   )
  282                 );
  283   vacc2x0123 = _mm_cvtps_epi32(
  284                 _mm_mul_ps(
  285                   _mm_cvtepi32_ps(vacc2x0123),
  286                   vmultiplier
  287                   )
  288                 );
  289   vacc3x0123 = _mm_cvtps_epi32(
  290                 _mm_mul_ps(
  291                   _mm_cvtepi32_ps(vacc3x0123),
  292                   vmultiplier
  293                   )
  294                 );
  295 
  296   const __m128i voutput_zero_point = _mm_load_si128(
  297       (const __m128i*)quantization_params->sse2.output_zero_point);
  298   const __m128i vacc01x0123 = _mm_adds_epi16(
  299       _mm_packs_epi32(vacc0x0123, vacc1x0123), voutput_zero_point);
  300   const __m128i vacc23x0123 = _mm_adds_epi16(
  301       _mm_packs_epi32(vacc2x0123, vacc3x0123), voutput_zero_point);
  302   __m128i vout = _mm_packus_epi16(vacc01x0123, vacc23x0123);
  303   vout = _mm_min_epu8(
  304       vout,
  305       _mm_load_si128((const __m128i*)quantization_params->sse2.output_max));
  306   vout = _mm_max_epu8(
  307       vout,
  308       _mm_load_si128((const __m128i*)quantization_params->sse2.output_min));
  309 
  310   uint8_t* c0 = c;
  311   uint8_t* c1 = (uint8_t*)((uintptr_t)c0 + c_stride);
  312   if (mr < 2) {
  313     c1 = c0;
  314   }
  315   uint8_t* c2 = (uint8_t*)((uintptr_t)c1 + c_stride);
  316   if (mr <= 2) {
  317     c2 = c1;
  318   }
  319   uint8_t* c3 = (uint8_t*)((uintptr_t)c2 + c_stride);
  320   if (mr != 4) {
  321     c3 = c2;
  322   }
  323   if (nr == 4) {
  324     *((uint32_t*)c0) = (uint32_t)_mm_cvtsi128_si32(vout);
  325     *((uint32_t*)c1) = (uint32_t)_mm_cvtsi128_si32(_mm_srli_epi64(vout, 32));
  326     *((uint32_t*)c2) =
  327         (uint32_t)_mm_cvtsi128_si32(_mm_unpackhi_epi32(vout, vout));
  328     *((uint32_t*)c3) = (uint32_t)_mm_cvtsi128_si32(_mm_srli_si128(vout, 12));
  329   } else {
  330     if (nr >= 2) {
  331       *((uint16_t*)c0) = (uint16_t)_mm_extract_epi16(vout, 0);
  332       c0 += 2;
  333       *((uint16_t*)c1) = (uint16_t)_mm_extract_epi16(vout, 2);
  334       c1 += 2;
  335       *((uint16_t*)c2) = (uint16_t)_mm_extract_epi16(vout, 4);
  336       c2 += 2;
  337       *((uint16_t*)c3) = (uint16_t)_mm_extract_epi16(vout, 6);
  338       c3 += 2;
  339       vout = _mm_srli_epi32(vout, 16);
  340       nr -= 2;
  341     }
  342     if (nr != 0) {
  343       *((uint8_t*)c0) = (uint8_t)_mm_cvtsi128_si32(vout);
  344       *((uint8_t*)c1) = (uint8_t)_mm_extract_epi16(vout, 2);
  345       *((uint8_t*)c2) = (uint8_t)_mm_extract_epi16(vout, 4);
  346       *((uint8_t*)c3) = (uint8_t)_mm_extract_epi16(vout, 6);
  347     }
  348   }
  349 }