Source code for dantro.plot.funcs.graph

"""Plot functions to draw :py:class:`networkx.Graph` objects.

.. todo::

    Should really integrate utopya ``GraphPlot`` here!
"""

import copy
import logging
import os
from typing import Callable, Union

from ...tools import recursive_update as _recursive_update

log = logging.getLogger(__name__)

# -----------------------------------------------------------------------------


[docs]def _wiggle_pos( pos: dict, *, x: float = None, y: float = None, seed: int = None ) -> dict: """Wiggles positions by absolute random amplitudes in x and y direction Args: pos (dict): Positions dict with values being x and y positions x (float, optional): Absolute wiggle amplitude y (float, optional): Absolute wiggle amplitude seed (int, optional): Seed for the :py:class:`numpy.random.RandomState` that is used for drawing random numbers. Set to a fixed value to always get the same positions. """ import numpy as np import numpy.random rng = np.random.RandomState(seed) wiggle = lambda old_pos, amp: old_pos + amp * rng.uniform(-1, +1) if x is not None: pos = {n: [wiggle(p[0], x), p[1]] for n, p in pos.items()} if y is not None: pos = {n: [p[0], wiggle(p[1], y)] for n, p in pos.items()} return pos
[docs]def _get_positions( g: "networkx.Graph", *, model: Union[str, Callable], wiggle: dict = None, **kwargs, ) -> dict: """Returns the positions dict for the given graph, created from a networkx layouting algorithm of a certain name or an arbitrary callable. Args: g (networkx.Graph): The graph object for which to create the layout model (Union[str, Callable]): Name of the layouting model or the layouting function itself. If starting with ``graphviz_<prog>``, will invoke :py:func:`networkx.drawing.nx_agraph.graphviz_layout` with the given value for ``prog``. Note that these only take a single keyword argument, ``args``. If it is a string, it's looked up from the networkx namespace. If it is a callable, it is invoked with ``g`` as only positional argument and ``**kwargs`` as keyword arguments. wiggle (dict, optional): If given, will postprocess the positions dict by randomly wiggling x and y coordinates according to the absolute amplitudes given as values. **kwargs: Passed on to the layouting algorithm. """ import networkx as nx if callable(model): log.debug("Invoking callable for node layouting ...") pos = model(g, **kwargs) elif model.startswith("graphviz_"): log.debug("Invoking %s model for node layouting ...", model) try: model = model[len("graphviz_") :] pos = nx.drawing.nx_agraph.graphviz_layout(g, prog=model, **kwargs) except ImportError as err: raise ImportError( "Could not apply graphviz layout, probably because " "pygraphviz is not installed!" ) from err else: _suffix = "_layout" POSITIONING_MODELS_NETWORKX = { l[: -len(_suffix)]: getattr(nx, l) for l in dir(nx) if l.endswith(_suffix) } try: log.debug("Invoking %s model for node layouting ...", model) layout_func = POSITIONING_MODELS_NETWORKX[model] pos = layout_func(g, **kwargs) except KeyError as err: _avail = ", ".join(POSITIONING_MODELS_NETWORKX) raise ValueError( f"No layouting model '{model}' available in networkx! " f"Available models: {_avail}" ) from err if wiggle: pos = _wiggle_pos(pos, **wiggle) return pos
[docs]def get_positions( g: "networkx.Graph", *, model: Union[str, Callable] = "spring", model_kwargs: dict = {}, fallback: Union[str, dict] = None, silent_fallback: bool = False, **kwargs, ) -> dict: """Returns the positions dict for the given graph, created from a networkx layouting algorithm of a certain name or an arbitrary callable. This is a wrapper around :py:func:`._get_positions` which allows to specify a fallback layouting model to use in case the first one fails for whatever reason. Args: g (networkx.Graph): The graph object for which to create the layout model (Union[str, Callable], optional): Name of the layouting model or the layouting function itself. If starting with ``graphviz_<prog>``, will invoke :py:func:`networkx.drawing.nx_agraph.graphviz_layout` with the given value for ``prog``. Note that these only take a single keyword argument, ``args``. If it is a string, it's looked up from the networkx namespace. If it is a callable, it is invoked with ``g`` as only positional argument and ``**kwargs`` as keyword arguments. model_kwargs (dict, optional): A dict where keys correspond to names of layouting models and values are parameters that are to be passed to the layouting function. This dict may contain more arguments than required, only the ``model`` key is looked up here. This can be useful for providing a wider set of defaults. These defaults are not considered when ``model`` is a callable. fallback (Union[str, dict], optional): The fallback model name (if a string) or a dict containing the key ``model`` and further kwargs. silent_fallback (bool, optional): Whether to log a visible message about the fallback or a more discrete one. **kwargs: Passed on to the layouting algorithm in addition to the selected entry from ``model_kwargs``. Keys given here update those from ``model_kwargs``. Also, these are *not* passed on to the fallback invocation. """ def parse_kwargs(model: Union[str, Callable], **kwargs) -> dict: """Performs lookup and update of the layouting model arguments""" if callable(model): return kwargs return _recursive_update( copy.deepcopy(model_kwargs.get(model, {})), kwargs ) # Prepare arguments if isinstance(fallback, str): fallback = dict(model=fallback) # Get the positions, potentially using a fallback try: return _get_positions(g, model=model, **parse_kwargs(model, **kwargs)) except Exception as exc: if not fallback: raise if not silent_fallback: log.caution( "Node layouting with '%s' model failed with a %s: %s", model, type(exc).__name__, exc, ) log.remark("Invoking fallback layouting: %s", fallback) else: log.remark( "Node layouting with '%s' model failed; using fallback: %s", model, fallback, ) return _get_positions( g, model=fallback["model"], **parse_kwargs(**fallback) )
# .............................................................................
[docs]def _draw_graph( g: "networkx.Graph", *, ax: "matplotlib.axes.Axes" = None, drawing: dict = {}, layout: dict = {}, ) -> list: """Draws a graph using :py:func:`networkx.drawing.nx_pylab.draw_networkx_nodes`, :py:func:`networkx.drawing.nx_pylab.draw_networkx_edges`, and :py:func:`networkx.drawing.nx_pylab.draw_networkx_labels`. .. warning:: This function is not yet completed and may change anytime. Args: g (networkx.Graph): The graph to draw out_path (str): Where to store it to drawing (dict, optional): Drawing arguments, containing the ``nodes``, ``edges`` and ``labels`` keys. The ``labels`` key can contain the ``from_attr`` key which will read the attribute specified there and use it for the label. layout (dict, optional): Used to generate node positions via the :py:func:`~dantro.plot.funcs.graph.get_positions` function. """ import matplotlib.pyplot as plt import networkx as nx def draw( g, *, ax, pos, nodes: dict = {}, edges: dict = {}, labels: dict = {} ): # Parse some attributes if "from_attr" in labels: labels["labels"] = nx.get_node_attributes( g, labels.pop("from_attr") ) # Now draw _nodes = nx.draw_networkx_nodes(g, pos=pos, ax=ax, **nodes) _edges = nx.draw_networkx_edges(g, pos=pos, ax=ax, **edges) _labels = nx.draw_networkx_labels(g, pos=pos, ax=ax, **labels) # Gather artists, adapting to the specific way they are passed back artists = [] artists.append(_nodes) artists += _edges artists += list(_labels.values()) return artists ax = ax if ax is not None else plt.gca() pos = get_positions(g, **layout) # Draw artists = draw(g, ax=ax, pos=pos, **drawing) # Post-process ax.axis("off") return artists