Source code for dantro.groups.psp

"""This module implements :py:class:`~dantro.base.BaseDataContainer`
specializations that make use of features from the
`paramspace <https://gitlab.com/blsqr/paramspace>`_ package, in particular the
:py:class:`~paramspace.paramspace.ParamSpace` class.
"""

import copy
import logging
from typing import Dict, List, Sequence, Union

import numpy as np
import numpy.ma
from paramspace import ParamSpace

from .._import_tools import LazyLoader
from ..base import PATH_JOIN_CHAR
from ..containers import XrDataContainer
from ..data_ops.arr_ops import multi_concat as _multi_concat
from ..mixins import PaddedIntegerItemAccessMixin
from . import is_group
from .ordered import IndexedDataGroup, OrderedDataGroup

log = logging.getLogger(__name__)

xr = LazyLoader("xarray")

# -----------------------------------------------------------------------------


[docs]@is_group class ParamSpaceStateGroup(OrderedDataGroup): """A ParamSpaceStateGroup is meant to be used as a member group of the :py:class:`~dantro.groups.psp.ParamSpaceGroup`. While its *own* name need be interpretable as a positive integer (enforced in the enclosing :py:class:`~dantro.groups.psp.ParamSpaceGroup` but also here), it can *hold* members with any name. """ _NEW_GROUP_CLS = OrderedDataGroup
[docs] def _check_name(self, name: str) -> None: """Called by __init__ and overwritten here to check the name.""" # Assert that the name is valid, i.e. convertible to an integer try: int(name) except ValueError as err: raise ValueError( "Only names that are representible as integers are possible " f"for the name of {self.classname}, got '{name}'!" ) from err # ... and not negative if int(name) < 0: raise ValueError( f"Name for {self.classname} needs to be positive when " f"converted to integer, was: {name}" ) # Still ask the parent method for its opinion on this matter super()._check_name(name)
@property def coords(self) -> dict: """Retrieves the coordinates of this group within the parameter space described by the :py:class:`~dantro.groups.psp.ParamSpaceGroup` this group is enclosed in. Returns: dict: The coordinates of this group, keys being dimension names and values being the coordinate values for this group. """ state_map = self.parent.pspace.state_map coords = state_map.where(state_map == int(self.name), drop=True).coords return {d: c.item() for d, c in coords.items()}
[docs]@is_group class ParamSpaceGroup(PaddedIntegerItemAccessMixin, IndexedDataGroup): """The ParamSpaceGroup is associated with a :py:class:`paramspace.paramspace.ParamSpace` object and the loaded results of an iteration over this parameter space. Thus, the groups that are stored in the ParamSpaceGroup need all relate to a state of the parameter space, identified by a zero-padded string name. In fact, this group allows no other kinds of groups stored inside. To make access to a specific state easier, it allows accessing a state by its state number as integer. """ # Configure the class variables that define some of the behaviour # Define which .attrs entry to return from the `pspace` property _PSPGRP_PSPACE_ATTR_NAME = "pspace" # A transformation callable that can be used during data selection _PSPGRP_TRANSFORMATOR = None # Define the class to use for the direct members of this group _NEW_GROUP_CLS = ParamSpaceStateGroup # Define allowed container types _ALLOWED_CONT_TYPES = (ParamSpaceStateGroup,) # .........................................................................
[docs] def __init__( self, *, name: str, pspace: ParamSpace = None, containers: list = None, **kwargs, ): """Initialize a OrderedDataGroup from the list of given containers. Args: name (str): The name of this group. pspace (paramspace.paramspace.ParamSpace, optional): Can already pass a ParamSpace object here. containers (list, optional): A list of containers to add, which need to be :py:class:`~dantro.groups.psp.ParamSpaceStateGroup` objects. **kwargs: Further initialisation kwargs, e.g. ``attrs`` ... """ # Initialize with parent method, which will call .add(*containers) super().__init__(name=name, containers=containers, **kwargs) # If given, associate the parameter space object if pspace is not None: self.pspace = pspace
# Properties .............................................................. @property def pspace(self) -> Union[ParamSpace, None]: """Reads the entry named ``_PSPGRP_PSPACE_ATTR_NAME`` in ``.attrs`` and returns a :py:class:`~paramspace.paramspace.ParamSpace` object, if available there. Returns: Union[paramspace.paramspace.ParamSpace, None]: The associated parameter space, or None, if there is none associated yet. """ return self.attrs.get(self._PSPGRP_PSPACE_ATTR_NAME, None) @pspace.setter def pspace(self, val: ParamSpace): """If not already set, sets the entry in the attributes that is accessed by the ``.pspace`` property """ if self.pspace is not None: raise RuntimeError( "The attribute for the parameter space of this " f"{self.logstr} was already set, cannot set it again!" ) elif not isinstance(val, ParamSpace): raise TypeError( f"The attribute for the parameter space of {self.logstr} " f"needs to be a ParamSpace-derived object, was {type(val)}!" ) # Checked it, now set it self.attrs[self._PSPGRP_PSPACE_ATTR_NAME] = val log.debug("Associated %s with %s", val, self.logstr) @property def only_default_data_present(self) -> bool: """Returns true if only data for the default point in parameter space is available in this group. """ return (len(self) == 1) and (0 in self) # Data access .............................................................
[docs] def select( self, *, field: Union[str, List[str]] = None, fields: Dict[str, List[str]] = None, subspace: dict = None, method: str = "concat", idx_as_label: bool = False, base_path: str = None, **kwargs, ) -> "xarray.Dataset": """Selects a multi-dimensional slab of this ParamSpaceGroup and the specified fields and returns them bundled into an :py:class:`xarray.Dataset` with labelled dimensions and coordinates. Args: field (Union[str, List[str]], optional): The field of data to select. Should be path or a list of strings that points to an entry in the data tree. To select multiple fields, do not pass this argument but use the `fields` argument. fields (Dict[str, List[str]], optional): A dict specifying the fields that are to be loaded into the dataset. Keys will be the names of the resulting variables, while values should specify the path to the field in the data tree. Thus, they can be strings, lists of strings or dicts with the `path` key present. In the latter case, a dtype can be specified via the `dtype` key in the dict. subspace (dict, optional): Selector for a subspace of the parameter space. Adheres to the parameter space's :py:meth:`~paramspace.paramspace.ParamSpace.activate_subspace` signature. method (str, optional): How to combine the selected datasets. - ``concat``: concatenate sequentially along all parameter space dimensions. This can preserve the data type but it does not work if one data point is missing. - ``merge``: merge always works, even if data points are missing, but will convert all dtypes to float. idx_as_label (bool, optional): If true, adds the trivial indices as labels for those dimensions where coordinate labels were not extractable from the loaded field. This allows merging for data with different extends in an unlabelled dimension. base_path (str, optional): If given, ``path`` specifications for each field can be seen as relative to this path **kwargs: Passed along either to xr.concat or xr.merge, depending on the ``method`` argument. Raises: KeyError: On invalid state key. ValueError: Raised in multiple scenarios: If no :py:class:`~paramspace.paramspace.ParamSpace` was associated with this group, for wrong argument values, if the data to select cannot be extracted with the given argument values, exceptions passed on from xarray. Returns: xarray.Dataset: The selected hyperslab of the parameter space, holding the desired fields. """ def parse_fields(*, field, fields) -> dict: """Parses the field and fields arguments into a uniform dict Return value is a dict of the following structure: <var_name_1>: path: <list of strings> dtype: <str, optional> dims: <list of strings, optional> ... further <var_name_2>: ... TODO Change such that its using strings for paths, not sequences. """ if field is not None and fields is not None: raise ValueError( "Can only specify either of the arguments " "`field` or `fields`, got both!" ) elif field is None and fields is None: raise ValueError( "Need to specify one of the arguments " "`field` or `fields`, got neither of them!" ) elif field is not None: # Generate a dict from the single field argument and put it # into a fields dict such that it can be processed like the # rest ... # Need to find a name from the path if isinstance(field, str): path = field.split(PATH_JOIN_CHAR) kwargs = {} elif isinstance(field, dict): path = field["path"] kwargs = {k: v for k, v in field.items() if k != "path"} # Not using .pop here in order to not change the dict. if isinstance(path, str): path = path.split(PATH_JOIN_CHAR) else: path = list(field) kwargs = {} # Create the fields dict, carrying over all other arguments fields = dict() fields[path[-1]] = dict(path=path, **kwargs) # The fields variable is now available # Make sure it is of right type if not isinstance(fields, dict): raise TypeError( "Argument `fields` needs to be a dict, " f"but was {type(fields)}!" ) # Ensure values of the dict are dicts of the proper structre for name, field in fields.items(): if isinstance(field, str): fields[name] = dict(path=field.split(PATH_JOIN_CHAR)) elif not isinstance(field, dict): # Assume this is a sequence, but better make sure ... fields[name] = dict(path=list(field)) # else: Already a dict; nothing to do. Parameters carried over. return fields def get_state_grp(state_no: int) -> ParamSpaceStateGroup: """Returns the group corresponding to the given state""" try: return self[state_no] except (KeyError, ValueError) as err: # TODO use custom exception class, e.g. from DataManager? raise ValueError( f"No state {state_no} available in {self.logstr}! Make " "sure the data was fully loaded." ) from err def get_var( state_grp: ParamSpaceStateGroup, *, path: List[str], base_path: List[str] = None, dtype: str = None, dims: List[str] = None, transform: Sequence[dict] = None, **transform_kwargs, ) -> Union["xr.Variable", "xr.DataArray"]: """Extracts the field specified by the given path and returns it as either an xr.Variable or (for supported containers) directly as an xr.DataArray. We are using xr.Variables as defaults here, as they provide higher performance than xr.DataArrays; the latter have to be frequently unpacked and restructured in the merge operations. Args: state_grp (ParamSpaceStateGroup): The group to search `path` in path (List[str]): The path to a data container. base_path (List[str], optional): Will be prepended to the given path, if given. dtype (str, optional): The desired dtype for the data. dims (List[str], optional): A list of dimension names for the extracted data. If not given, will name them manually as dim_0, dim_1, ... transform (Sequence[dict], optional): Optional transform arguments; passed on to transformator as *args. **transform_kwargs: Passed on to the transformator as **kwargs. Returns: Union[xr.Variable, xr.DataArray]: The extracted data, which can be either a data array (if the path led to an xarray-interface supporting container) or an xr.Variable (if not). Raises: ValueError: Missing transformator """ def convert_dtype(data, dtype, *, path): """Change the dtype of the data, if it does not match the specified one. """ if data.dtype == dtype: return data log.debug( "Converting data from '%s' with dtype %s to %s ...", PATH_JOIN_CHAR.join(path), data.dtype, dtype, ) return data.astype(dtype) # Prepare the path, ensuring to work on the list representation if not isinstance(path, list): path = path.split(PATH_JOIN_CHAR) if base_path: path = base_path + path # Now, get the desired container cont = state_grp[path] # Apply the transformator on the container, if arguments given if transform or transform_kwargs: if self._PSPGRP_TRANSFORMATOR is None: raise ValueError( "Got transform arguments or kwargs, but " "no transformator callable was defined " "as class variable!" ) # Invoke the transformator on the container cont = self._PSPGRP_TRANSFORMATOR( cont, *transform, **transform_kwargs ) # Shortcut: specialised containers might already supply all the # information, including coordinates. In that case, return it as # a data array. if isinstance(cont, XrDataContainer): # Will return the underlying data. See if some dtype change # or dimension name relabelling was specified darr = cont.data if dtype is not None: darr = convert_dtype(darr, dtype, path=path) if dims is not None: darr = darr.rename( {old: new for old, new in zip(darr.dims, dims)} ) return darr elif isinstance(cont, (xr.DataArray, xr.Dataset)): # Actually was not a container but already the data; skip below return cont # If this was not the case, xr.Variable will have to be constructed # manually from the container. # The only pre-requisite for the data is that it is np.array-like, # which is always possible; worst case: scalar of dtype "object". data = np.array(cont.data) # Can now assume data to comply to np.array interface # Check the dtype and convert, if needed if dtype is not None: data = convert_dtype(data, dtype, path=path) # Get the attributes attrs = {k: v for k, v in cont.attrs.items()} # Generate dimension names, if not explicitly given. if dims is None: dims = [f"dim_{i}" for i in range(len(data.shape))] # Check whether indices are to be added (var from outer scope!) if not idx_as_label: # Can use these to construct an xr.Variable return xr.Variable(dims, data, attrs=attrs) # else: will need to be a DataArray; Variable does not hold coords # For each dimension, add trivial coordinates coords = {d: range(data.shape[i]) for i, d in enumerate(dims)} return xr.DataArray(data, dims=dims, coords=coords, attrs=attrs) def combine( *, method: str, dsets: np.ndarray, psp: ParamSpace ) -> "xr.Dataset": """Tries to combine the given datasets either by concatenation or by merging and returns a combined xr.Dataset """ # NOTE change for valid `method` value is carried out before this # function is called. # Merging . . . . . . . . . . . . . . . . . . . . . . . . . . . . . if method in ["merge"]: log.remark("Combining datasets by merging ...") # TODO consider warning about dtype changes?! dset = xr.merge(dsets.flat) log.remark("Merge successful.") return dset # else: Concatenation . . . . . . . . . . . . . . . . . . . . . . . log.remark( "Combining %d datasets by concatenation along %d " "dimensions ...", dsets.size, len(dsets.shape), ) # Reduce the dsets array to one dimension by applying xr.concat # along each axis. The returned object contains the combined data. reduced = _multi_concat(dsets, dims=psp.dims.keys()) log.remark("Concatenation successful.") return reduced # End of definition of helper functions. # . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . # . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . # Some initial checks if self.pspace is None: raise ValueError( f"Cannot get data from {self.logstr} without having a " "parameter space associated!" ) elif method not in ("concat", "merge"): raise ValueError( f"Invalid value for argument `method`: '{method}'. Can " "be: 'concat' (default), 'merge'" ) # Pre-process arguments . . . . . . . . . . . . . . . . . . . . . . . . # From field and fields arguments, generate a fields dict, such that # it can be handled uniformly below. fields = parse_fields(field=field, fields=fields) # Prepare the base path if base_path and not isinstance(base_path, list): base_path = base_path.split(PATH_JOIN_CHAR) # Work on a copy of the parameter space and apply the subspace masks psp = copy.deepcopy(self.pspace) if subspace: # Need the parameter space to be of non-zero volume if psp.volume == 0: raise ValueError( "Cannot select a subspace because the " "associated parameter space has no " "dimensions defined! Remove the `subspace` " "argument in the call to this method." ) try: psp.activate_subspace(**subspace) except KeyError as err: _dim_names = ", ".join(psp.dims.keys()) raise KeyError( "Could not select a subspace! " f"{type(err).__name__}: {err}\n" "Make sure your subspace selector contains " "only valid dimension names and coordinates. " f"Available dimension names: {_dim_names}" ) from err # Now, the data needs to be collected from each point in this subspace # and associated with the correct coordinate, such that the datasets # can later be merged and aligned by that coordinate. if psp.volume > 0: log.info( "Collecting data for %d fields from %d points in " "parameter space ...", len(fields), psp.volume, ) else: log.info( "Collecting data for %d fields from a dimensionless " "parameter space ...", len(fields), ) # Gather them in a multi-dimensional array dsets = np.zeros(psp.shape, dtype="object") dsets.fill(dict()) # these are ignored in xr.merge # Prepare the iterators psp_it = psp.iterator( with_info=("state_no", "current_coords"), omit_pt=True ) arr_it = np.nditer(dsets, flags=("multi_index", "refs_ok")) for (_state_no, _coords), _ in zip(psp_it, arr_it): # Select the corresponding state group try: _state_grp = get_state_grp(_state_no) except ValueError: if method == "merge": # In merge, this will mereley lead to a NaN ... log.warning("Missing state group: %d", _state_no) continue # ...but for concatenation, it will result in an error. raise # Get the variables for all fields _vars = { k: get_var(_state_grp, **f, base_path=base_path) for k, f in fields.items() } # Construct a dataset from that ... _dset = xr.Dataset(_vars) # ... and expand its dimensions to accomodate the point in pspace _dset = _dset.expand_dims({k: [v] for k, v in _coords.items()}) # Store it in the array of datasets dsets[arr_it.multi_index] = _dset # All data points collected now. log.info("Data collected.") # Finally, combine all the datasets together into a dataset with # potentially non-homogeneous data type. This will have at least the # dimensions given by the parameter space aligned, but there could # be potentially more dimensions. try: dset = combine(method=method, dsets=dsets, psp=psp) except ValueError as err: raise ValueError( "Combination of datasets failed; see below. This " "is probably due to a failure of alignment, " "which can be resolved by adding trivial " "coordinates (i.e.: the indices) to unlabelled " "dimensions by setting the `idx_as_label` " "argument to True." ) from err log.info("Data selected.") log.note( "Available data variables: %s", ", ".join(dset.data_vars) ) log.note( "Dataset dimensions and sizes: %s", ", ".join("{}: {}".format(*kv) for kv in dset.sizes.items()), ) return dset