7sizeof = {
"float": 4,
"at::Half": 2,
"uint8_t": 1}
10def unroll(uf, IndexType, InType, OutType, use_weights, isa, fused, use_offsets):
11 def compute(regid, InType, use_weights, isa, prefetch):
16 " vop%d = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (%d)), vop%d);"
17 % (regid, regid, regid)
19 elif InType ==
"at::Half":
21 " vop%d = _mm256_fmadd_ps(\n"
24 " _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (%d)))),\n"
25 " vop%d);" % (regid, regid, regid)
27 elif InType ==
"uint8_t":
29 " vop%d = _mm256_fmadd_ps(\n"
31 " _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(\n"
32 " _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (%d))))),\n"
33 " _mm256_add_ps(vop%d, vbio));" % (regid, regid, regid)
41 " reinterpret_cast<const char*>(&ip_next_T0[%d]), _MM_HINT_T0);"
46 " // skip unnecessary prefetch of (&ip_next_T0[%d])" % (regid)
52 code.append(
" // unrolling " +
str(uf) +
" times")
58 +
" rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {"
64 +
" rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {"
67 code.append(
" " + OutType +
"* op = &out[rangeIndex * block_size];")
68 for i
in range(0, uf):
70 code.append(
" __m256 vop" +
str(j) +
" = _mm256_setzero_ps();")
75 " if (dataInd != offsets[rangeIndex] - offsets[0]) {\n"
80 int64_t end_offset = offsets[rangeIndex + 1];
81 int64_t length = end_offset - offsets[rangeIndex];""")
85 +
" start = dataInd; dataInd < end_offset - offsets[0];\n ++dataInd) {"
89 " if (dataInd + lengths[rangeIndex] > index_size) {\n"
96 +
" start = dataInd; dataInd < start + lengths[rangeIndex];\n ++dataInd) {"
98 code.append(
" const " + IndexType +
" idx = indices[dataInd];")
100 " if (idx < 0 || idx >= data_size) {\n"
105 if InType ==
"uint8_t":
106 code.append(
" " + OutType +
" wgt = 1.f;")
107 code.append(
" " + OutType +
" bio;")
108 code.append(
" if (weights) {")
110 " wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];"
115 " const float* scale_bias = reinterpret_cast<const float*>(\n"
116 " &input[idx * fused_block_size + block_size]);"
118 code.append(
" bio = wgt * scale_bias[1];")
119 code.append(
" wgt = wgt * scale_bias[0];")
121 code.append(
" bio = wgt * scale_bias[2 * idx + 1];")
122 code.append(
" wgt = wgt * scale_bias[2 * idx];")
123 code.append(
" __m256 vbio = _mm256_set1_ps(bio);")
125 code.append(
" " + OutType +
" wgt = 1.f;")
126 code.append(
" if (weights) {")
128 " wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];"
131 code.append(
" __m256 vwgt = _mm256_set1_ps(wgt);")
133 code.append(
" const {}* ip = &input[idx * fused_block_size];".
format(InType))
135 " const {} next_T0 = (dataInd < index_size - prefdist_T0)\n"
136 " ? (dataInd + prefdist_T0)\n : dataInd;".
format(
140 code.append(
" const " + IndexType +
" idx_pref_T0 = indices[next_T0];")
142 " if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {\n"
148 " const {}* ip_next_T0 = "
149 "&input[idx_pref_T0 * fused_block_size];".
format(InType)
152 for i
in range(0, uf):
155 byteoffset = sizeof[InType] * j
156 prefetch = (byteoffset % cachelinesize) == 0
157 code.extend(compute(j, InType, use_weights, isa, prefetch))
161 code.append(
" if (!normalize_by_lengths || length == 0) {")
163 code.append(
" if (!normalize_by_lengths || lengths[rangeIndex] == 0) {")
164 for i
in range(0, uf):
166 code.append(
" _mm256_storeu_ps(&op[" +
str(j) +
"], vop" +
str(j) +
");")
167 code.append(
" } else {")
170 code.append(
" __m256 vlen_inv = _mm256_set1_ps(1.0f / length);")
172 code.append(
" __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);")
173 for i
in range(0, uf):
176 " _mm256_storeu_ps(&op["
178 +
"], _mm256_mul_ps("
189def generic(IndexType, InType, OutType, use_weights, isa, fused, use_offsets):
190 def compute(InType, use_weights, isa):
192 if InType ==
"float":
194 " _mm256_storeu_ps(\n"
196 " _mm256_fmadd_ps(\n"
197 " vwgt, _mm256_loadu_ps(&ip[j]), _mm256_loadu_ps(&op[j])));"
199 elif InType ==
"at::Half":
201 " _mm256_storeu_ps(\n"
203 " _mm256_fmadd_ps(\n"
205 " _mm256_cvtph_ps(_mm_loadu_si128(\n"
206 " reinterpret_cast<const __m128i*>(&ip[j]))),\n"
207 " _mm256_loadu_ps(&op[j])));"
209 elif InType ==
"uint8_t":
211 " _mm256_storeu_ps(\n"
213 " _mm256_fmadd_ps(\n"
215 " _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(\n"
216 " reinterpret_cast<const __m128i*>(&ip[j])))),\n"
217 " _mm256_add_ps(_mm256_loadu_ps(&op[j]), vbio)));"
224 " reinterpret_cast<const char*>(&ip_next_T0[j]), _MM_HINT_T0);"
230 if InType ==
"at::Half":
231 code.append(
" alignas(64) at::Half vtmp1[8] = {0};")
239 +
" rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {"
245 +
" rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {"
248 code.append(
" " + OutType +
"* op = &out[rangeIndex * block_size];")
251 code.append(
" int64_t j = 0;")
252 code.append(
" for (; j + 8 <= block_size; j += 8) {")
253 code.append(
" _mm256_storeu_ps(op + j, _mm256_setzero_ps());")
255 code.append(
" for (; j < block_size; j++) {")
256 code.append(
" op[j] = 0.0f;")
262 " if (dataInd != offsets[rangeIndex] - offsets[0]) {\n"
267 int64_t end_offset = offsets[rangeIndex + 1];
268 int64_t length = end_offset - offsets[rangeIndex];""")
272 +
" start = dataInd; dataInd < end_offset - offsets[0];\n ++dataInd) {"
276 " if (dataInd + lengths[rangeIndex] > index_size) {\n"
283 +
" start = dataInd; dataInd < start + lengths[rangeIndex];\n ++dataInd) {"
285 code.append(
" const " + IndexType +
" idx = indices[dataInd];")
287 " if (idx < 0 || idx >= data_size) {\n"
292 if InType ==
"uint8_t":
293 code.append(
" " + OutType +
" wgt = 1.f;")
294 code.append(
" " + OutType +
" bio;")
295 code.append(
" if (weights) {")
297 " wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];"
302 " const float* scale_bias = reinterpret_cast<const float*>(\n"
303 " &input[idx * fused_block_size + block_size]);"
305 code.append(
" bio = wgt * scale_bias[1];")
306 code.append(
" wgt = wgt * scale_bias[0];")
308 code.append(
" bio = wgt * scale_bias[2 * idx + 1];")
309 code.append(
" wgt = wgt * scale_bias[2 * idx];")
310 code.append(
" __m256 vbio = _mm256_set1_ps(bio);")
312 code.append(
" " + OutType +
" wgt = 1.f;")
313 code.append(
" if (weights) {")
315 " wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];"
318 code.append(
" __m256 vwgt = _mm256_set1_ps(wgt);")
320 code.append(
" const {}* ip = &input[idx * fused_block_size];".
format(InType))
322 " const {} next_T0 = (dataInd < index_size - prefdist_T0)\n"
323 " ? (dataInd + prefdist_T0)\n : dataInd;".
format(
327 code.append(
" const " + IndexType +
" idx_pref_T0 = indices[next_T0];")
329 " if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {\n"
334 " const {}* ip_next_T0 = "
335 "&input[idx_pref_T0 * fused_block_size];".
format(InType)
339 code.append(
" j = 0;")
340 code.append(
" for (; j + 8 <= block_size; j += 8) {")
341 code.extend(compute(InType, use_weights, isa))
344 code.append(
" for (; j < block_size; j++) {")
345 if InType ==
"float":
346 code.append(
" op[j] = std::fma(wgt, ip[j], op[j]);")
347 elif InType ==
"at::Half":
348 code.append(
" vtmp1[0] = ip[j];")
351 " _mm256_cvtph_ps(*(reinterpret_cast<const __m128i*>(vtmp1)));"
353 code.append(
" op[j] = std::fma(wgt, ((float*)(&vtmp2))[0], op[j]);")
354 elif InType ==
"uint8_t":
355 code.append(
" op[j] = std::fma(wgt, (float)ip[j], bio + op[j]);")
364 code.append(
" if (normalize_by_lengths && length) {")
365 code.append(
" float len_inv = 1.0f / length;")
367 code.append(
" if (normalize_by_lengths && lengths[rangeIndex]) {")
368 code.append(
" float len_inv = 1.0f / lengths[rangeIndex];")
369 code.append(
" __m256 vlen_inv = _mm256_set1_ps(len_inv);")
370 code.append(
" j = 0;")
371 code.append(
" for (; j + 8 <= block_size; j += 8) {")
373 " _mm256_storeu_ps(\n"
374 " &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));"
377 code.append(
" for (; j < block_size; j++) {")
378 code.append(
" op[j] = len_inv * op[j];")
388parser = argparse.ArgumentParser()
389parser.add_argument(
"-f",
"--filename", help=
"file name")
390parser.add_argument(
"--fused", action=
"store_true")
391parser.add_argument(
"--use-offsets", action=
"store_true")
392opts = parser.parse_args()
394 filename = opts.filename
397 filename =
"embedding_lookup_fused_8bit_rowwise_idx_avx2.cc"
399 filename =
"embedding_lookup_fused_8bit_rowwise_avx2.cc"
402 filename =
"embedding_lookup_idx_avx2.cc"
404 filename =
"embedding_lookup_avx2.cc"
407 [
"int32_t",
"int",
"float",
"float",
"float",
"float"],
408 [
"int64_t",
"int64_t",
"float",
"float",
"float",
"float"],
409 [
"int32_t",
"int",
"half",
"at::Half",
"float",
"float"],
410 [
"int64_t",
"int64_t",
"half",
"at::Half",
"float",
"float"],
411 [
"int32_t",
"int",
"uint8_t",
"uint8_t",
"float",
"float"],
412 [
"int64_t",
"int64_t",
"uint8_t",
"uint8_t",
"float",
"float"],
417code.append(
"//// --------------------------")
418code.append(
"//// ATTENTION:")
419code.append(
"//// THIS CODE IS AUTOGENERATED")
420code.append(
"//// BY {}".
format(sys.argv[0]))
421code.append(
"//// DO NOT MODIFY!!!")
422code.append(
"//// --------------------------\n")
424code.append(
"#include <c10/util/Half.h>")
425code.append(
"#include <immintrin.h>")
427code.append(
"namespace caffe2 {\n")
429 [IndexTypeName, IndexType, InTypeName, InType, OutTypeName, OutType] = o
431 prefix =
"Fused8BitRowwise" if opts.fused
else ""
432 code.append(
"template <bool IS_WEIGHT_POSITIONAL>")
434 fn_base =
"{}EmbeddingLookupIdx_{}_{}_{}".
format(
435 prefix, IndexTypeName, InTypeName, OutTypeName
438 fn_base =
"{}EmbeddingLookup_{}_{}_{}".
format(
439 prefix, IndexTypeName, InTypeName, OutTypeName
441 suffix =
"__avx2_fma"
442 fn =
"static bool " + fn_base + suffix
443 code.append(fn +
"(")
446 args.append(
" const int64_t block_size,")
447 args.append(
" const int64_t output_size,")
448 args.append(
" const int64_t index_size,")
449 args.append(
" const int64_t data_size,")
450 args.append(
" const " + InType +
"* input,")
451 args.append(
" const " + IndexType +
"* indices,")
453 args.append(
" const " + IndexType +
"* offsets,")
455 args.append(
" const int* lengths,")
456 args.append(
" const float* weights,")
458 args.append(
" const float* scale_bias,")
459 args.append(
" bool normalize_by_lengths,")
460 args.append(
" " + OutType +
"* out) {")
463 code.append(
" const " + IndexType +
" prefdist_T0 = 16;")
466 offset = (8 // sizeof[InType])
if opts.fused
else 0
468 " const {} fused_block_size = block_size + {};".
format(IndexType, offset)
471 code.append(
" int64_t dataInd = 0;")
473 code.append(
" " + IndexType +
" dataInd = 0;")
477 code.append(
" if (block_size == 128) {")
478 code +=
unroll(16, IndexType, InType, OutType,
True,
"AVX2", opts.fused, opts.use_offsets)
479 code.append(
" } else if (block_size == 64) {")
480 code +=
unroll(8, IndexType, InType, OutType,
True,
"AVX2", opts.fused, opts.use_offsets)
481 code.append(
" } else if (block_size == 32) {")
482 code +=
unroll(4, IndexType, InType, OutType,
True,
"AVX2", opts.fused, opts.use_offsets)
483 code.append(
" } else if (block_size == 16) {")
484 code +=
unroll(2, IndexType, InType, OutType,
True,
"AVX2", opts.fused, opts.use_offsets)
485 code.append(
" } else {")
486 code.append(
" // generic code")
487 code +=
generic(IndexType, InType, OutType,
True,
"AVX2", opts.fused, opts.use_offsets)
489 code.append(
" return dataInd == index_size;")
493 for is_weight_positional
in [
"false",
"true"]:
494 code.append(
"bool " + fn_base +
"_" + is_weight_positional + suffix +
"(")
498 ret_string =
" return " + fn_base + suffix +
"<" + is_weight_positional +
">("
499 if len(ret_string) <= 80:
500 code.append(ret_string)
502 code.append(
" return " + fn_base + suffix +
"<" + extra_space + is_weight_positional +
">(")
503 code.append(
" block_size,")
504 code.append(
" output_size,")
505 code.append(
" index_size,")
506 code.append(
" data_size,")
507 code.append(
" input,")
508 code.append(
" indices,")
510 code.append(
" offsets,")
512 code.append(
" lengths,")
513 code.append(
" weights,")
515 code.append(
" scale_bias,")
516 code.append(
" normalize_by_lengths,")
517 code.append(
" out);")
522code.append(
"} // namespace caffe2")
524with open(filename,
"w")
as fout:
530print(
"Created " + filename)
std::ostream & print(std::ostream &stream, const Tensor &tensor_, int64_t linesize)
constexpr Symbol len(static_cast< unique_t >(_keys::aten_len))
constexpr Symbol range(static_cast< unique_t >(_keys::prim_range))
def unroll(uf, IndexType, InType, OutType, use_weights, isa, fused, use_offsets)
def generic(IndexType, InType, OutType, use_weights, isa, fused, use_offsets)
returning a scalar tensor containing a pointer to it The timer is stopped by calling **TimerEnd **Github str
bounding box regression result deltas as well as predefined bounding box shapes anchors Greedy non maximum suppression is applied to generate the final bounding boxes DOC int RPN_PRE_NMS_TOP_N float RPN_NMS_THRESH for rotated angle is normalized to be within[angle_bound_lo, angle_bound_hi] for rotated angle is normalized to be within[angle_bound_lo, angle_bound_hi] Scores from conv Bounding box deltas from conv Image format(height, width, scale)") .Input(3