ML Hyperpolyglot / Frameworks

a side-by-side reference sheet

general | core characteristics | distributed training | layers

Contributions welcome on GitHub.

General
PyTorch Keras JAX TensorFlow
First Release 2016 2015 2018 2015
PyPI Package torch keras jax tensorflow
Default Image Format NCHW (channels first) (depends on backend) NHWC (channels last) NHWC (channels last)
Core Characteristics
PyTorch Keras JAX TensorFlow
Default Execution Model Eager (depends on backend) Eager Eager (2.x) or graph (1.x)
Compilation torch.compile (depends on backend) jax.jit tf.function
Automatic Differentiation
loss = compute_loss(model(x), y)
loss.backward()
# gradients in param.grad
(via backend)
grad_fn = jax.grad(loss_fn)
grads = grad_fn(params, x, y)
with tf.GradientTape() as tape:
  loss = compute_loss(model(x), y)
grads = tape.gradient(loss, model.trainable_variables)
Vectorization
# Manual batching or torch.vmap
batched_fn = torch.vmap(fn)
outputs = batched_fn(inputs)
(via backend)
batched_fn = jax.vmap(fn)
outputs = batched_fn(inputs)
# Manual batching via tf.map_fn
outputs = tf.map_fn(fn, inputs)
Dynamic Shapes
# Native support
x = torch.randn(batch_size, seq_len)
(via backend)
# Shape polymorphism for export
@jax.jit
def f(x):  # works with varying shapes
  return jax.numpy.sum(x)
# Native support in eager mode
@tf.function(input_signature=[tf.TensorSpec(shape=[None, None])])
def f(x):
  return tf.reduce_sum(x)
Distributed Training
PyTorch Keras JAX TensorFlow
DDP (Data Parallel)
from torch.distributed.ddp import DistributedDataParallel
model = DistributedDataParallel(model)
(via backend)
# Replicate across devices
sharding = NamedSharding(mesh, PartitionSpec('data'))
@jax.jit
def train_step(state, batch):
  return updated_state
strategy = tf.distribute.MirroredStrategy()
with strategy.scope():
  model = create_model()
FSDP (Fully Sharded)
from torch.distributed.fsdp import FullyShardedDataParallel
model = FullyShardedDataParallel(model)
(via backend)
# Shard params across devices
sharding = NamedSharding(mesh, PartitionSpec(None, 'fsdp'))
@jax.jit
def train_step(params, batch):
  return updated_params
strategy = tf.distribute.experimental.ParameterServerStrategy(...)
with strategy.scope():
  model = create_model()
TP (Tensor Parallel)
# Via torch.distributed.tensor or libraries
from torch.distributed.tensor.parallel import parallelize_module
parallelize_module(model, device_mesh, parallelize_plan)
(via backend)
# Partition tensors across devices
sharding = NamedSharding(mesh, PartitionSpec('data', 'fsdp'))
@jax.jit
def forward(weights, x):
  return x @ weights
layout = dtensor.Layout.batch_sharded(mesh, 'model', rank=2)
weights = dtensor.DVariable(initial_value, layout=layout)
Layers
PyTorch Keras JAX TensorFlow
Dense / Linear
nn.Linear(in_features, out_features)
layers.Dense(units)
nnx.Linear(in_features, out_features, rngs=rngs)
tf.keras.layers.Dense(units)
Conv2D
nn.Conv2d(in_channels, out_channels, kernel_size)
layers.Conv2D(filters, kernel_size)
nnx.Conv(in_features, out_features, kernel_size, rngs=rngs)
tf.keras.layers.Conv2D(filters, kernel_size)
Conv1D
nn.Conv1d(in_channels, out_channels, kernel_size)
layers.Conv1D(filters, kernel_size)
nnx.Conv(in_features, out_features, kernel_size=(kernel_size,), rngs=rngs)
tf.keras.layers.Conv1D(filters, kernel_size)
Conv3D
nn.Conv3d(in_channels, out_channels, kernel_size)
layers.Conv3D(filters, kernel_size)
nnx.Conv(in_features, out_features, kernel_size=(k, k, k), rngs=rngs)
tf.keras.layers.Conv3D(filters, kernel_size)
ConvTranspose2D
nn.ConvTranspose2d(in_channels, out_channels, kernel_size)
layers.Conv2DTranspose(filters, kernel_size)
nnx.ConvTranspose(in_features, out_features, kernel_size, rngs=rngs)
tf.keras.layers.Conv2DTranspose(filters, kernel_size)
MaxPool2D
nn.MaxPool2d(kernel_size)
layers.MaxPooling2D(pool_size)
nnx.max_pool(x, window_shape, strides)
tf.keras.layers.MaxPooling2D(pool_size)
AvgPool2D
nn.AvgPool2d(kernel_size)
layers.AveragePooling2D(pool_size)
nnx.avg_pool(x, window_shape, strides)
tf.keras.layers.AveragePooling2D(pool_size)
BatchNorm
nn.BatchNorm2d(num_features)
layers.BatchNormalization()
nnx.BatchNorm(num_features, rngs=rngs)
tf.keras.layers.BatchNormalization()
LayerNorm
nn.LayerNorm(normalized_shape)
layers.LayerNormalization()
nnx.LayerNorm(num_features, rngs=rngs)
tf.keras.layers.LayerNormalization()
Dropout
nn.Dropout(p=0.5)
layers.Dropout(rate=0.5)
nnx.Dropout(rate=0.5, rngs=rngs)
tf.keras.layers.Dropout(rate=0.5)
ReLU
nn.ReLU()
layers.ReLU()
nnx.relu
tf.keras.layers.ReLU()
Softmax
nn.Softmax(dim=-1)
layers.Softmax()
nnx.softmax
tf.keras.layers.Softmax()
Embedding
nn.Embedding(num_embeddings, embedding_dim)
layers.Embedding(input_dim, output_dim)
nnx.Embed(num_embeddings, features, rngs=rngs)
tf.keras.layers.Embedding(input_dim, output_dim)
LSTM
nn.LSTM(input_size, hidden_size)
layers.LSTM(units)
nnx.LSTM(in_features, hidden_size, rngs=rngs)
tf.keras.layers.LSTM(units)
GRU
nn.GRU(input_size, hidden_size)
layers.GRU(units)
nnx.GRU(in_features, hidden_size, rngs=rngs)
tf.keras.layers.GRU(units)
MultiHeadAttention
nn.MultiheadAttention(embed_dim, num_heads)
layers.MultiHeadAttention(num_heads, key_dim)
nnx.MultiHeadAttention(num_heads, in_features, rngs=rngs)
tf.keras.layers.MultiHeadAttention(num_heads, key_dim)
Flatten
nn.Flatten()
layers.Flatten()
x.reshape(x.shape[0], -1)
tf.keras.layers.Flatten()
Reshape
x.view(batch_size, -1)
layers.Reshape(target_shape)
x.reshape(new_shape)
tf.keras.layers.Reshape(target_shape)
Concatenate
torch.cat([x1, x2], dim=1)
layers.Concatenate()([x1, x2])
jnp.concatenate([x1, x2], axis=1)
tf.keras.layers.Concatenate()([x1, x2])
Add
x1 + x2
layers.Add()([x1, x2])
x1 + x2
tf.keras.layers.Add()([x1, x2])
GlobalAvgPool2D
nn.AdaptiveAvgPool2d((1, 1))
layers.GlobalAveragePooling2D()
jnp.mean(x, axis=(1, 2))
tf.keras.layers.GlobalAveragePooling2D()
GlobalMaxPool2D
nn.AdaptiveMaxPool2d((1, 1))
layers.GlobalMaxPooling2D()
jnp.max(x, axis=(1, 2))
tf.keras.layers.GlobalMaxPooling2D()