"Fossies" - the Fresh Open Source Software Archive

Member "pytorch-1.8.2/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/8x4c1x4-dq-packedA-aarch32-neon.S" (23 Jul 2021, 11849 Bytes) of package /linux/misc/pytorch-1.8.2.tar.gz:


As a special service "Fossies" has tried to format the requested source page into HTML format using (guessed) PowerPC Assembler source code syntax highlighting (style: standard) with prefixed line numbers. Alternatively you can here view or download the uninterpreted source code file.

    1 /*
    2  * Copyright (c) Facebook, Inc. and its affiliates.
    3  * All rights reserved.
    4  *
    5  * This source code is licensed under the BSD-style license found in the
    6  * LICENSE file in the root directory of this source tree.
    7  */
    8 
    9 #include <qnnpack/assembly.h>
   10 #include <requantization/runtime-assembly.h>
   11 
   12 # r0 mr
   13 # r1 nr
   14 # r2 packed_a
   15 # r3 packed_w
   16 
   17 # d14 a_zero_point
   18 # d15 b_zero_point
   19 
   20 ## Stack
   21 # 4     a_stride
   22 # 4     packed_w
   23 # 4     w_row_ptr 
   24 # 4     w_block_ids_ptr
   25 # 4     b
   26 # 4     c
   27 # 4     c_stride
   28 # 4     output channel index
   29 # 4     quantization_params
   30 # --
   31 
   32 .syntax unified
   33 
   34 #  Args passed via stack.
   35 #  TOS
   36 #  |----------------|
   37 #  |packed_w        | 0
   38 #  |w_row_ptr       | 4
   39 #  |w_block_ids_ptr | 8
   40 #  |b               | 12
   41 #  |c               | 16
   42 #  |c_stride        | 20
   43 #  |out ch indx     | 24
   44 #  |params          | 28
   45 #  |----------------|
   46 #  
   47 
   48 #  After loading w pointer in ip reg.
   49 #  And after pushing r4-r9 and d8-d15 on stack
   50 #  |----------------|
   51 #  |d8 - d15        | 0
   52 #  |r4 - r11,lr     | 64
   53 #  |w_row_ptr       | 100
   54 #  |w_block_ids_ptr | 104
   55 #  |b               | 108
   56 #  |c               | 112
   57 #  |c_stride        | 116
   58 #  |out ch indx     | 120
   59 #  |params          | 124
   60 #  |----------------|
   61 #  
   62 
   63 # void pytorch_q8gemm_dq_sparse_1x4_ukernel_8x4_packedA__aarch32_neon(
   64 #     size_t mr,
   65 #     size_t nr,
   66 #     const uint8_t* a_packed,
   67 #     const uint8_t* packed_w,
   68 #     const uint32_t* w_row_ptr,
   69 #     const uint32_t* w_block_ids_ptr,
   70 #     const float* b,
   71 #     uint8_t* restrict c,
   72 #     size_t c_stride,
   73 #     size_t output_channel_index,
   74 #     const union pytorch_qnnp_conv_dynamic_quantization_params quantization_params[restrict static 1])
   75 BEGIN_FUNCTION pytorch_q8gemm_dq_sparse_1x4_ukernel_8x4_packedA__aarch32_neon
   76     .arm
   77 #ifndef __APPLE__
   78     .arch armv7-a
   79     .fpu neon
   80 #endif
   81 
   82     PUSH {r4, r5, r6, r7, r8, r9, r10, r11, lr}
   83     VPUSH {d8-d15}
   84 
   85     # Store nr in r11 as well for late user.
   86     MOV r11, r1
   87     # Load output channel index
   88     LDR r5, [sp, 120]
   89     # Load quantization params
   90     # - r7 = quantization_params
   91     LDR r7, [sp, 124]
   92     # Load input_zero_point
   93     VLD1.8 {d16[]}, [r7]
   94     ADD r7, r7, 4
   95     # Load pointer to per channel zero points array
   96     LDR r4, [r7]
   97     # Add output_channel_index to the b_zero_point pointer
   98     ADD r4, r4, r5
   99 
  100     # We enter the loop if r1 is atleast 1.
  101     # r1 = r1 - 1 will happen in the epilogue
  102     # of the loop
  103     CMP r1, 1
  104     BLO 7f
  105 
  106     # Load w_row_ptr + n
  107     LDR r5, [sp, 100]
  108     # r7 = blocks_id_ptr
  109     LDR r7, [sp, 104]
  110 
  111     .p2align 5
  112 0:
  113     VEOR q10, q10, q10
  114     VLD1.8 {d17[]}, [r4]!
  115     # ip = w_row_ptr[n], lr = w_row_ptr[n+1]
  116     # r5 = r5 + 4 to point to next n
  117     LDR ip, [r5], #4
  118     LDR lr, [r5]
  119     VEOR q11, q11, q11
  120     # r6 = temp_packed_w = packed_w + w_row_ptr[n] * 4
  121     # This points to the first block of nonzero value
  122     # for the nth row.
  123     ADD r6, r3, ip, LSL #2
  124     # r9 = temp_w_block_ids_ptr = w_block_ids_ptr (r7) + w_row_ptr[n]
  125     # LSL2 because each element is 4 bytes
  126     # This points to the block id of the first block
  127     # It should contain lr - ip number of block ids
  128     ADD r9, r7, ip, LSL #2
  129     # r8 = num_blocks that needs to be processed
  130     SUB r8, lr, ip
  131     SUBS r8, r8, 2
  132     BLO 1f
  133 
  134 k_loop:
  135     # Load 2 non zero blocks of weights. Each block = 1x4.
  136     VLD1.8 {d19}, [r6]!
  137 
  138     #ip = block_id_ptr[0]
  139     #lr = block_id_ptr[1]
  140     LDR ip, [r9], #4
  141     LDR lr, [r9], #4
  142 
  143     # Add offset to r2
  144     # Shift by 5 because each packed block is a block of 8x4
  145     # which 32 bytes
  146     ADD r10, r2, ip, LSL #5
  147     # q9 = vxb
  148     VSUBL.U8 q9, d19, d17
  149 
  150     VLD1.8 {d0}, [r10]!
  151     VLD1.8 {d1}, [r10]!
  152     VLD1.8 {d2}, [r10]!
  153     VLD1.8 {d3}, [r10]
  154 
  155     ADD r10, r2, lr, LSL #5
  156 
  157     VSUBL.U8 q4, d0, d16  // vxa0_t
  158 
  159     VLD1.8 {d4}, [r10]!
  160     VLD1.8 {d5}, [r10]!
  161     VLD1.8 {d6}, [r10]!
  162     VLD1.8 {d7}, [r10]
  163 
  164     VSUBL.U8 q5, d1, d16  // vxa1_t
  165     VSUBL.U8 q6, d2, d16  // vxa2_t
  166     VSUBL.U8 q7, d3, d16  // vxa3_t
  167     VSUBL.U8 q12, d4, d16  // vxa4_t
  168     VSUBL.U8 q13, d5, d16  // vxa5_t
  169     VSUBL.U8 q14, d6, d16  // vxa6_t
  170     VSUBL.U8 q15, d7, d16  // vxa7_t
  171     # This setup without the VMOVs is a perfect
  172     # setup for double buffering
  173     # Load data in q0, q1, q2, q3.
  174     # vxa* ares in q4-q7, q12-q13.
  175     # Now q0-q3 are free to load next iterations.
  176     # We will do this as a later optimization.
  177 
  178     VMOV q0, q9
  179 
  180     VMLAL.S16 q10, d8, d0[0]
  181     VMLAL.S16 q11, d9, d0[0]
  182     VMLAL.S16 q10, d10, d0[1]
  183     VMLAL.S16 q11, d11, d0[1]
  184     VMLAL.S16 q10, d12, d0[2]
  185     VMLAL.S16 q11, d13, d0[2]
  186     VMLAL.S16 q10, d14, d0[3]
  187     VMLAL.S16 q11, d15, d0[3]
  188     VMLAL.S16 q10, d24, d1[0]
  189     VMLAL.S16 q11, d25, d1[0]
  190     VMLAL.S16 q10, d26, d1[1]
  191     VMLAL.S16 q11, d27, d1[1]
  192     VMLAL.S16 q10, d28, d1[2]
  193     VMLAL.S16 q11, d29, d1[2]
  194     VMLAL.S16 q10, d30, d1[3]
  195     VMLAL.S16 q11, d31, d1[3]
  196 
  197     SUBS r8, r8, 2
  198 
  199     BHS k_loop
  200 1:
  201     CMP r8, -2
  202     BEQ 2f
  203 
  204     # Load last nonzero block
  205     # For this we will load 4 8 bit values as one 32 bit value
  206     VLD1.32 {d19[]}, [r6]!
  207     # q9 = vxb
  208     VSUBL.U8 q9, d19, d17
  209 
  210     #ip = block_id_ptr[0]
  211     LDR ip, [r9]
  212 
  213     # Add offset to r2
  214     # Shift by 5 because each packed block is a block of 8x4
  215     # which 32 bytes
  216     ADD r10, r2, ip, LSL #5
  217 
  218     VLD1.8 {d0}, [r10]!
  219     VLD1.8 {d1}, [r10]!
  220     VLD1.8 {d2}, [r10]!
  221     VLD1.8 {d3}, [r10]
  222 
  223     VSUBL.U8 q4, d0, d16  // vxa04_t
  224     VSUBL.U8 q5, d1, d16  // vxa15_t
  225     VSUBL.U8 q6, d2, d16  // vxa26_t
  226     VSUBL.U8 q7, d3, d16  // vxa37_t
  227 
  228     VMOV q0, q9
  229 
  230     VMLAL.S16 q10, d8, d0[0]
  231     VMLAL.S16 q11, d9, d0[0]
  232     VMLAL.S16 q10, d10, d0[1]
  233     VMLAL.S16 q11, d11, d0[1]
  234     VMLAL.S16 q10, d12, d0[2]
  235     VMLAL.S16 q11, d13, d0[2]
  236     VMLAL.S16 q10, d14, d0[3]
  237     VMLAL.S16 q11, d15, d0[3]
  238 
  239     .p2align 4
  240 2:
  241     # Store result on stack
  242 
  243     # -12 because TOS - 4, TOS - 8, and TOS - 12, store mr, nr and pointer to weight zp
  244     # + 128 bytes of buffer when nr = 1
  245     # This is needed because after processing all nrs we will
  246     # load 128 bytes from stack. This is for q10, q11 for max nr of 4
  247     # Thus we will load accumulators back in q0, q1, q2, q3, q4, q5, q6, q7
  248     # When nr < 4, extra q values will be fetched from stack which may overlap
  249     # with other parts of stack storing local variables. To avoid that we just
  250     # create a buffer of 128 bytes inbetween to make sure pointer increment
  251     # never produces address that is beyond the stack frame of this function.
  252     SUB r9, sp, 140
  253     # Each iteration produce 8 values each of 4 bytes
  254     # Thus 8 x 4 = 32 bytes 2^5
  255     # In this implementation, first value will be stored at
  256     # 1st value: sp - 12 - r1 * 32
  257     # 2nd value: sp - 12 - (r1 - 1) * 32
  258     # and so on.
  259     SUB r9, r9, r1, LSL #5
  260     VST1.32 {q10}, [r9]!
  261     VST1.32 {q11}, [r9]
  262 
  263     # Check if nr >=1
  264     SUBS r1, r1, 1
  265     BHI 0b
  266 3:
  267     # First load all the accumulators from stack
  268     # Load nr
  269     SUB r9, sp, 140
  270     SUB r9, r9, r11, LSL #5
  271     # Now load q8-q15
  272     # This is 4x8 block (nrxmr)
  273     # We will transpose this to 8x4 (mrxnr)
  274     # q8, q12  : x00, x10, x20, x30; x40, x50, x60, x70
  275     # q9, q13  : x01, x11, x21, x31; x41, x51, x61, x71
  276     # q10, q14 : x02, x12, x22, x32; x42, x52, x62, x72
  277     # q11, q15 : x03, x13, x23, x33; x43, x53, x63, x73
  278     VLD1.32 {q8}, [r9]!
  279     VLD1.32 {q12}, [r9]!
  280     VLD1.32 {q9}, [r9]!
  281     VLD1.32 {q13}, [r9]!
  282     VLD1.32 {q10}, [r9]!
  283     VLD1.32 {q14}, [r9]!
  284     VLD1.32 {q11}, [r9]!
  285     VLD1.32 {q15}, [r9]
  286 
  287     ## Now transpose q8-11
  288     # VTRN.32 q8, q9
  289     # VTRN.32 q10, q11
  290     # q8 : X00, x01, x20, x21
  291     # q9 : X10, x11, x30, x31
  292     # q10: X02, x03, x22, x23
  293     # q11: X12, x13, x32, x33
  294     # VSWP d16, d17
  295     # q8 : x20, x21, x00, x01
  296     # VEXT.32 q6, q8, q10, 2
  297     # q6 : x00, x01, x02, x03
  298     # VEXT.32 q10, q10, q8, 2
  299     # q10: x22, x23, x20, x21
  300     # VSWP d20, d21
  301     # VMOV q8, q6
  302     # q8 : X00, x01, x02, x03
  303     # q10: x20, x21, x22, x23
  304     # VSWP d18, d19
  305     # q9 : x30, x31, x10, x11
  306     # VEXT.32 q6, q9, q11, 2
  307     # q6 : x10, x11, x12, x13
  308     # VEXT.32 q11, q11, q9, 2
  309     # q11: x32, x33, x30, x31
  310     # VSWP d22, d23
  311     # VMOV q9, q6
  312     # q9 : x10, x11, x12, x13
  313     # q11: x30, x31, x32, x33
  314     # Thus we have
  315     # q8 : X00, x01, x02, x03
  316     # q9 : X10, x11, x12, x13
  317     # q10: X20, x21, x22, x23
  318     # q11: X30, x31, x32, x33
  319     # Now we can do the same for q4-q7
  320     # q12: X40, X41, X42, X43
  321     # q13: X50, X51, X52, X53
  322     # q14: X60, X61, X62, X63
  323     # q15: X70, X71, X72, X73
  324     # NEED TO VALIDATE THIS
  325     VTRN.32 q8, q9
  326     VTRN.32 q10, q11
  327     VSWP d16, d17
  328     VEXT.32 q6, q8, q10, 2
  329     VEXT.32 q10, q10, q8, 2
  330     VSWP d20, d21
  331     VMOV q8, q6
  332     VSWP d18, d19
  333     VEXT.32 q6, q9, q11, 2
  334     VEXT.32 q11, q11, q9, 2
  335     VSWP d22, d23
  336     VMOV q9, q6
  337 
  338     VTRN.32 q12, q13
  339     VTRN.32 q14, q15
  340     VSWP d24, d25
  341     VEXT.32 q6, q12, q14, 2
  342     VEXT.32 q14, q14, q12, 2
  343     VSWP d28, d29
  344     VMOV q12, q6
  345     VSWP d26, d27
  346     VEXT.32 q6, q13, q15, 2
  347     VEXT.32 q15, q15, q13, 2
  348     VSWP d30, d31
  349     VMOV q13, q6
  350 
  351     # Load output channel index
  352     LDR r5, [sp, 120]
  353     # Load quantization params
  354     # - r7 = quantization_params
  355     LDR r7, [sp, 124]
  356     ADD r7, r7, 8
  357     # Load pointer to per channel requant scale
  358     LDR r7, [r7]
  359     # Now r7 has the base_addr + offset for multipliers
  360     ADD r7, r7, r5, LSL #2
  361 
  362     LDR r6, [sp, 108]
  363     # Load q6: vmultiplier_c0123
  364     VLD1.32 {d12, d13}, [r7]!
  365     VCVT.F32.S32 q8, q8
  366     VCVT.F32.S32 q9, q9
  367     VCVT.F32.S32 q10, q10
  368     VLD1.32 {q0}, [r6]
  369 
  370     VCVT.F32.S32 q11, q11
  371     VCVT.F32.S32 q12, q12
  372     VCVT.F32.S32 q13, q13
  373     VCVT.F32.S32 q14, q14
  374     VCVT.F32.S32 q15, q15
  375 
  376     VMUL.F32 q8, q8, q6
  377     VMUL.F32 q9, q9, q6
  378     VMUL.F32 q10, q10, q6
  379     VMUL.F32 q11, q11, q6
  380     VMUL.F32 q12, q12, q6
  381     VMUL.F32 q13, q13, q6
  382     VMUL.F32 q14, q14, q6
  383     VMUL.F32 q15, q15, q6
  384 
  385     VADD.F32 q8, q8, q0
  386     VADD.F32 q9, q9, q0
  387     VADD.F32 q10, q10, q0
  388     VADD.F32 q11, q11, q0
  389     VADD.F32 q12, q12, q0
  390     VADD.F32 q13, q13, q0
  391     VADD.F32 q14, q14, q0
  392     VADD.F32 q15, q15, q0
  393 
  394     # Load c, c_stride:
  395     # - r1 = c
  396     # - r9 = c_stride
  397     LDR r1, [sp, 112]
  398     LDR r9, [sp, 116]
  399     LSL r9, r9, 2
  400 
  401     # r1 = c0 = c pointer
  402 
  403     CMP r0, 2
  404     # r2 = c1
  405     ADD r2, r1, r9
  406     MOVLO r2, r1
  407 
  408     # r3 = c2
  409     ADD r3, r2, r9
  410     MOVLS r3, r2
  411 
  412     CMP r0, 4
  413     # r4 = c3
  414     ADD r4, r3, r9
  415     MOVLO r4, r3
  416 
  417     # r5 = c4
  418     ADD r5, r4, r9
  419     MOVLS r5, r4
  420 
  421     CMP r0, 6
  422     # r6 = c5
  423     ADD r6, r5, r9
  424     MOVLO r6, r5
  425 
  426     # r7 = c6
  427     ADD r7, r6, r9
  428     MOVLS r7, r6
  429 
  430     CMP r0, 8
  431     # r7 = c7
  432     ADD r8, r7, r9
  433     MOVNE r8, r7
  434 
  435     CMP r11, 4
  436     BNE 4f
  437 
  438     VST1.32 {q8}, [r1]
  439     VST1.32 {q9}, [r2]
  440     VST1.32 {q10}, [r3]
  441     VST1.32 {q11}, [r4]
  442     VST1.32 {q12}, [r5]
  443     VST1.32 {q13}, [r6]
  444     VST1.32 {q14}, [r7]
  445     VST1.32 {q15}, [r8]
  446 
  447     VPOP {d8-d15}
  448     POP {r4, r5, r6, r7, r8, r9, r10, r11, lr}
  449     BX lr
  450 
  451     .p2align 3
  452 4:
  453     CMP r11, 2
  454     BLO 5f
  455 
  456     VST1.32 {d16}, [r1]!
  457     VST1.32 {d18}, [r2]!
  458     VST1.32 {d20}, [r3]!
  459     VST1.32 {d22}, [r4]!
  460     VST1.32 {d24}, [r5]!
  461     VST1.32 {d26}, [r6]!
  462     VST1.32 {d28}, [r7]!
  463     VST1.32 {d30}, [r8]!
  464 
  465     SUB r11, 2
  466 
  467     VMOV.32 d16, d17
  468     VMOV.32 d18, d19
  469     VMOV.32 d20, d21
  470     VMOV.32 d22, d23
  471     VMOV.32 d24, d25
  472     VMOV.32 d26, d27
  473     VMOV.32 d28, d29
  474     VMOV.32 d30, d31
  475 
  476 5:
  477     CMP r11, 0
  478     BEQ 7f
  479 
  480     VST1.32 {d16[0]}, [r1]
  481     VST1.32 {d18[0]}, [r2]
  482     VST1.32 {d20[0]}, [r3]
  483     VST1.32 {d22[0]}, [r4]
  484     VST1.32 {d24[0]}, [r5]
  485     VST1.32 {d26[0]}, [r6]
  486     VST1.32 {d28[0]}, [r7]
  487     VST1.32 {d30[0]}, [r8]
  488 
  489 7:
  490     VPOP {d8-d15}
  491     POP {r4, r5, r6, r7, r8, r9, r10, r11, lr}
  492     BX lr
  493 
  494 END_FUNCTION pytorch_q8gemm_dq_sparse_1x4_ukernel_8x4_packedA__aarch32_neon
  495 
  496 #ifdef __ELF__
  497 .section ".note.GNU-stack","",%progbits
  498 #endif