JAX-LPT: a fast & differentiable code for
                 Lagrangian Perturbation Theory

Axel Lapel - PhD @IAP

Supervision: Guilhem Lavaux, Pauline Zarrouk & Karim Benabed

Cosmological simulations

Statistical inference

Who is the intended audience for this talk?

Theory

  • Initial conditions
  • Structure formation
  • Cosmological model

Methodology

  • Forward modeling
  • Gradient-based inference

F. Leclercq et al. (2021)

Lagrangian Perturbation Theory

  • Analytical prescription for modeling the LSS:
    Perturbations in the matter density field lead to structure formation in an expanding Universe.
        Bridges the gap between linear theory and non-linear numerical codes
     
    • Tracks (ballistic) Lagrangian trajectories of mass elements



       
    • Relies on a perturbative expansion of the displacement field in powers of the linear growth factor
\boldsymbol{x}(\boldsymbol{q}, t) = \boldsymbol{q} + \boldsymbol{\Psi}(\boldsymbol{q}, t)

Eulerian coord.

Lagrangian coord.

Displacement field

\boldsymbol{\Psi}(\boldsymbol{q}, t) = D_+(t) \boldsymbol{\Psi}^{(1)}(\boldsymbol{q}) + D_+^2(t) \boldsymbol{\Psi}^{(2)}(\boldsymbol{q}) + \ldots

Lagrangian Perturbation Theory

  • 1LPT (Zel'dovich approximation):

                                                          with
     
    • Trajectories follow the gradient of a Lagrangian potential sourced by the initial density
\boldsymbol{\Psi}^{(1)}(\boldsymbol{q}) = - \nabla_\boldsymbol{q} \phi^{(1)}(\boldsymbol{q})
\Delta_\boldsymbol{q} \phi^{(1)}(\boldsymbol{q}) = \delta_{\text{ini}}(\boldsymbol{q})
- D_+(t) \nabla_\boldsymbol{q} \Delta^{-1}_\boldsymbol{q} \delta_{\text{ini}}(\boldsymbol{q})
\boldsymbol{q}
\boldsymbol{x}(\boldsymbol{q}, t)
  • Initial gravitational potential
  • Growth (matter, expansion, gravity ...)

Lagrangian Perturbation Theory

  • 1LPT (Zel'dovich approximation):

                                                          with
     
    • Trajectories follow the gradient of a Lagrangian potential sourced by the initial density
\boldsymbol{\Psi}^{(1)}(\boldsymbol{q}) = - \nabla_\boldsymbol{q} \phi^{(1)}(\boldsymbol{q})
\Delta_\boldsymbol{q} \phi^{(1)}(\boldsymbol{q}) = \delta_{\text{ini}}(\boldsymbol{q})
  • 2LPT:

                                                         
    with
     
    • Correction to Zel'dovich displacements from gravitational tidal effects
    • More costly but significant improvement
\boldsymbol{\Psi}^{(2)}(\boldsymbol{q}) = \nabla_\boldsymbol{q} \phi^{(2)}(\boldsymbol{q})
\Delta_\boldsymbol{q} \phi^{(2)}(\boldsymbol{q}) = -\dfrac{3}{7}\sum_{i>j}\left(\phi^{(1)}_{,ii}(\boldsymbol{q})\phi^{(1)}_{,jj}(\boldsymbol{q})-\left[\phi^{(1)}_{,ij}(\boldsymbol{q})\right]^2 \right)

Lagrangian Perturbation Theory

Advantages

Limits

  • Fast analytical predictions
     
  • Non-linear structure formation model
     
  • Initial conditions for N-body simulations
    • MUSIC   
    • 2LPTic
    • N-GenIC
      ...
  • BAO reconstruction

O. Hahn et al. (2011)

R. Scoccimarro et al. (2006)

V. Springel (2005)

  • Breakdown at small-scales
     
  • Higher-order calculations
     
  • Shell-crossing

O. Hahn et al. (2016)

Differentiability: why should we bother ?

  • A hunt for the small & non-linear scales, requirements of modern cosmology
     
  • Complex inference problems require advanced techniques.
     
  • Differentiable cosmological simulations provide the gradients for optimization.

Ex: the Hamiltonian Monte Carlo sampler

An active field

H(\mathbf{x},\mathbf{p})= -\ln \mathcal{L}(\mathbf{x}) + \dfrac{1}{2}\mathbf{p}^\text{T} M^{-1}\mathbf{p}

Y. Li et al. (2022)

F. Lanusse et al.

Rise of field-level inference

\frac{\text{d} x_i}{\text{d} t}=\frac{\partial H}{\partial p_i}
\dfrac{\text{d} p_i}{\text{d} t}=-\dfrac{\partial H}{\partial x_i}

How to extract (maximum) information from cosmological fields ?

How to extract (maximum) information from cosmological fields ?

2pt statistics

Planck Collaboration

Planck Collaboration

How to extract (maximum) information from cosmological fields ?

Adapted from C.Hahn et al. (2020)

3pt statistics

Planck Collaboration

2pt statistics

Planck Collaboration

Initial conditions

Final field

Physical forward model

How to extract (maximum) information from cosmological fields ?

Field level inference

Statistical inference

Fast

Differentiable

An example framework: BORG

Final field(s)

Observations

Bayesian Origin Reconstruction from Galaxies
 

  • (Analytical) differentiable model
  • HMC sampler

    BORG infers:
    1. Initial conditions
    2. Bias parameters
    3. Cosmological parameters

Courtesy of F. Leclercq

Jasche & Wandelt (2013), Jasche & Lavaux (2019)

Initial conditions

JAX-LPT for a versatile and modular
accelerated gravity model

Autodiff

JAX-LPT: in a nutshell

Initial density

Transfer functions

Final density

\Omega_m, \Omega_b \newline \Omega_\Lambda, H_0 \newline \sigma_8, n_s

Cosmology

(+         ,      )

f_{R0}
\sum m_\nu

1LPT, 2LPT

Mass assignment

Growth

  • Lagrangian displacements (1st & 2nd order)
  • Automatic differenciation
  • Swift generation of multiple realizations with JIT compilation
  • Acceleration through GPU support

JAX-LPT: example run

from jax_lpt.lpt import JaxLptSolver, Jax2LptSolver
from jax_lpt.simgrid import Box
from jax_lpt.cosmology import CosmologicalParameters
from jax_lpt.utils import generate_initial_density
# Simulation box
Lbox = 1024.0
Nmesh = 128
box = Box(Lbox, Nmesh)

# Scale-factors
a_init = 0.01
a_final = 1.0

# Fiducial cosmology
cosmo = cosmology.CosmologicalParameters()

# JAX-LPT gravity models
jax_1lpt = JaxLptSolver(box, cosmo, a_init, a_final)
jax_2lpt = Jax2LptSolver(box, cosmo, a_init, a_final)

# Initial density field
delta_init = generate_initial_density(Lbox, Nmesh, cosmo, a_init)

1. Modules

2. Initialization

3. Forward run

# Evolved density fields
delta_1lpt = jax_1lpt.run(delta_init)
delta_2lpt = jax_2lpt.run(delta_init)

JAX-LPT: example run

from jax_lpt.lpt import JaxLptSolver, Jax2LptSolver
from jax_lpt.simgrid import Box
from jax_lpt.cosmology import CosmologicalParameters
from jax_lpt.utils import generate_initial_density
# Simulation box
Lbox = 1024.0
Nmesh = 128
box = Box(Lbox, Nmesh)

# Scale-factors
a_init = 0.01
a_final = 1.0

# Fiducial cosmology
cosmo = cosmology.CosmologicalParameters()

# JAX-LPT gravity models
jax_1lpt = JaxLptSolver(box, cosmo, a_init, a_final)
jax_2lpt = Jax2LptSolver(box, cosmo, a_init, a_final)

# Initial density field
delta_init = generate_initial_density(Lbox, Nmesh, cosmo, a_init)

1. Modules

2. Initialization

3. Forward run

# Evolved density fields
delta_1lpt = jax_1lpt.run(delta_init)
delta_2lpt = jax_2lpt.run(delta_init)

JAX-LPT: example run

from jax_lpt.lpt import JaxLptSolver, Jax2LptSolver
from jax_lpt.simgrid import Box
from jax_lpt.cosmology import CosmologicalParameters
from jax_lpt.utils import generate_initial_density
# Simulation box
Lbox = 1024.0
Nmesh = 128
box = Box(Lbox, Nmesh)

# Scale-factors
a_init = 0.01
a_final = 1.0

# Fiducial cosmology
cosmo = cosmology.CosmologicalParameters()

# JAX-LPT gravity models
jax_1lpt = JaxLptSolver(box, cosmo, a_init, a_final)
jax_2lpt = Jax2LptSolver(box, cosmo, a_init, a_final)

# Initial density field
delta_init = generate_initial_density(Lbox, Nmesh, cosmo, a_init)

1. Modules

2. Initialization

3. Forward run

# Evolved density fields
delta_1lpt = jax_1lpt.run(delta_init)
delta_2lpt = jax_2lpt.run(delta_init)

JAX-LPT: simulation products

Power spectra

Density fields

JAX-LPT: diagnostics - power spectrum

1LPT

  • 1LPT: ~ 1ppb accuracy
  • 2LPT: ~ per-mille accuracy

2pt statistics of JAX-LPT are consistent with BORG LPT

JAX-LPT: diagnostics - power spectrum

1LPT

2LPT

  • 1LPT: ~ 1ppb accuracy
  • 2LPT: ~ per-mille accuracy

2pt statistics of JAX-LPT are consistent with BORG LPT

JAX-LPT: diagnostics - bispectrum

k_1
k_2
k_3
k_1
k_2
k_3
k_2
k_3
k_1

1LPT

3-point statistics also match those of BORG LPT models.

  • Non-Gaussian information
     
  • Field-level prescription for late-time cosmological fields in the limit of shell-crossing

JAX-LPT: diagnostics - bispectrum

k_1
k_2
k_3
k_1
k_2
k_3
k_2
k_3
k_1
k_1
k_2
k_3
k_2
k_3
k_1
k_1
k_2
k_3

2LPT

3-point statistics also match those of BORG LPT models.

  • Non-Gaussian information
     
  • Field-level prescription for late-time cosmological fields in the limit of shell-crossing

JAX-LPT: ongoing projects

Field-level inference with f(R) gravity

(A. Lapel, D. Bartlett, H. Desmond, A. Kostić)

Neural N-body emulator from LPT displacements

(L. Doeser, D. Jamieson, J. Jasche, G. Lavaux)

Field-level signature of massive neutrinos

(A. Lapel, G. Lavaux, P. Zarrouk, K. Benabed)

Feel free to contact me if you want to play with JAX-LPT or its branches

axel.lapel@iap.fr

JAX-LPT: ongoing projects

Field-level inference with f(R) gravity

(A. Lapel, D. Bartlett, H. Desmond, A. Kostić)

Neural N-body emulator from LPT displacements

(L. Doeser, D. Jamieson, J. Jasche, G. Lavaux)

Field-level signature of massive neutrinos

(A. Lapel, G. Lavaux, P. Zarrouk, K. Benabed)

axel.lapel@iap.fr

Feel free to contact me if you want to play with JAX-LPT or its branches

JAX-LPT: ongoing projects

Field-level inference with f(R) gravity

(A. Lapel, D. Bartlett, H. Desmond, A. Kostić)

Neural N-body emulator from LPT displacements

(L. Doeser, D. Jamieson, J. Jasche, G. Lavaux)

Field-level signature of massive neutrinos

(A. Lapel, G. Lavaux, P. Zarrouk, K. Benabed)

axel.lapel@iap.fr

Feel free to contact me if you want to play with JAX-LPT or its branches