You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
137 lines
4.6 KiB
137 lines
4.6 KiB
2 years ago
|
from collections import OrderedDict
|
||
|
from typing import Any, Callable, Dict, List, Optional
|
||
|
|
||
|
|
||
|
class AttrDict(OrderedDict):
|
||
|
"""
|
||
|
An attribute dictionary that automatically handles nested keys joined by "/".
|
||
|
|
||
|
Originally copied from: https://stackoverflow.com/questions/3031219/recursively-access-dict-via-attributes-as-well-as-index-access
|
||
|
"""
|
||
|
|
||
|
MARKER = object()
|
||
|
|
||
|
# pylint: disable=super-init-not-called
|
||
|
def __init__(self, *args, **kwargs):
|
||
|
if len(args) == 0:
|
||
|
for key, value in kwargs.items():
|
||
|
self.__setitem__(key, value)
|
||
|
else:
|
||
|
assert len(args) == 1
|
||
|
assert isinstance(args[0], (dict, AttrDict))
|
||
|
for key, value in args[0].items():
|
||
|
self.__setitem__(key, value)
|
||
|
|
||
|
def __contains__(self, key):
|
||
|
if "/" in key:
|
||
|
keys = key.split("/")
|
||
|
key, next_key = keys[0], "/".join(keys[1:])
|
||
|
return key in self and next_key in self[key]
|
||
|
return super(AttrDict, self).__contains__(key)
|
||
|
|
||
|
def __setitem__(self, key, value):
|
||
|
if "/" in key:
|
||
|
keys = key.split("/")
|
||
|
key, next_key = keys[0], "/".join(keys[1:])
|
||
|
if key not in self:
|
||
|
self[key] = AttrDict()
|
||
|
self[key].__setitem__(next_key, value)
|
||
|
return
|
||
|
|
||
|
if isinstance(value, dict) and not isinstance(value, AttrDict):
|
||
|
value = AttrDict(**value)
|
||
|
if isinstance(value, list):
|
||
|
value = [AttrDict(val) if isinstance(val, dict) else val for val in value]
|
||
|
super(AttrDict, self).__setitem__(key, value)
|
||
|
|
||
|
def __getitem__(self, key):
|
||
|
if "/" in key:
|
||
|
keys = key.split("/")
|
||
|
key, next_key = keys[0], "/".join(keys[1:])
|
||
|
val = self[key]
|
||
|
if not isinstance(val, AttrDict):
|
||
|
raise ValueError
|
||
|
return val.__getitem__(next_key)
|
||
|
|
||
|
return self.get(key, None)
|
||
|
|
||
|
def all_keys(
|
||
|
self,
|
||
|
leaves_only: bool = False,
|
||
|
parent: Optional[str] = None,
|
||
|
) -> List[str]:
|
||
|
keys = []
|
||
|
for key in self.keys():
|
||
|
cur = key if parent is None else f"{parent}/{key}"
|
||
|
if not leaves_only or not isinstance(self[key], dict):
|
||
|
keys.append(cur)
|
||
|
if isinstance(self[key], dict):
|
||
|
keys.extend(self[key].all_keys(leaves_only=leaves_only, parent=cur))
|
||
|
return keys
|
||
|
|
||
|
def dumpable(self, strip=True):
|
||
|
"""
|
||
|
Casts into OrderedDict and removes internal attributes
|
||
|
"""
|
||
|
|
||
|
def _dump(val):
|
||
|
if isinstance(val, AttrDict):
|
||
|
return val.dumpable()
|
||
|
elif isinstance(val, list):
|
||
|
return [_dump(v) for v in val]
|
||
|
return val
|
||
|
|
||
|
if strip:
|
||
|
return {k: _dump(v) for k, v in self.items() if not k.startswith("_")}
|
||
|
return {k: _dump(v if not k.startswith("_") else repr(v)) for k, v in self.items()}
|
||
|
|
||
|
def map(
|
||
|
self,
|
||
|
map_fn: Callable[[Any, Any], Any],
|
||
|
should_map: Optional[Callable[[Any, Any], bool]] = None,
|
||
|
) -> "AttrDict":
|
||
|
"""
|
||
|
Creates a copy of self where some or all values are transformed by
|
||
|
map_fn.
|
||
|
|
||
|
:param should_map: If provided, only those values that evaluate to true
|
||
|
are converted; otherwise, all values are mapped.
|
||
|
"""
|
||
|
|
||
|
def _apply(key, val):
|
||
|
if isinstance(val, AttrDict):
|
||
|
return val.map(map_fn, should_map)
|
||
|
elif should_map is None or should_map(key, val):
|
||
|
return map_fn(key, val)
|
||
|
return val
|
||
|
|
||
|
return AttrDict({k: _apply(k, v) for k, v in self.items()})
|
||
|
|
||
|
def __eq__(self, other):
|
||
|
return self.keys() == other.keys() and all(self[k] == other[k] for k in self.keys())
|
||
|
|
||
|
def combine(
|
||
|
self,
|
||
|
other: Dict[str, Any],
|
||
|
combine_fn: Callable[[Optional[Any], Optional[Any]], Any],
|
||
|
) -> "AttrDict":
|
||
|
"""
|
||
|
Some values may be missing, but the dictionary structures must be the
|
||
|
same.
|
||
|
|
||
|
:param combine_fn: a (possibly non-commutative) function to combine the
|
||
|
values
|
||
|
"""
|
||
|
|
||
|
def _apply(val, other_val):
|
||
|
if val is not None and isinstance(val, AttrDict):
|
||
|
assert isinstance(other_val, AttrDict)
|
||
|
return val.combine(other_val, combine_fn)
|
||
|
return combine_fn(val, other_val)
|
||
|
|
||
|
# TODO nit: this changes the ordering..
|
||
|
keys = self.keys() | other.keys()
|
||
|
return AttrDict({k: _apply(self[k], other[k]) for k in keys})
|
||
|
|
||
|
__setattr__, __getattr__ = __setitem__, __getitem__
|