Example Notebooks

This page provides links to example Jupyter notebooks demonstrating various use cases of discoverysamplers.

Available Examples

The examples/ directory contains Jupyter notebooks showcasing different features and samplers:

Quick Start Guide

quick_start_guide.ipynb

Comprehensive introduction to all samplers:

  • Basic model setup

  • Using all four samplers (Nessai, JAX-NS, Eryn, GPry)

  • Parallel tempering with Eryn

  • Results analysis and visualization

# Location
examples/quick_start_guide.ipynb

Reversible Jump MCMC

RJ_MCMC.ipynb

Trans-dimensional sampling with Eryn’s RJMCMC:

  • Toy example with variable number of Gaussian components

  • Discovery PTA example with continuous wave sources

  • Branch-indexed prior specification

  • Component count analysis and plotting

# Location
examples/RJ_MCMC.ipynb

Eryn Examples

eryn_example.ipynb

Basic usage of the Eryn MCMC sampler:

  • Setting up a simple model

  • Configuring priors

  • Running MCMC sampling

  • Analyzing chains and convergence

  • Visualizing results

# Location
examples/eryn_example.ipynb

Nessai Examples

nessai_example.ipynb

Nested sampling with Nessai:

  • Flow-based nested sampling

  • Configuring flow parameters

  • Evidence calculation

  • Posterior sampling

  • Model comparison

# Location
examples/nessai_example.ipynb

Running the Examples

Prerequisites

Install required dependencies:

# Install discoverysamplers with all samplers
pip install discoverysamplers eryn nessai jaxns gpry cobaya jax

# Install Jupyter
pip install jupyter matplotlib corner

Running Notebooks

# Navigate to examples directory
cd examples/

# Start Jupyter
jupyter notebook

# Open desired notebook in browser

Converting to Python Scripts

To run examples as scripts:

# Convert notebook to Python script
jupyter nbconvert --to python eryn_example.ipynb

# Run the script
python eryn_example.py

Example Snippets

Quick Start Example

A minimal working example:

import numpy as np
from discoverysamplers.nessai_interface import DiscoveryNessaiBridge

# Define a simple 2D Gaussian model
def gaussian_model(params):
    x = params['x']
    y = params['y']
    return -0.5 * (x**2 + y**2)

# Define priors
priors = {
    'x': ('uniform', -5, 5),
    'y': ('uniform', -5, 5),
}

# Create bridge
bridge = DiscoveryNessaiBridge(
    discovery_model=gaussian_model,
    priors=priors,
    jit=True
)

# Run sampler
results = bridge.run_sampler(
    nlive=1000,
    output='output/gaussian/',
    resume=False
)

# Print results
print(f"Log evidence: {results['logZ']:.2f} ± {results['logZ_err']:.2f}")

Multimodal Example

Handling multimodal distributions:

from scipy.special import logsumexp
from discoverysamplers.eryn_interface import DiscoveryErynBridge

# Bimodal likelihood
def bimodal_model(params):
    x = params['x']
    y = params['y']

    # Two modes at (-2, -2) and (2, 2)
    log_L1 = -0.5 * ((x + 2)**2 + (y + 2)**2)
    log_L2 = -0.5 * ((x - 2)**2 + (y - 2)**2)

    return logsumexp([log_L1, log_L2]) - np.log(2)

priors = {
    'x': ('uniform', -6, 6),
    'y': ('uniform', -6, 6),
}

# Use parallel tempering for multimodality
bridge = DiscoveryErynBridge(bimodal_model, priors)

sampler = bridge.create_sampler(
    nwalkers=32,
    tempering_kwargs=dict(ntemps=8, Tmax=20.0)
)

initial = bridge.sample_priors(nwalkers=32, ntemps=8)
sampler.run_mcmc(initial, nsteps=10000)

# Get cold chain samples
samples = sampler.get_chain(discard=1000, flat=True)

High-Dimensional Example

Sampling in higher dimensions with JAX-NS:

import jax.numpy as jnp
from discoverysamplers.jaxns_interface import DiscoveryJAXNSBridge

# 10-dimensional Gaussian
ndim = 10

def high_dim_model(params):
    x = jnp.array([params[f'x{i}'] for i in range(ndim)])
    return -0.5 * jnp.sum(x**2)

priors = {f'x{i}': ('uniform', -5, 5) for i in range(ndim)}

bridge = DiscoveryJAXNSBridge(high_dim_model, priors, jit=True)

# Enable vectorization for speed
bridge.configure_array_api(order=[f'x{i}' for i in range(ndim)])

results = bridge.run_sampler(
    nlive=1000,
    max_samples=20000,
    rng_seed=42
)

Realistic Science Example

Fitting a sinusoidal signal:

import numpy as np
import matplotlib.pyplot as plt

# Generate synthetic data
np.random.seed(42)
times = np.linspace(0, 10, 100)

true_params = {
    'amplitude': 2.5,
    'frequency': 1.5,
    'phase': 0.3,
    'noise_std': 0.5,
}

signal = (true_params['amplitude'] *
         np.sin(2*np.pi*true_params['frequency']*times + true_params['phase']))
noise = np.random.normal(0, true_params['noise_std'], len(times))
data = signal + noise

# Define model
def sinusoid_model(params):
    A = params['amplitude']
    f = params['frequency']
    phi = params['phase']
    sigma = params['noise_std']

    model_signal = A * np.sin(2*np.pi*f*times + phi)
    residuals = data - model_signal

    # Gaussian likelihood
    log_L = -0.5 * np.sum((residuals/sigma)**2) - len(data)*np.log(sigma)
    return log_L

# Priors
priors = {
    'amplitude': ('uniform', 0, 5),
    'frequency': ('uniform', 0.1, 3),
    'phase': ('uniform', 0, 2*np.pi),
    'noise_std': ('loguniform', 0.01, 2),
}

# Run sampling
from discoverysamplers.nessai_interface import DiscoveryNessaiBridge

bridge = DiscoveryNessaiBridge(sinusoid_model, priors, jit=False)
results = bridge.run_sampler(nlive=1000, output='output/sinusoid/')

# Analyze results
posterior = results['posterior_samples']
weights = posterior['weights']

for param in ['amplitude', 'frequency', 'phase', 'noise_std']:
    samples = posterior[param]
    mean = np.average(samples, weights=weights)
    std = np.sqrt(np.average((samples - mean)**2, weights=weights))
    true_val = true_params[param]

    print(f"{param}:")
    print(f"  True: {true_val:.3f}")
    print(f"  Estimated: {mean:.3f} ± {std:.3f}")

Visualization Examples

Corner Plots

import corner
import numpy as np

# Get samples
posterior = results['posterior_samples']
samples_array = np.column_stack([
    posterior[name] for name in bridge.sampled_names
])
weights = posterior['weights']

# Create corner plot
fig = corner.corner(
    samples_array,
    weights=weights,
    labels=[bridge.latex_labels.get(n, n) for n in bridge.sampled_names],
    quantiles=[0.16, 0.5, 0.84],
    show_titles=True,
    title_kwargs={"fontsize": 12},
)

plt.savefig('corner_plot.png', dpi=300, bbox_inches='tight')

Chain Diagnostics

For MCMC (Eryn):

# Get chain
chain = sampler.get_chain()  # Shape: (nsteps, nwalkers, ndim)

# Plot traces
fig, axes = plt.subplots(bridge.ndim, figsize=(10, 2*bridge.ndim))

for i, name in enumerate(bridge.sampled_names):
    ax = axes[i] if bridge.ndim > 1 else axes
    for walker in range(chain.shape[1]):
        ax.plot(chain[:, walker, i], alpha=0.3)
    ax.set_ylabel(bridge.latex_labels.get(name, name))
    ax.set_xlabel('Step')

plt.tight_layout()
plt.savefig('traces.png')

Model Predictions

# Generate predictions from posterior
n_posterior_samples = 100
predictions = []

indices = np.random.choice(
    len(posterior['weights']),
    size=n_posterior_samples,
    p=posterior['weights']/posterior['weights'].sum()
)

for idx in indices:
    params = {name: posterior[name][idx] for name in bridge.sampled_names}
    pred = generate_signal(params)  # Your signal generation
    predictions.append(pred)

predictions = np.array(predictions)

# Plot with credible intervals
plt.figure(figsize=(12, 6))
plt.plot(times, data, 'k.', label='Data', alpha=0.5)

median = np.median(predictions, axis=0)
lower = np.percentile(predictions, 16, axis=0)
upper = np.percentile(predictions, 84, axis=0)

plt.plot(times, median, 'r-', label='Median', lw=2)
plt.fill_between(times, lower, upper, alpha=0.3, label='68% CI')

plt.xlabel('Time')
plt.ylabel('Signal')
plt.legend()
plt.savefig('predictions.png')

Additional Resources

External Examples

Tutorials

See Also