#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import inspect
import json
from copy import deepcopy
from functools import wraps
from trixi.util.util import ModuleMultiTypeDecoder, ModuleMultiTypeEncoder
[docs]class Config(dict):
"""
Config is the main object used to store configurations. As a rule of thumb, anything you might
want to change in your experiment should go into the Config. It's basically a :class:`dict`,
but vastly more powerful. Key features are
- Access keys as attributes
Config["a"]["b"]["c"] is the same as Config.a.b.c.
Can also be used for setting if the second to last key exists. Only works for keys that
conform with Python syntax (Config.myattr-1 is not allowed).
- Advanced de-/serialization
Using specialized JSON encoders and decoders, almost anything can be serialized and
deserialized. This includes types, functions (except lambdas) and modules. For example,
you could have something like::
c = Config(model=MyModel)
c.dump("somewhere")
and end up with a JSON file that looks like this::
{
"model": "__type__(your.model.module.MyModel)"
}
and vice versa. We use double underscores and parentheses for serialization,
so it's probably a good idea to not use this pattern for other stuff!
- Automatic CLI exposure
If desired, the Config will create an ArgumentParser that contains all keys in the
Config as arguments in the form "- - key", so you can run your experiment from the command
line and manually overwrite certain values. Deeper levels are also accessible via dot
notation "- - key_with_dict_value.inner_key".
- Comparison
Compare any number of Configs and get a new Config containing only the values that
differ among input Configs.
Args:
file_ (str): Load Config from this file.
config (Config): Update with values from this Config (can be combined with :attr:`file_`).
Will by default only make shallow copies, see :attr:`deep`.
update_from_argv (bool): Update values from argv. Will automatically expose keys to the
CLI as '- - key'.
deep (bool): Make deep copies if :attr:`config` is given.
"""
def __init__(self, file_=None, config=None, update_from_argv=False, deep=False, **kwargs):
super(Config, self).__init__()
# the following allows us to access keys as attributes if syntax permits
# config["a"] = 1 -> config.a -> 1
# config["a-b"] = 2 -> config.a-b (not possible)
# this is purely for convenience
self.__dict__ = self
if file_ is not None:
self.load(file_)
if config is not None:
if deep:
# convert config to Config (in case it's just a dict)
# and get deepcopy
config = Config(config=config, update_from_argv=False, deep=False)
config = config.deepcopy()
self.update(config, deep=False)
if len(kwargs) >= 1:
if deep:
kwargs = Config(config=kwargs, update_from_argv=False, deep=False)
kwargs = kwargs.deepcopy()
self.update(kwargs, deep=False)
if update_from_argv:
update_from_sys_argv(self)
[docs] def update(self, dict_like, deep=False, ignore=None, allow_dict_overwrite=True):
"""Update entries in the Config.
Args:
dict_like (dict or derivative thereof): Update source.
deep (bool): Make deep copies of all references in the source.
ignore (iterable): Iterable of keys to ignore in update.
allow_dict_overwrite (bool): Allow overwriting with dict.
Regular dicts only update on the highest level while we recurse
and merge Configs. This flag decides whether it is possible to
overwrite a 'regular' value with a dict/Config at lower levels.
See examples for an illustration of the difference
Examples:
The following illustrates the update behaviour if
:obj:allow_dict_overwrite is active. If it isn't, an AttributeError
would be raised, originating from trying to update "string"::
config1 = Config(config={
"lvl0": {
"lvl1": "string",
"something": "else"
}
})
config2 = Config(config={
"lvl0": {
"lvl1": {
"lvl2": "string"
}
}
})
config1.update(config2, allow_dict_overwrite=True)
>>>config1
{
"lvl0": {
"lvl1": {
"lvl2": "string"
},
"something": "else"
}
}
"""
if ignore is None:
ignore = ()
if deep:
update_config = Config(config=dict_like, deep=True)
self.update(update_config, deep=False, ignore=ignore, allow_dict_overwrite=allow_dict_overwrite)
else:
for key, value in dict_like.items():
if key in ignore:
continue
if key in self and isinstance(value, dict):
try:
new_ignore = []
for ignore_key in ignore:
ignore_key = ignore_key.split(".")
if ignore_key[0] == key:
new_ignore.append(".".join(ignore_key[1:]))
self[key].update(value, deep=False, ignore=new_ignore, allow_dict_overwrite=allow_dict_overwrite)
except AttributeError as ae:
if allow_dict_overwrite:
self[key] = value
else:
raise ae
else:
self[key] = value
[docs] def deepupdate(self, dict_like, ignore=None, allow_dict_overwrite=True):
"""Identical to :meth:`update` with `deep=True`.
Args:
dict_like (dict or derivative thereof): Update source.
ignore (iterable): Iterable of keys to ignore in update.
allow_dict_overwrite (bool): Allow overwriting with dict.
Regular dicts only update on the highest level while we recurse
and merge Configs. This flag decides whether it is possible to
overwrite a 'regular' value with a dict/Config at lower levels.
See examples for an illustration of the difference
"""
self.update(dict_like, deep=True, ignore=ignore, allow_dict_overwrite=allow_dict_overwrite)
def __setattr__(self, key, value):
"""Modified to automatically convert `dict` to Config."""
if type(value) == dict:
new_config = Config()
new_config.update(value, deep=False)
super(Config, self).__setattr__(key, new_config)
else:
super(Config, self).__setattr__(key, value)
def __getitem__(self, key):
"""Allows convenience access to deeper levels using dots to separate
levels, for example `config["a.b.c"]`.
"""
if key == "":
if len(self.keys()) == 1:
key = list(self.keys())[0]
else:
raise KeyError("Empty string only works for single element Configs.")
if type(key) == str and "." in key:
superkey = key.split(".")[0]
subkeys = ".".join(key.split(".")[1:])
if superkey not in self:
# this part enables ints in the access chain, e.g. a.1.b
try:
intkey = int(superkey)
if intkey in self:
superkey = intkey
except ValueError:
# if we can't convert to int, just continue so a KeyError will be raised
pass
if type(self[superkey]) in (list, tuple):
try:
subkeys = int(subkeys)
except ValueError:
pass
return self[superkey][subkeys]
else:
return super(Config, self).__getitem__(key)
def __setitem__(self, key, value):
"""Allows convenience access to deeper levels using dots to separate
levels, for example `config["a.b.c"]`.
"""
if key == "":
if len(self.keys()) == 1:
key = list(self.keys())[0]
else:
raise KeyError("Empty string only works for single element Configs.")
if type(key) == str and "." in key:
superkey = key.split(".")[0]
subkeys = ".".join(key.split(".")[1:])
if superkey != "" and superkey not in self:
self[superkey] = Config()
if type(self[superkey]) == list:
try:
subkeys = int(subkeys)
except ValueError:
pass
self[superkey][subkeys] = value
elif type(value) == dict:
super(Config, self).__setitem__(key, Config(config=value))
else:
super(Config, self).__setitem__(key, value)
def __delitem__(self, key):
"""Allows convenience access to deeper levels using dots to separate
levels, for example `config["a.b.c"]`.
"""
if type(key) == str and "." in key and key not in self:
superkey = key.split(".")[0]
subkeys = ".".join(key.split(".")[1:])
if superkey not in self:
raise KeyError(superkey + " not found.")
else:
self[superkey].__delitem__(subkeys)
else:
super().__delitem__(key)
[docs] def set_with_decode(self, key, value, stringify_value=False):
"""Set single value, using :class:`.ModuleMultiTypeDecoder` to interpret
key and value strings by creating a temporary JSON string.
Args:
key (str): Config key.
value (str): New value key will map to.
stringify_value (bool): If `True`, will insert the value into the
temporary JSON as a real string. See examples!
Examples:
Example for when you need to set `stringify_value=True`::
config.set_with_decode("key", "__type__(trixi.util.config.Config)", stringify_value=True)
Example for when you need to set `stringify_value=False`::
config.set_with_decode("key", "[1, 2, 3]")
"""
if type(key) != str:
# We could encode the key if it's not a string, but for now raise
raise TypeError("set_with_decode requires string as key.")
if type(value) != str:
raise TypeError("set_with_decode requires string as value.")
dict_str = ""
depth = 0
key_split = key.split(".")
for k in key_split:
dict_str += "{"
dict_str += '"{}":'.format(k)
depth += 1
if stringify_value:
dict_str += '"{}"'.format(value)
else:
dict_str += "{}".format(value)
for _ in range(depth):
dict_str += "}"
self.loads(dict_str)
[docs] def set_from_string(self, str_, stringify_value=False):
"""Set a value from a single string, separated with "=".
Uses :meth:´set_with_decode´.
Args:
str_ (str): String that looks like "key=value".
"""
key, value = str_.split("=")
self.set_with_decode(key, value, stringify_value)
[docs] def update_missing(self, dict_like, deep=False, ignore=None):
"""Recursively insert values that do not yet exist.
Args:
dict_like (dict or derivative thereof): Update source.
deep (bool): Make deep copies of all references in the source.
ignore (iterable): Iterable of keys to ignore in update.
"""
for key, value in dict_like.items():
if key not in self:
if type(value) == Config:
if deep:
self[key] = value.deepcopy()
else:
self[key] = value
else:
if deep:
self[key] = deepcopy(value)
else:
self[key] = value
else:
if isinstance(value, dict) and isinstance(self[key], dict):
self[key].update_missing(Config(config=value, deep=deep))
[docs] def dump(self, file_, indent=4, separators=(",", ": "), **kwargs):
"""Write config to file using :meth:`json.dump`.
Args:
file_ (str or File): Write to this location.
indent (int): Formatting option.
separators (iterable): Formatting option.
**kwargs: Will be passed to :meth:`json.dump`.
"""
if hasattr(file_, "write"):
json.dump(self, file_, cls=ModuleMultiTypeEncoder, indent=indent, separators=separators, **kwargs)
else:
with open(file_, "w") as file_object:
json.dump(self, file_object, cls=ModuleMultiTypeEncoder, indent=indent, separators=separators, **kwargs)
[docs] def dumps(self, indent=4, separators=(",", ": "), **kwargs):
"""Get string representation using :meth:`json.dumps`.
Args:
indent (int): Formatting option.
separators (iterable): Formatting option.
**kwargs: Will be passed to :meth:`json.dumps`.
"""
return json.dumps(self, cls=ModuleMultiTypeEncoder, indent=indent, separators=separators, **kwargs)
[docs] def load(self, file_, raise_=True, decoder_cls_=ModuleMultiTypeDecoder, **kwargs):
"""Load config from file using :meth:`json.load`.
Args:
file_ (str or File): Read from this location.
raise (bool): Raise errors.
decoder_cls_ (type): Class that is used to decode JSON string.
**kwargs: Will be passed to :meth:`json.load`.
"""
try:
if hasattr(file_, "read"):
new_dict = json.load(file_, cls=decoder_cls_, **kwargs)
else:
with open(file_, "r") as file_object:
new_dict = json.load(file_object, cls=decoder_cls_, **kwargs)
except Exception as e:
if raise_:
raise e
self.update(new_dict)
[docs] def loads(self, json_str, decoder_cls_=ModuleMultiTypeDecoder, **kwargs):
"""Load config from JSON string using :meth:`json.loads`.
Args:
json_str (str): Interpret this string.
decoder_cls_ (type): Class that is used to decode JSON string.
**kwargs: Will be passed to :meth:`json.loads`.
"""
if not json_str.startswith("{"):
json_str = "{" + json_str
if not json_str.endswith("}"):
json_str = json_str + "}"
new_dict = json.loads(json_str, cls=decoder_cls_, **kwargs)
self.update(new_dict)
[docs] def hasattr_not_none(self, key):
try:
result = self[key]
return result is not None
except KeyError as ke:
return False
[docs] def contains(self, dict_like):
"""Check whether all items in a dictionary-like object match the ones in
this Config.
Args:
dict_like (dict or derivative thereof): Returns True if this is
contained in this Config.
Returns:
bool: True if dict_like is contained in self, otherwise False.
"""
dict_like_config = Config(config=dict_like)
for key, val in dict_like_config.items():
if key not in self:
return False
else:
if isinstance(val, dict):
if not self[key].contains(val):
return False
else:
if not self[key] == val:
return False
return True
[docs] def deepcopy(self):
"""Get a deep copy of this Config.
Returns:
Config: A deep copy of self.
"""
def _deepcopy(source, target):
for key, val in source.items():
if not isinstance(val, dict):
try:
target[key] = deepcopy(val)
except TypeError as e:
target[key] = val
else:
target[key] = Config()
_deepcopy(source[key], target[key])
new_config = Config()
_deepcopy(self, new_config)
return new_config
[docs] @staticmethod
def init_objects(config):
"""Returns a new Config with types converted to instances.
Any value that is a Config and contains a type key will be converted to
an instance of that type::
{
"stuff": "also_stuff",
"convert_me": {
type: {
"param": 1,
"other_param": 2
},
"something_else": "hopefully_useless"
}
}
becomes::
{
"stuff": "also_stuff",
"convert_me": type(param=1, other_param=2)
}
Note that additional entries can be lost as shown above.
Args:
config (Config): New Config will be built from this one
Returns:
Config: A new config with instances made from type entries.
"""
def init_sub_objects(objs):
if isinstance(objs, dict):
ret_dict = Config()
for key, val in objs.items():
if isinstance(key, type):
init_param = init_sub_objects(val)
if isinstance(init_param, dict):
init_obj = key(**init_param)
elif isinstance(init_param, (list, tuple, set)):
init_obj = key(*init_param)
else:
init_obj = key()
return init_obj
elif isinstance(val, (dict, list, tuple, set)):
ret_dict[key] = init_sub_objects(val)
else:
ret_dict[key] = val
return ret_dict
elif isinstance(objs, (list, tuple, set)):
orig_type = type(objs)
ret_list = []
for el in objs:
ret_list.append(init_sub_objects(el))
return orig_type(ret_list)
else:
return objs
return init_sub_objects(config)
def __str__(self):
return self.dumps(sort_keys=True)
[docs] def difference_config(self, *other_configs):
"""Get the difference of this and any number of other configs.
See :meth:`difference_config_static` for more information.
Args:
*other_configs (Config): Compare these configs and self.
Returns:
Config: Difference of self and the other configs.
"""
return self.difference_config_static(self, *other_configs)
[docs] @staticmethod
def difference_config_static(*configs, only_set=False, encode=True):
"""Make a Config of all elements that differ between N configs.
The resulting Config looks like this::
{
key: (config1[key], config2[key], ...)
}
If the key is missing, None will be inserted. The inputs will not be
modified.
Args:
configs (Config): Any number of Configs
only_set (bool): If only the set of different values hould be returned or for each config the
corresponding one
encode (bool): If True, values will be encoded the same way as they are when exported to disk (e.g."__type__(MyClass)")
Returns:
Config: Possibly empty Config
"""
difference = dict()
mmte = ModuleMultiTypeEncoder()
all_keys = set()
for config in configs:
all_keys.update(set(config.keys()))
for key in all_keys:
current_values = []
all_equal = True
all_configs = True
for config in configs:
if key not in config:
all_equal = False
all_configs = False
current_values.append(None)
else:
if encode:
current_values.append(mmte._encode(config[key]))
else:
current_values.append(config[key])
if len(current_values) >= 2:
if current_values[-1] != current_values[-2]:
all_equal = False
if type(current_values[-1]) != Config:
all_configs = False
if not all_equal:
if not all_configs:
if not only_set:
difference[key] = tuple(current_values)
else:
difference[key] = tuple(set(current_values))
else:
difference[key] = Config.difference_config_static(*current_values, only_set=only_set)
return Config(config=difference)
[docs] def flat(self, keep_lists=True, max_split_size=10, flatten_int=False):
"""Returns a flattened version of the Config as dict.
Nested Configs and lists will be replaced by concatenated keys like so::
{
"a": 1,
"b": [2, 3],
"c": {
"x": 4,
"y": {
"z": 5
}
},
"d": (6, 7)
}
Becomes::
{
"a": 1,
"b": [2, 3], # if keep_lists is True
"b.0": 2,
"b.1": 3,
"c.x": 4,
"c.y.z": 5,
"d": (6, 7)
}
We return a dict because dots are disallowed within Config keys.
Args:
keep_lists: Keeps list along with unpacked values
max_split_size: List longer than this will not be unpacked
flatten_int: Integer keys will be treated as strings
Returns:
dict: A flattened version of self
"""
def flat_(obj):
def items():
for key, val in obj.items():
if isinstance(val, dict) and (isinstance(key, str) or (isinstance(key, int) and flatten_int)):
intermediate_dict = {}
for subkey, subval in flat_(val).items():
if isinstance(subkey, str):
yield str(key) + "." + subkey, subval
elif isinstance(subkey, int) and flatten_int:
yield str(key) + "." + str(subkey), subval
else:
intermediate_dict[subkey] = subval
if len(intermediate_dict) > 0:
yield str(key), intermediate_dict
elif isinstance(val, (list, tuple)):
keep_this = (
keep_lists or not isinstance(key, (str, int)) or (isinstance(key, int) and not flatten_int)
)
if max_split_size not in (None, False) and len(val) > max_split_size:
keep_this = True
if keep_this:
yield key, val
else:
for i, subval in enumerate(val):
yield str(key) + "." + str(i), subval
else:
yield key, val
return dict(items())
return flat_(self)
[docs] def to_cmd_args_str(self):
"""Create a string representing what one would need to pass to the
command line. Does not yet use JSON encoding!
Returns:
str: Command line string
"""
c_flat = self.flat()
str_list = []
for key, val in c_flat.items():
if isinstance(val, (list, tuple)):
vals = [str(v) for v in val]
val_str = " ".join(vals)
else:
val_str = str(val)
str_list.append("--{} {}".format(key, val_str))
return " ".join(str_list)
[docs]def update_from_sys_argv(config, warn=False):
"""Updates Config with the arguments passed as args when running the
program. Keys will be converted to command line options, then matching
options in `sys.argv` will be used to update the Config.
Args:
config (Config): Update this Config.
warn (bool): Raise warnings if there are unknown options. Turn this on
if you don't use any :class:`argparse.ArgumentParser` after to
check for possible errors.
"""
import sys
import argparse
import warnings
def str2bool(v):
if v.lower() in ("yes", "true", "t", "y", "1"):
return True
elif v.lower() in ("no", "false", "f", "n", "0"):
return False
else:
raise argparse.ArgumentTypeError("Boolean value expected.")
if len(sys.argv) > 1:
parser = argparse.ArgumentParser(allow_abbrev=False)
encoder = ModuleMultiTypeEncoder()
decoder = ModuleMultiTypeDecoder()
config_flat = config.flat()
for key, val in config_flat.items():
name = "--{}".format(key)
if val is None:
parser.add_argument(name)
else:
if type(val) == bool:
parser.add_argument(name, type=str2bool, default=val)
elif isinstance(val, (list, tuple)):
if len(val) > 0 and type(val[0]) != type:
parser.add_argument(name, nargs="+", type=type(val[0]), default=val)
else:
parser.add_argument(name, nargs="+", default=val)
else:
if type(val) == type:
val = encoder._encode(val)
parser.add_argument(name, type=type(val), default=val)
# parse args
param, unknown = parser.parse_known_args()
param = vars(param)
if len(unknown) > 0 and warn:
warnings.warn("Called with unknown arguments: {}".format(unknown), RuntimeWarning)
# calc diff between configs
diff_keys = list(Config.difference_config_static(param, config_flat).flat().keys())
# convert type args
ignore_ = []
for key, val in param.items():
if val in ("none", "None"):
param[key] = None
if type(config_flat[key]) == type:
if isinstance(val, str):
val = val.replace("'", "")
val = val.replace('"', "")
param[key] = decoder._decode(val)
try:
key_split = key.split(".")
list_object, _ = ".".join(key_split[:-1]), int(key_split[-1])
if "--" + list_object in sys.argv:
ignore_.append(key)
except ValueError as ve:
pass
for i in ignore_:
del param[i]
### Delete not changed entries
param_keys = list(param.keys())
for i in param_keys:
if i not in diff_keys and i in param:
del param[i]
# update dict
config.update(param)
[docs]def monkey_patch_fn_args_as_config(f):
"""Decorator: Monkey patches, aka addes a variable 'fn_args_as_config' to globals,
so that it can be accessed by the decorated function.
Adds all function parameters to a dict 'fn_args_as_config', which can be accessed by the method.
Be careful using it!
"""
sig = inspect.signature(f)
@wraps(f)
def wrapper(*args, **kwargs):
bound_arguments = sig.bind(*args, **kwargs)
bound_arguments.apply_defaults()
c = Config(config=bound_arguments.arguments)
if "self" in c:
c["self"] = c["self"].__class__
g = f.__globals__
g["fn_args_as_config"] = c
try:
res = f(*args, **kwargs)
finally:
del g["fn_args_as_config"]
return res
return wrapper