1from typing
import NamedTuple, Callable, Any, Tuple, List, Dict, Type, cast, Optional
4Contains utility functions for working with nested python data structures.
6A *pytree* is Python nested data structure. It is a tree in the sense that
7nodes are Python collections (e.g., list, tuple, dict) and the leaves are
8Python values. Furthermore, a pytree should not contain reference cycles.
10pytrees are useful for working with nested collections of Tensors. For example,
11one can use `tree_map` to map a function over all Tensors inside some nested
12collection of Tensors and `tree_unflatten` to get a flat list of all Tensors
13inside some nested collection. pytrees are helpful for implementing nested
14collection support for PyTorch APIs.
16This pytree implementation is not very performant due to Python overhead
17To improve the performance we can move parts of the implementation to C++.
29FlattenFunc = Callable[[PyTree], Tuple[List, Context]]
30UnflattenFunc = Callable[[List, Context], PyTree]
33 flatten_fn: FlattenFunc
34 unflatten_fn: UnflattenFunc
36SUPPORTED_NODES: Dict[Type[Any], NodeDef] = {}
39 SUPPORTED_NODES[typ] =
NodeDef(flatten_fn, unflatten_fn)
42 return list(d.values()),
list(d.keys())
45 return {key: value
for key, value
in zip(context, values)}
66 return type(pytree)
not in SUPPORTED_NODES.keys()
75 def __init__(self, typ: Any, context: Context, children_specs: List[
'TreeSpec']) ->
None:
79 self.
num_leavesnum_leaves: int =
sum([spec.num_leaves
for spec
in children_specs])
82 return f
'TreeSpec({self.type.__name__}, {self.context}, {self.children_specs})'
84 def __eq__(self, other: Any) -> bool:
85 result = self.
typetype == other.type
and self.
contextcontext == other.context \
89 return cast(bool, result)
91 def __ne__(self, other: Any) -> bool:
92 return not self.
__eq____eq__(other)
105 """Flattens a pytree into a list of values and a TreeSpec that can be used
106 to reconstruct the pytree.
111 flatten_fn = SUPPORTED_NODES[
type(pytree)].flatten_fn
112 child_pytrees, context = flatten_fn(pytree)
115 result : List[Any] = []
116 children_specs : List[
'TreeSpec'] = []
117 for child
in child_pytrees:
120 children_specs.append(child_spec)
122 return result,
TreeSpec(
type(pytree), context, children_specs)
126 """Given a list of values and a TreeSpec, builds a pytree.
127 This is the inverse operation of `tree_flatten`.
131 f
'tree_unflatten(values, spec): Expected `spec` to be instance of '
132 f
'TreeSpec but got item of type {type(spec)}.')
133 if len(values) != spec.num_leaves:
135 f
'tree_unflatten(values, spec): `values` has length {len(values)} '
136 f
'but the spec refers to a pytree that holds {spec.num_leaves} '
141 unflatten_fn = SUPPORTED_NODES[spec.type].unflatten_fn
147 for child_spec
in spec.children_specs:
148 end += child_spec.num_leaves
149 child_pytrees.append(
tree_unflatten(values[start:end], child_spec))
152 return unflatten_fn(child_pytrees, spec.context)
167 return [pytree] * spec.num_leaves
170 if type(pytree) != spec.type:
173 flatten_fn = SUPPORTED_NODES[
type(pytree)].flatten_fn
174 child_pytrees, ctx = flatten_fn(pytree)
177 if len(child_pytrees) !=
len(spec.children_specs)
or ctx != spec.context:
181 result : List[Any] = []
182 for child, child_spec
in zip(child_pytrees, spec.children_specs):
None __init__(self, Any typ, Context context, List['TreeSpec'] children_specs)
bool __eq__(self, Any other)
bool __ne__(self, Any other)
T & cast(const Tensor &packed)
constexpr Symbol len(static_cast< unique_t >(_keys::aten_len))
constexpr Symbol zip(static_cast< unique_t >(_keys::prim_zip))
constexpr Symbol list(static_cast< unique_t >(_keys::prim_list))
constexpr Symbol isinstance(static_cast< unique_t >(_keys::prim_isinstance))
def NamedTuple(name_prefix, *fields)
computes the sum of all elements per channel and the sum of all elements squared per channel These values can be reduced across multiple batches and used to obtain the mean and variance across the full set of batches Using the new mean and variance as input to SpatialBN has the effect of changing the batch size over which SpatialBN is applied DOC sum
List[Any] _list_unflatten(List[Any] values, Context context)
Tuple[Any,...] _tuple_unflatten(List[Any] values, Context context)
Tuple[List[Any], TreeSpec] tree_flatten(PyTree pytree)
Optional[List[Any]] _broadcast_to_and_flatten(PyTree pytree, TreeSpec spec)
bool _is_leaf(PyTree pytree)
Tuple[List[Any], Context] _list_flatten(List[Any] d)
Tuple[List[Any], Context] _dict_flatten(Dict[Any, Any] d)
None _register_pytree_node(Any typ, FlattenFunc flatten_fn, UnflattenFunc unflatten_fn)
Dict[Any, Any] _dict_unflatten(List[Any] values, Context context)
PyTree tree_unflatten(List[Any] values, TreeSpec spec)
Tuple[List[Any], Context] _tuple_flatten(Tuple[Any,...] d)