Eryn MCMC Sampler
Interface to Eryn, an ensemble MCMC sampler with optional parallel tempering and reversible jump (RJMCMC) capabilities. Use it for posterior sampling when you do not need evidence estimates.
Minimal Run
import numpy as np
from discoverysamplers.eryn_interface import DiscoveryErynBridge
def my_model(params):
x, y = params['x'], params['y']
return -0.5 * (x**2 + y**2)
priors = {'x': ('uniform', -5.0, 5.0), 'y': ('uniform', -5.0, 5.0)}
bridge = DiscoveryErynBridge(model=my_model, priors=priors)
sampler = bridge.create_sampler(nwalkers=32)
p0 = bridge.sample_priors(nwalkers=32) # initialize from prior
sampler.run_mcmc(p0, nsteps=5000, progress=True)
samples = sampler.get_chain(discard=1000, flat=True)
x_samples = samples[:, bridge.eryn_mapping['x']]
print(f"x mean = {np.mean(x_samples):.3f}")
Parallel Tempering
Parallel tempering helps sample multi-modal posteriors by running chains at different “temperatures.” Hot chains explore more broadly while cold chains sample the target distribution.
Basic Usage:
# Enable parallel tempering with tempering_kwargs
sampler = bridge.create_sampler(
nwalkers=32,
tempering_kwargs=dict(ntemps=8) # 8 temperature chains
)
# Initial state must have shape (ntemps, nwalkers, ndim)
p0 = bridge.sample_priors(nwalkers=32, ntemps=8)
sampler.run_mcmc(p0, nsteps=5000, progress=True)
# get_chain returns only the cold chain (temp=0) by default
samples = sampler.get_chain(discard=1000, flat=True)
Configuration Options:
# Fine-tune tempering behavior
sampler = bridge.create_sampler(
nwalkers=32,
tempering_kwargs=dict(
ntemps=8, # Number of temperature levels
Tmax=None, # Maximum temperature (None = adaptive)
adaptive=True, # Adapt temperatures during run
)
)
Tips for Parallel Tempering:
Start with
ntemps=4-8and increase if chains don’t mix wellTarget swap acceptance rates around 20-40%
Use more temperatures for highly multi-modal posteriors
Monitor
sampler.acceptance_fractionper temperature
Reversible Jump MCMC (RJMCMC)
RJMCMC enables trans-dimensional sampling where the number of model components can vary. This is useful for model selection problems like counting the number of gravitational wave sources in PTA data.
Setting Up RJMCMC:
import numpy as np
from discoverysamplers.eryn_RJ_interface import RJ_Discovery_model, DiscoveryErynRJBridge
# 1. Create your base Discovery model (must return (delay_fn, param_names) when called)
# This is typically a signal constructor like cw_delay
def signal_constructor(psr):
# Returns (delay_function, param_names)
return delay_fn, ['f', 'h', 'phi']
# 2. Wrap in RJ_Discovery_model
rj_model = RJ_Discovery_model(
signal_constructors=[signal_constructor], # One per branch
pulsar=psr,
variable_component_numbers=[0, 1, 2, 3], # Allowed component counts
)
# 3. Define branch-indexed priors
priors = {
'cw': {
0: {}, # No parameters when 0 components
1: {
'f': ('loguniform', 1e-9, 1e-7),
'h': ('loguniform', 1e-20, 1e-14),
'phi': ('uniform', 0, 2*np.pi),
},
2: {
'f': ('loguniform', 1e-9, 1e-7),
'h': ('loguniform', 1e-20, 1e-14),
'phi': ('uniform', 0, 2*np.pi),
},
3: {
'f': ('loguniform', 1e-9, 1e-7),
'h': ('loguniform', 1e-20, 1e-14),
'phi': ('uniform', 0, 2*np.pi),
},
}
}
# 4. Create bridge
bridge = DiscoveryErynRJBridge(
discovery_model=rj_model,
priors=priors,
branch_names=['cw'],
nleaves_min={'cw': 0}, # Minimum components
nleaves_max={'cw': 3}, # Maximum components
)
# 5. Create sampler with RJMCMC moves enabled
sampler = bridge.create_sampler(
nwalkers=32,
tempering_kwargs=dict(ntemps=4), # Often want tempering for RJMCMC
)
Running and Analyzing RJMCMC:
# Initialize from priors
state = bridge.sample_priors(nwalkers=32, ntemps=4)
sampler.run_mcmc(state, nsteps=10000, progress=True)
# Get number of components per sample
nleaves = sampler.get_nleaves()['cw'] # Shape: (nsteps, ntemps, nwalkers)
# Plot component count histogram
import matplotlib.pyplot as plt
cold_nleaves = nleaves[:, 0, :].flatten() # Cold chain only
plt.hist(cold_nleaves, bins=np.arange(-0.5, 4.5), density=True)
plt.xlabel('Number of components')
plt.ylabel('Posterior probability')
Key RJMCMC Concepts:
Branches: Groups of parameters that can appear multiple times (e.g.,
'cw'for continuous waves)Leaves: Individual instances within a branch (e.g., each CW source is a leaf)
``nleaves_min/max``: Control the allowed range of component counts
Model format: Your
logL(*params)receives nested lists whereparams[i][j]contains parameters for componentjof branchi
See the examples/RJ_MCMC.ipynb notebook for a complete working example.
Key Options
For ``create_sampler()``:
nwalkers: at least2 * ndim(typ. 32–64).tempering_kwargs: dict withntemps,Tmax,adaptivefor parallel tempering.moves: pass Eryn move objects/weights to customize proposals.backend: useeryn.backends.HDFBackendto checkpoint chains.
For RJMCMC (DiscoveryErynRJBridge):
branch_names: list of branch names (e.g.,['cw']).nleaves_min/max: dicts mapping branch names to min/max component counts.RJMCMC moves are automatically enabled.
Quick Diagnostics
print(f"Acceptance: {np.mean(sampler.acceptance_fraction):.3f}")
# Autocorr may fail for short runs; guard with try/except
try:
tau = sampler.get_autocorr_time()
print(f"Autocorr time: {tau}")
except Exception:
pass
Tips
Provide priors for all Discovery parameters; missing entries raise errors.
Use parallel tempering when you suspect multimodality; target swap rates ~20–40%.
Start from priors for broad exploration; from a Gaussian ball if you have a good initial guess.
See Also
Prior Specification - Prior format for RJMCMC
Eryn Interface - API reference for fixed-dimensional MCMC
../api/eryn_RJ_interface - API reference for RJMCMC
Example Notebooks - Example notebooks including RJMCMC