{"slug": "n2cholas--awesome-jax", "title": "Jax", "description": "JAX - A curated list of resources https://github.com/google/jax", "github_url": "https://github.com/n2cholas/awesome-jax", "stars": "2K", "tag": "Computer Science", "entry_count": 200, "subcategory_count": 8, "subcategories": [{"name": "General", "parent": "", "entries": [{"name": "Libraries", "url": "#libraries", "description": ""}, {"name": "Models and Projects", "url": "#models-and-projects", "description": ""}, {"name": "Videos", "url": "#videos", "description": ""}, {"name": "Papers", "url": "#papers", "description": ""}, {"name": "Tutorials and Blog Posts", "url": "#tutorials-and-blog-posts", "description": ""}, {"name": "Books", "url": "#books", "description": ""}, {"name": "Community", "url": "#community", "description": ""}, {"name": "Levanter", "url": "https://github.com/stanford-crfm/levanter", "description": "Legible, Scalable, Reproducible Foundation Models with Named Tensors and JAX. ", "stars": "690"}, {"name": "EasyLM", "url": "https://github.com/young-geng/EasyLM", "description": "LLMs made easy: Pre-training, finetuning, evaluating and serving LLMs in JAX/Flax. ", "stars": "2.5k"}, {"name": "NumPyro", "url": "https://github.com/pyro-ppl/numpyro", "description": "Probabilistic programming based on the Pyro library. ", "stars": "2.6k"}, {"name": "Chex", "url": "https://github.com/deepmind/chex", "description": "Utilities to write and test reliable JAX code. ", "stars": "918"}, {"name": "Optax", "url": "https://github.com/deepmind/optax", "description": "Gradient processing and optimization library. ", "stars": "2.2k"}, {"name": "RLax", "url": "https://github.com/deepmind/rlax", "description": "Library for implementing reinforcement learning agents. ", "stars": "1.4k"}, {"name": "JAX, M.D.", "url": "https://github.com/google/jax-md", "description": "Accelerated, differential molecular dynamics. ", "stars": "1.4k"}, {"name": "Coax", "url": "https://github.com/coax-dev/coax", "description": "Turn RL papers into code, the easy way. ", "stars": "181"}, {"name": "Distrax", "url": "https://github.com/deepmind/distrax", "description": "Reimplementation of TensorFlow Probability, containing probability distributions and bijectors. ", "stars": "615"}, {"name": "cvxpylayers", "url": "https://github.com/cvxgrp/cvxpylayers", "description": "Construct differentiable convex optimization layers. ", "stars": "2k"}, {"name": "TensorLy", "url": "https://github.com/tensorly/tensorly", "description": "Tensor learning made simple. ", "stars": "1.7k"}, {"name": "NetKet", "url": "https://github.com/netket/netket", "description": "Machine Learning toolbox for Quantum Physics. ", "stars": "655"}, {"name": "Fortuna", "url": "https://github.com/awslabs/fortuna", "description": "AWS library for Uncertainty Quantification in Deep Learning. ", "stars": "924"}, {"name": "BlackJAX", "url": "https://github.com/blackjax-devs/blackjax", "description": "Library of samplers for JAX. ", "stars": "1k"}, {"name": "Dynamax", "url": "https://github.com/probml/dynamax", "description": "Probabilistic state space models. ", "stars": "908"}]}, {"name": "New Libraries", "parent": "Libraries", "entries": [{"name": "jax-unirep", "url": "https://github.com/ElArkk/jax-unirep", "description": "Library implementing the [UniRep model](https://www.nature.com/articles/s41592-019-0598-1) for protein machine learning applications. ", "stars": "108"}, {"name": "flowjax", "url": "https://github.com/danielward27/flowjax", "description": "Distributions and normalizing flows built as equinox modules. ", "stars": "209"}, {"name": "flaxdiff", "url": "https://github.com/AshishKumar4/FlaxDiff", "description": "Framework and Library for building and training Diffusion models in multi-node multi-device distributed settings (TPUs) ", "stars": "40"}, {"name": "jax-flows", "url": "https://github.com/ChrisWaites/jax-flows", "description": "Normalizing flows in JAX. "}, {"name": "sklearn-jax-kernels", "url": "https://github.com/ExpectationMax/sklearn-jax-kernels", "description": "`scikit-learn` kernel matrices using JAX. ", "stars": "45"}, {"name": "jax-cosmo", "url": "https://github.com/DifferentiableUniverseInitiative/jax_cosmo", "description": "Differentiable cosmology library. ", "stars": "221"}, {"name": "efax", "url": "https://github.com/NeilGirdhar/efax", "description": "Exponential Families in JAX. ", "stars": "76"}, {"name": "mpi4jax", "url": "https://github.com/PhilipVinc/mpi4jax", "description": "Combine MPI operations with your Jax code on CPUs and GPUs. ", "stars": "512"}, {"name": "imax", "url": "https://github.com/4rtemi5/imax", "description": "Image augmentations and transformations. ", "stars": "41"}, {"name": "FlaxVision", "url": "https://github.com/rolandgvc/flaxvision", "description": "Flax version of TorchVision. ", "stars": "45"}, {"name": "Oryx", "url": "https://github.com/tensorflow/probability/tree/master/spinoffs/oryx", "description": "Probabilistic programming language based on program transformations.", "stars": "4.4k"}, {"name": "Optimal Transport Tools", "url": "https://github.com/google-research/ott", "description": "Toolbox that bundles utilities to solve optimal transport problems.", "stars": "214"}, {"name": "delta PV", "url": "https://github.com/romanodev/deltapv", "description": "A photovoltaic simulator with automatic differentation. ", "stars": "64"}, {"name": "jaxlie", "url": "https://github.com/brentyi/jaxlie", "description": "Lie theory library for rigid body transformations and optimization. ", "stars": "317"}, {"name": "BRAX", "url": "https://github.com/google/brax", "description": "Differentiable physics engine to simulate environments along with learning algorithms to train agents for these environments. ", "stars": "3k"}, {"name": "flaxmodels", "url": "https://github.com/matthias-wright/flaxmodels", "description": "Pretrained models for Jax/Flax. ", "stars": "265"}, {"name": "CR.Sparse", "url": "https://github.com/carnotresearch/cr-sparse", "description": "XLA accelerated algorithms for sparse representations and compressive sensing. ", "stars": "96"}, {"name": "exojax", "url": "https://github.com/HajimeKawahara/exojax", "description": "Automatic differentiable spectrum modeling of exoplanets/brown dwarfs compatible to JAX. ", "stars": "66"}, {"name": "PIX", "url": "https://github.com/deepmind/dm_pix", "description": "PIX is an image processing library in JAX, for JAX. ", "stars": "432"}, {"name": "bayex", "url": "https://github.com/alonfnt/bayex", "description": "Bayesian Optimization powered by JAX. ", "stars": "101"}, {"name": "JaxDF", "url": "https://github.com/ucl-bug/jaxdf", "description": "Framework for differentiable simulators with arbitrary discretizations. ", "stars": "131"}, {"name": "tree-math", "url": "https://github.com/google/tree-math", "description": "Convert functions that operate on arrays into functions that operate on PyTrees. ", "stars": "206"}, {"name": "jax-models", "url": "https://github.com/DarshanDeshpande/jax-models", "description": "Implementations of research papers originally without code or code written with frameworks other than JAX. ", "stars": "160"}, {"name": "PGMax", "url": "https://github.com/vicariousinc/PGMax", "description": "A framework for building discrete Probabilistic Graphical Models (PGM's) and running inference inference on them via JAX. ", "stars": "65"}, {"name": "EvoJAX", "url": "https://github.com/google/evojax", "description": "Hardware-Accelerated Neuroevolution ", "stars": "931"}, {"name": "evosax", "url": "https://github.com/RobertTLange/evosax", "description": "JAX-Based Evolution Strategies ", "stars": "718"}, {"name": "SymJAX", "url": "https://github.com/SymJAX/SymJAX", "description": "Symbolic CPU/GPU/TPU programming. ", "stars": "125"}, {"name": "mcx", "url": "https://github.com/rlouf/mcx", "description": "Express & compile probabilistic programs for performant inference. ", "stars": "329"}, {"name": "Einshape", "url": "https://github.com/deepmind/einshape", "description": "DSL-based reshaping library for JAX and other frameworks. ", "stars": "107"}, {"name": "ALX", "url": "https://github.com/google-research/google-research/tree/master/alx", "description": "Open-source library for distributed matrix factorization using Alternating Least Squares, more info in [*ALX: Large Scale Matrix Factorization on TPUs*](https://arxiv.org/abs/2112.02194).", "stars": "37k"}, {"name": "Diffrax", "url": "https://github.com/patrick-kidger/diffrax", "description": "Numerical differential equation solvers in JAX. ", "stars": "1.9k"}, {"name": "tinygp", "url": "https://github.com/dfm/tinygp", "description": "The *tiniest* of Gaussian process libraries in JAX. ", "stars": "328"}, {"name": "gymnax", "url": "https://github.com/RobertTLange/gymnax", "description": "Reinforcement Learning Environments with the well-known gym API. ", "stars": "847"}, {"name": "Mctx", "url": "https://github.com/deepmind/mctx", "description": "Monte Carlo tree search algorithms in native JAX. ", "stars": "2.6k"}, {"name": "KFAC-JAX", "url": "https://github.com/deepmind/kfac-jax", "description": "Second Order Optimization with Approximate Curvature for NNs. ", "stars": "308"}, {"name": "TF2JAX", "url": "https://github.com/deepmind/tf2jax", "description": "Convert functions/graphs to JAX functions. ", "stars": "120"}, {"name": "jwave", "url": "https://github.com/ucl-bug/jwave", "description": "A library for differentiable acoustic simulations ", "stars": "190"}, {"name": "GPJax", "url": "https://github.com/thomaspinder/GPJax", "description": "Gaussian processes in JAX.", "stars": "576"}, {"name": "Jumanji", "url": "https://github.com/instadeepai/jumanji", "description": "A Suite of Industry-Driven Hardware-Accelerated RL Environments written in JAX. ", "stars": "798"}, {"name": "Eqxvision", "url": "https://github.com/paganpasta/eqxvision", "description": "Equinox version of Torchvision. ", "stars": "110"}, {"name": "JAXFit", "url": "https://github.com/dipolar-quantum-gases/jaxfit", "description": "Accelerated curve fitting library for nonlinear least-squares problems (see [arXiv paper](https://arxiv.org/abs/2208.12187)). ", "stars": "59"}, {"name": "econpizza", "url": "https://github.com/gboehl/econpizza", "description": "Solve macroeconomic models with hetereogeneous agents using JAX. ", "stars": "107"}, {"name": "SPU", "url": "https://github.com/secretflow/spu", "description": "A domain-specific compiler and runtime suite to run JAX code with MPC(Secure Multi-Party Computation). ", "stars": "316"}, {"name": "jax-tqdm", "url": "https://github.com/jeremiecoullon/jax-tqdm", "description": "Add a tqdm progress bar to JAX scans and loops. ", "stars": "124"}, {"name": "safejax", "url": "https://github.com/alvarobartt/safejax", "description": "Serialize JAX, Flax, Haiku, or Objax model params with \ud83e\udd17`safetensors`. ", "stars": "47"}, {"name": "Kernex", "url": "https://github.com/ASEM000/kernex", "description": "Differentiable stencil decorators in JAX. ", "stars": "71"}, {"name": "MaxText", "url": "https://github.com/google/maxtext", "description": "A simple, performant and scalable Jax LLM written in pure Python/Jax and targeting Google Cloud TPUs. ", "stars": "2.1k"}, {"name": "Pax", "url": "https://github.com/google/paxml", "description": "A Jax-based machine learning framework for training large scale models. ", "stars": "544"}, {"name": "Praxis", "url": "https://github.com/google/praxis", "description": "The layer library for Pax with a goal to be usable by other JAX-based ML projects. ", "stars": "192"}, {"name": "purejaxrl", "url": "https://github.com/luchris429/purejaxrl", "description": "Vectorisable, end-to-end RL algorithms in JAX. ", "stars": "1k"}, {"name": "Lorax", "url": "https://github.com/davisyoshida/lorax", "description": "Automatically apply LoRA to JAX models (Flax, Haiku, etc.)", "stars": "143"}, {"name": "SCICO", "url": "https://github.com/lanl/scico", "description": "Scientific computational imaging in JAX. ", "stars": "147"}, {"name": "Spyx", "url": "https://github.com/kmheckel/spyx", "description": "Spiking Neural Networks in JAX for machine learning on neuromorphic hardware. ", "stars": "131"}, {"name": "OTT-JAX", "url": "https://github.com/ott-jax/ott", "description": "Optimal transport tools in JAX. ", "stars": "691"}, {"name": "QDax", "url": "https://github.com/adaptive-intelligent-robotics/QDax", "description": "Quality Diversity optimization in Jax. ", "stars": "335"}, {"name": "JAX Toolbox", "url": "https://github.com/NVIDIA/JAX-Toolbox", "description": "Nightly CI and optimized examples for JAX on NVIDIA GPUs using libraries such as T5x, Paxml, and Transformer Engine. ", "stars": "377"}, {"name": "Pgx", "url": "http://github.com/sotetsuk/pgx", "description": "Vectorized board game environments for RL with an AlphaZero example. ", "stars": "576"}, {"name": "EasyDeL", "url": "https://github.com/erfanzar/EasyDeL", "description": "EasyDeL \ud83d\udd2e is an OpenSource Library to make your training faster and more Optimized With cool Options for training and serving (Llama, MPT, Mixtral, Falcon, etc) in JAX ", "stars": "330"}, {"name": "XLB", "url": "https://github.com/Autodesk/XLB", "description": "A Differentiable Massively Parallel Lattice Boltzmann Library in Python for Physics-Based Machine Learning. ", "stars": "430"}, {"name": "dynamiqs", "url": "https://github.com/dynamiqs/dynamiqs", "description": "High-performance and differentiable simulations of quantum systems with JAX. ", "stars": "266"}, {"name": "foragax", "url": "https://github.com/i-m-iron-man/Foragax", "description": "Agent-Based modelling framework in JAX. ", "stars": "5"}, {"name": "tmmax", "url": "https://github.com/bahremsd/tmmax", "description": "Vectorized calculation of optical properties in thin-film structures using JAX. Swiss Army knife tool for thin-film optics research ", "stars": "28"}, {"name": "Coreax", "url": "https://github.com/gchq/coreax", "description": "Algorithms for finding coresets to compress large datasets while retaining their statistical properties. ", "stars": "37"}, {"name": "NAVIX", "url": "https://github.com/epignatelli/navix", "description": "A reimplementation of MiniGrid, a Reinforcement Learning environment, in JAX ", "stars": "154"}, {"name": "FDTDX", "url": "https://github.com/ymahlau/fdtdx", "description": "Finite-Difference Time-Domain Electromagnetic Simulations in JAX ", "stars": "207"}, {"name": "DiffeRT", "url": "https://github.com/jeertmans/DiffeRT", "description": "Differentiable Ray Tracing toolbox for Radio Propagation powered by the JAX ecosystem. ", "stars": "48"}, {"name": "JAX-in-Cell", "url": "https://github.com/uwplasma/JAX-in-Cell", "description": "Plasma physics simulations using a PIC (Particle-in-Cell) method to self-consistently solve for electron and ion dynamics in electromagnetic fields ", "stars": "22"}, {"name": "kvax", "url": "https://github.com/nebius/kvax", "description": "A FlashAttention implementation for JAX with support for efficient document mask computation and context parallelism. ", "stars": "156"}, {"name": "astronomix", "url": "https://github.com/leo1200/astronomix", "description": "differentiable (magneto)hydrodynamics for astrophysics in JAX ", "stars": "44"}, {"name": "vivsim", "url": "https://github.com/haimingz/vivsim", "description": "Fluid-structure interaction simulations using Immersed Boundary-Lattice Boltzmann Method. ", "stars": "30"}, {"name": "MBIRJAX", "url": "https://github.com/cabouman/mbirjax", "description": "High-performance tomographic reconstruction. ", "stars": "19"}, {"name": "torchax", "url": "https://github.com/google/torchax/", "description": "torchax is a library for Jax to interoperate with model code written in PyTorch.", "stars": "166"}]}, {"name": "JAX", "parent": "Models and Projects", "entries": [{"name": "Fourier Feature Networks", "url": "https://github.com/tancik/fourier-feature-networks", "description": "Official implementation of [*Fourier Features Let Networks Learn High Frequency Functions in Low Dimensional Domains*](https://people.eecs.berkeley.edu/~bmild/fourfeat).", "stars": "1.4k"}, {"name": "kalman-jax", "url": "https://github.com/AaltoML/kalman-jax", "description": "Approximate inference for Markov (i.e., temporal) Gaussian processes using iterated Kalman filtering and smoothing.", "stars": "103"}, {"name": "jaxns", "url": "https://github.com/Joshuaalbert/jaxns", "description": "Nested sampling in JAX.", "stars": "218"}, {"name": "Amortized Bayesian Optimization", "url": "https://github.com/google-research/google-research/tree/master/amortized_bo", "description": "Code related to [*Amortized Bayesian Optimization over Discrete Spaces*](http://www.auai.org/uai2020/proceedings/329_main_paper.pdf).", "stars": "37k"}, {"name": "Accurate Quantized Training", "url": "https://github.com/google-research/google-research/tree/master/aqt", "description": "Tools and libraries for running and analyzing neural network quantization experiments in JAX and Flax.", "stars": "37k"}, {"name": "BNN-HMC", "url": "https://github.com/google-research/google-research/tree/master/bnn_hmc", "description": "Implementation for the paper [*What Are Bayesian Neural Network Posteriors Really Like?*](https://arxiv.org/abs/2104.14421).", "stars": "37k"}, {"name": "JAX-DFT", "url": "https://github.com/google-research/google-research/tree/master/jax_dft", "description": "One-dimensional density functional theory (DFT) in JAX, with implementation of [*Kohn-Sham equations as regularizer: building prior knowledge into machine-learned physics*](https://journals.aps.org/prl/abstract/10.1103/PhysRevLett.126.036401).", "stars": "37k"}, {"name": "Robust Loss", "url": "https://github.com/google-research/google-research/tree/master/robust_loss_jax", "description": "Reference code for the paper [*A General and Adaptive Robust Loss Function*](https://arxiv.org/abs/1701.03077).", "stars": "37k"}, {"name": "Symbolic Functionals", "url": "https://github.com/google-research/google-research/tree/master/symbolic_functionals", "description": "Demonstration from [*Evolving symbolic density functionals*](https://arxiv.org/abs/2203.02540).", "stars": "37k"}, {"name": "TriMap", "url": "https://github.com/google-research/google-research/tree/master/trimap", "description": "Official JAX implementation of [*TriMap: Large-scale Dimensionality Reduction Using Triplets*](https://arxiv.org/abs/1910.00204).", "stars": "37k"}]}, {"name": "Flax", "parent": "Models and Projects", "entries": [{"name": "awesome-jax-flax-llms", "url": "https://github.com/your-username/awesome-jax-flax-llms", "description": "Collection of LLMs implemented in **JAX** & **Flax**"}, {"name": "DeepSeek-R1-Flax-1.5B-Distill", "url": "https://github.com/J-Rosser-UK/Torch2Jax-DeepSeek-R1-Distill-Qwen-1.5B", "description": "Flax implementation of DeepSeek-R1 1.5B distilled reasoning LLM.", "stars": "26"}, {"name": "Performer", "url": "https://github.com/google-research/google-research/tree/master/performer/fast_attention/jax", "description": "Flax implementation of the Performer (linear transformer via FAVOR+) architecture.", "stars": "37k"}, {"name": "JaxNeRF", "url": "https://github.com/google-research/google-research/tree/master/jaxnerf", "description": "Implementation of [*NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis*](http://www.matthewtancik.com/nerf) with multi-device GPU/TPU support.", "stars": "37k"}, {"name": "mip-NeRF", "url": "https://github.com/google/mipnerf", "description": "Official implementation of [*Mip-NeRF: A Multiscale Representation for Anti-Aliasing Neural Radiance Fields*](https://jonbarron.info/mipnerf).", "stars": "937"}, {"name": "RegNeRF", "url": "https://github.com/google-research/google-research/tree/master/regnerf", "description": "Official implementation of [*RegNeRF: Regularizing Neural Radiance Fields for View Synthesis from Sparse Inputs*](https://m-niemeyer.github.io/regnerf/).", "stars": "37k"}, {"name": "JaxNeuS", "url": "https://github.com/huangjuite/jaxneus", "description": "Implementation of [*NeuS: Learning Neural Implicit Surfaces by Volume Rendering for Multi-view Reconstruction*](https://lingjie0206.github.io/papers/NeuS/)", "stars": "1"}, {"name": "Big Transfer (BiT)", "url": "https://github.com/google-research/big_transfer", "description": "Implementation of [*Big Transfer (BiT): General Visual Representation Learning*](https://arxiv.org/abs/1912.11370).", "stars": "1.5k"}, {"name": "JAX RL", "url": "https://github.com/ikostrikov/jax-rl", "description": "Implementations of reinforcement learning algorithms.", "stars": "741"}, {"name": "gMLP", "url": "https://github.com/SauravMaheshkar/gMLP", "description": "Implementation of [*Pay Attention to MLPs*](https://arxiv.org/abs/2105.08050)."}, {"name": "MLP Mixer", "url": "https://github.com/SauravMaheshkar/MLP-Mixer", "description": "Minimal implementation of [*MLP-Mixer: An all-MLP Architecture for Vision*](https://arxiv.org/abs/2105.01601)."}, {"name": "Distributed Shampoo", "url": "https://github.com/google-research/google-research/tree/master/scalable_shampoo", "description": "Implementation of [*Second Order Optimization Made Practical*](https://arxiv.org/abs/2002.09018).", "stars": "37k"}, {"name": "NesT", "url": "https://github.com/google-research/nested-transformer", "description": "Official implementation of [*Aggregating Nested Transformers*](https://arxiv.org/abs/2105.12723).", "stars": "201"}, {"name": "XMC-GAN", "url": "https://github.com/google-research/xmcgan_image_generation", "description": "Official implementation of [*Cross-Modal Contrastive Learning for Text-to-Image Generation*](https://arxiv.org/abs/2101.04702).", "stars": "96"}, {"name": "FNet", "url": "https://github.com/google-research/google-research/tree/master/f_net", "description": "Official implementation of [*FNet: Mixing Tokens with Fourier Transforms*](https://arxiv.org/abs/2105.03824).", "stars": "37k"}, {"name": "GFSA", "url": "https://github.com/google-research/google-research/tree/master/gfsa", "description": "Official implementation of [*Learning Graph Structure With A Finite-State Automaton Layer*](https://arxiv.org/abs/2007.04929).", "stars": "37k"}, {"name": "IPA-GNN", "url": "https://github.com/google-research/google-research/tree/master/ipagnn", "description": "Official implementation of [*Learning to Execute Programs with Instruction Pointer Attention Graph Neural Networks*](https://arxiv.org/abs/2010.12621).", "stars": "37k"}, {"name": "Flax Models", "url": "https://github.com/google-research/google-research/tree/master/flax_models", "description": "Collection of models and methods implemented in Flax.", "stars": "37k"}, {"name": "Protein LM", "url": "https://github.com/google-research/google-research/tree/master/protein_lm", "description": "Implements BERT and autoregressive models for proteins, as described in [*Biological Structure and Function Emerge from Scaling Unsupervised Learning to 250 Million Protein Sequences*](https://www.biorxiv.org/content/10.1101/622803v1.full) and [*ProGen: Language Modeling for Protein Generation*](https://www.biorxiv.org/content/10.1101/2020.03.07.982272v2).", "stars": "37k"}, {"name": "Slot Attention", "url": "https://github.com/google-research/google-research/tree/master/ptopk_patch_selection", "description": "Reference implementation for [*Differentiable Patch Selection for Image Recognition*](https://arxiv.org/abs/2104.03059).", "stars": "37k"}, {"name": "Vision Transformer", "url": "https://github.com/google-research/vision_transformer", "description": "Official implementation of [*An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale*](https://arxiv.org/abs/2010.11929).", "stars": "12k"}, {"name": "FID computation", "url": "https://github.com/matthias-wright/jax-fid", "description": "Port of [mseitzer/pytorch-fid (\u2b503.8k)](https://github.com/mseitzer/pytorch-fid) to Flax.", "stars": "29"}, {"name": "ARDM", "url": "https://github.com/google-research/google-research/tree/master/autoregressive_diffusion", "description": "Official implementation of [*Autoregressive Diffusion Models*](https://arxiv.org/abs/2110.02037).", "stars": "37k"}, {"name": "D3PM", "url": "https://github.com/google-research/google-research/tree/master/d3pm", "description": "Official implementation of [*Structured Denoising Diffusion Models in Discrete State-Spaces*](https://arxiv.org/abs/2107.03006).", "stars": "37k"}, {"name": "Gumbel-max Causal Mechanisms", "url": "https://github.com/google-research/google-research/tree/master/gumbel_max_causal_gadgets", "description": "Code for [*Learning Generalized Gumbel-max Causal Mechanisms*](https://arxiv.org/abs/2111.06888), with extra code in [GuyLor/gumbel\\_max\\_causal\\_gadgets\\_part2 (\u2b502)](https://github.com/GuyLor/gumbel_max_causal_gadgets_part2).", "stars": "37k"}, {"name": "Latent Programmer", "url": "https://github.com/google-research/google-research/tree/master/latent_programmer", "description": "Code for the ICML 2021 paper [*Latent Programmer: Discrete Latent Codes for Program Synthesis*](https://arxiv.org/abs/2012.00377).", "stars": "37k"}, {"name": "SNeRG", "url": "https://github.com/google-research/google-research/tree/master/snerg", "description": "Official implementation of [*Baking Neural Radiance Fields for Real-Time View Synthesis*](https://phog.github.io/snerg).", "stars": "37k"}, {"name": "Spin-weighted Spherical CNNs", "url": "https://github.com/google-research/google-research/tree/master/spin_spherical_cnns", "description": "Adaptation of [*Spin-Weighted Spherical CNNs*](https://arxiv.org/abs/2006.10731).", "stars": "37k"}, {"name": "VDVAE", "url": "https://github.com/google-research/google-research/tree/master/vdvae_flax", "description": "Adaptation of [*Very Deep VAEs Generalize Autoregressive Models and Can Outperform Them on Images*](https://arxiv.org/abs/2011.10650), original code at [openai/vdvae (\u2b50449)](https://github.com/openai/vdvae).", "stars": "37k"}, {"name": "MUSIQ", "url": "https://github.com/google-research/google-research/tree/master/musiq", "description": "Checkpoints and model inference code for the ICCV 2021 paper [*MUSIQ: Multi-scale Image Quality Transformer*](https://arxiv.org/abs/2108.05997)", "stars": "37k"}, {"name": "AQuaDem", "url": "https://github.com/google-research/google-research/tree/master/aquadem", "description": "Official implementation of [*Continuous Control with Action Quantization from Demonstrations*](https://arxiv.org/abs/2110.10149).", "stars": "37k"}, {"name": "Combiner", "url": "https://github.com/google-research/google-research/tree/master/combiner", "description": "Official implementation of [*Combiner: Full Attention Transformer with Sparse Computation Cost*](https://arxiv.org/abs/2107.05768).", "stars": "37k"}, {"name": "Dreamfields", "url": "https://github.com/google-research/google-research/tree/master/dreamfields", "description": "Official implementation of the ICLR 2022 paper [*Progressive Distillation for Fast Sampling of Diffusion Models*](https://ajayj.com/dreamfields).", "stars": "37k"}, {"name": "GIFT", "url": "https://github.com/google-research/google-research/tree/master/gift", "description": "Official implementation of [*Gradual Domain Adaptation in the Wild:When Intermediate Distributions are Absent*](https://arxiv.org/abs/2106.06080).", "stars": "37k"}, {"name": "Light Field Neural Rendering", "url": "https://github.com/google-research/google-research/tree/master/light_field_neural_rendering", "description": "Official implementation of [*Light Field Neural Rendering*](https://arxiv.org/abs/2112.09687).", "stars": "37k"}, {"name": "Sharpened Cosine Similarity in JAX by Raphael Pisoni", "url": "https://colab.research.google.com/drive/1KUKFEMneQMS3OzPYnWZGkEnry3PdzCfn?usp=sharing", "description": "A JAX/Flax implementation of the Sharpened Cosine Similarity layer."}, {"name": "GNNs for Solving Combinatorial Optimization Problems", "url": "https://github.com/IvanIsCoding/GNN-for-Combinatorial-Optimization", "description": "A JAX + Flax implementation of [Combinatorial Optimization with Physics-Inspired Graph Neural Networks](https://arxiv.org/abs/2107.01188).", "stars": "64"}, {"name": "DETR", "url": "https://github.com/MasterSkepticista/detr", "description": "Flax implementation of [*DETR: End-to-end Object Detection with Transformers*](https://github.com/facebookresearch/detr) using Sinkhorn solver and parallel bipartite matching.", "stars": "8"}]}, {"name": "Haiku", "parent": "Models and Projects", "entries": [{"name": "AlphaFold", "url": "https://github.com/deepmind/alphafold", "description": "Implementation of the inference pipeline of AlphaFold v2.0, presented in [*Highly accurate protein structure prediction with AlphaFold*](https://www.nature.com/articles/s41586-021-03819-2).", "stars": "14k"}, {"name": "Adversarial Robustness", "url": "https://github.com/deepmind/deepmind-research/tree/master/adversarial_robustness", "description": "Reference code for [*Uncovering the Limits of Adversarial Training against Norm-Bounded Adversarial Examples*](https://arxiv.org/abs/2010.03593) and [*Fixing Data Augmentation to Improve Adversarial Robustness*](https://arxiv.org/abs/2103.01946).", "stars": "15k"}, {"name": "Bootstrap Your Own Latent", "url": "https://github.com/deepmind/deepmind-research/tree/master/byol", "description": "Implementation for the paper [*Bootstrap your own latent: A new approach to self-supervised Learning*](https://arxiv.org/abs/2006.07733).", "stars": "15k"}, {"name": "Gated Linear Networks", "url": "https://github.com/deepmind/deepmind-research/tree/master/gated_linear_networks", "description": "GLNs are a family of backpropagation-free neural networks.", "stars": "15k"}, {"name": "Glassy Dynamics", "url": "https://github.com/deepmind/deepmind-research/tree/master/glassy_dynamics", "description": "Open source implementation of the paper [*Unveiling the predictive power of static structure in glassy systems*](https://www.nature.com/articles/s41567-020-0842-8).", "stars": "15k"}, {"name": "MMV", "url": "https://github.com/deepmind/deepmind-research/tree/master/mmv", "description": "Code for the models in [*Self-Supervised MultiModal Versatile Networks*](https://arxiv.org/abs/2006.16228).", "stars": "15k"}, {"name": "Normalizer-Free Networks", "url": "https://github.com/deepmind/deepmind-research/tree/master/nfnets", "description": "Official Haiku implementation of [*NFNets*](https://arxiv.org/abs/2102.06171).", "stars": "15k"}, {"name": "NuX", "url": "https://github.com/Information-Fusion-Lab-Umass/NuX", "description": "Normalizing flows with JAX.", "stars": "86"}, {"name": "OGB-LSC", "url": "https://github.com/deepmind/deepmind-research/tree/master/ogb_lsc", "description": "This repository contains DeepMind's entry to the [PCQM4M-LSC](https://ogb.stanford.edu/kddcup2021/pcqm4m/) (quantum chemistry) and [MAG240M-LSC](https://ogb.stanford.edu/kddcup2021/mag240m/) (academic graph)", "stars": "15k"}, {"name": "Persistent Evolution Strategies", "url": "https://github.com/google-research/google-research/tree/master/persistent_es", "description": "Code used for the paper [*Unbiased Gradient Estimation in Unrolled Computation Graphs with Persistent Evolution Strategies*](http://proceedings.mlr.press/v139/vicol21a.html).", "stars": "37k"}, {"name": "Two Player Auction Learning", "url": "https://github.com/degregat/two-player-auctions", "description": "JAX implementation of the paper [*Auction learning as a two-player game*](https://arxiv.org/abs/2006.05684).", "stars": "0"}, {"name": "WikiGraphs", "url": "https://github.com/deepmind/deepmind-research/tree/master/wikigraphs", "description": "Baseline code to reproduce results in [*WikiGraphs: A Wikipedia Text - Knowledge Graph Paired Datase*](https://aclanthology.org/2021.textgraphs-1.7).", "stars": "15k"}]}, {"name": "Trax", "parent": "Models and Projects", "entries": [{"name": "Reformer", "url": "https://github.com/google/trax/tree/master/trax/models/reformer", "description": "Implementation of the Reformer (efficient transformer) architecture.", "stars": "8.3k"}]}, {"name": "NumPyro", "parent": "Models and Projects", "entries": [{"name": "lqg", "url": "https://github.com/RothkopfLab/lqg", "description": "Official implementation of Bayesian inverse optimal control for linear-quadratic Gaussian problems from the paper [*Putting perception into action with inverse optimal control for continuous psychophysics*](https://elifesciences.org/articles/76635)", "stars": "30"}]}, {"name": "Equinox", "parent": "Models and Projects", "entries": [{"name": "Sampling Path Candidates with Machine Learning", "url": "https://differt.eertmans.be/icmlcn2025/notebooks/sampling_paths.html", "description": "Official tutorial and implementation from the paper [*Towards Generative Ray Path Sampling for Faster Point-to-Point Ray Tracing*](https://arxiv.org/abs/2410.23773)."}, {"name": "NeurIPS 2020: JAX Ecosystem Meetup", "url": "https://www.youtube.com/watch?v=iDxJxIyzSiM", "description": "JAX, its use at DeepMind, and discussion between engineers, scientists, and JAX core team."}, {"name": "Introduction to JAX", "url": "https://youtu.be/0mVmRHMaOJ4", "description": "Simple neural network from scratch in JAX."}, {"name": "JAX: Accelerated Machine Learning Research | SciPy 2020 | VanderPlas", "url": "https://youtu.be/z-WSrQDXkuM", "description": "JAX's core design, how it's powering new research, and how you can start using it."}, {"name": "Bayesian Programming with JAX + NumPyro \u2014 Andy Kitchen", "url": "https://youtu.be/CecuWGpoztw", "description": "Introduction to Bayesian modelling using NumPyro."}, {"name": "JAX: Accelerated machine-learning research via composable function transformations in Python | NeurIPS 2019 | Skye Wanderman-Milne", "url": "https://slideslive.com/38923687/jax-accelerated-machinelearning-research-via-composable-function-transformations-in-python", "description": "JAX intro presentation in [*Program Transformations for Machine Learning*](https://program-transformations.github.io) workshop."}, {"name": "JAX on Cloud TPUs | NeurIPS 2020 | Skye Wanderman-Milne and James Bradbury", "url": "https://drive.google.com/file/d/1jKxefZT1xJDUxMman6qrQVed7vWI0MIn/edit", "description": "Presentation of TPU host access with demo."}, {"name": "Deep Implicit Layers - Neural ODEs, Deep Equilibirum Models, and Beyond | NeurIPS 2020", "url": "https://slideslive.com/38935810/deep-implicit-layers-neural-odes-equilibrium-models-and-beyond", "description": "Tutorial created by Zico Kolter, David Duvenaud, and Matt Johnson with Colab notebooks avaliable in [*Deep Implicit Layers*](http://implicit-layers-tutorial.org)."}, {"name": "Solving y=mx+b with Jax on a TPU Pod slice - Mat Kelcey", "url": "http://matpalm.com/blog/ymxb_pod_slice/", "description": "A four part YouTube tutorial series with Colab notebooks that starts with Jax fundamentals and moves up to training with a data parallel approach on a v3-32 TPU Pod slice."}, {"name": "JAX, Flax & Transformers \ud83e\udd17", "url": "https://github.com/huggingface/transformers/blob/9160d81c98854df44b1d543ce5d65a6aa28444a2/examples/research_projects/jax-projects/README.md#talks", "description": "3 days of talks around JAX / Flax, Transformers, large-scale language modeling and other great topics.", "stars": "155k"}, {"name": "**Compiling machine learning programs via high-level tracing**. Roy Frostig, Matthew James Johnson, Chris Leary. *MLSys 2018*.", "url": "https://mlsys.org/Conferences/doc/2018/146.pdf", "description": "White paper describing an early version of JAX, detailing how computation is traced and compiled."}, {"name": "**JAX, M.D.: A Framework for Differentiable Physics**. Samuel S. Schoenholz, Ekin D. Cubuk. *NeurIPS 2020*.", "url": "https://arxiv.org/abs/1912.04232", "description": "Introduces JAX, M.D., a differentiable physics library which includes simulation environments, interaction potentials, neural networks, and more."}, {"name": "**Enabling Fast Differentially Private SGD via Just-in-Time Compilation and Vectorization**. Pranav Subramani, Nicholas Vadivelu, Gautam Kamath. *arXiv 2020*.", "url": "https://arxiv.org/abs/2010.09063", "description": "Uses JAX's JIT and VMAP to achieve faster differentially private than existing libraries."}, {"name": "**XLB: A Differentiable Massively Parallel Lattice Boltzmann Library in Python**. Mohammadmehdi Ataei, Hesam Salehipour. *arXiv 2023*.", "url": "https://arxiv.org/abs/2311.16080", "description": "White paper describing the XLB library: benchmarks, validations, and more details about the library."}, {"name": "Using JAX to accelerate our research by David Budden and Matteo Hessel", "url": "https://deepmind.com/blog/article/using-jax-to-accelerate-our-research", "description": "Describes the state of JAX and the JAX ecosystem at DeepMind."}, {"name": "Getting started with JAX (MLPs, CNNs & RNNs) by Robert Lange", "url": "https://roberttlange.github.io/posts/2020/03/blog-post-10/", "description": "Neural network building blocks from scratch with the basic JAX operators."}, {"name": "Learn JAX: From Linear Regression to Neural Networks by Rito Ghosh", "url": "https://www.kaggle.com/code/truthr/jax-0", "description": "A gentle introduction to JAX and using it to implement Linear and Logistic Regression, and Neural Network models and using them to solve real world problems."}, {"name": "Tutorial: image classification with JAX and Flax Linen by 8bitmp3", "url": "https://github.com/8bitmp3/JAX-Flax-Tutorial-Image-Classification-with-Linen", "description": "Learn how to create a simple convolutional network with the Linen API by Flax and train it to recognize handwritten digits.", "stars": "25"}, {"name": "Plugging Into JAX by Nick Doiron", "url": "https://medium.com/swlh/plugging-into-jax-16c120ec3302", "description": "Compares Flax, Haiku, and Objax on the Kaggle flower classification challenge."}, {"name": "Meta-Learning in 50 Lines of JAX by Eric Jang", "url": "https://blog.evjang.com/2019/02/maml-jax.html", "description": "Introduction to both JAX and Meta-Learning."}, {"name": "Normalizing Flows in 100 Lines of JAX by Eric Jang", "url": "https://blog.evjang.com/2019/07/nf-jax.html", "description": "Concise implementation of [RealNVP](https://arxiv.org/abs/1605.08803)."}, {"name": "Differentiable Path Tracing on the GPU/TPU by Eric Jang", "url": "https://blog.evjang.com/2019/11/jaxpt.html", "description": "Tutorial on implementing path tracing."}, {"name": "Ensemble networks by Mat Kelcey", "url": "http://matpalm.com/blog/ensemble_nets", "description": "Ensemble nets are a method of representing an ensemble of models as one single logical model."}, {"name": "Out of distribution (OOD) detection by Mat Kelcey", "url": "http://matpalm.com/blog/ood_using_focal_loss", "description": "Implements different methods for OOD detection."}, {"name": "Understanding Autodiff with JAX by Srihari Radhakrishna", "url": "https://www.radx.in/jax.html", "description": "Understand how autodiff works using JAX."}, {"name": "From PyTorch to JAX: towards neural net frameworks that purify stateful code by Sabrina J. Mielke", "url": "https://sjmielke.com/jax-purify.htm", "description": "Showcases how to go from a PyTorch-like style of coding to a more Functional-style of coding."}, {"name": "Extending JAX with custom C++ and CUDA code by Dan Foreman-Mackey", "url": "https://github.com/dfm/extending-jax", "description": "Tutorial demonstrating the infrastructure required to provide custom ops in JAX.", "stars": "403"}, {"name": "Evolving Neural Networks in JAX by Robert Tjarko Lange", "url": "https://roberttlange.github.io/posts/2021/02/cma-es-jax/", "description": "Explores how JAX can power the next generation of scalable neuroevolution algorithms."}, {"name": "Exploring hyperparameter meta-loss landscapes with JAX by Luke Metz", "url": "http://lukemetz.com/exploring-hyperparameter-meta-loss-landscapes-with-jax/", "description": "Demonstrates how to use JAX to perform inner-loss optimization with SGD and Momentum, outer-loss optimization with gradients, and outer-loss optimization using evolutionary strategies."}, {"name": "Deterministic ADVI in JAX by Martin Ingram", "url": "https://martiningram.github.io/deterministic-advi/", "description": "Walk through of implementing automatic differentiation variational inference (ADVI) easily and cleanly with JAX."}, {"name": "Evolved channel selection by Mat Kelcey", "url": "http://matpalm.com/blog/evolved_channel_selection/", "description": "Trains a classification model robust to different combinations of input channels at different resolutions, then uses a genetic algorithm to decide the best combination for a particular loss."}, {"name": "Introduction to JAX by Kevin Murphy", "url": "https://colab.research.google.com/github/probml/probml-notebooks/blob/main/notebooks/jax_intro.ipynb", "description": "Colab that introduces various aspects of the language and applies them to simple ML problems."}, {"name": "Writing an MCMC sampler in JAX by Jeremie Coullon", "url": "https://www.jeremiecoullon.com/2020/11/10/mcmcjax3ways/", "description": "Tutorial on the different ways to write an MCMC sampler in JAX along with speed benchmarks."}, {"name": "How to add a progress bar to JAX scans and loops by Jeremie Coullon", "url": "https://www.jeremiecoullon.com/2021/01/29/jax_progress_bar/", "description": "Tutorial on how to add a progress bar to compiled loops in JAX using the `host_callback` module."}, {"name": "Get started with JAX by Aleksa Gordi\u0107", "url": "https://github.com/gordicaleksa/get-started-with-JAX", "description": "A series of notebooks and videos going from zero JAX knowledge to building neural networks in Haiku.", "stars": "775"}, {"name": "Writing a Training Loop in JAX + FLAX by Saurav Maheshkar and Soumik Rakshit", "url": "https://wandb.ai/jax-series/simple-training-loop/reports/Writing-a-Training-Loop-in-JAX-FLAX--VmlldzoyMzA4ODEy", "description": "A tutorial on writing a simple end-to-end training and evaluation pipeline in JAX, Flax and Optax."}, {"name": "Implementing NeRF in JAX by Soumik Rakshit and Saurav Maheshkar", "url": "https://wandb.ai/wandb/nerf-jax/reports/Implementing-NeRF-in-JAX--VmlldzoxODA2NDk2?galleryTag=jax", "description": "A tutorial on 3D volumetric rendering of scenes represented by Neural Radiance Fields in JAX."}, {"name": "Deep Learning tutorials with JAX+Flax by Phillip Lippe", "url": "https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/JAX/tutorial2/Introduction_to_JAX.html", "description": "A series of notebooks explaining various deep learning concepts, from basics (e.g. intro to JAX/Flax, activiation functions) to recent advances (e.g., Vision Transformers, SimCLR), with translations to PyTorch."}, {"name": "Achieving 4000x Speedups with PureJaxRL", "url": "https://chrislu.page/blog/meta-disco/", "description": "A blog post on how JAX can massively speedup RL training through vectorisation."}, {"name": "Simple PDE solver + Constrained Optimization with JAX by Philip Mocz", "url": "https://levelup.gitconnected.com/create-your-own-automatically-differentiable-simulation-with-python-jax-46951e120fbb?sk=e8b9213dd2c6a5895926b2695d28e4aa", "description": "A simple example of solving the advection-diffusion equations with JAX and using it in a constrained optimization problem to find initial conditions that yield desired result."}, {"name": "Jax in Action", "url": "https://www.manning.com/books/jax-in-action", "description": "A hands-on guide to using JAX for deep learning and other mathematically-intensive applications."}, {"name": "JaxLLM (Unofficial) Discord", "url": "https://discord.com/channels/1107832795377713302/1107832795688083561", "description": ""}, {"name": "JAX GitHub Discussions", "url": "https://github.com/google/jax/discussions", "description": "", "stars": "35k"}, {"name": "Reddit", "url": "https://www.reddit.com/r/JAX/", "description": ""}]}], "name": ""}