"Fossies" - the Fresh Open Source Software Archive

Member "mesa-20.1.8/src/compiler/spirv/vtn_subgroup.c" (16 Sep 2020, 15258 Bytes) of package /linux/misc/mesa-20.1.8.tar.xz:


As a special service "Fossies" has tried to format the requested source page into HTML format using (guessed) C and C++ source code syntax highlighting (style: standard) with prefixed line numbers and code folding option. Alternatively you can here view or download the uninterpreted source code file. For more information about "vtn_subgroup.c" see the Fossies "Dox" file reference documentation and the last Fossies "Diffs" side-by-side code changes report: 20.1.5_vs_20.2.0-rc1.

    1 /*
    2  * Copyright © 2016 Intel Corporation
    3  *
    4  * Permission is hereby granted, free of charge, to any person obtaining a
    5  * copy of this software and associated documentation files (the "Software"),
    6  * to deal in the Software without restriction, including without limitation
    7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
    8  * and/or sell copies of the Software, and to permit persons to whom the
    9  * Software is furnished to do so, subject to the following conditions:
   10  *
   11  * The above copyright notice and this permission notice (including the next
   12  * paragraph) shall be included in all copies or substantial portions of the
   13  * Software.
   14  *
   15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
   16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
   17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
   18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
   19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
   20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
   21  * IN THE SOFTWARE.
   22  */
   23 
   24 #include "vtn_private.h"
   25 
   26 static void
   27 vtn_build_subgroup_instr(struct vtn_builder *b,
   28                          nir_intrinsic_op nir_op,
   29                          struct vtn_ssa_value *dst,
   30                          struct vtn_ssa_value *src0,
   31                          nir_ssa_def *index,
   32                          unsigned const_idx0,
   33                          unsigned const_idx1)
   34 {
   35    /* Some of the subgroup operations take an index.  SPIR-V allows this to be
   36     * any integer type.  To make things simpler for drivers, we only support
   37     * 32-bit indices.
   38     */
   39    if (index && index->bit_size != 32)
   40       index = nir_u2u32(&b->nb, index);
   41 
   42    vtn_assert(dst->type == src0->type);
   43    if (!glsl_type_is_vector_or_scalar(dst->type)) {
   44       for (unsigned i = 0; i < glsl_get_length(dst->type); i++) {
   45          vtn_build_subgroup_instr(b, nir_op, dst->elems[i],
   46                                   src0->elems[i], index,
   47                                   const_idx0, const_idx1);
   48       }
   49       return;
   50    }
   51 
   52    nir_intrinsic_instr *intrin =
   53       nir_intrinsic_instr_create(b->nb.shader, nir_op);
   54    nir_ssa_dest_init_for_type(&intrin->instr, &intrin->dest,
   55                               dst->type, NULL);
   56    intrin->num_components = intrin->dest.ssa.num_components;
   57 
   58    intrin->src[0] = nir_src_for_ssa(src0->def);
   59    if (index)
   60       intrin->src[1] = nir_src_for_ssa(index);
   61 
   62    intrin->const_index[0] = const_idx0;
   63    intrin->const_index[1] = const_idx1;
   64 
   65    nir_builder_instr_insert(&b->nb, &intrin->instr);
   66 
   67    dst->def = &intrin->dest.ssa;
   68 }
   69 
   70 void
   71 vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
   72                     const uint32_t *w, unsigned count)
   73 {
   74    struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_ssa);
   75 
   76    val->ssa = vtn_create_ssa_value(b, val->type->type);
   77 
   78    switch (opcode) {
   79    case SpvOpGroupNonUniformElect: {
   80       vtn_fail_if(val->type->type != glsl_bool_type(),
   81                   "OpGroupNonUniformElect must return a Bool");
   82       nir_intrinsic_instr *elect =
   83          nir_intrinsic_instr_create(b->nb.shader, nir_intrinsic_elect);
   84       nir_ssa_dest_init_for_type(&elect->instr, &elect->dest,
   85                                  val->type->type, NULL);
   86       nir_builder_instr_insert(&b->nb, &elect->instr);
   87       val->ssa->def = &elect->dest.ssa;
   88       break;
   89    }
   90 
   91    case SpvOpGroupNonUniformBallot: ++w;
   92    case SpvOpSubgroupBallotKHR: {
   93       vtn_fail_if(val->type->type != glsl_vector_type(GLSL_TYPE_UINT, 4),
   94                   "OpGroupNonUniformBallot must return a uvec4");
   95       nir_intrinsic_instr *ballot =
   96          nir_intrinsic_instr_create(b->nb.shader, nir_intrinsic_ballot);
   97       ballot->src[0] = nir_src_for_ssa(vtn_ssa_value(b, w[3])->def);
   98       nir_ssa_dest_init(&ballot->instr, &ballot->dest, 4, 32, NULL);
   99       ballot->num_components = 4;
  100       nir_builder_instr_insert(&b->nb, &ballot->instr);
  101       val->ssa->def = &ballot->dest.ssa;
  102       break;
  103    }
  104 
  105    case SpvOpGroupNonUniformInverseBallot: {
  106       /* This one is just a BallotBitfieldExtract with subgroup invocation.
  107        * We could add a NIR intrinsic but it's easier to just lower it on the
  108        * spot.
  109        */
  110       nir_intrinsic_instr *intrin =
  111          nir_intrinsic_instr_create(b->nb.shader,
  112                                     nir_intrinsic_ballot_bitfield_extract);
  113 
  114       intrin->src[0] = nir_src_for_ssa(vtn_ssa_value(b, w[4])->def);
  115       intrin->src[1] = nir_src_for_ssa(nir_load_subgroup_invocation(&b->nb));
  116 
  117       nir_ssa_dest_init_for_type(&intrin->instr, &intrin->dest,
  118                                  val->type->type, NULL);
  119       nir_builder_instr_insert(&b->nb, &intrin->instr);
  120 
  121       val->ssa->def = &intrin->dest.ssa;
  122       break;
  123    }
  124 
  125    case SpvOpGroupNonUniformBallotBitExtract:
  126    case SpvOpGroupNonUniformBallotBitCount:
  127    case SpvOpGroupNonUniformBallotFindLSB:
  128    case SpvOpGroupNonUniformBallotFindMSB: {
  129       nir_ssa_def *src0, *src1 = NULL;
  130       nir_intrinsic_op op;
  131       switch (opcode) {
  132       case SpvOpGroupNonUniformBallotBitExtract:
  133          op = nir_intrinsic_ballot_bitfield_extract;
  134          src0 = vtn_ssa_value(b, w[4])->def;
  135          src1 = vtn_ssa_value(b, w[5])->def;
  136          break;
  137       case SpvOpGroupNonUniformBallotBitCount:
  138          switch ((SpvGroupOperation)w[4]) {
  139          case SpvGroupOperationReduce:
  140             op = nir_intrinsic_ballot_bit_count_reduce;
  141             break;
  142          case SpvGroupOperationInclusiveScan:
  143             op = nir_intrinsic_ballot_bit_count_inclusive;
  144             break;
  145          case SpvGroupOperationExclusiveScan:
  146             op = nir_intrinsic_ballot_bit_count_exclusive;
  147             break;
  148          default:
  149             unreachable("Invalid group operation");
  150          }
  151          src0 = vtn_ssa_value(b, w[5])->def;
  152          break;
  153       case SpvOpGroupNonUniformBallotFindLSB:
  154          op = nir_intrinsic_ballot_find_lsb;
  155          src0 = vtn_ssa_value(b, w[4])->def;
  156          break;
  157       case SpvOpGroupNonUniformBallotFindMSB:
  158          op = nir_intrinsic_ballot_find_msb;
  159          src0 = vtn_ssa_value(b, w[4])->def;
  160          break;
  161       default:
  162          unreachable("Unhandled opcode");
  163       }
  164 
  165       nir_intrinsic_instr *intrin =
  166          nir_intrinsic_instr_create(b->nb.shader, op);
  167 
  168       intrin->src[0] = nir_src_for_ssa(src0);
  169       if (src1)
  170          intrin->src[1] = nir_src_for_ssa(src1);
  171 
  172       nir_ssa_dest_init_for_type(&intrin->instr, &intrin->dest,
  173                                  val->type->type, NULL);
  174       nir_builder_instr_insert(&b->nb, &intrin->instr);
  175 
  176       val->ssa->def = &intrin->dest.ssa;
  177       break;
  178    }
  179 
  180    case SpvOpGroupNonUniformBroadcastFirst: ++w;
  181    case SpvOpSubgroupFirstInvocationKHR:
  182       vtn_build_subgroup_instr(b, nir_intrinsic_read_first_invocation,
  183                                val->ssa, vtn_ssa_value(b, w[3]), NULL, 0, 0);
  184       break;
  185 
  186    case SpvOpGroupNonUniformBroadcast:
  187    case SpvOpGroupBroadcast: ++w;
  188    case SpvOpSubgroupReadInvocationKHR:
  189       vtn_build_subgroup_instr(b, nir_intrinsic_read_invocation,
  190                                val->ssa, vtn_ssa_value(b, w[3]),
  191                                vtn_ssa_value(b, w[4])->def, 0, 0);
  192       break;
  193 
  194    case SpvOpGroupNonUniformAll:
  195    case SpvOpGroupNonUniformAny:
  196    case SpvOpGroupNonUniformAllEqual:
  197    case SpvOpGroupAll:
  198    case SpvOpGroupAny:
  199    case SpvOpSubgroupAllKHR:
  200    case SpvOpSubgroupAnyKHR:
  201    case SpvOpSubgroupAllEqualKHR: {
  202       vtn_fail_if(val->type->type != glsl_bool_type(),
  203                   "OpGroupNonUniform(All|Any|AllEqual) must return a bool");
  204       nir_intrinsic_op op;
  205       switch (opcode) {
  206       case SpvOpGroupNonUniformAll:
  207       case SpvOpGroupAll:
  208       case SpvOpSubgroupAllKHR:
  209          op = nir_intrinsic_vote_all;
  210          break;
  211       case SpvOpGroupNonUniformAny:
  212       case SpvOpGroupAny:
  213       case SpvOpSubgroupAnyKHR:
  214          op = nir_intrinsic_vote_any;
  215          break;
  216       case SpvOpSubgroupAllEqualKHR:
  217          op = nir_intrinsic_vote_ieq;
  218          break;
  219       case SpvOpGroupNonUniformAllEqual:
  220          switch (glsl_get_base_type(vtn_ssa_value(b, w[4])->type)) {
  221          case GLSL_TYPE_FLOAT:
  222          case GLSL_TYPE_FLOAT16:
  223          case GLSL_TYPE_DOUBLE:
  224             op = nir_intrinsic_vote_feq;
  225             break;
  226          case GLSL_TYPE_UINT:
  227          case GLSL_TYPE_INT:
  228          case GLSL_TYPE_UINT8:
  229          case GLSL_TYPE_INT8:
  230          case GLSL_TYPE_UINT16:
  231          case GLSL_TYPE_INT16:
  232          case GLSL_TYPE_UINT64:
  233          case GLSL_TYPE_INT64:
  234          case GLSL_TYPE_BOOL:
  235             op = nir_intrinsic_vote_ieq;
  236             break;
  237          default:
  238             unreachable("Unhandled type");
  239          }
  240          break;
  241       default:
  242          unreachable("Unhandled opcode");
  243       }
  244 
  245       nir_ssa_def *src0;
  246       if (opcode == SpvOpGroupNonUniformAll || opcode == SpvOpGroupAll ||
  247           opcode == SpvOpGroupNonUniformAny || opcode == SpvOpGroupAny ||
  248           opcode == SpvOpGroupNonUniformAllEqual) {
  249          src0 = vtn_ssa_value(b, w[4])->def;
  250       } else {
  251          src0 = vtn_ssa_value(b, w[3])->def;
  252       }
  253       nir_intrinsic_instr *intrin =
  254          nir_intrinsic_instr_create(b->nb.shader, op);
  255       intrin->num_components = src0->num_components;
  256       intrin->src[0] = nir_src_for_ssa(src0);
  257       nir_ssa_dest_init_for_type(&intrin->instr, &intrin->dest,
  258                                  val->type->type, NULL);
  259       nir_builder_instr_insert(&b->nb, &intrin->instr);
  260 
  261       val->ssa->def = &intrin->dest.ssa;
  262       break;
  263    }
  264 
  265    case SpvOpGroupNonUniformShuffle:
  266    case SpvOpGroupNonUniformShuffleXor:
  267    case SpvOpGroupNonUniformShuffleUp:
  268    case SpvOpGroupNonUniformShuffleDown: {
  269       nir_intrinsic_op op;
  270       switch (opcode) {
  271       case SpvOpGroupNonUniformShuffle:
  272          op = nir_intrinsic_shuffle;
  273          break;
  274       case SpvOpGroupNonUniformShuffleXor:
  275          op = nir_intrinsic_shuffle_xor;
  276          break;
  277       case SpvOpGroupNonUniformShuffleUp:
  278          op = nir_intrinsic_shuffle_up;
  279          break;
  280       case SpvOpGroupNonUniformShuffleDown:
  281          op = nir_intrinsic_shuffle_down;
  282          break;
  283       default:
  284          unreachable("Invalid opcode");
  285       }
  286       vtn_build_subgroup_instr(b, op, val->ssa, vtn_ssa_value(b, w[4]),
  287                                vtn_ssa_value(b, w[5])->def, 0, 0);
  288       break;
  289    }
  290 
  291    case SpvOpGroupNonUniformQuadBroadcast:
  292       vtn_build_subgroup_instr(b, nir_intrinsic_quad_broadcast,
  293                                val->ssa, vtn_ssa_value(b, w[4]),
  294                                vtn_ssa_value(b, w[5])->def, 0, 0);
  295       break;
  296 
  297    case SpvOpGroupNonUniformQuadSwap: {
  298       unsigned direction = vtn_constant_uint(b, w[5]);
  299       nir_intrinsic_op op;
  300       switch (direction) {
  301       case 0:
  302          op = nir_intrinsic_quad_swap_horizontal;
  303          break;
  304       case 1:
  305          op = nir_intrinsic_quad_swap_vertical;
  306          break;
  307       case 2:
  308          op = nir_intrinsic_quad_swap_diagonal;
  309          break;
  310       default:
  311          vtn_fail("Invalid constant value in OpGroupNonUniformQuadSwap");
  312       }
  313       vtn_build_subgroup_instr(b, op, val->ssa, vtn_ssa_value(b, w[4]),
  314                                NULL, 0, 0);
  315       break;
  316    }
  317 
  318    case SpvOpGroupNonUniformIAdd:
  319    case SpvOpGroupNonUniformFAdd:
  320    case SpvOpGroupNonUniformIMul:
  321    case SpvOpGroupNonUniformFMul:
  322    case SpvOpGroupNonUniformSMin:
  323    case SpvOpGroupNonUniformUMin:
  324    case SpvOpGroupNonUniformFMin:
  325    case SpvOpGroupNonUniformSMax:
  326    case SpvOpGroupNonUniformUMax:
  327    case SpvOpGroupNonUniformFMax:
  328    case SpvOpGroupNonUniformBitwiseAnd:
  329    case SpvOpGroupNonUniformBitwiseOr:
  330    case SpvOpGroupNonUniformBitwiseXor:
  331    case SpvOpGroupNonUniformLogicalAnd:
  332    case SpvOpGroupNonUniformLogicalOr:
  333    case SpvOpGroupNonUniformLogicalXor:
  334    case SpvOpGroupIAdd:
  335    case SpvOpGroupFAdd:
  336    case SpvOpGroupFMin:
  337    case SpvOpGroupUMin:
  338    case SpvOpGroupSMin:
  339    case SpvOpGroupFMax:
  340    case SpvOpGroupUMax:
  341    case SpvOpGroupSMax:
  342    case SpvOpGroupIAddNonUniformAMD:
  343    case SpvOpGroupFAddNonUniformAMD:
  344    case SpvOpGroupFMinNonUniformAMD:
  345    case SpvOpGroupUMinNonUniformAMD:
  346    case SpvOpGroupSMinNonUniformAMD:
  347    case SpvOpGroupFMaxNonUniformAMD:
  348    case SpvOpGroupUMaxNonUniformAMD:
  349    case SpvOpGroupSMaxNonUniformAMD: {
  350       nir_op reduction_op;
  351       switch (opcode) {
  352       case SpvOpGroupNonUniformIAdd:
  353       case SpvOpGroupIAdd:
  354       case SpvOpGroupIAddNonUniformAMD:
  355          reduction_op = nir_op_iadd;
  356          break;
  357       case SpvOpGroupNonUniformFAdd:
  358       case SpvOpGroupFAdd:
  359       case SpvOpGroupFAddNonUniformAMD:
  360          reduction_op = nir_op_fadd;
  361          break;
  362       case SpvOpGroupNonUniformIMul:
  363          reduction_op = nir_op_imul;
  364          break;
  365       case SpvOpGroupNonUniformFMul:
  366          reduction_op = nir_op_fmul;
  367          break;
  368       case SpvOpGroupNonUniformSMin:
  369       case SpvOpGroupSMin:
  370       case SpvOpGroupSMinNonUniformAMD:
  371          reduction_op = nir_op_imin;
  372          break;
  373       case SpvOpGroupNonUniformUMin:
  374       case SpvOpGroupUMin:
  375       case SpvOpGroupUMinNonUniformAMD:
  376          reduction_op = nir_op_umin;
  377          break;
  378       case SpvOpGroupNonUniformFMin:
  379       case SpvOpGroupFMin:
  380       case SpvOpGroupFMinNonUniformAMD:
  381          reduction_op = nir_op_fmin;
  382          break;
  383       case SpvOpGroupNonUniformSMax:
  384       case SpvOpGroupSMax:
  385       case SpvOpGroupSMaxNonUniformAMD:
  386          reduction_op = nir_op_imax;
  387          break;
  388       case SpvOpGroupNonUniformUMax:
  389       case SpvOpGroupUMax:
  390       case SpvOpGroupUMaxNonUniformAMD:
  391          reduction_op = nir_op_umax;
  392          break;
  393       case SpvOpGroupNonUniformFMax:
  394       case SpvOpGroupFMax:
  395       case SpvOpGroupFMaxNonUniformAMD:
  396          reduction_op = nir_op_fmax;
  397          break;
  398       case SpvOpGroupNonUniformBitwiseAnd:
  399       case SpvOpGroupNonUniformLogicalAnd:
  400          reduction_op = nir_op_iand;
  401          break;
  402       case SpvOpGroupNonUniformBitwiseOr:
  403       case SpvOpGroupNonUniformLogicalOr:
  404          reduction_op = nir_op_ior;
  405          break;
  406       case SpvOpGroupNonUniformBitwiseXor:
  407       case SpvOpGroupNonUniformLogicalXor:
  408          reduction_op = nir_op_ixor;
  409          break;
  410       default:
  411          unreachable("Invalid reduction operation");
  412       }
  413 
  414       nir_intrinsic_op op;
  415       unsigned cluster_size = 0;
  416       switch ((SpvGroupOperation)w[4]) {
  417       case SpvGroupOperationReduce:
  418          op = nir_intrinsic_reduce;
  419          break;
  420       case SpvGroupOperationInclusiveScan:
  421          op = nir_intrinsic_inclusive_scan;
  422          break;
  423       case SpvGroupOperationExclusiveScan:
  424          op = nir_intrinsic_exclusive_scan;
  425          break;
  426       case SpvGroupOperationClusteredReduce:
  427          op = nir_intrinsic_reduce;
  428          assert(count == 7);
  429          cluster_size = vtn_constant_uint(b, w[6]);
  430          break;
  431       default:
  432          unreachable("Invalid group operation");
  433       }
  434 
  435       vtn_build_subgroup_instr(b, op, val->ssa, vtn_ssa_value(b, w[5]),
  436                                NULL, reduction_op, cluster_size);
  437       break;
  438    }
  439 
  440    default:
  441       unreachable("Invalid SPIR-V opcode");
  442    }
  443 }