"""
Discovery ↔︎ nessai Interface
This module provides a light wrapper that adapts a Discovery-style model
(a callable that returns a log-probability/log-likelihood given a
parameter dictionary) to the API expected by `nessai`'s `FlowSampler`.
Design goals
------------
- Keep input style similar to the (intended) Eryn interface: pass a
`discovery_model` and a `prior` dictionary.
- Minimise hard-coding so this can be refactored later into a common
abstract `SamplerInterface`.
- Clean separation between: parsing priors, converting
(dict ↔︎ structured array), and running the sampler.
Notes on nessai model API
-------------------------
`nessai` expects a subclass of `nessai.model.Model` that defines:
- `names: list[str]` – parameter names
- `bounds: dict[str, tuple[float, float]]` – hard bounds used for sampling
- `log_prior(x) -> float` – log prior for a single live point `x`
- `log_likelihood(x) -> float` – log likelihood for a single live point
`x` is a *structured numpy array* with fields equal to `names`.
This wrapper constructs such a model on the fly, based on a
`prior`-specification dictionary.
Prior specification
-------------------
The `prior` dictionary maps parameter names to specs. Supported forms:
1) Distribution dicts:
{
'dist': 'uniform' | 'loguniform' | 'normal' | 'fixed',
# parameters depend on the dist (see `_make_prior`)
}
2) Shorthand tuples for common cases:
- ('uniform', min, max)
- ('loguniform', a, b)
- ('normal', mean, sigma)
- ('fixed', value)
3) A callable prior: any object with `logpdf(value)` and, for non-fixed
parameters, hard bounds provided via `'bounds': (min, max)`.
Fixed parameters are separated out and always injected into the model
inputs before calling the Discovery model.
Example
-------
>>> bridge = DiscoveryNessaiBridge(
... discovery_model=my_model, # callable or object with log_prob/log_likelihood
... prior={
... 'm1': {'dist': 'uniform', 'min': 5, 'max': 50},
... 'm2': ('loguniform', 1e-1, 10.0),
... 'z': ('fixed', 0.2),
... },
... )
>>> # Run the sampler
>>> results = bridge.run_sampler(nlive=1000, max_iterations=50_000, output='./out')
"""
from __future__ import annotations
from typing import Any, Callable, Dict, Iterable, List, Mapping, MutableMapping, Optional, Sequence, Tuple, Union
import math
import os
import numpy as np
import numpy.lib.recfunctions as rfn
from .priors import ParsedPrior, PriorParsingError, _parse_single_prior, _split_priors, ParamName, PriorSpec
from .likelihood import LikelihoodWrapper
try:
# These imports are only needed when actually running the sampler
import nessai
from nessai.flowsampler import FlowSampler
from nessai.model import Model as NessaiModel
import jax.numpy as jnp
except Exception: # pragma: no cover - allow import without nessai installed
FlowSampler = None # type: ignore
NessaiModel = object # type: ignore
jnp = None # type: ignore
# -------------------------- Utilities & Types --------------------------- #
def _as_batch_struct(x: np.ndarray) -> np.ndarray:
"""Ensure x is a 1D structured array (length N)."""
if x.dtype.names is None:
raise TypeError("Expected a structured numpy array with named fields.")
# np.atleast_1d keeps scalar structured arrays as shape (1,)
xb = np.atleast_1d(x)
if xb.ndim != 1:
# nessai expects a 1D batch of structured samples
xb = xb.reshape(-1)
return xb
# ------------------------ Nessai-specific prior utilities --------------- #
def _convert_parsed_prior_to_nessai(parsed: ParsedPrior, name: str) -> Tuple[Optional[Tuple[float, float]], Dict[str, float]]:
"""
Convert a ParsedPrior from the common module to nessai-compatible format.
Returns
-------
bounds : tuple or None
(min, max) bounds for the parameter
params : dict
Parameter-specific values (min, max for uniform, a, b for loguniform, etc.)
"""
if parsed.dist_type == 'uniform':
return parsed.bounds, {"min": parsed.bounds[0], "max": parsed.bounds[1]}
elif parsed.dist_type == 'loguniform':
return parsed.bounds, {"a": parsed.bounds[0], "b": parsed.bounds[1]}
elif parsed.dist_type == 'normal':
return parsed.bounds, {"mean": parsed.mean, "sigma": parsed.sigma}
elif parsed.dist_type == 'fixed':
return None, {"value": parsed.value}
elif parsed.dist_type == 'callable':
return parsed.bounds, {}
else:
raise PriorParsingError(f"Unsupported prior type '{parsed.dist_type}' for '{name}'")
def _split_priors_nessai(prior: Mapping[ParamName, PriorSpec]):
"""
Split priors into sampled and fixed parameters for nessai.
This is a nessai-specific wrapper around the common _split_priors function.
It handles nessai's requirement for explicit bounds on normal priors.
Returns
-------
sampled_names : list[str]
fixed : dict[str, float]
bounds : dict[str, tuple[float, float]]
logprior_fns : dict[str, Callable[[float], float]]
"""
# Use the common splitting function
sampled_names, fixed, bounds, logprior_fns = _split_priors(prior)
# Nessai-specific: check that all sampled parameters have finite bounds
for name in sampled_names:
if bounds[name] == (-np.inf, np.inf):
raise PriorParsingError(
f"Nessai requires finite bounds for all parameters. "
f"Parameter '{name}' has infinite bounds. "
f"For normal priors, specify bounds explicitly."
)
return sampled_names, fixed, bounds, logprior_fns
# --------------------------- nessai Model ------------------------------- #
[docs]
class DiscoveryNessaiModel(NessaiModel):
"""A nessai `Model` that wraps a Discovery-style model.
Parameters
----------
names : list[str]
Names of *sampled* parameters (fixed parameters are injected internally).
bounds : dict[str, tuple[float, float]]
Sampling bounds for each sampled parameter.
logprior_fns : dict[str, Callable[[float], float]]
Per-parameter log-prior functions.
fixed_params : dict[str, float]
Parameters that are not sampled.
discovery_adapter : LikelihoodWrapper
Adapter to call the Discovery model with a parameter dict.
"""
[docs]
def __init__(self,
names: List[str],
bounds: Dict[str, Tuple[float, float]],
logprior_fns: Dict[str, Callable[[float], float]],
fixed_params: Dict[str, float],
discovery_adapter: LikelihoodWrapper):
super().__init__()
self.names = list(names)
self._names_tuple = tuple(self.names) # optional: faster internal iteration
self.bounds = dict(bounds)
self._logprior_fns = dict(logprior_fns)
self._fixed = dict(fixed_params)
self._adapter = discovery_adapter
# Allow non-deterministic likelihoods (disable verification check)
self.allow_multi_valued_likelihood = True
# Keep a fixed column order for packed matrices
self._all_names = self._names_tuple + tuple(self._fixed.keys())
if hasattr(self._adapter, "configure_array_api"):
self._adapter.configure_array_api(self._all_names)
[docs]
def log_prior(self, x: np.ndarray) -> np.ndarray:
xb = _as_batch_struct(x) # shape (N,)
N = xb.shape[0]
total = np.zeros(N, dtype=float)
for n in self.names:
fn = self._logprior_fns[n]
vals = xb[n] # shape (N,)
# Evaluate per-sample; robust if fn is scalar-only
contrib = np.array([float(fn(float(v))) for v in vals], dtype=float)
total += contrib
# Replace any non-finite totals with -inf (per-sample)
total[~np.isfinite(total)] = -np.inf
# Preserve scalar-like return for N=1 while keeping ndarray type
return total if N > 1 else np.array(total[0])
[docs]
def log_likelihood(self, x: np.ndarray) -> np.ndarray:
xb = _as_batch_struct(x) # shape (N,)
N = xb.shape[0]
# Build one dense host array (N, D_s)
cols = [xb[n].astype(np.float64, copy=False) for n in self._names_tuple]
X = np.column_stack(cols) # (N, D_s)
# Append fixed params on host
if self._fixed:
fvec = np.array([self._fixed[k] for k in self._fixed], dtype=np.float64) # (D_f,)
if fvec.size:
F = np.broadcast_to(fvec, (N, fvec.size)) # (N, D_f)
X = np.concatenate([X, F], axis=1) # (N, D)
# One device put, zero dicts:
Xj = jnp.asarray(X, dtype=jnp.float64)
return self._adapter.log_likelihood_matrix(Xj)
# nessai calls these with a *structured* numpy array `x`
def _log_prior(self, x: np.ndarray) -> float:
total = 0.0
for n in self.names:
total += float(self._logprior_fns[n](float(x[n])))
if not np.isfinite(total):
return -np.inf
return np.array(total)
def _log_likelihood(self, x: np.ndarray) -> float:
# Scalar path: pack one row -> array -> compiled row call (no dicts here either)
vals = [float(x[n]) for n in self._names_tuple]
if self._fixed:
vals.extend(float(self._fixed[k]) for k in self._fixed)
row = jnp.asarray(np.asarray(vals, dtype=np.float64))
return self._adapter.log_likelihood_row(row)
# ------------------------- Public bridge class ------------------------- #
[docs]
class DiscoveryNessaiBridge:
"""Bridge between a Discovery-style model and `nessai`'s FlowSampler.
Parameters
----------
discovery_model : callable | object
A callable or object with `log_prob` or `log_likelihood`.
priors : Mapping[str, PriorSpec]
Dictionary describing priors (see module docstring). Includes fixed
parameters.
labels : Optional[Mapping[str, str]]
Optional display labels per parameter (not used internally yet).
Attributes
----------
sampled_names : list[str]
fixed_params : dict[str, float]
bounds : dict[str, tuple[float, float]]
model : DiscoveryNessaiModel
"""
[docs]
def __init__(self,
discovery_model: Any,
priors: Mapping[str, PriorSpec],
latex_labels: Optional[Mapping[str, str]] = None,
jit: bool = True):
self.adapter = LikelihoodWrapper(discovery_model, jit=jit, fixed_params=None, allow_array_api=True)
snames, fixed, bounds, lpfns = _split_priors_nessai(priors)
if not snames:
raise ValueError("No sampled parameters defined (all fixed?)")
self.sampled_names = snames
self.fixed_params = fixed
self.bounds = bounds
self.latex_labels = dict(latex_labels) if latex_labels else {n: n for n in snames}
# Keep original order from the priors dict for "all params"
self.discovery_paramnames = list(priors.keys())
self.fixed_names = list(self.fixed_params.keys())
# Build label lists
# self.latex_labels already exists in your class; ensure it's a mapping name->label
# If a label is missing, fall back to the name.
self.latex_list = [self.latex_labels.get(n, n) for n in self.discovery_paramnames]
self.sampled_names_latex = [self.latex_labels.get(n, n) for n in self.sampled_names]
self.fixed_names_latex = [self.latex_labels.get(n, n) for n in self.fixed_names]
self.model = DiscoveryNessaiModel(
names=self.sampled_names,
bounds=self.bounds,
logprior_fns=lpfns,
fixed_params=self.fixed_params,
discovery_adapter=self.adapter
)
# Convenience helpers -------------------------------------------------
[docs]
def dict_to_livepoint(self, d: Mapping[str, float]) -> np.ndarray:
"""Convert a parameter dict to a nessai live point (structured array)."""
dt = [(n, "f8") for n in self.sampled_names] + [("logP", "f8"), ("logL", "f8")]
lp = np.zeros((), dtype=dt)
for n in self.sampled_names:
lp[n] = float(d[n])
lp["logP"] = 0.0
lp["logL"] = 0.0
return lp
[docs]
def livepoint_to_dict(self, x: np.ndarray) -> Dict[str, float]:
return {n: float(x[n]) for n in self.sampled_names}
# Running the sampler -------------------------------------------------
[docs]
def run_sampler(self, *,
nlive: int = 1000,
output: str = "./nessai_out",
resume: bool = False,
**kwargs: Any) -> Any:
"""Run `nessai.FlowSampler` with this model.
Parameters
----------
nlive : int
Number of live points.
output : str
Output directory for nessai (checkpoints, samples, plots).
max_iterations : Optional[int]
Maximum number of iterations. If `None`, uses nessai default.
seed : Optional[int]
Random seed passed to the sampler.
resume : bool
Resume from previous run if possible.
**kwargs : Any
Forwarded directly to `FlowSampler`.
Returns
-------
The object returned by `FlowSampler.run()`, typically a results
dictionary including posterior samples and evidences.
"""
if FlowSampler is None:
raise RuntimeError("nessai is not installed. Please `pip install nessai`." )
self.sampler = FlowSampler(
self.model,
output=output,
nlive=nlive,
resume=resume,
**kwargs,
)
self.results = self.sampler.run()
# If run() returns None (newer nessai versions), construct results dict from sampler state
if self.results is None and hasattr(self.sampler, 'ns') and hasattr(self.sampler.ns, 'state'):
state = self.sampler.ns.state
self.results = {
'logZ': state.log_evidence if hasattr(state, 'log_evidence') else None,
'logZ_err': state.log_evidence_error if hasattr(state, 'log_evidence_error') else None,
'nested_samples': getattr(self.sampler.ns, 'nested_samples', None),
'posterior_samples': getattr(self.sampler, 'posterior_samples', None),
}
return self.results
[docs]
def return_logZ(self, *, results: Optional[Mapping[str, Any]] = None) -> Dict[str, float]:
"""
Return the log evidence and its uncertainty from nested sampling.
Parameters
----------
results : dict, optional
Results dict from run_sampler(). If None, uses stored results.
Returns
-------
dict
Dictionary containing:
- 'logZ': float - the log evidence estimate
- 'logZ_err': float - uncertainty on logZ
Raises
------
RuntimeError
If no results are available (run_sampler not called)
"""
res = results if results is not None else self.results
if res is None:
raise RuntimeError("No results available. Run `run_sampler()` first.")
# Try to extract logZ from results dict
logZ = None
logZ_err = None
# Check common keys in results dict
for key in ('logZ', 'log_evidence', 'log_Z', 'evidence'):
if key in res:
logZ = float(res[key])
break
for key in ('logZ_err', 'log_evidence_error', 'log_Z_err', 'evidence_error', 'logZerr'):
if key in res:
logZ_err = float(res[key])
break
# If not found in results, try sampler state
if logZ is None and self.sampler is not None:
if hasattr(self.sampler, 'ns') and hasattr(self.sampler.ns, 'state'):
state = self.sampler.ns.state
if hasattr(state, 'log_evidence'):
logZ = float(state.log_evidence)
if hasattr(state, 'log_evidence_error'):
logZ_err = float(state.log_evidence_error)
if logZ is None:
raise RuntimeError("Could not find log evidence in results. Check that sampling completed successfully.")
return {'logZ': logZ, 'logZ_err': logZ_err}
# ------------------------ Results & samples ------------------------ #
def _posterior_struct_array(self, results: Optional[Mapping[str, Any]] = None) -> np.ndarray:
"""
Return a *structured* numpy array of posterior samples (one row per sample)
with fields that include all `self.sampled_names` (plus possibly logL/logP/weights).
Tries several common locations and errors with guidance if not found.
"""
# 1) explicitly provided results dict
if results is not None:
for k in ("posterior_samples", "samples", "posterior"):
if k in results and isinstance(results[k], np.ndarray) and results[k].dtype.names:
return results[k]
# 2) whatever run() returned last time
if self.results is not None:
for k in ("posterior_samples", "samples", "posterior"):
v = self.results.get(k) if isinstance(self.results, Mapping) else None
if isinstance(v, np.ndarray) and v.dtype.names:
return v
# 3) look on the sampler object (common on recent nessai)
if self.sampler is not None:
for attr in ("posterior_samples", "posterior", "samples", "samples_posterior"):
v = getattr(self.sampler, attr, None)
if isinstance(v, np.ndarray) and v.dtype.names:
return v
# 4) fall back to the output directory (best effort; optional)
out = getattr(self.sampler, "output", None)
if isinstance(out, str) and os.path.isdir(out):
for fname in ("posterior_samples.npy", "posterior.npy", "samples_post.npy"):
path = os.path.join(out, fname)
if os.path.exists(path):
arr = np.load(path, allow_pickle=False)
if isinstance(arr, np.ndarray) and arr.dtype.names:
return arr
raise RuntimeError(
"Could not locate posterior samples. "
"Run `run_sampler()` first, or pass the results dict returned by `.run_sampler()`."
)
def _stack_columns(self, struct: np.ndarray, names: Iterable[str]) -> np.ndarray:
"""Return a (nsamples, len(names)) float array by selecting fields from a structured array."""
cols = [np.asarray(struct[n], dtype=float).reshape(-1) for n in names]
return np.stack(cols, axis=1) if cols else np.empty((len(struct), 0), dtype=float)
# --------------------------- Public API ---------------------------- #
[docs]
def return_sampled_samples(self, *, results: Optional[Mapping[str, Any]] = None) -> Dict[str, Any]:
"""
Returns sampled parameters only.
Returns
-------
dict with keys:
- 'names' : list[str]
- 'labels': list[str] (LaTeX or names)
- 'chain' : ndarray (nsamples, n_sampled)
"""
struct = self._posterior_struct_array(results)
chain = self._stack_columns(struct, self.sampled_names)
return {"names": self.sampled_names, "labels": self.sampled_names_latex, "chain": chain}
[docs]
def return_all_samples(self, *, results: Optional[Mapping[str, Any]] = None) -> Dict[str, Any]:
"""
Returns sampled + fixed parameters, arranged in the original `priors` order.
Fixed parameters are filled with their constant values.
"""
struct = self._posterior_struct_array(results)
ns = struct.shape[0]
chain_all = np.zeros((ns, len(self.discovery_paramnames)), dtype=float)
# Fill sampled
sampled_cols = self._stack_columns(struct, self.sampled_names)
for j, name in enumerate(self.sampled_names):
idx = self.discovery_paramnames.index(name)
chain_all[:, idx] = sampled_cols[:, j]
# Fill fixed
for name in self.fixed_names:
idx = self.discovery_paramnames.index(name)
chain_all[:, idx] = float(self.fixed_params[name])
return {"names": self.discovery_paramnames, "labels": self.latex_list, "chain": chain_all}
# ------------------------------ Plots ------------------------------ #
[docs]
def plot_trace(self, *, burn: int = 0, plot_fixed: bool = False,
results: Optional[Mapping[str, Any]] = None, **kwargs):
"""
Plot trace of samples vs sample index.
Parameters
----------
burn : int, optional
Number of initial samples to discard, by default 0.
plot_fixed : bool, optional
If True, includes fixed parameters in the plot, by default False.
results : dict, optional
Results dict from run_sampler(). If None, uses stored results.
**kwargs
Additional keyword arguments passed to plots.plot_trace().
Returns
-------
matplotlib.figure.Figure
Figure containing the trace plots.
"""
from .plots import plot_trace
data = self.return_all_samples(results=results) if plot_fixed else self.return_sampled_samples(results=results)
return plot_trace(
data,
burn=burn,
fixed_params=self.fixed_params,
fixed_names=self.fixed_names,
**kwargs
)
[docs]
def plot_corner(self, *, burn: int = 0, results: Optional[Mapping[str, Any]] = None, **kwargs):
"""
Corner plot of sampled parameters.
Parameters
----------
burn : int, optional
Number of initial samples to discard, by default 0.
results : dict, optional
Results dict from run_sampler(). If None, uses stored results.
**kwargs
Additional keyword arguments passed to corner.corner().
Returns
-------
matplotlib.figure.Figure
Corner plot figure.
"""
from .plots import plot_corner
data = self.return_sampled_samples(results=results)
return plot_corner(data, burn=burn, **kwargs)
__all__ = [
"DiscoveryNessaiBridge",
"DiscoveryNessaiModel",
]