Extending discoverysamplers

This guide shows how to extend discoverysamplers to support new samplers or customize existing interfaces.

Creating a New Sampler Interface

Basic Structure

All sampler interfaces follow a common pattern:

from typing import Dict, Any, Optional
import numpy as np

class DiscoveryNewSamplerBridge:
    """Bridge interface for NewSampler."""

    def __init__(
        self,
        discovery_model,
        priors: Dict[str, Any],
        latex_labels: Optional[Dict[str, str]] = None,
        jit: bool = False
    ):
        """
        Initialize bridge.

        Parameters
        ----------
        discovery_model : callable
            Model function accepting parameter dict, returning log-likelihood
        priors : dict
            Prior specifications
        latex_labels : dict, optional
            LaTeX labels for parameters
        jit : bool
            Enable JIT compilation if applicable
        """
        self.model = discovery_model
        self.jit = jit

        # Parse priors (separate fixed vs sampled)
        self._parse_priors(priors)

        # Setup parameter names and mappings
        self._setup_parameters()

        # Store LaTeX labels
        self.latex_labels = latex_labels or {}

    def _parse_priors(self, priors):
        """Parse prior specifications."""
        # Implement prior parsing
        # Separate fixed parameters from sampled parameters
        pass

    def _setup_parameters(self):
        """Setup parameter name lists and mappings."""
        # Create parameter name lists
        # Setup index mappings
        pass

    def run_sampler(self, **kwargs):
        """
        Run the sampler.

        Parameters
        ----------
        **kwargs
            Sampler-specific options

        Returns
        -------
        dict
            Results dictionary
        """
        # Implement sampler execution
        pass

Example: Minimal Sampler Interface

from discoverysamplers.nessai_interface import _split_priors, _parse_single_prior

class DiscoveryMinimalBridge:
    """Minimal sampler bridge example."""

    def __init__(self, discovery_model, priors, latex_labels=None):
        self.model = discovery_model

        # Parse priors
        self.sampled_prior_dict, self.fixed_param_dict = _split_priors(priors)

        # Parameter names
        self.sampled_names = list(self.sampled_prior_dict.keys())
        self.fixed_names = list(self.fixed_param_dict.keys())
        self.discovery_paramnames = self.sampled_names + self.fixed_names

        # Counts
        self.n_sampled = len(self.sampled_names)
        self.n_fixed = len(self.fixed_names)

        # LaTeX labels
        self.latex_labels = latex_labels or {
            name: name for name in self.discovery_paramnames
        }

    def dict_to_array(self, params_dict):
        """Convert parameter dict to array."""
        return np.array([params_dict[name] for name in self.sampled_names])

    def array_to_dict(self, params_array):
        """Convert array to parameter dict (including fixed params)."""
        params_dict = {
            name: val for name, val in zip(self.sampled_names, params_array)
        }
        params_dict.update(self.fixed_param_dict)
        return params_dict

    def log_likelihood(self, params_array):
        """Evaluate log-likelihood from array."""
        params_dict = self.array_to_dict(params_array)
        return self.model(params_dict)

    def run_sampler(self, n_samples=1000):
        """Run a simple rejection sampler (example)."""
        # Sample from priors
        samples = []
        log_likes = []

        for _ in range(n_samples):
            # Sample from priors
            params_array = np.array([
                self._sample_single_prior(self.sampled_prior_dict[name])
                for name in self.sampled_names
            ])

            # Evaluate likelihood
            log_L = self.log_likelihood(params_array)

            samples.append(params_array)
            log_likes.append(log_L)

        return {
            'samples': np.array(samples),
            'log_likes': np.array(log_likes),
        }

    def _sample_single_prior(self, prior_spec):
        """Sample from single prior specification."""
        # Implement based on prior format
        pass

Reusing Common Components

Import utility functions from existing interfaces:

from discoverysamplers.nessai_interface import (
    _split_priors,
    _parse_single_prior,
    ParsedPrior,
    _DiscoveryAdapter
)

class DiscoveryNewBridge:
    def __init__(self, model, priors, latex_labels=None):
        # Use existing prior parsing
        self.sampled_prior_dict, self.fixed_param_dict = _split_priors(priors)

        # Parse each prior
        self.parsed_priors = {
            name: _parse_single_prior(spec, name)
            for name, spec in self.sampled_prior_dict.items()
        }

        # Use discovery adapter
        self.adapter = _DiscoveryAdapter(
            model,
            self.fixed_param_dict,
            allow_array_api=True
        )

Adding New Prior Types

Extend Prior Parsing

To add new prior types, extend the prior parsing functions:

def _parse_single_prior_extended(prior_spec, param_name):
    """Extended prior parser with new types."""
    from discoverysamplers.nessai_interface import _parse_single_prior, ParsedPrior

    # First try standard parsing
    try:
        return _parse_single_prior(prior_spec, param_name)
    except:
        pass

    # Add new prior types
    if isinstance(prior_spec, dict):
        dist_type = prior_spec.get('dist', '')

        if dist_type == 'beta':
            # Beta distribution
            return ParsedPrior(
                dist_type='beta',
                a=prior_spec['a'],
                b=prior_spec['b'],
                bounds=(0, 1)
            )

        elif dist_type == 'gamma':
            # Gamma distribution
            return ParsedPrior(
                dist_type='gamma',
                shape=prior_spec['shape'],
                scale=prior_spec.get('scale', 1.0),
                bounds=(0, np.inf)
            )

    # If nothing worked, raise error
    raise ValueError(f"Cannot parse prior specification: {prior_spec}")

Custom Prior Container

For samplers that need special prior objects:

class CustomPriorContainer:
    """Container for prior distributions."""

    def __init__(self, prior_dict):
        """
        Initialize from prior dictionary.

        Parameters
        ----------
        prior_dict : dict
            Dictionary mapping parameter names to prior specs
        """
        self.priors = {}

        for name, spec in prior_dict.items():
            self.priors[name] = self._create_prior(spec)

    def _create_prior(self, spec):
        """Create prior object from specification."""
        # Convert spec to appropriate prior object
        pass

    def sample(self, n=1):
        """Sample from all priors."""
        samples = {}
        for name, prior in self.priors.items():
            samples[name] = prior.sample(n)
        return samples

    def logpdf(self, params):
        """Compute log prior probability."""
        log_p = 0.0
        for name, value in params.items():
            log_p += self.priors[name].logpdf(value)
        return log_p

See Also