pytorch  1.8.2
About: PyTorch provides Tensor computation (like NumPy) with strong GPU acceleration and Deep Neural Networks (in Python) built on a tape-based autograd system. LTS (Long Term Support) release.
  Fossies Dox: pytorch-1.8.2.tar.gz  ("unofficial" and yet experimental doxygen-generated source code documentation)  

_pytree.py
Go to the documentation of this file.
1from typing import NamedTuple, Callable, Any, Tuple, List, Dict, Type, cast, Optional
2
3"""
4Contains utility functions for working with nested python data structures.
5
6A *pytree* is Python nested data structure. It is a tree in the sense that
7nodes are Python collections (e.g., list, tuple, dict) and the leaves are
8Python values. Furthermore, a pytree should not contain reference cycles.
9
10pytrees are useful for working with nested collections of Tensors. For example,
11one can use `tree_map` to map a function over all Tensors inside some nested
12collection of Tensors and `tree_unflatten` to get a flat list of all Tensors
13inside some nested collection. pytrees are helpful for implementing nested
14collection support for PyTorch APIs.
15
16This pytree implementation is not very performant due to Python overhead
17To improve the performance we can move parts of the implementation to C++.
18"""
19
20# A NodeDef holds two callables:
21# - flatten_fn should take the collection and return a flat list of values.
22# It can also return some context that is used in reconstructing the
23# collection.
24# - unflatten_fn should take a flat list of values and some context
25# (returned by flatten_fn). It returns the collection by reconstructing
26# it from the list and the context.
27Context = Any
28PyTree = Any
29FlattenFunc = Callable[[PyTree], Tuple[List, Context]]
30UnflattenFunc = Callable[[List, Context], PyTree]
31
33 flatten_fn: FlattenFunc
34 unflatten_fn: UnflattenFunc
35
36SUPPORTED_NODES: Dict[Type[Any], NodeDef] = {}
37
38def _register_pytree_node(typ: Any, flatten_fn: FlattenFunc, unflatten_fn: UnflattenFunc) -> None:
39 SUPPORTED_NODES[typ] = NodeDef(flatten_fn, unflatten_fn)
40
41def _dict_flatten(d: Dict[Any, Any]) -> Tuple[List[Any], Context]:
42 return list(d.values()), list(d.keys())
43
44def _dict_unflatten(values: List[Any], context: Context) -> Dict[Any, Any]:
45 return {key: value for key, value in zip(context, values)}
46
47def _list_flatten(d: List[Any]) -> Tuple[List[Any], Context]:
48 return d, None
49
50def _list_unflatten(values: List[Any], context: Context) -> List[Any]:
51 return list(values)
52
53def _tuple_flatten(d: Tuple[Any, ...]) -> Tuple[List[Any], Context]:
54 return list(d), None
55
56def _tuple_unflatten(values: List[Any], context: Context) -> Tuple[Any, ...]:
57 return tuple(values)
58
59_register_pytree_node(dict, _dict_flatten, _dict_unflatten)
60_register_pytree_node(list, _list_flatten, _list_unflatten)
61_register_pytree_node(tuple, _tuple_flatten, _tuple_unflatten)
62
63
64# A leaf is defined as anything that is not a Node.
65def _is_leaf(pytree: PyTree) -> bool:
66 return type(pytree) not in SUPPORTED_NODES.keys()
67
68
69# A TreeSpec represents the structure of a pytree. It holds:
70# "type": the type of root Node of the pytree
71# context: some context that is useful in unflattening the pytree
72# children_specs: specs for each child of the root Node
73# num_leaves: the number of leaves
75 def __init__(self, typ: Any, context: Context, children_specs: List['TreeSpec']) -> None:
76 self.typetype = typ
77 self.contextcontext = context
78 self.children_specschildren_specs = children_specs
79 self.num_leavesnum_leaves: int = sum([spec.num_leaves for spec in children_specs])
80
81 def __repr__(self) -> str:
82 return f'TreeSpec({self.type.__name__}, {self.context}, {self.children_specs})'
83
84 def __eq__(self, other: Any) -> bool:
85 result = self.typetype == other.type and self.contextcontext == other.context \
86 and self.children_specschildren_specs == other.children_specs \
87 and self.num_leavesnum_leaves == other.num_leaves
88 # This should really not be necessary, but mypy errors out without it.
89 return cast(bool, result)
90
91 def __ne__(self, other: Any) -> bool:
92 return not self.__eq____eq__(other)
93
94
96 def __init__(self) -> None:
97 super().__init__(None, None, [])
98 self.num_leavesnum_leavesnum_leaves = 1
99
100 def __repr__(self) -> str:
101 return '*'
102
103
104def tree_flatten(pytree: PyTree) -> Tuple[List[Any], TreeSpec]:
105 """Flattens a pytree into a list of values and a TreeSpec that can be used
106 to reconstruct the pytree.
107 """
108 if _is_leaf(pytree):
109 return [pytree], LeafSpec()
110
111 flatten_fn = SUPPORTED_NODES[type(pytree)].flatten_fn
112 child_pytrees, context = flatten_fn(pytree)
113
114 # Recursively flatten the children
115 result : List[Any] = []
116 children_specs : List['TreeSpec'] = []
117 for child in child_pytrees:
118 flat, child_spec = tree_flatten(child)
119 result += flat
120 children_specs.append(child_spec)
121
122 return result, TreeSpec(type(pytree), context, children_specs)
123
124
125def tree_unflatten(values: List[Any], spec: TreeSpec) -> PyTree:
126 """Given a list of values and a TreeSpec, builds a pytree.
127 This is the inverse operation of `tree_flatten`.
128 """
129 if not isinstance(spec, TreeSpec):
130 raise ValueError(
131 f'tree_unflatten(values, spec): Expected `spec` to be instance of '
132 f'TreeSpec but got item of type {type(spec)}.')
133 if len(values) != spec.num_leaves:
134 raise ValueError(
135 f'tree_unflatten(values, spec): `values` has length {len(values)} '
136 f'but the spec refers to a pytree that holds {spec.num_leaves} '
137 f'items ({spec}).')
138 if isinstance(spec, LeafSpec):
139 return values[0]
140
141 unflatten_fn = SUPPORTED_NODES[spec.type].unflatten_fn
142
143 # Recursively unflatten the children
144 start = 0
145 end = 0
146 child_pytrees = []
147 for child_spec in spec.children_specs:
148 end += child_spec.num_leaves
149 child_pytrees.append(tree_unflatten(values[start:end], child_spec))
150 start = end
151
152 return unflatten_fn(child_pytrees, spec.context)
153
154
155# Broadcasts a pytree to the provided TreeSpec and returns the flattened
156# values. If this is not possible, then this function returns None.
157#
158# For example, given pytree=0 and spec=TreeSpec(list, None, [LeafSpec(), LeafSpec()]),
159# would return [0, 0]. This is useful for part of the vmap implementation:
160# a user can pass in vmap(fn, in_dims)(*inputs). `in_dims` should be
161# broadcastable to the tree structure of `inputs` and we use
162# _broadcast_to_and_flatten to check this.
163def _broadcast_to_and_flatten(pytree: PyTree, spec: TreeSpec) -> Optional[List[Any]]:
164 assert isinstance(spec, TreeSpec)
165
166 if _is_leaf(pytree):
167 return [pytree] * spec.num_leaves
168 if isinstance(spec, LeafSpec):
169 return None
170 if type(pytree) != spec.type:
171 return None
172
173 flatten_fn = SUPPORTED_NODES[type(pytree)].flatten_fn
174 child_pytrees, ctx = flatten_fn(pytree)
175
176 # Check if the Node is different from the spec
177 if len(child_pytrees) != len(spec.children_specs) or ctx != spec.context:
178 return None
179
180 # Recursively flatten the children
181 result : List[Any] = []
182 for child, child_spec in zip(child_pytrees, spec.children_specs):
183 flat = _broadcast_to_and_flatten(child, child_spec)
184 if flat is not None:
185 result += flat
186 else:
187 return None
188
189 return result
None __init__(self, Any typ, Context context, List['TreeSpec'] children_specs)
Definition: _pytree.py:75
bool __eq__(self, Any other)
Definition: _pytree.py:84
bool __ne__(self, Any other)
Definition: _pytree.py:91
T & cast(const Tensor &packed)
constexpr Symbol len(static_cast< unique_t >(_keys::aten_len))
constexpr Symbol zip(static_cast< unique_t >(_keys::prim_zip))
constexpr Symbol list(static_cast< unique_t >(_keys::prim_list))
constexpr Symbol isinstance(static_cast< unique_t >(_keys::prim_isinstance))
def NamedTuple(name_prefix, *fields)
Definition: schema.py:930
computes the sum of all elements per channel and the sum of all elements squared per channel These values can be reduced across multiple batches and used to obtain the mean and variance across the full set of batches Using the new mean and variance as input to SpatialBN has the effect of changing the batch size over which SpatialBN is applied DOC sum
List[Any] _list_unflatten(List[Any] values, Context context)
Definition: _pytree.py:50
Tuple[Any,...] _tuple_unflatten(List[Any] values, Context context)
Definition: _pytree.py:56
Tuple[List[Any], TreeSpec] tree_flatten(PyTree pytree)
Definition: _pytree.py:104
Optional[List[Any]] _broadcast_to_and_flatten(PyTree pytree, TreeSpec spec)
Definition: _pytree.py:163
bool _is_leaf(PyTree pytree)
Definition: _pytree.py:65
Tuple[List[Any], Context] _list_flatten(List[Any] d)
Definition: _pytree.py:47
Tuple[List[Any], Context] _dict_flatten(Dict[Any, Any] d)
Definition: _pytree.py:41
None _register_pytree_node(Any typ, FlattenFunc flatten_fn, UnflattenFunc unflatten_fn)
Definition: _pytree.py:38
Dict[Any, Any] _dict_unflatten(List[Any] values, Context context)
Definition: _pytree.py:44
PyTree tree_unflatten(List[Any] values, TreeSpec spec)
Definition: _pytree.py:125
Tuple[List[Any], Context] _tuple_flatten(Tuple[Any,...] d)
Definition: _pytree.py:53
int32_t type