Source code for jobshoplab.env.rendering.state_transitions_rendering

from functools import partial

import matplotlib.pyplot as plt
import networkx as nx

from jobshoplab.state_machine.core.transitions import (BufferTransition,
                                                       MachineTransition,
                                                       TransportTransition)


[docs] def render_state_transitions(config, loglevel, backend): """ Render the state transitions. Args: config (Config): The configuration object. loglevel (int | str): The log level backend (callable): The backend to use for rendering. """ graph = [] fig, axs = plt.subplots(1, 3, figsize=(15, 5)) for i, transition in enumerate( [BufferTransition(), TransportTransition(), MachineTransition()] ): G = nx.DiGraph() for node in transition.states: G.add_node(node, label=node) # Add node with label _trans = transition.transitions[node] for edge in _trans: G.add_edge(node, edge, label=edge) # Add edge with label pos = nx.circular_layout(G) # Set the positions of the nodes nx.draw_networkx_nodes( G, pos, node_size=1000, ax=axs[i], node_color="white", edgecolors="black" ) nx.draw_networkx_edges( G, pos, ax=axs[i], edge_color="black", arrows=True, arrowsize=10, min_source_margin=20, # Adjust the value to make self-loops smaller min_target_margin=20, # Adjust the value to make self-loops smaller connectionstyle="arc3,rad=0", ) # Show arrow heads nx.draw_networkx_labels( G, pos, labels=nx.get_node_attributes(G, "label"), ax=axs[i], font_color="black", font_size=8, font_weight="bold", verticalalignment="center", ) # Add node labels axs[i].set_title(transition.__class__.__name__) plt.tight_layout() backend(config=config, loglevel=loglevel, fig=fig)
if __name__ == "__main__": from jobshoplab.env.rendering.backends import save_to_file config = load_config() args = {"instance": None, "name": "state_transition_graph_representation"} backend = partial(save_to_file, **args) render_state_transitions(config, 0, backend=backend)