Source code for simuk.sbc

"""Simulation based calibration (Talts et. al. 2018) in PyMC."""

import logging
from copy import copy
from importlib.metadata import version

try:
    import pymc as pm
except ImportError:
    pass
try:
    import jax
    from numpyro.handlers import seed, trace
    from numpyro.infer import MCMC, Predictive
    from numpyro.infer.mcmc import MCMCKernel
except ImportError:
    pass

import numpy as np
from arviz_base import extract, from_dict, from_numpyro
from tqdm import tqdm


class quiet_logging:
    """Turn off logging for PyMC, Bambi and PyTensor."""

    def __init__(self, *libraries):
        self.loggers = [logging.getLogger(library) for library in libraries]

    def __call__(self, func):
        def wrapped(cls, *args, **kwargs):
            levels = []
            for logger in self.loggers:
                levels.append(logger.level)
                logger.setLevel(logging.CRITICAL)
            res = func(cls, *args, **kwargs)
            for logger, level in zip(self.loggers, levels):
                logger.setLevel(level)
            return res

        return wrapped


[docs] class SBC: """Set up class for doing SBC. Parameters ---------- model : pymc.Model, bambi.Model or numpyro.infer.mcmc.MCMCKernel A PyMC, Bambi model or Numpyro MCMC kernel. If a PyMC model the data needs to be defined as mutable data. num_simulations : int How many simulations to run sample_kwargs : dict[str] -> Any Arguments passed to pymc.sample or bambi.Model.fit seed : int (optional) Random seed. This persists even if running the simulations is paused for whatever reason. data_dir : dict Keyword arguments passed to numpyro model, intended for use when providing an MCMC Kernel model. Example ------- .. code-block :: python with pm.Model() as model: x = pm.Normal('x') y = pm.Normal('y', mu=2 * x, observed=obs) sbc = SBC(model) sbc.run_simulations() sbc.plot_results() """ def __init__(self, model, num_simulations=1000, sample_kwargs=None, seed=None, data_dir=None): if hasattr(model, "basic_RVs") and isinstance(model, pm.Model): self.engine = "pymc" self.model = model elif hasattr(model, "formula"): self.engine = "bambi" model.build() self.bambi_model = model self.model = model.backend.model self.formula = model.formula self.new_data = copy(model.data) elif isinstance(model, MCMCKernel): self.engine = "numpyro" self.numpyro_model = model self.model = self.numpyro_model.model self.run_simulations = self._run_simulations_numpyro self.data_dir = data_dir else: raise ValueError( "model should be one of pymc.Model, bambi.Model, or numpyro.infer.mcmc.MCMCKernel" ) self.num_simulations = num_simulations if sample_kwargs is None: sample_kwargs = {} if self.engine == "numpyro": sample_kwargs.setdefault("num_warmup", 1000) sample_kwargs.setdefault("num_samples", 1000) sample_kwargs.setdefault("progress_bar", False) else: sample_kwargs.setdefault("progressbar", False) sample_kwargs.setdefault("compute_convergence_checks", False) self.sample_kwargs = sample_kwargs self.seed = seed self._seeds = self._get_seeds() self._extract_variable_names() self.simulations = {name: [] for name in self.var_names} self._simulations_complete = 0 def _extract_variable_names(self): """Extract observed and free variables from the model.""" if self.engine == "numpyro": with trace() as tr: with seed(rng_seed=int(self._seeds[0])): self.numpyro_model.model(**self.data_dir) self.var_names = [ name for name, site in tr.items() if site["type"] == "sample" and not site.get("is_observed", False) ] self.observed_vars = [ name for name, site in tr.items() if site["type"] == "sample" and site.get("is_observed", False) ] else: self.observed_vars = [obs.name for obs in self.model.observed_RVs] self.var_names = [v.name for v in self.model.free_RVs] def _get_seeds(self): """Set the random seed, and generate seeds for all the simulations.""" rng = np.random.default_rng(self.seed) return rng.integers(0, 2**30, size=self.num_simulations) def _get_prior_predictive_samples(self): """Generate samples to use for the simulations.""" with self.model: idata = pm.sample_prior_predictive( samples=self.num_simulations, random_seed=self._seeds[0] ) prior_pred = extract(idata, group="prior_predictive", keep_dataset=True) prior = extract(idata, group="prior", keep_dataset=True) return prior, prior_pred def _get_prior_predictive_samples_numpyro(self): """Generate samples to use for the simulations using numpyro.""" predictive = Predictive(self.model, num_samples=self.num_simulations) free_vars_data = {k: v for k, v in self.data_dir.items() if k not in self.observed_vars} samples = predictive(jax.random.PRNGKey(self._seeds[0]), **free_vars_data) prior = {k: v for k, v in samples.items() if k not in self.observed_vars} prior_pred = {k: v for k, v in samples.items() if k in self.observed_vars} return prior, prior_pred def _get_posterior_samples(self, prior_predictive_draw): """Generate posterior samples conditioned to a prior predictive sample.""" new_model = pm.observe(self.model, prior_predictive_draw) with new_model: check = pm.sample( **self.sample_kwargs, random_seed=self._seeds[self._simulations_complete] ) posterior = extract(check, group="posterior", keep_dataset=True) return posterior def _get_posterior_samples_numpyro(self, prior_predictive_draw): """Generate posterior samples using numpyro conditioned to a prior predictive sample.""" mcmc = MCMC(self.numpyro_model, **self.sample_kwargs) rng_seed = jax.random.PRNGKey(self._seeds[self._simulations_complete]) free_vars_data = {k: v for k, v in self.data_dir.items() if k not in self.observed_vars} mcmc.run(rng_seed, **free_vars_data, **prior_predictive_draw) return from_numpyro(mcmc)["posterior"] def _convert_to_datatree(self): self.simulations = from_dict( {"prior_sbc": self.simulations}, attrs={ "/": { "inferece_library": self.engine, "inferece_library_version": version(self.engine), "modeling_interface": "simuk", "modeling_interface_version": version("simuk"), } }, )
[docs] @quiet_logging("pymc", "pytensor.gof.compilelock", "bambi") def run_simulations(self): """Run all the simulations. This function can be stopped and restarted on the same instance, so you can keyboard interrupt part way through, look at the plot, and then resume. If a seed was passed initially, it will still be respected (that is, the resulting simulations will be identical to running without pausing in the middle). """ prior, prior_pred = self._get_prior_predictive_samples() progress = tqdm( initial=self._simulations_complete, total=self.num_simulations, ) try: while self._simulations_complete < self.num_simulations: idx = self._simulations_complete prior_predictive_draw = { var_name: prior_pred[var_name].sel(chain=0, draw=idx).values for var_name in self.observed_vars } posterior = self._get_posterior_samples(prior_predictive_draw) for name in self.var_names: self.simulations[name].append( (posterior[name] < prior[name].sel(chain=0, draw=idx)).sum("sample").values ) self._simulations_complete += 1 progress.update() finally: self.simulations = { k: np.stack(v[: self._simulations_complete])[None, :] for k, v in self.simulations.items() } self._convert_to_datatree() progress.close()
@quiet_logging("numpyro") def _run_simulations_numpyro(self): """Run all the simulations for Numpyro Model.""" prior, prior_pred = self._get_prior_predictive_samples_numpyro() progress = tqdm( initial=self._simulations_complete, total=self.num_simulations, ) try: while self._simulations_complete < self.num_simulations: idx = self._simulations_complete prior_predictive_draw = {k: v[idx] for k, v in prior_pred.items()} posterior = self._get_posterior_samples_numpyro(prior_predictive_draw) for name in self.var_names: self.simulations[name].append( (posterior[name].sel(chain=0) < prior[name][idx]).sum(axis=0).values ) self._simulations_complete += 1 progress.update() finally: self.simulations = { k: np.stack(v[: self._simulations_complete])[None, :] for k, v in self.simulations.items() } self._convert_to_datatree() progress.close()