Sankey Diagram

ML Training Pipeline

Machine learning data flow from raw datasets through preprocessing to model training and deployment.

Output
ML Training Pipeline
Python
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.path import Path

def draw_flow(ax, x0, y0, x1, y1, w0, w1, color, alpha=0.6):
    cx = (x0 + x1) / 2
    verts = [
        (x0, y0 + w0/2), (cx, y0 + w0/2), (cx, y1 + w1/2), (x1, y1 + w1/2),
        (x1, y1 - w1/2), (cx, y1 - w1/2), (cx, y0 - w0/2), (x0, y0 - w0/2),
        (x0, y0 + w0/2)
    ]
    codes = [Path.MOVETO] + [Path.CURVE4]*3 + [Path.LINETO] + [Path.CURVE4]*3 + [Path.CLOSEPOLY]
    ax.add_patch(mpatches.PathPatch(Path(verts, codes), fc=color, alpha=alpha, ec='none'))

def draw_node(ax, x, y, w, h, color, label):
    ax.add_patch(mpatches.FancyBboxPatch((x-w/2, y-h/2), w, h, boxstyle="round,pad=0.02",
                                          fc=color, ec='white', lw=1.5))
    ax.text(x, y, label, ha='center', va='center', fontsize=8, color='white', fontweight='bold')

fig, ax = plt.subplots(figsize=(14, 8), facecolor='#0a0a0f')
ax.set_facecolor('#0a0a0f')

s = 0.004

# Raw data to cleaning
draw_flow(ax, 0.5, 5, 2.5, 5, 1000*s, 1000*s, '#F527B0', 0.7)

# Clean to splits
draw_flow(ax, 3.5, 5.5, 5.5, 7, 100*s, 100*s, '#C82909', 0.7)  # Noise removed
draw_flow(ax, 3.5, 4.5, 5.5, 4, 900*s, 900*s, '#27D3F5', 0.7)  # Clean data

# Train/test split
draw_flow(ax, 6.5, 4.5, 8.5, 6, 180*s, 180*s, '#F5B027', 0.7)  # Test
draw_flow(ax, 6.5, 3.5, 8.5, 3, 720*s, 720*s, '#6CF527', 0.7)  # Train

# Nodes
draw_node(ax, 0, 5, 0.6, 1000*s*1.2, '#F527B0', 'Raw Data\n1M')
draw_node(ax, 3, 5, 0.6, 1000*s*1.2, '#4927F5', 'Cleaning\n1M')
draw_node(ax, 6, 7, 0.6, 100*s*3, '#C82909', 'Noise\n100K')
draw_node(ax, 6, 4, 0.6, 900*s*1.3, '#27D3F5', 'Clean\n900K')
draw_node(ax, 9, 6, 0.6, 180*s*2, '#F5B027', 'Test Set\n180K')
draw_node(ax, 9, 3, 0.6, 720*s*1.3, '#6CF527', 'Train Set\n720K')

ax.set_title('ML Training Pipeline Flow', fontsize=16, color='white', fontweight='bold', pad=20)
ax.set_xlim(-1, 10)
ax.set_ylim(0, 9)
ax.axis('off')
plt.tight_layout()
plt.show()
Library

Matplotlib

Category

Part-to-Whole

Did this help you?

Support PyLucid to keep it free & growing

Support