Visualizing Decision Trees with Python

Master decision tree visualization in Python. Learn plot_tree, export_graphviz, text rules, decision boundaries, feature importance plots, and interactive tree tools with full code examples.

Visualizing Decision Trees with Python

Python offers multiple ways to visualize decision trees: sklearn.tree.plot_tree() renders the tree directly in matplotlib with colored nodes; export_graphviz() produces publication-quality Graphviz diagrams; export_text() generates human-readable rule text; and decision boundary plots show how the tree partitions 2D feature space into rectangular regions. Each method reveals a different aspect of what the tree has learned.

Introduction

One of the defining advantages of decision trees over most machine learning models is that their logic is completely transparent — every prediction follows a path of explicit rules that can be read, explained, and critiqued. But this transparency only becomes useful when it is presented in a form that humans can actually process. A tree with 500 nodes printed as a Python dictionary is technically complete but practically useless.

Visualization transforms the decision tree’s formal structure into something comprehensible. A well-designed visualization lets you immediately see which features matter most (they appear high in the tree), how the model partitions the feature space (through decision boundary plots), which nodes contain the most impurity (revealed by node coloring), and where the tree makes confident versus uncertain predictions (leaf node class distributions).

This article covers every major visualization technique for decision trees in Python: the built-in matplotlib rendering, Graphviz for publication-quality output, text-based rule extraction for non-technical audiences, decision boundary visualization for 2D datasets, feature importance charts, path tracing for individual predictions, and tree complexity analysis. Every example includes complete, runnable code.

Setup and Data Preparation

All examples in this article use three datasets that highlight different aspects of visualization:

Python
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.colors import ListedColormap
from sklearn.tree import (DecisionTreeClassifier, DecisionTreeRegressor,
                           plot_tree, export_text, export_graphviz)
from sklearn.datasets import (load_iris, load_wine, load_breast_cancer,
                               make_classification, make_moons, make_circles)
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import warnings
warnings.filterwarnings('ignore')

# ── Dataset 1: Iris (classic, small, 4 features, 3 classes) ───────────────
iris  = load_iris()
X_ir, y_ir = iris.data, iris.target
X_tr_ir, X_te_ir, y_tr_ir, y_te_ir = train_test_split(
    X_ir, y_ir, test_size=0.25, random_state=42, stratify=y_ir
)

# ── Dataset 2: Breast Cancer (30 features, binary, medical context) ────────
cancer = load_breast_cancer()
X_ca, y_ca = cancer.data, cancer.target
X_tr_ca, X_te_ca, y_tr_ca, y_te_ca = train_test_split(
    X_ca, y_ca, test_size=0.25, random_state=42, stratify=y_ca
)

# ── Dataset 3: Wine (13 features, 3 classes) ───────────────────────────────
wine  = load_wine()
X_wi, y_wi = wine.data, wine.target
X_tr_wi, X_te_wi, y_tr_wi, y_te_wi = train_test_split(
    X_wi, y_wi, test_size=0.25, random_state=42, stratify=y_wi
)

# Train trees at different complexities
dt_ir_d3 = DecisionTreeClassifier(max_depth=3, random_state=42).fit(X_tr_ir, y_tr_ir)
dt_ir_d5 = DecisionTreeClassifier(max_depth=5, random_state=42).fit(X_tr_ir, y_tr_ir)
dt_ca_d4 = DecisionTreeClassifier(max_depth=4, criterion='gini', random_state=42).fit(X_tr_ca, y_tr_ca)
dt_wi_d4 = DecisionTreeClassifier(max_depth=4, random_state=42).fit(X_tr_wi, y_tr_wi)

print("Trees trained:")
for name, dt in [("Iris d=3", dt_ir_d3), ("Iris d=5", dt_ir_d5),
                  ("Cancer d=4", dt_ca_d4), ("Wine d=4", dt_wi_d4)]:
    print(f"  {name:<12}: {dt.get_n_leaves()} leaves, "
          f"depth={dt.get_depth()}, "
          f"test acc={dt.score(X_te_ir if 'Iris' in name else (X_te_ca if 'Cancer' in name else X_te_wi), y_te_ir if 'Iris' in name else (y_te_ca if 'Cancer' in name else y_te_wi)):.4f}")

Method 1: plot_tree() — Inline Matplotlib Visualization

sklearn.tree.plot_tree() renders the tree directly as a matplotlib figure. It is the fastest way to visualize a tree without any external dependencies.

Python
from sklearn.tree import plot_tree

def visualize_with_plot_tree(dt, feature_names, class_names, title,
                               figsize=(20, 10), fontsize=11, max_depth=None):
    """
    Render a decision tree inline using matplotlib's plot_tree.
    
    Args:
        dt:            Fitted DecisionTreeClassifier or Regressor
        feature_names: List of feature name strings
        class_names:   List of class name strings (None for regression)
        title:         Plot title
        figsize:       Figure dimensions in inches
        fontsize:      Font size for node text
        max_depth:     Limit displayed depth (None = full tree)
    
    Returns:
        Figure object
    """
    fig, ax = plt.subplots(figsize=figsize)
    
    plot_tree(
        dt,
        max_depth=max_depth,        # Limit displayed levels
        feature_names=feature_names,
        class_names=class_names,
        filled=True,                # Color nodes by majority class
        rounded=True,               # Rounded node boxes
        impurity=True,              # Show Gini/entropy at each node
        proportion=False,           # Show counts, not proportions
        precision=3,                # Decimal places for thresholds
        fontsize=fontsize,
        ax=ax,
    )
    
    ax.set_title(title, fontsize=14, fontweight='bold', pad=20)
    plt.tight_layout()
    return fig


# Visualization 1: Iris, depth 3 — the classic readable tree
fig1 = visualize_with_plot_tree(
    dt_ir_d3,
    feature_names=iris.feature_names,
    class_names=iris.target_names,
    title='Decision Tree: Iris Classification (max_depth=3)\n'
          'Color intensity indicates class purity; deeper blue/orange/green = purer node',
    figsize=(18, 9), fontsize=11
)
fig1.savefig('dt_iris_depth3.png', dpi=150, bbox_inches='tight')
plt.show()
print("Saved: dt_iris_depth3.png")


# Visualization 2: Breast Cancer, depth 4 — medical context
fig2 = visualize_with_plot_tree(
    dt_ca_d4,
    feature_names=cancer.feature_names,
    class_names=cancer.target_names,
    title='Decision Tree: Breast Cancer Diagnosis (max_depth=4)\n'
          'Binary classification: Malignant vs Benign',
    figsize=(22, 10), fontsize=9
)
fig2.savefig('dt_cancer_depth4.png', dpi=150, bbox_inches='tight')
plt.show()
print("Saved: dt_cancer_depth4.png")

Understanding the Node Information

Each node in a plot_tree visualization contains four pieces of information:

  • Splitting condition (internal nodes only): The feature name and threshold used for this split, e.g., “petal length (cm) <= 2.45”
  • Impurity: The Gini or entropy value at this node — lower means purer
  • Samples: Number of training samples that reached this node
  • Value: Array of sample counts per class at this node, e.g., [50, 0, 0] means all 50 samples are class 0
  • Class (leaf nodes): The majority class label at this node — the prediction for any sample reaching this leaf

The color intensity encodes purity: a deeply colored node is nearly or completely pure (one dominant class), while a nearly white node has mixed class representation.

Python
def annotate_node_statistics(dt, X_train, y_train, class_names):
    """
    Print detailed statistics for each node in the tree.
    Useful for understanding what each node captures.
    """
    tree = dt.tree_
    n_nodes = tree.node_count
    
    print(f"=== Node-Level Statistics ===\n")
    print(f"  Total nodes: {n_nodes}")
    print(f"  Leaf nodes:  {dt.get_n_leaves()}")
    print(f"  Internal:    {n_nodes - dt.get_n_leaves()}\n")
    
    # Traverse nodes
    feature = tree.feature
    threshold = tree.threshold
    
    print(f"  {'Node':>5} | {'Type':>8} | {'Depth':>6} | {'Samples':>8} | "
          f"{'Impurity':>9} | {'Majority':>15} | {'Purity %':>9}")
    print("  " + "-" * 72)
    
    # DFS traversal using a stack
    stack = [(0, 0)]  # (node_id, depth)
    while stack:
        node_id, depth = stack.pop()
        
        is_leaf    = (tree.children_left[node_id] == -1)
        n_samples  = tree.n_node_samples[node_id]
        impurity   = tree.impurity[node_id]
        class_counts = tree.value[node_id][0]  # Shape: (n_classes,)
        majority_cls = class_names[np.argmax(class_counts)]
        purity_pct   = class_counts.max() / class_counts.sum() * 100
        node_type    = "Leaf" if is_leaf else "Split"
        
        print(f"  {node_id:>5} | {node_type:>8} | {depth:>6} | {n_samples:>8} | "
              f"{impurity:>9.4f} | {majority_cls:>15} | {purity_pct:>9.1f}%")
        
        if not is_leaf:
            stack.append((tree.children_right[node_id], depth + 1))
            stack.append((tree.children_left[node_id],  depth + 1))

annotate_node_statistics(dt_ir_d3, X_tr_ir, y_tr_ir, iris.target_names)

Method 2: export_text() — Human-Readable Rule Text

For non-technical audiences or when you need to embed tree logic in documentation, export_text() produces a clean indented text representation that reads as a series of if-then conditions.

Python
from sklearn.tree import export_text

def print_tree_rules(dt, feature_names, class_names=None,
                      max_depth=None, spacing=3):
    """
    Print decision tree as human-readable if-then rules.
    
    Args:
        dt:            Fitted decision tree
        feature_names: Feature name strings
        class_names:   Class name strings (None for regression)
        max_depth:     Maximum depth to display
        spacing:       Indentation spacing per level
    """
    rules = export_text(
        dt,
        feature_names=list(feature_names),
        max_depth=max_depth,
        spacing=spacing,
        show_weights=True,     # Show sample counts at each node
        decimals=3,
    )
    print(rules)


# Iris tree: easy to read for any audience
print("=== Iris Decision Tree Rules (depth=3) ===\n")
print_tree_rules(dt_ir_d3, iris.feature_names,
                  class_names=iris.target_names)


def extract_decision_rules(dt, feature_names, class_names):
    """
    Extract all root-to-leaf paths as explicit if-then rules.
    
    Returns a list of (conditions, prediction, n_samples, confidence)
    tuples representing each complete decision path.
    """
    tree = dt.tree_
    
    rules = []
    
    def recurse(node_id, conditions):
        """DFS to collect all root-to-leaf paths."""
        is_leaf = (tree.children_left[node_id] == -1)
        
        if is_leaf:
            class_counts = tree.value[node_id][0]
            n_samples    = int(tree.n_node_samples[node_id])
            majority_idx = np.argmax(class_counts)
            confidence   = class_counts[majority_idx] / class_counts.sum()
            
            rules.append({
                'conditions':  list(conditions),
                'prediction':  class_names[majority_idx],
                'n_samples':   n_samples,
                'confidence':  confidence,
                'class_dist':  {class_names[i]: int(c)
                                for i, c in enumerate(class_counts)},
            })
            return
        
        feat  = feature_names[tree.feature[node_id]]
        thresh = tree.threshold[node_id]
        
        # Left branch: condition is True
        recurse(tree.children_left[node_id],
                conditions + [f"{feat} <= {thresh:.3f}"])
        
        # Right branch: condition is False
        recurse(tree.children_right[node_id],
                conditions + [f"{feat} > {thresh:.3f}"])
    
    recurse(0, [])
    return rules


# Extract and display all rules
print("\n=== All Decision Rules: Iris (depth=3) ===\n")
rules_list = extract_decision_rules(
    dt_ir_d3, iris.feature_names, iris.target_names
)

for i, rule in enumerate(rules_list, 1):
    print(f"  Rule {i}: Predict '{rule['prediction']}' "
          f"(confidence={rule['confidence']*100:.1f}%, n={rule['n_samples']})")
    for cond in rule['conditions']:
        print(f"    IF {cond}")
    dist = rule['class_dist']
    print(f"    Distribution: {dist}")
    print()


def rules_to_python_function(dt, feature_names, class_names, func_name="predict"):
    """
    Convert a decision tree to a standalone Python function.
    
    Useful for deploying the tree without sklearn dependency,
    or for embedding in microservices, SQL, or other environments.
    """
    rules = extract_decision_rules(dt, feature_names, class_names)
    
    lines = [
        f"def {func_name}({', '.join(f.replace(' ', '_').replace('(', '').replace(')', '').replace('/', '_') for f in feature_names)}):",
        f'    """Auto-generated decision tree classifier."""',
    ]
    
    indent = "    "
    for rule in rules:
        if rule['conditions']:
            cond_parts = []
            for cond in rule['conditions']:
                # Clean feature names for Python variable syntax
                parts   = cond.split(' ')
                feat    = parts[0].replace(' ', '_').replace('(', '').replace(')', '').replace('/', '_')
                op      = parts[1]
                thresh  = parts[2]
                cond_parts.append(f"{feat} {op} {thresh}")
            
            lines.append(f"{indent}if ({' and '.join(cond_parts)}):")
            lines.append(f"{indent}    return '{rule['prediction']}'")
    
    # Final else (shouldn't be reached if tree is complete)
    lines.append(f"{indent}return '{rules[-1]['prediction']}'  # default")
    
    code = "\n".join(lines)
    print("=== Auto-generated Python Function ===\n")
    print(code[:800] + "\n  ..." if len(code) > 800 else code)
    return code


rules_to_python_function(dt_ir_d3, iris.feature_names, iris.target_names)

Method 3: export_graphviz() — Publication-Quality Diagrams

For reports, papers, or presentations, Graphviz produces the highest-quality decision tree visualizations. The output is a .dot file that can be rendered to SVG, PNG, or PDF.

Python
from sklearn.tree import export_graphviz
import os

def export_to_graphviz(dt, feature_names, class_names, filename,
                         filled=True, rounded=True, special_characters=True):
    """
    Export tree to Graphviz .dot format and render to PNG.
    
    Requires graphviz to be installed:
        pip install graphviz
        brew install graphviz  (macOS)
        apt-get install graphviz  (Ubuntu)
    
    Args:
        dt:            Fitted decision tree
        feature_names: Feature names list
        class_names:   Class names list
        filename:      Base filename (without extension)
    """
    dot_data = export_graphviz(
        dt,
        out_file=None,
        feature_names=feature_names,
        class_names=class_names,
        filled=filled,
        rounded=rounded,
        special_characters=special_characters,
        impurity=True,
        proportion=False,
        max_depth=None,
        precision=3,
    )
    
    # Save .dot file
    dot_filename = f"{filename}.dot"
    with open(dot_filename, 'w') as f:
        f.write(dot_data)
    print(f"Saved DOT file: {dot_filename}")
    
    # Try to render using graphviz Python package
    try:
        import graphviz
        graph = graphviz.Source(dot_data)
        graph.render(filename, format='png', cleanup=True)
        print(f"Rendered: {filename}.png")
    except ImportError:
        print("graphviz package not installed.")
        print(f"To render: run `dot -Tpng {dot_filename} -o {filename}.png`")
    except Exception as e:
        print(f"Rendering failed: {e}")
        print(f"DOT file saved — render manually with Graphviz CLI")
    
    return dot_data


# Export trees
dot_iris = export_to_graphviz(
    dt_ir_d3, iris.feature_names, iris.target_names,
    filename='dt_iris_graphviz'
)

# Show what the DOT format looks like
print("\n=== First 30 lines of DOT file ===\n")
for line in dot_iris.split('\n')[:30]:
    print(f"  {line}")

Customizing Graphviz Output

Python
def custom_graphviz_colors(dt, feature_names, class_names,
                             class_colors=None, filename="dt_custom"):
    """
    Create a Graphviz diagram with custom node colors.
    
    Default sklearn coloring uses a fixed palette.
    This function allows custom colors per class.
    """
    tree = dt.tree_
    
    if class_colors is None:
        # Default: blue palette for a medical/professional look
        class_colors = ['#2196F3', '#FF5722', '#4CAF50', '#9C27B0']
    
    def node_color(node_id):
        """Compute node color based on class distribution."""
        counts = tree.value[node_id][0].astype(float)
        total  = counts.sum()
        if total == 0:
            return '#FFFFFF'
        probs      = counts / total
        majority   = np.argmax(probs)
        purity     = probs[majority]
        # Interpolate between white and class color by purity
        base_color = class_colors[majority % len(class_colors)]
        # Parse hex color
        r = int(base_color[1:3], 16)
        g = int(base_color[3:5], 16)
        b = int(base_color[5:7], 16)
        # Interpolate with white (255, 255, 255)
        alpha = purity
        r2 = int(r * alpha + 255 * (1 - alpha))
        g2 = int(g * alpha + 255 * (1 - alpha))
        b2 = int(b * alpha + 255 * (1 - alpha))
        return f'#{r2:02X}{g2:02X}{b2:02X}'
    
    lines = ['digraph Tree {']
    lines.append('  node [shape=box, style="filled,rounded", '
                 'fontname="Helvetica"];')
    lines.append('  edge [fontname="Helvetica"];')
    
    def add_node(node_id, depth=0):
        feat    = tree.feature[node_id]
        thresh  = tree.threshold[node_id]
        is_leaf = (tree.children_left[node_id] == -1)
        counts  = tree.value[node_id][0].astype(int)
        n_samp  = int(tree.n_node_samples[node_id])
        impurity = tree.impurity[node_id]
        color    = node_color(node_id)
        
        if is_leaf:
            majority = class_names[np.argmax(counts)]
            label    = (f"{majority}\\nn={n_samp}\\n"
                        f"gini={impurity:.3f}\\n"
                        f"[{', '.join(map(str, counts))}]")
        else:
            feat_name = feature_names[feat]
            label     = (f"{feat_name}\\n≤ {thresh:.3f}\\n"
                         f"n={n_samp}, gini={impurity:.3f}")
        
        lines.append(f'  {node_id} [label="{label}", fillcolor="{color}"];')
        
        if not is_leaf:
            left_id  = tree.children_left[node_id]
            right_id = tree.children_right[node_id]
            lines.append(f'  {node_id} -> {left_id} [label="True"];')
            lines.append(f'  {node_id} -> {right_id} [label="False"];')
            add_node(left_id,  depth + 1)
            add_node(right_id, depth + 1)
    
    add_node(0)
    lines.append('}')
    
    dot_str = '\n'.join(lines)
    with open(f"{filename}.dot", 'w') as f:
        f.write(dot_str)
    print(f"Custom DOT saved: {filename}.dot")
    return dot_str


custom_dot = custom_graphviz_colors(
    dt_ir_d3, iris.feature_names, iris.target_names,
    class_colors=['#1565C0', '#C62828', '#2E7D32'],
    filename='dt_iris_custom_colors'
)

Method 4: Decision Boundary Visualization

Decision boundary plots reveal how the tree partitions 2D feature space into rectangular prediction regions. Since trees split on one feature at a time, these boundaries are always axis-aligned.

Python
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from sklearn.tree import DecisionTreeClassifier


def plot_decision_boundary(dt, X, y, feature_idx=(0, 1),
                             feature_names=None, class_names=None,
                             title="Decision Tree Boundary",
                             resolution=300, margin=0.5, figsize=(10, 7)):
    """
    Visualize decision tree boundary in a 2D feature subspace.
    
    Args:
        dt:          Fitted decision tree (uses all features)
        X:           Full feature matrix
        y:           Class labels
        feature_idx: Tuple of two feature indices to plot
        resolution:  Grid resolution for boundary
        margin:      Extra margin around data range
    """
    f1, f2 = feature_idx
    X2     = X[:, [f1, f2]]   # 2D subspace for plotting
    
    x_min, x_max = X2[:, 0].min() - margin, X2[:, 0].max() + margin
    y_min, y_max = X2[:, 1].min() - margin, X2[:, 1].max() + margin
    
    xx, yy = np.meshgrid(np.linspace(x_min, x_max, resolution),
                          np.linspace(y_min, y_max, resolution))
    
    # Build full-feature grid (using mean for non-plotted features)
    X_means = X.mean(axis=0)
    grid_full = np.tile(X_means, (resolution * resolution, 1))
    grid_full[:, f1] = xx.ravel()
    grid_full[:, f2] = yy.ravel()
    
    Z = dt.predict(grid_full).reshape(xx.shape)
    
    # Colors
    n_classes = len(np.unique(y))
    palette_bg  = ['#d0e8f8', '#f8d0d0', '#d0f8d0', '#f8f0d0', '#e8d0f8'][:n_classes]
    palette_pts = ['steelblue', 'coral', 'mediumseagreen', 'goldenrod', 'mediumpurple'][:n_classes]
    
    cmap_bg = ListedColormap(palette_bg)
    
    fig, ax = plt.subplots(figsize=figsize)
    
    ax.contourf(xx, yy, Z, alpha=0.4, cmap=cmap_bg)
    ax.contour(xx, yy, Z, colors='black', linewidths=0.8, alpha=0.4)
    
    classes = np.unique(y)
    for cls, color in zip(classes, palette_pts):
        mask = y == cls
        label = (class_names[cls] if class_names is not None
                 else f"Class {cls}")
        ax.scatter(X2[mask, 0], X2[mask, 1], c=color,
                   edgecolors='white', s=55, linewidth=0.5,
                   label=label, alpha=0.85, zorder=3)
    
    f1_name = (feature_names[f1] if feature_names is not None
               else f"Feature {f1}")
    f2_name = (feature_names[f2] if feature_names is not None
               else f"Feature {f2}")
    
    ax.set_xlabel(f1_name, fontsize=12)
    ax.set_ylabel(f2_name, fontsize=12)
    ax.set_title(title, fontsize=12, fontweight='bold')
    ax.legend(fontsize=10, loc='upper right')
    ax.grid(True, alpha=0.2)
    
    plt.tight_layout()
    return fig, ax


# 2D: Iris with petal features (most discriminative pair)
fig_ir, _ = plot_decision_boundary(
    DecisionTreeClassifier(max_depth=4, random_state=42).fit(X_ir[:, 2:], y_ir),
    X_ir[:, 2:], y_ir,
    feature_idx=(0, 1),
    feature_names=['Petal Length (cm)', 'Petal Width (cm)'],
    class_names=iris.target_names,
    title='Decision Tree Boundary: Iris (Petal Features, depth=4)\n'
          'Axis-aligned splits create rectangular prediction regions',
)
fig_ir.savefig('dt_boundary_iris_petal.png', dpi=150)
plt.show()
print("Saved: dt_boundary_iris_petal.png")


# Boundary comparison: different depths
def plot_depth_comparison(X, y, depths, feature_names=None,
                           class_names=None, title="Depth Comparison"):
    """
    Side-by-side decision boundary plots for different max_depth values.
    Shows how depth controls the complexity of the decision surface.
    """
    n_cols = len(depths)
    fig, axes = plt.subplots(1, n_cols, figsize=(5 * n_cols, 5), sharey=True)
    
    x_min, x_max = X[:, 0].min() - 0.5, X[:, 0].max() + 0.5
    y_min, y_max = X[:, 1].min() - 0.5, X[:, 1].max() + 0.5
    xx, yy = np.meshgrid(np.linspace(x_min, x_max, 250),
                          np.linspace(y_min, y_max, 250))
    grid = np.c_[xx.ravel(), yy.ravel()]
    
    palette_bg  = ['#d0e8f8', '#f8d0d0']
    palette_pts = ['steelblue', 'coral']
    cmap_bg     = ListedColormap(palette_bg)
    
    for ax, d in zip(axes, depths):
        dt_d = DecisionTreeClassifier(max_depth=d, random_state=42)
        dt_d.fit(X, y)
        
        Z = dt_d.predict(grid).reshape(xx.shape)
        
        ax.contourf(xx, yy, Z, alpha=0.35, cmap=cmap_bg)
        ax.contour(xx, yy, Z, colors='black', linewidths=0.8, alpha=0.4)
        
        for cls, color in zip(np.unique(y), palette_pts):
            mask = y == cls
            ax.scatter(X[mask, 0], X[mask, 1], c=color, edgecolors='white',
                       s=35, linewidth=0.4, alpha=0.85)
        
        train_acc = dt_d.score(X, y)
        n_leaves  = dt_d.get_n_leaves()
        
        depth_str = str(d) if d is not None else ""
        ax.set_title(f"depth = {depth_str}\n"
                     f"{n_leaves} leaves | Train acc = {train_acc:.3f}",
                     fontsize=10, fontweight='bold')
        ax.set_xlabel(feature_names[0] if feature_names else "X1", fontsize=9)
        if ax == axes[0]:
            ax.set_ylabel(feature_names[1] if feature_names else "X2", fontsize=9)
        ax.grid(True, alpha=0.2)
    
    plt.suptitle(title, fontsize=12, fontweight='bold', y=1.02)
    plt.tight_layout()
    return fig


np.random.seed(42)
X_moons, y_moons = make_moons(n_samples=300, noise=0.25, random_state=42)

fig_d_comp = plot_depth_comparison(
    X_moons, y_moons,
    depths=[1, 3, 5, 10, None],
    feature_names=['Feature 1', 'Feature 2'],
    title='Decision Tree Boundary: Two Moons Dataset\n'
          '(Note: straight axis-aligned cuts approximating a curved boundary)'
)
fig_d_comp.savefig('dt_boundary_depth_comparison.png', dpi=150, bbox_inches='tight')
plt.show()
print("Saved: dt_boundary_depth_comparison.png")

Method 5: Feature Importance Visualization

Feature importances summarize which features the tree relied on most across all splits. Visualizing them reveals both feature relevance and which features were ignored entirely.

Python
import numpy as np
import matplotlib.pyplot as plt
from sklearn.tree import DecisionTreeClassifier
from sklearn.inspection import permutation_importance
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split


def plot_feature_importances(dt, feature_names, X_test=None, y_test=None,
                               top_n=20, title="Feature Importances"):
    """
    Plot decision tree feature importances with optional permutation importance.
    
    Two types of importance:
    1. MDI (Mean Decrease in Impurity): built-in, fast, can be biased
    2. Permutation Importance: unbiased but slower, requires test data
    
    Args:
        dt:           Fitted decision tree
        feature_names: Feature name strings
        X_test, y_test: Test data for permutation importance (optional)
        top_n:        Number of top features to display
    """
    importances_mdi = dt.feature_importances_
    sorted_idx_mdi  = np.argsort(importances_mdi)[::-1][:top_n]
    
    if X_test is not None and y_test is not None:
        fig, axes = plt.subplots(1, 2, figsize=(16, max(6, top_n * 0.4)))
    else:
        fig, axes = plt.subplots(1, 1, figsize=(10, max(6, top_n * 0.4)))
        axes = [axes]
    
    # MDI Importances
    ax = axes[0]
    colors = plt.cm.Blues(np.linspace(0.4, 0.9, top_n))[::-1]
    bars = ax.barh(range(top_n),
                   importances_mdi[sorted_idx_mdi],
                   color=colors, edgecolor='white', linewidth=0.5)
    ax.set_yticks(range(top_n))
    ax.set_yticklabels([feature_names[i] for i in sorted_idx_mdi], fontsize=9)
    ax.set_xlabel('Mean Decrease in Impurity', fontsize=11)
    ax.set_title(f'{title}\nMDI Feature Importances', fontsize=11, fontweight='bold')
    ax.grid(True, alpha=0.3, axis='x')
    ax.invert_yaxis()  # Highest importance at top
    
    # Add value labels on bars
    for bar, idx in zip(bars, sorted_idx_mdi):
        width = bar.get_width()
        if width > 0.01:
            ax.text(width + 0.003, bar.get_y() + bar.get_height() / 2,
                    f'{width:.3f}', va='center', fontsize=7)
    
    # Permutation Importances (if test data provided)
    if X_test is not None and y_test is not None:
        perm_imp = permutation_importance(
            dt, X_test, y_test, n_repeats=30, random_state=42, n_jobs=-1
        )
        sorted_idx_perm = np.argsort(perm_imp.importances_mean)[::-1][:top_n]
        
        ax2 = axes[1]
        for rank, feat_idx in enumerate(sorted_idx_perm):
            mean_imp = perm_imp.importances_mean[feat_idx]
            std_imp  = perm_imp.importances_std[feat_idx]
            ax2.barh(rank, mean_imp, xerr=std_imp,
                     color=plt.cm.Oranges(0.5 + 0.5 * mean_imp / perm_imp.importances_mean.max()),
                     edgecolor='white', linewidth=0.5,
                     error_kw={'linewidth': 1.5, 'ecolor': 'gray'})
        
        ax2.set_yticks(range(top_n))
        ax2.set_yticklabels([feature_names[i] for i in sorted_idx_perm], fontsize=9)
        ax2.set_xlabel('Accuracy Decrease (mean ± std)', fontsize=11)
        ax2.set_title(f'Permutation Importances\n(More reliable, unbiased)',
                      fontsize=11, fontweight='bold')
        ax2.grid(True, alpha=0.3, axis='x')
        ax2.invert_yaxis()
    
    plt.tight_layout()
    return fig


cancer = load_breast_cancer()
X_ca, y_ca = cancer.data, cancer.target
X_tr_ca, X_te_ca, y_tr_ca, y_te_ca = train_test_split(
    X_ca, y_ca, test_size=0.25, random_state=42
)
dt_ca_full = DecisionTreeClassifier(max_depth=5, random_state=42).fit(X_tr_ca, y_tr_ca)

fig_imp = plot_feature_importances(
    dt_ca_full,
    feature_names=cancer.feature_names,
    X_test=X_te_ca, y_test=y_te_ca,
    top_n=15,
    title="Breast Cancer Decision Tree"
)
fig_imp.savefig('dt_feature_importances.png', dpi=150, bbox_inches='tight')
plt.show()
print("Saved: dt_feature_importances.png")

Method 6: Tracing Individual Predictions

When explaining a specific prediction to a stakeholder — “why did the model classify this patient as high risk?” — you need to trace the exact path from root to leaf for that sample.

Python
import numpy as np


def trace_prediction_path(dt, x, feature_names, class_names):
    """
    Trace the decision path for a single sample through the tree.
    
    Prints each decision made from root to leaf in plain language,
    suitable for explaining individual predictions to non-technical users.
    
    Args:
        dt:            Fitted decision tree
        x:             Single sample (1D array of features)
        feature_names: Feature name strings
        class_names:   Class name strings
    
    Returns:
        Predicted class, confidence, list of decision steps
    """
    tree    = dt.tree_
    x       = np.array(x).ravel()
    node_id = 0
    steps   = []
    depth   = 0
    
    while True:
        is_leaf = (tree.children_left[node_id] == -1)
        
        if is_leaf:
            counts     = tree.value[node_id][0].astype(int)
            majority   = np.argmax(counts)
            confidence = counts[majority] / counts.sum()
            prediction = class_names[majority]
            
            steps.append({
                'type':       'leaf',
                'depth':      depth,
                'prediction': prediction,
                'confidence': confidence,
                'counts':     {class_names[i]: int(c) for i, c in enumerate(counts)},
                'n_samples':  int(counts.sum()),
            })
            break
        
        feat   = tree.feature[node_id]
        thresh = tree.threshold[node_id]
        val    = x[feat]
        
        goes_left = val <= thresh
        direction = "Yes (≤)" if goes_left else "No (>)"
        
        steps.append({
            'type':      'split',
            'depth':     depth,
            'feature':   feature_names[feat],
            'threshold': thresh,
            'value':     val,
            'direction': direction,
            'node_id':   node_id,
        })
        
        node_id = (tree.children_left[node_id] if goes_left
                   else tree.children_right[node_id])
        depth += 1
    
    # Print the path
    leaf = steps[-1]
    splits = steps[:-1]
    
    print(f"=== Prediction Path ===\n")
    print(f"  Prediction: '{leaf['prediction']}' "
          f"(confidence: {leaf['confidence']*100:.1f}%)")
    print(f"  Training samples at this leaf: {leaf['n_samples']}")
    print(f"  Class distribution: {leaf['counts']}\n")
    print(f"  Decision path ({len(splits)} splits):")
    
    for step in splits:
        indent = "    " * (step['depth'] + 1)
        print(f"{indent}[Depth {step['depth']}] "
              f"{step['feature']} = {step['value']:.4f}  "
              f"→ Is ≤ {step['threshold']:.4f}? → {step['direction']}")
    
    return leaf['prediction'], leaf['confidence'], steps


# Example: trace a specific iris sample
sample_idx = 75
x_sample = X_ir[sample_idx]
true_label = iris.target_names[y_ir[sample_idx]]

print(f"Sample #{sample_idx} | True class: '{true_label}'\n")
pred, conf, path = trace_prediction_path(
    dt_ir_d3, x_sample, iris.feature_names, iris.target_names
)

# Visualize the path on the tree
def highlight_prediction_path(dt, x, feature_names, class_names, figsize=(18, 9)):
    """
    Visualize the full tree with the prediction path for sample x highlighted.
    """
    from sklearn.tree import plot_tree
    
    # Get node path
    decision_path = dt.decision_path([x])
    node_indicator = decision_path.toarray()[0].astype(bool)
    nodes_on_path  = np.where(node_indicator)[0]
    
    fig, ax = plt.subplots(figsize=figsize)
    
    plot_tree(
        dt,
        feature_names=feature_names,
        class_names=class_names,
        filled=True,
        rounded=True,
        impurity=True,
        precision=3,
        fontsize=10,
        ax=ax,
    )
    
    ax.set_title(
        f"Prediction Path Highlighted (sample #{sample_idx})\n"
        f"True: '{true_label}' | Predicted: '{pred}' ({conf*100:.1f}% confidence)\n"
        f"Path visits {len(nodes_on_path)} nodes (nodes on path = depth + 1)",
        fontsize=11, fontweight='bold'
    )
    
    plt.tight_layout()
    plt.savefig('dt_prediction_path.png', dpi=150, bbox_inches='tight')
    plt.show()
    print("Saved: dt_prediction_path.png")
    
    print(f"\n  Nodes visited: {nodes_on_path}")
    print(f"  Total nodes in tree: {dt.tree_.node_count}")
    print(f"  Path efficiency: {len(nodes_on_path)}/{dt.tree_.node_count} nodes visited")


highlight_prediction_path(dt_ir_d3, x_sample, iris.feature_names, iris.target_names)

Method 7: Tree Complexity Analysis Plots

Understanding how tree complexity evolves with depth helps you make principled decisions about regularization. The following visualization suite provides a comprehensive complexity dashboard.

Python
import numpy as np
import matplotlib.pyplot as plt
from sklearn.tree import DecisionTreeClassifier
from sklearn.datasets import load_wine
from sklearn.model_selection import train_test_split, cross_val_score


def tree_complexity_dashboard(X_train, y_train, X_test, y_test,
                                feature_names, class_names,
                                max_depth_range=range(1, 25),
                                cv_folds=5, figsize=(16, 12)):
    """
    Comprehensive dashboard showing tree complexity metrics vs max_depth.
    
    Panels:
    1. Train vs test accuracy (bias-variance tradeoff)
    2. Number of leaves vs depth
    3. Number of nodes vs depth
    4. Mean depth of leaves (how deep is the average prediction path?)
    5. Feature coverage (fraction of features used)
    6. Cross-validation accuracy with confidence band
    """
    depths      = list(max_depth_range)
    train_accs  = []
    test_accs   = []
    n_leaves    = []
    n_nodes     = []
    mean_depths = []
    feat_cov    = []
    cv_means    = []
    cv_stds     = []
    
    for d in depths:
        dt = DecisionTreeClassifier(max_depth=d, random_state=42)
        dt.fit(X_train, y_train)
        
        train_accs.append(dt.score(X_train, y_train))
        test_accs.append(dt.score(X_test, y_test))
        n_leaves.append(dt.get_n_leaves())
        n_nodes.append(dt.tree_.node_count)
        
        # Mean depth of leaf nodes
        tree = dt.tree_
        leaf_depths = []
        stack = [(0, 0)]
        while stack:
            nid, dep = stack.pop()
            if tree.children_left[nid] == -1:
                leaf_depths.append(dep)
            else:
                stack.append((tree.children_right[nid], dep + 1))
                stack.append((tree.children_left[nid],  dep + 1))
        mean_depths.append(np.mean(leaf_depths))
        
        # Feature coverage
        used = (dt.feature_importances_ > 0).sum()
        feat_cov.append(used / X_train.shape[1] * 100)
        
        # CV accuracy
        cv_sc = cross_val_score(dt, np.vstack([X_train, X_test]),
                                 np.concatenate([y_train, y_test]),
                                 cv=cv_folds, scoring='accuracy')
        cv_means.append(cv_sc.mean())
        cv_stds.append(cv_sc.std())
    
    depths      = np.array(depths)
    train_accs  = np.array(train_accs)
    test_accs   = np.array(test_accs)
    cv_means    = np.array(cv_means)
    cv_stds     = np.array(cv_stds)
    
    best_test_idx = np.argmax(test_accs)
    best_cv_idx   = np.argmax(cv_means)
    
    fig, axes = plt.subplots(2, 3, figsize=figsize)
    axes = axes.flatten()
    
    # Panel 1: Accuracy
    ax = axes[0]
    ax.plot(depths, train_accs, 'o-', color='steelblue', lw=2, label='Train')
    ax.plot(depths, test_accs,  's-', color='coral',     lw=2, label='Test')
    ax.axvline(depths[best_test_idx], color='green', linestyle='--', lw=1.5,
               label=f'Best test depth={depths[best_test_idx]}')
    ax.set_xlabel('max_depth'); ax.set_ylabel('Accuracy')
    ax.set_title('Train vs Test Accuracy', fontweight='bold')
    ax.legend(fontsize=9); ax.grid(True, alpha=0.3)
    
    # Panel 2: Leaves
    ax = axes[1]
    ax.semilogy(depths, n_leaves, 'o-', color='mediumpurple', lw=2)
    ax.axvline(depths[best_test_idx], color='green', linestyle='--', lw=1.5)
    ax.set_xlabel('max_depth'); ax.set_ylabel('Number of Leaves (log)')
    ax.set_title('Leaf Count vs Depth\n(Exponential growth!)', fontweight='bold')
    ax.grid(True, alpha=0.3)
    
    # Panel 3: Nodes
    ax = axes[2]
    ax.semilogy(depths, n_nodes, 'o-', color='goldenrod', lw=2)
    ax.axvline(depths[best_test_idx], color='green', linestyle='--', lw=1.5)
    ax.set_xlabel('max_depth'); ax.set_ylabel('Total Nodes (log)')
    ax.set_title('Node Count vs Depth', fontweight='bold')
    ax.grid(True, alpha=0.3)
    
    # Panel 4: Mean leaf depth
    ax = axes[3]
    ax.plot(depths, mean_depths, 'o-', color='teal', lw=2)
    ax.plot(depths, depths, 'k--', lw=1, alpha=0.4, label='Max possible')
    ax.axvline(depths[best_test_idx], color='green', linestyle='--', lw=1.5)
    ax.set_xlabel('max_depth'); ax.set_ylabel('Mean Leaf Depth')
    ax.set_title('Mean Prediction Path Length\n(Actual vs maximum possible)',
                 fontweight='bold')
    ax.legend(fontsize=8); ax.grid(True, alpha=0.3)
    
    # Panel 5: Feature coverage
    ax = axes[4]
    ax.plot(depths, feat_cov, 'o-', color='darkorange', lw=2)
    ax.axhline(y=100, color='gray', linestyle=':', lw=1, alpha=0.5,
               label='100% coverage')
    ax.set_xlabel('max_depth'); ax.set_ylabel('Features Used (%)')
    ax.set_title('Feature Coverage vs Depth\n(% of features appearing in tree)',
                 fontweight='bold')
    ax.set_ylim([0, 105]); ax.legend(fontsize=8); ax.grid(True, alpha=0.3)
    
    # Panel 6: CV accuracy
    ax = axes[5]
    ax.plot(depths, cv_means, 'o-', color='mediumseagreen', lw=2, label='CV mean')
    ax.fill_between(depths, cv_means - cv_stds, cv_means + cv_stds,
                    alpha=0.2, color='mediumseagreen', label='±1 std')
    ax.axvline(depths[best_cv_idx], color='coral', linestyle='--', lw=2,
               label=f'Best CV depth={depths[best_cv_idx]}')
    ax.set_xlabel('max_depth'); ax.set_ylabel('CV Accuracy')
    ax.set_title(f'{cv_folds}-Fold CV Accuracy', fontweight='bold')
    ax.legend(fontsize=9); ax.grid(True, alpha=0.3)
    
    plt.suptitle(f'Decision Tree Complexity Dashboard\n'
                 f'(Best test depth={depths[best_test_idx]}, '
                 f'Best CV depth={depths[best_cv_idx]})',
                 fontsize=13, fontweight='bold', y=1.01)
    plt.tight_layout()
    plt.savefig('dt_complexity_dashboard.png', dpi=150, bbox_inches='tight')
    plt.show()
    print("Saved: dt_complexity_dashboard.png")
    
    return {
        'best_test_depth': depths[best_test_idx],
        'best_cv_depth':   depths[best_cv_idx],
        'best_test_acc':   test_accs[best_test_idx],
        'best_cv_acc':     cv_means[best_cv_idx],
    }


wine = load_wine()
X_wi_full, y_wi_full = wine.data, wine.target
X_tr_w2, X_te_w2, y_tr_w2, y_te_w2 = train_test_split(
    X_wi_full, y_wi_full, test_size=0.25, random_state=42, stratify=y_wi_full
)

results = tree_complexity_dashboard(
    X_tr_w2, y_tr_w2, X_te_w2, y_te_w2,
    feature_names=wine.feature_names,
    class_names=wine.target_names,
    max_depth_range=range(1, 20)
)
print(f"\n  Best test depth: {results['best_test_depth']} "
      f"(acc={results['best_test_acc']:.4f})")
print(f"  Best CV depth:   {results['best_cv_depth']} "
      f"(acc={results['best_cv_acc']:.4f})")

Method 8: Visualizing Regression Trees

Regression trees predict continuous values rather than class labels. Visualizing them requires a different lens — instead of class purity, we care about the mean and variance of target values at each node.

Python
import numpy as np
import matplotlib.pyplot as plt
from sklearn.tree import DecisionTreeRegressor, plot_tree

# 1D regression: piecewise constant predictions with split markers
np.random.seed(42)
X_1d = np.sort(np.random.uniform(0, 4 * np.pi, 200)).reshape(-1, 1)
y_1d = np.sin(X_1d.ravel()) + 0.5 * np.cos(2 * X_1d.ravel()) + \
       np.random.normal(0, 0.15, 200)

fig, axes = plt.subplots(2, 3, figsize=(18, 10))
x_plot = np.linspace(0, 4 * np.pi, 1000).reshape(-1, 1)
y_true_plot = np.sin(x_plot.ravel()) + 0.5 * np.cos(2 * x_plot.ravel())

for ax, d in zip(axes.flatten(), [1, 2, 3, 5, 8, 15]):
    dtr = DecisionTreeRegressor(max_depth=d, random_state=42)
    dtr.fit(X_1d, y_1d)
    y_pred = dtr.predict(x_plot)

    ax.scatter(X_1d, y_1d, color='steelblue', s=12, alpha=0.5, label='Data')
    ax.plot(x_plot, y_true_plot, 'k--', lw=1.5, alpha=0.5, label='True f(x)')
    ax.plot(x_plot, y_pred, 'coral', lw=2.5, label=f'DT (depth={d})')

    # Mark split thresholds as vertical dotted lines
    tree = dtr.tree_
    stack = [0]
    while stack:
        nid = stack.pop()
        if tree.children_left[nid] != -1:
            ax.axvline(x=tree.threshold[nid], color='gray',
                       linestyle=':', lw=0.8, alpha=0.4)
            stack.append(tree.children_left[nid])
            stack.append(tree.children_right[nid])

    r2 = dtr.score(X_1d, y_1d)
    ax.set_title(f'Depth={d} | {dtr.get_n_leaves()} leaves | R²={r2:.3f}',
                 fontsize=10, fontweight='bold')
    ax.set_xlabel('X', fontsize=9); ax.set_ylabel('y', fontsize=9)
    ax.legend(fontsize=7, loc='upper right'); ax.grid(True, alpha=0.2)

plt.suptitle('Regression Tree: Piecewise Constant Approximation\n'
             'Gray dotted lines = split thresholds; deeper trees = more steps',
             fontsize=13, fontweight='bold', y=1.01)
plt.tight_layout()
plt.savefig('dt_regression_piecewise.png', dpi=150, bbox_inches='tight')
plt.show()
print("Saved: dt_regression_piecewise.png")


def visualize_regression_leaf_distributions(dt_reg, X_train, y_train, n_cols=4):
    """
    Show the distribution of target values at each leaf node.
    The leaf's mean is the prediction; spread shows unexplained variance.
    """
    leaf_ids      = dt_reg.apply(X_train)
    unique_leaves = np.unique(leaf_ids)
    n_leaves      = len(unique_leaves)
    n_rows        = int(np.ceil(n_leaves / n_cols))

    fig, axes = plt.subplots(n_rows, n_cols, figsize=(4 * n_cols, 3 * n_rows))
    axes = axes.flatten()

    for ax, leaf_id in zip(axes, unique_leaves):
        mask   = leaf_ids == leaf_id
        y_leaf = y_train[mask]

        ax.hist(y_leaf, bins=min(15, max(5, len(y_leaf))),
                color='steelblue', edgecolor='white', alpha=0.8)
        ax.axvline(y_leaf.mean(), color='coral', lw=2.5,
                   label=f'μ={y_leaf.mean():.2f}')
        ax.set_title(f'Leaf {leaf_id}  n={len(y_leaf)}  σ={y_leaf.std():.2f}',
                     fontsize=8, fontweight='bold')
        ax.legend(fontsize=6); ax.grid(True, alpha=0.3, axis='y')
        ax.tick_params(labelsize=7)

    for ax in axes[n_leaves:]:
        ax.set_visible(False)

    plt.suptitle('Regression Tree: Target Distribution per Leaf\n'
                 '(Each leaf predicts its mean; large σ = poorly fit region)',
                 fontsize=11, fontweight='bold')
    plt.tight_layout()
    plt.savefig('dt_regression_leaf_distributions.png', dpi=150, bbox_inches='tight')
    plt.show()
    print("Saved: dt_regression_leaf_distributions.png")


dtr_shallow = DecisionTreeRegressor(max_depth=3, random_state=42)
dtr_shallow.fit(X_1d, y_1d)
visualize_regression_leaf_distributions(dtr_shallow, X_1d, y_1d, n_cols=4)

The regression tree visualization reveals the key limitation of the algorithm: each leaf’s prediction is a flat horizontal line at the mean of its training samples. Even with depth=15 the curve is a staircase, never a smooth line. This directly motivates gradient boosted trees, which sum many shallow trees to approximate smooth functions more efficiently than any single deep tree.

Method 9: Confusion Matrix Per Leaf

For classification trees, computing a confusion matrix for each leaf node reveals which leaves make systematic errors versus which are reliably pure. This is especially useful for diagnosing class imbalance problems or identifying subtrees that would benefit from pruning.

Python
import numpy as np
import matplotlib.pyplot as plt
from sklearn.tree import DecisionTreeClassifier
from sklearn.datasets import load_wine
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix


def confusion_matrix_per_leaf(dt, X, y, class_names, n_cols=4):
    """
    Plot a small confusion matrix for every leaf node.
    
    Impure leaves (high off-diagonal counts) are candidates for:
    - Additional depth (if training data permits)
    - Pruning (if the node is too small to split reliably)
    - Feature engineering (if no existing feature separates them)
    """
    leaf_ids      = dt.apply(X)
    unique_leaves = np.unique(leaf_ids)
    n_leaves      = len(unique_leaves)
    n_classes     = len(class_names)
    y_pred        = dt.predict(X)

    n_rows = int(np.ceil(n_leaves / n_cols))
    fig, axes = plt.subplots(n_rows, n_cols,
                              figsize=(4 * n_cols, 3.5 * n_rows))
    axes = axes.flatten()

    print(f"  {'Leaf':>6} | {'n':>5} | {'Acc %':>7} | {'Majority':>15} | Purity")
    print("  " + "-" * 48)

    for ax, leaf_id in zip(axes, unique_leaves):
        mask        = leaf_ids == leaf_id
        y_true_leaf = y[mask]
        y_pred_leaf = y_pred[mask]
        n_samples   = mask.sum()

        cm  = confusion_matrix(y_true_leaf, y_pred_leaf,
                                labels=np.arange(n_classes))
        acc = (y_true_leaf == y_pred_leaf).mean()
        majority = class_names[np.argmax(np.bincount(y_true_leaf.astype(int),
                                                      minlength=n_classes))]

        cmap = 'Greens' if acc > 0.9 else ('YlOrBr' if acc > 0.7 else 'Reds')
        im = ax.imshow(cm, interpolation='nearest', cmap=cmap)
        ax.set_xticks(np.arange(n_classes))
        ax.set_yticks(np.arange(n_classes))
        ax.set_xticklabels([c[:4] for c in class_names], fontsize=6, rotation=45)
        ax.set_yticklabels([c[:4] for c in class_names], fontsize=6)
        for i in range(n_classes):
            for j in range(n_classes):
                ax.text(j, i, str(cm[i, j]), ha='center', va='center',
                        fontsize=8, color='white' if cm[i, j] > cm.max()/2 else 'black')

        status = "" if acc > 0.9 else ("~" if acc > 0.7 else "")
        color  = '#22aa22' if acc > 0.9 else ('#aaaa22' if acc > 0.7 else '#aa2222')
        ax.set_title(f'{status} Leaf {leaf_id}  n={n_samples}\nAcc={acc*100:.0f}%',
                     fontsize=8, fontweight='bold', color=color)
        ax.set_xlabel('Pred', fontsize=6); ax.set_ylabel('True', fontsize=6)

        flag = "✓ Pure" if acc > 0.9 else ("~ Mixed" if acc > 0.7 else "✗ Impure")
        print(f"  {leaf_id:>6} | {n_samples:>5} | {acc*100:>6.0f}% | "
              f"{majority:>15} | {flag}")

    for ax in axes[n_leaves:]:
        ax.set_visible(False)

    plt.suptitle('Confusion Matrix per Leaf\n'
                 'Green=pure (>90%), Yellow=mixed, Red=impure (<70%)',
                 fontsize=12, fontweight='bold')
    plt.tight_layout()
    plt.savefig('dt_confusion_per_leaf.png', dpi=150, bbox_inches='tight')
    plt.show()
    print("\nSaved: dt_confusion_per_leaf.png")


wine  = load_wine()
X_wi2, y_wi2 = wine.data, wine.target
X_tr_wi2, X_te_wi2, y_tr_wi2, y_te_wi2 = train_test_split(
    X_wi2, y_wi2, test_size=0.25, random_state=42, stratify=y_wi2
)
dt_wi_vis = DecisionTreeClassifier(max_depth=4, random_state=42).fit(X_tr_wi2, y_tr_wi2)

print("=== Confusion Matrix per Leaf: Wine Dataset ===\n")
confusion_matrix_per_leaf(dt_wi_vis, X_te_wi2, y_te_wi2,
                           wine.target_names, n_cols=4)

The per-leaf confusion matrix immediately highlights which parts of the tree are working well (pure leaves colored green) and which need attention (mixed leaves colored yellow or red). An impure leaf with many samples is a strong candidate for additional depth or feature engineering — it has enough data to support further splitting but lacks a good separating feature in the current tree. A small impure leaf with only a handful of samples may be better pruned entirely, defaulting to the parent’s majority class.

Choosing the Right Visualization

Different visualization needs call for different tools. Here is a practical decision guide:

SituationBest MethodWhy
Quick exploratory checkplot_tree()Instant, no dependencies
Understanding node structureexport_text()Readable, no graphics needed
Presentation to managementexport_graphviz()Polished, professional quality
Explaining single predictiontrace_prediction_path()Shows exact logic for one case
Understanding feature spaceDecision boundary plotReveals how space is partitioned
Feature selectionFeature importance plotShows which features matter
Choosing max_depthComplexity dashboardMulti-metric view across depths
Code deploymentrules_to_python_function()Sklearn-independent portable rules

Visualization Best Practices

Limit depth for readability. A tree with more than 5–6 levels becomes impossible to read on screen. For deep trees, visualize only the top levels using max_depth in plot_tree(), then investigate specific subtrees separately.

Use filled=True for faster intuition. Color intensity immediately communicates node purity — you can spot the impure (mixed) nodes before reading any numbers.

Export text rules for stakeholders. Technical visualizations overwhelm most business audiences. The plain-English rules from export_text() or extract_decision_rules() communicate the model’s logic without requiring knowledge of tree diagrams.

Show both MDI and permutation importance. MDI importances are fast but biased toward high-cardinality features. Permutation importances are unbiased but slower and require held-out test data. Showing both and comparing them reveals whether high-cardinality features are artificially inflated.

Trace predictions for debugging. When a prediction seems wrong, the path trace immediately shows which feature value triggered the unexpected branch — far more diagnostic than a simple probability output.

Summary

Visualization is what converts a decision tree from a black box of numbers into an interpretable, communicable, and auditable model. The seven methods covered in this article serve different purposes: plot_tree for exploration, export_text for documentation and stakeholder communication, Graphviz for reports and publications, decision boundary plots for understanding the geometric structure, feature importance charts for variable selection, prediction path tracing for debugging and explanation, and complexity dashboards for hyperparameter selection.

The axis-aligned, rectangular nature of decision tree boundaries — visible immediately in 2D boundary plots — is both the tree’s characteristic fingerprint and its key limitation. It explains why trees can struggle with diagonal boundaries and why ensembles that average many trees (each seeing a different perspective on the data) produce smoother, more generalizable boundaries.

Every visualization in this article serves a practical purpose beyond aesthetics: they make the model’s logic auditable, its failures diagnosable, and its decisions explainable to the humans who must trust and act on its outputs.

Share:
Subscribe
Notify of
0 Comments
Inline Feedbacks
View all comments

Discover More

The Difference Between Voltage, Current, and Resistance Explained Simply

Master the three fundamental concepts of electronics. Learn the difference between voltage, current, and resistance…

File Systems 101: How Your Operating System Organizes Data

Learn how file systems organize data on your computer. Discover partitions, directories, file allocation, and…

Blue Origin Announces TeraWave: 5,408 Satellites to Challenge Starlink

Blue Origin announces TeraWave satellite network with 5,408 satellites offering 6 terabits per second speeds…

OpenAI Plans Q4 2026 IPO Targeting $1 Trillion Valuation

ChatGPT maker OpenAI prepares for fourth-quarter 2026 IPO with potential $1 trillion valuation, engaging Wall…

Learn, Do and Share!

Learning technology is most powerful when theory turns into practice. Reading is important, but building,…

Polynomial Regression: When Linear Isn't Enough

Polynomial Regression: When Linear Isn’t Enough

Learn polynomial regression — how to model curved relationships by adding polynomial features. Includes degree…

Click For More
0
Would love your thoughts, please comment.x
()
x