---
title: "Training Foundation Models on Supercomputers"
description: "A deep dive into the challenges and solutions for training large-scale AI models on supercomputing infrastructure."
categories: ["AI", "Machine Learning", "Supercomputing"]
location: "Georgia Institute of Technology"
location-logo: "assets/gatech.png"
location-url:
date: 2025-10-15
date-modified: last-modified
image: ./assets/thumbnail.png
lightbox: auto
editor:
render-on-save: true
twitter-card:
image: ./assets/thumbnail.png
site: "saforem2"
creator: "saforem2"
title: "Training Foundation Models on Supercomputers"
description: "Presented at the Georgia Intitute of Technology"
open-graph:
title: "Training Foundation Models on Supercomputers"
description: "Presented at the Georgia Intitute of Technology"
image: "./assets/thumbnail.png"
citation:
author: Sam Foreman
type: speech
url: https://samforeman.me/talks/2025/10/15/slides.html
format:
html:
image: "assets/thumbnail.png"
revealjs:
image: "assets/thumbnail.png"
shift-heading-level-by: -1
logo: "/assets/anl-black.svg"
slide-url: https://samforeman.me/talks/2025/10/15/slides.html
footer: "[samforeman.me/talks/2025/10/15](https://samforeman.me/talks/2025/10/15/slides)"
#footer: "[samforeman.me/talks/2025/10/15/slides](https://samforeman.me/talks/2025//24/slides.html)"
template-partials:
- "title-slide.html"
title-slide-attributes:
# data-background-opacity: "0.5"
# scale: 90%
# data-background-color: "#CCFFCC"
data-background-color: "#E8FAFF"
# data-background-color: "oklch(from #CCFFCC calc(l * 1.15) c h/0.05)"
#data-background-color: "oklch(from #CCFFCC l c * 1.15 h / 0.1)"
# data-background-color: "oklch(from #E599F7 l c * 1.15 h / 0.1)"
mermaid-format: "svg"
mermaid:
layout: dagre
useMaxWidth: true
gfm: default
---
## 🌐 Distributed Training {style="width:100%;"}
### 🚀 Scaling: Overview
- ✅ **Goal**:
- Minimize: [Cost]{.highlight-red} (i.e. amount of time spent training)
- Maximize: [Performance]{.highlight-blue}
::: {.callout-note icon=false title="📑 Note"}
See
[🤗 Performance and Scalability](https://huggingface.co/docs/transformers/v4.46.0/performance)
for more details
:::
::: notes
In this talk, we will explore the intricacies of training foundation models on
supercomputers. We will discuss the architecture of these models, the
computational requirements, and the strategies employed to optimize training
processes. Attendees will gain insights into the latest advancements in
hardware and software that facilitate efficient model training at scale.
:::
### 🐢 Training on a Single Device
::: {.content-visible unless-format="revealjs"}
::: {#fig-html-single-device}
```{mermaid}
flowchart LR
subgraph G0["`GPU0`"]
subgraph N0["`Network`"]
end
L0("`Loss`")
end
subgraph D["`Data`"]
x("`x0`")
x1("`x1`")
x2("`x2`")
end
x --> N0
N0 --> L0
L0 --> N0
classDef block fill:#CCCCCC02,stroke:#838383,stroke-width:1px,color:#838383
classDef eblock fill:#CCCCCC02,stroke:#838383,stroke-width:0px,color:#838383
classDef red fill:#ff8181,stroke:#333,stroke-width:1px,color:#000
classDef green fill:#98E6A5,stroke:#333,stroke-width:1px,color:#000
classDef blue fill:#7DCAFF,stroke:#333,stroke-width:1px,color:#000
classDef grey fill:#cccccc,stroke:#333,stroke-width:1px,color:#000
class x,L0 red
class x1, green
class x2, blue
class x3, grey
class N0,G0,n0 block
class D eblock
```
**SLOW** !! model size limited by GPU memory
:::
:::
::: {.content-visible when-format="revealjs"}
::: {.r-stack}
::: {.fragment}
```{mermaid}
flowchart LR
subgraph G0["`GPU0`"]
subgraph N0["`Network`"]
end
L0("`Loss`")
end
subgraph D["`Data`"]
x("`x0`")
x1("`x1`")
x2("`x2`")
end
x --> N0
N0 --> L0
L0 --> N0
classDef block fill:#CCCCCC02,stroke:#838383,stroke-width:1px,color:#838383
classDef red fill:#ff8181,stroke:#333,stroke-width:1px,color:#000
classDef green fill:#98E6A5,stroke:#333,stroke-width:1px,color:#000
classDef blue fill:#7DCAFF,stroke:#333,stroke-width:1px,color:#000
classDef grey fill:#cccccc,stroke:#333,stroke-width:1px,color:#000
classDef eblock fill:#CCCCCC02,stroke:#838383,stroke-width:0px,color:#838383
class D eblock
class x,L0 red
class x1, green
class x2, blue
class x3, grey
class N0,G0,n0 block
```
:::
::: {.fragment}
```{mermaid}
flowchart LR
subgraph G0["`GPU0`"]
subgraph N0["`Network`"]
end
L0("`Loss`")
end
subgraph D["`Data`"]
x("`x1`")
x1("`x2`")
x2("`x3`")
end
x --> N0
N0 --> L0
L0 --> N0
classDef block fill:#CCCCCC02,stroke:#838383,stroke-width:1px,color:#838383
classDef red fill:#ff8181,stroke:#333,stroke-width:1px,color:#000
classDef grey fill:#cccccc,stroke:#333,stroke-width:1px,color:#000
classDef orange fill:#FFC47F,stroke:#333,stroke-width:1px,color:#000
classDef yellow fill:#FFFF7F,stroke:#333,stroke-width:1px,color:#000
classDef green fill:#98E6A5,stroke:#333,stroke-width:1px,color:#000
classDef blue fill:#7DCAFF,stroke:#333,stroke-width:1px,color:#000
classDef purple fill:#FFCBE6,stroke:#333,stroke-width:1px,color:#000
classDef text fill:#CCCCCC02,stroke:#838383,stroke-width:0px,color:#838383
classDef eblock fill:#CCCCCC02,stroke:#838383,stroke-width:0px,color:#838383
class D eblock
class x,L0 green
class x1, blue
class x2, yellow
class x3, grey
class N0,G0,n0 block
```
:::
::: {.fragment}
```{mermaid}
flowchart LR
subgraph G0["`GPU0`"]
subgraph N0["`Network`"]
end
L0("`Loss`")
end
subgraph D["`Data`"]
x("`x2`")
x1("`x3`")
x2("`x4`")
end
x --> N0
N0 --> L0
L0 --> N0
classDef block fill:#CCCCCC02,stroke:#838383,stroke-width:1px,color:#838383
classDef red fill:#ff8181,stroke:#333,stroke-width:1px,color:#000
classDef grey fill:#cccccc,stroke:#333,stroke-width:1px,color:#000
classDef orange fill:#FFC47F,stroke:#333,stroke-width:1px,color:#000
classDef yellow fill:#FFFF7F,stroke:#333,stroke-width:1px,color:#000
classDef green fill:#98E6A5,stroke:#333,stroke-width:1px,color:#000
classDef blue fill:#7DCAFF,stroke:#333,stroke-width:1px,color:#000
classDef purple fill:#FFCBE6,stroke:#333,stroke-width:1px,color:#000
classDef text fill:#CCCCCC02,stroke:#838383,stroke-width:0px,color:#838383
classDef eblock fill:#CCCCCC02,stroke:#838383,stroke-width:0px,color:#838383
class D eblock
class x,L0 blue
class x1, yellow
class x2, purple
class x3, grey
class N0,G0,n0 block
```
:::
:::
::: {.fragment style="text-align: center; margin-left: auto; margin-right: auto;"}
[SLOW!]{.highlight-red}: Model size limited by GPU memory
:::
:::
### 🕸️ Parallelism Strategies
::: {.flex-container style="justify-content: space-around;"}
::: {.column style="width: 45%"}
- **Data Parallelism**
- Split *data* across workers
- Easiest to implement
- *No changes to model*
:::
::: {.column style="width: 45%"}
- **Model Parallelism**
- Split _model_ across workers
- **Hybrid Parallelism**
- Combine data + model parallelism
- More complex to implement
- Requires changes to model
:::
:::
### 👬 Training on Multiple GPUs: Data Parallelism {style="width:100%"}
::: {#fig-ddp-training-mermaid}
```{mermaid}
flowchart LR
subgraph D["`Data`"]
direction TB
x2("`x2`")
x1("`x1`")
x("`x0`")
end
direction LR
subgraph G0["`GPU0`"]
direction LR
subgraph N0["`NN`"]
end
%%y0("`y₀`")
L0["`Loss`"]
end
subgraph G1["`GPU1`"]
direction LR
subgraph N1["`NN`"]
end
L1["`Loss`"]
end
subgraph G2["`GPU2`"]
direction LR
subgraph N2["`NN`"]
end
L2["`Loss`"]
end
x --> N0
x1 --> N1
x2 --> N2
N0 --> L0
N1 --> L1
N2 --> L2
classDef block fill:#CCCCCC02,stroke:#838383,stroke-width:1px,color:#838383
classDef eblock fill:#CCCCCC02,stroke:#838383,stroke-width:0px,color:#838383
classDef text fill:#CCCCCC02,stroke:#838383,stroke-width:0px,color:#838383
classDef grey fill:#cccccc,stroke:#333,stroke-width:1px,color:#000
classDef red fill:#ff8181,stroke:#333,stroke-width:1px,color:#000
classDef orange fill:#FFC47F,stroke:#333,stroke-width:1px,color:#000
classDef yellow fill:#FFFF7F,stroke:#333,stroke-width:1px,color:#000
classDef green fill:#98E6A5,stroke:#333,stroke-width:1px,color:#000
classDef blue fill:#7DCAFF,stroke:#333,stroke-width:1px,color:#000
classDef purple fill:#FFCBE6,stroke:#333,stroke-width:1px,color:#000
classDef eblock fill:#CCCCCC02,stroke:#838383,stroke-width:0px,color:#838383
class x,y0,L0 red
class x1,L1 green
class x2,L2 blue
class x3,ar grey
class N0,N1,N2,G0,G1,G2,GU block
class D eblock
class AR block
class bc text
```
Each GPU receives **unique** data at each step
:::
::: {.aside}
- See [🤗 Methods and tools for efficient training on a single GPU](https://huggingface.co/docs/transformers/v4.46.0/perf_train_gpu_one)
:::
### ▶️ Data Parallel: Forward Pass
::: {#fig-ddp-training-mermaid-allreduce}
```{mermaid}
flowchart LR
subgraph D["`Data`"]
direction TB
x("`x0`")
x1("`x1`")
x2("`x2`")
end
direction LR
subgraph G0["`GPU0`"]
direction LR
subgraph N0["`NN`"]
end
L0["`Loss`"]
end
subgraph G1["`GPU1`"]
direction LR
subgraph N1["`NN`"]
end
L1["`Loss`"]
end
subgraph G2["`GPU2`"]
direction LR
subgraph N2["`NN`"]
end
L2["`Loss`"]
end
ar("`Avg. Grads<br>(∑ₙgₙ)/N`")
x --> G0
x1 --> G1
x2 --> G2
N0 --> L0
N1 --> L1
N2 --> L2
L0 -.-> ar
L1 -.-> ar
L2 -.-> ar
classDef eblock fill:#CCCCCC02,stroke:#838383,stroke-width:0px,color:#838383
classDef block fill:#CCCCCC02,stroke:#838383,stroke-width:1px,color:#838383
classDef grey fill:#cccccc,stroke:#333,stroke-width:1px,color:#000
classDef red fill:#ff8181,stroke:#333,stroke-width:1px,color:#000
classDef orange fill:#FFC47F,stroke:#333,stroke-width:1px,color:#000
classDef yellow fill:#FFFF7F,stroke:#333,stroke-width:1px,color:#000
classDef green fill:#98E6A5,stroke:#333,stroke-width:1px,color:#000
classDef blue fill:#7DCAFF,stroke:#333,stroke-width:1px,color:#000
classDef purple fill:#FFCBE6,stroke:#333,stroke-width:1px,color:#000
classDef text fill:#CCCCCC02,stroke:#838383,stroke-width:0px,color:#838383
class x,y0,L0 red
class x1,L1 green
class x2,L2 blue
class x3,ar grey
class N0,N1,N2,G0,G1,G2,GU block
class D eblock
class AR block
class bc text
```
Average gradients across all GPUs
:::
### ◀️ Data Parallel: Backward Pass
::: {#fig-ddp-backward-mermaid}
```{mermaid}
flowchart RL
subgraph D["`Data`"]
direction TB
x("`x0`")
x1("`x1`")
x2("`x1`")
end
subgraph G0["`GPU0`"]
direction RL
subgraph N0["`NN`"]
end
L0["`Loss`"]
end
subgraph G1["`GPU1`"]
direction RL
subgraph N1["`NN`"]
end
L1["`Loss`"]
end
subgraph G2["`GPU2`"]
direction RL
subgraph N2["`NN`"]
end
L2["`Loss`"]
end
subgraph BC["`Send Updates`"]
direction TB
end
BC -.-> G0
BC -.-> G1
BC -.-> G2
L0 ~~~ N0
L1 ~~~ N1
L2 ~~~ N2
G0 ~~~ x
G1 ~~~ x1
G2 ~~~ x2
classDef grey fill:#cccccc,stroke:#333,stroke-width:1px,color:#000
classDef eblock fill:#CCCCCC02,stroke:#838383,stroke-width:0px,color:#838383
classDef block fill:#CCCCCC02,stroke:#838383,stroke-width:1px,color:#838383
classDef red fill:#ff8181,stroke:#333,stroke-width:1px,color:#000
classDef orange fill:#FFC47F,stroke:#333,stroke-width:1px,color:#000
classDef yellow fill:#FFFF7F,stroke:#333,stroke-width:1px,color:#000
classDef green fill:#98E6A5,stroke:#333,stroke-width:1px,color:#000
classDef blue fill:#7DCAFF,stroke:#333,stroke-width:1px,color:#000
classDef purple fill:#FFCBE6,stroke:#333,stroke-width:1px,color:#000
classDef text fill:#CCCCCC02,stroke:#838383,stroke-width:0px,color:#838383
class x,y0,L0 red
class x1,L1 green
class x2,L2 blue
class x3,ar grey
class N0,N1,N2,G0,G1,G2,GU block
class BC block
class bc text
class D eblock
```
Send global updates back to each GPU.
See: [PyTorch / Distributed Data Parallel](https://pytorch.org/tutorials/intermediate/ddp_tutorial.html)
:::
### 🔄 Collective Communication
- **Broadcast**: Send data from one node to all other nodes
- **Reduce**: Aggregate data from all nodes to one node
- **AllReduce**: Aggregate data from all nodes to all nodes
- **Gather**: Collect data from all nodes to one node
- **AllGather**: Collect data from all nodes to all nodes
- **Scatter**: Distribute data from one node to all other nodes
### Reduce
- Perform a _reduction_ on data across ranks, send to individual
::: {#fig-reduce-mermaid}
```{mermaid}
flowchart TD
subgraph R0["`0`"]
x0("`x0`")
end
subgraph R1["`1`"]
x1("`x1`")
end
subgraph R2["`2`"]
x2("`x2`")
end
subgraph R3["`3`"]
x3("`x3`")
end
subgraph AR["`Reduce`"]
xp["`z=reduce(x, 2, SUM)`"]
end
subgraph AR3["`3`"]
end
subgraph AR2["`2`"]
xp2("`z`")
end
subgraph AR1["`1`"]
end
subgraph AR0["`0`"]
end
x0 --> AR
x1 --> AR
x2 --> AR
x3 --> AR
AR --> AR3
AR --> xp2
AR --> AR1
AR --> AR0
classDef block fill:#CCCCCC02,stroke:#838383,stroke-width:1px,color:#838383
classDef red fill:#ff8181,stroke:#333,stroke-width:1px,color:#000
classDef orange fill:#FFC47F,stroke:#333,stroke-width:1px,color:#000
classDef green fill:#98E6A5,stroke:#333,stroke-width:1px,color:#000
classDef yellow fill:#FFFF7F,stroke:#333,stroke-width:1px,color:#000
classDef blue fill:#7DCAFF,stroke:#333,stroke-width:1px,color:#000
classDef purple fill:#FFCBE6,stroke:#333,stroke-width:1px,color:#000
classDef pink fill:#E599F7,stroke:#333,stroke-width:1px,color:#000
class R0,R1,R2,R3,AR,AR0,AR1,AR2,AR3, block
class xp,xp2 purple
class x0, red
class x1, green
class x2, blue
class x3, yellow
```
Reduce operation: one rank receives the reduction of input values across ranks
:::
### 🐣 Getting Started: In Practice
::: {.flex-container justify-content="space-around" style="gap: 5pt;"}
::: {.column width="45%"}
- 📦 **Distributed Training Frameworks**:
- 🍋 [saforem2 / `ezpz`](https://github.com/saforem2/ezpz)
- 🤖 [Megatron-LM](https://github.com/NVIDIA/Megatron-LM)
- 🤗 [Accelerate](https://huggingface.co/docs/accelerate/index)
- 🔥 PyTorch
- [DDP](https://docs.pytorch.org/tutorials/intermediate/ddp_tutorial.html) / [FSDP](https://docs.pytorch.org/tutorials/intermediate/FSDP_tutorial.html)
- 🚀 [DeepSpeed](https://www.deepspeed.ai/)
- [ZeRO Offloading](https://www.deepspeed.ai/tutorials/zero/)
- [Megatron-DeepSpeed](https://github.com/argonne-lcf/Megatron-DeepSpeed)
:::
::: {.column width="45%"}
- 🧠 **Memory Management**:
- FSDP vs. ZeRO
- Activation Checkpointing
- Mixed Precision Training
- Gradient Accumulation
- Offloading to CPU/NVMe
::: {.callout-important icon=false title="🔄 Keeping things in Sync"}
**Computation stalls during communication !!**
Keeping the communication to computation ratio small is important for effective scaling.
:::
:::
:::
### 📝 Plan of Attack
::: {#fig-scaling-strategy-mermaid}
```{mermaid}
flowchart TB
A{"Model Perfect?"}
A -- no --> M{"Available Memory?"}
A -- yes --> AD["Done"]
M -- yes --> MY["Make Model Larger"]
M -- no --> ZMP["<b>Free Up Memory</b>"]
MY --> A
ZMP --> MY
A:::block
M:::block
AD:::block
MY:::block
ZMP:::sblock
classDef text fill:#CCCCCC02,stroke:#838383,stroke-width:0px,color:#838383
classDef sblock fill:#CCCCCC02,stroke:#838383,stroke-width:1px,color:#838383,white-space:collapse
classDef block fill:#CCCCCC02,stroke:#838383,stroke-width:1px,color:#838383
```
General strategy for scaling model training
:::
### 🚀 Going Beyond Data Parallelism
- ✅ Useful when model fits on single GPU:
- ultimately **limited by GPU memory**
- model performance limited by size
- ⚠️ When model does not fit on a single GPU:
- Offloading (can only get you so far...):
- {{< iconify logos microsoft-icon >}} [DeepSpeed + `ZeRO`](https://www.deepspeed.ai/tutorials/zero/)
- 🔥 [PyTorch + `FSDP`](https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/)
- Otherwise, resort to [model parallelism strategies](#additional-parallelism-strategies)
### Going beyond Data Parallelism: {{< iconify logos microsoft-icon >}} DeepSpeed + `ZeRO`
- Depending on the `ZeRO` stage (1, 2, 3), we can offload:
1. **Stage 1**: optimizer states $\left(P_{\mathrm{os}}\right)$
2. **Stage 2**: gradients + opt. states $\left(P_{\mathrm{os}+\mathrm{g}}\right)$
3. **Stage 3**: model params + grads + opt. states $\left(P_{\mathrm{os}+\mathrm{g}+\mathrm{p}}\right)$
 + [`ZeRO`](https://www.deepspeed.ai/tutorials/zero-offload/)](./assets/zero.png){#fig-zero width=70%}
### 🕸️ Additional Parallelism Strategies
- **Tensor (/ Model) Parallelism** (`TP`):
- 🤗 [Tensor Parallelism](https://huggingface.co/docs/text-generation-inference/en/conceptual/tensor_parallelism)
- 🔥 [Large Scale Transformer model training with Tensor Parallel (TP)](https://pytorch.org/tutorials/intermediate/TP_tutorial.html)
- **Pipeline Parallelism** (`PP`):
- 🔥 [PyTorch](https://pytorch.org/docs/main/distributed.pipelining.html), {{<iconify logos microsoft-icon >}} [DeepSpeed](https://deepspeed.readthedocs.io/en/latest/pipeline.html)
- **Sequence Parallelism** (`SP`):
- {{< iconify logos microsoft-icon >}} [DeepSpeed Ulysses](https://github.com/microsoft/DeepSpeed/blob/master/blogs/deepspeed-ulysses/README.md)
- [Megatron / Context Parallelism](https://docs.nvidia.com/megatron-core/developer-guide/latest/api-guide/context_parallel.html)
- [Unified Sequence Parallel (USP)](https://arxiv.org/abs/2405.07719v3)
- {{< fa brands github >}} [feifeibear/`long-context-attention`](https://github.com/feifeibear/long-context-attention)
- [x] {{< fa brands github >}} [argonne-lcf/`Megatron-DeepSpeed`](https://github.com/argonne-lcf/Megatron-DeepSpeed)
- Supports 4D Parallelism (`DP` + `TP` + `PP` + `SP`)
### Pipeline Parallelism (PP)
::: {.flex-container style="place-content: end space-evenly;"}
::: {.column style="width:60%;"}
- Model is split up **vertically** (layer-level) across multiple GPUs
- Each GPU:
- has a portion of the full model
- processes _in parallel_ different stages of the pipeline
(on a small chunk of the batch)
- See:
- 🔥 [PyTorch / Pipeline Parallelism](https://pytorch.org/docs/main/distributed.pipelining.html)
- {{<iconify logos microsoft-icon >}} [DeepSpeed / Pipeline Parallelism](https://deepspeed.readthedocs.io/en/latest/pipeline.html)
:::
::: {.column style="width:40%;"}
::: {#fig-pipeline-parallelism}
```{mermaid}
flowchart TB
subgraph G0["`GPU 0`"]
direction LR
a0("`Layer 0`")
b0("`Layer 1`")
end
subgraph G1["`GPU 1`"]
direction LR
a1("`Layer 2`")
b1("`Layer 3`")
end
a0 -.-> b0
b0 --> a1
a1 -.-> b1
classDef block fill:#CCCCCC02,stroke:#838383,stroke-width:1px,color:#838383
classDef red fill:#ff8181,stroke:#333,stroke-width:1px,color:#000
classDef orange fill:#FFC47F,stroke:#333,stroke-width:1px,color:#000
classDef yellow fill:#FFFF7F,stroke:#333,stroke-width:1px,color:#000
classDef green fill:#98E6A5,stroke:#333,stroke-width:1px,color:#000
classDef blue fill:#7DCAFF,stroke:#333,stroke-width:1px,color:#000
classDef purple fill:#FFCBE6,stroke:#333,stroke-width:1px,color:#000
class G0,G1, block
class a0, red
class b0, green
class a1, blue
class b1, yellow
```
Pipeline Parallelism
:::
:::
:::
### Tensor Parallel (TP)
::: {layout="[50,50]"}
::: {.column}
- Each tensor is split up into multiple chunks
- Each shard of the tensor resides on its designated GPU
- During processing each shard gets processed separately (and in parallel) on
different GPUs
- synced at the end of the step
- See: [🤗 Model Parallelism](https://huggingface.co/docs/transformers/v4.15.0/parallelism) for additional details
:::
::: {.column}
::: {#fig-model-parallel-1}
```{mermaid}
flowchart LR
subgraph G0["`GPU0`"]
direction TB
a0("`Layer 0`")
b0("`Layer 1`")
c0("`Layer 2`")
d0("`Layer 3`")
end
subgraph G1["`GPU1`"]
direction TB
a1("`Layer 0`")
b1("`Layer 1`")
c1("`Layer 2`")
d1("`Layer 3`")
end
a0 <-.-> a1
b0 <-.-> b1
c0 <-.-> c1
d0 <-.-> d1
classDef red fill:#ff8181,stroke:#333,stroke-width:1px,color:#000
classDef orange fill:#FFC47F,stroke:#333,stroke-width:1px,color:#000
classDef yellow fill:#FFFF7F,stroke:#333,stroke-width:1px,color:#000
classDef green fill:#98E6A5,stroke:#333,stroke-width:1px,color:#000
classDef blue fill:#7DCAFF,stroke:#333,stroke-width:1px,color:#000
classDef purple fill:#FFCBE6,stroke:#333,stroke-width:1px,color:#000
classDef block fill:#CCCCCC02,stroke:#838383,stroke-width:1px,color:#838383
class G0,G1, block
class a0,a1 red
class b0,b1 green
class c0,c1 blue
class d0,d1 yellow
```
Tensor Parallel Training
:::
:::
:::
### Tensor Parallel (TP)
::: {layout="[50,50]"}
::: {.column}
- Suitable when the model is too large to fit onto a single device (CPU / GPU)
- Typically **more complicated** to implement than data parallel training
- This is what one may call _horizontal parallelism_
- Communication whenever dataflow between two subsets
- {{< fa brands github >}} [`argonne-lcf/Megatron-DeepSpeed`](https://github.com/argonne-lcf/Megatron-DeepSpeed)
- 🤗 [`huggingface/nanotron`](https://github.com/huggingface/nanotron)
:::
::: {.column}
::: {#fig-model-parallel-1}
```{mermaid}
flowchart LR
subgraph G0["`GPU0`"]
direction TB
a0("`Layer 0`")
b0("`Layer 1`")
c0("`Layer 2`")
d0("`Layer 3`")
end
subgraph G1["`GPU1`"]
direction TB
a1("`Layer 0`")
b1("`Layer 1`")
c1("`Layer 2`")
d1("`Layer 3`")
end
a0 <-.-> a1
b0 <-.-> b1
c0 <-.-> c1
d0 <-.-> d1
classDef red fill:#ff8181,stroke:#333,stroke-width:1px,color:#000
classDef orange fill:#FFC47F,stroke:#333,stroke-width:1px,color:#000
classDef yellow fill:#FFFF7F,stroke:#333,stroke-width:1px,color:#000
classDef green fill:#98E6A5,stroke:#333,stroke-width:1px,color:#000
classDef blue fill:#7DCAFF,stroke:#333,stroke-width:1px,color:#000
classDef purple fill:#FFCBE6,stroke:#333,stroke-width:1px,color:#000
classDef block fill:#CCCCCC02,stroke:#838383,stroke-width:1px,color:#838383
class G0,G1, block
class a0,a1 red
class b0,b1 green
class c0,c1 blue
class d0,d1 yellow
```
Tensor Parallel Training
:::
:::
:::
::: {.notes}
- Split up network over multiple workers
- Each receives disjoint subset
- All communication associated with subsets are distributed
- Communication whenever dataflow between two subsets
- Typically **more complicated** to implement than data parallel training
- Suitable when the model is too large to fit onto a single device (CPU / GPU)
:::
### Tensor (/ Model) Parallel Training: Example {.smaller style="width:100%;"}
Want to compute: $y = \sum_{i} x_{i} W_{i} = x_0 * W_0 + x_1 * W_1 + x_2 * W_2$
where each GPU has only its portion of the full weights as shown below
1. Compute: $y_{0} = x_{0} * W_{0}\rightarrow$ `GPU1`
2. Compute: $y_{1} = y_{0} + x_{1} * W_{1}\rightarrow$ `GPU2`
3. Compute: $y = y_{1} + x_{2} * W_{2} = \sum_{i} x_{i} W_{i}$ ✅
::: {#fig-tensor-parallel-example style="width:75%; margin-left: auto; margin-right: auto; text-align:center;"}
```{mermaid}
flowchart LR
subgraph X0["`GPU0`"]
direction LR
a("`W0`")
end
subgraph X1["`GPU1`"]
direction LR
b("`W1`")
end
subgraph X2["`GPU2`"]
direction LR
c("`W2`")
end
t0("`x0`")-->X0
X0 -->|"`x0 W0`"|X1
X1 -->|"`x0 W0 <br>+ x1 W1`"|X2
t1("`x1`") --> X1
t2("`x1`") --> X2
```
:::
::: {.content-visible when-format="revealjs"}
## {.smaller background-color="#040406"}
:::
::: {.flex-container style="align-items: center; gap: 5pt;"}
::: {.column style="width:55%; text-align: center;"}
[🔭 AI-for-Science]{style="font-weight: 600; font-size: 1.5em;"}
{{< iconify fa twitter >}} [source](https://x.com/tenderizzation/status/1944591320796090606)
([\@tenderizzation](https://twitter.com/tenderizzation))
<br>
ChatGPT: [explain this image](https://chatgpt.com/share/688ab77e-9ca0-800a-8ab0-a293e06b3cce)
:::
::: {.column}

:::
:::
### 🏗️ Aurora {style="width:100%"}
::: {.flex-container style="align-items: center; gap:10pt;"}
::: {.column #tbl-aurora}
| Property | Value |
| -----------: | :------ |
| Racks | 166 |
| Nodes | 10,624 |
| XPUs[^tiles] | 127,488 |
| CPUs | 21,248 |
| NICs | 84,992 |
| HBM | 8 PB |
| DDR5c | 10 PB |
: Aurora[^aurora-ai] Specs {.responsive .striped .hover}
:::
::: {#fig-aurora .r-stretch}

Aurora: [Fact Sheet](https://www.alcf.anl.gov/sites/default/files/2024-07/Aurora_FactSheet_2024.pdf).
:::
:::
[^tiles]: Each node has 6 Intel Data Center GPU Max 1550
(code-named "Ponte Vecchio") tiles, with 2 XPUs per tile.
[^aurora-ai]: 🏆 [Aurora Supercomputer Ranks Fastest for AI](https://www.intel.com/content/www/us/en/newsroom/news/intel-powered-aurora-supercomputer-breaks-exascale-barrier.html)
## 🌌 AuroraGPT (2024--)
::: {.flex-container style="justify-content: space-around;"}
::: {.column style="width: 50%"}
::: {.blue-card}
[**AuroraGPT**](https://auroragpt.anl.gov): *General purpose scientific LLM*
Broadly trained on a general corpora plus scientific {papers, texts, data}
:::
- **Explore pathways** towards a "Scientific Assistant" model
- **Build with international partners** (RIKEN, BSC, others)
- **Multilingual** English, 日本語, French, German, Spanish
- **Multimodal**: images, tables, equations, proofs, time series, graphs, fields, sequences, etc
:::
::: {.column style="text-align: center;"}
::: {#fig-awesome-llm}

Image from {{< iconify fa github >}}
[Hannibal046 / `Awesome-LLM`](https://github.com/Hannibal046/Awesome-LLM)
:::
:::
:::
### 🧪 AuroraGPT: Open Science Foundation Model
::: {#fig-aurora-gpt .r-stretch style="vertical-align:center;"}

High-level overview of AuroraGPT project
:::
### 🧰 AuroraGPT: Toolbox
- **Datasets and data pipelines**
(how do we deal with scientific data?)
- **Software infrastructure and workflows**
(scalable, robust, extensible)
- **Evaluation of state-of-the-art LLM Models**
(how do they perform on scientific tasks?)
::: {.flex-container style="gap: 5pt;"}
::: {.callout-note icon=false title="🚂 Training"}
{{< fa brands github >}} [argonne-lcf/Megatron-DeepSpeed](https://github.com/argonne-lcf/Megatron-DeepSpeed)
[Large Model Training: Any Scale, Any Accelerator]{.dim-text}
:::
::: {.callout-important icon=false title="🏃♂️ Running"}
{{< fa brands github >}} [argonne-lcf/inference-endpoints](https://github.com/argonne-lcf/inference-endpoints)
[Inference endpoints for LLMs, hosted @ ALCF]{.dim-text}
:::
:::
### 🏋️ Challenges: In Practice
This is _incredibly_ difficult in practice, due in part to:
- Brand new {hardware, architecture, software}
- Lack of native support in existing frameworks (though getting better!)
- General system stability
\+10k Nodes $\left(\times \frac{12\,\,\mathrm{XPU}}{1\,\,\mathrm{Node}}\right)\Rightarrow$ \+**100k** XPUs
- network performance
- file system stability (impacted by _other users_ !)
- _many_ unexpected difficulties occur at increasingly large scales
- Combinatorial explosion of possible configurations and experiments
- {hyperparameters, architectures, tokenizers, learning rates, ...}
### 💾 AuroraGPT: Training
- To train a fixed model on trillions of tokens requires:
1. **Aggregating** data from multiple different _corpora_
(e.g. ArXiv, Reddit, StackExchange, GitHub, Wikipedia, etc.)
1. **Sampling** _each training batch_ according to a fixed distribution
across corpora
1. **Building** indices that map batches of tokens into these files
(indexing)
::: {.red-card}
The original implementation was _slow_:
- Designed to run _serially_ on a **single device**
- **Major bottleneck** when debugging data pipeline at scale
:::
### 🍹 AuroraGPT: Blending Data, Efficiently
::: {.flex-container style="padding: 10pt; justify-content: space-around; align-items: flex-start;"}
::: {.column style="width:25%;"}
- 🐢 Original implementation:
- **Slow** (serial, single device)
- [\~ 1 hr]{.dim-text}/2T tokens
- 🐇 New implementation:
- **Fast!** (distributed, asynchronous)
- [\~ **2 min**]{style="color:#2296F3;"}/2T tokens
(**30x** faster !!)
:::
::: {.column}
{#fig-data-processing .r-stretch}
:::
:::
### 📉 Loss Curve: Training AuroraGPT-7B on 2T Tokens
::: {.content-visible when-format="html" unless-format="revealjs"}
::: {#fig-loss-curve}
{.width="90%" style="margin-left:auto;margin-right:auto;"}
Loss curve during training on 2T tokens.
:::
:::
::: {.content-visible when-format="revealjs"}
::: {#fig-loss-curve}
{width="90%" style="margin-left:auto;margin-right:auto;"}
Loss curve during training on 2T tokens.
:::
:::
### ✨ Features
- 🕸️ **Parallelism**:
- {data, tensor, pipeline, sequence, ...}
- ♻️ **Checkpoint Converters**:
- Megatron ⇄ 🤗 HF ⇄ ZeRO ⇄ Universal
- 🔀 **DeepSpeed Integration**:
- ZeRO Offloading
- Activation checkpointing
- AutoTP (*WIP*)
- ability to leverage features from DeepSpeed community
### ✨ Features (even more!)
- 🧗 **Optimizers**[^marieme]:
- Support for *many* different optimizers:
- Distributed Shampoo, Muon, Adopt, Sophia, Lamb, GaLORE, ScheduleFree, ...
- See
[full list](https://github.com/argonne-lcf/Megatron-DeepSpeed/blob/e3b0398d2f2d3f8ec543e99373ca14bd18a1e4f8/megatron/arguments.py#L1477-L1502)
- Large batch training
- 📊 **Experiment Tracking**:
- Automatic experiment and metric tracking with Weights \& Biases
[^marieme]: Implemented by Marieme Ngom
## 🧬 MProt-DPO {style="width:100%;"}
- [Finalist]{.highlight-green}: SC'24 [ACM Gordon Bell Prize](https://sc24.supercomputing.org/2024/10/presenting-the-finalists-for-the-2024-gordon-bell-prize/)
- [MProt-DPO: Breaking the ExaFLOPS Barrier for Multimodal Protein Design Workflows with Direct Preference Optimization](https://www.researchgate.net/profile/Carla-Mann-3/publication/387390653_MProt-DPO_Breaking_the_ExaFLOPS_Barrier_for_Multimodal_Protein_Design_Workflows_with_Direct_Preference_Optimization/links/67a0f736645ef274a46243f1/MProt-DPO-Breaking-the-ExaFLOPS-Barrier-for-Multimodal-Protein-Design-Workflows-with-Direct-Preference-Optimization.pdf)
(@mprot-dpo2024)
- One of the first protein design toolkits that integrates:
- Text, (protein/gene) sequence, structure/conformational sampling modalities
to build aligned representations for protein sequence-function mapping
### 🧬 Scaling Results (2024) {.smaller}
::: {.columns}
::: {.column style="width:70%;"}
::: {.flex-container style="align-items: center; text-align: center; margin-left: auto; margin-right: auto;"}
::: {#fig-mprot-3p5B-scaling0}
{width=100% style="margin:0; padding-unset;"}
Scaling results for `3.5B` model across ~38,400 GPUs
:::
:::
:::
::: {.column style="width:30%;"}
- ~ [4 EFLOPS]{.highlight-blue} @ Aurora
- 38,400 XPUs
= 3200 \[node\] x 12 \[XPU / node\]
- 🎖️ [Gordon Bell Finalist](https://sc24.supercomputing.org/2024/10/presenting-the-finalists-for-the-2024-gordon-bell-prize/):
- [MProt-DPO: Breaking the ExaFLOPS Barrier for Multimodal Protein Design Workflows](https://dl.acm.org/doi/10.1109/SC41406.2024.00013) (@mprot-dpo2024)
:::
:::
::: notes
This novel work presents a scalable, multimodal workflow for protein design
that trains an LLM to generate protein sequences, computationally evaluates the
generated sequences, and then exploits them to fine-tune the model.
Direct Preference Optimization steers the LLM toward the generation of
preferred sequences, and enhanced workflow technology enables its efficient
execution. A 3.5B and a 7B model demonstrate scalability and exceptional mixed
precision performance of the full workflow on ALPS, Aurora, Frontier, Leonardo
and PDX.
:::
### 🧬 MProt-DPO: Scaling Results {.smaller}
::: {.flex-container}
::: {.column #fig-mprot-3p5B-scaling}

`3.5B` model
:::
::: {.column #fig-mprot-7B-scaling}

`7B` model
:::
:::
### 🚂 Loooooooooong Sequence Lengths {.smaller style="width: 100%;"}
::: {.flex-container style="align-items: center; justify-content: center;"}
{style="height:50pt; margin: unset; padding: 0"}
[{{< iconify ic baseline-plus >}}]{.dim-text style="font-size: 2.0em;"}
{style="height:50pt; margin: unset; padding: 0;"}
:::
- Working with
[{{< fa brands microsoft >}} Microsoft/DeepSpeed](https://github.com/microsoft/DeepSpeed)
team to enable longer sequence lengths (context windows) for LLMs
- See my [blog post](https://samforeman.me/posts/auroragpt/long-sequences/) for additional details
::: {#fig-long-seq}
::: {.flex-container}


:::
Maximum (achievable) `SEQ_LEN` for both `25B` and `33B` models (See: @song2023ds4sci)
:::
::: aside
[{{< fa brands github >}} `scaling4science`](https://github.com/saforem2/scaling4science)
[{{< fa brands github >}} `Megatron-DS-Benchmarking`](https://github.com/saforem2/Megatron-DS-Benchmarking)
:::
## 🌎 AERIS (2025)
::: {.content-visible unless-format="revealjs"}
::: {.flex-container}
::: {.flex-child style="width:50%;"}
](./assets/team.png){#fig-arxiv}
:::
::: {.flex-child style="width:43.6%;"}

:::
:::
:::
::: {.content-visible when-format="revealjs"}
::: {.flex-container}
::: {.column style="width:50%;"}
](./assets/team.png){#fig-arxiv}
:::
::: {.column style="width:43.6%;"}
![Pixel-level Swin diffusion transformer in sizes from \[1--80\]B](./assets/aeris.svg){#fig-aeris-cover}
:::
:::
:::
::: notes
> We demonstrate a significant advancement in AI weather
> and climate modeling with AERIS by efficient scaling of
> window-based transformer models. We have performed global
> medium-range forecasts with performance competitive with
> GenCast and surpassing the IFS ENS model, with longer, 90-
> day rollouts showing our ability to learn atmospheric dynamics
> on seasonal scales without collapsing, becoming the first
> diffusion-based model that can work across forecast scales
> from 6 hours all the way to 3 months with remarkably accurate
> out of distribution predictions of extreme events.
:::
### 👀 High-Level Overview of AERIS {.smaller}
::: {.flex-container}
::: {#fig-rollout}

Rollout of AERIS model, specific humidity at 700m.
:::
::: {#tbl-aeris}
| Property | Description |
| -----------------: | :---------------- |
| Domain | Global |
| Resolution | 0.25° \& 1.4° |
| Training Data | ERA5 (1979--2018) |
| Model Architecture | Swin Transformer |
| Speedup[^pde] | O(10k--100k) |
: Overview of AERIS model and training setup {.responsive .striped .hover}
:::
:::
[^pde]: Relative to PDE-based models, e.g.: [GFS](https://www.ncdc.noaa.gov/data-access/model-data/model-datasets/global-forcast-system-gfs)
### ➕ Contributions
::: {.flex-container}
::: {.callout-caution icon=false title="☔ AERIS"}
[_First billion-parameter diffusion model for weather \+ climate_]{style="color:var(--callout-color-caution)!important;"}
- Operates at the pixel level (1 × 1 patch size), guided by physical priors
- Medium-range forecast skill:
- **Surpasses IFS ENS, competitive with GenCast[^gen-cast]**
- Uniquely stable on seasonal scales to 90 days
:::
::: {.callout-note icon=false title="🌀 SWiPe"}
[_A novel 3D (sequence-window-pipeline) parallelism strategy for training transformers across high-resolution inputs_]{style="color:var(--callout-color-note)!important;"}
- Enables scalable small-batch training on large supercomputers[^aurora-scale]
- **10.21 ExaFLOPS**
- @ 121,000 Intel XPUs (Aurora)
:::
:::
[^gen-cast]: [GenCast: A Generative Model for Medium-Range Global Weather Forecasting](https://arxiv.org/html/2312.15796v1) (@price2024gencast)
[^aurora-scale]: Demonstrated on up to 120,960 GPUs on Aurora and 8,064 GPUs on LUMI.
### ⚠️ Issues with the Deterministic Approach
::: {.flex-container}
::: {.flex-child}
- [{{< iconify material-symbols close>}}]{.red-text} [**Transformers**]{.highlight-red}:
- *Deterministic*
- Single input → single forecast
:::
::: {.flex-child}
<!-- {{< iconify ph github-logo-duotone >}} -->
- [{{<iconify material-symbols check>}}]{.green-text} [**Diffusion**]{.highlight-green}:
- *Probabilistic*
- Single input → _**ensemble of forecasts**_
- Captures uncertainty and variability in weather predictions
- Enables ensemble forecasting for better risk assessment
:::
:::
### 🎲 Transitioning to a Probabilistic Model
::: {#fig-forward-pass}

Reverse diffusion with the [input]{style="color:#228be6"} condition, individual
sampling steps $t_{0} \rightarrow t_{64}$, the next time step
[estimate]{style="color:#40c057"} and the [target]{style="color:#fa5252"}
output.
:::
::: {.flex-container}

{width="89.6%"}
:::
### 🌀 Sequence-Window-Pipeline Parallelism `SWiPe` {.smaller}
::: {.content-visible unless-format="revealjs"}
::: {.flex-container}
::: {.column style="width:33%;"}
- `SWiPe` is a **novel parallelism strategy** for Swin-based Transformers
- Hybrid 3D Parallelism strategy, combining:
- Sequence parallelism (`SP`)
- Window parallelism (`WP`)
- Pipeline parallelism (`PP`)
:::
::: {#fig-swipe-layer style="width:66%;"}

:::
:::
::: {#fig-comms style="width:80%; text-align: center; margin-left: auto; margin-right: auto; "}

`SWiPe` Communication Patterns
:::
:::
::: {.content-visible when-format="revealjs"}
::: {.flex-container}
::: {.column style="width:33%;"}
- `SWiPe` is a **novel parallelism strategy** for Swin-based Transformers
- Hybrid 3D Parallelism strategy, combining:
- Sequence parallelism (`SP`)
- Window parallelism (`WP`)
- Pipeline parallelism (`PP`)
:::
::: {#fig-swipe-layer style="width:66%;"}

:::
:::
::: {#fig-comms style="width:60%; text-align: center; margin-left: auto; margin-right: auto;"}

`SWiPe` Communication Patterns
:::
:::
### 🚀 AERIS: Scaling Results
::: {.flex-container}
::: {.column #fig-aeris-scaling style="width:70%;"}

AERIS: Scaling Results
:::
::: {.column style="width:30%;"}
- [**10 EFLOPs**]{.highlight-blue} (sustained) @ **120,960 GPUs**
- See (@stock2025aeris) for additional details
- [arXiv:2509.13523](https://arxiv.org/abs/2509.13523)
:::
:::
### 🌪️ Hurricane Laura
::: {#fig-hurricane-laura}

Hurricane Laura tracks (top) and intensity (bottom). Initialized 7(a), 5(b) and
3(c) days prior to 2020-08-28T00z.
:::
## 📓 References
::: {#refs}
:::
## ❤️ Acknowledgements
> This research used resources of the Argonne Leadership Computing
> Facility, which is a DOE Office of Science User Facility supported
> under Contract DE-AC02-06CH11357.