Source code for simuk.sbc

"""Simulation-based calibration checking (SBC) for PyMC, Bambi, and NumPyro.

Implements both Prior SBC (Talts et al., 2020) and Posterior SBC
(Säilynoja et al., 2025).

References
----------
.. [1] Talts, S., Betancourt, M., Simpson, D., Vehtari, A., & Gelman, A. (2020).
   Validating Bayesian Inference Algorithms with Simulation-Based Calibration.
   arXiv:1804.06788.
.. [2] Säilynoja, T., Schmitt, M., Bürkner, P.-C., & Vehtari, A. (2025).
   Posterior SBC: Simulation-Based Calibration Checking Conditional on Data.
   arXiv:2502.03279.
"""

import logging
import traceback
from copy import copy
from importlib.metadata import version

try:
    import pymc as pm
except ImportError:
    pass
try:
    import jax
    from numpyro.handlers import seed, trace
    from numpyro.infer import MCMC, Predictive
    from numpyro.infer.mcmc import MCMCKernel
except ImportError:
    pass

import inspect
from collections.abc import Mapping

import numpy as np
from arviz_base import dict_to_dataset, extract, from_dict, from_numpyro
from tqdm import tqdm


class quiet_logging:
    """Turn off logging for PyMC, Bambi and PyTensor."""

    def __init__(self, *libraries):
        self.loggers = [logging.getLogger(library) for library in libraries]

    def __call__(self, func):
        def wrapped(cls, *args, **kwargs):
            levels = []
            for logger in self.loggers:
                levels.append(logger.level)
                logger.setLevel(logging.CRITICAL)
            res = func(cls, *args, **kwargs)
            for logger, level in zip(self.loggers, levels):
                logger.setLevel(level)
            return res

        return wrapped


[docs] class SBC: r"""Simulation-based calibration checking (SBC). Supports two modes of operation: - **Prior SBC** (``method="prior"``, default): validates that the inference algorithm across the prior. Reference draws come from the prior and replicated data from the prior predictive (Talts et al., 2020 [1]_). - **Posterior SBC** (``method="posterior"``): validates that the inference algorithm across the posterior. Reference draws come from the original posterior and replicated data from the posterior predictive. The model is then re-fit on the concatenation of the original observations and the replicated data (Säilynoja et al., 2025 [2]_). Parameters ---------- model : pymc.Model, bambi.Model or numpyro.infer.mcmc.MCMCKernel A PyMC, Bambi model or NumPyro MCMC kernel. If a PyMC model the data needs to be defined as mutable data. method : {"prior", "posterior"}, default "prior" Which variant of SBC to perform. num_simulations : int, default 1000 How many SBC iterations to run. sample_kwargs : dict, optional Keyword arguments forwarded to ``pymc.sample`` (or ``bambi.Model.fit`` / ``numpyro.infer.MCMC``). seed : int, optional Random seed. This persists even if running the simulations is paused for whatever reason. data_dir : dict, optional Keyword arguments passed to numpyro model, intended for use when providing an MCMC Kernel model. simulator : callable, optional A custom data-generating function. It receives the model parameter values as keyword arguments plus a ``seed`` integer, and must return a ``dict`` mapping observed-variable names to numpy arrays. trace : arviz.InferenceData, optional Required for ``method="posterior"``. An InferenceData object that contains both the ``posterior`` and ``observed_data`` groups. The number of posterior draws per chain must be at least ``num_simulations``. augment_observed : callable, optional *Posterior SBC only.* Signature: ``(model, observed_data, replicated_data, simulation_idx) -> dict``. Builds the augmented observed data that the model will be conditioned on. ``observed_data`` is the xarray Dataset from ``trace["observed_data"]``, and ``replicated_data`` is a ``dict[str, np.ndarray]`` of the simulated observations from the original posterior predictive for the current iteration. The returned ``dict`` maps variable names to the augmented data. The **default** behaviour concatenates the original and replicated observations along the first axis for each variable. Provide this callback when simple concatenation is not valid, e.g. for structured data. update_data : callable, optional *Posterior SBC only.* Signature: ``(model, augmented_data, simulation_idx) -> None``. Called *before* conditioning the model on the augmented data. Use this to resize covariates, coordinate labels, or other ``pm.Data`` containers so that the model is consistent with the augmented dataset. transform : callable, optional A transform applied to both the reference draw and the posterior draws before computing the rank statistic. Signature: ``(param_name, param_value) -> transformed_value``. Useful for defining scalar test quantities (e.g. ``lambda param_name, param_value: np.mean(param_value)`` to test the mean of a vector parameter). The return values must be comparable with the ``<`` operator. The default is the identity (rank on the raw parameter values). keep_fits : bool, default True Whether to store posteriors to allow re-evaluation of rank statistics using a different quantity (``compute_rank_statistics``) without needing to run the simulations again. Notes ----- **Prior SBC** exploits the self-consistency of Bayesian updating: if :math:`\theta' \sim \pi(\theta)` and :math:`y' \sim \pi(y \mid \theta')`, then :math:`\theta'` is also a draw from :math:`\pi(\theta \mid y')`. See Talts et al., 2020 [1]_. **Posterior SBC** uses the same self-consistency after conditioning on observed data :math:`y_{\text{obs}}`. A draw :math:`\theta'_i \sim \pi(\theta \mid y_{\text{obs}})` and a replicated dataset :math:`y_i \sim \pi(y \mid \theta'_i)` are combined so that :math:`\theta'_i` is also a draw from :math:`\pi(\theta \mid y_i, y_{\text{obs}})`. The rank of :math:`\theta'_i` among augmented-posterior draws should be uniformly distributed if the inference is calibrated. See Säilynoja et al., 2025 [2]_. References ---------- .. [1] Talts, S., Betancourt, M., Simpson, D., Vehtari, A., & Gelman, A. (2020). Validating Bayesian Inference Algorithms with Simulation-Based Calibration. arXiv:1804.06788. .. [2] Säilynoja, T., Schmitt, M., Bürkner, P.-C., & Vehtari, A. (2025). Posterior SBC: Simulation-Based Calibration Checking Conditional on Data. arXiv:2502.03279. Examples -------- **Prior SBC** (default): .. code-block:: python import pymc as pm import simuk with pm.Model() as model: x = pm.Normal('x') y = pm.Normal('y', mu=2 * x, observed=obs) sbc = simuk.SBC(model, num_simulations=200) sbc.run_simulations() **Posterior SBC** – validate inference conditional on observed data: .. code-block:: python import pymc as pm import simuk with pm.Model() as model: x = pm.Normal('x') y = pm.Normal('y', mu=2 * x, observed=obs) # 1. Obtain posterior samples from the real data trace = pm.sample() # 2. Run posterior SBC sbc = simuk.SBC( model, method="posterior", trace=trace, num_simulations=200, ) sbc.run_simulations() """ def __init__( self, model, method="prior", num_simulations=1000, sample_kwargs=None, seed=None, data_dir=None, simulator=None, trace=None, augment_observed=None, update_data=None, transform=None, keep_fits=True, progress_bar=True, ): if hasattr(model, "basic_RVs") and isinstance(model, pm.Model): self.engine = "pymc" self.model = model elif hasattr(model, "formula"): self.engine = "bambi" model.build() self.bambi_model = model self.model = model.backend.model self.formula = model.formula self.new_data = copy(model.data) elif isinstance(model, MCMCKernel): self.engine = "numpyro" self.numpyro_model = model self.model = self.numpyro_model.model self.run_simulations = self._run_simulations_numpyro self.data_dir = data_dir if data_dir is not None else {} else: raise ValueError( "model should be one of pymc.Model, bambi.Model, or numpyro.infer.mcmc.MCMCKernel" ) if method == "posterior" and self.engine != "pymc": raise NotImplementedError("Currently, Posterior SBC is only implemented for PyMC") self.progress_bar = progress_bar if sample_kwargs is None: sample_kwargs = {} if self.engine == "numpyro": sample_kwargs.setdefault("num_warmup", 1000) sample_kwargs.setdefault("num_samples", 1000) sample_kwargs.setdefault("progress_bar", False) else: sample_kwargs.setdefault("progressbar", False) sample_kwargs.setdefault("compute_convergence_checks", False) self.sample_kwargs = sample_kwargs self.num_simulations = num_simulations self.seed = seed self._seeds = self._get_seeds() self._extract_model_info() self.simulations = {name: [] for name in self.var_names} self._simulations_complete = 0 self.posteriors = [] self.keep_fits = keep_fits self.ref_params = None if simulator is not None and not callable(simulator): raise ValueError("simulator should be a function or None") if simulator is not None and self.observed_vars: logging.warning( "Provided model contains both observed variables and a simulator. " "Ignoring observed variables and using the simulator instead." ) if simulator is None and not self.observed_vars and self.engine == "pymc": # Ideally, we could raise an error early for `numpyro` also, # but `factor` also produces 'observed_vars' raise ValueError( "There are no observed variables, and PyMC will not generate predictive " "samples for both Prior and Posterior SBC. Either change the model or " "specify a simulator with the `simulator` argument." ) if simulator is None and self.engine == "numpyro": if not self.observed_model_vars: raise ValueError( "There are no observed variables we can condition on, and NumPyro " "will not generate prior predictive samples. Either change the model " "or specify a simulator with the `simulator` argument." ) missing = [name for name in self.observed_model_vars if name not in self.data_dir] if missing: raise ValueError( "The following model parameters are missing from data_dir: " + ", ".join(sorted(missing)) ) self.simulator = simulator self._transform = lambda param_name, param_value: param_value if transform is not None: if not callable(transform): raise ValueError("`transform` should be a function or None") self._transform = transform self.method = method.lower() if self.method == "posterior": if trace is None: raise ValueError( "When performing Posterior SBC, posterior samples from the " "original posterior are required to generate replicate datasets" ) if "posterior" not in trace: raise ValueError("`trace` should contain 'posterior' group") if "observed_data" not in trace: raise ValueError("`trace` should contain 'observed_data' group") if self.num_simulations > trace["posterior"].sizes["draw"]: raise ValueError( "posterior samples in `trace` should have more draws per " "chain than `num_simulations`. This is required to obtain enough " "posterior predictive samples" ) self.trace = trace if augment_observed is not None and not callable(augment_observed): raise ValueError("`augment_observed` should be a function or None") self.augment_observed = augment_observed if update_data is not None and not callable(update_data): raise ValueError("`update_data` should be a function or None") self.update_data = update_data else: if update_data is not None: logging.warning( "`update_data` is only supported for Posterior SBC. Ignoring...\n" "Prior SBC does not augment observations, so there is no need to " "update model data." ) if augment_observed is not None: logging.warning( "`augment_observed` is only supported for Posterior SBC. Ignoring...\n" "Prior SBC does not augment observations, so there is no need to " "augment observed data and replicated data" ) if trace is not None: logging.warning("`trace` is only used for Posterior SBC. Ignoring...") def _extract_model_info(self): """Extract observed and free variables from the model. Also records the baseline state for Posterior SBC. """ if self.engine == "numpyro": self.model_params = set(inspect.signature(self.model).parameters.keys()) with trace() as tr: with seed(rng_seed=int(self._seeds[0])): self.numpyro_model.model(**self.data_dir) self.var_names = [ name for name, site in tr.items() if site["type"] == "sample" and not site.get("is_observed", False) ] self.observed_vars = [ name for name, site in tr.items() if site["type"] == "sample" and site.get("is_observed", False) ] # Observed model variables are those that are marked as observed # and are also model function parameters in order to be able to condition on them. # For instance, this is used to filter out factor variables that are marked as observed # but cannot be conditioned on. self.observed_model_vars = [ name for name in self.observed_vars if name in self.model_params ] else: observed_var_nodes = [obs_rv for obs_rv in self.model.observed_RVs] self.observed_vars = [obs.name for obs in observed_var_nodes] self.var_names = [v.name for v in self.model.free_RVs] # Stores what observed values are given by pm.Data self.observed_rvs_to_pm_data = { var.name: ( self.model.rvs_to_values[var].name if hasattr(self.model.rvs_to_values[var], "get_value") else None ) for var in observed_var_nodes } self.model_baseline_state = self._get_baseline_state(self.model) def _get_baseline_state(self, model): """Extract the current mutable data and coordinates from a PyMC model.""" baseline_data = {} # Extract Mutable Data for var in model.data_vars: if hasattr(var, "get_value"): baseline_data[var.name] = var.get_value(borrow=False) # Extract Coordinates # Convert the internal PyMC coordinate object to a standard dictionary baseline_coords = dict(model.coords) return {"data": baseline_data, "coords": baseline_coords} def _reset_model_state(self, model, model_state): """Reset the state of PyMC model.""" with model: pm.set_data(model_state["data"], coords=model_state["coords"]) def _get_seeds(self): """Set the random seed, and generate seeds for all the simulations.""" rng = np.random.default_rng(self.seed) return rng.integers(0, 2**30, size=self.num_simulations) def _get_simulator_data(self, free_rv_samples): """Run the user-defined simulator to obtain predictive samples. These samples can be generated from either prior or posterior samples. """ # Deal with custom simulator pred = [] for i in range(free_rv_samples.sizes["sample"]): params = { var: free_rv_samples[var].isel(sample=i).values for var in free_rv_samples.data_vars } params["seed"] = self._seeds[i] try: res = self.simulator(**params) except Exception as e: raise ValueError( f"Error generating prior predictive sample with parameters {params}: {e}." ) if not isinstance(res, Mapping): raise TypeError(f"Simulator must return a dictionary, got {type(res)}") pred.append(res) pred = dict_to_dataset( {key: np.stack([pp[key] for pp in pred]) for key in pred[0]}, sample_dims=["sample"], coords={**free_rv_samples.coords}, ) return pred def _get_prior_predictive_samples(self): """Generate samples to use for the simulations.""" with self.model: idata = pm.sample_prior_predictive( draws=self.num_simulations, random_seed=self._seeds[0] ) prior = extract(idata, group="prior", keep_dataset=True) if self.simulator is None: prior_pred = extract(idata, group="prior_predictive", keep_dataset=True) return prior, prior_pred prior_pred = self._get_simulator_data(prior) return prior, prior_pred def _get_prior_predictive_samples_numpyro(self): """Generate samples to use for the simulations using numpyro.""" predictive = Predictive(self.model, num_samples=self.num_simulations) free_vars_data = { k: v for k, v in self.data_dir.items() if k not in self.observed_vars and k in self.model_params } samples = predictive(jax.random.PRNGKey(self._seeds[0]), **free_vars_data) prior = {k: v for k, v in samples.items() if k not in self.observed_vars} if self.simulator: results = [] for i, vals in enumerate(zip(*prior.values())): params = dict(zip(prior.keys(), vals)) params["seed"] = self._seeds[i] results.append(self.simulator(**params)) prior_pred = {key: [result[key] for result in results] for key in results[0]} else: prior_pred = {k: v for k, v in samples.items() if k in self.observed_model_vars} return prior, prior_pred def _get_posterior_samples(self, replicated_data): """Fit the model and return posterior draws for one SBC iteration. For **Prior SBC** the model is conditioned on the replicated data alone. For **Posterior SBC** the original observed data and the replicated data are combined (via ``augment_observed`` or the default simple concatenation) and the model is conditioned on the augmented dataset. Parameters ---------- replicated_data : dict[str, np.ndarray] Simulated observations for the current iteration, keyed by observed-variable name. Returns ------- xarray.Dataset Posterior draws from the (augmented) model. """ if self.method == "posterior": observed_data = self.trace["observed_data"] if self.augment_observed is not None: augmented_data = self.augment_observed( self.model, observed_data, replicated_data, self._simulations_complete ) else: # Default: concatenate original and replicated observations augmented_data = { var_name: np.concatenate( [observed_data[var_name].values, replicated_data[var_name]] ) for var_name in self.observed_vars } if self.update_data is not None: with self.model: self.update_data(self.model, augmented_data, self._simulations_complete) vars_to_observations = augmented_data else: # Prior SBC simply uses the generated prior predictive replicated data vars_to_observations = replicated_data # Set observed data that are pm.Data objects if the user hasn't modified them yet. # We enforce an np.array_equal check against the baseline to prevent PyMC size mismatch # ValueErrors when the user's `update_data` hook or `pm.observe` already updated it. with self.model: for rv, data_node in self.observed_rvs_to_pm_data.items(): if data_node is not None and np.array_equal( self.model.named_vars[data_node].get_value(), self.model_baseline_state["data"][data_node], ): pm.set_data(new_data={data_node: vars_to_observations[rv]}) try: new_model = pm.observe(self.model, vars_to_observations=vars_to_observations) with new_model: check = pm.sample( **self.sample_kwargs, random_seed=self._seeds[self._simulations_complete] ) posterior = extract(check, group="posterior", keep_dataset=True) except Exception: traceback.print_exc() raise finally: # Always ensure the model is reset to its un-augmented baseline state # so the next simulation iteration isn't corrupted by the previous loop's augmented data self._reset_model_state(self.model, self.model_baseline_state) return posterior def _get_posterior_samples_numpyro(self, prior_predictive_draw): """Generate posterior samples using numpyro conditioned to a prior predictive sample.""" mcmc = MCMC(self.numpyro_model, **self.sample_kwargs) rng_seed = jax.random.PRNGKey(self._seeds[self._simulations_complete]) free_vars_data = { k: v for k, v in self.data_dir.items() if k not in self.observed_model_vars and k in self.model_params } prior_predictive_args = { k: v for k, v in prior_predictive_draw.items() if k in self.observed_model_vars } mcmc.run(rng_seed, **free_vars_data, **prior_predictive_args) return from_numpyro(mcmc)["posterior"] def _get_posterior_predictive_samples(self): with self.model: num_draws = self.trace["posterior"].sizes["draw"] draw_indices = np.linspace(0, num_draws - 1, self.num_simulations, dtype=int) thinned_idata = self.trace.isel(draw=draw_indices) posterior = extract(thinned_idata, group="posterior", keep_dataset=True) if self.simulator is None: pm.sample_posterior_predictive( thinned_idata, extend_inferencedata=True, random_seed=self._seeds[0], progressbar=self.progress_bar, ) posterior_pred = extract( thinned_idata, group="posterior_predictive", keep_dataset=True ) return posterior, posterior_pred else: posterior_pred = self._get_simulator_data(posterior) return posterior, posterior_pred def _convert_to_datatree(self): """Pack the rank-statistic arrays into an xarray DataTree. Creates a group named ``"prior_sbc"`` or ``"posterior_sbc"`` (depending on ``self.method``) inside ``self.simulations``. """ if self.method == "prior": group_name = "prior_sbc" else: group_name = "posterior_sbc" self.simulations = from_dict( {group_name: self.simulations}, attrs={ "/": { "inferece_library": self.engine, "inferece_library_version": version(self.engine), "modeling_interface": "simuk", "modeling_interface_version": version("simuk"), } }, )
[docs] def compute_rank_statistics(self, transform=None): """Compute the rank statistic for the reference parameters. This method computes the rank of each reference parameter value relative to the newly sampled posterior draws for each simulation. This allows users to recompute rank statistics rapidly using a different parameter transformation without needing to rerun the simulations. Parameters ---------- transform : callable, optional A function that accepts two arguments: `(param_name, param_value)`. This function is applied to both the posterior draws and the reference parameter draws before computing the rank. For instance, it can be used to take the mean over a vectorized parameter grouping. If None, defaults to the `transform` passed during class initialization. Returns ------- xarray.DataTree An xarray.DataTree containing the computed rank statistics, matching the output structure generated by `run_simulations`. """ if not self.keep_fits: raise ValueError("calling `compute_rank_statistics` requires `keep_fits` to be True") if transform is None: transform = self._transform elif not callable(transform): raise ValueError("`transform` should be a function or None") self.simulations = {name: [] for name in self.var_names} for idx, posterior in enumerate(self.posteriors): self._compute_single_rank(idx, posterior, transform) self.simulations = {k: np.stack(v)[None, :] for k, v in self.simulations.items()} self._convert_to_datatree() return self.simulations
def _compute_single_rank(self, simulation_idx, posterior, transform): for name in self.var_names: if self.engine == "numpyro": transformed_posterior = np.array( [ transform(name, posterior[name].sel(chain=0).isel(draw=i).values) for i in range(posterior[name].sizes["draw"]) ] ) self.simulations[name].append( ( transformed_posterior < transform(name, self.ref_params[name][simulation_idx]) ).sum(axis=0) ) elif self.engine in ["bambi", "pymc"]: transformed_posterior = np.array( [ transform(name, posterior[name].isel(sample=i).values) for i in range(posterior[name].sizes["sample"]) ] ) self.simulations[name].append( ( transformed_posterior < transform(name, self.ref_params[name].isel(sample=simulation_idx).values) ).sum(axis=0) )
[docs] @quiet_logging("pymc", "pytensor.gof.compilelock", "bambi") def run_simulations(self): """Run all SBC iterations (Prior or Posterior SBC). For each iteration the method: 1. Draws a reference parameter vector and a replicated dataset (from the prior / prior-predictive for Prior SBC, or from the original posterior / posterior-predictive for Posterior SBC). 2. Fits the model to the (possibly augmented) replicated data. 3. Computes the rank of the reference draw among the new (augmented) posterior draws. The results are stored in ``self.simulations`` as an ArviZ DataTree with group ``"prior_sbc"`` or ``"posterior_sbc"``. This method can be stopped and restarted on the same instance: you can keyboard-interrupt part way through, inspect the partial results, and then call ``run_simulations()`` again to continue. If a seed was passed at init, reproducibility is preserved. """ progress = tqdm( initial=self._simulations_complete, total=self.num_simulations, disable=not self.progress_bar, ) if self.method == "prior": # In Prior SBC, the reference parameter draws are from the prior, # the predictive samples are from the prior predictive ref_params, predictive = self._get_prior_predictive_samples() else: # In Posterior SBC, the reference parameter draws are from the original posterior, # the predictive samples are from the original posterior predictive ref_params, predictive = self._get_posterior_predictive_samples() rng = np.random.default_rng(self.seed) sample_indices = rng.choice( ref_params.sizes["sample"], size=self.num_simulations, replace=False ) self.ref_params = ref_params.isel(sample=sample_indices) predictive = predictive.isel(sample=sample_indices) # if simulator is used, ignore observed_vars if self.simulator is not None: self.observed_vars = list(predictive.data_vars) self.var_names = list( filter( lambda var_name: var_name not in self.observed_vars, list(ref_params.data_vars), ) ) self.simulations = {var_name: [] for var_name in self.var_names} try: while self._simulations_complete < self.num_simulations: idx = self._simulations_complete replicated_data = { var_name: predictive[var_name].isel(sample=idx).values for var_name in self.observed_vars } posterior = self._get_posterior_samples(replicated_data) if self.keep_fits: self.posteriors.append(posterior) else: self._compute_single_rank(idx, posterior, self._transform) self._simulations_complete += 1 progress.update() except Exception: logging.error("Stopping simulation. An error occurred during simulations:") traceback.print_exc() finally: if self._simulations_complete: if self.keep_fits: self.compute_rank_statistics() else: self.simulations = { k: np.stack(v)[None, :] for k, v in self.simulations.items() } self._convert_to_datatree() progress.close()
@quiet_logging("numpyro") def _run_simulations_numpyro(self): """Run all the simulations for Numpyro Model.""" prior, prior_pred = self._get_prior_predictive_samples_numpyro() self.ref_params = prior progress = tqdm( initial=self._simulations_complete, total=self.num_simulations, ) # if simulator is used, ignore observed_vars if self.simulator is not None: self.observed_vars = list(prior_pred.keys()) self.observed_model_vars = [ name for name in self.observed_vars if name in self.model_params ] if not self.observed_model_vars: raise ValueError("No observed variables to condition on") self.var_names = list( filter( lambda var_name: var_name not in self.observed_vars, list(prior.keys()), ) ) self.simulations = {var_name: [] for var_name in self.var_names} try: while self._simulations_complete < self.num_simulations: idx = self._simulations_complete prior_predictive_draw = {k: v[idx] for k, v in prior_pred.items()} posterior = self._get_posterior_samples_numpyro(prior_predictive_draw) if self.keep_fits: self.posteriors.append(posterior) else: self._compute_single_rank(idx, posterior, self._transform) self._simulations_complete += 1 progress.update() finally: if self._simulations_complete: if self.keep_fits: self.compute_rank_statistics() else: self.simulations = { k: np.stack(v)[None, :] for k, v in self.simulations.items() } self._convert_to_datatree() progress.close()