import pandas as pd
import warnings
from typing import Dict, Union, List
from .constants import ImputationMethod, SUPPORTED_METHODS, DEFAULT_METHOD, InitialMethod, SUPPORTED_INITIAL_METHODS, DEFAULT_INITIAL_METHOD, VisitSequence, SUPPORTED_VISIT_SEQUENCES
import numpy as np
import re
[docs]
def check_n_imputations(n_imputations: int) -> None:
"""
Check if the number of imputations is valid and provide a warning if it's high.
Parameters
----------
n_imputations : int
Number of imputations to perform
Raises
------
ValueError
If n_imputations is not a positive integer
"""
if not isinstance(n_imputations, int):
raise ValueError("n_imputations must be a positive integer")
if n_imputations <= 0:
raise ValueError("n_imputations must be a positive integer")
if n_imputations > 100:
print(f"Warning: {n_imputations} imputations is a large number. This might take a while to compute.")
[docs]
def check_maxit(maxit: int) -> None:
"""
Check if the maximum iterations parameter is valid and provide a warning if it's high.
Parameters
----------
maxit : int
Maximum number of iterations for each imputation cycle
Raises
------
ValueError
If maxit is not a positive integer
"""
if not isinstance(maxit, int):
raise ValueError("maxit must be an integer")
if maxit <= 0:
raise ValueError("maxit must be a positive integer")
if maxit > 50:
warnings.warn("maxit is greater than 50, imputations will take a lot of time", UserWarning)
[docs]
def check_method(method: Union[str, Dict[str, str]], columns: List[str]) -> Dict[str, str]:
"""
Check and process the method parameter for MICE imputation.
Parameters
----------
method : Union[str, Dict[str, str]]
Method specification. Can be:
- str: use the same method for all columns
- Dict[str, str]: dictionary mapping column names to their methods
columns : List[str]
List of column names in the data
Returns
-------
Dict[str, str]
Dictionary mapping each column to its imputation method
Raises
------
ValueError
If method is invalid or references non-existent columns
"""
# If method is a string, validate and use for all columns
if isinstance(method, str):
if not ImputationMethod.is_valid_method(method):
raise ValueError(f"Unsupported method: {method}. Supported methods are: {SUPPORTED_METHODS}")
return {col: method for col in columns}
# If method is a dictionary, validate each entry
if isinstance(method, dict):
# Check if all specified columns exist
invalid_cols = [col for col in method.keys() if col not in columns]
if invalid_cols:
raise ValueError(f"Columns not found in data: {invalid_cols}")
# Check if all methods are supported
invalid_methods = {col: m for col, m in method.items() if not ImputationMethod.is_valid_method(m)}
if invalid_methods:
raise ValueError(f"Unsupported methods: {invalid_methods}. Supported methods are: {SUPPORTED_METHODS}")
# Create result dict with default method for unspecified columns
# TODO: make default method dependent on column type or just handle it otherwise,
# e.g. not let method for a column be not specified
result = {col: DEFAULT_METHOD for col in columns} # Default method
result.update(method) # Override with specified methods
return result
raise ValueError("method must be either a string or a dictionary")
[docs]
def check_initial_method(initial_method: str) -> None:
"""
Check if the initial imputation method is valid.
Parameters
----------
initial_method : str
Initial imputation method to validate
Raises
------
ValueError
If initial_method is not a valid initial imputation method
"""
if not isinstance(initial_method, str):
raise ValueError("initial_method must be a string")
if not InitialMethod.is_valid_method(initial_method):
raise ValueError(f"Unsupported initial method: {initial_method}. Supported initial methods are: {SUPPORTED_INITIAL_METHODS}")
[docs]
def check_visit_sequence(
visit_sequence: Union[str, List[str]],
columns: List[str],
columns_with_missing: List[str] = None
) -> tuple:
"""
Check and process the visit sequence parameter for MICE imputation.
Parameters
----------
visit_sequence : Union[str, List[str]]
Visit sequence specification. Can be:
- str: "monotone" or "random" for predefined sequences
- List[str]: list of column names specifying the order to visit variables
columns : List[str]
List of all column names in the data
columns_with_missing : List[str], optional
List of columns that have missing values. If provided, will validate
that all these columns are included in a custom visit sequence.
Returns
-------
tuple
(validated_sequence, columns_without_missing) where:
- validated_sequence: the processed visit sequence (only for list input, None for string)
- columns_without_missing: list of columns in sequence that don't have missing values
Raises
------
ValueError
If visit_sequence is invalid or references non-existent columns
Notes
-----
For string visit sequences ("monotone", "random"), the actual sequence will be
generated in MICE._set_visit_sequence() based on the data.
For list visit sequences, this function validates that:
1. All columns exist in the data
2. No duplicate columns
3. All columns with missing values are included (if columns_with_missing provided)
"""
if isinstance(visit_sequence, str):
if not VisitSequence.is_valid_sequence(visit_sequence):
raise ValueError(f"Unsupported visit sequence: {visit_sequence}. Supported visit sequences are: {SUPPORTED_VISIT_SEQUENCES}")
return None, []
if isinstance(visit_sequence, list):
invalid_cols = [col for col in visit_sequence if col not in columns]
if invalid_cols:
raise ValueError(f"Visit sequence contains columns not in data: {invalid_cols}")
if len(visit_sequence) != len(set(visit_sequence)):
duplicates = [col for col in visit_sequence if visit_sequence.count(col) > 1]
raise ValueError(f"Visit sequence contains duplicate columns: {set(duplicates)}")
if columns_with_missing is not None:
missing_from_sequence = [col for col in columns_with_missing if col not in visit_sequence]
if missing_from_sequence:
raise ValueError(
f"Visit sequence is missing columns with missing values: {missing_from_sequence}. "
f"All columns with missing values must be included in the visit sequence. "
f"Columns with missing values: {columns_with_missing}"
)
cols_without_missing = [col for col in visit_sequence if col not in columns_with_missing]
validated_sequence = [col for col in visit_sequence if col in columns_with_missing]
return validated_sequence, cols_without_missing
return visit_sequence, []
raise ValueError("visit_sequence must be either a string or a list of strings")
[docs]
def validate_predictor_matrix(predictor_matrix: pd.DataFrame, data_columns: List[str], data: pd.DataFrame) -> pd.DataFrame:
"""
Validate predictor matrix for MICE imputation.
Parameters
----------
predictor_matrix : pd.DataFrame
Binary matrix indicating which variables should be used as predictors
for each target variable. Rows represent target variables, columns represent predictors.
A 1 indicates that the column variable is used as predictor for the index variable.
data_columns : List[str]
List of column names in the data to validate against
data : pd.DataFrame
The data to check for missing values
Returns
-------
pd.DataFrame
Validated predictor matrix
Raises
------
ValueError
If predictor_matrix has invalid structure or column names don't match data
"""
if not isinstance(predictor_matrix, pd.DataFrame):
raise ValueError("predictor_matrix must be a pandas DataFrame")
# Check if all column names in predictor matrix exist in data
predictor_cols = list(predictor_matrix.columns)
missing_cols = [col for col in predictor_cols if col not in data_columns]
if missing_cols:
raise ValueError(f"predictor_matrix contains columns not in data: {missing_cols}")
# Check if all row names (target variables) exist in data
target_vars = list(predictor_matrix.index)
missing_targets = [var for var in target_vars if var not in data_columns]
if missing_targets:
raise ValueError(f"predictor_matrix contains target variables not in data: {missing_targets}")
# Check if matrix contains only 0s and 1s
if not predictor_matrix.isin([0, 1]).all().all():
raise ValueError("predictor_matrix must contain only 0s and 1s")
# Check for columns without missing data that are being imputed (warning)
complete_targets = [var for var in target_vars if not data[var].isna().any()]
if complete_targets:
warnings.warn(f"Target variables without missing data are being imputed: {complete_targets}. This is unnecessary but allowed.")
# Check for columns with missing data that are used as predictors but not imputed (error)
predictor_only_vars = [col for col in predictor_cols if col not in target_vars]
incomplete_predictors = [var for var in predictor_only_vars if data[var].isna().any()]
if incomplete_predictors:
raise ValueError(f"Predictor variables with missing data are not being imputed: {incomplete_predictors}. All predictors must be complete or included as targets.")
return predictor_matrix
[docs]
def validate_columns(data: pd.DataFrame) -> pd.DataFrame:
"""
Validate and clean columns in the DataFrame.
Checks for columns with only NaN values and drops them with appropriate warnings.
Parameters
----------
data : pd.DataFrame
Input DataFrame to validate
Returns
-------
pd.DataFrame
DataFrame with invalid columns removed
Warns
-----
UserWarning
If columns with only NaN values are found and dropped
Notes
-----
Missing data values that are treated as NaN:
- pandas NaN (numpy.nan)
"""
# Check for columns with only NaN values
nan_only_cols = []
for col in data.columns:
if data[col].isna().all():
nan_only_cols.append(col)
# Drop columns with only NaN values and print warning
if nan_only_cols:
warnings.warn(f"Found columns with only NaN values: {nan_only_cols}. These columns will be dropped as they cannot be imputed.")
print(f"Dropping {len(nan_only_cols)} columns with only NaN values: {nan_only_cols}")
data = data.drop(columns=nan_only_cols)
return data
[docs]
def validate_dataframe(data) -> pd.DataFrame:
"""
Check and validate input data for MICE imputation.
Parameters
----------
data : Any
Input data to be checked and converted to DataFrame
Returns
-------
pd.DataFrame
Validated and cleaned DataFrame
Raises
------
ValueError
If data cannot be converted to DataFrame or has duplicate column names
Notes
-----
Missing data values that are treated as NaN:
- pandas NaN (numpy.nan)
"""
# Try to convert to DataFrame if it's not already one
try:
if not isinstance(data, pd.DataFrame):
data = pd.DataFrame(data)
except Exception as e:
raise ValueError(f"Input data cannot be converted to DataFrame: {str(e)}")
# Check for duplicate column names
duplicate_cols = data.columns[data.columns.duplicated()].tolist()
if duplicate_cols:
print(f"Found duplicate column names: {duplicate_cols}. Please make sure that the column names are unique.")
raise ValueError("DataFrame contains duplicate column names")
# Check for fully empty rows
n_rows_before = len(data)
data = data.dropna(how='all')
n_rows_after = len(data)
n_dropped = n_rows_before - n_rows_after
if n_dropped > 0:
print(f"Dropped {n_dropped} fully empty rows")
# Check for columns with no values
empty_cols = data.columns[data.isna().all()].tolist()
if empty_cols:
warnings.warn(f"Found columns with no values: {empty_cols}. These columns will be dropped as they cannot be imputed.")
print(f"Dropping {len(empty_cols)} columns with no values: {empty_cols}")
data = data.drop(columns=empty_cols)
return data