"Fossies" - the Fresh Open Source Software Archive

Member "pytorch-1.8.2/tools/fast_nvcc/fast_nvcc.py" (23 Jul 2021, 15528 Bytes) of package /linux/misc/pytorch-1.8.2.tar.gz:


As a special service "Fossies" has tried to format the requested source page into HTML format using (guessed) Python source code syntax highlighting (style: standard) with prefixed line numbers. Alternatively you can here view or download the uninterpreted source code file. For more information about "fast_nvcc.py" see the Fossies "Dox" file reference documentation and the last Fossies "Diffs" side-by-side code changes report: 1.12.1_vs_1.13.0.

    1 #!/usr/bin/env python3
    2 
    3 import argparse
    4 import asyncio
    5 import collections
    6 import csv
    7 import hashlib
    8 import itertools
    9 import os
   10 import pathlib
   11 import re
   12 import shlex
   13 import shutil
   14 import subprocess
   15 import sys
   16 import time
   17 
   18 
   19 help_msg = '''fast_nvcc [OPTION]... -- [NVCC_ARG]...
   20 
   21 Run the commands given by nvcc --dryrun, in parallel.
   22 
   23 All flags for this script itself (see the "optional arguments" section
   24 of --help) must be passed before the first "--". Everything after that
   25 first "--" is passed directly to nvcc, with the --dryrun argument added.
   26 
   27 This script only works with the "normal" execution path of nvcc, so for
   28 instance passing --help (after "--") doesn't work since the --help
   29 execution path doesn't compile anything, so adding --dryrun there gives
   30 nothing in stderr.
   31 '''
   32 parser = argparse.ArgumentParser(help_msg)
   33 parser.add_argument(
   34     '--faithful',
   35     action='store_true',
   36     help="don't modify the commands given by nvcc (slower)",
   37 )
   38 parser.add_argument(
   39     '--graph',
   40     metavar='FILE.dot',
   41     help='write Graphviz DOT file with execution graph',
   42 )
   43 parser.add_argument(
   44     '--nvcc',
   45     metavar='PATH',
   46     default='nvcc',
   47     help='path to nvcc (default is just "nvcc")',
   48 )
   49 parser.add_argument(
   50     '--save',
   51     metavar='DIR',
   52     help='copy intermediate files from each command into DIR',
   53 )
   54 parser.add_argument(
   55     '--sequential',
   56     action='store_true',
   57     help='sequence commands instead of using the graph (slower)',
   58 )
   59 parser.add_argument(
   60     '--table',
   61     metavar='FILE.csv',
   62     help='write CSV with times and intermediate file sizes',
   63 )
   64 parser.add_argument(
   65     '--verbose',
   66     metavar='FILE.txt',
   67     help='like nvcc --verbose, but expanded and into a file',
   68 )
   69 default_config = parser.parse_args([])
   70 
   71 
   72 # docs about temporary directories used by NVCC
   73 url_base = 'https://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/index.html'
   74 url_vars = f'{url_base}#keeping-intermediate-phase-files'
   75 
   76 
   77 # regex for temporary file names
   78 re_tmp = r'(?<![\w\-/])(?:/tmp/)?(tmp[^ \"\'\\]+)'
   79 
   80 
   81 def fast_nvcc_warn(warning):
   82     """
   83     Warn the user about something regarding fast_nvcc.
   84     """
   85     print(f'warning (fast_nvcc): {warning}', file=sys.stderr)
   86 
   87 
   88 def warn_if_windows():
   89     """
   90     Warn the user that using fast_nvcc on Windows might not work.
   91     """
   92     # use os.name instead of platform.system() because there is a
   93     # platform.py file in this directory, making it very difficult to
   94     # import the platform module from the Python standard library
   95     if os.name == 'nt':
   96         fast_nvcc_warn("untested on Windows, might not work; see this URL:")
   97         fast_nvcc_warn(url_vars)
   98 
   99 
  100 def warn_if_tmpdir_flag(args):
  101     """
  102     Warn the user that using fast_nvcc with some flags might not work.
  103     """
  104     file_path_specs = 'file-and-path-specifications'
  105     guiding_driver = 'options-for-guiding-compiler-driver'
  106     scary_flags = {
  107         '--objdir-as-tempdir': file_path_specs,
  108         '-objtemp': file_path_specs,
  109         '--keep': guiding_driver,
  110         '-keep': guiding_driver,
  111         '--keep-dir': guiding_driver,
  112         '-keep-dir': guiding_driver,
  113         '--save-temps': guiding_driver,
  114         '-save-temps': guiding_driver,
  115     }
  116     for arg in args:
  117         for flag, frag in scary_flags.items():
  118             if re.match(fr'^{re.escape(flag)}(?:=.*)?$', arg):
  119                 fast_nvcc_warn(f'{flag} not supported since it interacts with')
  120                 fast_nvcc_warn('TMPDIR, so fast_nvcc may break; see this URL:')
  121                 fast_nvcc_warn(f'{url_base}#{frag}')
  122 
  123 
  124 def nvcc_dryrun_data(binary, args):
  125     """
  126     Return parsed environment variables and commands from nvcc --dryrun.
  127     """
  128     result = subprocess.run(
  129         [binary, '--dryrun'] + args,
  130         capture_output=True,
  131         encoding='ascii',  # this is just a guess
  132     )
  133     print(result.stdout, end='')
  134     env = {}
  135     commands = []
  136     for line in result.stderr.splitlines():
  137         match = re.match(r'^#\$ (.*)$', line)
  138         if match:
  139             stripped, = match.groups()
  140             mapping = re.match(r'^(\w+)=(.*)$', stripped)
  141             if mapping:
  142                 name, val = mapping.groups()
  143                 env[name] = val
  144             else:
  145                 commands.append(stripped)
  146         else:
  147             print(line, file=sys.stderr)
  148     return {'env': env, 'commands': commands, 'exit_code': result.returncode}
  149 
  150 
  151 def warn_if_tmpdir_set(env):
  152     """
  153     Warn the user that setting TMPDIR with fast_nvcc might not work.
  154     """
  155     if os.getenv('TMPDIR') or 'TMPDIR' in env:
  156         fast_nvcc_warn("TMPDIR is set, might not work; see this URL:")
  157         fast_nvcc_warn(url_vars)
  158 
  159 
  160 def contains_non_executable(commands):
  161     for command in commands:
  162         # This is to deal with special command dry-run result from NVCC such as:
  163         # ```
  164         # #$ "/lib64/ccache"/c++ -std=c++11 -E -x c++ -D__CUDACC__ -D__NVCC__  -fPIC -fvisibility=hidden -O3 \
  165         #   -I ... -m64 "reduce_scatter.cu" > "/tmp/tmpxft_0037fae3_00000000-5_reduce_scatter.cpp4.ii
  166         # #$ -- Filter Dependencies -- > ... pytorch/build/nccl/obj/collectives/device/reduce_scatter.dep.tmp
  167         # ```
  168         if command.startswith("--"):
  169             return True
  170     return False
  171 
  172 
  173 def module_id_contents(command):
  174     """
  175     Guess the contents of the .module_id file contained within command.
  176     """
  177     if command[0] == 'cicc':
  178         path = command[-3]
  179     elif command[0] == 'cudafe++':
  180         path = command[-1]
  181     middle = pathlib.PurePath(path).name.replace('-', '_').replace('.', '_')
  182     # this suffix is very wrong (the real one is far less likely to be
  183     # unique), but it seems difficult to find a rule that reproduces the
  184     # real suffixes, so here's one that, while inaccurate, is at least
  185     # hopefully as straightforward as possible
  186     suffix = hashlib.md5(str.encode(middle)).hexdigest()[:8]
  187     return f'_{len(middle)}_{middle}_{suffix}'
  188 
  189 
  190 def unique_module_id_files(commands):
  191     """
  192     Give each command its own .module_id filename instead of sharing.
  193     """
  194     module_id = None
  195     uniqueified = []
  196     for i, line in enumerate(commands):
  197         arr = []
  198 
  199         def uniqueify(s):
  200             filename = re.sub(r'\-(\d+)', r'-\1-' + str(i), s.group(0))
  201             arr.append(filename)
  202             return filename
  203 
  204         line = re.sub(re_tmp + r'.module_id', uniqueify, line)
  205         line = re.sub(r'\s*\-\-gen\_module\_id\_file\s*', ' ', line)
  206         if arr:
  207             filename, = arr
  208             if not module_id:
  209                 module_id = module_id_contents(shlex.split(line))
  210             uniqueified.append(f"echo -n '{module_id}' > '{filename}'")
  211         uniqueified.append(line)
  212     return uniqueified
  213 
  214 
  215 def make_rm_force(commands):
  216     """
  217     Add --force to all rm commands.
  218     """
  219     return [f'{c} --force' if c.startswith('rm ') else c for c in commands]
  220 
  221 
  222 def print_verbose_output(*, env, commands, filename):
  223     """
  224     Human-readably write nvcc --dryrun data to stderr.
  225     """
  226     padding = len(str(len(commands) - 1))
  227     with open(filename, 'w') as f:
  228         for name, val in env.items():
  229             print(f'#{" "*padding}$ {name}={val}', file=f)
  230         for i, command in enumerate(commands):
  231             prefix = f'{str(i).rjust(padding)}$ '
  232             print(f'#{prefix}{command[0]}', file=f)
  233             for part in command[1:]:
  234                 print(f'#{" "*len(prefix)}{part}', file=f)
  235 
  236 
  237 def straight_line_dependencies(commands):
  238     """
  239     Return a straight-line dependency graph.
  240     """
  241     return [({i - 1} if i > 0 else set()) for i in range(len(commands))]
  242 
  243 
  244 def files_mentioned(command):
  245     """
  246     Return fully-qualified names of all tmp files referenced by command.
  247     """
  248     return [f'/tmp/{match.group(1)}' for match in re.finditer(re_tmp, command)]
  249 
  250 
  251 def nvcc_data_dependencies(commands):
  252     """
  253     Return a list of the set of dependencies for each command.
  254     """
  255     # fatbin needs to be treated specially because while the cicc steps
  256     # do refer to .fatbin.c files, they do so through the
  257     # --include_file_name option, since they're generating files that
  258     # refer to .fatbin.c file(s) that will later be created by the
  259     # fatbinary step; so for most files, we make a data dependency from
  260     # the later step to the earlier step, but for .fatbin.c files, the
  261     # data dependency is sort of flipped, because the steps that use the
  262     # files generated by cicc need to wait for the fatbinary step to
  263     # finish first
  264     tmp_files = {}
  265     fatbins = collections.defaultdict(set)
  266     graph = []
  267     for i, line in enumerate(commands):
  268         deps = set()
  269         for tmp in files_mentioned(line):
  270             if tmp in tmp_files:
  271                 dep = tmp_files[tmp]
  272                 deps.add(dep)
  273                 if dep in fatbins:
  274                     for filename in fatbins[dep]:
  275                         if filename in tmp_files:
  276                             deps.add(tmp_files[filename])
  277             if tmp.endswith('.fatbin.c') and not line.startswith('fatbinary'):
  278                 fatbins[i].add(tmp)
  279             else:
  280                 tmp_files[tmp] = i
  281         if line.startswith('rm ') and not deps:
  282             deps.add(i - 1)
  283         graph.append(deps)
  284     return graph
  285 
  286 
  287 def is_weakly_connected(graph):
  288     """
  289     Return true iff graph is weakly connected.
  290     """
  291     if not graph:
  292         return True
  293     neighbors = [set() for _ in graph]
  294     for node, predecessors in enumerate(graph):
  295         for pred in predecessors:
  296             neighbors[pred].add(node)
  297             neighbors[node].add(pred)
  298     # assume nonempty graph
  299     stack = [0]
  300     found = {0}
  301     while stack:
  302         node = stack.pop()
  303         for neighbor in neighbors[node]:
  304             if neighbor not in found:
  305                 found.add(neighbor)
  306                 stack.append(neighbor)
  307     return len(found) == len(graph)
  308 
  309 
  310 def warn_if_not_weakly_connected(graph):
  311     """
  312     Warn the user if the execution graph is not weakly connected.
  313     """
  314     if not is_weakly_connected(graph):
  315         fast_nvcc_warn('execution graph is not (weakly) connected')
  316 
  317 
  318 def print_dot_graph(*, commands, graph, filename):
  319     """
  320     Print a DOT file displaying short versions of the commands in graph.
  321     """
  322     def name(k):
  323         return f'"{k} {os.path.basename(commands[k][0])}"'
  324     with open(filename, 'w') as f:
  325         print('digraph {', file=f)
  326         # print all nodes, in case it's disconnected
  327         for i in range(len(graph)):
  328             print(f'    {name(i)};', file=f)
  329         for i, deps in enumerate(graph):
  330             for j in deps:
  331                 print(f'    {name(j)} -> {name(i)};', file=f)
  332         print('}', file=f)
  333 
  334 
  335 async def run_command(command, *, env, deps, gather_data, i, save):
  336     """
  337     Run the command with the given env after waiting for deps.
  338     """
  339     for task in deps:
  340         dep_result = await task
  341         # abort if a previous step failed
  342         if 'exit_code' not in dep_result or dep_result['exit_code'] != 0:
  343             return {}
  344     if gather_data:
  345         t1 = time.monotonic()
  346     proc = await asyncio.create_subprocess_shell(
  347         command,
  348         env=env,
  349         stdout=asyncio.subprocess.PIPE,
  350         stderr=asyncio.subprocess.PIPE,
  351     )
  352     stdout, stderr = await proc.communicate()
  353     code = proc.returncode
  354     results = {'exit_code': code, 'stdout': stdout, 'stderr': stderr}
  355     if gather_data:
  356         t2 = time.monotonic()
  357         results['time'] = t2 - t1
  358         sizes = {}
  359         for tmp_file in files_mentioned(command):
  360             if os.path.exists(tmp_file):
  361                 sizes[tmp_file] = os.path.getsize(tmp_file)
  362             else:
  363                 sizes[tmp_file] = 0
  364         results['files'] = sizes
  365     if save:
  366         dest = pathlib.Path(save) / str(i)
  367         dest.mkdir()
  368         for src in map(pathlib.Path, files_mentioned(command)):
  369             if src.exists():
  370                 shutil.copy2(src, dest / (src.name))
  371     return results
  372 
  373 
  374 async def run_graph(*, env, commands, graph, gather_data=False, save=None):
  375     """
  376     Return outputs/errors (and optionally time/file info) from commands.
  377     """
  378     tasks = []
  379     for i, (command, indices) in enumerate(zip(commands, graph)):
  380         deps = {tasks[j] for j in indices}
  381         tasks.append(asyncio.create_task(run_command(
  382             command,
  383             env=env,
  384             deps=deps,
  385             gather_data=gather_data,
  386             i=i,
  387             save=save,
  388         )))
  389     return [await task for task in tasks]
  390 
  391 
  392 def print_command_outputs(command_results):
  393     """
  394     Print captured stdout and stderr from commands.
  395     """
  396     for result in command_results:
  397         sys.stdout.write(result.get('stdout', b'').decode('ascii'))
  398         sys.stderr.write(result.get('stderr', b'').decode('ascii'))
  399 
  400 
  401 def write_log_csv(command_parts, command_results, *, filename):
  402     """
  403     Write a CSV file of the times and /tmp file sizes from each command.
  404     """
  405     tmp_files = []
  406     for result in command_results:
  407         tmp_files.extend(result.get('files', {}).keys())
  408     with open(filename, 'w', newline='') as csvfile:
  409         fieldnames = ['command', 'seconds'] + list(dict.fromkeys(tmp_files))
  410         writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
  411         writer.writeheader()
  412         for i, result in enumerate(command_results):
  413             command = f'{i} {os.path.basename(command_parts[i][0])}'
  414             row = {'command': command, 'seconds': result.get('time', 0)}
  415             writer.writerow({**row, **result.get('files', {})})
  416 
  417 
  418 def exit_code(results):
  419     """
  420     Aggregate individual exit codes into a single code.
  421     """
  422     for result in results:
  423         code = result.get('exit_code', 0)
  424         if code != 0:
  425             return code
  426     return 0
  427 
  428 
  429 def wrap_nvcc(args, config=default_config):
  430     return subprocess.call([config.nvcc] + args)
  431 
  432 
  433 def fast_nvcc(args, *, config=default_config):
  434     """
  435     Emulate the result of calling the given nvcc binary with args.
  436 
  437     Should run faster than plain nvcc.
  438     """
  439     warn_if_windows()
  440     warn_if_tmpdir_flag(args)
  441     dryrun_data = nvcc_dryrun_data(config.nvcc, args)
  442     env = dryrun_data['env']
  443     warn_if_tmpdir_set(env)
  444     commands = dryrun_data['commands']
  445     if not config.faithful:
  446         commands = make_rm_force(unique_module_id_files(commands))
  447 
  448     if contains_non_executable(commands):
  449         return wrap_nvcc(args, config)
  450 
  451     command_parts = list(map(shlex.split, commands))
  452     if config.verbose:
  453         print_verbose_output(
  454             env=env,
  455             commands=command_parts,
  456             filename=config.verbose,
  457         )
  458     graph = nvcc_data_dependencies(commands)
  459     warn_if_not_weakly_connected(graph)
  460     if config.graph:
  461         print_dot_graph(
  462             commands=command_parts,
  463             graph=graph,
  464             filename=config.graph,
  465         )
  466     if config.sequential:
  467         graph = straight_line_dependencies(commands)
  468     results = asyncio.run(run_graph(
  469         env=env,
  470         commands=commands,
  471         graph=graph,
  472         gather_data=bool(config.table),
  473         save=config.save,
  474     ))
  475     print_command_outputs(results)
  476     if config.table:
  477         write_log_csv(command_parts, results, filename=config.table)
  478     return exit_code([dryrun_data] + results)
  479 
  480 
  481 def our_arg(arg):
  482     return arg != '--'
  483 
  484 
  485 if __name__ == '__main__':
  486     argv = sys.argv[1:]
  487     us = list(itertools.takewhile(our_arg, argv))
  488     them = list(itertools.dropwhile(our_arg, argv))
  489     sys.exit(fast_nvcc(them[1:], config=parser.parse_args(us)))