ML Hyperpolyglot / Frameworks
a side-by-side reference sheet
general | core characteristics | distributed training | layers
Contributions welcome on GitHub.
| General | ||||
|---|---|---|---|---|
| PyTorch | Keras | JAX | TensorFlow | |
| First Release | 2016 | 2015 | 2018 | 2015 |
| PyPI Package | torch | keras | jax | tensorflow |
| Default Image Format | NCHW (channels first) | (depends on backend) | NHWC (channels last) | NHWC (channels last) |
| Core Characteristics | ||||
| PyTorch | Keras | JAX | TensorFlow | |
| Default Execution Model | Eager | (depends on backend) | Eager | Eager (2.x) or graph (1.x) |
| Compilation | torch.compile | (depends on backend) | jax.jit | tf.function |
| Automatic Differentiation |
|
(via backend) |
|
|
| Vectorization |
|
(via backend) |
|
|
| Dynamic Shapes |
|
(via backend) |
|
|
| Distributed Training | ||||
| PyTorch | Keras | JAX | TensorFlow | |
| DDP (Data Parallel) |
|
(via backend) |
|
|
| FSDP (Fully Sharded) |
|
(via backend) |
|
|
| TP (Tensor Parallel) |
|
(via backend) |
|
|
| Layers | ||||
| PyTorch | Keras | JAX | TensorFlow | |
| Dense / Linear |
|
|
|
|
| Conv2D |
|
|
|
|
| Conv1D |
|
|
|
|
| Conv3D |
|
|
|
|
| ConvTranspose2D |
|
|
|
|
| MaxPool2D |
|
|
|
|
| AvgPool2D |
|
|
|
|
| BatchNorm |
|
|
|
|
| LayerNorm |
|
|
|
|
| Dropout |
|
|
|
|
| ReLU |
|
|
|
|
| Softmax |
|
|
|
|
| Embedding |
|
|
|
|
| LSTM |
|
|
|
|
| GRU |
|
|
|
|
| MultiHeadAttention |
|
|
|
|
| Flatten |
|
|
|
|
| Reshape |
|
|
|
|
| Concatenate |
|
|
|
|
| Add |
|
|
|
|
| GlobalAvgPool2D |
|
|
|
|
| GlobalMaxPool2D |
|
|
|
|