JAX-NS Nested Sampler
=====================
Interface to `JAX-NS `_, a pure-JAX nested sampler. Use it for JAX-native models, vectorized likelihoods, and GPU/TPU acceleration.
Minimal Run
-----------
.. code-block:: python
import jax
import jax.numpy as jnp
from discoverysamplers.jaxns_interface import DiscoveryJAXNSBridge
jax.config.update("jax_enable_x64", True)
def my_model(params):
x, y = params['x'], params['y']
return -0.5 * (jnp.square(x) + jnp.square(y))
priors = {'x': ('uniform', -5.0, 5.0), 'y': ('uniform', -5.0, 5.0)}
# Create the bridge (accepts callable or object with .logL attribute)
bridge = DiscoveryJAXNSBridge(
discovery_model=my_model,
priors=priors,
latex_labels={'x': r'$x$', 'y': r'$y$'},
jit=True,
)
results = bridge.run_sampler(
nlive=800,
max_samples=10000,
termination_frac=0.01,
rng_seed=42,
)
print(f"logZ = {results['logZ']} ± {results['logZerr']}")
Using with Discovery Likelihoods
--------------------------------
.. code-block:: python
import discovery as ds
# Create Discovery likelihood
psr = ds.Pulsar.read_feather('path/to/pulsar.feather')
likelihood = ds.PulsarLikelihood([...])
# Pass the likelihood object directly
bridge = DiscoveryJAXNSBridge(
discovery_model=likelihood, # or likelihood.logL - both work
priors=priors,
jit=True
)
Vectorized Evaluation
---------------------
Enable batching when your model supports array inputs:
.. code-block:: python
bridge.configure_array_api(order=['x', 'y']) # parameter order for arrays
results = bridge.run_sampler(nlive=800, max_samples=8000)
Key Options
-----------
- ``nlive``: accuracy vs. cost (start 500–1000).
- ``max_samples``: hard cap; set high enough to reach termination.
- ``termination_frac``: smaller = more accurate evidence.
- ``jit``: keep True for JAX models; disable for pure NumPy.
- Priors must be uniform/loguniform/normal/fixed; callable priors are not supported by the JAX-NS bridge.
Reading Results
---------------
.. code-block:: python
samples = results['samples']
weights = results['weights']
x_mean = jnp.average(samples['x'], weights=weights)
x_std = jnp.sqrt(jnp.average(jnp.square(samples['x'] - x_mean), weights=weights))
Tips
----
- Always enable x64 for nested sampling precision.
- Keep priors bounded when possible for stable transforms.
- Monitor ``results['ESS']`` and ``results['logZerr']`` to judge run quality.
- Use GPU automatically if available (``jax.devices()``).
See Also
--------
- :doc:`../api/jaxns_interface` - API reference
- :doc:`../advanced/performance` - Performance optimization
- `JAX documentation `_ - JAX ecosystem docs