JAX-NS Nested Sampler
Interface to JAX-NS, a pure-JAX nested sampler. Use it for JAX-native models, vectorized likelihoods, and GPU/TPU acceleration.
Minimal Run
import jax
import jax.numpy as jnp
from discoverysamplers.jaxns_interface import DiscoveryJAXNSBridge
jax.config.update("jax_enable_x64", True)
def my_model(params):
x, y = params['x'], params['y']
return -0.5 * (jnp.square(x) + jnp.square(y))
priors = {'x': ('uniform', -5.0, 5.0), 'y': ('uniform', -5.0, 5.0)}
# Create the bridge (accepts callable or object with .logL attribute)
bridge = DiscoveryJAXNSBridge(
discovery_model=my_model,
priors=priors,
latex_labels={'x': r'$x$', 'y': r'$y$'},
jit=True,
)
results = bridge.run_sampler(
nlive=800,
max_samples=10000,
termination_frac=0.01,
rng_seed=42,
)
print(f"logZ = {results['logZ']} ± {results['logZerr']}")
Using with Discovery Likelihoods
import discovery as ds
# Create Discovery likelihood
psr = ds.Pulsar.read_feather('path/to/pulsar.feather')
likelihood = ds.PulsarLikelihood([...])
# Pass the likelihood object directly
bridge = DiscoveryJAXNSBridge(
discovery_model=likelihood, # or likelihood.logL - both work
priors=priors,
jit=True
)
Vectorized Evaluation
Enable batching when your model supports array inputs:
bridge.configure_array_api(order=['x', 'y']) # parameter order for arrays
results = bridge.run_sampler(nlive=800, max_samples=8000)
Key Options
nlive: accuracy vs. cost (start 500–1000).max_samples: hard cap; set high enough to reach termination.termination_frac: smaller = more accurate evidence.jit: keep True for JAX models; disable for pure NumPy.Priors must be uniform/loguniform/normal/fixed; callable priors are not supported by the JAX-NS bridge.
Reading Results
samples = results['samples']
weights = results['weights']
x_mean = jnp.average(samples['x'], weights=weights)
x_std = jnp.sqrt(jnp.average(jnp.square(samples['x'] - x_mean), weights=weights))
Tips
Always enable x64 for nested sampling precision.
Keep priors bounded when possible for stable transforms.
Monitor
results['ESS']andresults['logZerr']to judge run quality.Use GPU automatically if available (
jax.devices()).
See Also
JAX-NS Interface - API reference
Performance Optimization - Performance optimization
JAX documentation - JAX ecosystem docs