bolojax

CI status Documentation PyPI version Python License

Bolometric sensitivity calculator for CMB instruments, built on JAX.

bolojax models the full radiative transfer chain of a CMB telescope and computes noise-equivalent temperature (NET), noise-equivalent power (NEP), and mapping speed. Because the compute path is written in pure JAX, the entire forward model is automatically differentiable, enabling gradient-based fitting, Fisher forecasting, MCMC sampling, and potentially inverse design.

History

This package descends from BoloCalc by Charlie Hill (arXiv:1806.04316), which was subsequently forked and restructured by Eric Charles at KIPAC/bolo-calc. bolojax is a fork of bolo-calc that replaces the configuration layer with pydantic and the numerical backend with JAX, equinox, and zodiax, making the sensitivity calculation JIT-compiled and fully differentiable.

What you can do

Task

Before (BoloCalc / bolo-calc)

With bolojax

Forward modeling

Supported (YAML config only)

Supported, JIT-compiled

Monte Carlo (posterior predictive)

Supported (sample parameters, run many realizations)

Supported, faster with JIT

Least-squares fitting

Subprocess wrapper, finite-difference Jacobians

Exact Jacobians through autodiff

Fisher analysis

Limited by finite-difference accuracy and subprocess overhead

Exact Fisher matrix in a single JIT-compiled pass

MCMC / HMC

Not practical (no gradients)

Enabled by autodiff

Installation

uv pip install bolojax

For GPU support:

uv pip install 'bolojax[gpu]'

Architecture

bolojax separates configuration (pydantic) from computation (JAX/zodiax):

Config layer (pydantic)         Compute layer (zodiax/JAX)
───────────────────────         ──────────────────────────
ExperimentConfig                Experiment
  SimConfig                       Instrument
  Universe                          elements: {name: Element}
  InstrumentConfig                  Tc, bath_temp, psat, ...
    Optics                        fsky, obs_time, obs_effic
    CameraConfig                SensitivityResult
      ChannelConfig               NET, NEP, powers, ...

ExperimentConfig.setup() bridges the two layers, returning an Experiment zodiax pytree that you compute with, differentiate through, and modify with .set().

Quick start

YAML-driven

import bolojax

config = bolojax.ExperimentConfig.from_yaml("config/example.yaml")
experiment = config.setup()

# Compute and export
ds = experiment.to_dataset()
ds.to_netcdf("results.nc")

Or from the command line:

bolojax -i config/example.yaml -o results.nc

Programmatic

import bolojax
import equinox as eqx

# 1. Load configuration and set up the compute object
config = bolojax.ExperimentConfig.from_yaml("config/example.yaml")
experiment = config.setup()

# 2. Compute sensitivity
result = experiment.compute()
print(f"NET = {result.NET.squeeze() * 1e6:.2f} uK-rts")

# 3. Get labeled xarray output
ds = experiment.to_dataset()
ds.to_netcdf("results.nc")

# 4. Modify parameters and recompute
exp2 = experiment.set("instrument.elements.window.loss_tangent", 2.5e-4)
result2 = exp2.compute()

# 5. Differentiate (filter_grad skips non-float leaves like ndet)
@eqx.filter_grad
def grad_net(exp):
    return exp.compute().NET.squeeze()

g = grad_net(experiment)

Optical element types

Each element in the optical chain has a type field and computes its own emissivity and transmission from physical properties. Quantities can include units.:

elements:
  - forebaffle:
      temperature: "240 K"
      scatter_frac: 0.013
  - window:
      type: dielectric
      temperature: "260 K"
      thickness: "0.025 m"
      index: 1.5246
      loss_tangent: 1.38e-4
      reflection: 0.01
  - primary:
      type: mirror
      temperature: "273 K"
      conductivity: 3.6e7
  - aperture:
      type: aperture_stop
      temperature: "5.5 K"
      spillover: 0.15

License

BSD 3-Clause. See LICENSE for details.