"Fossies" - the Fresh Open Source Software Archive

Member "pytorch-1.8.2/aten/src/ATen/native/quantized/cpu/qnnpack/src/sgemm/6x8-neon.c" (23 Jul 2021, 10669 Bytes) of package /linux/misc/pytorch-1.8.2.tar.gz:


As a special service "Fossies" has tried to format the requested source page into HTML format using (guessed) C and C++ source code syntax highlighting (style: standard) with prefixed line numbers and code folding option. Alternatively you can here view or download the uninterpreted source code file. For more information about "6x8-neon.c" see the Fossies "Dox" file reference documentation.

    1 /*
    2  * Copyright (c) Facebook, Inc. and its affiliates.
    3  * All rights reserved.
    4  *
    5  * This source code is licensed under the BSD-style license found in the
    6  * LICENSE file in the root directory of this source tree.
    7  */
    8 
    9 #include <arm_neon.h>
   10 
   11 #include <qnnpack/sgemm.h>
   12 
   13 void pytorch_sgemm_ukernel_6x8__neon(
   14     size_t mr,
   15     size_t nr,
   16     size_t k,
   17     const float* restrict a,
   18     size_t a_stride,
   19     const float* restrict w,
   20     float* restrict c,
   21     size_t c_stride,
   22     const struct pytorch_qnnp_fp32_clamping_params
   23         clamping_params[restrict static 1]) {
   24   float32x4_t vacc0x0123 = vld1q_f32(w);
   25   w += 4;
   26   float32x4_t vacc0x4567 = vld1q_f32(w);
   27   w += 4;
   28   float32x4_t vacc1x0123 = vacc0x0123;
   29   float32x4_t vacc1x4567 = vacc0x4567;
   30   float32x4_t vacc2x0123 = vacc0x0123;
   31   float32x4_t vacc2x4567 = vacc0x4567;
   32   float32x4_t vacc3x0123 = vacc0x0123;
   33   float32x4_t vacc3x4567 = vacc0x4567;
   34   float32x4_t vacc4x0123 = vacc0x0123;
   35   float32x4_t vacc4x4567 = vacc0x4567;
   36   float32x4_t vacc5x0123 = vacc0x0123;
   37   float32x4_t vacc5x4567 = vacc0x4567;
   38 
   39   const float* a0 = a;
   40   const float* a1 = (const float*)((uintptr_t)a0 + a_stride);
   41   if (mr < 2) {
   42     a1 = a0;
   43   }
   44   const float* a2 = (const float*)((uintptr_t)a1 + a_stride);
   45   if (mr <= 2) {
   46     a2 = a1;
   47   }
   48   const float* a3 = (const float*)((uintptr_t)a2 + a_stride);
   49   if (mr < 4) {
   50     a3 = a2;
   51   }
   52   const float* a4 = (const float*)((uintptr_t)a3 + a_stride);
   53   if (mr <= 4) {
   54     a4 = a3;
   55   }
   56   const float* a5 = (const float*)((uintptr_t)a4 + a_stride);
   57   if (mr != 6) {
   58     a5 = a4;
   59   }
   60 
   61   for (; k >= 2; k -= 2) {
   62     const float32x2_t va0 = vld1_f32(a0);
   63     a0 += 2;
   64     const float32x2_t va1 = vld1_f32(a1);
   65     a1 += 2;
   66     const float32x2_t va2 = vld1_f32(a2);
   67     a2 += 2;
   68     const float32x2_t va3 = vld1_f32(a3);
   69     a3 += 2;
   70     const float32x2_t va4 = vld1_f32(a4);
   71     a4 += 2;
   72     const float32x2_t va5 = vld1_f32(a5);
   73     a5 += 2;
   74 
   75     {
   76       const float32x4_t vb0123 = vld1q_f32(w);
   77       w += 4;
   78       const float32x4_t vb4567 = vld1q_f32(w);
   79       w += 4;
   80 
   81 #if defined(__aarch64__)
   82       vacc0x0123 = vfmaq_lane_f32(vacc0x0123, vb0123, va0, 0);
   83       vacc0x4567 = vfmaq_lane_f32(vacc0x4567, vb4567, va0, 0);
   84       vacc1x0123 = vfmaq_lane_f32(vacc1x0123, vb0123, va1, 0);
   85       vacc1x4567 = vfmaq_lane_f32(vacc1x4567, vb4567, va1, 0);
   86       vacc2x0123 = vfmaq_lane_f32(vacc2x0123, vb0123, va2, 0);
   87       vacc2x4567 = vfmaq_lane_f32(vacc2x4567, vb4567, va2, 0);
   88       vacc3x0123 = vfmaq_lane_f32(vacc3x0123, vb0123, va3, 0);
   89       vacc3x4567 = vfmaq_lane_f32(vacc3x4567, vb4567, va3, 0);
   90       vacc4x0123 = vfmaq_lane_f32(vacc4x0123, vb0123, va4, 0);
   91       vacc4x4567 = vfmaq_lane_f32(vacc4x4567, vb4567, va4, 0);
   92       vacc5x0123 = vfmaq_lane_f32(vacc5x0123, vb0123, va5, 0);
   93       vacc5x4567 = vfmaq_lane_f32(vacc5x4567, vb4567, va5, 0);
   94 #else
   95       vacc0x0123 = vmlaq_lane_f32(vacc0x0123, vb0123, va0, 0);
   96       vacc0x4567 = vmlaq_lane_f32(vacc0x4567, vb4567, va0, 0);
   97       vacc1x0123 = vmlaq_lane_f32(vacc1x0123, vb0123, va1, 0);
   98       vacc1x4567 = vmlaq_lane_f32(vacc1x4567, vb4567, va1, 0);
   99       vacc2x0123 = vmlaq_lane_f32(vacc2x0123, vb0123, va2, 0);
  100       vacc2x4567 = vmlaq_lane_f32(vacc2x4567, vb4567, va2, 0);
  101       vacc3x0123 = vmlaq_lane_f32(vacc3x0123, vb0123, va3, 0);
  102       vacc3x4567 = vmlaq_lane_f32(vacc3x4567, vb4567, va3, 0);
  103       vacc4x0123 = vmlaq_lane_f32(vacc4x0123, vb0123, va4, 0);
  104       vacc4x4567 = vmlaq_lane_f32(vacc4x4567, vb4567, va4, 0);
  105       vacc5x0123 = vmlaq_lane_f32(vacc5x0123, vb0123, va5, 0);
  106       vacc5x4567 = vmlaq_lane_f32(vacc5x4567, vb4567, va5, 0);
  107 #endif
  108     }
  109 
  110     {
  111       const float32x4_t vb0123 = vld1q_f32(w);
  112       w += 4;
  113       const float32x4_t vb4567 = vld1q_f32(w);
  114       w += 4;
  115 
  116 #if defined(__aarch64__)
  117       vacc0x0123 = vfmaq_lane_f32(vacc0x0123, vb0123, va0, 1);
  118       vacc0x4567 = vfmaq_lane_f32(vacc0x4567, vb4567, va0, 1);
  119       vacc1x0123 = vfmaq_lane_f32(vacc1x0123, vb0123, va1, 1);
  120       vacc1x4567 = vfmaq_lane_f32(vacc1x4567, vb4567, va1, 1);
  121       vacc2x0123 = vfmaq_lane_f32(vacc2x0123, vb0123, va2, 1);
  122       vacc2x4567 = vfmaq_lane_f32(vacc2x4567, vb4567, va2, 1);
  123       vacc3x0123 = vfmaq_lane_f32(vacc3x0123, vb0123, va3, 1);
  124       vacc3x4567 = vfmaq_lane_f32(vacc3x4567, vb4567, va3, 1);
  125       vacc4x0123 = vfmaq_lane_f32(vacc4x0123, vb0123, va4, 1);
  126       vacc4x4567 = vfmaq_lane_f32(vacc4x4567, vb4567, va4, 1);
  127       vacc5x0123 = vfmaq_lane_f32(vacc5x0123, vb0123, va5, 1);
  128       vacc5x4567 = vfmaq_lane_f32(vacc5x4567, vb4567, va5, 1);
  129 #else
  130       vacc0x0123 = vmlaq_lane_f32(vacc0x0123, vb0123, va0, 1);
  131       vacc0x4567 = vmlaq_lane_f32(vacc0x4567, vb4567, va0, 1);
  132       vacc1x0123 = vmlaq_lane_f32(vacc1x0123, vb0123, va1, 1);
  133       vacc1x4567 = vmlaq_lane_f32(vacc1x4567, vb4567, va1, 1);
  134       vacc2x0123 = vmlaq_lane_f32(vacc2x0123, vb0123, va2, 1);
  135       vacc2x4567 = vmlaq_lane_f32(vacc2x4567, vb4567, va2, 1);
  136       vacc3x0123 = vmlaq_lane_f32(vacc3x0123, vb0123, va3, 1);
  137       vacc3x4567 = vmlaq_lane_f32(vacc3x4567, vb4567, va3, 1);
  138       vacc4x0123 = vmlaq_lane_f32(vacc4x0123, vb0123, va4, 1);
  139       vacc4x4567 = vmlaq_lane_f32(vacc4x4567, vb4567, va4, 1);
  140       vacc5x0123 = vmlaq_lane_f32(vacc5x0123, vb0123, va5, 1);
  141       vacc5x4567 = vmlaq_lane_f32(vacc5x4567, vb4567, va5, 1);
  142 #endif
  143     }
  144   }
  145   if (k != 0) {
  146     const float32x4_t va0 = vld1q_dup_f32(a0);
  147     const float32x4_t va1 = vld1q_dup_f32(a1);
  148     const float32x4_t va2 = vld1q_dup_f32(a2);
  149     const float32x4_t va3 = vld1q_dup_f32(a3);
  150     const float32x4_t va4 = vld1q_dup_f32(a4);
  151     const float32x4_t va5 = vld1q_dup_f32(a5);
  152 
  153     const float32x4_t vb0123 = vld1q_f32(w);
  154     w += 4;
  155     const float32x4_t vb4567 = vld1q_f32(w);
  156     w += 4;
  157 
  158 #if defined(__aarch64__)
  159     vacc0x0123 = vfmaq_f32(vacc0x0123, vb0123, va0);
  160     vacc0x4567 = vfmaq_f32(vacc0x4567, vb4567, va0);
  161     vacc1x0123 = vfmaq_f32(vacc1x0123, vb0123, va1);
  162     vacc1x4567 = vfmaq_f32(vacc1x4567, vb4567, va1);
  163     vacc2x0123 = vfmaq_f32(vacc2x0123, vb0123, va2);
  164     vacc2x4567 = vfmaq_f32(vacc2x4567, vb4567, va2);
  165     vacc3x0123 = vfmaq_f32(vacc3x0123, vb0123, va3);
  166     vacc3x4567 = vfmaq_f32(vacc3x4567, vb4567, va3);
  167     vacc4x0123 = vfmaq_f32(vacc4x0123, vb0123, va4);
  168     vacc4x4567 = vfmaq_f32(vacc4x4567, vb4567, va4);
  169     vacc5x0123 = vfmaq_f32(vacc5x0123, vb0123, va5);
  170     vacc5x4567 = vfmaq_f32(vacc5x4567, vb4567, va5);
  171 #else
  172     vacc0x0123 = vmlaq_f32(vacc0x0123, vb0123, va0);
  173     vacc0x4567 = vmlaq_f32(vacc0x4567, vb4567, va0);
  174     vacc1x0123 = vmlaq_f32(vacc1x0123, vb0123, va1);
  175     vacc1x4567 = vmlaq_f32(vacc1x4567, vb4567, va1);
  176     vacc2x0123 = vmlaq_f32(vacc2x0123, vb0123, va2);
  177     vacc2x4567 = vmlaq_f32(vacc2x4567, vb4567, va2);
  178     vacc3x0123 = vmlaq_f32(vacc3x0123, vb0123, va3);
  179     vacc3x4567 = vmlaq_f32(vacc3x4567, vb4567, va3);
  180     vacc4x0123 = vmlaq_f32(vacc4x0123, vb0123, va4);
  181     vacc4x4567 = vmlaq_f32(vacc4x4567, vb4567, va4);
  182     vacc5x0123 = vmlaq_f32(vacc5x0123, vb0123, va5);
  183     vacc5x4567 = vmlaq_f32(vacc5x4567, vb4567, va5);
  184 #endif
  185   }
  186   const float32x4_t vmax = vld1q_dup_f32(&clamping_params->max);
  187   vacc0x0123 = vminq_f32(vacc0x0123, vmax);
  188   vacc0x4567 = vminq_f32(vacc0x4567, vmax);
  189   vacc1x0123 = vminq_f32(vacc1x0123, vmax);
  190   vacc1x4567 = vminq_f32(vacc1x4567, vmax);
  191   vacc2x0123 = vminq_f32(vacc2x0123, vmax);
  192   vacc2x4567 = vminq_f32(vacc2x4567, vmax);
  193   vacc3x0123 = vminq_f32(vacc3x0123, vmax);
  194   vacc3x4567 = vminq_f32(vacc3x4567, vmax);
  195   vacc4x0123 = vminq_f32(vacc4x0123, vmax);
  196   vacc4x4567 = vminq_f32(vacc4x4567, vmax);
  197   vacc5x0123 = vminq_f32(vacc5x0123, vmax);
  198   vacc5x4567 = vminq_f32(vacc5x4567, vmax);
  199 
  200   const float32x4_t vmin = vld1q_dup_f32(&clamping_params->min);
  201   vacc0x0123 = vmaxq_f32(vacc0x0123, vmin);
  202   vacc0x4567 = vmaxq_f32(vacc0x4567, vmin);
  203   vacc1x0123 = vmaxq_f32(vacc1x0123, vmin);
  204   vacc1x4567 = vmaxq_f32(vacc1x4567, vmin);
  205   vacc2x0123 = vmaxq_f32(vacc2x0123, vmin);
  206   vacc2x4567 = vmaxq_f32(vacc2x4567, vmin);
  207   vacc3x0123 = vmaxq_f32(vacc3x0123, vmin);
  208   vacc3x4567 = vmaxq_f32(vacc3x4567, vmin);
  209   vacc4x0123 = vmaxq_f32(vacc4x0123, vmin);
  210   vacc4x4567 = vmaxq_f32(vacc4x4567, vmin);
  211   vacc5x0123 = vmaxq_f32(vacc5x0123, vmin);
  212   vacc5x4567 = vmaxq_f32(vacc5x4567, vmin);
  213 
  214   float* c0 = c;
  215   float* c1 = (float*)((uintptr_t)c0 + c_stride);
  216   if (mr < 2) {
  217     c1 = c0;
  218   }
  219   float* c2 = (float*)((uintptr_t)c1 + c_stride);
  220   if (mr <= 2) {
  221     c2 = c1;
  222   }
  223   float* c3 = (float*)((uintptr_t)c2 + c_stride);
  224   if (mr < 4) {
  225     c3 = c2;
  226   }
  227   float* c4 = (float*)((uintptr_t)c3 + c_stride);
  228   if (mr <= 4) {
  229     c4 = c3;
  230   }
  231   float* c5 = (float*)((uintptr_t)c4 + c_stride);
  232   if (mr != 6) {
  233     c5 = c4;
  234   }
  235   if (nr == 8) {
  236     vst1q_f32(c0, vacc0x0123);
  237     c0 += 4;
  238     vst1q_f32(c1, vacc1x0123);
  239     c1 += 4;
  240     vst1q_f32(c2, vacc2x0123);
  241     c2 += 4;
  242     vst1q_f32(c3, vacc3x0123);
  243     c3 += 4;
  244     vst1q_f32(c4, vacc4x0123);
  245     c4 += 4;
  246     vst1q_f32(c5, vacc5x0123);
  247     c5 += 4;
  248 
  249     vst1q_f32(c0, vacc0x4567);
  250     vst1q_f32(c1, vacc1x4567);
  251     vst1q_f32(c2, vacc2x4567);
  252     vst1q_f32(c3, vacc3x4567);
  253     vst1q_f32(c4, vacc4x4567);
  254     vst1q_f32(c5, vacc5x4567);
  255   } else {
  256     if (nr >= 4) {
  257       vst1q_f32(c0, vacc0x0123);
  258       c0 += 4;
  259       vst1q_f32(c1, vacc1x0123);
  260       c1 += 4;
  261       vst1q_f32(c2, vacc2x0123);
  262       c2 += 4;
  263       vst1q_f32(c3, vacc3x0123);
  264       c3 += 4;
  265       vst1q_f32(c4, vacc4x0123);
  266       c4 += 4;
  267       vst1q_f32(c5, vacc5x0123);
  268       c5 += 4;
  269       vacc0x0123 = vacc0x4567;
  270       vacc1x0123 = vacc1x4567;
  271       vacc2x0123 = vacc2x4567;
  272       vacc3x0123 = vacc3x4567;
  273       vacc4x0123 = vacc4x4567;
  274       vacc5x0123 = vacc5x4567;
  275       nr -= 4;
  276     }
  277     if (nr >= 2) {
  278       vst1_f32(c0, vget_low_f32(vacc0x0123));
  279       c0 += 2;
  280       vst1_f32(c1, vget_low_f32(vacc1x0123));
  281       c1 += 2;
  282       vst1_f32(c2, vget_low_f32(vacc2x0123));
  283       c2 += 2;
  284       vst1_f32(c3, vget_low_f32(vacc3x0123));
  285       c3 += 2;
  286       vst1_f32(c4, vget_low_f32(vacc4x0123));
  287       c4 += 2;
  288       vst1_f32(c5, vget_low_f32(vacc5x0123));
  289       c5 += 2;
  290       vacc0x0123 = vextq_f32(vacc0x0123, vacc0x0123, 2);
  291       vacc1x0123 = vextq_f32(vacc1x0123, vacc1x0123, 2);
  292       vacc2x0123 = vextq_f32(vacc2x0123, vacc2x0123, 2);
  293       vacc3x0123 = vextq_f32(vacc3x0123, vacc3x0123, 2);
  294       vacc4x0123 = vextq_f32(vacc4x0123, vacc4x0123, 2);
  295       vacc5x0123 = vextq_f32(vacc5x0123, vacc5x0123, 2);
  296       nr -= 2;
  297     }
  298     if (nr != 0) {
  299       vst1q_lane_f32(c0, vacc0x0123, 0);
  300       vst1q_lane_f32(c1, vacc1x0123, 0);
  301       vst1q_lane_f32(c2, vacc2x0123, 0);
  302       vst1q_lane_f32(c3, vacc3x0123, 0);
  303       vst1q_lane_f32(c4, vacc4x0123, 0);
  304       vst1q_lane_f32(c5, vacc5x0123, 0);
  305     }
  306   }
  307 }