a fork of shap-e for gc
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

136 lines
4.6 KiB

from collections import OrderedDict
from typing import Any, Callable, Dict, List, Optional
class AttrDict(OrderedDict):
"""
An attribute dictionary that automatically handles nested keys joined by "/".
Originally copied from: https://stackoverflow.com/questions/3031219/recursively-access-dict-via-attributes-as-well-as-index-access
"""
MARKER = object()
# pylint: disable=super-init-not-called
def __init__(self, *args, **kwargs):
if len(args) == 0:
for key, value in kwargs.items():
self.__setitem__(key, value)
else:
assert len(args) == 1
assert isinstance(args[0], (dict, AttrDict))
for key, value in args[0].items():
self.__setitem__(key, value)
def __contains__(self, key):
if "/" in key:
keys = key.split("/")
key, next_key = keys[0], "/".join(keys[1:])
return key in self and next_key in self[key]
return super(AttrDict, self).__contains__(key)
def __setitem__(self, key, value):
if "/" in key:
keys = key.split("/")
key, next_key = keys[0], "/".join(keys[1:])
if key not in self:
self[key] = AttrDict()
self[key].__setitem__(next_key, value)
return
if isinstance(value, dict) and not isinstance(value, AttrDict):
value = AttrDict(**value)
if isinstance(value, list):
value = [AttrDict(val) if isinstance(val, dict) else val for val in value]
super(AttrDict, self).__setitem__(key, value)
def __getitem__(self, key):
if "/" in key:
keys = key.split("/")
key, next_key = keys[0], "/".join(keys[1:])
val = self[key]
if not isinstance(val, AttrDict):
raise ValueError
return val.__getitem__(next_key)
return self.get(key, None)
def all_keys(
self,
leaves_only: bool = False,
parent: Optional[str] = None,
) -> List[str]:
keys = []
for key in self.keys():
cur = key if parent is None else f"{parent}/{key}"
if not leaves_only or not isinstance(self[key], dict):
keys.append(cur)
if isinstance(self[key], dict):
keys.extend(self[key].all_keys(leaves_only=leaves_only, parent=cur))
return keys
def dumpable(self, strip=True):
"""
Casts into OrderedDict and removes internal attributes
"""
def _dump(val):
if isinstance(val, AttrDict):
return val.dumpable()
elif isinstance(val, list):
return [_dump(v) for v in val]
return val
if strip:
return {k: _dump(v) for k, v in self.items() if not k.startswith("_")}
return {k: _dump(v if not k.startswith("_") else repr(v)) for k, v in self.items()}
def map(
self,
map_fn: Callable[[Any, Any], Any],
should_map: Optional[Callable[[Any, Any], bool]] = None,
) -> "AttrDict":
"""
Creates a copy of self where some or all values are transformed by
map_fn.
:param should_map: If provided, only those values that evaluate to true
are converted; otherwise, all values are mapped.
"""
def _apply(key, val):
if isinstance(val, AttrDict):
return val.map(map_fn, should_map)
elif should_map is None or should_map(key, val):
return map_fn(key, val)
return val
return AttrDict({k: _apply(k, v) for k, v in self.items()})
def __eq__(self, other):
return self.keys() == other.keys() and all(self[k] == other[k] for k in self.keys())
def combine(
self,
other: Dict[str, Any],
combine_fn: Callable[[Optional[Any], Optional[Any]], Any],
) -> "AttrDict":
"""
Some values may be missing, but the dictionary structures must be the
same.
:param combine_fn: a (possibly non-commutative) function to combine the
values
"""
def _apply(val, other_val):
if val is not None and isinstance(val, AttrDict):
assert isinstance(other_val, AttrDict)
return val.combine(other_val, combine_fn)
return combine_fn(val, other_val)
# TODO nit: this changes the ordering..
keys = self.keys() | other.keys()
return AttrDict({k: _apply(self[k], other[k]) for k in keys})
__setattr__, __getattr__ = __setitem__, __getitem__