User Experience on mobile might not be great yet, but I'm working on it.

Your first time on this page? Allow me to give some explanations.

Awesome JAX

JAX - A curated list of resources https://github.com/google/jax

Here you can see meta information about this topic like the time we last updated this page, the original creator of the awesome list and a link to the original GitHub repository.

Last Update: Nov. 30, 2021, 11:19 a.m.

Thank you n2cholas & contributors
View Topic on GitHub:
n2cholas/awesome-jax

Search for resources by name or description.
Simply type in what you are looking for and the results will be filtered on the fly.

Further filter the resources on this page by type (repository/other resource), number of stars on GitHub and time of last commit in months.

Libraries

Flax is a neural network library for JAX that is designed for flexibility.

2.26K
270
27d
Apache-2.0

JAX-based neural network library

1.45K
113
27d
Apache-2.0
661
58
71d
Apache-2.0

Trax โ€” Deep Learning with Clear Code and Speed

6.55K
659
34d
Apache-2.0

A Graph Neural Network Library in Jax

719
38
41d
Apache-2.0

Fast and Easy Infinite Neural Networks in Python

1.58K
179
32d
n/a

๐Ÿค— Transformers: State-of-the-art Natural Language Processing for Pytorch, TensorFlow, and JAX.

53.47K
12.7K
27d
Apache-2.0

Callable PyTrees and filtered JIT/grad transformations => neural networks in JAX.

133
9
28d
Apache-2.0

Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU.

1.15K
118
26d
Apache-2.0
330
15
27d
Apache-2.0

Optax is a gradient processing and optimization library for JAX.

511
45
33d
Apache-2.0
694
54
54d
Apache-2.0

Differentiable, Hardware Accelerated, Molecular Dynamics

571
90
27d
n/a

Modular framework for Reinforcement Learning in python

60
4
27d
MIT

Documentation:

95
5
11m
Apache-2.0

Express & compile probabilistic programs for performant inference on CPU & GPU. Powered by JAX.

314
18
4m
Apache-2.0
231
10
33d
Apache-2.0

Differentiable convex optimization layers

1.12K
83
7m
Apache-2.0

TensorLy: Tensor Learning in Python.

1.11K
220
27d
n/a

Machine learning algorithms for many-body quantum systems

302
121
27d
Apache-2.0

New Libraries

FedJAX is a JAX-based open source library for Federated Learning simulations that emphasizes ease-of-use in research.

137
23
33d
Apache-2.0

A library for programmatically generating equivariant layers through constraint solving

153
9
27d
MIT

Implementations and checkpoints for ResNet, Wide ResNet, ResNeXt, ResNet-D, and ResNeSt in JAX (Flax).

50
2
111d
MIT

Reimplementation of the UniRep protein featurization model.

80
22
47d
GPL-3.0

Normalizing Flows in JAX ๐ŸŒŠ

164
9
83d
MIT

Composable kernels for scikit-learn implemented in JAX.

26
2
1y 35d
n/a

A differentiable cosmology library in JAX

70
13
31d
MIT

Exponential families for JAX

15
1
47d
MIT

Zero-copy MPI communication of JAX arrays, for turbo-charged HPC applications in Python

103
6
90d
MIT

Image augmentation library for Jax

16
1
6m
Apache-2.0

A selection of neural network models ported from torchvision for JAX & Flax.

15
0
11m
Apache-2.0

Probabilistic reasoning and statistical analysis in TensorFlow

3.52K
944
26d
Apache-2.0
193
16
49d
Apache-2.0

A photovoltaic simulator with automatic differentation

6
0
4m
MIT

Lie groups for JAX!

94
4
91d
MIT

Massively parallel rigidbody physics simulation on accelerator hardware.

1.01K
84
28d
Apache-2.0

Pretrained models for Jax/Flax: StyleGAN2, GPT2, VGG, ResNet.

72
3
45d
n/a

Functional models and algorithms for sparse signal processing

15
0
29d
Apache-2.0

๐Ÿˆ Automatic differentiable spectrum modeling of exoplanets/brown dwarfs compatible to JAX.

12
4
27d
MIT

Hardware accelerated, batchable and differentiable optimizers in JAX.

298
14
28d
Apache-2.0

PIX is an image processing library in JAX, for JAX.

193
10
29d
Apache-2.0

Bayesian Optimization Python Library powered by JAX

12
0
28d
MIT

A JAX-based research framework for writing differentiable numerical simulators with arbitrary discretizations

3
0
14d
LGPL-3.0

JAX

Fourier Features Let Networks Learn High Frequency Functions in Low Dimensional Domains

650
52
10m
MIT

Approximate inference for Markov Gaussian processes using iterated Kalman smoothing, in JAX

76
11
8m
Apache-2.0

A didactic Gaussian process package for researchers in Jax.

24
5
94d
Apache-2.0

Nested sampling in JAX

43
3
29d
CC-BY-SA-4.0

Google Research

20.17K
4.59K
27d
Apache-2.0

Flax

Google Research

20.17K
4.59K
27d
Apache-2.0

Official repository for the "Big Transfer (BiT): General Visual Representation Learning" paper.

1.27K
148
97d
Apache-2.0

JAX (Flax) implementation of algorithms for Deep Reinforcement Learning with continuous action spaces.

230
12
29d
MIT

Flax implementation of gMLP from "Pay Attention to MLPs" (https://arxiv.org/abs/2105.08050)

2
0
5m
MIT

Minimal Flax implementation of MLP-Mixer from "MLP-Mixer: An all-MLP Architecture for Vision" (https://arxiv.org/abs/2105.01601)

5
2
5m
MIT

Aggregating Nested Transformer https://arxiv.org/pdf/2105.12723.pdf

79
8
4m
Apache-2.0
3.83K
487
55d
Apache-2.0

FID computation in Jax/Flax.

2
0
95d
Apache-2.0

Haiku

Open source code for AlphaFold.

7.01K
1.03K
28d
Apache-2.0

This repository contains implementations and illustrative code to accompany DeepMind publications

8.7K
1.72K
35d
Apache-2.0

Normalizing Flows using JAX

60
3
39d
n/a

Google Research

20.17K
4.59K
27d
Apache-2.0

Trax

Trax โ€” Deep Learning with Clear Code and Speed

6.55K
659
34d
Apache-2.0

Videos

Papers

Tutorials and Blog Posts

How to use the Flax Linen API to build a convolutional neural network model and train it for image classification (using TensorFlow Datasets).

5
1
5m
n/a

Extending JAX with custom C++ and CUDA code

165
7
10m
MIT

Community

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

14.84K
1.38K
26d
Apache-2.0