Long time ago (somewhere in 2015) I have written an article about Dijkstra and the shortest path. In that article, I draw a graph on a notebook, just to illustrate what I was trying to traverse:

A picture of directed graph, from an article written in 2015.
Today, I was thinking to start set of videos in YouTube, explaining a bit Graph Theory the way I see it. And to do that, I needed some kind of visualization tool, as drawing a graph every time by myself can be pretty tiring. So, I found a way to draw that:

This one is mainly coming out of the box with the magic of nx.DiGraph()
So, after spending a few hours, preparing for a YT video, this is what I came up with:
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
import networkx as nx import matplotlib.pyplot as plt def generate_graph_plot(data_string): lines = data_string.strip().split('\n') header = lines[0].split() n_edges = int(header[1]) edge_lines = lines[1 : 1 + n_edges] node_lines = lines[1 + n_edges : -1] path_info = lines[-1].split() start_node = path_info[0] end_node = path_info[1] G = nx.DiGraph() for line in edge_lines: parts = line.split() u, v, w = parts[0], parts[1], int(parts[2]) G.add_edge(u, v, weight=w) for line in node_lines: parts = line.split() node = parts[0] h_value = int(parts[1]) if node not in G: G.add_node(node) G.nodes[node]['h'] = h_value layers = {n: 0 for n in G.nodes()} try: sorted_nodes = list(nx.topological_sort(G)) for node in sorted_nodes: current_depth = layers[node] for child in G.successors(node): # Push child to the right if needed if layers[child] < current_depth + 1: layers[child] = current_depth + 1 except nx.NetworkXUnfeasible: # Problem if graph has a cycle pass for node, depth in layers.items(): G.nodes[node]['subset'] = depth node_colors = [] for n in G.nodes(): if n == start_node: node_colors.append("orange") elif n == end_node: node_colors.append("red") else: node_colors.append('lightblue') node_labels = {} for n in G.nodes(): h_val = G.nodes[n].get('h', '?') node_labels[n] = f"{n}\n(h={h_val})" edge_labels = nx.get_edge_attributes(G, "weight") fig, ax = plt.subplots(figsize=(14, 8)) pos = nx.multipartite_layout(G, subset_key='subset') nx.draw_networkx_nodes(G, pos, node_color=node_colors, node_size=1500, ax=ax) nx.draw_networkx_edges(G, pos, edge_color='gray', arrowsize=20, ax=ax) nx.draw_networkx_labels(G, pos, labels=node_labels, font_size=8, ax=ax) nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels, font_size=8, ax=ax) ax.set_title(f"Graph Visualization: {start_node} to {end_node}") ax.axis('off') return fig |
And returning it is also easy:
|
1 2 3 4 5 6 7 8 9 10 11 12 |
data = """ 2 5 A B 1 A C 5 A D 10 D P 15 A P 100 A 50 A P """ figure_object = generate_graph_plot(data) plt.show() |

Large visualization of tiny graph.
Of course, the video describes the code before putting it into a function and is probably more interesting. The complete code (both the generate_graph_plot() and the one from the video) is in Github here:
Enjoy it! 🙂
