"Fossies" - the Fresh Open Source Software Archive

Member "pytorch-1.8.2/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm/2x4c8-sse2.c" (23 Jul 2021, 8597 Bytes) of package /linux/misc/pytorch-1.8.2.tar.gz:


As a special service "Fossies" has tried to format the requested source page into HTML format using (guessed) C and C++ source code syntax highlighting (style: standard) with prefixed line numbers and code folding option. Alternatively you can here view or download the uninterpreted source code file. For more information about "2x4c8-sse2.c" see the Fossies "Dox" file reference documentation.

    1 /*
    2  * Copyright (c) Facebook, Inc. and its affiliates.
    3  * All rights reserved.
    4  *
    5  * This source code is licensed under the BSD-style license found in the
    6  * LICENSE file in the root directory of this source tree.
    7  */
    8 
    9 #include <immintrin.h>
   10 
   11 #include <qnnpack/q8gemm.h>
   12 #include <requantization/runtime-sse2.h>
   13 
   14 static inline __m128i pytorch_sse_reduce4_i32(
   15     __m128i x,
   16     __m128i y,
   17     __m128i z,
   18     __m128i w) {
   19 #if defined(__SSSE3__) && !defined(__ANDROID__)
   20   /* xxyy = ( y2 + y3, y0 + y1, x2 + x3, x0 + x1 ) */
   21   const __m128i xxyy = _mm_hadd_epi32(x, y);
   22   /* zzww = ( w2 + w3, w0 + w1, z2 + z3, z0 + z1 ) */
   23   const __m128i zzww = _mm_hadd_epi32(z, w);
   24   /* xyzw = ( w0 + w1 + w2 + w3, y0 + y1 + y2 + y3, z0 + z1 + z2 + z3, x0 + x1 +
   25    * x2 + x3 ) */
   26   return _mm_hadd_epi32(xxyy, zzww);
   27 #else
   28   /* xzxz = ( z1 + z3, x1 + x3, z0 + z2, x0 + x2 ) */
   29   const __m128i xzxz =
   30       _mm_add_epi32(_mm_unpacklo_epi32(x, z), _mm_unpackhi_epi32(x, z));
   31   /* ywyw = ( w1 + w3, y1 + y3, w0 + w2, y0 + y2 ) */
   32   const __m128i ywyw =
   33       _mm_add_epi32(_mm_unpacklo_epi32(y, w), _mm_unpackhi_epi32(y, w));
   34   /* xyzw = ( w0 + w2 + w1 + w3, y0 + y2 + y1 + y3, z0 + z2 + z1 + z3, x0 + x2 +
   35    * x1 + x3 ) */
   36   return _mm_add_epi32(
   37       _mm_unpacklo_epi32(xzxz, ywyw), _mm_unpackhi_epi32(xzxz, ywyw));
   38 #endif
   39 }
   40 
   41 void pytorch_q8gemm_ukernel_2x4c8__sse2(
   42     size_t mr,
   43     size_t nr,
   44     size_t k,
   45     const uint8_t* restrict a,
   46     size_t a_stride,
   47     const void* restrict w,
   48     uint8_t* restrict c,
   49     size_t c_stride,
   50     size_t output_channel_index,
   51     const union pytorch_qnnp_conv_quantization_params
   52         quantization_params[RESTRICT_STATIC 1]) {
   53   __m128i vacc00 = _mm_cvtsi32_si128((int)((const int32_t*)w)[0]);
   54   __m128i vacc01 = _mm_cvtsi32_si128((int)((const int32_t*)w)[1]);
   55   __m128i vacc02 = _mm_cvtsi32_si128((int)((const int32_t*)w)[2]);
   56   __m128i vacc03 = _mm_cvtsi32_si128((int)((const int32_t*)w)[3]);
   57   __m128i vacc10 = vacc00;
   58   __m128i vacc11 = vacc01;
   59   __m128i vacc12 = vacc02;
   60   __m128i vacc13 = vacc03;
   61   w = (const void*)((uintptr_t)w + 16);
   62 
   63   const uint8_t* a0 = a;
   64   const uint8_t* a1 = (const uint8_t*)((uintptr_t)a0 + a_stride);
   65   if (mr != 2) {
   66     a1 = a0;
   67   }
   68 
   69   const uint8_t* b0 = w;
   70   const uint8_t* b1 = b0 + 8;
   71   if (nr < 2) {
   72     b1 = b0;
   73   }
   74   const uint8_t* b2 = b1 + 8;
   75   if (nr <= 2) {
   76     b2 = b1;
   77   }
   78   const uint8_t* b3 = b2 + 8;
   79   if (nr != 4) {
   80     b3 = b2;
   81   }
   82   const size_t b_stride = nr * 8;
   83 
   84   const __m128i va_zero_point = _mm_load_si128(
   85       (const __m128i*)quantization_params->sse2.input_zero_point);
   86   const __m128i vb_zero_point_0 = _mm_set1_epi16(
   87       (int16_t)(uint16_t)quantization_params->sse2.kernel_zero_points[
   88         output_channel_index]);
   89   // Assumes kernel_zero_point allocated memory is always multiple of nr=4.
   90   const __m128i vb_zero_point_1 = _mm_set1_epi16(
   91       (int16_t)(uint16_t)quantization_params->sse2.kernel_zero_points[
   92         output_channel_index + 1]);
   93   const __m128i vb_zero_point_2 = _mm_set1_epi16(
   94       (int16_t)(uint16_t)quantization_params->sse2.kernel_zero_points[
   95         output_channel_index + 2]);
   96   const __m128i vb_zero_point_3 = _mm_set1_epi16(
   97       (int16_t)(uint16_t)quantization_params->sse2.kernel_zero_points[
   98         output_channel_index + 3]);
   99   const __m128i vzero = _mm_setzero_si128();
  100   for (; k >= 8; k -= 8) {
  101     const __m128i va0 = _mm_loadl_epi64((const __m128i*)a0);
  102     const __m128i vxa0 =
  103         sub_zero_point(_mm_unpacklo_epi8(va0, vzero), va_zero_point);
  104     a0 += 8;
  105     const __m128i va1 = _mm_loadl_epi64((const __m128i*)a1);
  106     const __m128i vxa1 =
  107         sub_zero_point(_mm_unpacklo_epi8(va1, vzero), va_zero_point);
  108     a1 += 8;
  109 
  110     const __m128i vb0 = _mm_loadl_epi64((const __m128i*)b0);
  111     const __m128i vxb0 =
  112         _mm_sub_epi16(_mm_unpacklo_epi8(vb0, vzero), vb_zero_point_0);
  113     b0 += b_stride;
  114     const __m128i vb1 = _mm_loadl_epi64((const __m128i*)b1);
  115     const __m128i vxb1 =
  116         _mm_sub_epi16(_mm_unpacklo_epi8(vb1, vzero), vb_zero_point_1);
  117     b1 += b_stride;
  118     const __m128i vb2 = _mm_loadl_epi64((const __m128i*)b2);
  119     const __m128i vxb2 =
  120         _mm_sub_epi16(_mm_unpacklo_epi8(vb2, vzero), vb_zero_point_2);
  121     b2 += b_stride;
  122     const __m128i vb3 = _mm_loadl_epi64((const __m128i*)b3);
  123     const __m128i vxb3 =
  124         _mm_sub_epi16(_mm_unpacklo_epi8(vb3, vzero), vb_zero_point_3);
  125     b3 += b_stride;
  126 
  127     vacc00 = _mm_add_epi32(vacc00, _mm_madd_epi16(vxa0, vxb0));
  128     vacc01 = _mm_add_epi32(vacc01, _mm_madd_epi16(vxa0, vxb1));
  129     vacc02 = _mm_add_epi32(vacc02, _mm_madd_epi16(vxa0, vxb2));
  130     vacc03 = _mm_add_epi32(vacc03, _mm_madd_epi16(vxa0, vxb3));
  131     vacc10 = _mm_add_epi32(vacc10, _mm_madd_epi16(vxa1, vxb0));
  132     vacc11 = _mm_add_epi32(vacc11, _mm_madd_epi16(vxa1, vxb1));
  133     vacc12 = _mm_add_epi32(vacc12, _mm_madd_epi16(vxa1, vxb2));
  134     vacc13 = _mm_add_epi32(vacc13, _mm_madd_epi16(vxa1, vxb3));
  135   }
  136   if (k != 0) {
  137     const size_t a_predecrement = 8 - k;
  138     const __m128i va_shift = _mm_cvtsi32_si128(8 * a_predecrement);
  139 
  140     const __m128i va_zero_point_partial = _mm_unpacklo_epi8(
  141         _mm_srl_epi64(_mm_packus_epi16(va_zero_point, va_zero_point), va_shift),
  142         vzero);
  143 
  144     const __m128i va0 = _mm_srl_epi64(
  145         _mm_loadl_epi64((const __m128i*)(a0 - a_predecrement)), va_shift);
  146     const __m128i vxa0 =
  147         sub_zero_point(_mm_unpacklo_epi8(va0, vzero), va_zero_point_partial);
  148     const __m128i va1 = _mm_srl_epi64(
  149         _mm_loadl_epi64((const __m128i*)(a1 - a_predecrement)), va_shift);
  150     const __m128i vxa1 =
  151         sub_zero_point(_mm_unpacklo_epi8(va1, vzero), va_zero_point_partial);
  152 
  153     const __m128i vb0 = _mm_loadl_epi64((const __m128i*)b0);
  154     const __m128i vxb0 =
  155         _mm_sub_epi16(_mm_unpacklo_epi8(vb0, vzero), vb_zero_point_0);
  156     const __m128i vb1 = _mm_loadl_epi64((const __m128i*)b1);
  157     const __m128i vxb1 =
  158         _mm_sub_epi16(_mm_unpacklo_epi8(vb1, vzero), vb_zero_point_1);
  159     const __m128i vb2 = _mm_loadl_epi64((const __m128i*)b2);
  160     const __m128i vxb2 =
  161         _mm_sub_epi16(_mm_unpacklo_epi8(vb2, vzero), vb_zero_point_2);
  162     const __m128i vb3 = _mm_loadl_epi64((const __m128i*)b3);
  163     const __m128i vxb3 =
  164         _mm_sub_epi16(_mm_unpacklo_epi8(vb3, vzero), vb_zero_point_3);
  165 
  166     vacc00 = _mm_add_epi32(vacc00, _mm_madd_epi16(vxa0, vxb0));
  167     vacc01 = _mm_add_epi32(vacc01, _mm_madd_epi16(vxa0, vxb1));
  168     vacc02 = _mm_add_epi32(vacc02, _mm_madd_epi16(vxa0, vxb2));
  169     vacc03 = _mm_add_epi32(vacc03, _mm_madd_epi16(vxa0, vxb3));
  170     vacc10 = _mm_add_epi32(vacc10, _mm_madd_epi16(vxa1, vxb0));
  171     vacc11 = _mm_add_epi32(vacc11, _mm_madd_epi16(vxa1, vxb1));
  172     vacc12 = _mm_add_epi32(vacc12, _mm_madd_epi16(vxa1, vxb2));
  173     vacc13 = _mm_add_epi32(vacc13, _mm_madd_epi16(vxa1, vxb3));
  174   }
  175 
  176   __m128i vacc0x0123 = pytorch_sse_reduce4_i32(vacc00, vacc01, vacc02, vacc03);
  177   __m128i vacc1x0123 = pytorch_sse_reduce4_i32(vacc10, vacc11, vacc12, vacc13);
  178 
  179   const __m128 vmultiplier =
  180       _mm_loadu_ps(&quantization_params->sse2.requantization_scales
  181           [output_channel_index]);
  182 
  183   vacc0x0123 = _mm_cvtps_epi32(
  184                 _mm_mul_ps(
  185                   _mm_cvtepi32_ps(vacc0x0123),
  186                   vmultiplier
  187                   )
  188                 );
  189   vacc1x0123 = _mm_cvtps_epi32(
  190                 _mm_mul_ps(
  191                   _mm_cvtepi32_ps(vacc1x0123),
  192                   vmultiplier
  193                   )
  194                 );
  195 
  196   const __m128i voutput_zero_point = _mm_load_si128(
  197       (const __m128i*)quantization_params->sse2.output_zero_point);
  198   const __m128i vacc01x0123 = _mm_adds_epi16(
  199       _mm_packs_epi32(vacc0x0123, vacc1x0123), voutput_zero_point);
  200   __m128i vout = _mm_packus_epi16(vacc01x0123, vacc01x0123);
  201   vout = _mm_min_epu8(
  202       vout,
  203       _mm_load_si128((const __m128i*)quantization_params->sse2.output_max));
  204   vout = _mm_max_epu8(
  205       vout,
  206       _mm_load_si128((const __m128i*)quantization_params->sse2.output_min));
  207 
  208   uint8_t* c0 = c;
  209   uint8_t* c1 = (uint8_t*)((uintptr_t)c0 + c_stride);
  210   if (mr != 2) {
  211     c1 = c0;
  212   }
  213   if (nr == 4) {
  214     *((uint32_t*)c0) = (uint32_t)_mm_cvtsi128_si32(vout);
  215     *((uint32_t*)c1) = (uint32_t)_mm_cvtsi128_si32(_mm_srli_epi64(vout, 32));
  216   } else {
  217     if (nr >= 2) {
  218       *((uint16_t*)c0) = (uint16_t)_mm_extract_epi16(vout, 0);
  219       c0 += 2;
  220       *((uint16_t*)c1) = (uint16_t)_mm_extract_epi16(vout, 2);
  221       c1 += 2;
  222       vout = _mm_srli_epi32(vout, 16);
  223       nr -= 2;
  224     }
  225     if (nr != 0) {
  226       *((uint8_t*)c0) = (uint8_t)_mm_cvtsi128_si32(vout);
  227       *((uint8_t*)c1) = (uint8_t)_mm_extract_epi16(vout, 2);
  228     }
  229   }
  230 }