Source code for pytspl.plot.plot

"""Module for plotting simplicial complexes."""

from collections.abc import Iterable
from numbers import Number

import matplotlib as mpl
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np

from pytspl.decomposition.frequency_component import FrequencyComponent
from pytspl.simplicial_complex import SimplicialComplex


[docs] class SCPlot: """Class for plotting simplicial complexes.""" def __init__( self, simplicial_complex: SimplicialComplex, coordinates: dict = None, ) -> None: """ Args: simplicial_complex (SimplicialComplex): The simplicial complex network object. coordinates (dict, optional): Dict of positions [node_id : (x, y)] is used for placing the 0-simplices. The standard nx spring layer is used otherwise. """ self.sc = simplicial_complex self.pos = coordinates
[docs] def _init_axes(self, ax) -> dict: """ Initialize the axes for the plot. The axis limits are set to the bounding box of the nodes. Args: ax (matplotlib.axes.Axes): The axes object. Returns: dict: The layout of the nodes. """ layout = self.pos if self.pos is None: # use spring layout if no coordinates are provided G = nx.Graph() G.add_edges_from(self.sc.edges) layout = nx.spring_layout(G) self.pos = layout # set the axis limits to a square ax.set_xlim([-1.1, 1.1]) ax.set_ylim([-1.1, 1.1]) else: # scale the coordinates x = [x[0] for x in self.pos.values()] y = [x[1] for x in self.pos.values()] min_x, max_x = min(x), max(x) min_y, max_y = min(y), max(y) # add padding to the bounding box x_padding = (max_x - min_x) * 0.05 y_padding = (max_y - min_y) * 0.05 # set the axis limits according to the bounding box of the nodes ax.set_xlim([min_x - x_padding, max_x + x_padding]) ax.set_ylim([min_y - y_padding, max_y + y_padding]) # layout configuration ax.get_xaxis().set_ticks([]) ax.get_yaxis().set_ticks([]) ax.axis("off") return layout
[docs] def create_edge_flow(self, flow: np.ndarray) -> dict: """ Create a dictionary of edge flows from the flow array. Args: flow (np.ndarray): The flow on the edges. Returns: dict: The edge flow dictionary. """ return dict(zip(self.sc.edges, flow))
[docs] def draw_sc_nodes( self, node_size: int = 300, node_color: str = "#ff7f0e", node_edge_colors: str = "black", font_size: float = 12, font_color: str = "k", font_weight: str = "normal", cmap=plt.cm.Blues, vmin=None, vmax=None, alpha: float = 0.8, margins=None, with_labels: bool = False, ax=None, ) -> None: """ Draw the nodes of the simplicial complex. Args: node_size (int, optional): The size of the nodes. Defaults to 300. node_color (str, optional): The color of the nodes. Defaults to '#ff7f0e'. node_edge_colors (str, optional): The color of the node edges. Defaults to 'black'. font_size (float, optional): The font size of the node labels. Defaults to 12. font_color (str, optional): The color of the node labels. Defaults to 'k'. font_weight (str, optional): The font weight of the node labels. Defaults to 'normal'. cmap (mpl.colors.Colormap, optional): The color map. Defaults to plt.cm.Blues. vmin (float, optional): The minimum value for the color map. Defaults to None. vmax (float, optional): The maximum value for the color map. Defaults to None. alpha (float, optional): The transparency of the nodes. Defaults to 0.8. margins (float, optional): The margins of the plot. Defaults to None. with_labels (bool, optional): Whether to show the node labels. Defaults to False. ax (matplotlib.axes.Axes, optional): The axes object. Defaults to None. """ if ax is None: ax = plt.gca() self._init_axes(ax=ax) if np.iterable(node_color) and np.all( [isinstance(c, Number) for c in node_color] ): if cmap is not None: assert isinstance(cmap, mpl.colors.Colormap) else: cmap = plt.get_cmap() if vmin is None: # for more contrast vmin = min(node_color) - abs(min(node_color)) * 0.5 if vmax is None: vmax = max(node_color) # add colorbar color_map = mpl.cm.ScalarMappable(cmap=cmap) color_map.set_clim(vmin=vmin, vmax=vmax) fig = ax.get_figure() fig.colorbar(mappable=color_map, ax=ax) nodes = self.sc.nodes node_collection = ax.scatter( [self.pos[node_id][0] for node_id in nodes], [self.pos[node_id][1] for node_id in nodes], s=node_size, c=node_color, edgecolors=node_edge_colors, vmin=vmin, vmax=vmax, alpha=alpha, ) if margins is not None: if isinstance(margins, Iterable): ax.margins(*margins) else: ax.margins(margins) if with_labels: self._draw_node_labels( font_size=font_size, font_weight=font_weight, font_color=font_color, alpha=alpha, ) node_collection.set_zorder(2)
[docs] def _draw_node_labels( self, font_size: float = 12, font_color: str = "k", font_weight: str = "normal", alpha=None, ) -> None: """ Draw the labels of the nodes. Args: font_size (float, optional): The font size of the node labels. Defaults to 12. font_color (str, optional): The color of the node labels. Defaults to 'k'. font_weight (str, optional): The font weight of the node labels. Defaults to 'normal'. alpha (float, optional): The transparency of the node labels. Defaults to None. """ for node_id in self.sc.nodes: (x, y) = self.pos[node_id] plt.text( x, y, node_id, fontsize=font_size, color=font_color, weight=font_weight, ha="center", va="center", alpha=alpha, )
[docs] def draw_sc_edges( self, edge_flow: dict = None, edge_color: str = "lightblue", edge_width: float = 1.0, arrowsize: int = 10, edge_cmap=plt.cm.Blues, edge_vmin=None, edge_vmax=None, directed: bool = True, alpha: float = 0.8, ax=None, ) -> None: """ Draw the edges of the simplicial complex. Args: edge_flow (dict, optional): The flow of the edges. e.g. {(0, 1): 0.5, (1, 2): 0.3, (2, 0): 0.2}. Defaults to None. edge_color (str, optional): The color of the edges. Defaults to 'lightblue'. edge_width (float, optional): The width of the edges. Defaults to 1.0. arrowsize (int, optional): The size of the arrows. Defaults to 10. edge_cmap (mpl.colors.Colormap, optional): The color map of the edges. Defaults to plt.cm.Blues. edge_vmin (float, optional): The minimum value for the color map. Defaults to None. edge_vmax (float, optional): The maximum value for the color map. Defaults to None. directed (bool, optional): Whether the edges are directed. Defaults to True. alpha (float, optional): The transparency of the edges. Defaults to 0.8. ax (matplotlib.axes.Axes, optional): The axes object. Defaults to None. """ if edge_flow: assert isinstance(edge_flow, dict) if ax is None: ax = plt.gca() _, fig_height = ax.get_figure().get_size_inches() fig = ax.get_figure() self._init_axes(ax=ax) # if edge labels are provided, use them to color the edges if edge_flow is not None: edges = list(edge_flow.keys()) edge_color = list(edge_flow.values()) else: edges = self.sc.edges # create a graph graph = nx.DiGraph() graph.add_edges_from(edges) # edge color is iterable and all elements are numbers if np.iterable(edge_color) and np.all( [isinstance(c, Number) for c in edge_color] ): # check if edge_cmap is a colormap if edge_cmap is not None: assert isinstance(edge_cmap, mpl.colors.Colormap) else: edge_cmap = plt.get_cmap() # set the color map limits if edge_vmin is None: # for more contrast edge_vmin = min(edge_color) - abs(min(edge_color)) * 0.5 if edge_vmax is None: edge_vmax = max(edge_color) # add colorbar color_map = mpl.cm.ScalarMappable( cmap=edge_cmap, ) # set the color map limits color_map.set_clim(vmin=edge_vmin, vmax=edge_vmax) fig.colorbar( mappable=color_map, ax=ax, ).ax.tick_params(labelsize=fig_height * 2) # reorder the edges to match the order of the edge colors edge_color = [ edge_color[edges.index(edge)] for edge in graph.edges() ] # draw edges nx.draw_networkx_edges( graph, pos=self.pos, edge_color=edge_color, width=edge_width, alpha=alpha, arrowsize=arrowsize, edge_cmap=edge_cmap, edge_vmin=edge_vmin, edge_vmax=edge_vmax, ax=ax, arrows=directed, ) # fill the 2-simplices (triangles) for i, j, k in self.sc.triangles: (x0, y0) = self.pos[i] (x1, y1) = self.pos[j] (x2, y2) = self.pos[k] tri = plt.Polygon( [[x0, y0], [x1, y1], [x2, y2]], edgecolor="k", facecolor=plt.cm.Blues(0.4), alpha=0.3, lw=0.5, zorder=0, ) ax.add_patch(tri)
[docs] def _calculate_edge_label_position( self, src: tuple, dest: tuple, offset: float ) -> tuple: """Calculate the position of the edge label based on the edge position.""" center_coeff = 0.5 (x1, y1) = self.pos[src] (x2, y2) = self.pos[dest] # calculate the slope of the edge if x2 - x1 == 0: slope = 0 else: slope = (y2 - y1) / (x2 - x1) # straight edge if np.abs(slope) == 0: # horizontal edge if x1 == x2: x, y = ( x1 * center_coeff + x2 * (1.0 - center_coeff) + (offset / 3), y1 * center_coeff + y2 * (1.0 - center_coeff), ) else: # vertical edge x, y = ( x1 * center_coeff + x2 * (1.0 - center_coeff), y1 * center_coeff + y2 * (1.0 - center_coeff) - offset, ) elif np.abs(slope) <= 0.1: x, y = ( x1 * center_coeff + x2 * (1.0 - center_coeff), y1 * center_coeff + y2 * (1.0 - center_coeff) + (offset / 2), ) # diagonal edge else: x, y = ( x1 * center_coeff + x2 * (1.0 - center_coeff) - offset, y1 * center_coeff + y2 * (1.0 - center_coeff) + offset, ) return x, y
[docs] def draw_edge_labels( self, edge_labels: dict, font_size: int = 10, font_color: str = "k", font_weight: str = "normal", offset=0.15, alpha=None, ax=None, ) -> dict: """ Draw the labels (flow) of the edges. Args: edge_labels (dict): The labels of the edges. e.g. {(0, 1): 0.5, (1, 2): 0.3, (2, 0): 0.2} Defaults to None. font_size (int, optional): The font size of the labels. Defaults to 10. font_color (str, optional): The color of the labels. Defaults to 'k'. font_weight (str, optional): The font weight of the labels. Defaults to 'normal'. offset (float, optional): The offset of the labels from the center of the edge. Defaults to 0.15. alpha (float, optional): The transparency of the labels. Defaults to None. ax (matplotlib.axes.Axes, optional): The axes object. Defaults to None. """ assert isinstance(edge_labels, dict) if ax is None: ax = plt.gca() self._init_axes(ax=ax) edge_items = {} for (src, dest), label in edge_labels.items(): (x, y) = self._calculate_edge_label_position( src=src, dest=dest, offset=offset ) t = ax.text( x, y, label, size=font_size, color=font_color, weight=font_weight, alpha=alpha, zorder=1, ) edge_items[(src, dest)] = t return edge_items
[docs] def draw_network( self, edge_flow=None, directed: bool = True, with_labels: bool = True, ax=None, **kwargs, ) -> None: """ Draw the simplicial complex network with edge flow. If the flow is not provided, the network is drawn without flow. Args: edge_flow (dict, np.ndarray, list, optional): The labels of the edges. e.g. {(0, 1): 0.5, (1, 2): 0.3, (2, 0): 0.2}. You can also provide a numpy array of the flow. Defaults to None. directed (bool, optional): Whether the edges are directed. Defaults to True. with_labels (bool, optional): Whether to show the node labels. Defaults to True. ax (matplotlib.axes.Axes, optional): The axes object. Defaults to None. Node kwargs: node_size (int, optional): The size of the nodes. Defaults to 300. node_color (str, optional): The color of the nodes. Defaults to '#ff7f0e'. node_edge_colors (str, optional): The color of the node edges. Defaults to 'black'. font_size (float, optional): The font size of the node labels. Defaults to 12. font_color (str, optional): The color of the node labels. Defaults to 'k'. font_weight (str, optional): The font weight of the node labels. Defaults to 'normal'. cmap (mpl.colors.Colormap, optional): The color map. Defaults to plt.cm.Blues. vmin (float, optional): The minimum value for the color map. Defaults to None. vmax (float, optional): The maximum value for the color map. Defaults to None. alpha (float, optional): The transparency of the nodes. Defaults to 0.8. margins (float, optional): The margins of the plot. Defaults to None. Edge kwargs: edge_color (str, optional): The color of the edges. Defaults to 'lightblue'. edge_width (float, optional): The width of the edges. Defaults to 1.0. arrowsize (int, optional): The size of the arrows. Defaults to 10. edge_cmap (mpl.colors.Colormap, optional): The color map of the edges. Defaults to plt.cm.Blues. edge_vmin (float, optional): The minimum value for the color map. Defaults to None. edge_vmax (float, optional): The maximum value for the color map. Defaults to None. directed (bool, optional): Whether the edges are directed. Defaults to True. alpha (float, optional): The transparency of the edges. Defaults to 0.8. Edge label kwargs: font_size (int, optional): The font size of the labels. Defaults to 10. font_color (str, optional): The color of the labels. Defaults to 'k'. font_weight (str, optional): The font weight of the labels. Defaults to 'normal'. offset (float, optional): The offset of the labels from the center of the edge. Defaults to 0.15. alpha (float, optional): The transparency of the labels. Defaults to None. """ from inspect import signature # check if any kwargs passed - default values are used if len(kwargs) == 0: kwargs = { "node_size": 400, "edge_width": 5, "arrowsize": 30, "font_size": 12, } # get the default arguments of the function node_kwargs = signature(self.draw_sc_nodes).parameters.keys() edge_kwargs = signature(self.draw_sc_edges).parameters.keys() label_kwargs = signature(self.draw_edge_labels).parameters.keys() valid_kwargs = (node_kwargs | edge_kwargs | label_kwargs) - { "edge_flow", "directed", "with_labels", "ax", } if any([k not in valid_kwargs for k in kwargs]): invalid_args = ", ".join( [k for k in kwargs if k not in valid_kwargs] ) raise ValueError(f"Invalid arguments: {invalid_args}") node_kwargs = {k: v for k, v in kwargs.items() if k in node_kwargs} edge_kwargs = {k: v for k, v in kwargs.items() if k in edge_kwargs} label_kwargs = {k: v for k, v in kwargs.items() if k in label_kwargs} # initialize the axes if ax is None: ax = plt.gca() if isinstance(edge_flow, (np.ndarray, list)): edge_flow = self.create_edge_flow(flow=edge_flow) # if edge labels are provided, use them to color the edges if not np.iterable(edge_flow): edge_color = "lightblue" else: edge_color = list(edge_flow.values()) # draw the nodes self.draw_sc_nodes(with_labels=with_labels, ax=ax, **node_kwargs) # draw the edges self.draw_sc_edges( edge_flow=edge_flow, directed=directed, ax=ax, **edge_kwargs, ) # plot edge labels if with_labels and np.all([isinstance(c, Number) for c in edge_color]): self.draw_edge_labels(edge_labels=edge_flow, ax=ax, **label_kwargs)
[docs] def draw_hodge_decomposition( self, flow: np.ndarray, component=None, round_fig: bool = True, round_sig_fig: int = 2, figsize=(15, 5), font_dict={"fontsize": 20}, ) -> None: """ Draw the Hodge decomposition of the flow. Args: flow (np.ndarray): The flow on the edges. component (str, optional): The component of the flow to draw. If None, all three components are drawn. Defaults to None. round_fig (bool, optional): Whether to round the figures. Defaults to True. round_sig_fig (int, optional): The number of significant figures to round to. Defaults to 2. figsize (tuple, optional): The size of the figure. Defaults to (15, 5). font_dict (dict, optional): The font dictionary. Defaults to {"fontsize": 20}. Raises: ValueError: If an invalid component is provided. """ fig = plt.figure(figsize=figsize) if component is not None: component_flow = self.sc.get_component_flow( flow=flow, component=component, round_fig=round_fig, round_sig_fig=round_sig_fig, ) # create a single figure ax = fig.add_subplot(1, 1, 1) ax.set_title( rf"$\mathbf{{f_{{{component[0].upper()}}}}}$", fontdict=font_dict, ) self.draw_network(edge_flow=component_flow, ax=ax) # if no component is specified, draw all three components else: f_g = self.sc.get_component_flow( flow=flow, component=FrequencyComponent.GRADIENT.value, round_fig=round_fig, round_sig_fig=round_sig_fig, ) f_c = self.sc.get_component_flow( flow=flow, component=FrequencyComponent.CURL.value, round_fig=round_fig, round_sig_fig=round_sig_fig, ) f_h = self.sc.get_component_flow( flow=flow, component=FrequencyComponent.HARMONIC.value, round_fig=round_fig, round_sig_fig=round_sig_fig, ) # gradient flow ax1 = fig.add_subplot(1, 3, 1) ax1.set_title(rf"$\mathbf{{f_{{G}}}}$", fontdict=font_dict) self.draw_network(edge_flow=f_g, ax=ax1) # curl flow ax2 = fig.add_subplot(1, 3, 2) ax2.set_title(rf"$\mathbf{{f_{{C}}}}$", fontdict=font_dict) self.draw_network(edge_flow=f_c, ax=ax2) # harmonic flow ax3 = fig.add_subplot(1, 3, 3) ax3.set_title(rf"$\mathbf{{f_{{H}}}}$", fontdict=font_dict) self.draw_network(edge_flow=f_h, ax=ax3) plt.show()
[docs] def draw_eigenvectors( self, component: str, eigenvector_indices: np.ndarray = [], round_fig: bool = True, round_sig_fig: int = 2, with_labels: bool = True, figsize=(15, 5), font_dict={"fontsize": 20}, ): """ Draw the eigenvectors for the given component and eigenvalue indices using eigendecomposition. Args: component (str): The component of the eigenvectors to draw. eigenvector_indices (np.ndarray, optional): The indices of the eigenvectors to draw. Defaults to []. round_fig (bool, optional): Whether to round the figures. Defaults to True. round_sig_fig (int, optional): The number of significant figures to round to. Defaults to 2. with_labels (bool, optional): Whether to show the node labels. Defaults to True. figsize (tuple, optional): The size of the figure. Defaults to (15, 5). font_dict (dict, optional): The font dictionary. Defaults to {"fontsize": 20}. """ viz_per_row = 3 U, eigenvals = self.sc.get_component_eigenpair(component=component) # if no eigenvector indices are provided, draw all eigenvectors if len(eigenvector_indices) == 0: eigenvector_indices = range(len(eigenvals)) # Assuming you have a total number of eigenvector_indices as num_plots num_plots = len(eigenvector_indices) # Calculate the number of columns needed num_cols = min(num_plots, viz_per_row) # Calculate the number of rows needed num_rows = num_plots // num_cols if num_plots % num_cols != 0: num_rows += 1 positions = range(1, num_plots + 1) # adjust the figure size to fit all the plots if num_rows > 1: new_figsize = (figsize[0], figsize[1] * num_rows) fig = plt.figure(1, figsize=new_figsize) else: if num_cols != 1: figsize = ((figsize[0] / viz_per_row) * num_cols, figsize[1]) fig = plt.figure(1, figsize=figsize) for i, eig_vec in enumerate(eigenvector_indices): ax = fig.add_subplot(num_rows, num_cols, positions[i]) ax.set_title( rf"$\lambda_{{{component[0].upper()}}}$" + f" = {round(eigenvals[eig_vec], round_sig_fig)}", fontdict=font_dict, ) flow = U[:, eig_vec] if round_fig: flow = np.round(flow, round_sig_fig) self.draw_network(edge_flow=flow, ax=ax, with_labels=with_labels) plt.tight_layout() plt.show()