"Fossies" - the Fresh Open Source Software Archive

Member "pytorch-1.8.2/aten/src/ATen/native/quantized/cpu/qnnpack/src/sgemm/6x8-psimd.c" (23 Jul 2021, 6296 Bytes) of package /linux/misc/pytorch-1.8.2.tar.gz:


As a special service "Fossies" has tried to format the requested source page into HTML format using (guessed) C and C++ source code syntax highlighting (style: standard) with prefixed line numbers and code folding option. Alternatively you can here view or download the uninterpreted source code file. For more information about "6x8-psimd.c" see the Fossies "Dox" file reference documentation.

    1 /*
    2  * Copyright (c) Facebook, Inc. and its affiliates.
    3  * All rights reserved.
    4  *
    5  * This source code is licensed under the BSD-style license found in the
    6  * LICENSE file in the root directory of this source tree.
    7  */
    8 
    9 #include <psimd.h>
   10 
   11 #include <qnnpack/sgemm.h>
   12 
   13 void pytorch_sgemm_ukernel_6x8__psimd(
   14     size_t mr,
   15     size_t nr,
   16     size_t k,
   17     const float* restrict a,
   18     size_t a_stride,
   19     const float* restrict w,
   20     float* restrict c,
   21     size_t c_stride,
   22     const struct pytorch_qnnp_fp32_clamping_params
   23         clamping_params[restrict static 1]) {
   24   psimd_f32 vacc0x0123 = psimd_load_f32(w);
   25   w += 4;
   26   psimd_f32 vacc0x4567 = psimd_load_f32(w);
   27   w += 4;
   28   psimd_f32 vacc1x0123 = vacc0x0123;
   29   psimd_f32 vacc1x4567 = vacc0x4567;
   30   psimd_f32 vacc2x0123 = vacc0x0123;
   31   psimd_f32 vacc2x4567 = vacc0x4567;
   32   psimd_f32 vacc3x0123 = vacc0x0123;
   33   psimd_f32 vacc3x4567 = vacc0x4567;
   34   psimd_f32 vacc4x0123 = vacc0x0123;
   35   psimd_f32 vacc4x4567 = vacc0x4567;
   36   psimd_f32 vacc5x0123 = vacc0x0123;
   37   psimd_f32 vacc5x4567 = vacc0x4567;
   38 
   39   const float* a0 = a;
   40   const float* a1 = (const float*)((uintptr_t)a0 + a_stride);
   41   if (mr < 2) {
   42     a1 = a0;
   43   }
   44   const float* a2 = (const float*)((uintptr_t)a1 + a_stride);
   45   if (mr <= 2) {
   46     a2 = a1;
   47   }
   48   const float* a3 = (const float*)((uintptr_t)a2 + a_stride);
   49   if (mr < 4) {
   50     a3 = a2;
   51   }
   52   const float* a4 = (const float*)((uintptr_t)a3 + a_stride);
   53   if (mr <= 4) {
   54     a4 = a3;
   55   }
   56   const float* a5 = (const float*)((uintptr_t)a4 + a_stride);
   57   if (mr != 6) {
   58     a5 = a4;
   59   }
   60 
   61   do {
   62     const psimd_f32 va0 = psimd_splat_f32(*a0);
   63     a0 += 1;
   64     const psimd_f32 va1 = psimd_splat_f32(*a1);
   65     a1 += 1;
   66     const psimd_f32 va2 = psimd_splat_f32(*a2);
   67     a2 += 1;
   68     const psimd_f32 va3 = psimd_splat_f32(*a3);
   69     a3 += 1;
   70     const psimd_f32 va4 = psimd_splat_f32(*a4);
   71     a4 += 1;
   72     const psimd_f32 va5 = psimd_splat_f32(*a5);
   73     a5 += 1;
   74 
   75     const psimd_f32 vb0123 = psimd_load_f32(w);
   76     w += 4;
   77     const psimd_f32 vb4567 = psimd_load_f32(w);
   78     w += 4;
   79 
   80     vacc0x0123 += vb0123 * va0;
   81     vacc0x4567 += vb4567 * va0;
   82     vacc1x0123 += vb0123 * va1;
   83     vacc1x4567 += vb4567 * va1;
   84     vacc2x0123 += vb0123 * va2;
   85     vacc2x4567 += vb4567 * va2;
   86     vacc3x0123 += vb0123 * va3;
   87     vacc3x4567 += vb4567 * va3;
   88     vacc4x0123 += vb0123 * va4;
   89     vacc4x4567 += vb4567 * va4;
   90     vacc5x0123 += vb0123 * va5;
   91     vacc5x4567 += vb4567 * va5;
   92   } while (--k != 0);
   93 
   94   const psimd_f32 vmax = psimd_splat_f32(clamping_params->max);
   95   vacc0x0123 = psimd_min_f32(vacc0x0123, vmax);
   96   vacc0x4567 = psimd_min_f32(vacc0x4567, vmax);
   97   vacc1x0123 = psimd_min_f32(vacc1x0123, vmax);
   98   vacc1x4567 = psimd_min_f32(vacc1x4567, vmax);
   99   vacc2x0123 = psimd_min_f32(vacc2x0123, vmax);
  100   vacc2x4567 = psimd_min_f32(vacc2x4567, vmax);
  101   vacc3x0123 = psimd_min_f32(vacc3x0123, vmax);
  102   vacc3x4567 = psimd_min_f32(vacc3x4567, vmax);
  103   vacc4x0123 = psimd_min_f32(vacc4x0123, vmax);
  104   vacc4x4567 = psimd_min_f32(vacc4x4567, vmax);
  105   vacc5x0123 = psimd_min_f32(vacc5x0123, vmax);
  106   vacc5x4567 = psimd_min_f32(vacc5x4567, vmax);
  107 
  108   const psimd_f32 vmin = psimd_splat_f32(clamping_params->min);
  109   vacc0x0123 = psimd_max_f32(vacc0x0123, vmin);
  110   vacc0x4567 = psimd_max_f32(vacc0x4567, vmin);
  111   vacc1x0123 = psimd_max_f32(vacc1x0123, vmin);
  112   vacc1x4567 = psimd_max_f32(vacc1x4567, vmin);
  113   vacc2x0123 = psimd_max_f32(vacc2x0123, vmin);
  114   vacc2x4567 = psimd_max_f32(vacc2x4567, vmin);
  115   vacc3x0123 = psimd_max_f32(vacc3x0123, vmin);
  116   vacc3x4567 = psimd_max_f32(vacc3x4567, vmin);
  117   vacc4x0123 = psimd_max_f32(vacc4x0123, vmin);
  118   vacc4x4567 = psimd_max_f32(vacc4x4567, vmin);
  119   vacc5x0123 = psimd_max_f32(vacc5x0123, vmin);
  120   vacc5x4567 = psimd_max_f32(vacc5x4567, vmin);
  121 
  122   float* c0 = c;
  123   float* c1 = (float*)((uintptr_t)c0 + c_stride);
  124   if (mr < 2) {
  125     c1 = c0;
  126   }
  127   float* c2 = (float*)((uintptr_t)c1 + c_stride);
  128   if (mr <= 2) {
  129     c2 = c1;
  130   }
  131   float* c3 = (float*)((uintptr_t)c2 + c_stride);
  132   if (mr < 4) {
  133     c3 = c2;
  134   }
  135   float* c4 = (float*)((uintptr_t)c3 + c_stride);
  136   if (mr <= 4) {
  137     c4 = c3;
  138   }
  139   float* c5 = (float*)((uintptr_t)c4 + c_stride);
  140   if (mr != 6) {
  141     c5 = c4;
  142   }
  143   if (nr == 8) {
  144     psimd_store_f32(c0, vacc0x0123);
  145     c0 += 4;
  146     psimd_store_f32(c1, vacc1x0123);
  147     c1 += 4;
  148     psimd_store_f32(c2, vacc2x0123);
  149     c2 += 4;
  150     psimd_store_f32(c3, vacc3x0123);
  151     c3 += 4;
  152     psimd_store_f32(c4, vacc4x0123);
  153     c4 += 4;
  154     psimd_store_f32(c5, vacc5x0123);
  155     c5 += 4;
  156 
  157     psimd_store_f32(c0, vacc0x4567);
  158     psimd_store_f32(c1, vacc1x4567);
  159     psimd_store_f32(c2, vacc2x4567);
  160     psimd_store_f32(c3, vacc3x4567);
  161     psimd_store_f32(c4, vacc4x4567);
  162     psimd_store_f32(c5, vacc5x4567);
  163   } else {
  164     if (nr >= 4) {
  165       psimd_store_f32(c0, vacc0x0123);
  166       c0 += 4;
  167       psimd_store_f32(c1, vacc1x0123);
  168       c1 += 4;
  169       psimd_store_f32(c2, vacc2x0123);
  170       c2 += 4;
  171       psimd_store_f32(c3, vacc3x0123);
  172       c3 += 4;
  173       psimd_store_f32(c4, vacc4x0123);
  174       c4 += 4;
  175       psimd_store_f32(c5, vacc5x0123);
  176       c5 += 4;
  177       vacc0x0123 = vacc0x4567;
  178       vacc1x0123 = vacc1x4567;
  179       vacc2x0123 = vacc2x4567;
  180       vacc3x0123 = vacc3x4567;
  181       vacc4x0123 = vacc4x4567;
  182       vacc5x0123 = vacc5x4567;
  183       nr -= 4;
  184     }
  185     if (nr >= 2) {
  186       psimd_store2_f32(c0, vacc0x0123);
  187       c0 += 2;
  188       psimd_store2_f32(c1, vacc1x0123);
  189       c1 += 2;
  190       psimd_store2_f32(c2, vacc2x0123);
  191       c2 += 2;
  192       psimd_store2_f32(c3, vacc3x0123);
  193       c3 += 2;
  194       psimd_store2_f32(c4, vacc4x0123);
  195       c4 += 2;
  196       psimd_store2_f32(c5, vacc5x0123);
  197       c5 += 2;
  198       vacc0x0123 = psimd_concat_hi_f32(vacc0x0123, vacc0x0123);
  199       vacc1x0123 = psimd_concat_hi_f32(vacc1x0123, vacc1x0123);
  200       vacc2x0123 = psimd_concat_hi_f32(vacc2x0123, vacc2x0123);
  201       vacc3x0123 = psimd_concat_hi_f32(vacc3x0123, vacc3x0123);
  202       vacc4x0123 = psimd_concat_hi_f32(vacc4x0123, vacc4x0123);
  203       vacc5x0123 = psimd_concat_hi_f32(vacc5x0123, vacc5x0123);
  204       nr -= 2;
  205     }
  206     if (nr != 0) {
  207       psimd_store1_f32(c0, vacc0x0123);
  208       psimd_store1_f32(c1, vacc1x0123);
  209       psimd_store1_f32(c2, vacc2x0123);
  210       psimd_store1_f32(c3, vacc3x0123);
  211       psimd_store1_f32(c4, vacc4x0123);
  212       psimd_store1_f32(c5, vacc5x0123);
  213     }
  214   }
  215 }