"Fossies" - the Fresh Open Source Software Archive

Member "pytorch-1.8.2/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8conv/4x4c2-sse2.c" (23 Jul 2021, 13518 Bytes) of package /linux/misc/pytorch-1.8.2.tar.gz:


As a special service "Fossies" has tried to format the requested source page into HTML format using (guessed) C and C++ source code syntax highlighting (style: standard) with prefixed line numbers and code folding option. Alternatively you can here view or download the uninterpreted source code file. For more information about "4x4c2-sse2.c" see the Fossies "Dox" file reference documentation.

    1 /*
    2  * Copyright (c) Facebook, Inc. and its affiliates.
    3  * All rights reserved.
    4  *
    5  * This source code is licensed under the BSD-style license found in the
    6  * LICENSE file in the root directory of this source tree.
    7  */
    8 
    9 #include <immintrin.h>
   10 
   11 #include <qnnpack/q8conv.h>
   12 #include <requantization/runtime-sse2.h>
   13 
   14 void pytorch_q8conv_ukernel_4x4c2__sse2(
   15     size_t mr,
   16     size_t nr,
   17     size_t kc,
   18     size_t ks,
   19     const uint8_t** restrict a,
   20     const void* restrict w,
   21     uint8_t* restrict c,
   22     size_t c_stride,
   23     size_t output_channel_index,
   24     const union pytorch_qnnp_conv_quantization_params
   25         quantization_params[RESTRICT_STATIC 1]) {
   26   __m128i vacc0x0123 = _mm_loadu_si128((const __m128i*)w);
   27   __m128i vacc1x0123 = vacc0x0123;
   28   __m128i vacc2x0123 = vacc0x0123;
   29   __m128i vacc3x0123 = vacc0x0123;
   30   w = (const void*)((uintptr_t)w + 16);
   31 
   32   const __m128i va_zero_point = _mm_load_si128(
   33       (const __m128i*)quantization_params->sse2.input_zero_point);
   34   const int16_t vb_zero_point_0 =
   35     quantization_params->sse2.kernel_zero_points[output_channel_index];
   36   const int16_t vb_zero_point_1 =
   37       quantization_params->sse2.kernel_zero_points[output_channel_index + 1];
   38   const int16_t vb_zero_point_2 =
   39       quantization_params->sse2.kernel_zero_points[output_channel_index + 2];
   40   const int16_t vb_zero_point_3 =
   41       quantization_params->sse2.kernel_zero_points[output_channel_index + 3];
   42 
   43   const __m128i vb_zero_point = _mm_set_epi16(vb_zero_point_3,
   44                                               vb_zero_point_3,
   45                                               vb_zero_point_2,
   46                                               vb_zero_point_2,
   47                                               vb_zero_point_1,
   48                                               vb_zero_point_1,
   49                                               vb_zero_point_0,
   50                                               vb_zero_point_0
   51                                               );
   52   const __m128i vzero = _mm_setzero_si128();
   53   do {
   54     const uint8_t* restrict a0 = *a++;
   55     const uint8_t* restrict a1 = *a++;
   56     const uint8_t* restrict a2 = *a++;
   57     const uint8_t* restrict a3 = *a++;
   58 
   59     size_t k = kc;
   60     for (; k >= 8; k -= 8) {
   61       const __m128i va0 = _mm_loadl_epi64((const __m128i*)a0);
   62       const __m128i vxa0 =
   63           sub_zero_point(_mm_unpacklo_epi8(va0, vzero), va_zero_point);
   64       a0 += 8;
   65       const __m128i va1 = _mm_loadl_epi64((const __m128i*)a1);
   66       const __m128i vxa1 =
   67           sub_zero_point(_mm_unpacklo_epi8(va1, vzero), va_zero_point);
   68       a1 += 8;
   69       const __m128i va2 = _mm_loadl_epi64((const __m128i*)a2);
   70       const __m128i vxa2 =
   71           sub_zero_point(_mm_unpacklo_epi8(va2, vzero), va_zero_point);
   72       a2 += 8;
   73       const __m128i va3 = _mm_loadl_epi64((const __m128i*)a3);
   74       const __m128i vxa3 =
   75           sub_zero_point(_mm_unpacklo_epi8(va3, vzero), va_zero_point);
   76       a3 += 8;
   77 
   78       const __m128i vb0 = _mm_loadl_epi64((const __m128i*)w);
   79       const __m128i vxb0 =
   80           _mm_sub_epi16(_mm_unpacklo_epi8(vb0, vzero), vb_zero_point);
   81       vacc0x0123 = _mm_add_epi32(
   82           vacc0x0123,
   83           _mm_madd_epi16(
   84               _mm_shuffle_epi32(vxa0, _MM_SHUFFLE(0, 0, 0, 0)), vxb0));
   85       vacc1x0123 = _mm_add_epi32(
   86           vacc1x0123,
   87           _mm_madd_epi16(
   88               _mm_shuffle_epi32(vxa1, _MM_SHUFFLE(0, 0, 0, 0)), vxb0));
   89       vacc2x0123 = _mm_add_epi32(
   90           vacc2x0123,
   91           _mm_madd_epi16(
   92               _mm_shuffle_epi32(vxa2, _MM_SHUFFLE(0, 0, 0, 0)), vxb0));
   93       vacc3x0123 = _mm_add_epi32(
   94           vacc3x0123,
   95           _mm_madd_epi16(
   96               _mm_shuffle_epi32(vxa3, _MM_SHUFFLE(0, 0, 0, 0)), vxb0));
   97 
   98       const __m128i vb1 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 8));
   99       const __m128i vxb1 =
  100           _mm_sub_epi16(_mm_unpacklo_epi8(vb1, vzero), vb_zero_point);
  101       vacc0x0123 = _mm_add_epi32(
  102           vacc0x0123,
  103           _mm_madd_epi16(
  104               _mm_shuffle_epi32(vxa0, _MM_SHUFFLE(1, 1, 1, 1)), vxb1));
  105       vacc1x0123 = _mm_add_epi32(
  106           vacc1x0123,
  107           _mm_madd_epi16(
  108               _mm_shuffle_epi32(vxa1, _MM_SHUFFLE(1, 1, 1, 1)), vxb1));
  109       vacc2x0123 = _mm_add_epi32(
  110           vacc2x0123,
  111           _mm_madd_epi16(
  112               _mm_shuffle_epi32(vxa2, _MM_SHUFFLE(1, 1, 1, 1)), vxb1));
  113       vacc3x0123 = _mm_add_epi32(
  114           vacc3x0123,
  115           _mm_madd_epi16(
  116               _mm_shuffle_epi32(vxa3, _MM_SHUFFLE(1, 1, 1, 1)), vxb1));
  117 
  118       const __m128i vb2 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 16));
  119       const __m128i vxb2 =
  120           _mm_sub_epi16(_mm_unpacklo_epi8(vb2, vzero), vb_zero_point);
  121       vacc0x0123 = _mm_add_epi32(
  122           vacc0x0123,
  123           _mm_madd_epi16(
  124               _mm_shuffle_epi32(vxa0, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
  125       vacc1x0123 = _mm_add_epi32(
  126           vacc1x0123,
  127           _mm_madd_epi16(
  128               _mm_shuffle_epi32(vxa1, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
  129       vacc2x0123 = _mm_add_epi32(
  130           vacc2x0123,
  131           _mm_madd_epi16(
  132               _mm_shuffle_epi32(vxa2, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
  133       vacc3x0123 = _mm_add_epi32(
  134           vacc3x0123,
  135           _mm_madd_epi16(
  136               _mm_shuffle_epi32(vxa3, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
  137 
  138       const __m128i vb3 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 24));
  139       const __m128i vxb3 =
  140           _mm_sub_epi16(_mm_unpacklo_epi8(vb3, vzero), vb_zero_point);
  141       vacc0x0123 = _mm_add_epi32(
  142           vacc0x0123,
  143           _mm_madd_epi16(
  144               _mm_shuffle_epi32(vxa0, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
  145       vacc1x0123 = _mm_add_epi32(
  146           vacc1x0123,
  147           _mm_madd_epi16(
  148               _mm_shuffle_epi32(vxa1, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
  149       vacc2x0123 = _mm_add_epi32(
  150           vacc2x0123,
  151           _mm_madd_epi16(
  152               _mm_shuffle_epi32(vxa2, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
  153       vacc3x0123 = _mm_add_epi32(
  154           vacc3x0123,
  155           _mm_madd_epi16(
  156               _mm_shuffle_epi32(vxa3, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
  157 
  158       w = (void*)((uintptr_t)w + 32);
  159     }
  160     if (k != 0) {
  161       const size_t a_predecrement = 8 - k;
  162       const __m128i va_shift = _mm_cvtsi32_si128(8 * a_predecrement);
  163 
  164       const __m128i va0 = _mm_srl_epi64(
  165           _mm_loadl_epi64((const __m128i*)(a0 - a_predecrement)), va_shift);
  166       const __m128i vxa0 =
  167           sub_zero_point(_mm_unpacklo_epi8(va0, vzero), va_zero_point);
  168       const __m128i va1 = _mm_srl_epi64(
  169           _mm_loadl_epi64((const __m128i*)(a1 - a_predecrement)), va_shift);
  170       const __m128i vxa1 =
  171           sub_zero_point(_mm_unpacklo_epi8(va1, vzero), va_zero_point);
  172       const __m128i va2 = _mm_srl_epi64(
  173           _mm_loadl_epi64((const __m128i*)(a2 - a_predecrement)), va_shift);
  174       const __m128i vxa2 =
  175           sub_zero_point(_mm_unpacklo_epi8(va2, vzero), va_zero_point);
  176       const __m128i va3 = _mm_srl_epi64(
  177           _mm_loadl_epi64((const __m128i*)(a3 - a_predecrement)), va_shift);
  178       const __m128i vxa3 =
  179           sub_zero_point(_mm_unpacklo_epi8(va3, vzero), va_zero_point);
  180 
  181       const __m128i vb0 = _mm_loadl_epi64((const __m128i*)w);
  182       const __m128i vxb0 =
  183           _mm_sub_epi16(_mm_unpacklo_epi8(vb0, vzero), vb_zero_point);
  184       w = (void*)((uintptr_t)w + 8);
  185 
  186       vacc0x0123 = _mm_add_epi32(
  187           vacc0x0123,
  188           _mm_madd_epi16(
  189               _mm_shuffle_epi32(vxa0, _MM_SHUFFLE(0, 0, 0, 0)), vxb0));
  190       vacc1x0123 = _mm_add_epi32(
  191           vacc1x0123,
  192           _mm_madd_epi16(
  193               _mm_shuffle_epi32(vxa1, _MM_SHUFFLE(0, 0, 0, 0)), vxb0));
  194       vacc2x0123 = _mm_add_epi32(
  195           vacc2x0123,
  196           _mm_madd_epi16(
  197               _mm_shuffle_epi32(vxa2, _MM_SHUFFLE(0, 0, 0, 0)), vxb0));
  198       vacc3x0123 = _mm_add_epi32(
  199           vacc3x0123,
  200           _mm_madd_epi16(
  201               _mm_shuffle_epi32(vxa3, _MM_SHUFFLE(0, 0, 0, 0)), vxb0));
  202 
  203       if (k > 2) {
  204         const __m128i vb1 = _mm_loadl_epi64((const __m128i*)w);
  205         const __m128i vxb1 =
  206             _mm_sub_epi16(_mm_unpacklo_epi8(vb1, vzero), vb_zero_point);
  207         w = (void*)((uintptr_t)w + 8);
  208 
  209         vacc0x0123 = _mm_add_epi32(
  210             vacc0x0123,
  211             _mm_madd_epi16(
  212                 _mm_shuffle_epi32(vxa0, _MM_SHUFFLE(1, 1, 1, 1)), vxb1));
  213         vacc1x0123 = _mm_add_epi32(
  214             vacc1x0123,
  215             _mm_madd_epi16(
  216                 _mm_shuffle_epi32(vxa1, _MM_SHUFFLE(1, 1, 1, 1)), vxb1));
  217         vacc2x0123 = _mm_add_epi32(
  218             vacc2x0123,
  219             _mm_madd_epi16(
  220                 _mm_shuffle_epi32(vxa2, _MM_SHUFFLE(1, 1, 1, 1)), vxb1));
  221         vacc3x0123 = _mm_add_epi32(
  222             vacc3x0123,
  223             _mm_madd_epi16(
  224                 _mm_shuffle_epi32(vxa3, _MM_SHUFFLE(1, 1, 1, 1)), vxb1));
  225 
  226         if (k > 4) {
  227           const __m128i vb2 = _mm_loadl_epi64((const __m128i*)w);
  228           const __m128i vxb2 =
  229               _mm_sub_epi16(_mm_unpacklo_epi8(vb2, vzero), vb_zero_point);
  230           w = (void*)((uintptr_t)w + 8);
  231 
  232           vacc0x0123 = _mm_add_epi32(
  233               vacc0x0123,
  234               _mm_madd_epi16(
  235                   _mm_shuffle_epi32(vxa0, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
  236           vacc1x0123 = _mm_add_epi32(
  237               vacc1x0123,
  238               _mm_madd_epi16(
  239                   _mm_shuffle_epi32(vxa1, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
  240           vacc2x0123 = _mm_add_epi32(
  241               vacc2x0123,
  242               _mm_madd_epi16(
  243                   _mm_shuffle_epi32(vxa2, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
  244           vacc3x0123 = _mm_add_epi32(
  245               vacc3x0123,
  246               _mm_madd_epi16(
  247                   _mm_shuffle_epi32(vxa3, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
  248 
  249           if (k > 6) {
  250             const __m128i vb3 = _mm_loadl_epi64((const __m128i*)w);
  251             const __m128i vxb3 =
  252                 _mm_sub_epi16(_mm_unpacklo_epi8(vb3, vzero), vb_zero_point);
  253             w = (void*)((uintptr_t)w + 8);
  254 
  255             vacc0x0123 = _mm_add_epi32(
  256                 vacc0x0123,
  257                 _mm_madd_epi16(
  258                     _mm_shuffle_epi32(vxa0, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
  259             vacc1x0123 = _mm_add_epi32(
  260                 vacc1x0123,
  261                 _mm_madd_epi16(
  262                     _mm_shuffle_epi32(vxa1, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
  263             vacc2x0123 = _mm_add_epi32(
  264                 vacc2x0123,
  265                 _mm_madd_epi16(
  266                     _mm_shuffle_epi32(vxa2, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
  267             vacc3x0123 = _mm_add_epi32(
  268                 vacc3x0123,
  269                 _mm_madd_epi16(
  270                     _mm_shuffle_epi32(vxa3, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
  271           }
  272         }
  273       }
  274     }
  275   } while (--ks != 0);
  276 
  277   const __m128 vmultiplier =
  278       _mm_loadu_ps(&quantization_params->sse2.requantization_scales
  279           [output_channel_index]);
  280 
  281   vacc0x0123 = _mm_cvtps_epi32(
  282                 _mm_mul_ps(
  283                   _mm_cvtepi32_ps(vacc0x0123),
  284                   vmultiplier
  285                   )
  286                 );
  287   vacc1x0123 = _mm_cvtps_epi32(
  288                 _mm_mul_ps(
  289                   _mm_cvtepi32_ps(vacc1x0123),
  290                   vmultiplier
  291                   )
  292                 );
  293   vacc2x0123 = _mm_cvtps_epi32(
  294                 _mm_mul_ps(
  295                   _mm_cvtepi32_ps(vacc2x0123),
  296                   vmultiplier
  297                   )
  298                 );
  299   vacc3x0123 = _mm_cvtps_epi32(
  300                 _mm_mul_ps(
  301                   _mm_cvtepi32_ps(vacc3x0123),
  302                   vmultiplier
  303                   )
  304                 );
  305 
  306   const __m128i voutput_zero_point = _mm_load_si128(
  307       (const __m128i*)quantization_params->sse2.output_zero_point);
  308   const __m128i vacc01x0123 = _mm_adds_epi16(
  309       _mm_packs_epi32(vacc0x0123, vacc1x0123), voutput_zero_point);
  310   const __m128i vacc23x0123 = _mm_adds_epi16(
  311       _mm_packs_epi32(vacc2x0123, vacc3x0123), voutput_zero_point);
  312   __m128i vout = _mm_packus_epi16(vacc01x0123, vacc23x0123);
  313   vout = _mm_min_epu8(
  314       vout,
  315       _mm_load_si128((const __m128i*)quantization_params->sse2.output_max));
  316   vout = _mm_max_epu8(
  317       vout,
  318       _mm_load_si128((const __m128i*)quantization_params->sse2.output_min));
  319 
  320   uint8_t* c0 = c;
  321   uint8_t* c1 = (uint8_t*)((uintptr_t)c0 + c_stride);
  322   if (mr < 2) {
  323     c1 = c0;
  324   }
  325   uint8_t* c2 = (uint8_t*)((uintptr_t)c1 + c_stride);
  326   if (mr <= 2) {
  327     c2 = c1;
  328   }
  329   uint8_t* c3 = (uint8_t*)((uintptr_t)c2 + c_stride);
  330   if (mr != 4) {
  331     c3 = c2;
  332   }
  333   if (nr == 4) {
  334     *((uint32_t*)c0) = (uint32_t)_mm_cvtsi128_si32(vout);
  335     *((uint32_t*)c1) = (uint32_t)_mm_cvtsi128_si32(_mm_srli_epi64(vout, 32));
  336     *((uint32_t*)c2) =
  337         (uint32_t)_mm_cvtsi128_si32(_mm_unpackhi_epi32(vout, vout));
  338     *((uint32_t*)c3) = (uint32_t)_mm_cvtsi128_si32(_mm_srli_si128(vout, 12));
  339   } else {
  340     if (nr >= 2) {
  341       *((uint16_t*)c0) = (uint16_t)_mm_extract_epi16(vout, 0);
  342       c0 += 2;
  343       *((uint16_t*)c1) = (uint16_t)_mm_extract_epi16(vout, 2);
  344       c1 += 2;
  345       *((uint16_t*)c2) = (uint16_t)_mm_extract_epi16(vout, 4);
  346       c2 += 2;
  347       *((uint16_t*)c3) = (uint16_t)_mm_extract_epi16(vout, 6);
  348       c3 += 2;
  349       vout = _mm_srli_epi32(vout, 16);
  350       nr -= 2;
  351     }
  352     if (nr != 0) {
  353       *((uint8_t*)c0) = (uint8_t)_mm_cvtsi128_si32(vout);
  354       *((uint8_t*)c1) = (uint8_t)_mm_extract_epi16(vout, 2);
  355       *((uint8_t*)c2) = (uint8_t)_mm_extract_epi16(vout, 4);
  356       *((uint8_t*)c3) = (uint8_t)_mm_extract_epi16(vout, 6);
  357     }
  358   }
  359 }