"Fossies" - the Fresh Open Source Software Archive

Member "pytorch-1.8.2/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm/4x4c2-dq-sse2.c" (23 Jul 2021, 11857 Bytes) of package /linux/misc/pytorch-1.8.2.tar.gz:


As a special service "Fossies" has tried to format the requested source page into HTML format using (guessed) C and C++ source code syntax highlighting (style: standard) with prefixed line numbers and code folding option. Alternatively you can here view or download the uninterpreted source code file. For more information about "4x4c2-dq-sse2.c" see the Fossies "Dox" file reference documentation.

    1 /*
    2  * Copyright (c) Facebook, Inc. and its affiliates.
    3  * All rights reserved.
    4  *
    5  * This source code is licensed under the BSD-style license found in the
    6  * LICENSE file in the root directory of this source tree.
    7  */
    8 
    9 #include <immintrin.h>
   10 
   11 #include <qnnpack/q8gemm.h>
   12 #include <requantization/runtime-sse2.h>
   13 
   14 void pytorch_q8gemm_dq_ukernel_4x4c2__sse2(
   15     size_t mr,
   16     size_t nr,
   17     size_t k,
   18     const uint8_t* restrict a,
   19     size_t a_stride,
   20     const void* restrict w,
   21     const float* restrict b,
   22     float* restrict c,
   23     size_t c_stride,
   24     size_t output_channel_index,
   25     const struct pytorch_qnnp_conv_dynamic_quantization_params
   26         quantization_params[RESTRICT_STATIC 1]) {
   27   __m128i vacc0x0123 = _mm_setzero_si128();
   28   __m128i vacc1x0123 = _mm_setzero_si128();
   29   __m128i vacc2x0123 = _mm_setzero_si128();
   30   __m128i vacc3x0123 = _mm_setzero_si128();
   31   w = (const void*)((uintptr_t)w + 16);
   32 
   33   const uint8_t* a0 = a;
   34   const uint8_t* a1 = (const uint8_t*)((uintptr_t)a0 + a_stride);
   35   if (mr < 2) {
   36     a1 = a0;
   37   }
   38   const uint8_t* a2 = (const uint8_t*)((uintptr_t)a1 + a_stride);
   39   if (mr <= 2) {
   40     a2 = a1;
   41   }
   42   const uint8_t* a3 = (const uint8_t*)((uintptr_t)a2 + a_stride);
   43   if (mr != 4) {
   44     a3 = a2;
   45   }
   46 
   47   const __m128i va_zero_point = _mm_set1_epi16(quantization_params->input_zero_point);
   48   const int16_t vb_zero_point_0 =
   49     (int16_t)(uint16_t)quantization_params->kernel_zero_points[
   50     output_channel_index];
   51   const int16_t vb_zero_point_1 =
   52       (int16_t)(uint16_t)quantization_params->kernel_zero_points[
   53         output_channel_index + 1];
   54   const int16_t vb_zero_point_2 =
   55       (int16_t)(uint16_t)quantization_params->kernel_zero_points[
   56         output_channel_index + 2];
   57   const int16_t vb_zero_point_3 =
   58       (int16_t)(uint16_t)quantization_params->kernel_zero_points[
   59         output_channel_index + 3];
   60 
   61   __m128i vb_zero_point = _mm_set_epi16(vb_zero_point_3,
   62                                         vb_zero_point_3,
   63                                         vb_zero_point_2,
   64                                         vb_zero_point_2,
   65                                         vb_zero_point_1,
   66                                         vb_zero_point_1,
   67                                         vb_zero_point_0,
   68                                         vb_zero_point_0
   69                                         );
   70   const __m128 vmultiplier =
   71       _mm_loadu_ps(&quantization_params->multipliers[output_channel_index]);
   72 
   73   const __m128 vbias = _mm_load_ps(b);
   74 
   75   const __m128i vzero = _mm_setzero_si128();
   76   for (; k >= 8; k -= 8) {
   77     const __m128i va0 = _mm_loadl_epi64((const __m128i*)a0);
   78     const __m128i vxa0 =
   79         sub_zero_point(_mm_unpacklo_epi8(va0, vzero), va_zero_point);
   80     a0 += 8;
   81     const __m128i va1 = _mm_loadl_epi64((const __m128i*)a1);
   82     const __m128i vxa1 =
   83         sub_zero_point(_mm_unpacklo_epi8(va1, vzero), va_zero_point);
   84     a1 += 8;
   85     const __m128i va2 = _mm_loadl_epi64((const __m128i*)a2);
   86     const __m128i vxa2 =
   87         sub_zero_point(_mm_unpacklo_epi8(va2, vzero), va_zero_point);
   88     a2 += 8;
   89     const __m128i va3 = _mm_loadl_epi64((const __m128i*)a3);
   90     const __m128i vxa3 =
   91         sub_zero_point(_mm_unpacklo_epi8(va3, vzero), va_zero_point);
   92     a3 += 8;
   93 
   94     const __m128i vb0 = _mm_loadl_epi64((const __m128i*)w);
   95     const __m128i vxb0 =
   96         _mm_sub_epi16(_mm_unpacklo_epi8(vb0, vzero), vb_zero_point);
   97 
   98     vacc0x0123 = _mm_add_epi32(
   99         vacc0x0123,
  100         _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(0, 0, 0, 0)), vxb0));
  101     vacc1x0123 = _mm_add_epi32(
  102         vacc1x0123,
  103         _mm_madd_epi16(_mm_shuffle_epi32(vxa1, _MM_SHUFFLE(0, 0, 0, 0)), vxb0));
  104     vacc2x0123 = _mm_add_epi32(
  105         vacc2x0123,
  106         _mm_madd_epi16(_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(0, 0, 0, 0)), vxb0));
  107     vacc3x0123 = _mm_add_epi32(
  108 
  109         vacc3x0123,
  110         _mm_madd_epi16(_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(0, 0, 0, 0)), vxb0));
  111 
  112     const __m128i vb1 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 8));
  113     const __m128i vxb1 =
  114         _mm_sub_epi16(_mm_unpacklo_epi8(vb1, vzero), vb_zero_point);
  115 
  116     vacc0x0123 = _mm_add_epi32(
  117         vacc0x0123,
  118         _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(1, 1, 1, 1)), vxb1));
  119     vacc1x0123 = _mm_add_epi32(
  120         vacc1x0123,
  121         _mm_madd_epi16(_mm_shuffle_epi32(vxa1, _MM_SHUFFLE(1, 1, 1, 1)), vxb1));
  122     vacc2x0123 = _mm_add_epi32(
  123         vacc2x0123,
  124         _mm_madd_epi16(_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(1, 1, 1, 1)), vxb1));
  125     vacc3x0123 = _mm_add_epi32(
  126         vacc3x0123,
  127         _mm_madd_epi16(_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(1, 1, 1, 1)), vxb1));
  128 
  129     const __m128i vb2 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 16));
  130     const __m128i vxb2 =
  131         _mm_sub_epi16(_mm_unpacklo_epi8(vb2, vzero), vb_zero_point);
  132 
  133     vacc0x0123 = _mm_add_epi32(
  134         vacc0x0123,
  135         _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
  136     vacc1x0123 = _mm_add_epi32(
  137         vacc1x0123,
  138         _mm_madd_epi16(_mm_shuffle_epi32(vxa1, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
  139     vacc2x0123 = _mm_add_epi32(
  140         vacc2x0123,
  141         _mm_madd_epi16(_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
  142     vacc3x0123 = _mm_add_epi32(
  143         vacc3x0123,
  144         _mm_madd_epi16(_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
  145 
  146     const __m128i vb3 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 24));
  147     const __m128i vxb3 =
  148         _mm_sub_epi16(_mm_unpacklo_epi8(vb3, vzero), vb_zero_point);
  149     w = (const void*)((uintptr_t)w + 32);
  150 
  151     vacc0x0123 = _mm_add_epi32(
  152         vacc0x0123,
  153         _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
  154     vacc1x0123 = _mm_add_epi32(
  155         vacc1x0123,
  156         _mm_madd_epi16(_mm_shuffle_epi32(vxa1, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
  157     vacc2x0123 = _mm_add_epi32(
  158         vacc2x0123,
  159         _mm_madd_epi16(_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
  160     vacc3x0123 = _mm_add_epi32(
  161         vacc3x0123,
  162         _mm_madd_epi16(_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
  163   }
  164   if (k != 0) {
  165     const size_t a_predecrement = 8 - k;
  166     const __m128i va_shift = _mm_cvtsi32_si128(8 * a_predecrement);
  167 
  168     const __m128i va0 = _mm_srl_epi64(
  169         _mm_loadl_epi64((const __m128i*)(a0 - a_predecrement)), va_shift);
  170     const __m128i vxa0 =
  171         sub_zero_point(_mm_unpacklo_epi8(va0, vzero), va_zero_point);
  172     const __m128i va1 = _mm_srl_epi64(
  173         _mm_loadl_epi64((const __m128i*)(a1 - a_predecrement)), va_shift);
  174     const __m128i vxa1 =
  175         sub_zero_point(_mm_unpacklo_epi8(va1, vzero), va_zero_point);
  176     const __m128i va2 = _mm_srl_epi64(
  177         _mm_loadl_epi64((const __m128i*)(a2 - a_predecrement)), va_shift);
  178     const __m128i vxa2 =
  179         sub_zero_point(_mm_unpacklo_epi8(va2, vzero), va_zero_point);
  180     const __m128i va3 = _mm_srl_epi64(
  181         _mm_loadl_epi64((const __m128i*)(a3 - a_predecrement)), va_shift);
  182     const __m128i vxa3 =
  183         sub_zero_point(_mm_unpacklo_epi8(va3, vzero), va_zero_point);
  184 
  185     const __m128i vb0 = _mm_loadl_epi64((const __m128i*)w);
  186     const __m128i vxb0 =
  187         _mm_sub_epi16(_mm_unpacklo_epi8(vb0, vzero), vb_zero_point);
  188 
  189     vacc0x0123 = _mm_add_epi32(
  190         vacc0x0123,
  191         _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(0, 0, 0, 0)), vxb0));
  192     vacc1x0123 = _mm_add_epi32(
  193         vacc1x0123,
  194         _mm_madd_epi16(_mm_shuffle_epi32(vxa1, _MM_SHUFFLE(0, 0, 0, 0)), vxb0));
  195     vacc2x0123 = _mm_add_epi32(
  196         vacc2x0123,
  197         _mm_madd_epi16(_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(0, 0, 0, 0)), vxb0));
  198     vacc3x0123 = _mm_add_epi32(
  199         vacc3x0123,
  200         _mm_madd_epi16(_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(0, 0, 0, 0)), vxb0));
  201 
  202     if (k > 2) {
  203       const __m128i vb1 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 8));
  204       const __m128i vxb1 =
  205           _mm_sub_epi16(_mm_unpacklo_epi8(vb1, vzero), vb_zero_point);
  206 
  207       vacc0x0123 = _mm_add_epi32(
  208           vacc0x0123,
  209           _mm_madd_epi16(
  210               _mm_shuffle_epi32(vxa0, _MM_SHUFFLE(1, 1, 1, 1)), vxb1));
  211       vacc1x0123 = _mm_add_epi32(
  212           vacc1x0123,
  213           _mm_madd_epi16(
  214               _mm_shuffle_epi32(vxa1, _MM_SHUFFLE(1, 1, 1, 1)), vxb1));
  215       vacc2x0123 = _mm_add_epi32(
  216           vacc2x0123,
  217           _mm_madd_epi16(
  218               _mm_shuffle_epi32(vxa2, _MM_SHUFFLE(1, 1, 1, 1)), vxb1));
  219       vacc3x0123 = _mm_add_epi32(
  220           vacc3x0123,
  221           _mm_madd_epi16(
  222               _mm_shuffle_epi32(vxa3, _MM_SHUFFLE(1, 1, 1, 1)), vxb1));
  223 
  224       if (k > 4) {
  225         const __m128i vb2 =
  226             _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 16));
  227         const __m128i vxb2 =
  228             _mm_sub_epi16(_mm_unpacklo_epi8(vb2, vzero), vb_zero_point);
  229 
  230         vacc0x0123 = _mm_add_epi32(
  231             vacc0x0123,
  232             _mm_madd_epi16(
  233                 _mm_shuffle_epi32(vxa0, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
  234         vacc1x0123 = _mm_add_epi32(
  235             vacc1x0123,
  236             _mm_madd_epi16(
  237                 _mm_shuffle_epi32(vxa1, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
  238         vacc2x0123 = _mm_add_epi32(
  239             vacc2x0123,
  240             _mm_madd_epi16(
  241                 _mm_shuffle_epi32(vxa2, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
  242         vacc3x0123 = _mm_add_epi32(
  243             vacc3x0123,
  244             _mm_madd_epi16(
  245                 _mm_shuffle_epi32(vxa3, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
  246 
  247         if (k > 6) {
  248           const __m128i vb3 =
  249               _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 24));
  250           const __m128i vxb3 =
  251               _mm_sub_epi16(_mm_unpacklo_epi8(vb3, vzero), vb_zero_point);
  252 
  253           vacc0x0123 = _mm_add_epi32(
  254               vacc0x0123,
  255               _mm_madd_epi16(
  256                   _mm_shuffle_epi32(vxa0, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
  257           vacc1x0123 = _mm_add_epi32(
  258               vacc1x0123,
  259               _mm_madd_epi16(
  260                   _mm_shuffle_epi32(vxa1, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
  261           vacc2x0123 = _mm_add_epi32(
  262               vacc2x0123,
  263               _mm_madd_epi16(
  264                   _mm_shuffle_epi32(vxa2, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
  265           vacc3x0123 = _mm_add_epi32(
  266               vacc3x0123,
  267               _mm_madd_epi16(
  268                   _mm_shuffle_epi32(vxa3, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
  269         }
  270       }
  271     }
  272   }
  273 
  274   __m128 vout0 = _mm_mul_ps(vmultiplier, _mm_cvtepi32_ps(vacc0x0123));
  275   __m128 vout1 = _mm_mul_ps(vmultiplier, _mm_cvtepi32_ps(vacc1x0123));
  276   __m128 vout2 = _mm_mul_ps(vmultiplier, _mm_cvtepi32_ps(vacc2x0123));
  277   __m128 vout3 = _mm_mul_ps(vmultiplier, _mm_cvtepi32_ps(vacc3x0123));
  278 
  279   vout0 = _mm_add_ps(vout0, vbias);
  280   vout1 = _mm_add_ps(vout1, vbias);
  281   vout2 = _mm_add_ps(vout2, vbias);
  282   vout3 = _mm_add_ps(vout3, vbias);
  283 
  284   float* c0 = c;
  285   float* c1 = c0 + c_stride;
  286   if (mr < 2) {
  287     c1 = c0;
  288   }
  289   float* c2 = c1 + c_stride;
  290   if (mr <= 2) {
  291     c2 = c1;
  292   }
  293   float* c3 = c2 + c_stride;
  294   if (mr != 4) {
  295     c3 = c2;
  296   }
  297 
  298   if (nr == 4) {
  299     _mm_storeu_ps(c0, vout0);
  300     _mm_storeu_ps(c1, vout1);
  301     _mm_storeu_ps(c2, vout2);
  302     _mm_storeu_ps(c3, vout3);
  303   } else {
  304     if (nr >= 2) {
  305       _mm_storel_pi((__m64*)c0, vout0);
  306       _mm_storel_pi((__m64*)c1, vout1);
  307       _mm_storel_pi((__m64*)c2, vout2);
  308       _mm_storel_pi((__m64*)c3, vout3);
  309 
  310       c0 += 2;
  311       vout0 = _mm_shuffle_ps(vout0, vout0, _MM_SHUFFLE(2, 2, 2, 2));
  312       c1 += 2;
  313       vout1 = _mm_shuffle_ps(vout1, vout1, _MM_SHUFFLE(2, 2, 2, 2));
  314       c2 += 2;
  315       vout2 = _mm_shuffle_ps(vout2, vout2, _MM_SHUFFLE(2, 2, 2, 2));
  316       c3 += 2;
  317       vout3 = _mm_shuffle_ps(vout3, vout3, _MM_SHUFFLE(2, 2, 2, 2));
  318 
  319       nr -= 2;
  320     }
  321     if (nr != 0) {
  322       *c0 = _mm_cvtss_f32(vout0);
  323       *c1 = _mm_cvtss_f32(vout1);
  324       *c2 = _mm_cvtss_f32(vout2);
  325       *c3 = _mm_cvtss_f32(vout3);
  326     }
  327   }
  328 }