"Fossies" - the Fresh Open Source Software Archive

Member "mesa-20.1.8/src/compiler/nir/nir_lower_bool_to_bitsize.c" (16 Sep 2020, 14330 Bytes) of package /linux/misc/mesa-20.1.8.tar.xz:


As a special service "Fossies" has tried to format the requested source page into HTML format using (guessed) C and C++ source code syntax highlighting (style: standard) with prefixed line numbers and code folding option. Alternatively you can here view or download the uninterpreted source code file. For more information about "nir_lower_bool_to_bitsize.c" see the Fossies "Dox" file reference documentation.

    1 /*
    2  * Copyright © 2018 Intel Corporation
    3  *
    4  * Permission is hereby granted, free of charge, to any person obtaining a
    5  * copy of this software and associated documentation files (the "Software"),
    6  * to deal in the Software without restriction, including without limitation
    7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
    8  * and/or sell copies of the Software, and to permit persons to whom the
    9  * Software is furnished to do so, subject to the following conditions:
   10  *
   11  * The above copyright notice and this permission notice (including the next
   12  * paragraph) shall be included in all copies or substantial portions of the
   13  * Software.
   14  *
   15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
   16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
   17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
   18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
   19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
   20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
   21  * IN THE SOFTWARE.
   22  */
   23 
   24 #include "nir.h"
   25 #include "nir_builder.h"
   26 
   27 static bool
   28 assert_ssa_def_is_not_1bit(nir_ssa_def *def, UNUSED void *unused)
   29 {
   30    assert(def->bit_size > 1);
   31    return true;
   32 }
   33 
   34 static bool
   35 rewrite_1bit_ssa_def_to_32bit(nir_ssa_def *def, void *_progress)
   36 {
   37    bool *progress = _progress;
   38    if (def->bit_size == 1) {
   39       def->bit_size = 32;
   40       *progress = true;
   41    }
   42    return true;
   43 }
   44 
   45 static uint32_t
   46 get_bool_convert_opcode(uint32_t dst_bit_size)
   47 {
   48    switch (dst_bit_size) {
   49    case 32: return nir_op_i2i32;
   50    case 16: return nir_op_i2i16;
   51    case 8:  return nir_op_i2i8;
   52    default:
   53       unreachable("invalid boolean bit-size");
   54    }
   55 }
   56 
   57 static void
   58 make_sources_canonical(nir_builder *b, nir_alu_instr *alu, uint32_t start_idx)
   59 {
   60    /* TODO: for now we take the bit-size of the first source as the canonical
   61     * form but we could try to be smarter.
   62     */
   63    const nir_op_info *op_info = &nir_op_infos[alu->op];
   64    uint32_t bit_size = nir_src_bit_size(alu->src[start_idx].src);
   65    for (uint32_t i = start_idx + 1; i < op_info->num_inputs; i++) {
   66       if (nir_src_bit_size(alu->src[i].src) != bit_size) {
   67          b->cursor = nir_before_instr(&alu->instr);
   68          nir_op convert_op = get_bool_convert_opcode(bit_size);
   69          nir_ssa_def *new_src =
   70             nir_build_alu(b, convert_op, alu->src[i].src.ssa, NULL, NULL, NULL);
   71          /* Retain the write mask and swizzle of the original instruction so
   72           * that we don’t unnecessarily create a vectorized instruction.
   73           */
   74          nir_alu_instr *conv_instr =
   75             nir_instr_as_alu(nir_builder_last_instr(b));
   76          conv_instr->dest.write_mask = alu->dest.write_mask;
   77          conv_instr->dest.dest.ssa.num_components =
   78             alu->dest.dest.ssa.num_components;
   79          memcpy(conv_instr->src[0].swizzle,
   80                 alu->src[i].swizzle,
   81                 sizeof(conv_instr->src[0].swizzle));
   82          nir_instr_rewrite_src(&alu->instr,
   83                                &alu->src[i].src, nir_src_for_ssa(new_src));
   84          /* The swizzle will have been handled by the conversion instruction
   85           * so we can reset it back to the default
   86           */
   87          for (unsigned j = 0; j < NIR_MAX_VEC_COMPONENTS; j++)
   88             alu->src[i].swizzle[j] = j;
   89       }
   90    }
   91 }
   92 
   93 static bool
   94 lower_alu_instr(nir_builder *b, nir_alu_instr *alu)
   95 {
   96    const nir_op_info *op_info = &nir_op_infos[alu->op];
   97 
   98    /* For operations that can take multiple boolean sources we need to ensure
   99     * that all booleans have the same bit-size
  100     */
  101    switch (alu->op) {
  102    case nir_op_mov:
  103    case nir_op_vec2:
  104    case nir_op_vec3:
  105    case nir_op_vec4:
  106    case nir_op_vec8:
  107    case nir_op_vec16:
  108    case nir_op_inot:
  109    case nir_op_iand:
  110    case nir_op_ior:
  111    case nir_op_ixor:
  112       if (nir_dest_bit_size(alu->dest.dest) > 1)
  113          break; /* Not a boolean instruction */
  114       /* Fallthrough */
  115 
  116    case nir_op_ball_fequal2:
  117    case nir_op_ball_fequal3:
  118    case nir_op_ball_fequal4:
  119    case nir_op_bany_fnequal2:
  120    case nir_op_bany_fnequal3:
  121    case nir_op_bany_fnequal4:
  122    case nir_op_ball_iequal2:
  123    case nir_op_ball_iequal3:
  124    case nir_op_ball_iequal4:
  125    case nir_op_bany_inequal2:
  126    case nir_op_bany_inequal3:
  127    case nir_op_bany_inequal4:
  128    case nir_op_ieq:
  129    case nir_op_ine:
  130       make_sources_canonical(b, alu, 0);
  131       break;
  132 
  133    case nir_op_bcsel:
  134       /* bcsel may be choosing between boolean sources too */
  135       if (nir_dest_bit_size(alu->dest.dest) == 1)
  136          make_sources_canonical(b, alu, 1);
  137       break;
  138 
  139    default:
  140       break;
  141    }
  142 
  143    /* Now that we have a canonical boolean bit-size, go on and rewrite the
  144     * instruction to match the canonical bit-size.
  145     */
  146    uint32_t bit_size = nir_src_bit_size(alu->src[0].src);
  147    assert(bit_size > 1);
  148 
  149    nir_op opcode = alu->op;
  150    switch (opcode) {
  151    case nir_op_mov:
  152    case nir_op_vec2:
  153    case nir_op_vec3:
  154    case nir_op_vec4:
  155    case nir_op_vec8:
  156    case nir_op_vec16:
  157    case nir_op_inot:
  158    case nir_op_iand:
  159    case nir_op_ior:
  160    case nir_op_ixor:
  161       /* Nothing to do here, we do not specialize these opcodes by bit-size */
  162       break;
  163 
  164    case nir_op_f2b1:
  165       opcode = bit_size == 8 ? nir_op_f2b8 :
  166                                bit_size == 16 ? nir_op_f2b16 : nir_op_f2b32;
  167       break;
  168 
  169    case nir_op_i2b1:
  170       opcode = bit_size == 8 ? nir_op_i2b8 :
  171                                bit_size == 16 ? nir_op_i2b16 : nir_op_i2b32;
  172       break;
  173 
  174    case nir_op_b2b1:
  175       /* Since the canonical bit size is the size of the src, it's a no-op */
  176       opcode = nir_op_mov;
  177       break;
  178 
  179    case nir_op_b2b32:
  180       /* For up-converting booleans, sign-extend */
  181       opcode = nir_op_i2i32;
  182       break;
  183 
  184    case nir_op_flt:
  185       opcode = bit_size == 8 ? nir_op_flt8 :
  186                                bit_size == 16 ? nir_op_flt16 : nir_op_flt32;
  187       break;
  188 
  189    case nir_op_fge:
  190       opcode = bit_size == 8 ? nir_op_fge8 :
  191                                bit_size == 16 ? nir_op_fge16 : nir_op_fge32;
  192       break;
  193 
  194    case nir_op_feq:
  195       opcode = bit_size == 8 ? nir_op_feq8 :
  196                                bit_size == 16 ? nir_op_feq16 : nir_op_feq32;
  197       break;
  198 
  199    case nir_op_fne:
  200       opcode = bit_size == 8 ? nir_op_fne8 :
  201                                bit_size == 16 ? nir_op_fne16 : nir_op_fne32;
  202       break;
  203 
  204    case nir_op_ilt:
  205       opcode = bit_size == 8 ? nir_op_ilt8 :
  206                                bit_size == 16 ? nir_op_ilt16 : nir_op_ilt32;
  207       break;
  208 
  209    case nir_op_ige:
  210       opcode = bit_size == 8 ? nir_op_ige8 :
  211                                bit_size == 16 ? nir_op_ige16 : nir_op_ige32;
  212       break;
  213 
  214    case nir_op_ieq:
  215       opcode = bit_size == 8 ? nir_op_ieq8 :
  216                                bit_size == 16 ? nir_op_ieq16 : nir_op_ieq32;
  217       break;
  218 
  219    case nir_op_ine:
  220       opcode = bit_size == 8 ? nir_op_ine8 :
  221                                bit_size == 16 ? nir_op_ine16 : nir_op_ine32;
  222       break;
  223 
  224    case nir_op_ult:
  225       opcode = bit_size == 8 ? nir_op_ult8 :
  226                                bit_size == 16 ? nir_op_ult16 : nir_op_ult32;
  227       break;
  228 
  229    case nir_op_uge:
  230       opcode = bit_size == 8 ? nir_op_uge8 :
  231                                bit_size == 16 ? nir_op_uge16 : nir_op_uge32;
  232       break;
  233 
  234    case nir_op_ball_fequal2:
  235       opcode = bit_size == 8 ? nir_op_b8all_fequal2 :
  236                                bit_size == 16 ? nir_op_b16all_fequal2 :
  237                                                 nir_op_b32all_fequal2;
  238       break;
  239 
  240    case nir_op_ball_fequal3:
  241       opcode = bit_size == 8 ? nir_op_b8all_fequal3 :
  242                                bit_size == 16 ? nir_op_b16all_fequal3 :
  243                                                 nir_op_b32all_fequal3;
  244       break;
  245 
  246    case nir_op_ball_fequal4:
  247       opcode = bit_size == 8 ? nir_op_b8all_fequal4 :
  248                                bit_size == 16 ? nir_op_b16all_fequal4 :
  249                                                 nir_op_b32all_fequal4;
  250       break;
  251 
  252    case nir_op_bany_fnequal2:
  253       opcode = bit_size == 8 ? nir_op_b8any_fnequal2 :
  254                                bit_size == 16 ? nir_op_b16any_fnequal2 :
  255                                                 nir_op_b32any_fnequal2;
  256       break;
  257 
  258    case nir_op_bany_fnequal3:
  259       opcode = bit_size == 8 ? nir_op_b8any_fnequal3 :
  260                                bit_size == 16 ? nir_op_b16any_fnequal3 :
  261                                                 nir_op_b32any_fnequal3;
  262       break;
  263 
  264    case nir_op_bany_fnequal4:
  265       opcode = bit_size == 8 ? nir_op_b8any_fnequal4 :
  266                                bit_size == 16 ? nir_op_b16any_fnequal4 :
  267                                                 nir_op_b32any_fnequal4;
  268       break;
  269 
  270    case nir_op_ball_iequal2:
  271       opcode = bit_size == 8 ? nir_op_b8all_iequal2 :
  272                                bit_size == 16 ? nir_op_b16all_iequal2 :
  273                                                 nir_op_b32all_iequal2;
  274       break;
  275 
  276    case nir_op_ball_iequal3:
  277       opcode = bit_size == 8 ? nir_op_b8all_iequal3 :
  278                                bit_size == 16 ? nir_op_b16all_iequal3 :
  279                                                 nir_op_b32all_iequal3;
  280       break;
  281 
  282    case nir_op_ball_iequal4:
  283       opcode = bit_size == 8 ? nir_op_b8all_iequal4 :
  284                                bit_size == 16 ? nir_op_b16all_iequal4 :
  285                                                 nir_op_b32all_iequal4;
  286       break;
  287 
  288    case nir_op_bany_inequal2:
  289       opcode = bit_size == 8 ? nir_op_b8any_inequal2 :
  290                                bit_size == 16 ? nir_op_b16any_inequal2 :
  291                                                 nir_op_b32any_inequal2;
  292       break;
  293 
  294    case nir_op_bany_inequal3:
  295       opcode = bit_size == 8 ? nir_op_b8any_inequal3 :
  296                                bit_size == 16 ? nir_op_b16any_inequal3 :
  297                                                 nir_op_b32any_inequal3;
  298       break;
  299 
  300    case nir_op_bany_inequal4:
  301       opcode = bit_size == 8 ? nir_op_b8any_inequal4 :
  302                                bit_size == 16 ? nir_op_b16any_inequal4 :
  303                                                 nir_op_b32any_inequal4;
  304       break;
  305 
  306    case nir_op_bcsel:
  307       opcode = bit_size == 8 ? nir_op_b8csel :
  308                                bit_size == 16 ? nir_op_b16csel : nir_op_b32csel;
  309 
  310       /* The destination of the selection may have a different bit-size from
  311        * the bcsel condition.
  312        */
  313       bit_size = nir_src_bit_size(alu->src[1].src);
  314       break;
  315 
  316    default:
  317       assert(alu->dest.dest.ssa.bit_size > 1);
  318       for (unsigned i = 0; i < op_info->num_inputs; i++)
  319          assert(alu->src[i].src.ssa->bit_size > 1);
  320       return false;
  321    }
  322 
  323    alu->op = opcode;
  324 
  325    if (alu->dest.dest.ssa.bit_size == 1)
  326       alu->dest.dest.ssa.bit_size = bit_size;
  327 
  328    return true;
  329 }
  330 
  331 static bool
  332 lower_load_const_instr(nir_load_const_instr *load)
  333 {
  334    bool progress = false;
  335 
  336    if (load->def.bit_size > 1)
  337       return progress;
  338 
  339    /* TODO: It is not clear if there is any case in which we can ever hit
  340     * this path, so for now we just provide a 32-bit default.
  341     *
  342     * TODO2: after some changed on nir_const_value and other on upstream, we
  343     * removed the initialization of a general value like this:
  344     *   nir_const_value value = load->value
  345     *
  346     * to initialize per value component. Need to confirm if that is correct,
  347     * but look at the TOO before.
  348     */
  349    for (unsigned i = 0; i < load->def.num_components; i++) {
  350       load->value[i].u32 = load->value[i].b ? NIR_TRUE : NIR_FALSE;
  351       load->def.bit_size = 32;
  352       progress = true;
  353    }
  354 
  355    return progress;
  356 }
  357 
  358 static bool
  359 lower_phi_instr(nir_builder *b, nir_phi_instr *phi)
  360 {
  361    if (nir_dest_bit_size(phi->dest) != 1)
  362       return false;
  363 
  364    /* Ensure all phi sources have a canonical bit-size. We choose the
  365     * bit-size of the first phi source as the canonical form.
  366     *
  367     * TODO: maybe we can be smarter about how we choose the canonical form.
  368     */
  369    uint32_t dst_bit_size = 0;
  370    nir_foreach_phi_src(phi_src, phi) {
  371       uint32_t src_bit_size = nir_src_bit_size(phi_src->src);
  372       if (dst_bit_size == 0) {
  373          dst_bit_size = src_bit_size;
  374       } else if (src_bit_size != dst_bit_size) {
  375          assert(phi_src->src.is_ssa);
  376          b->cursor = nir_before_src(&phi_src->src, false);
  377          nir_op convert_op = get_bool_convert_opcode(dst_bit_size);
  378          nir_ssa_def *new_src =
  379             nir_build_alu(b, convert_op, phi_src->src.ssa, NULL, NULL, NULL);
  380          nir_instr_rewrite_src(&phi->instr, &phi_src->src,
  381                                nir_src_for_ssa(new_src));
  382       }
  383    }
  384 
  385    phi->dest.ssa.bit_size = dst_bit_size;
  386 
  387    return true;
  388 }
  389 
  390 static bool
  391 nir_lower_bool_to_bitsize_impl(nir_builder *b, nir_function_impl *impl)
  392 {
  393    bool progress = false;
  394 
  395    nir_foreach_block(block, impl) {
  396       nir_foreach_instr_safe(instr, block) {
  397          switch (instr->type) {
  398          case nir_instr_type_alu:
  399             progress |= lower_alu_instr(b, nir_instr_as_alu(instr));
  400             break;
  401 
  402          case nir_instr_type_load_const:
  403             progress |= lower_load_const_instr(nir_instr_as_load_const(instr));
  404             break;
  405 
  406          case nir_instr_type_phi:
  407             progress |= lower_phi_instr(b, nir_instr_as_phi(instr));
  408             break;
  409 
  410          case nir_instr_type_ssa_undef:
  411          case nir_instr_type_intrinsic:
  412          case nir_instr_type_tex:
  413             nir_foreach_ssa_def(instr, rewrite_1bit_ssa_def_to_32bit,
  414                                 &progress);
  415             break;
  416 
  417          default:
  418             nir_foreach_ssa_def(instr, assert_ssa_def_is_not_1bit, NULL);
  419          }
  420       }
  421    }
  422 
  423    if (progress) {
  424       nir_metadata_preserve(impl, nir_metadata_block_index |
  425                                   nir_metadata_dominance);
  426    }
  427 
  428    return progress;
  429 }
  430 
  431 bool
  432 nir_lower_bool_to_bitsize(nir_shader *shader)
  433 {
  434    nir_builder b;
  435    bool progress = false;
  436 
  437    nir_foreach_function(function, shader) {
  438       if (function->impl) {
  439          nir_builder_init(&b, function->impl);
  440          progress = nir_lower_bool_to_bitsize_impl(&b, function->impl) || progress;
  441       }
  442    }
  443 
  444    return progress;
  445 }