Source code for plotting.utils

import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.colors import ListedColormap
import pandas as pd
import numpy as np
import os
import logging

logger = logging.getLogger('imputation.plotting.utils')


def _ensure_directory(file_path):
    """
    Ensure the directory for a file path exists.
    Creates parent directories if they don't exist.
    
    Parameters
    ----------
    file_path : str
        Path to a file (not a directory)
    """
    directory = os.path.dirname(file_path)
    if directory and not os.path.exists(directory):
        os.makedirs(directory, exist_ok=True)
        logger.debug(f"Created directory: {directory}")


[docs] def md_pattern_like(df): """ Replicates the md.pattern() behavior from R's mice package. Shows missing data patterns as 1 (observed) and 0 (missing), counts per pattern and per column. Parameters: ----------- df : pandas.DataFrame Input DataFrame with potential missing values Returns: -------- pandas.DataFrame DataFrame showing missing data patterns with counts """ if not isinstance(df, pd.DataFrame): raise ValueError("Input must be a pandas DataFrame") if df.shape[1] < 2: raise ValueError("Data must have at least two columns") R = df.notna().astype(int) col_missing_counts = (R == 0).sum() sorted_cols = col_missing_counts.sort_values(ascending=True).index.tolist() R_sorted = R[sorted_cols] pattern_strings = R_sorted.astype(str).agg(''.join, axis=1) pattern_counts = pattern_strings.value_counts() pattern_matrix = pd.DataFrame( [list(map(int, list(p))) for p in pattern_counts.index], columns=sorted_cols ) pattern_matrix["#miss_row"] = pattern_matrix.shape[1] - pattern_matrix.sum(axis=1) pattern_matrix = pattern_matrix.sort_values(by="#miss_row") pattern_counts = pattern_counts.iloc[pattern_matrix.index] pattern_matrix.index = pattern_counts.values pattern_matrix.index.name = "#rows" col_missing_sorted = col_missing_counts[sorted_cols] col_missing_row = col_missing_sorted.to_list() total_missing = sum(col_missing_row) col_missing_row.append(total_missing) summary_df = pattern_matrix.copy() summary_df.index = summary_df.index.map(str) summary_df.loc["#miss_col"] = col_missing_row return summary_df
[docs] def plot_missing_data_pattern(pattern_df, figsize=(8, 5), title="Missing Data Pattern", rotate_names=False, save_path=None): """ Plots the missing data pattern from a pattern dataframe. Parameters: ----------- pattern_df : pandas.DataFrame DataFrame containing the missing data pattern, typically generated by md_pattern_like() figsize : tuple, optional Figure size in inches (width, height). Default is (8, 5) title : str, optional Title for the plot. Default is "Missing Data Pattern" rotate_names : bool, optional Whether to rotate column names 90 degrees. Default is False save_path : str, optional If provided, save the plot to this path instead of displaying it Returns: -------- pandas.DataFrame The pattern matrix with counts, similar to R's md.pattern output """ data_only = pattern_df.iloc[:-1, :-1] row_counts = pattern_df.index[:-1] row_miss = pattern_df.iloc[:-1, -1] col_miss = pattern_df.loc["#miss_col"].iloc[:-1] cmap = ListedColormap([ (101/255, 155/255, 213/255), # blue for present (205/255, 100/255, 140/255) # pink for missing ]) # Create figure with adjusted size based on rotation if rotate_names: figsize = (figsize[0], figsize[1] + 0.5) # Add extra height for rotated labels plt.figure(figsize=figsize) # Create the heatmap ax = sns.heatmap(1 - data_only.astype(int), cmap=cmap, cbar=False, linewidths=0.5, linecolor='black', square=True) # Add row counts on the left for i, count in enumerate(row_counts): ax.text(-0.7, i + 0.5, f"{count}", va='center', ha='right', fontsize=10) # Add missing counts on the right for i, miss in enumerate(row_miss): ax.text(data_only.shape[1] + 0.1, i + 0.5, f"{int(miss)}", va='center', ha='left', fontsize=10) # Add column missing counts at the bottom for j, miss in enumerate(col_miss): ax.text(j + 0.5, data_only.shape[0] + 0.5, f"{int(miss)}", ha='center', va='top', fontsize=10) # Add column names on top for j, col_name in enumerate(data_only.columns): if rotate_names: ax.text(j + 0.5, -0.5, col_name, ha='right', va='top', rotation=90, fontsize=10) else: ax.text(j + 0.5, -0.5, col_name, ha='center', va='top', fontsize=10) # Remove default x and y axis labels ax.set_xticks([]) ax.set_yticks([]) ax.set_xticklabels([]) ax.set_yticklabels([]) # Add total missing count at bottom right ax.text(data_only.shape[1] + 0.1, data_only.shape[0] + 0.5, f"{int(pattern_df.loc['#miss_col'].iloc[-1])}", ha='left', va='top', fontsize=10) plt.title(title) plt.tight_layout(pad=0.4) if save_path: _ensure_directory(save_path) plt.savefig(save_path, dpi=300, bbox_inches='tight') plt.close() else: plt.show() # Return the pattern matrix for textual output return pattern_df