pytorch  1.8.2
About: PyTorch provides Tensor computation (like NumPy) with strong GPU acceleration and Deep Neural Networks (in Python) built on a tape-based autograd system. LTS (Long Term Support) release.
  Fossies Dox: pytorch-1.8.2.tar.gz  ("unofficial" and yet experimental doxygen-generated source code documentation)  

hp_emblookup_codegen.py
Go to the documentation of this file.
2
3import argparse
4import sys
5
6
7sizeof = {"float": 4, "at::Half": 2, "uint8_t": 1}
8
9
10def unroll(uf, IndexType, InType, OutType, use_weights, isa, fused, use_offsets):
11 def compute(regid, InType, use_weights, isa, prefetch):
12 code = []
13
14 if InType == "float":
15 code.append(
16 " vop%d = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (%d)), vop%d);" # noqa
17 % (regid, regid, regid)
18 )
19 elif InType == "at::Half":
20 code.append(
21 " vop%d = _mm256_fmadd_ps(\n"
22 " vwgt,\n"
23 " _mm256_cvtph_ps(\n"
24 " _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (%d)))),\n" # noqa
25 " vop%d);" % (regid, regid, regid)
26 )
27 elif InType == "uint8_t":
28 code.append(
29 " vop%d = _mm256_fmadd_ps(\n"
30 " vwgt,\n"
31 " _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(\n"
32 " _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (%d))))),\n" # noqa
33 " _mm256_add_ps(vop%d, vbio));" % (regid, regid, regid)
34 )
35 else:
36 assert False
37
38 if prefetch:
39 code.append(
40 " _mm_prefetch(\n"
41 " reinterpret_cast<const char*>(&ip_next_T0[%d]), _MM_HINT_T0);"
42 % (regid)
43 )
44 else:
45 code.append(
46 " // skip unnecessary prefetch of (&ip_next_T0[%d])" % (regid)
47 )
48
49 return code
50
51 code = []
52 code.append(" // unrolling " + str(uf) + " times")
53
54 if use_offsets:
55 code.append(
56 " for ("
57 + IndexType
58 + " rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {"
59 )
60 else:
61 code.append(
62 " for ("
63 + IndexType
64 + " rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {"
65 )
66
67 code.append(" " + OutType + "* op = &out[rangeIndex * block_size];")
68 for i in range(0, uf):
69 j = 8 * i
70 code.append(" __m256 vop" + str(j) + " = _mm256_setzero_ps();")
71
72 # inner loop
73 if use_offsets:
74 code.append(
75 " if (dataInd != offsets[rangeIndex] - offsets[0]) {\n"
76 + " return false;\n"
77 + " }"
78 )
79 code.append("""\
80 int64_t end_offset = offsets[rangeIndex + 1];
81 int64_t length = end_offset - offsets[rangeIndex];""")
82 code.append(
83 " for ("
84 + "int64_t"
85 + " start = dataInd; dataInd < end_offset - offsets[0];\n ++dataInd) {" # noqa
86 )
87 else:
88 code.append(
89 " if (dataInd + lengths[rangeIndex] > index_size) {\n"
90 + " return false;\n"
91 + " }"
92 )
93 code.append(
94 " for ("
95 + IndexType
96 + " start = dataInd; dataInd < start + lengths[rangeIndex];\n ++dataInd) {" # noqa
97 )
98 code.append(" const " + IndexType + " idx = indices[dataInd];")
99 code.append(
100 " if (idx < 0 || idx >= data_size) {\n"
101 + " return false;\n"
102 + " }"
103 )
104
105 if InType == "uint8_t":
106 code.append(" " + OutType + " wgt = 1.f;")
107 code.append(" " + OutType + " bio;")
108 code.append(" if (weights) {")
109 code.append(
110 " wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];" # noqa
111 )
112 code.append(" }")
113 if fused:
114 code.append(
115 " const float* scale_bias = reinterpret_cast<const float*>(\n"
116 " &input[idx * fused_block_size + block_size]);"
117 )
118 code.append(" bio = wgt * scale_bias[1];")
119 code.append(" wgt = wgt * scale_bias[0];")
120 else:
121 code.append(" bio = wgt * scale_bias[2 * idx + 1];")
122 code.append(" wgt = wgt * scale_bias[2 * idx];")
123 code.append(" __m256 vbio = _mm256_set1_ps(bio);")
124 else:
125 code.append(" " + OutType + " wgt = 1.f;")
126 code.append(" if (weights) {")
127 code.append(
128 " wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];" # noqa
129 )
130 code.append(" }")
131 code.append(" __m256 vwgt = _mm256_set1_ps(wgt);")
132
133 code.append(" const {}* ip = &input[idx * fused_block_size];".format(InType))
134 code.append(
135 " const {} next_T0 = (dataInd < index_size - prefdist_T0)\n"
136 " ? (dataInd + prefdist_T0)\n : dataInd;".format(
137 IndexType
138 )
139 )
140 code.append(" const " + IndexType + " idx_pref_T0 = indices[next_T0];")
141 code.append(
142 " if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {\n"
143 + " return false;\n"
144 + " }"
145 )
146
147 code.append(
148 " const {}* ip_next_T0 = "
149 "&input[idx_pref_T0 * fused_block_size];".format(InType)
150 )
151
152 for i in range(0, uf):
153 j = 8 * i
154 cachelinesize = 64
155 byteoffset = sizeof[InType] * j
156 prefetch = (byteoffset % cachelinesize) == 0
157 code.extend(compute(j, InType, use_weights, isa, prefetch))
158 code.append(" }")
159
160 if use_offsets:
161 code.append(" if (!normalize_by_lengths || length == 0) {")
162 else:
163 code.append(" if (!normalize_by_lengths || lengths[rangeIndex] == 0) {")
164 for i in range(0, uf):
165 j = 8 * i
166 code.append(" _mm256_storeu_ps(&op[" + str(j) + "], vop" + str(j) + ");")
167 code.append(" } else {")
168 # inv of length
169 if use_offsets:
170 code.append(" __m256 vlen_inv = _mm256_set1_ps(1.0f / length);")
171 else:
172 code.append(" __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);")
173 for i in range(0, uf):
174 j = 8 * i
175 code.append(
176 " _mm256_storeu_ps(&op["
177 + str(j)
178 + "], _mm256_mul_ps("
179 + "vop"
180 + str(j)
181 + ", vlen_inv));"
182 )
183 code.append(" }")
184
185 code.append(" }")
186 return code
187
188
189def generic(IndexType, InType, OutType, use_weights, isa, fused, use_offsets):
190 def compute(InType, use_weights, isa):
191 code = []
192 if InType == "float":
193 code.append(
194 " _mm256_storeu_ps(\n"
195 " &op[j],\n"
196 " _mm256_fmadd_ps(\n"
197 " vwgt, _mm256_loadu_ps(&ip[j]), _mm256_loadu_ps(&op[j])));" # noqa
198 )
199 elif InType == "at::Half":
200 code.append(
201 " _mm256_storeu_ps(\n"
202 " &op[j],\n"
203 " _mm256_fmadd_ps(\n"
204 " vwgt,\n"
205 " _mm256_cvtph_ps(_mm_loadu_si128(\n"
206 " reinterpret_cast<const __m128i*>(&ip[j]))),\n"
207 " _mm256_loadu_ps(&op[j])));"
208 )
209 elif InType == "uint8_t":
210 code.append(
211 " _mm256_storeu_ps(\n"
212 " &op[j],\n"
213 " _mm256_fmadd_ps(\n"
214 " vwgt,\n"
215 " _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(\n" # noqa
216 " reinterpret_cast<const __m128i*>(&ip[j])))),\n"
217 " _mm256_add_ps(_mm256_loadu_ps(&op[j]), vbio)));"
218 )
219 else:
220 assert False
221
222 code.append(
223 " _mm_prefetch(\n"
224 " reinterpret_cast<const char*>(&ip_next_T0[j]), _MM_HINT_T0);"
225 )
226
227 return code
228
229 code = []
230 if InType == "at::Half":
231 code.append(" alignas(64) at::Half vtmp1[8] = {0};")
232
233
234
235 if use_offsets:
236 code.append(
237 " for ("
238 + IndexType
239 + " rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {"
240 )
241 else:
242 code.append(
243 " for ("
244 + IndexType
245 + " rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {"
246 )
247
248 code.append(" " + OutType + "* op = &out[rangeIndex * block_size];")
249
250 # initialize to 0
251 code.append(" int64_t j = 0;")
252 code.append(" for (; j + 8 <= block_size; j += 8) {")
253 code.append(" _mm256_storeu_ps(op + j, _mm256_setzero_ps());")
254 code.append(" }")
255 code.append(" for (; j < block_size; j++) {")
256 code.append(" op[j] = 0.0f;")
257 code.append(" }")
258
259 # inner loop
260 if use_offsets:
261 code.append(
262 " if (dataInd != offsets[rangeIndex] - offsets[0]) {\n"
263 + " return false;\n"
264 + " }"
265 )
266 code.append("""\
267 int64_t end_offset = offsets[rangeIndex + 1];
268 int64_t length = end_offset - offsets[rangeIndex];""")
269 code.append(
270 " for ("
271 + "int64_t"
272 + " start = dataInd; dataInd < end_offset - offsets[0];\n ++dataInd) {" # noqa
273 )
274 else:
275 code.append(
276 " if (dataInd + lengths[rangeIndex] > index_size) {\n"
277 + " return false;\n"
278 + " }"
279 )
280 code.append(
281 " for ("
282 + IndexType
283 + " start = dataInd; dataInd < start + lengths[rangeIndex];\n ++dataInd) {" # noqa
284 )
285 code.append(" const " + IndexType + " idx = indices[dataInd];")
286 code.append(
287 " if (idx < 0 || idx >= data_size) {\n"
288 + " return false;\n"
289 + " }"
290 )
291
292 if InType == "uint8_t":
293 code.append(" " + OutType + " wgt = 1.f;")
294 code.append(" " + OutType + " bio;")
295 code.append(" if (weights) {")
296 code.append(
297 " wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];" # noqa
298 )
299 code.append(" }")
300 if fused:
301 code.append(
302 " const float* scale_bias = reinterpret_cast<const float*>(\n"
303 " &input[idx * fused_block_size + block_size]);"
304 )
305 code.append(" bio = wgt * scale_bias[1];")
306 code.append(" wgt = wgt * scale_bias[0];")
307 else:
308 code.append(" bio = wgt * scale_bias[2 * idx + 1];")
309 code.append(" wgt = wgt * scale_bias[2 * idx];")
310 code.append(" __m256 vbio = _mm256_set1_ps(bio);")
311 else:
312 code.append(" " + OutType + " wgt = 1.f;")
313 code.append(" if (weights) {")
314 code.append(
315 " wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];" # noqa
316 )
317 code.append(" }")
318 code.append(" __m256 vwgt = _mm256_set1_ps(wgt);")
319
320 code.append(" const {}* ip = &input[idx * fused_block_size];".format(InType))
321 code.append(
322 " const {} next_T0 = (dataInd < index_size - prefdist_T0)\n"
323 " ? (dataInd + prefdist_T0)\n : dataInd;".format(
324 IndexType
325 )
326 )
327 code.append(" const " + IndexType + " idx_pref_T0 = indices[next_T0];")
328 code.append(
329 " if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {\n"
330 + " return false;\n"
331 + " }"
332 )
333 code.append(
334 " const {}* ip_next_T0 = "
335 "&input[idx_pref_T0 * fused_block_size];".format(InType)
336 )
337
338 # compute and store main loop
339 code.append(" j = 0;")
340 code.append(" for (; j + 8 <= block_size; j += 8) {")
341 code.extend(compute(InType, use_weights, isa))
342 code.append(" }")
343 # leftover
344 code.append(" for (; j < block_size; j++) {")
345 if InType == "float":
346 code.append(" op[j] = std::fma(wgt, ip[j], op[j]);")
347 elif InType == "at::Half":
348 code.append(" vtmp1[0] = ip[j];")
349 code.append(
350 " __m256 vtmp2 =\n"
351 " _mm256_cvtph_ps(*(reinterpret_cast<const __m128i*>(vtmp1)));"
352 )
353 code.append(" op[j] = std::fma(wgt, ((float*)(&vtmp2))[0], op[j]);")
354 elif InType == "uint8_t":
355 code.append(" op[j] = std::fma(wgt, (float)ip[j], bio + op[j]);")
356 else:
357 assert False
358
359 code.append(" }")
360
361 code.append(" }")
362
363 if use_offsets:
364 code.append(" if (normalize_by_lengths && length) {")
365 code.append(" float len_inv = 1.0f / length;")
366 else:
367 code.append(" if (normalize_by_lengths && lengths[rangeIndex]) {")
368 code.append(" float len_inv = 1.0f / lengths[rangeIndex];")
369 code.append(" __m256 vlen_inv = _mm256_set1_ps(len_inv);")
370 code.append(" j = 0;")
371 code.append(" for (; j + 8 <= block_size; j += 8) {")
372 code.append(
373 " _mm256_storeu_ps(\n"
374 " &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));"
375 )
376 code.append(" }")
377 code.append(" for (; j < block_size; j++) {")
378 code.append(" op[j] = len_inv * op[j];")
379 code.append(" }")
380
381 code.append(" }")
382
383 code.append(" }")
384 return code
385
386
387# start main code
388parser = argparse.ArgumentParser()
389parser.add_argument("-f", "--filename", help="file name")
390parser.add_argument("--fused", action="store_true")
391parser.add_argument("--use-offsets", action="store_true")
392opts = parser.parse_args()
393if opts.filename:
394 filename = opts.filename
395elif opts.fused:
396 if opts.use_offsets:
397 filename = "embedding_lookup_fused_8bit_rowwise_idx_avx2.cc"
398 else:
399 filename = "embedding_lookup_fused_8bit_rowwise_avx2.cc"
400else:
401 if opts.use_offsets:
402 filename = "embedding_lookup_idx_avx2.cc"
403 else:
404 filename = "embedding_lookup_avx2.cc"
405
406options = [
407 ["int32_t", "int", "float", "float", "float", "float"],
408 ["int64_t", "int64_t", "float", "float", "float", "float"],
409 ["int32_t", "int", "half", "at::Half", "float", "float"],
410 ["int64_t", "int64_t", "half", "at::Half", "float", "float"],
411 ["int32_t", "int", "uint8_t", "uint8_t", "float", "float"],
412 ["int64_t", "int64_t", "uint8_t", "uint8_t", "float", "float"],
413]
414
415code = []
416# includes
417code.append("//// --------------------------")
418code.append("//// ATTENTION:")
419code.append("//// THIS CODE IS AUTOGENERATED")
420code.append("//// BY {}".format(sys.argv[0]))
421code.append("//// DO NOT MODIFY!!!")
422code.append("//// --------------------------\n")
423
424code.append("#include <c10/util/Half.h>")
425code.append("#include <immintrin.h>")
426
427code.append("namespace caffe2 {\n")
428for o in options:
429 [IndexTypeName, IndexType, InTypeName, InType, OutTypeName, OutType] = o
430
431 prefix = "Fused8BitRowwise" if opts.fused else ""
432 code.append("template <bool IS_WEIGHT_POSITIONAL>")
433 if opts.use_offsets:
434 fn_base = "{}EmbeddingLookupIdx_{}_{}_{}".format(
435 prefix, IndexTypeName, InTypeName, OutTypeName
436 )
437 else:
438 fn_base = "{}EmbeddingLookup_{}_{}_{}".format(
439 prefix, IndexTypeName, InTypeName, OutTypeName
440 )
441 suffix = "__avx2_fma"
442 fn = "static bool " + fn_base + suffix
443 code.append(fn + "(")
444
445 args = []
446 args.append(" const int64_t block_size,")
447 args.append(" const int64_t output_size,")
448 args.append(" const int64_t index_size,")
449 args.append(" const int64_t data_size,")
450 args.append(" const " + InType + "* input,")
451 args.append(" const " + IndexType + "* indices,")
452 if opts.use_offsets:
453 args.append(" const " + IndexType + "* offsets,")
454 else:
455 args.append(" const int* lengths,")
456 args.append(" const float* weights,")
457 if not opts.fused:
458 args.append(" const float* scale_bias,")
459 args.append(" bool normalize_by_lengths,")
460 args.append(" " + OutType + "* out) {")
461 code += args
462
463 code.append(" const " + IndexType + " prefdist_T0 = 16;")
464 # block_size is the number of elements and fused_block_size is the size of
465 # an entire row, including scale and bias.
466 offset = (8 // sizeof[InType]) if opts.fused else 0
467 code.append(
468 " const {} fused_block_size = block_size + {};".format(IndexType, offset)
469 )
470 if opts.use_offsets:
471 code.append(" int64_t dataInd = 0;")
472 else:
473 code.append(" " + IndexType + " dataInd = 0;")
474
475 # code.append("printf(\"calling " + fn + "\\n\");");
476
477 code.append(" if (block_size == 128) {")
478 code += unroll(16, IndexType, InType, OutType, True, "AVX2", opts.fused, opts.use_offsets)
479 code.append(" } else if (block_size == 64) {")
480 code += unroll(8, IndexType, InType, OutType, True, "AVX2", opts.fused, opts.use_offsets)
481 code.append(" } else if (block_size == 32) {")
482 code += unroll(4, IndexType, InType, OutType, True, "AVX2", opts.fused, opts.use_offsets)
483 code.append(" } else if (block_size == 16) {")
484 code += unroll(2, IndexType, InType, OutType, True, "AVX2", opts.fused, opts.use_offsets)
485 code.append(" } else {")
486 code.append(" // generic code")
487 code += generic(IndexType, InType, OutType, True, "AVX2", opts.fused, opts.use_offsets)
488 code.append(" }")
489 code.append(" return dataInd == index_size;")
490
491 code.append("}")
492
493 for is_weight_positional in ["false", "true"]:
494 code.append("bool " + fn_base + "_" + is_weight_positional + suffix + "(")
495 code += args
496 # Resolve the Lint warnings: Limit of 80 characters in one line.
497 extra_space = "\n "
498 ret_string = " return " + fn_base + suffix + "<" + is_weight_positional + ">("
499 if len(ret_string) <= 80:
500 code.append(ret_string)
501 else:
502 code.append(" return " + fn_base + suffix + "<" + extra_space + is_weight_positional + ">(")
503 code.append(" block_size,")
504 code.append(" output_size,")
505 code.append(" index_size,")
506 code.append(" data_size,")
507 code.append(" input,")
508 code.append(" indices,")
509 if opts.use_offsets:
510 code.append(" offsets,")
511 else:
512 code.append(" lengths,")
513 code.append(" weights,")
514 if not opts.fused:
515 code.append(" scale_bias,")
516 code.append(" normalize_by_lengths,")
517 code.append(" out);")
518 code.append("}")
519
520 code.append("")
521
522code.append("} // namespace caffe2")
523
524with open(filename, "w") as fout:
525 for c in code:
526 # print(c, file = fout)
527 fout.write(c + "\n")
528
529
530print("Created " + filename)
std::ostream & print(std::ostream &stream, const Tensor &tensor_, int64_t linesize)
Definition: Formatting.cpp:230
constexpr Symbol len(static_cast< unique_t >(_keys::aten_len))
constexpr Symbol range(static_cast< unique_t >(_keys::prim_range))
def unroll(uf, IndexType, InType, OutType, use_weights, isa, fused, use_offsets)
def generic(IndexType, InType, OutType, use_weights, isa, fused, use_offsets)
returning a scalar tensor containing a pointer to it The timer is stopped by calling **TimerEnd **Github str
Definition: stats_ops.cc:239
bounding box regression result deltas as well as predefined bounding box shapes anchors Greedy non maximum suppression is applied to generate the final bounding boxes DOC int RPN_PRE_NMS_TOP_N float RPN_NMS_THRESH for rotated angle is normalized to be within[angle_bound_lo, angle_bound_hi] for rotated angle is normalized to be within[angle_bound_lo, angle_bound_hi] Scores from conv Bounding box deltas from conv Image format(height, width, scale)") .Input(3