JAX-NS Interface

This module provides the bridge interface connecting Discovery models to the JAX-NS nested sampler. JAX-NS is a pure JAX implementation that supports GPU acceleration and vectorized likelihood evaluation.

Hint

For best performance with JAX-NS, ensure your likelihood is JAX-compatible and enable JIT compilation with jit=True.

Discovery ↔︎ JAX-NS Interface

This module provides a bridge between Discovery-style models and JAX-NS nested sampling.

class discoverysamplers.jaxns_interface.DiscoveryJAXNSBridge(discovery_model, priors, latex_labels=None, jit=True)[source]

Bases: object

Bridge between a Discovery-style model and JAX-NS NestedSampler.

Mirrors DiscoveryNessaiBridge API where possible.

Parameters:
  • discovery_model (callable | object) – Callable or object with .logL(params_dict) -> float.

  • priors (Mapping[str, PriorSpec]) – Same schema you use for the nessai bridge.

  • latex_labels (Optional[Mapping[str, str]]) – Optional labels used for plotting/exports.

  • jit (bool) – JIT the discovery model for fast likelihood calls.

__init__(discovery_model, priors, latex_labels=None, jit=True)[source]
run_sampler(*, nlive=800, max_samples=None, termination_frac=0.01, rng_seed=None, sampler_kwargs=None)[source]

Run JAX-NS NestedSampler.

Parameters:
  • nlive (int) – Number of live points.

  • max_samples (int | None) – Optional hard cap on the number of samples (pass through when supported).

  • termination_frac (float) – Evidence tolerance fraction at which to terminate (version dependent).

  • rng_seed (int | None) – Seed for JAX PRNG.

  • sampler_kwargs (dict | None) – Extra kwargs forwarded to NestedSampler (e.g., sampler implementation).

Returns:

results – JAX-NS results object (version-dependent structure).

Return type:

object

return_sampled_samples(*, results=None)[source]
Return type:

Dict[str, Any]

return_all_samples(*, results=None)[source]
Return type:

Dict[str, Any]

return_logZ(*, results=None)[source]

Return the log evidence and its uncertainty from nested sampling.

Parameters:

results (dict, optional) – Results dict from run_sampler(). If None, uses stored results.

Returns:

Dictionary containing: - ‘logZ’: float - the log evidence estimate - ‘logZ_err’: float - uncertainty on logZ

Return type:

dict

Raises:

RuntimeError – If no results are available (run_sampler not called)

plot_trace(*, burn=0, plot_fixed=False, results=None, **kwargs)[source]

Plot trace of samples vs sample index.

Parameters:
  • burn (int, optional) – Number of initial samples to discard, by default 0.

  • plot_fixed (bool, optional) – If True, includes fixed parameters in the plot, by default False.

  • results (optional) – Results from run_sampler(). If None, uses stored results.

  • **kwargs – Additional keyword arguments passed to plots.plot_trace().

Returns:

Figure containing the trace plots.

Return type:

matplotlib.figure.Figure

plot_corner(*, burn=0, results=None, **kwargs)[source]

Corner plot of sampled parameters.

Parameters:
  • burn (int, optional) – Number of initial samples to discard, by default 0.

  • results (optional) – Results from run_sampler(). If None, uses stored results.

  • **kwargs – Additional keyword arguments passed to corner.corner().

Returns:

Corner plot figure.

Return type:

matplotlib.figure.Figure

See Also