"Fossies" - the Fresh Open Source Software Archive

Member "pytorch-1.8.2/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/4x8c1x4-dq-packedA-aarch32-neon.S" (23 Jul 2021, 10735 Bytes) of package /linux/misc/pytorch-1.8.2.tar.gz:


As a special service "Fossies" has tried to format the requested source page into HTML format using (guessed) PowerPC Assembler source code syntax highlighting (style: standard) with prefixed line numbers. Alternatively you can here view or download the uninterpreted source code file.

    1 /*
    2  * Copyright (c) Facebook, Inc. and its affiliates.
    3  * All rights reserved.
    4  *
    5  * This source code is licensed under the BSD-style license found in the
    6  * LICENSE file in the root directory of this source tree.
    7  */
    8 
    9 #include <qnnpack/assembly.h>
   10 #include <requantization/runtime-assembly.h>
   11 
   12 # r0 mr
   13 # r1 nr
   14 # r2 packed_a
   15 # r3 packed_w
   16 
   17 # d14 a_zero_point
   18 # d15 b_zero_point
   19 
   20 ## Stack
   21 # 4     a_stride
   22 # 4     packed_w
   23 # 4     w_row_ptr 
   24 # 4     w_block_ids_ptr
   25 # 4     b
   26 # 4     c
   27 # 4     c_stride
   28 # 4     output channel index
   29 # 4     quantization_params
   30 # --
   31 
   32 .syntax unified
   33 
   34 #  Args passed via stack.
   35 #  TOS
   36 #  |----------------|
   37 #  |packed_w        | 0
   38 #  |w_row_ptr       | 4
   39 #  |w_block_ids_ptr | 8
   40 #  |b               | 12
   41 #  |c               | 16
   42 #  |c_stride        | 20
   43 #  |out ch indx     | 24
   44 #  |params          | 28
   45 #  |----------------|
   46 #  
   47 
   48 #  After loading w pointer in ip reg.
   49 #  And after pushing r4-r9 and d8-d15 on stack
   50 #  |----------------|
   51 #  |d8 - d15        | 0
   52 #  |r4 - r11,lr     | 64
   53 #  |w_row_ptr       | 100
   54 #  |w_block_ids_ptr | 104
   55 #  |b               | 108
   56 #  |c               | 112
   57 #  |c_stride        | 116
   58 #  |out ch indx     | 120
   59 #  |params          | 124
   60 #  |----------------|
   61 #  
   62 
   63 # void pytorch_q8gemm_dq_sparse_1x4_ukernel_4x8_packedA__aarch32_neon(
   64 #     size_t mr,
   65 #     size_t nr,
   66 #     const uint8_t* a_packed,
   67 #     const uint8_t* packed_w,
   68 #     const uint32_t* w_row_ptr,
   69 #     const uint32_t* w_block_ids_ptr,
   70 #     const float* b,
   71 #     uint8_t* restrict c,
   72 #     size_t c_stride,
   73 #     size_t output_channel_index,
   74 #     const union pytorch_qnnp_conv_dynamic_quantization_params quantization_params[restrict static 1])
   75 BEGIN_FUNCTION pytorch_q8gemm_dq_sparse_1x4_ukernel_4x8_packedA__aarch32_neon
   76     .arm
   77 #ifndef __APPLE__
   78     .arch armv7-a
   79     .fpu neon
   80 #endif
   81 
   82     PUSH {r4, r5, r6, r7, r8, r9, r10, r11, lr}
   83     VPUSH {d8-d15}
   84 
   85     # Store nr in r11 as well for late user.
   86     MOV r11, r1
   87     # Load output channel index
   88     LDR r5, [sp, 120]
   89     # Load quantization params
   90     # - r7 = quantization_params
   91     LDR r7, [sp, 124]
   92     # Load input_zero_point
   93     VLD1.8 {d16[]}, [r7]
   94     ADD r7, r7, 4
   95     # Load pointer to per channel zero points array
   96     LDR r4, [r7]
   97     # Add output_channel_index to the b_zero_point pointer
   98     ADD r4, r4, r5
   99 
  100     # We enter the loop if r1 is atleast 1.
  101     # r1 = r1 - 1 will happen in the epilogue
  102     # of the loop
  103     CMP r1, 1
  104     BLO 7f
  105 
  106     # Load w_row_ptr + n
  107     LDR r5, [sp, 100]
  108     # r7 = blocks_id_ptr
  109     LDR r7, [sp, 104]
  110 
  111     .p2align 5
  112 0:
  113     VEOR q10, q10, q10
  114     VLD1.8 {d17[]}, [r4]!
  115     # ip = w_row_ptr[n], lr = w_row_ptr[n+1]
  116     # r5 = r5 + 4 to point to next n
  117     LDR ip, [r5], #4
  118     LDR lr, [r5]
  119     # r6 = temp_packed_w = packed_w + w_row_ptr[n] * 4
  120     # This points to the first block of nonzero value
  121     # for the nth row.
  122     ADD r6, r3, ip, LSL #2
  123     # r9 = temp_w_block_ids_ptr = w_block_ids_ptr (r7) + w_row_ptr[n]
  124     # LSL2 because each element is 4 bytes
  125     # This points to the block id of the first block
  126     # It should contain lr - ip number of block ids
  127     ADD r9, r7, ip, LSL #2
  128     # r8 = num_blocks that needs to be processed
  129     SUB r8, lr, ip
  130     SUBS r8, r8, 2
  131     BLO 1f
  132 
  133 k_loop:
  134     # Load 2 non zero blocks of weights. Each block = 1x4.
  135     VLD1.8 {d0}, [r6]!
  136 
  137     #ip = block_id_ptr[0]
  138     #lr = block_id_ptr[1]
  139     LDR ip, [r9], #4
  140     LDR lr, [r9], #4
  141 
  142     # Add offset to r2
  143     # Shift by 4 because each packed block is a block of 4x4
  144     # which 16 bytes
  145     ADD r10, r2, ip, LSL #4
  146     # q9 = vxb
  147     VSUBL.U8 q0, d0, d17
  148 
  149     # d2, d3 = 4x4 transposed
  150     VLD1.8 {d2}, [r10]!
  151     VLD1.8 {d3}, [r10]
  152 
  153     ADD r10, r2, lr, LSL #4
  154 
  155     VSUBL.U8 q4, d2, d16  // vxa0_t
  156 
  157     # d4, d5 = next 4x4 transposed
  158     VLD1.8 {d4}, [r10]!
  159     VLD1.8 {d5}, [r10]
  160 
  161     VSUBL.U8 q5, d3, d16  // vxa1_t
  162     VSUBL.U8 q6, d4, d16  // vxa4_t
  163     VSUBL.U8 q7, d5, d16  // vxa5_t
  164 
  165     # q4, q5 = 4x4 block (16 values each of 16 bits)
  166     # q6, q7 = 4x4 block (16 values each of 16 bits)
  167 
  168     VMLAL.S16 q10, d8, d0[0]
  169     VMLAL.S16 q10, d9, d0[1]
  170     VMLAL.S16 q10, d10, d0[2]
  171     VMLAL.S16 q10, d11, d0[3]
  172     VMLAL.S16 q10, d12, d1[0]
  173     VMLAL.S16 q10, d13, d1[1]
  174     VMLAL.S16 q10, d14, d1[2]
  175     VMLAL.S16 q10, d15, d1[3]
  176 
  177     SUBS r8, r8, 2
  178 
  179     BHS k_loop
  180 1:
  181     CMP r8, -2
  182     BEQ 2f
  183 
  184     # Load last nonzero block
  185     # For this we will load 4 8 bit values as one 32 bit value
  186     VLD1.32 {d0[]}, [r6]!
  187     # q9 = vxb
  188     VSUBL.U8 q0, d0, d17
  189 
  190     #ip = block_id_ptr[0]
  191     LDR ip, [r9]
  192 
  193     # Add offset to r2
  194     # Shift by 4 because each packed block is a block of 4x4
  195     # which 16 bytes
  196     ADD r10, r2, ip, LSL #4
  197 
  198     VLD1.8 {d2}, [r10]!
  199     VLD1.8 {d3}, [r10]
  200 
  201     VSUBL.U8 q4, d2, d16  // vxa0_t
  202     VSUBL.U8 q5, d3, d16  // vxa1_t
  203 
  204     VMLAL.S16 q10, d8, d0[0]
  205     VMLAL.S16 q10, d9, d0[1]
  206     VMLAL.S16 q10, d10, d0[2]
  207     VMLAL.S16 q10, d11, d0[3]
  208 
  209     .p2align 4
  210 2:
  211     # Store result on stack
  212 
  213     # -12 because TOS - 4, TOS - 8, and TOS - 12, store mr, nr and pointer to weight zp
  214     # + 128 bytes of buffer when nr = 1
  215     # This is needed because after processing all nrs we will
  216     # load 128 bytes from stack. This is for q10, q11 for max nr of 4
  217     # Thus we will load accumulators back in q0, q1, q2, q3, q4, q5, q6, q7
  218     # When nr < 4, extra q values will be fetched from stack which may overlap
  219     # with other parts of stack storing local variables. To avoid that we just
  220     # create a buffer of 128 bytes inbetween to make sure pointer increment
  221     # never produces address that is beyond the stack frame of this function.
  222     SUB r9, sp, 140
  223     # Each iteration produce 4 values each of 4 bytes
  224     # Thus 4 x 4 = 16 bytes 2^4
  225     # In this implementation, first value will be stored at
  226     # 1st value: sp - 12 - r1 * 16 
  227     # 2nd value: sp - 12 - (r1 - 1) * 16
  228     # and so on.
  229     SUB r9, r9, r1, LSL #4
  230     VST1.32 {q10}, [r9]
  231 
  232     # Check if nr >=1
  233     SUBS r1, r1, 1
  234     BHI 0b
  235 3:
  236     # First load all the accumulators from stack
  237     # Load nr
  238     SUB r9, sp, 140
  239     SUB r9, r9, r11, LSL #4
  240     # Now load q8-q15
  241     # This is 8x4 block (nrxmr)
  242     # We will transpose this to 4x8 (mrxnr)
  243     # q8, q12  : x00, x10, x20, x30; x04, x14, x24, x34
  244     # q9, q13  : x01, x11, x21, x31; x05, x15, x25, x35
  245     # q10, q14 : x02, x12, x22, x32; x06, x16, x26, x36
  246     # q11, q15 : x03, x13, x23, x33; x07, x17, x27, x37
  247     VLD1.32 {q8}, [r9]!
  248     VLD1.32 {q9}, [r9]!
  249     VLD1.32 {q10}, [r9]!
  250     VLD1.32 {q11}, [r9]!
  251     VLD1.32 {q12}, [r9]!
  252     VLD1.32 {q13}, [r9]!
  253     VLD1.32 {q14}, [r9]!
  254     VLD1.32 {q15}, [r9]
  255 
  256     ## Now transpose q8-11
  257     # VTRN.32 q8, q9
  258     # VTRN.32 q10, q11
  259     # q8 : X00, x01, x20, x21
  260     # q9 : X10, x11, x30, x31
  261     # q10: X02, x03, x22, x23
  262     # q11: X12, x13, x32, x33
  263     # VSWP d16, d17
  264     # q8 : x20, x21, x00, x01
  265     # VEXT.32 q6, q8, q10, 2
  266     # q6 : x00, x01, x02, x03
  267     # VEXT.32 q10, q10, q8, 2
  268     # q10: x22, x23, x20, x21
  269     # VSWP d20, d21
  270     # VMOV q8, q6
  271     # q8 : X00, x01, x02, x03
  272     # q10: x20, x21, x22, x23
  273     # VSWP d18, d19
  274     # q9 : x30, x31, x10, x11
  275     # VEXT.32 q6, q9, q11, 2
  276     # q6 : x10, x11, x12, x13
  277     # VEXT.32 q11, q11, q9, 2
  278     # q11: x32, x33, x30, x31
  279     # VSWP d22, d23
  280     # VMOV q9, q6
  281     # q9 : x10, x11, x12, x13
  282     # q11: x30, x31, x32, x33
  283     # Thus we have
  284     # q8 : X00, x01, x02, x03
  285     # q9 : X10, x11, x12, x13
  286     # q10: X20, x21, x22, x23
  287     # q11: X30, x31, x32, x33
  288     # Now we can do the same for q4-q7
  289     # q12: X04, X05, X06, X07
  290     # q13: X14, X15, X16, X17
  291     # q14: X24, X25, X26, X27
  292     # q15: X34, X35, X36, X37
  293 
  294     VTRN.32 q8, q9
  295     VTRN.32 q10, q11
  296     VSWP d16, d17
  297     VEXT.32 q6, q8, q10, 2
  298     VEXT.32 q10, q10, q8, 2
  299     VSWP d20, d21
  300     VMOV q8, q6
  301     VSWP d18, d19
  302     VEXT.32 q6, q9, q11, 2
  303     VEXT.32 q11, q11, q9, 2
  304     VSWP d22, d23
  305     VMOV q9, q6
  306 
  307     VTRN.32 q12, q13
  308     VTRN.32 q14, q15
  309     VSWP d24, d25
  310     VEXT.32 q6, q12, q14, 2
  311     VEXT.32 q14, q14, q12, 2
  312     VSWP d28, d29
  313     VMOV q12, q6
  314     VSWP d26, d27
  315     VEXT.32 q6, q13, q15, 2
  316     VEXT.32 q15, q15, q13, 2
  317     VSWP d30, d31
  318     VMOV q13, q6
  319 
  320     # Load output channel index
  321     LDR r5, [sp, 120]
  322     # Load quantization params
  323     # - r7 = quantization_params
  324     LDR r7, [sp, 124]
  325     ADD r7, r7, 8
  326     # Load pointer to per channel requant scale
  327     LDR r7, [r7]
  328     # Now r7 has the base_addr + offset for multipliers
  329     ADD r7, r7, r5, LSL #2
  330 
  331     LDR r6, [sp, 108]
  332     # Load q6: vmultiplier_c0123
  333     VLD1.32 {d12, d13}, [r7]!
  334     # Load q7: vmultiplier_c4567
  335     VLD1.32 {d14, d15}, [r7]
  336     VCVT.F32.S32 q8, q8
  337     VCVT.F32.S32 q9, q9
  338     VCVT.F32.S32 q10, q10
  339     VLD1.32 {q0}, [r6]!
  340     VLD1.32 {q1}, [r6]
  341 
  342     VCVT.F32.S32 q11, q11
  343     VCVT.F32.S32 q12, q12
  344     VCVT.F32.S32 q13, q13
  345     VCVT.F32.S32 q14, q14
  346     VCVT.F32.S32 q15, q15
  347 
  348     VMUL.F32 q8, q8, q6
  349     VMUL.F32 q9, q9, q6
  350     VMUL.F32 q10, q10, q6
  351     VMUL.F32 q11, q11, q6
  352     VMUL.F32 q12, q12, q7
  353     VMUL.F32 q13, q13, q7
  354     VMUL.F32 q14, q14, q7
  355     VMUL.F32 q15, q15, q7
  356 
  357     VADD.F32 q8, q8, q0
  358     VADD.F32 q9, q9, q0
  359     VADD.F32 q10, q10, q0
  360     VADD.F32 q11, q11, q0
  361     VADD.F32 q12, q12, q1
  362     VADD.F32 q13, q13, q1
  363     VADD.F32 q14, q14, q1
  364     VADD.F32 q15, q15, q1
  365 
  366     # Load c, c_stride:
  367     # - r1 = c
  368     # - r9 = c_stride
  369     LDR r1, [sp, 112]
  370     LDR r9, [sp, 116]
  371     LSL r9, r9, 2
  372 
  373     # r1 = c0 = c pointer
  374 
  375     CMP r0, 2
  376     # r2 = c1
  377     ADD r2, r1, r9
  378     MOVLO r2, r1
  379 
  380     # r3 = c2
  381     ADD r3, r2, r9
  382     MOVLS r3, r2
  383 
  384     CMP r0, 4
  385     # r4 = c3
  386     ADD r4, r3, r9
  387     MOVNE r4, r3
  388 
  389     CMP r11, 8
  390     BNE 4f
  391 
  392     VST1.32 {q8}, [r1]!
  393     VST1.32 {q9}, [r2]!
  394     VST1.32 {q10}, [r3]!
  395     VST1.32 {q11}, [r4]!
  396     VST1.32 {q12}, [r1]
  397     VST1.32 {q13}, [r2]
  398     VST1.32 {q14}, [r3]
  399     VST1.32 {q15}, [r4]
  400 
  401     VPOP {d8-d15}
  402     POP {r4, r5, r6, r7, r8, r9, r10, r11, lr}
  403     BX lr
  404 
  405     .p2align 3
  406 4:
  407     CMP r11, 4
  408     BLO 5f
  409 
  410     VST1.32 {q8}, [r1]!
  411     VST1.32 {q9}, [r2]!
  412     VST1.32 {q10}, [r3]!
  413     VST1.32 {q11}, [r4]!
  414 
  415     SUB r11, 4
  416 
  417     VMOV.32 q8, q12
  418     VMOV.32 q9, q13
  419     VMOV.32 q10, q14
  420     VMOV.32 q11, q15
  421 
  422 5:
  423     CMP r11, 2
  424     BLO 6f
  425 
  426     VST1.32 {d16}, [r1]!
  427     VST1.32 {d18}, [r2]!
  428     VST1.32 {d20}, [r3]!
  429     VST1.32 {d22}, [r4]!
  430 
  431     SUB r11, 2
  432 
  433     VEXT.32 q8, q8, 2
  434     VEXT.32 q9, q9, 2
  435     VEXT.32 q10, q10, 2
  436     VEXT.32 q11, q11, 2
  437 
  438 6:
  439     TEQ r11, 0
  440     BEQ 7f
  441 
  442     VST1.32 {d16[0]}, [r1]
  443     VST1.32 {d18[0]}, [r2]
  444     VST1.32 {d20[0]}, [r3]
  445     VST1.32 {d22[0]}, [r4]
  446 
  447 7:
  448     VPOP {d8-d15}
  449     POP {r4, r5, r6, r7, r8, r9, r10, r11, lr}
  450     BX lr
  451 
  452 END_FUNCTION pytorch_q8gemm_dq_sparse_1x4_ukernel_4x8_packedA__aarch32_neon
  453 
  454 #ifdef __ELF__
  455 .section ".note.GNU-stack","",%progbits
  456 #endif