import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.colors import ListedColormap
import pandas as pd
import numpy as np
import os
import logging
logger = logging.getLogger('imputation.plotting.utils')
def _ensure_directory(file_path):
"""
Ensure the directory for a file path exists.
Creates parent directories if they don't exist.
Parameters
----------
file_path : str
Path to a file (not a directory)
"""
directory = os.path.dirname(file_path)
if directory and not os.path.exists(directory):
os.makedirs(directory, exist_ok=True)
logger.debug(f"Created directory: {directory}")
[docs]
def md_pattern_like(df):
"""
Replicates the md.pattern() behavior from R's mice package.
Shows missing data patterns as 1 (observed) and 0 (missing),
counts per pattern and per column.
Parameters:
-----------
df : pandas.DataFrame
Input DataFrame with potential missing values
Returns:
--------
pandas.DataFrame
DataFrame showing missing data patterns with counts
"""
if not isinstance(df, pd.DataFrame):
raise ValueError("Input must be a pandas DataFrame")
if df.shape[1] < 2:
raise ValueError("Data must have at least two columns")
R = df.notna().astype(int)
col_missing_counts = (R == 0).sum()
sorted_cols = col_missing_counts.sort_values(ascending=True).index.tolist()
R_sorted = R[sorted_cols]
pattern_strings = R_sorted.astype(str).agg(''.join, axis=1)
pattern_counts = pattern_strings.value_counts()
pattern_matrix = pd.DataFrame(
[list(map(int, list(p))) for p in pattern_counts.index],
columns=sorted_cols
)
pattern_matrix["#miss_row"] = pattern_matrix.shape[1] - pattern_matrix.sum(axis=1)
pattern_matrix = pattern_matrix.sort_values(by="#miss_row")
pattern_counts = pattern_counts.iloc[pattern_matrix.index]
pattern_matrix.index = pattern_counts.values
pattern_matrix.index.name = "#rows"
col_missing_sorted = col_missing_counts[sorted_cols]
col_missing_row = col_missing_sorted.to_list()
total_missing = sum(col_missing_row)
col_missing_row.append(total_missing)
summary_df = pattern_matrix.copy()
summary_df.index = summary_df.index.map(str)
summary_df.loc["#miss_col"] = col_missing_row
return summary_df
[docs]
def plot_missing_data_pattern(pattern_df, figsize=(8, 5), title="Missing Data Pattern", rotate_names=False, save_path=None):
"""
Plots the missing data pattern from a pattern dataframe.
Parameters:
-----------
pattern_df : pandas.DataFrame
DataFrame containing the missing data pattern, typically generated by md_pattern_like()
figsize : tuple, optional
Figure size in inches (width, height). Default is (8, 5)
title : str, optional
Title for the plot. Default is "Missing Data Pattern"
rotate_names : bool, optional
Whether to rotate column names 90 degrees. Default is False
save_path : str, optional
If provided, save the plot to this path instead of displaying it
Returns:
--------
pandas.DataFrame
The pattern matrix with counts, similar to R's md.pattern output
"""
data_only = pattern_df.iloc[:-1, :-1]
row_counts = pattern_df.index[:-1]
row_miss = pattern_df.iloc[:-1, -1]
col_miss = pattern_df.loc["#miss_col"].iloc[:-1]
cmap = ListedColormap([
(101/255, 155/255, 213/255), # blue for present
(205/255, 100/255, 140/255) # pink for missing
])
# Create figure with adjusted size based on rotation
if rotate_names:
figsize = (figsize[0], figsize[1] + 0.5) # Add extra height for rotated labels
plt.figure(figsize=figsize)
# Create the heatmap
ax = sns.heatmap(1 - data_only.astype(int), cmap=cmap, cbar=False,
linewidths=0.5, linecolor='black', square=True)
# Add row counts on the left
for i, count in enumerate(row_counts):
ax.text(-0.7, i + 0.5, f"{count}", va='center', ha='right', fontsize=10)
# Add missing counts on the right
for i, miss in enumerate(row_miss):
ax.text(data_only.shape[1] + 0.1, i + 0.5, f"{int(miss)}", va='center', ha='left', fontsize=10)
# Add column missing counts at the bottom
for j, miss in enumerate(col_miss):
ax.text(j + 0.5, data_only.shape[0] + 0.5, f"{int(miss)}", ha='center', va='top', fontsize=10)
# Add column names on top
for j, col_name in enumerate(data_only.columns):
if rotate_names:
ax.text(j + 0.5, -0.5, col_name, ha='right', va='top', rotation=90, fontsize=10)
else:
ax.text(j + 0.5, -0.5, col_name, ha='center', va='top', fontsize=10)
# Remove default x and y axis labels
ax.set_xticks([])
ax.set_yticks([])
ax.set_xticklabels([])
ax.set_yticklabels([])
# Add total missing count at bottom right
ax.text(data_only.shape[1] + 0.1, data_only.shape[0] + 0.5,
f"{int(pattern_df.loc['#miss_col'].iloc[-1])}",
ha='left', va='top', fontsize=10)
plt.title(title)
plt.tight_layout(pad=0.4)
if save_path:
_ensure_directory(save_path)
plt.savefig(save_path, dpi=300, bbox_inches='tight')
plt.close()
else:
plt.show()
# Return the pattern matrix for textual output
return pattern_df