S2SCAT
is a Python package for computing scattering covariances on the sphere (Mousset et al. 2024) using JAX. It exploits autodiff to provide differentiable transforms, which are also deployable on hardware accelerators (e.g. GPUs and TPUs), leveraging the differentiable and accelerated spherical harmonic and wavelet transforms implemented in S2FFT and S2WAV, respectively. Scattering covariances are useful both for field-level generative modelling of complex non-Gaussian textures and for statistical compression of high dimensional field-level data, a key step of e.g. simulation based inference.
Important
It is worth highlighting that the input to S2SCAT
are spherical harmonic coefficients, which can be generated with whichever software package you prefer, e.g. S2FFT
or healpy
. Just ensure your harmonic coefficients are indexed using our convention; helper functions for this reindexing can be found in S2FFT
.
Tip
At launch S2SCAT
provides two core transform modes: on-the-fly, which performs underlying spherical harmonic and Wigner transforms through the Price & McEwen recursion; and precompute, which a priori computes and caches all Wigner elements required. The precompute approach will be faster but can only be run up to
Ballpark compute times (when running on an 40GB A100 GPU) and compression levels are given in the table below.
Method | Resolution | Forward pass | Gradient pass | JIT compilation | Input params | Anisotropic (compression) | Isotropic (compression) |
---|---|---|---|---|---|---|---|
Precompute | L=512, N=3 | ~90ms | ~190ms | ~20s | 2,618,880 | ~ 63,000 (97.594%) | ~504 (99.981%) |
On-the-fly | L=2048, N=3 | ~18s | ~40s | ~5m | 41,932,800 | ~ 123,750 (99.705%) | ~ 990 (99.998%) |
Note that these times are not batched, so in practice may be substantially faster. For example, with a large batch size at
We introduce scattering covariances on the sphere in Mousset et al. (2024), which extend to spherical settings similar scattering transforms introduced for 1D signals by Morel et al. (2023) and for planar 2D signals by Cheng et al. (2023). Scattering covariances
where
This statistical representation characterises the power and sparsity at given scales, as well as covariant features between different wavelet scale and directions, which can effectively capture complex non-Gaussian structural information, e.g. filamentary structure.
Using the recently released JAX spherical harmonic code S2FFT
(Price & McEwen 2024) and spherical wavelet transform code S2WAV
(Price et al. 2024) in the S2SCAT
code we extends scattering covariances to the sphere, which are necessary for their application to generative modelling of wide-field cosmological fields (Mousset et al. 2024).
To import and use S2SCAT
is as simple follows:
import s2scat, jax
# For statistical compression
encoder = s2scat.build_encoder(L, N) # Returns a callable compression model.
covariance_statistics = encoder(alm) # Generate statistics (can be batched).
# For generative modelling
key = jax.random.PRNGKey(seed)
generator = s2scat.build_generator(alm, L, N) # Returns a callable generative model.
new_samples = generator(key, 10) # Generate 10 new spherical textures.
For further details on usage see the documentation and associated notebooks.
s2scat/
βββ representation.py # - Scattering covariance transform.
βββ compression.py # - Statistical compression functions.
βββ optimisation.py # - Optimisation algorithm wrappers.
βββ generation.py # - Latent encoder and Generative decoder.
β
βββ operators/ # Internal functionality:
β ββ spherical.py # - Specific spherical operations, e.g. batched SHTs.
β ββ matrices.py # - Wrappers to generate cached values.
β
βββ utility/ # Convenience functionality:
β ββ reorder.py # - Reindexing and converting list and arrays.
β ββ statistics.py # - Calculation of covariance statistics.
β ββ normalisation.py # - Normalisation functions for covariance statistics.
β ββ plotting.py # - Plotting functions for signals and statistics.
The Python dependencies for the S2SCAT
package are listed in the file
requirements/requirements-core.txt
and will be automatically installed
into the active python environment by pip when running
pip install s2scat
This will install all core functionality which includes full JAX support.
Alternatively, the S2SCAT
package may be installed directly from GitHub by cloning this
repository and then running
pip install .
from the root directory of the repository.
Unit tests can then be executed to ensure the installation was successful by first installing the test requirements and then running pytest
pip install -r requirements/requirements-tests.txt
pytest tests/
Documentation for the released version is available here.
Matt Price π€ π» π¨ π |
mousset π» π¨ π€ |
Jason McEwen π€ π» π |
Eralys π€ |
Should this code be used in any way, we kindly request that the following article is referenced. A BibTeX entry for this reference may look like:
@article{mousset:s2scat,
author = "Louise Mousset et al",
title = "TBD",
journal = "TBD, submitted",
year = "2024",
eprint = "TBD"
}
You might also like to consider citing our related papers on which this code builds:
@article{price:s2fft,
author = "Matthew A. Price and Jason D. McEwen",
title = "Differentiable and accelerated spherical harmonic and {W}igner transforms",
journal = "Journal of Computational Physics",
volume = "510",
pages = "113109",
year = "2024",
doi = {10.1016/j.jcp.2024.113109},
eprint = "arXiv:2311.14670"
}
@article{price:s2wav,
author = "Matthew A. Price and Alicja Polanska and Jessica Whitney and Jason D. McEwen",
title = "Differentiable and accelerated directional wavelet transform on the sphere and ball",
year = "2024",
eprint = "arXiv:2402.01282"
}
We provide this code under an MIT open-source licence with the hope that it will be of use to a wider community.
Copyright 2024 Louise Mousset, Matthew Price, Erwan Allys and Jason McEwen
S2SCAT
is free software made available under the MIT License. For
details see the LICENSE file.