Another JAX Introduction: A Reinforcement Learning Perspective
Introduction
Using JAX will make writing hard code way easier, at the expense of writing simple code a little more difficult. Nevertheless, this expense is generally seen as a high bar to overcome. Hopefully this guide somewhat lowers this bar by providing an overview of the most common JAX intricacies with some examples specifically aimed at reinforcement learning.
This guide is only what I personally gathered and definitely does not encompass everything. However, even without understanding every concept on this page, it should be fairly doable to get started with JAX. For a more complete picture though, JAX has a pretty good documentation. In particular, the “common gothca’s” and “FAQ” page contain a lot of relevant information that I will not simply repeat here.
What is important to know, is that Jax revolves around the idea of function transformation, like @jit
, @grad
and @vmap
. These transformations allow us to level-up our regular Python functions by, for example, compiling them for GPU’s or vectorizing them.
We’ll go over each important transformation, and, in particular for @jit
, go over the restrictions that they bring on your code. These restrictions are what is generally seen as “high bar” of entering Jax.
@jit | Just-in-time Compilation
One of the most important JAX function transformation is @jit
, which will take a function and compile it into XLA when it is first called.
As we will see with any JAX transformation, this is surprisingly easy — we either apply a decorator or apply the transformation directly as a python function:
import jax
@jax.jit
def decorated_jitted_function(x):
return x
def some_function(y):
return y
some_jitted_function = jax.jit(some_function)
JIT compiling your code is generally great. First of all, Python is known to be quite slow, so compiling it to a lower-level language may alleviate this. More importantly perhaps, jax.jit()
can compile code not only for the CPU, but for the GPU and TPU as well. The compiled code will then subsequently run on the compiled platform.
While running your custom code on accelerated hardware used to be difficult, JAX allows us to easily run plain Python functions on the GPU.
JIT Implications
Unfortunately, jax.jit()
does not just allow for any Python function to be compiled, but has some restrictions on those functions. Fundamental concepts like conditionals, loops, and state, have to be done somewhat differently. This is likely what most people would struggle with when entering the JAX framework. As such, we go over most common implications and how to deal with them.
JIT phases
Excecuting JAX code provides us with two distinct phases. This is not a limitation on the written code, but is fundamental to understand when working with JAX. I will refer to these phases as the compilation phase, and the runtime phase. In the compilation phase, regular python is executed, primarily capturing the shapes and dtypes of your data and afterwards compiling the code into XLA. Importantly, some code will not make it into XLA, but will be run during the compilation phase. You can see this happening when printing values in JITted functions:
import jax
import jax.numpy as jnp
@jax.jit
def some_jitted_function(x):
y = x * 2
print("y in jit:", y)
return y
print("y output1:", some_jitted_function(jnp.array([1, 2, 3])))
print("y output2:", some_jitted_function(jnp.array([1, 2, 3])))
y in jit: Traced<ShapedArray(int32[3])>with<DynamicJaxprTrace(level=1/0)>
y output1: [2 4 6]
y output1: [2 4 6]
As you can see, a Traced<ShapedArray>
gets printed, this is essentially a dummy value used during the compilation phase with the correct dtype int32
and shape (3,)
.
Even though we run the code twice, the compilation only happens once (and the compiled code is cached), as a result, the Traced<ShapedArray>
also only prints once.
Note, how printing the dummy value does not capture the actual value, which causes some issues for debugging purposes. We dive deeper into solving those issues later.
Functional code (Pure functions)
The most glaring requirement of JAX is that any jit compiled function is required to be functional, disallowing any side-effects. In other words, all inputs should be part of the parameters and all the function results should be part of the outputs. This requirement takes some getting used to as most of us frequently use state of an object:
class MyGridWorld:
def reset(self):
self.X_loc = 0 # A single state variable
observation = self.X_loc # Fully observable
return observation
def step(self, action):
self.X_loc += action # only move in X direction
observation = self.X_loc
return observation #, reward, done, info ...
env = MyGridWorld()
jit_step_fn = jax.jit(env.step)
obs = env.reset()
for i in range(10):
action = 1
last_obs = jit_step_fn(action)
print(last_obs)
1
While the above runs, we see that our X_location has not moved 10 times, but just one.
The above is not allowed because updating self
is a side-effect. Obviously, this programmatic pattern of using internal state is used a lot, so we need an alternative. The most obvious one comes in the form of explicit state
:
class MyGridWorld:
def reset(self):
state = { # this doesn't have to be a dict
"loc_x": 0,
"loc_y": 0
}
return generate_observation(state), state
def step(self, state, action):
# state as input parameter
# output should contain the state for the next step
...
env = MyGridWorld()
obs, state = jax.jit(env.reset)()
(obs, reward, terminated, truncated, info), state = jax.jit(env.step)(state, action=0)
Here, we just use the state as an input to function calls, and return altered states from those functions.
Random numbers
Random numbers pop up a lot, but creating one in JAX is a little bit different compared to Numpy. That is again due to the problem of pure functions — when Numpy runs a np.random function, it secretely updates its inner seed. In JAX, jax.random functions require an explicit seed on each call instead:
key = jax.random.PRNKey(0)
random_int = jax.random.randint(key, 0, 10) # 0-10
The PRNKey is a special type required for each random function. Ideally, you create one at the start of your program and continuously split it into different keys:
rng = jax.random.PRNKey(42)
rng, key = jax.random.split(rng) # two new keys
# use one for the required operation
random_number = jax.random.uniform(key)
# you can also split into more keys
multiple_keys = jax.random.split(rng, 10) # a list of 10 new keys
While this is a bit more cumbersome, a nice added benefit is that randomness is more explicit and the user is forced to input a seed for every operation, improving reproducibility.
JAX types
XLA cannot deal with all kind of types, but fundamentally operates on numbers. Custom classes, and strings are then generally not valid inputs to JAX functions. This typically causes an issue with the methods:
class MyCustomClass():
@jax.jit
def step(self, action):
...
c = MyCustomClass()
c.step(1) # ❌ Error: self is an invalid Jax type
Dealing with methods specifically is pretty well explained in the Jax FAQ here, so I won’t repeat much here.
The easiest solutions would be to mark self
as static or to register the class as a custom PyTree
with Equinox.
JAX Primitives (and jnp)
JAX can only compile code that is compatible with XLA. JAX exposes the jax.lax
module that contains a large set of primitive operations that are compatible with XLA. these primitive operations will allow you to do most operations when working with data.
However, users of JAX are directed not to use the primitive module directly (where possible), and instead rely on more stable API’s. The most prominent of these is the JAX Numpy interface (jnp
). This interface basically reimplements almost all of Numpy into JAX primitives such that you can easily work a familier API while working with JAX.
While jax.numpy
is the largest and most important API, others exists such as jax.nn
and jax.scipy
.
Immutable arrays
It is worth noting here that, as apposed to the mutable Numpy arrays, JAX arrays are immutable. As a direct effect, you cannot perform familiar operations like x[0] = 5
. When you do this, a helpful error will show up directing you to a functional approach of updating an array: x.at[0].set(5)
.
Be aware that JAX does not error for out of bounds indexing, because this is difficult on accelerators. Instead, JAX will clip get
operations to be in-bounds and silently drop update operations (like .set()
). This behavior can be modified, so it is worth checking out the relevant documentation page.
When to use Numpy
You will quite often see JAX code which imports both Numpy as JAX:
import jax.numpy as jnp
import numpy as np
This is because JAX will evaluate Numpy arrays during the compilation phase, while evaluating JAX arrays at runtime. Because sometimes specific shapes must be known during compilation it may be useful to use Numpy. Consider below example from the JAX documentation:
x = jnp.ones((2, 3))
@jit
def f(x):
new_shape = jnp.array(x.shape).prod()
return x.reshape(new_shape)
f(x) # ❌ Error: output shape is not known during compilation
In the above example, new_shape
will be a Traced<ShapedArray>
object with an unknown value during compilation. However, because JAX requires all shapes and dtypes to be known at compile time, this results in an error.
In contrast, if we use Numpy, the shape will be evaluated during the compilation phase, and we have no issue:
x = jnp.ones((2, 3))
@jit
def f(x):
new_shape = np.array(x.shape).prod()
return x.reshape(new_shape)
f(x) # ✔️ All good
Conditionals
if
-statements are part of the regular python control flow that does not make it into XLA. They are however executed during the compilation phase.
This means that, as long as their conditional outcome is known at compile time, we can use them. As such, they are extremely helpful for control flow based environment or algorithmic options:
(obs, reward, done, info), state = env.step(...)
if done: # ❌ Error: done is not known at compile time.
obs, state = env.reset()
...
exploration_method = "egreedy" # constant value during runtime.
if exploration_method == "egreedy": # ✔️ All good
...
elif ...
Since the if
statement are removed from the compiled code, they are no longer evaluated during the runtime phase. This in itself may provide a healthy performance boost.
Of course, there are times where you do need runtime conditionals.
For this, JAX provides jax.lax.select
, and jax.lax.cond
.
select()
selects out of two existing elements, whereas cond()
will execute a function based on a boolean:
(obs, reward, done, info), state = env.step(...)
# Using cond:
obs, state = jax.lax.cond(
done,
lambda: env.reset(...),
lambda: (obs, state),
# Operand <-- Optional argument for the functions
)
# Using select:
obs_reset, state_reset = env.reset(...)
obs, state = jax.lax.select(
done,
(obs_reset, state_reset),
(obs, state)
)
Note that the outputs of the true/false paths in both select()
and cond()
, must match. I.e. obs
and obs_reset
, must have the exact same structure and underlying dtypes.
Finally, remember that we can often use jnp.where()
, which acts as select()
under the hood, but conveniently interfaces as the familier np.where()
function.
Loops
Another fundamental concept that are quite different in JAX are loops. regular for loops behave different than you would expect, and we have other options to choose from. Let’s go over each of these.
Regular for loops
First off all, for
(and while
) loops are still available in your JAX code. However, the loops are restricted to to fixed length loops and, perhaps more importantly, they are flattened into sequential XLA. This means, the following two blocks result in equivalent code after compilation:
for layer in nn_layers:
x = layer(x)
# Equivalent to:
x = nn_layers[0]
x = nn_layers[1]
x = nn_layers[2]
x = nn_layers[3]
x = nn_layers[4]
Essentially, for loops should be reserved for situations like the above - they can clean up our code. This is because flattening loops can be pretty costly in the compilation phase. As a result, when your loop is large, the compilation will slow down. This slowing down may become minutes, or tens of minutes. Hence, we generally try to avoid them and opt for one of the JAX specific loops.
Scan (when you can)
The jax.lax.scan
function will be your new best friend as it is a true swiss army knife.
Essentialy, jax.lax.scan
is:
- A regular loop (albeit restricted to fixed length loops),
- which can loop over vector data easily;
- while providing a carry value on each iteration,
- and constructs batches of data easily.
Copied over from the JAX documentation, below python implementation roughly provides the semantics of scan()
:
def scan(f, init, xs, length=None):
if xs is None:
xs = [None] * length
carry = init
ys = []
for x in xs:
carry, y = f(carry, x)
ys.append(y)
return carry, np.stack(ys)
Likely, you will not use scan()
for all its functionality all the time, but only use a subset. Let’s go over some examples to make things clearer.
Below we will perform a rollout of environment steps and create a batch of transitions.
def do_rollout(carry, _):
(obs, state) = carry
action = policy_net_fw(obs, ...)
value = critic_net_fw(obs, ...)
(new_obs, reward, done, info), new_state = env.step(
action, state, ...
)
transition = (obs, action, reward, done, value, info)
new_carry = (new_obs, new_state)
return new_carry, transition
num_rollout_steps = 100
obs, state = env.reset()
initial_carry = (obs, state)
final_carry, transitions = jax.lax.scan(
f=do_rollout,
init=initial_carry,
xs=None,
length=num_rollout_steps
)
# transitions is now a tuple of arrays
# Each with shape (num_rollout_steps, ...)
In the above function call, we construct a batch of transitions. We do not have to use this functionality — we could just return None
as the second argument, and then scan()
is essentially a regular loop.
You will further notice that, in the above do_rollout()
function definition, we do not use the second argument. This input corresponds to the xs
argument of the scan()
call.
xs
takes an array and will scan over the leading axis.
For example, we can scan over our collected trajectory and update our models:
# Assume we have some agent as a network: `curr_model`
# and some function that updates this model with `update_model()`
def update(model, transition):
# `transition` is a single row from `transitions`
updated_model = update_model(model, transition)
return updated_model, None
trained_model, _ = jax.lax.scan(
fn=update,
init=curr_model,
xs=transitions # <-- scans over the first axis
)
While_loop
The next loop operator is jax.lax.while_loop
. Unlike scan()
, while_loop()
allows for dynamic length loops.
It does however not support reverse-mode-differentiation. So while scan()
may be used in a loss function, while_loop()
may not.
A while_loop()
functions similarly to scan()
in that it uses a carry value, but now it terminates on some function that should return a boolean. In below example we run for a full environment episode:
def step_in_env(carry):
obs, done, env_state = carry
action = policy_net_fw(obs, ...)
(new_obs, reward, done, info), new_state = env.step(
action, env_state, ...
)
... # e.g. count the rewards
next_carry = (new_obs, done, new_state)
return next_carry
def cond_func(carry): # <-- must take the same carry
obs, done, env_state = carry
return ~done
obs, env_state = env.reset(...)
done = False
init_carry = (obs, done, env_state)
final_carry = jax.lax.while_loop(
cond_func, step_in_env, init_carry
)
Note that using while_loop()
inside of jax.vmap()
may cause problems, which we will get into later.
fori_loop
The fori_loop
is more of a helper function that is converted to a scan()
when the number of iterations is static (known at compile time); or into a while_loop()
if the number of iterations is dynamic. The interface of this function may seem a bit more familiar, but I’d personally just stick to using scan()
and while_loop()
to keep your code more explicit.
Other Transformations
Besides @jit
, JAX contains a few more function transformations that can be applied just as easily. We will go over vmap
for vectorization, grad
for computing gradients, and briefly address explicit parallelization.
For these, we luckily don’t need to introduce any new restriction.
In fact, a number of them do not apply if we do not use these transformations together with @jit
.
@vmap
jax.vmap
is insanely powerful. It takes in any function, and vectorizes the operation.
This is essentially what happens a lot in neural networks when you provide a batch of observations:
import some_neural_network_library as nn
import numpy as np
import jax
class some_model(nn.Module):
def __init__(in_shape, out_shape):
self.layer1 = nn.Linear(in_shape=in_shape, nodes=64, out_shape=64)
self.layer2 = nn.Linear(in_shape=64, nodes=64, out_shape=64)
self.layer3 = nn.Linear(in_shape=64, nodes=64, out_shape=out_shape)
def __call__(x): # forward_fn
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
return x
batch_size, in_shape, out_shape = 50, 10, 1
data_batch = np.random.uniform(shape=(batch_size, in_shape))
model = some_model(in_shape, out_shape)
# Model takes in observations of 10
# we vectorize the operation to take process a batch
data_out = jax.vmap(model)(data_batch) # data_out shape: (50, 1)
Of course, batching data through a neural network is usually handled by our neural network libraries (PyTorch, TensorFlow, Flax). However, now we have this power at our direct fingertips and can vectorize anything! For example, we can vectorize our environments easily, without any environment wrappers:
num_envs = 5
seed = jax.random.PRNGKey(42)
env_seeds = jax.random.split(num_envs) # Different seed for each env
obs_v, env_state_v = jax.vmap(env.reset)(env_seeds)
In the above example, obs_v
and env_state_v
are now observations and states for each of the 5 environments.
Next, we can easily take steps in our 5 environments:
actions_v = jax.vmap(policy_net)(obs_v) # 5 actions
(obs_v, reward_v, done_v, info_v), env_state_v = jax.vmap(
env.step
)(env_state_v, actions_v) # will map over both env_state_v and actions
Importantly the inputs to a vmapped function must all have an axis that is the same size. In the above example, both actions_v
and env_state_v
have a leading axis of 5. You can also select different axes with the in_axis argument to jax.vmap
.
Sometimes, you may want to provide inputs that should not mapped over. For example, let’s say, for whatever reason, we want to apply the same action in all of our environment:
action = policy_net(obs_v[0]) # 1 action
(obs_v, reward_v, done_v, info_v), env_state_v = jax.vmap(
env.step, in_axes=(0, None) # <-- vmap the 0th axis of the first argument and NO axis of the second argument
)(env_state_v, action) # will map over env_state_v
while vmap
is powerful, it’s important to understand that the operations are vectorized, not parallelized. The XLA compiler will likely take advantage of multiple cores on the system while executing vectorized operations. However, vectorized operations do have to run in sync, meaning that within a vectorized function, all operations on the data have to be the same.
For example, we cannot have while_loop
s inside our env.step
function where the length of the loop would depend on the environment state itself. This is because some environments will terminate the loop earlier than others and there is no way for the environments to communicate.
@grad | Autograd & Updating models
Next up: @grad
. @grad
will simply compute the gradient of a function with respect to its first argument.
This first argument would likely be a set of model weights, and the function itself will be a loss function:
@jax.grad # <- alt.: @jax.value_and_grad to return (loss, grads) tuple
def ppo_loss_function(weights, minibatch):
# calculate ppo loss with current parameters (weights)
...
return loss
grads = ppo_loss_function(agent_weights, rollout_minibatch)
While we are using autograd, let’s show how updating model parameters would work:
import optax # <-- official optimization library
optimizer = optax.adam(learning_rate=1e-4)
optimizer_state = optimizer.init(agent_weights)
# ... start training, collect rollout, create minibatches
grads = ppo_loss_function(agent_weights) # get gradients
updates, new_optimizer_state = optimizer.update(
grads, optimizer_state
) # get changes to parameters
new_agent_weights = optax.apply_updates(agent_weights, updates) # add changes
Essentially, that’s it. To me this method of updating a model is extremely explicit while still concise. There is no hidden magic going on (looking at alternatives like model.fit()
). In large part because there is no hidden state, I also find this much clearer compared to the PyTorch combination of loss.backward()
, optim.zero_grad()
optim.step()
.
In all these examples, I have been a bit loose with what exactly the model/agent is; i.e. how its class is defined, we will explore this more when discussing neural network libraries
pmap, xmap and shard_map ?
One more prominent JAX transformation would for parallelization. While XLA will already often take advantage of multiple cores, JAX provides methods of easily parallelizing and sharding data over multiple devices, or even different machines in a cluster. Most of this will be highly relevant when you are dealing with huge networks. The API however, is still often changing and not well documented. Jake summarizes the current state pretty nicely on this GitHub issue. I personally haven’t required using one of these methods, so I cannot add a lot here.
What else is important
We’ve discussed the important JAX transformation and their requirements. What may have been clear already, butall of these transformations are composable — we can construct a jit-compiled vmapped gradient function if we like. Let’s now dive into some other important stuff such as neural network libraries, debugging, and perhaps most importantly: PyTrees.
PyTrees
When listening to your favorite colleague talk about JAX (or reading up about JAX yourself), you probably encountered the term PyTree. PyTrees are pretty fundamental in JAX, and thus its useful to have at least some intuition of them. However, you can get decently far in JAX without really understanding Pytrees, hence I introduce them quite late here.
As mentioned, JAX fundamentally operates on arrays, which covers most use cases when we are working with data. However, we would often rather like to work with collections of arrays. For example, a neural network may be viewed as a list of matrices. For this, JAX provides the PyTree abstraction.
You can consider a PyTrees an extendible tree-like structure with nodes and leaves. The nodes are registered container-like classes, which by default include dict
, tuple
and list
. All other classes, not registered as a node, are PyTree leaves.
The idea here is that PyTree leaves represent elements that we want to operate on, which are usually arrays/scalars. PyTree nodes contain these leaves (or other PyTrees) along with functions to flatten (and reconstruct) them. By flattening each node, JAX internally can still operate on all leaves as if they were separate regular function arguments, while we can structure these in any way we want. JAX has functions to visualize a PyTrees structure and its leaves:
examples = [
1,
[1, 2, 3],
[(1, 2), (3, "hello world")],
{"a": 1, "b": 4.0, "c": {"d": np.array([1,2,3]), "e": [4, 5, 6]}}
]
for pytree in examples:
print(f"PyTree structure: {jax.tree.structure(pytree)}")
print(f"PyTree leaves: {jax.tree.leaves(pytree)} \n")
PyTree structure: PyTreeDef(*)
PyTree leaves: [1]
PyTree structure: PyTreeDef([*, *, *])
PyTree leaves: [1, 2, 3]
PyTree structure: PyTreeDef([(*, *), (*, *)])
PyTree leaves: [1, 2, 3, 'hello world']
PyTree structure: PyTreeDef({'a': *, 'b': *, 'c': {'d': *, 'e': [*, *, *]}})
PyTree leaves: [1, 4.0, array([1, 2, 3]), 4, 5, 6]
JAX transformation functions such grad
, jit
and vmap
require PyTrees as their input. This requirement is easily met, since essentially anything can be considered a PyTree — when there are no nodes, the PyTree will just be a single leaf. While PyTrees may be big, complex structures, in practice, you will likely just work with simple trees. We have done so through many examples above, for example when passing tuple
as the carry for scan()
.
JAX provides functions to work with PyTrees, the most common of which is probably tree.map
, which essentially just applies a function on all the leaves in a PyTree:
# We have a network of three layers, stored in a list
weights = [layer_1_w, layer_2_w, layer_3_w]
updates = optimizer.update(grads, ...)
updated_weights = jax.tree.map(
lambda w: w + updates
) # This is essentially the same as `optax.apply_updates`
Note that anything can be a leaf, even types that XLA cannot deal with, like strings:
("Hello", "Hi")
jax.tree.map(lambda w: w + " World")
('Hello World', 'Hi World')
As such, the concept of PyTrees are not necessarily exclusive to JAX. Anyhow, this means that you are free to use jax.tree
functionality on any type of collection of things. For JAX, the only thing that is relevant is that the input and output PyTrees of jitted functions should not contain types that XLA does not support.
Extending PyTrees
As I said, by default the list of PyTree nodes is defined as list
, tuple
and dict
. However, this list can be extended with any custom class. Essentially, you just have to tell JAX how to flatten and restructure your new collection type. You do this by defining two methods and registering the class with a decorator.
It is all pretty well explained in the documentation, so I won’t go over it much here.
Furthermore, extending PyTrees is made a bit easier with packages such as Equinox, which we will discuss more later.
Breaking out of XLA | Callbacks & Debugging
“Debugging in JAX is hard”, is often claimed. I would say it is definitely harder, but still very much doable.
As we have seen, print()
statement do not print any values, but only TracerArrays
. Even though these do not contain values, they do contain shapes and dtypes. In my experience, a significant part of debugging is making sure the shapes of your data is how you would expect them to be — so usually this information is enough.
Of course, we cannot only print()
, but also place a regular breakpoint()
. with this, you will enter the Python debugger during the compilation phase, allowing you to freely inspect any variables shapes and dtypes.
JAX does offer JAX specific debugging tools. The primary ones come in some form of jax.debug.callback()
.
callback()
allows us to transport data during runtime within JIT-compiled code out to the host (at a performance cost). We can simply write a regular Python function and do anything with this data, such as printing or logging:
import wandb
def print_and_log(data):
print("runtime data:", data)
wandb.log(data)
@jax.jit
def func(x):
z = x ** 2
jax.debug.callback(print_and_log, z)
z = z - 1
return z
arr = jnp.array([1,2,3])
print("final data: ", func(arr))
runtime data: [1 4 9]
final data: [0 3 8]
JAX also provides a convenient jax.debug.print() function that wraps callback()
. And most importantly, we have jax.debug.breakpoint()
. jax.debug.breakpoint()
behaves like a regular breakpoint()
, but will break during the runtime phase, allowing you to inspect values of any variables. The JAX debugger is slightly less powerful compared to the regular Python debugger though.
Lastly, JAX provides a few debugging flags. For instance, you can disable jit compilation entirely with jax.config.update("jax_disable_jit", True)
. This will generally run slow, but it allows for any regular debugging techniques you may already have.
The other noteworthy flag is jax.config.update("jax_debug_nans", True)
, which is aimed at finding when exactly NaN
values occur.
NN libraries & Equinox
JAX in itself is not a direct PyTorch or Tensorflow replacement. While it does contain autograd functionalities, and it is therefor significantly easier to build a neural network compared to building one in Numpy, it lacks easy tools such as pre-implemented layers.
As such, many neural network libraries compatible with jax exists (Equinox, Flax, Haiku to name a few). Out of these libraries, Flax is looking like the biggest player in town. Personally, though I much prefer Equinox. Equinox keeps things simple and elegant. Essentially, it contains toolings which could reasonably be part of the standard JAX library. Equinox Modules are automatically PyTrees (and this functionality can be used for more than just neural network libraries), and hence Equinox networks are just collections of arrays. JAX transformations know how to deal with these collections and hence it seems very elegant, yet also explicit.
Equinox further provides methods of manipulating PyTrees, such as filtering
and tree_at()
. You will find these methods generally helpful when working in JAX, so Equinox really seems like a one-stop-shop.
Flax does have a bigger community, and hence more tutorials and code exists online. They also recently introduced a new nnx
api, which may improve their workflow. In general though, I found Flax to introduce more magic than I would like.
RL Pipeline
What hasn’t truly been discussed yet, is an optimal pipeline for reinforcement learning training. JAX will allow for some function to be JIT compiled, while others are not. This would allow us to build our agent in JAX, while still training on regular Gymnasium environment (as many will exist).
However, to get the biggest benefit, we strive to build the entire pipeline in JAX. As Chris Lu has shown in PureJAXRL, we may see performance improvements of over a 1000 times. In general, a rough sketch of the entire training pipeline should look something like this:
# Setup agent and environment
env = gymnax.make("CartPole-v1")
actor_critic = EquinoxModel(...)
optimizer = optax.adam(...)
optimizer_state = optimizer.init(actor_critic)
@jax.jit
def train():
def train_step(actor_critic, ...):
batch = do_rollout(...)
# Maybe process data: e.g. GAE, minibatches ...
updated_model = update_on_batch(actor_critic, batch, ...)
trained_actor_critic = jax.lax.scan(
train_step,
(actor_critic, ...)
None,
num_training_iterations
)
return trained_model
Essentially, by keeping everything in a single training loop, we can avoid throwing around data between the GPU and CPU, which provides a messive
Environments in JAX & Gym API’s
An excelent resource for JAX based environments can be found here.