"Fossies" - the Fresh Open Source Software Archive

Member "pytorch-1.8.2/aten/src/ATen/native/quantized/cpu/qnnpack/src/u8maxpool/16x9p8q-neon.c" (23 Jul 2021, 8041 Bytes) of package /linux/misc/pytorch-1.8.2.tar.gz:


As a special service "Fossies" has tried to format the requested source page into HTML format using (guessed) C and C++ source code syntax highlighting (style: standard) with prefixed line numbers and code folding option. Alternatively you can here view or download the uninterpreted source code file. For more information about "16x9p8q-neon.c" see the Fossies "Dox" file reference documentation.

    1 /*
    2  * Copyright (c) Facebook, Inc. and its affiliates.
    3  * All rights reserved.
    4  *
    5  * This source code is licensed under the BSD-style license found in the
    6  * LICENSE file in the root directory of this source tree.
    7  */
    8 
    9 #include <assert.h>
   10 
   11 #include <arm_neon.h>
   12 
   13 #include <qnnpack/u8maxpool.h>
   14 
   15 void pytorch_u8maxpool_ukernel_16x9p8q__neon(
   16     size_t n,
   17     size_t ks,
   18     size_t kc,
   19     const uint8_t** input,
   20     uint8_t* output,
   21     size_t input_increment,
   22     size_t output_increment,
   23     const union pytorch_qnnp_u8_clamping_params params[restrict static 1]) {
   24   assert(n != 0);
   25   assert(ks != 0);
   26   assert(kc >= 16);
   27 
   28   const uint8x16_t voutput_max = vld1q_dup_u8(&params->neon.output_max);
   29   const uint8x16_t voutput_min = vld1q_dup_u8(&params->neon.output_min);
   30   do {
   31     uint8_t* o = output;
   32     {
   33       const uint8_t* i0 = *input++;
   34       const uint8_t* i1 = *input++;
   35       const uint8_t* i2 = *input++;
   36       const uint8_t* i3 = *input++;
   37       const uint8_t* i4 = *input++;
   38       const uint8_t* i5 = *input++;
   39       const uint8_t* i6 = *input++;
   40       const uint8_t* i7 = *input++;
   41       const uint8_t* i8 = *input++;
   42       if (ks < 2) {
   43         i1 = i0;
   44       }
   45       if (ks <= 2) {
   46         i2 = i0;
   47       }
   48       if (ks < 4) {
   49         i3 = i0;
   50       }
   51       if (ks <= 4) {
   52         i4 = i0;
   53       }
   54       if (ks < 6) {
   55         i5 = i0;
   56       }
   57       if (ks <= 6) {
   58         i6 = i0;
   59       }
   60       if (ks < 8) {
   61         i7 = i0;
   62       }
   63       if (ks <= 8) {
   64         i8 = i0;
   65       }
   66 
   67       size_t k = kc;
   68       while (k >= 16) {
   69         const uint8x16_t vi0 = vld1q_u8(i0);
   70         i0 += 16;
   71         const uint8x16_t vi1 = vld1q_u8(i1);
   72         i1 += 16;
   73         const uint8x16_t vi2 = vld1q_u8(i2);
   74         i2 += 16;
   75         const uint8x16_t vi3 = vld1q_u8(i3);
   76         i3 += 16;
   77         const uint8x16_t vi4 = vld1q_u8(i4);
   78         i4 += 16;
   79         const uint8x16_t vi5 = vld1q_u8(i5);
   80         i5 += 16;
   81         const uint8x16_t vi6 = vld1q_u8(i6);
   82         i6 += 16;
   83         const uint8x16_t vi7 = vld1q_u8(i7);
   84         i7 += 16;
   85         const uint8x16_t vi8 = vld1q_u8(i8);
   86         i8 += 16;
   87 
   88         const uint8x16_t vmax018 = vmaxq_u8(vmaxq_u8(vi0, vi1), vi8);
   89         const uint8x16_t vmax23 = vmaxq_u8(vi2, vi3);
   90         const uint8x16_t vmax45 = vmaxq_u8(vi4, vi5);
   91         const uint8x16_t vmax67 = vmaxq_u8(vi6, vi7);
   92 
   93         const uint8x16_t vmax2345 = vmaxq_u8(vmax23, vmax45);
   94         const uint8x16_t vmax01678 = vmaxq_u8(vmax018, vmax67);
   95         const uint8x16_t vmax = vmaxq_u8(vmax2345, vmax01678);
   96         const uint8x16_t vout =
   97             vmaxq_u8(vminq_u8(vmax, voutput_max), voutput_min);
   98 
   99         vst1q_u8(o, vout);
  100         o += 16;
  101 
  102         k -= 16;
  103       }
  104       if (k != 0) {
  105         const size_t address_increment = k - 16;
  106         i0 = (const uint8_t*)((uintptr_t)i0 + address_increment);
  107         i1 = (const uint8_t*)((uintptr_t)i1 + address_increment);
  108         i2 = (const uint8_t*)((uintptr_t)i2 + address_increment);
  109         i3 = (const uint8_t*)((uintptr_t)i3 + address_increment);
  110         i4 = (const uint8_t*)((uintptr_t)i4 + address_increment);
  111         i5 = (const uint8_t*)((uintptr_t)i5 + address_increment);
  112         i6 = (const uint8_t*)((uintptr_t)i6 + address_increment);
  113         i7 = (const uint8_t*)((uintptr_t)i7 + address_increment);
  114         i8 = (const uint8_t*)((uintptr_t)i8 + address_increment);
  115         o = (uint8_t*)((uintptr_t)o + address_increment);
  116 
  117         const uint8x16_t vi0 = vld1q_u8(i0);
  118         const uint8x16_t vi1 = vld1q_u8(i1);
  119         const uint8x16_t vi2 = vld1q_u8(i2);
  120         const uint8x16_t vi3 = vld1q_u8(i3);
  121         const uint8x16_t vi4 = vld1q_u8(i4);
  122         const uint8x16_t vi5 = vld1q_u8(i5);
  123         const uint8x16_t vi6 = vld1q_u8(i6);
  124         const uint8x16_t vi7 = vld1q_u8(i7);
  125         const uint8x16_t vi8 = vld1q_u8(i8);
  126 
  127         const uint8x16_t vmax018 = vmaxq_u8(vmaxq_u8(vi0, vi1), vi8);
  128         const uint8x16_t vmax23 = vmaxq_u8(vi2, vi3);
  129         const uint8x16_t vmax45 = vmaxq_u8(vi4, vi5);
  130         const uint8x16_t vmax67 = vmaxq_u8(vi6, vi7);
  131 
  132         const uint8x16_t vmax2345 = vmaxq_u8(vmax23, vmax45);
  133         const uint8x16_t vmax01678 = vmaxq_u8(vmax018, vmax67);
  134         const uint8x16_t vmax = vmaxq_u8(vmax2345, vmax01678);
  135         const uint8x16_t vout =
  136             vmaxq_u8(vminq_u8(vmax, voutput_max), voutput_min);
  137 
  138         vst1q_u8(o, vout);
  139         o += 16;
  140       }
  141     }
  142 
  143     for (ptrdiff_t m = (ptrdiff_t)ks - 9; m > 0; m -= 8) {
  144       const uint8_t* i0 = *input++;
  145       const uint8_t* i1 = *input++;
  146       const uint8_t* i2 = *input++;
  147       const uint8_t* i3 = *input++;
  148       const uint8_t* i4 = *input++;
  149       const uint8_t* i5 = *input++;
  150       const uint8_t* i6 = *input++;
  151       const uint8_t* i7 = *input++;
  152       if (m < 2) {
  153         i1 = i0;
  154       }
  155       if (m <= 2) {
  156         i2 = i0;
  157       }
  158       if (m < 4) {
  159         i3 = i0;
  160       }
  161       if (m <= 4) {
  162         i4 = i0;
  163       }
  164       if (m < 6) {
  165         i5 = i0;
  166       }
  167       if (m <= 6) {
  168         i6 = i0;
  169       }
  170       if (m < 8) {
  171         i7 = i0;
  172       }
  173 
  174       o = output;
  175       size_t k = kc;
  176       while (k >= 16) {
  177         const uint8x16_t vi0 = vld1q_u8(i0);
  178         i0 += 16;
  179         const uint8x16_t vi1 = vld1q_u8(i1);
  180         i1 += 16;
  181         const uint8x16_t vi2 = vld1q_u8(i2);
  182         i2 += 16;
  183         const uint8x16_t vi3 = vld1q_u8(i3);
  184         i3 += 16;
  185         const uint8x16_t vi4 = vld1q_u8(i4);
  186         i4 += 16;
  187         const uint8x16_t vi5 = vld1q_u8(i5);
  188         i5 += 16;
  189         const uint8x16_t vi6 = vld1q_u8(i6);
  190         i6 += 16;
  191         const uint8x16_t vi7 = vld1q_u8(i7);
  192         i7 += 16;
  193         const uint8x16_t vo = vld1q_u8(o);
  194 
  195         const uint8x16_t vmax01 = vmaxq_u8(vmaxq_u8(vi0, vi1), vo);
  196         const uint8x16_t vmax23 = vmaxq_u8(vi2, vi3);
  197         const uint8x16_t vmax45 = vmaxq_u8(vi4, vi5);
  198         const uint8x16_t vmax67 = vmaxq_u8(vi6, vi7);
  199 
  200         const uint8x16_t vmax2345 = vmaxq_u8(vmax23, vmax45);
  201         const uint8x16_t vmax0167 = vmaxq_u8(vmax01, vmax67);
  202         const uint8x16_t vmax = vmaxq_u8(vmax2345, vmax0167);
  203         const uint8x16_t vout =
  204             vmaxq_u8(vminq_u8(vmax, voutput_max), voutput_min);
  205 
  206         vst1q_u8(o, vout);
  207         o += 16;
  208 
  209         k -= 16;
  210       }
  211       if (k != 0) {
  212         const size_t address_increment = k - 16;
  213         i0 = (const uint8_t*)((uintptr_t)i0 + address_increment);
  214         i1 = (const uint8_t*)((uintptr_t)i1 + address_increment);
  215         i2 = (const uint8_t*)((uintptr_t)i2 + address_increment);
  216         i3 = (const uint8_t*)((uintptr_t)i3 + address_increment);
  217         i4 = (const uint8_t*)((uintptr_t)i4 + address_increment);
  218         i5 = (const uint8_t*)((uintptr_t)i5 + address_increment);
  219         i6 = (const uint8_t*)((uintptr_t)i6 + address_increment);
  220         i7 = (const uint8_t*)((uintptr_t)i7 + address_increment);
  221         o = (uint8_t*)((uintptr_t)o + address_increment);
  222 
  223         const uint8x16_t vi0 = vld1q_u8(i0);
  224         const uint8x16_t vi1 = vld1q_u8(i1);
  225         const uint8x16_t vi2 = vld1q_u8(i2);
  226         const uint8x16_t vi3 = vld1q_u8(i3);
  227         const uint8x16_t vi4 = vld1q_u8(i4);
  228         const uint8x16_t vi5 = vld1q_u8(i5);
  229         const uint8x16_t vi6 = vld1q_u8(i6);
  230         const uint8x16_t vi7 = vld1q_u8(i7);
  231         const uint8x16_t vo = vld1q_u8(o);
  232 
  233         const uint8x16_t vmax01 = vmaxq_u8(vmaxq_u8(vi0, vi1), vo);
  234         const uint8x16_t vmax23 = vmaxq_u8(vi2, vi3);
  235         const uint8x16_t vmax45 = vmaxq_u8(vi4, vi5);
  236         const uint8x16_t vmax67 = vmaxq_u8(vi6, vi7);
  237 
  238         const uint8x16_t vmax2345 = vmaxq_u8(vmax23, vmax45);
  239         const uint8x16_t vmax0167 = vmaxq_u8(vmax01, vmax67);
  240         const uint8x16_t vmax = vmaxq_u8(vmax2345, vmax0167);
  241         const uint8x16_t vout =
  242             vmaxq_u8(vminq_u8(vmax, voutput_max), voutput_min);
  243 
  244         vst1q_u8(o, vout);
  245         o += 16;
  246       }
  247     }
  248     input = (const uint8_t**)((uintptr_t)input + input_increment);
  249     output = (uint8_t*)((uintptr_t)o + output_increment);
  250   } while (--n != 0);
  251 }