JAX-NS Interface
This module provides the bridge interface connecting Discovery models to the JAX-NS nested sampler. JAX-NS is a pure JAX implementation that supports GPU acceleration and vectorized likelihood evaluation.
Hint
For best performance with JAX-NS, ensure your likelihood is JAX-compatible
and enable JIT compilation with jit=True.
Discovery ↔︎ JAX-NS Interface
This module provides a bridge between Discovery-style models and JAX-NS nested sampling.
- class discoverysamplers.jaxns_interface.DiscoveryJAXNSBridge(discovery_model, priors, latex_labels=None, jit=True)[source]
Bases:
objectBridge between a Discovery-style model and JAX-NS NestedSampler.
Mirrors DiscoveryNessaiBridge API where possible.
- Parameters:
discovery_model (callable | object) – Callable or object with .logL(params_dict) -> float.
priors (Mapping[str, PriorSpec]) – Same schema you use for the nessai bridge.
latex_labels (Optional[Mapping[str, str]]) – Optional labels used for plotting/exports.
jit (bool) – JIT the discovery model for fast likelihood calls.
- run_sampler(*, nlive=800, max_samples=None, termination_frac=0.01, rng_seed=None, sampler_kwargs=None)[source]
Run JAX-NS NestedSampler.
- Parameters:
nlive (int) – Number of live points.
max_samples (int | None) – Optional hard cap on the number of samples (pass through when supported).
termination_frac (float) – Evidence tolerance fraction at which to terminate (version dependent).
rng_seed (int | None) – Seed for JAX PRNG.
sampler_kwargs (dict | None) – Extra kwargs forwarded to NestedSampler (e.g., sampler implementation).
- Returns:
results – JAX-NS results object (version-dependent structure).
- Return type:
- return_logZ(*, results=None)[source]
Return the log evidence and its uncertainty from nested sampling.
- Parameters:
results (dict, optional) – Results dict from run_sampler(). If None, uses stored results.
- Returns:
Dictionary containing: - ‘logZ’: float - the log evidence estimate - ‘logZ_err’: float - uncertainty on logZ
- Return type:
- Raises:
RuntimeError – If no results are available (run_sampler not called)
- plot_trace(*, burn=0, plot_fixed=False, results=None, **kwargs)[source]
Plot trace of samples vs sample index.
- Parameters:
burn (int, optional) – Number of initial samples to discard, by default 0.
plot_fixed (bool, optional) – If True, includes fixed parameters in the plot, by default False.
results (optional) – Results from run_sampler(). If None, uses stored results.
**kwargs – Additional keyword arguments passed to plots.plot_trace().
- Returns:
Figure containing the trace plots.
- Return type:
matplotlib.figure.Figure
- plot_corner(*, burn=0, results=None, **kwargs)[source]
Corner plot of sampled parameters.
- Parameters:
burn (int, optional) – Number of initial samples to discard, by default 0.
results (optional) – Results from run_sampler(). If None, uses stored results.
**kwargs – Additional keyword arguments passed to corner.corner().
- Returns:
Corner plot figure.
- Return type:
matplotlib.figure.Figure
See Also
JAX-NS Nested Sampler - User guide for JAX-NS
Performance Optimization - Performance optimization