"""
Discovery ↔︎ GPry bridge (via **Cobaya**)
This version is tailored to **GPry** and Cobaya's model wrapper.
It builds a Cobaya `model` (with your likelihood + priors) and passes that
`model` directly to `gpry.Runner`.
Key references:
- GPry Runner accepts a Cobaya `model` as its first argument; in that case it
reads priors and names from the model (no `bounds` needed).
- In Cobaya, an external likelihood can be a plain Python function placed under
the `likelihood` block, and priors are defined in the `params` block using
scipy.stats names (e.g. `uniform`, `norm`, `loguniform`) or `min`/`max`.
What you get
------------
- `build_cobaya_info(...)` → a ready-to-use Cobaya `info` dict (likelihood + params)
- `get_cobaya_model(info)` → a `cobaya.model.Model`
- `DiscoveryGPryCobayaBridge` → convenience class that:
1) adapts your Discovery-style likelihood
2) builds the Cobaya model with proper priors
3) launches `gpry.Runner(model, ...)` and exposes results
Assumptions
-----------
- Your *Discovery* model is either a callable `loglike(params_dict) -> float`
or an object with `.log_prob(params_dict)` / `.log_likelihood(params_dict)`.
- You provide priors in a compact mapping per parameter.
Supported entries here map to Cobaya priors:
* ("uniform", a, b) → `prior: {min: a, max: b}`
* ("loguniform", a, b) → `prior: {dist: loguniform, a: a, b: b}`
* ("normal", mean, sigma[, ...]) → `prior: {dist: norm, loc: mean, scale: sigma}`
* ("fixed", value) → `value: value` (non-sampled)
(These match Cobaya's documented syntax and SciPy parameterization. citeturn6search0)
If you need custom/callable priors, you can add them under Cobaya's top-level
`prior:` block—see the TODO note in `build_cobaya_info`.
Notes on parameter naming
-------------------------
Cobaya requires parameter names to be valid Python identifiers (no special chars
like `+`, `-`, etc.). This module automatically sanitizes parameter names by
replacing invalid characters with underscores, and maintains a mapping to convert
back to the original names when calling the Discovery likelihood.
"""
from __future__ import annotations
import re
from typing import Any, Callable, Dict, Mapping, MutableMapping, Optional, Sequence, Tuple, List
import numpy as np
from .priors import ParsedPrior, PriorParsingError, _parse_single_prior, ParamName, PriorSpec
# ------------------------- Name sanitization ---------------------------- #
def _sanitize_param_name(name: str) -> str:
"""
Convert a parameter name to a valid Python identifier for Cobaya.
Replaces any character that's not alphanumeric or underscore with underscore.
Ensures the name starts with a letter or underscore.
"""
# Replace invalid characters with underscore
sanitized = re.sub(r'[^a-zA-Z0-9_]', '_', name)
# Ensure it starts with a letter or underscore
if sanitized and sanitized[0].isdigit():
sanitized = '_' + sanitized
return sanitized
def _create_name_mappings(param_names: List[str]) -> Tuple[Dict[str, str], Dict[str, str]]:
"""
Create bidirectional mappings between original and sanitized parameter names.
Returns
-------
orig_to_sanitized : dict
Maps original names to sanitized names
sanitized_to_orig : dict
Maps sanitized names back to original names
"""
orig_to_sanitized = {}
sanitized_to_orig = {}
for name in param_names:
sanitized = _sanitize_param_name(name)
# Handle potential collisions by adding suffix
base_sanitized = sanitized
counter = 1
while sanitized in sanitized_to_orig and sanitized_to_orig[sanitized] != name:
sanitized = f"{base_sanitized}_{counter}"
counter += 1
orig_to_sanitized[name] = sanitized
sanitized_to_orig[sanitized] = name
return orig_to_sanitized, sanitized_to_orig
# --------------------------- GPry-specific adapter ---------------------- #
class _GPryDiscoveryAdapter:
"""
Simplified Discovery adapter for GPry/Cobaya interface.
This is a lightweight wrapper that doesn't use JAX JIT or array APIs,
suitable for the Cobaya integration.
"""
def __init__(self, model: Any):
self.model = model
if callable(model):
self._fn = model
elif hasattr(model, "log_prob") and callable(getattr(model, "log_prob")):
self._fn = getattr(model, "log_prob")
elif hasattr(model, "logL") and callable(getattr(model, "logL")):
self._fn = getattr(model, "logL")
else:
raise TypeError(
"Model must be callable or expose .log_prob/.logL(params_dict)."
)
def __call__(self, params: Mapping[str, float]) -> float:
return float(self._fn(params))
# ------------------------------ Priors ---------------------------------- #
def split_priors(priors: Mapping[str, Any]) -> Tuple[List[str], Dict[str, float], Dict[str, ParsedPrior]]:
"""
Split priors into sampled, fixed parameters, and parsed prior objects.
Uses the common _parse_single_prior function but returns a format
suitable for GPry/Cobaya.
Returns
-------
sampled : list of str
Names of sampled parameters
fixed : dict
Fixed parameter values
parsed : dict
ParsedPrior objects for each parameter
"""
sampled: List[str] = []
fixed: Dict[str, float] = {}
parsed: Dict[str, ParsedPrior] = {}
for name, spec in priors.items():
p = _parse_single_prior(name, spec)
parsed[name] = p
if p.dist_type == "fixed":
fixed[name] = p.value
else:
sampled.append(name)
return sampled, fixed, parsed
# --------------------------- Cobaya builders ---------------------------- #
def _make_likelihood_func(parsed_prior: Mapping[str, ParsedPrior], adapter: _GPryDiscoveryAdapter,
sanitized_to_orig: Mapping[str, str],
fixed_params: Mapping[str, float]) -> Tuple[Callable[..., float], Tuple[str, ...]]:
"""Return a Cobaya-friendly likelihood.
Cobaya accepts external likelihood *functions*; it will pass the sampled
parameters by name as kwargs. We accept **kwargs to be robust to future
changes in the parameterization. (Cobaya can introspect function args, but
since we declare params explicitly in the `params` block, **kwargs is fine.)
See Cobaya's advanced example for function likelihoods. citeturn5view0
Parameters
----------
parsed_prior : dict
Mapping from original parameter names to ParsedPrior objects
adapter : _GPryDiscoveryAdapter
The likelihood adapter that accepts original parameter names
sanitized_to_orig : dict
Mapping from sanitized names (used by Cobaya) back to original names
fixed_params : dict
Mapping from original parameter names to fixed values
Returns
-------
loglike : callable
Likelihood function accepting sanitized parameter names as kwargs
sampled_names : tuple
Tuple of sanitized parameter names for sampled (non-fixed) parameters
"""
# Get sanitized names for sampled parameters only
orig_to_sanitized = {v: k for k, v in sanitized_to_orig.items()}
sampled_sanitized = tuple(
orig_to_sanitized[name] for name, p in parsed_prior.items()
if p.dist_type != "fixed"
)
# Store fixed params to include in every likelihood call
_fixed_params = dict(fixed_params)
def loglike(**kwargs: float) -> float:
# Start with fixed parameters
pd: Dict[str, float] = dict(_fixed_params)
# Add sampled parameters (convert sanitized names back to original)
for sanitized_name, value in kwargs.items():
if sanitized_name in sanitized_to_orig:
orig_name = sanitized_to_orig[sanitized_name]
pd[orig_name] = float(value)
return adapter(pd)
return loglike, sampled_sanitized
def _prior_entry_from_parsed(p: ParsedPrior) -> Dict[str, Any] | float:
"""
Convert a ParsedPrior from the common module to Cobaya format.
Parameters
----------
p : ParsedPrior
Parsed prior from common module
Returns
-------
dict or float
Cobaya prior specification
"""
if p.dist_type == "uniform":
a, b = p.bounds
return {"prior": {"min": a, "max": b}}
if p.dist_type == "loguniform":
a, b = p.bounds
# Cobaya delegates to SciPy: loguniform takes 'a' (min) and 'b' (max).
# https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.loguniform.html
return {"prior": {"dist": "loguniform", "a": a, "b": b}}
if p.dist_type == "normal":
mu, sigma = p.mean, p.sigma
return {"prior": {"dist": "norm", "loc": mu, "scale": sigma}}
if p.dist_type == "fixed":
return p.value # fixed parameter
raise ValueError(f"Unsupported prior type for GPry/Cobaya: {p.dist_type}")
[docs]
def build_cobaya_info(
*,
discovery_model: Any,
priors: Mapping[str, Any],
latex_labels: Optional[Mapping[str, str]] = None,
like_name: str = "discovery_like",
) -> Tuple[Dict[str, Any], Dict[str, str], Dict[str, str], Dict[str, float]]:
"""Create a Cobaya `info` dict with one external likelihood and a `params` block.
Parameters
----------
discovery_model : Any
Discovery model (callable or object with logL method)
priors : dict
Prior specifications keyed by original parameter names
latex_labels : dict, optional
LaTeX labels keyed by original parameter names
like_name : str
Name for the likelihood in the Cobaya info
Returns
-------
info : dict
Cobaya info dict ready for get_model()
orig_to_sanitized : dict
Mapping from original parameter names to sanitized names
sanitized_to_orig : dict
Mapping from sanitized names back to original names
fixed_params : dict
Fixed parameter values (original names)
Notes
-----
- Parameter names are automatically sanitized to be valid Python identifiers
- Fixed parameters are baked into the likelihood function and NOT included
in the Cobaya params block (to avoid "unused parameter" errors)
- If you need external/callable *global* priors (beyond 1D priors in params),
add them to `info['prior']` after this returns (see Cobaya docs). citeturn5view0
"""
adapter = _GPryDiscoveryAdapter(discovery_model)
sampled, fixed, parsed = split_priors(priors)
# Create name mappings (original <-> sanitized)
all_param_names = list(priors.keys())
orig_to_sanitized, sanitized_to_orig = _create_name_mappings(all_param_names)
# Likelihood function (external function style)
# Pass fixed params so they are baked into the likelihood
loglike, sampled_names = _make_likelihood_func(parsed, adapter, sanitized_to_orig, fixed)
info: Dict[str, Any] = {"likelihood":
{like_name:
{"external": loglike,
"input_params": sampled_names}}}
# Params block (using sanitized names, ONLY for sampled parameters)
pblock: Dict[str, Any] = {}
for orig_name, p in parsed.items():
# Skip fixed parameters - they're baked into the likelihood
if p.dist_type == "fixed":
continue
sanitized_name = orig_to_sanitized[orig_name]
entry = _prior_entry_from_parsed(p)
# add latex label if provided (use original name for lookup)
if isinstance(entry, dict):
if latex_labels and orig_name in latex_labels:
entry = {**entry, "latex": latex_labels[orig_name]}
elif orig_name != sanitized_name:
# Use original name as latex label if no explicit label provided
entry = {**entry, "latex": orig_name}
pblock[sanitized_name] = entry
info["params"] = pblock
return info, orig_to_sanitized, sanitized_to_orig, fixed
[docs]
def get_cobaya_model(info: Mapping[str, Any]):
from cobaya.model import get_model
return get_model(info)
# ------------------------------ GPry glue -------------------------------- #
[docs]
class DiscoveryGPryCobayaBridge:
"""Bridge that:
- builds a Cobaya model from a Discovery-style likelihood + priors
- runs GPry by passing the Cobaya model to `gpry.Runner`
Parameter names are automatically sanitized to be valid Python identifiers
(required by Cobaya). The original names are stored and can be used for
accessing results.
Attributes
----------
orig_param_names : list
Original parameter names from the priors dict
sanitized_param_names : list
Sanitized parameter names used internally by Cobaya
orig_to_sanitized : dict
Mapping from original to sanitized names
sanitized_to_orig : dict
Mapping from sanitized to original names
"""
[docs]
def __init__(
self,
discovery_model: Any,
priors: Mapping[str, Any],
*,
latex_labels: Optional[Mapping[str, str]] = None,
like_name: str = "discovery_like",
runner_kwargs: Optional[Mapping[str, Any]] = None,
) -> None:
# Store original parameter names
self.orig_param_names = list(priors.keys())
# Build Cobaya info with sanitized names
self.info, self.orig_to_sanitized, self.sanitized_to_orig, self.fixed_param_dict = build_cobaya_info(
discovery_model=discovery_model,
priors=priors,
latex_labels=latex_labels,
like_name=like_name,
)
# Store sanitized names (only sampled params are in info["params"])
self.sanitized_param_names = list(self.info["params"].keys())
self.sampled_names_sanitized = list(self.info["params"].keys())
self.fixed_names_sanitized = [self.orig_to_sanitized[n] for n in self.fixed_param_dict.keys()]
# Also provide original names for convenience
self.sampled_names = [self.sanitized_to_orig[n] for n in self.sampled_names_sanitized]
self.fixed_names = list(self.fixed_param_dict.keys())
self.model = get_cobaya_model(self.info)
self.runner_kwargs = dict(runner_kwargs or {})
self.runner = None
self.results = None
[docs]
def run_sampler(self, *, checkpoint: Optional[str] = None, **run_kwargs: Any):
"""Create and run `gpry.Runner` with the Cobaya model.
Per GPry docs, when passing a Cobaya `model` as first argument, GPry
uses the model's prior and parameter names automatically. citeturn2view0
Returns
-------
info : dict
The Cobaya info dict used to create the model
sampler : gpry.Runner
The GPry Runner object after running
"""
from gpry import Runner
rkwargs = dict(self.runner_kwargs)
if checkpoint is not None:
rkwargs.setdefault("checkpoint", checkpoint)
runner = Runner(self.model, **rkwargs)
self.runner = runner
self.results = runner.run(**run_kwargs)
return self.info, runner
# Convenience accessors (these use GPry's common post-run attributes)
[docs]
def posterior_samples(self) -> Optional[np.ndarray]:
if self.runner is None:
return None
# GPry Cobaya wrapper stores samples via Cobaya products; but the Runner
# can also expose `surrogate_samples` after MC on the GP. Try common names.
for attr in ("posterior_samples", "surrogate_samples", "samples"):
if hasattr(self.runner, attr):
return np.asarray(getattr(self.runner, attr))
return None
[docs]
def return_logZ(self, *, results=None) -> Dict[str, float]:
"""
Return the log evidence estimate.
Note: GPry uses Gaussian Process surrogate modeling and does not directly
compute the Bayesian evidence in the same way nested samplers do.
This method is provided for API consistency but raises NotImplementedError.
Raises
------
NotImplementedError
Always raised - GPry does not compute evidence in the standard nested sampling sense
"""
raise NotImplementedError(
"GPry uses Gaussian Process surrogate modeling and does not compute the "
"Bayesian evidence (logZ) in the standard nested sampling sense. "
"Use a nested sampling method (Nessai, JAX-NS) if you need evidence estimates."
)
[docs]
def return_sampled_samples(self, *, results=None) -> Dict[str, Any]:
"""
Return the sampled parameter chains.
Returns
-------
dict
Dictionary containing:
- 'names': list of original parameter names
- 'labels': list of parameter labels
- 'chain': ndarray of shape (nsamples, n_sampled_params)
Raises
------
RuntimeError
If no results are available
"""
samples = self.posterior_samples()
if samples is None:
raise RuntimeError("No posterior samples available. Run `run_sampler()` first.")
# GPry returns samples with sanitized names - we map back to original
return {
"names": self.sampled_names,
"labels": self.sampled_names, # Use original names as labels
"chain": samples
}
[docs]
def return_all_samples(self, *, results=None) -> Dict[str, Any]:
"""
Return all samples including fixed parameters.
Returns
-------
dict
Dictionary containing:
- 'names': list of all parameter names
- 'labels': list of parameter labels
- 'chain': ndarray of shape (nsamples, n_all_params)
Raises
------
RuntimeError
If no results are available
"""
samples = self.posterior_samples()
if samples is None:
raise RuntimeError("No posterior samples available. Run `run_sampler()` first.")
nsamp = samples.shape[0]
chain_all = np.zeros((nsamp, len(self.orig_param_names)), dtype=float)
# Fill sampled parameters
for j, name in enumerate(self.sampled_names):
idx = self.orig_param_names.index(name)
chain_all[:, idx] = samples[:, j]
# Fill fixed parameters
for name in self.fixed_names:
idx = self.orig_param_names.index(name)
chain_all[:, idx] = float(self.fixed_param_dict[name])
return {
"names": self.orig_param_names,
"labels": self.orig_param_names,
"chain": chain_all
}
# ------------------------------ Plots ------------------------------ #
[docs]
def plot_trace(self, *, burn: int = 0, plot_fixed: bool = False, **kwargs):
"""
Plot trace of samples vs sample index.
Parameters
----------
burn : int, optional
Number of initial samples to discard, by default 0.
plot_fixed : bool, optional
If True, includes fixed parameters in the plot, by default False.
**kwargs
Additional keyword arguments passed to plots.plot_trace().
Returns
-------
matplotlib.figure.Figure
Figure containing the trace plots.
"""
from .plots import plot_trace
data = self.return_all_samples() if plot_fixed else self.return_sampled_samples()
return plot_trace(
data,
burn=burn,
fixed_params=self.fixed_param_dict,
fixed_names=self.fixed_names,
**kwargs
)
[docs]
def plot_corner(self, *, burn: int = 0, **kwargs):
"""
Corner plot of sampled parameters.
Parameters
----------
burn : int, optional
Number of initial samples to discard, by default 0.
**kwargs
Additional keyword arguments passed to corner.corner().
Returns
-------
matplotlib.figure.Figure
Corner plot figure.
"""
from .plots import plot_corner
data = self.return_sampled_samples()
return plot_corner(data, burn=burn, **kwargs)
__all__ = [
"build_cobaya_info",
"get_cobaya_model",
"DiscoveryGPryCobayaBridge",
]