Source code for discoverysamplers.jaxns_interface


"""
Discovery ↔︎ JAX-NS Interface

This module provides a bridge between Discovery-style models and JAX-NS nested sampling.
"""
from __future__ import annotations

from typing import Any, Callable, Dict, Iterable, List, Mapping, MutableMapping, Optional, Sequence, Tuple, Union
import math

import numpy as np
import numpy.lib.recfunctions as rfn

from .priors import ParsedPrior, PriorParsingError, _parse_single_prior, _split_priors, ParamName, PriorSpec
from .likelihood import LikelihoodWrapper

# --------------------------- JAX-NS Bridge ----------------------------- #

try:
    # Only needed when actually running the sampler
    import jax
    import jax.numpy as jnp
    from jax.scipy.special import erfinv
    # jaxns imports are version dependent; we try common entry points.
    import jaxns as jns  # type: ignore
    from jaxns import NestedSampler  # type: ignore
    from jaxns import Model, Prior  # type: ignore
    import tensorflow_probability.substrates.jax as tfp  # type: ignore
    tfpd = tfp.distributions  # type: ignore
except Exception:  # pragma: no cover - allow import without jaxns installed
    jns = None      # type: ignore
    NestedSampler = None  # type: ignore
    Model = None  # type: ignore
    Prior = None  # type: ignore
    tfpd = None  # type: ignore

# -------------------------- JAX-NS specific utilities ------------------- #


def _ndtri(u: "jnp.ndarray") -> "jnp.ndarray":
    """Inverse Normal CDF using erfinv: Φ^{-1}(u) = sqrt(2) * erfinv(2u-1)."""
    return jnp.sqrt(2.0) * erfinv(2.0 * u - 1.0)


def _make_prior_transform(sampled_names, parsed_specs_dict):
    """
    Build a vectorised prior transform f: U in [0,1]^D -> theta in R^D,
    matching the order in `sampled_names`.

    `parsed_specs_dict` is a dict[name] -> ParsedPrior from _parse_single_prior.
    This function adapts the common ParsedPrior format to JAX-NS requirements.
    """
    # Precompute constants for speed; everything JAX-friendly
    kinds = []
    params_list = []
    for n in sampled_names:
        p = parsed_specs_dict[n]
        kinds.append(p.dist_type)

        # Convert ParsedPrior to parameter dict for JAX-NS
        if p.dist_type == "uniform":
            params_list.append({"min": p.bounds[0], "max": p.bounds[1]})
        elif p.dist_type == "loguniform":
            params_list.append({"a": p.bounds[0], "b": p.bounds[1]})
        elif p.dist_type == "normal":
            params_list.append({"mean": p.mean, "sigma": p.sigma})
        else:
            raise ValueError(
                f"JAX-NS bridge: prior '{n}' with dist '{p.dist_type}' "
                "is not supported for transform-based sampling. "
                "Use a standard distribution (uniform/loguniform/normal)."
            )

    kinds_tuple = tuple(kinds)  # for Python-side dispatch

    def transform(uvec: "jnp.ndarray") -> "jnp.ndarray":
        # uvec shape (..., D); returns same leading shape with D parameters
        def one(u):
            outs = []
            for i, kind in enumerate(kinds_tuple):
                ui = u[i]
                par = params_list[i]
                if kind == "uniform":
                    a = par["min"]; b = par["max"]
                    outs.append(a + ui * (b - a))
                elif kind == "loguniform":
                    a = par["a"]; b = par["b"]
                    # exp( log(a) + u*(log(b)-log(a)) )
                    outs.append(jnp.exp(jnp.log(a) + ui * (jnp.log(b) - jnp.log(a))))
                elif kind == "normal":
                    mu = par["mean"]; sigma = par["sigma"]
                    outs.append(mu + sigma * _ndtri(ui))
                else:
                    # Should never reach here due to check above
                    raise ValueError(
                        f"JAX-NS bridge: prior '{sampled_names[i]}' with dist '{kind}' "
                        "is not supported for transform-based sampling."
                    )
            return jnp.stack(outs, axis=0)

        uvec = jnp.asarray(uvec, dtype=jnp.float64)
        if uvec.ndim == 1:
            return one(uvec)
        return jax.vmap(one)(uvec)

    return transform


[docs] class DiscoveryJAXNSBridge: """ Bridge between a Discovery-style model and JAX-NS NestedSampler. Mirrors DiscoveryNessaiBridge API where possible. Parameters ---------- discovery_model : callable | object Callable or object with `.logL(params_dict) -> float`. priors : Mapping[str, PriorSpec] Same schema you use for the nessai bridge. latex_labels : Optional[Mapping[str, str]] Optional labels used for plotting/exports. jit : bool JIT the discovery model for fast likelihood calls. """
[docs] def __init__(self, discovery_model, priors, latex_labels=None, jit: bool = True): if jns is None or NestedSampler is None: raise RuntimeError("jaxns is not installed. Please `pip install jaxns`.") # Parse priors once; keep the order from the input dict sampled_names, fixed, bounds, _logprior_fns = _split_priors(priors) if not sampled_names: raise ValueError("No sampled parameters defined (all fixed?)") # Reuse your adapter and prior splitter - pass fixed params to adapter self.adapter = LikelihoodWrapper(discovery_model, jit=jit, fixed_params=fixed, allow_array_api=True) self.sampled_names = list(sampled_names) self.fixed_params = dict(fixed) self.bounds = dict(bounds) # not strictly used by JAX-NS, but kept for parity self.discovery_paramnames = list(priors.keys()) self.fixed_names = list(self.fixed_params.keys()) self.latex_labels = dict(latex_labels) if latex_labels else {n: n for n in self.discovery_paramnames} self.latex_list = [self.latex_labels.get(n, n) for n in self.discovery_paramnames] self.sampled_names_latex = [self.latex_labels.get(n, n) for n in self.sampled_names] self.fixed_names_latex = [self.latex_labels.get(n, n) for n in self.fixed_names] # Build a map name -> ParsedPrior to feed the transform # Use the common parser for all parameters parsed_specs = {n: _parse_single_prior(n, priors[n]) for n in priors} self._parsed_sampled = {n: parsed_specs[n] for n in self.sampled_names} # Prior transform over unit-cube -> sampled parameter vector (in the sampled order) self._prior_transform_vec = _make_prior_transform(self.sampled_names, self._parsed_sampled) # Configure the discovery adapter’s array API (so the hot path is vectorised) all_order = tuple(self.sampled_names) + tuple(self.fixed_names) if hasattr(self.adapter, "configure_array_api"): self.adapter.configure_array_api(all_order) # Likelihood wrappers ------------------------------------------------ # Vector form: theta_sampled (D_s,) -> scalar logL # Expand with fixed params then call compiled row function. def _loglik_row(theta_s: "jnp.ndarray") -> "jnp.float64": vals = [theta_s[i] for i in range(len(self.sampled_names))] if self.fixed_params: vals.extend([self.fixed_params[k] for k in self.fixed_params]) row = jnp.asarray(jnp.array(vals, dtype=jnp.float64)) return self.adapter.log_likelihood_row(row) # Batch form: map over leading axis with vmap self._loglik_row = jax.jit(_loglik_row) self._loglik_batch = jax.jit(jax.vmap(self._loglik_row)) # JAX-NS API expects: # - loglikelihood(theta) that accepts theta either shape (D,) or (N,D) # - prior_transform(u) from unit-cube to theta, with the same shape contract def _loglik(theta): theta = jnp.asarray(theta, dtype=jnp.float64) if theta.ndim == 1: return self._loglik_row(theta) return self._loglik_batch(theta) def _prior_transform(u): return self._prior_transform_vec(u) self._jaxns_loglik = _loglik self._jaxns_prior_transform = _prior_transform # Sampler placeholder self.ns = None self.results = None
# ---------------------------- Running ------------------------------- #
[docs] def run_sampler(self, *, nlive: int = 800, max_samples: Optional[int] = None, termination_frac: float = 0.01, rng_seed: Optional[int] = None, sampler_kwargs: Optional[dict] = None): """ Run JAX-NS NestedSampler. Parameters ---------- nlive : int Number of live points. max_samples : int | None Optional hard cap on the number of samples (pass through when supported). termination_frac : float Evidence tolerance fraction at which to terminate (version dependent). rng_seed : int | None Seed for JAX PRNG. sampler_kwargs : dict | None Extra kwargs forwarded to NestedSampler (e.g., sampler implementation). Returns ------- results : object JAX-NS results object (version-dependent structure). """ if rng_seed is None: rng_seed = 0 key = jax.random.PRNGKey(rng_seed) sampler_kwargs = dict(sampler_kwargs or {}) # Create JAX-NS Model (new API) # prior_model must be a generator function that yields priors def prior_model(): priors_list = [] for name in self.sampled_names: parsed = self._parsed_sampled[name] if parsed.dist_type == "uniform": prior = Prior(tfpd.Uniform( low=parsed.bounds[0], high=parsed.bounds[1] ), name=name) elif parsed.dist_type == "loguniform": # LogUniform = exp(Uniform(log(a), log(b))) prior = Prior(tfpd.TransformedDistribution( distribution=tfpd.Uniform( low=jnp.log(parsed.bounds[0]), high=jnp.log(parsed.bounds[1]) ), bijector=tfp.bijectors.Exp() ), name=name) elif parsed.dist_type == "normal": prior = Prior(tfpd.Normal( loc=parsed.mean, scale=parsed.sigma ), name=name) else: raise ValueError(f"Unsupported prior type '{parsed.dist_type}' for parameter '{name}'") priors_list.append((yield prior)) return tuple(priors_list) # Define likelihood function that takes unpacked sampled parameter values def log_likelihood(*sampled_params): # sampled_params are unpacked values in the same order as self.sampled_names # Build dict with only sampled params - adapter will add fixed params sampled_dict = {} for i, name in enumerate(self.sampled_names): sampled_dict[name] = sampled_params[i] # Adapter will automatically merge with fixed_params return self.adapter.log_likelihood(sampled_dict) # Create the model model = Model(prior_model=prior_model, log_likelihood=log_likelihood) # Instantiate the sampler with the model self.ns = NestedSampler( model=model, max_samples=max_samples, num_live_points=nlive, **sampler_kwargs, ) # Run the sampler termination_kwargs = {} if termination_frac is not None: # Try to pass termination_frac if supported termination_kwargs["termination_frac"] = float(termination_frac) try: raw_results = self.ns(key, **termination_kwargs) except TypeError: # If termination_frac not supported, run without it raw_results = self.ns(key) # Process results - newer JAXNS returns (termination_reason, state) tuple if isinstance(raw_results, tuple) and len(raw_results) == 2: termination_reason, state = raw_results # Compute log evidence manually from nested samples # Simple trapezoid rule estimation # IMPORTANT: state.num_samples can be larger than the actual array size, # so we use the minimum of num_samples and the actual array length log_L_full = state.sample_collection.log_L n = min(int(state.num_samples), log_L_full.shape[0]) log_L = log_L_full[:n] # Sort by log-likelihood sorted_idx = jnp.argsort(log_L) log_L_sorted = log_L[sorted_idx] # Compute log prior mass (uniform spacing in nested sampling) log_X = jnp.log(jnp.linspace(1.0, 1.0/n, n)) # Trapezoid rule for log evidence log_widths = jnp.diff(jnp.concatenate([jnp.array([0.0]), log_X])) # Compute log(exp(log_L) * width) = log_L + log(width) log_weights = log_L_sorted + jnp.log(-log_widths) # LogSumExp to get total evidence from jax.scipy.special import logsumexp log_Z = logsumexp(log_weights) # Rough uncertainty estimate (1/sqrt(n)) log_Z_uncert = 1.0 / jnp.sqrt(float(n)) # Create a results dict for compatibility self.results = { 'termination_reason': termination_reason, 'state': state, 'logZ': float(log_Z), 'logZerr': float(log_Z_uncert), 'log_Z_mean': float(log_Z), 'log_Z_uncert': float(log_Z_uncert), 'samples': state.sample_collection, } else: # Older API or different structure self.results = raw_results return self.results
# --------------- Results extraction & convenience ------------------ # def _posterior_from_results(self, results=None): """ Try to obtain an equal-weight set of posterior draws (nsamples, D_s). Returns (chain_sampled, weight_array_or_None). """ res = results if results is not None else self.results if res is None: raise RuntimeError("No results available. Run `run_sampler()` first.") def _as_ndarray(x): try: return np.asarray(x, dtype=float) except Exception: return None chain = None # Handle new API: results dict with 'state' containing sample_collection if isinstance(res, Mapping) and 'state' in res: state = res['state'] if hasattr(state, 'sample_collection'): sc = state.sample_collection # Get actual number of samples (min of num_samples and array size) n = sc.log_L.shape[0] if hasattr(state, 'num_samples'): n = min(int(state.num_samples), n) # U_samples contains the unit hypercube samples # Shape is (n_samples, n_params) - already in the correct format if hasattr(sc, 'U_samples'): U_samples = sc.U_samples U_samples_np = np.asarray(U_samples, dtype=float)[:n] # Transform from unit hypercube to physical space # For each parameter, apply the inverse CDF of the prior chain = np.zeros_like(U_samples_np) for i, name in enumerate(self.sampled_names): parsed = self._parsed_sampled[name] u = U_samples_np[:, i] if parsed.dist_type == 'uniform': low, high = parsed.bounds chain[:, i] = low + u * (high - low) elif parsed.dist_type == 'loguniform': low, high = parsed.bounds log_low, log_high = np.log(low), np.log(high) chain[:, i] = np.exp(log_low + u * (log_high - log_low)) elif parsed.dist_type == 'normal': from scipy import stats chain[:, i] = stats.norm.ppf(u, loc=parsed.mean, scale=parsed.sigma) # Common possibilities across versions: # - res.samples or res['samples'] -> dict with 'theta' or ndarray # - res.posterior_samples -> ndarray # - res['posterior']['samples'] -> ndarray # If equal-weight posterior utility exists, prefer that. # We duck-type check a few. # 1) Equal-weight helper (if present) if chain is None: for attr in ("equal_weight_posterior", "get_equal_weight_posterior", "posterior_equal_weights"): fn = getattr(self.ns, attr, None) if callable(fn): key = jax.random.PRNGKey(0) try: arr = fn(res, key=key) if "key" in fn.__code__.co_varnames else fn(res) if isinstance(arr, (np.ndarray, jnp.ndarray)): chain = np.asarray(arr, dtype=float) break except Exception: pass # 2) Look for common fields if chain is None: candidates = [] for k in ("posterior_samples", "samples", "theta", "posterior"): val = getattr(res, k, None) if val is not None: candidates.append(val) if isinstance(res, Mapping) and k in res: candidates.append(res[k]) for c in candidates: arr = _as_ndarray(c) if arr is not None and arr.ndim == 2 and arr.shape[1] == len(self.sampled_names): chain = arr break if chain is None: raise RuntimeError( "Could not locate posterior samples in JAX-NS results. " "If your version exposes a different field, pass it to " "`return_sampled_samples(results=...)`." ) weights = None return chain, weights def _stack_with_fixed(self, sampled_chain: np.ndarray) -> np.ndarray: """Combine sampled columns with fixed params in original priors order.""" nsamp = sampled_chain.shape[0] chain_all = np.zeros((nsamp, len(self.discovery_paramnames)), dtype=float) # fill sampled for j, name in enumerate(self.sampled_names): idx = self.discovery_paramnames.index(name) chain_all[:, idx] = sampled_chain[:, j] # fill fixed for name in self.fixed_names: idx = self.discovery_paramnames.index(name) chain_all[:, idx] = float(self.fixed_params[name]) return chain_all # Public API (mirrors nessai bridge) ---------------------------------
[docs] def return_sampled_samples(self, *, results=None) -> Dict[str, Any]: chain, _ = self._posterior_from_results(results) return {"names": self.sampled_names, "labels": self.sampled_names_latex, "chain": np.asarray(chain, dtype=float)}
[docs] def return_all_samples(self, *, results=None) -> Dict[str, Any]: chain_s, _ = self._posterior_from_results(results) chain_all = self._stack_with_fixed(chain_s) return {"names": self.discovery_paramnames, "labels": self.latex_list, "chain": chain_all}
[docs] def return_logZ(self, *, results=None) -> Dict[str, float]: """ Return the log evidence and its uncertainty from nested sampling. Parameters ---------- results : dict, optional Results dict from run_sampler(). If None, uses stored results. Returns ------- dict Dictionary containing: - 'logZ': float - the log evidence estimate - 'logZ_err': float - uncertainty on logZ Raises ------ RuntimeError If no results are available (run_sampler not called) """ res = results if results is not None else self.results if res is None: raise RuntimeError("No results available. Run `run_sampler()` first.") logZ = None logZ_err = None # Check common keys in results dict for key in ('logZ', 'log_Z_mean', 'log_evidence', 'log_Z'): if isinstance(res, Mapping) and key in res: logZ = float(res[key]) break for key in ('logZerr', 'logZ_err', 'log_Z_uncert', 'log_evidence_error', 'log_Z_err'): if isinstance(res, Mapping) and key in res: logZ_err = float(res[key]) break # Try object attributes if dict keys failed if logZ is None: for attr in ('log_Z_mean', 'logZ', 'log_evidence'): if hasattr(res, attr): logZ = float(getattr(res, attr)) break if logZ_err is None: for attr in ('log_Z_uncert', 'logZerr', 'log_evidence_error'): if hasattr(res, attr): logZ_err = float(getattr(res, attr)) break if logZ is None: raise RuntimeError("Could not find log evidence in results. Check that sampling completed successfully.") return {'logZ': logZ, 'logZ_err': logZ_err}
# ------------------------------ Plots ------------------------------ #
[docs] def plot_trace(self, *, burn: int = 0, plot_fixed: bool = False, results=None, **kwargs): """ Plot trace of samples vs sample index. Parameters ---------- burn : int, optional Number of initial samples to discard, by default 0. plot_fixed : bool, optional If True, includes fixed parameters in the plot, by default False. results : optional Results from run_sampler(). If None, uses stored results. **kwargs Additional keyword arguments passed to plots.plot_trace(). Returns ------- matplotlib.figure.Figure Figure containing the trace plots. """ from .plots import plot_trace data = self.return_all_samples(results=results) if plot_fixed else self.return_sampled_samples(results=results) return plot_trace( data, burn=burn, fixed_params=self.fixed_params, fixed_names=self.fixed_names, **kwargs )
[docs] def plot_corner(self, *, burn: int = 0, results=None, **kwargs): """ Corner plot of sampled parameters. Parameters ---------- burn : int, optional Number of initial samples to discard, by default 0. results : optional Results from run_sampler(). If None, uses stored results. **kwargs Additional keyword arguments passed to corner.corner(). Returns ------- matplotlib.figure.Figure Corner plot figure. """ from .plots import plot_corner data = self.return_sampled_samples(results=results) return plot_corner(data, burn=burn, **kwargs)