Source code for liesel_gam.registry

"""Variable registry for managing data variables and transformations."""

from __future__ import annotations

import hashlib
import inspect
import logging
import warnings
from collections.abc import Callable
from dataclasses import dataclass
from typing import Any, Literal, assert_never

import jax.numpy as jnp
import liesel.model as lsl
import numpy as np
import pandas as pd

from .category_mapping import CategoryMapping, series_is_categorical

logger = logging.getLogger(__name__)

Array = Any


class CannotHashValueError(Exception):
    """Custom exception for values that cannot be hashed."""

    def __init__(self, value: Any):
        super().__init__(f"Cannot hash value of type '{type(value).__name__}'")
        self.value = value


@dataclass
class VarAndMapping:
    var: lsl.Var
    mapping: CategoryMapping | None = None

    @property
    def is_categorical(self) -> bool:
        return self.mapping is not None


[docs] class PandasRegistry: """Registry for managing variables and their transformations. Handles conversion from `pandas.DataFrame` to `liesel.Var` objects, applies transformations, and caches results for efficiency. """ def __init__( self, data: pd.DataFrame, na_action: Literal["error", "drop", "ignore"] = "error", prefix_names_by: str = "", ): """Initialize the variable registry. Args: data: pandas DataFrame containing model variables na_action: How to handle NaN values. Either "error", "drop", or "ignore" """ if na_action not in ["error", "drop", "ignore"]: raise ValueError("na_action must be 'error', 'drop', or 'ignore'") self.original_data = data.copy() self.na_action = na_action self.data = self._validate_data(data) self._var_cache: dict[str, lsl.Var] = {} self._derived_cache: dict[str, lsl.Var] = {} self.prefix = prefix_names_by def _validate_data(self, data: pd.DataFrame) -> pd.DataFrame: """Validate data and handle NaN values according to policy.""" if data.isna().any().any(): if self.na_action == "error": na_cols = data.columns[data.isna().any()].tolist() raise ValueError( f"Data contains NaN values in columns: {na_cols}. " "Use na_action='drop' to automatically remove rows with NaN values." ) elif self.na_action == "drop": clean_data = data.dropna() if len(clean_data) == 0: raise ValueError("No rows remaining after dropping NaN values") return clean_data elif self.na_action == "ignore": pass else: assert_never() return data.copy() @property def columns(self) -> list[str]: """Get list of available column names.""" return list(self.data.columns) @property def shape(self) -> tuple[int, int]: """Get shape of the data after NA handling.""" return self.data.shape def _to_jax(self, values: Any, var_name: str) -> Array: """Check if values are compatible with JAX.""" try: array = jnp.asarray(values) except Exception as e: raise TypeError( f"Variable '{var_name}' could not convert to JAX array" ) from e return array def _is_closure(self, func: Callable) -> bool: """Check if function is a closure (captures variables from outer scope).""" return func.__closure__ is not None def _hash_closure_value(self, value: Any) -> str: """Create hash for closure values, specifically supporting JAX arrays.""" try: # try direct hashing first return str(hash(value)) except TypeError: # handle unhashable types if isinstance(value, jnp.ndarray): # JAX arrays: hash shape, dtype, and content return f"jax_array_{value.shape}_{value.dtype}_{hash(value.tobytes())}" else: # unsupported type - signal to skip caching raise CannotHashValueError(value) def _hash_function(self, func: Callable) -> str | None: """Create hash for function, or use object ID for methods/callable objects.""" if inspect.isfunction(func): # Regular functions: hash source code and closures source = inspect.getsource(func) if self._is_closure(func): # for mypy assert func.__closure__ is not None, "Closure should have a closure" # hash closure variables closure_names = func.__code__.co_freevars closure_values = [cell.cell_contents for cell in func.__closure__] closure_hashes = [] for name, value in zip(closure_names, closure_values): try: value_hash = self._hash_closure_value(value) closure_hashes.append(f"{name}:{value_hash}") except CannotHashValueError: # unsupported closure variable, skip caching warnings.warn( f"Function uses unsupported closure variable type " f"'{type(value).__name__}'. Provide explicit cache_key " f"for caching.", UserWarning, stacklevel=3, ) return None closure_signature = ",".join(sorted(closure_hashes)) else: closure_signature = "" # combine source and closure state combined = f"{source}|{closure_signature}" return hashlib.md5(combined.encode()).hexdigest() elif inspect.ismethod(func): # Bound method: use object ID + method name for consistent caching obj_id = id(func.__self__) method_name = func.__name__ return f"method_{obj_id}_{method_name}" elif hasattr(func, "__call__"): # Callable objects, lambdas, etc.: use object ID return f"obj_id_{id(func)}" else: raise TypeError(f"Unsupported function type: {type(func)}")
[docs] def get_obs( self, name: str, ) -> lsl.Var: """Get or create a liesel Var for a data column. Args: name: Column name in the data Returns: liesel.Var object """ if name not in self.data.columns: available = list(self.data.columns) raise KeyError( f"Variable '{name}' not found in data. " f"Available variables: {sorted(available)}" ) varname = self.prefix + name # check if already cached if name in self._var_cache: var = self._var_cache[name] else: # get raw values values = self._to_jax(self.data[name].to_numpy(), name) var = lsl.Var.new_obs(values, name=varname) self._var_cache[name] = var return var
def _make_derived_var( self, base_var: lsl.Var, transform: Callable, var_name: str | None ) -> lsl.Var: """Apply a transformation to a base variable and return a new Var.""" if var_name is None: var_name = ( f"{base_var.name}_{getattr(transform, '__name__', str(transform))}" ) try: derived_var = lsl.Var.new_calc(transform, base_var, name=var_name) except Exception as e: transformation_name = getattr(transform, "__name__", str(transform)) raise ValueError( f"Failed to apply transformation '{transformation_name}' " f"to variable '{base_var.name}': {str(e)}" ) return derived_var
[docs] def get_calc( self, name: str, transform: Callable, var_name: str | None = None, cache_key: str | None = None, ) -> lsl.Var: """Get a derived version of the variable. Derived variables are cached when possible. Creates a lsl.new_obs for the base variable and a lsl.new_calc for the derived variable. Args: name: Column name in the data frame transform: Callable transformation function to apply var_name: Custom name for the resulting variable cache_key: Explicit cache key. If provided, skips function hashing. Returns: liesel.Var object with transformed values """ # get base var base_var = self.get_obs(name) # generate cache key if cache_key is not None: # explicit cache key provided full_cache_key = f"{name}_{cache_key}_{var_name or 'default'}" else: # try to hash the function func_hash = self._hash_function(transform) if func_hash is None: # caching not possible, return derived var without caching return self._make_derived_var(base_var, transform, var_name) full_cache_key = f"{name}_{func_hash}_{var_name or 'default'}" # check cache first if full_cache_key in self._derived_cache: return self._derived_cache[full_cache_key] # cache miss var = self._make_derived_var(base_var, transform, var_name) self._derived_cache[full_cache_key] = var return var
[docs] def get_calc_centered(self, name: str, var_name: str | None = None) -> lsl.Var: """Get a centered version of the variable: x - mean(x). note, mean(x) is computed from the original data and cached. Args: name: Column name in the data var_name: Custom name for the resulting variable Returns: liesel.Var object with centered values """ base_var = self.get_obs(name) values = base_var.value mean_val = float(np.mean(values)) def center_transform(x): return x - mean_val center_transform.__name__ = "centered" return self._make_derived_var( base_var, center_transform, var_name or f"{name}_centered" )
[docs] def get_calc_standardized(self, name: str, var_name: str | None = None) -> lsl.Var: """Get a standardized version of the variable: (x - mean(x)) / std(x). note, mean(x) and std(x) are computed from the original data and cached. Args: name: Column name in the data var_name: Custom name for the resulting variable Returns: liesel.Var object with standardized values """ base_var = self.get_obs(name) values = base_var.value mean_val = float(np.mean(values)) std_val = float(np.std(values)) if std_val == 0: raise ValueError( f"Failed to apply transformation 'standardization' to variable " f"'{name}': standard deviation is zero (constant variable)" ) def std_transform(x): return (x - mean_val) / std_val std_transform.__name__ = "std" return self._make_derived_var( base_var, std_transform, var_name or f"{name}_std" )
[docs] def get_calc_dummymatrix( self, name: str, var_name_prefix: str | None = None ) -> lsl.Var: """Get dummy matrix for a categorical column using standard dummy coding. Drops the column of the first category. Args: name: Column name in the data var_name_prefix: Prefix for dummy variable names Returns: Dictionary mapping category names to liesel.Var objects """ base_var, mapping = self.get_categorical_obs(name) base_var.name = base_var.name = f"{name}_codes" codebook = mapping.labels_to_integers_map if len(codebook) < 2: raise ValueError( f"Failed to apply transformation 'dummy encoding' to variable " f"'{name}': only {len(codebook)} unique value(s) found" ) # jax-compatible dummy coding transformation n_categories = len(codebook) def dummy_transform(codes): # create dummy matrix with standard dummy coding (drop first category) # use float32 to support NaN for unknown codes dummy_matrix = jnp.zeros( (codes.shape[0], n_categories - 1), dtype=jnp.float32 ) for i in range(1, n_categories): # only a few cat, so for loop is fine dummy_matrix = dummy_matrix.at[:, i - 1].set(codes == i) # set rows with unknown codes (>= n_categories or < 0) to NaN unknown_mask = (codes >= n_categories) | (codes < 0) dummy_matrix = jnp.where(unknown_mask[:, None], jnp.nan, dummy_matrix) return dummy_matrix dummy_transform.__name__ = f"{name}_dummy" # create dummy matrix variable prefix = var_name_prefix or f"{name}_" dummy_matrix_name = f"{prefix}matrix" dummy_matrix_var = lsl.Var.new_calc( dummy_transform, base_var, name=dummy_matrix_name ) return dummy_matrix_var
[docs] def is_numeric(self, name: str) -> bool: """Check if a variable is numeric. Args: name: Column name in the data Returns: True if variable is numeric, False otherwise """ if name not in self.data.columns: available = list(self.data.columns) raise KeyError( f"Variable '{name}' not found in data. " f"Available variables: {sorted(available)}" ) return pd.api.types.is_numeric_dtype(self.data[name])
[docs] def is_categorical(self, name: str) -> bool: """Check if a variable is categorical. Args: name: Column name in the data Returns: True if variable is categorical, False otherwise """ if name not in self.data.columns: available = list(self.data.columns) raise KeyError( f"Variable '{name}' not found in data. " f"Available variables: {sorted(available)}" ) return series_is_categorical(self.data[name])
[docs] def is_boolean(self, name: str) -> bool: """Check if a variable is boolean. Args: name: Column name in the data Returns: True if variable is boolean, False otherwise """ if name not in self.data.columns: available = list(self.data.columns) raise KeyError( f"Variable '{name}' not found in data. " f"Available variables: {sorted(available)}" ) return pd.api.types.is_bool_dtype(self.data[name])
[docs] def get_numeric_obs(self, name: str) -> lsl.Var: """Get a variable and ensure it is numeric. Args: name: Variable name to retrieve Returns: liesel.Var object for the numeric variable Raises: TypeError: If the variable is not numeric """ if not self.is_numeric(name): raise TypeError( f"Type mismatch for variable '{name}': expected numeric, " f"got {str(self.data[name].dtype)}" ) return self.get_obs(name)
[docs] def get_categorical_obs(self, name: str) -> tuple[lsl.Var, CategoryMapping]: """Get a variable and ensure it is categorical. Each variable is converted to integer codes. Args: name: Variable name to retrieve Returns: liesel.Var object for the categorical variable and a CategoryMapping. Raises: TypeError: If any variable is not categorical """ series = self.data[name] if not self.is_categorical(name): raise TypeError( f"Type mismatch for variable '{name}': expected categorical, " f"got {str(series.dtype)}" ) mapping = CategoryMapping.from_series(series) if name in self._var_cache: var = self._var_cache[name] else: # convert categorical variables to integer codes category_codes = mapping.labels_to_integers(series) jax_codes = self._to_jax(category_codes, name) varname = self.prefix + name var = lsl.Var.new_obs(jax_codes, name=varname) self._var_cache[name] = var # now some exception handling # only emitted once nparams = len(mapping.labels_to_integers_map) n_observed_clusters = jnp.unique(var.value).size observed_clusters = np.unique(var.value).tolist() clusters = list(mapping.integers_to_labels_map) clusters_not_in_data = [c for c in clusters if c not in observed_clusters] if n_observed_clusters != nparams: logger.info( f"For {name}, there are {nparams} categories, but the " f"data contain observations for only {n_observed_clusters}. The " f"categories without observations are: {clusters_not_in_data}. " "If this is intended, you can ignore this warning. " "Be aware, that parameters for the unobserved categories may be " "included in the model." ) return var, mapping
[docs] def get_boolean_obs(self, name: str) -> lsl.Var: """Get a variable and ensure it is boolean. Args: name: Variable name to retrieve Returns: liesel.Var object for the boolean variable Raises: TypeError: If the variable is not boolean """ if not self.is_boolean(name): raise TypeError( f"Type mismatch for variable '{name}': expected boolean, " f"got {str(self.data[name].dtype)}" ) return self.get_obs(name)
[docs] def get_obs_and_mapping(self, name: str) -> VarAndMapping: """ Get an observed variable. Returns a wrapper that holds the variable and, if the variable is categorical, the :class:`.CategoryMapping` between labels and integer codes. """ if self.is_categorical(name): var, mapping = self.get_categorical_obs(name) else: var = self.get_obs(name) mapping = None return VarAndMapping(var, mapping)