Training Foundation Models on Supercomputers

@ Georgia Institute of Technology

2025-10-15

🌐 Distributed Training

🚀 Scaling: Overview

🐢 Training on a Single Device

Data

GPU0

Network

Loss

x0

x1

x2

Data

GPU0

Network

Loss

x1

x2

x3

Data

GPU0

Network

Loss

x2

x3

x4

SLOW!: Model size limited by GPU memory

🕸️ Parallelism Strategies

  • Data Parallelism
    • Split data across workers
    • Easiest to implement
    • No changes to model
  • Model Parallelism
    • Split model across workers
  • Hybrid Parallelism
    • Combine data + model parallelism
    • More complex to implement
    • Requires changes to model

👬 Training on Multiple GPUs: Data Parallelism

GPU2

GPU1

GPU0

Data

NN

NN

NN

x2

x1

x0

Loss

Loss

Loss

Figure 1: Each GPU receives unique data at each step

▶️ Data Parallel: Forward Pass

GPU2

GPU1

GPU0

Data

NN

NN

NN

x0

x1

x2

Loss

Loss

Loss

Avg. Grads
(∑ₙgₙ)/N

Figure 2: Average gradients across all GPUs

◀️ Data Parallel: Backward Pass

Data

Send Updates

GPU2

NN

Loss

GPU1

NN

Loss

GPU0

NN

Loss

x0

x1

x1

Figure 3: Send global updates back to each GPU. See: PyTorch / Distributed Data Parallel

🔄 Collective Communication

  • Broadcast: Send data from one node to all other nodes
  • Reduce: Aggregate data from all nodes to one node
    • AllReduce: Aggregate data from all nodes to all nodes
  • Gather: Collect data from all nodes to one node
    • AllGather: Collect data from all nodes to all nodes
  • Scatter: Distribute data from one node to all other nodes

Reduce

  • Perform a reduction on data across ranks, send to individual

2

3

2

1

0

0

1

3

Reduce

z=reduce(x, 2, SUM)

x0

x1

x2

x3

z

Figure 4: Reduce operation: one rank receives the reduction of input values across ranks

🐣 Getting Started: In Practice

  • 🧠 Memory Management:
    • FSDP vs. ZeRO
    • Activation Checkpointing
    • Mixed Precision Training
    • Gradient Accumulation
    • Offloading to CPU/NVMe

🔄 Keeping things in Sync

Computation stalls during communication !!

Keeping the communication to computation ratio small is important for effective scaling.

📝 Plan of Attack

no

yes

yes

no

Model Perfect?

Available Memory?

Done

Make Model Larger

Free Up Memory

Figure 5: General strategy for scaling model training

🚀 Going Beyond Data Parallelism

Going beyond Data Parallelism: DeepSpeed + ZeRO

  • Depending on the ZeRO stage (1, 2, 3), we can offload:
    1. Stage 1: optimizer states \left(P_{\mathrm{os}}\right)
    2. Stage 2: gradients + opt. states \left(P_{\mathrm{os}+\mathrm{g}}\right)
    3. Stage 3: model params + grads + opt. states \left(P_{\mathrm{os}+\mathrm{g}+\mathrm{p}}\right)
Figure 6: DeepSpeed + ZeRO

🕸️ Additional Parallelism Strategies

Pipeline Parallelism (PP)

GPU 1

GPU 0

Layer 0

Layer 1

Layer 2

Layer 3

Figure 7: Pipeline Parallelism

Tensor Parallel (TP)

  • Each tensor is split up into multiple chunks
  • Each shard of the tensor resides on its designated GPU
  • During processing each shard gets processed separately (and in parallel) on different GPUs
    • synced at the end of the step
  • See: 🤗 Model Parallelism for additional details

GPU1

GPU0

Layer 0

Layer 1

Layer 2

Layer 3

Layer 0

Layer 1

Layer 2

Layer 3

Figure 8: Tensor Parallel Training

Tensor Parallel (TP)

  • Suitable when the model is too large to fit onto a single device (CPU / GPU)
  • Typically more complicated to implement than data parallel training
    • This is what one may call horizontal parallelism
    • Communication whenever dataflow between two subsets
  • argonne-lcf/Megatron-DeepSpeed
  • 🤗 huggingface/nanotron

GPU1

GPU0

Layer 0

Layer 1

Layer 2

Layer 3

Layer 0

Layer 1

Layer 2

Layer 3

Figure 9: Tensor Parallel Training

Tensor (/ Model) Parallel Training: Example

Want to compute: y = \sum_{i} x_{i} W_{i} = x_0 * W_0 + x_1 * W_1 + x_2 * W_2
where each GPU has only its portion of the full weights as shown below

  1. Compute: y_{0} = x_{0} * W_{0}\rightarrow GPU1
  2. Compute: y_{1} = y_{0} + x_{1} * W_{1}\rightarrow GPU2
  3. Compute: y = y_{1} + x_{2} * W_{2} = \sum_{i} x_{i} W_{i}

x0 W0

x0 W0
+ x1 W1

GPU2

W2

GPU1

W1

GPU0

W0

x0

x1

x1

Figure 10

🔭 AI-for-Science
source (@tenderizzation)
 

ChatGPT: explain this image

🏗️ Aurora

Table 1: Aurora1 Specs
Property Value
Racks 166
Nodes 10,624
XPUs2 127,488
CPUs 21,248
NICs 84,992
HBM 8 PB
DDR5c 10 PB
Figure 11: Aurora: Fact Sheet.

🌌 AuroraGPT (2024–)

AuroraGPT: General purpose scientific LLM Broadly trained on a general corpora plus scientific {papers, texts, data}

  • Explore pathways towards a “Scientific Assistant” model
  • Build with international partners (RIKEN, BSC, others)
  • Multilingual English, 日本語, French, German, Spanish
  • Multimodal: images, tables, equations, proofs, time series, graphs, fields, sequences, etc
Figure 12: Image from Hannibal046 / Awesome-LLM

🧪 AuroraGPT: Open Science Foundation Model

Figure 13: High-level overview of AuroraGPT project

🧰 AuroraGPT: Toolbox

  • Datasets and data pipelines (how do we deal with scientific data?)
  • Software infrastructure and workflows (scalable, robust, extensible)
  • Evaluation of state-of-the-art LLM Models (how do they perform on scientific tasks?)

🚂 Training

argonne-lcf/Megatron-DeepSpeed
Large Model Training: Any Scale, Any Accelerator

🏃‍♂️ Running

argonne-lcf/inference-endpoints
Inference endpoints for LLMs, hosted @ ALCF

🏋️ Challenges: In Practice

This is incredibly difficult in practice, due in part to:

  • Brand new {hardware, architecture, software}
  • Lack of native support in existing frameworks (though getting better!)
  • General system stability
    +10k Nodes \left(\times \frac{12\,\,\mathrm{XPU}}{1\,\,\mathrm{Node}}\right)\Rightarrow +100k XPUs
    • network performance
    • file system stability (impacted by other users !)
    • many unexpected difficulties occur at increasingly large scales
  • Combinatorial explosion of possible configurations and experiments
    • {hyperparameters, architectures, tokenizers, learning rates, …}

💾 AuroraGPT: Training

  • To train a fixed model on trillions of tokens requires:
    1. Aggregating data from multiple different corpora
      (e.g. ArXiv, Reddit, StackExchange, GitHub, Wikipedia, etc.)
    2. Sampling each training batch according to a fixed distribution across corpora
    3. Building indices that map batches of tokens into these files (indexing)

    The original implementation was slow:

    • Designed to run serially on a single device
    • Major bottleneck when debugging data pipeline at scale

🍹 AuroraGPT: Blending Data, Efficiently

  • 🐢 Original implementation:
    • Slow (serial, single device)
    • ~ 1 hr/2T tokens
  • 🐇 New implementation:
    • Fast! (distributed, asynchronous)
    • ~ 2 min/2T tokens
      (30x faster !!)
Figure 14: Time spent preparing 2T tokens

📉 Loss Curve: Training AuroraGPT-7B on 2T Tokens

Figure 15: Loss curve during training on 2T tokens.

✨ Features

  • 🕸️ Parallelism:
    • {data, tensor, pipeline, sequence, …}
  • ♻️ Checkpoint Converters:
    • Megatron ⇄ 🤗 HF ⇄ ZeRO ⇄ Universal
  • 🔀 DeepSpeed Integration:
    • ZeRO Offloading
    • Activation checkpointing
    • AutoTP (WIP)
    • ability to leverage features from DeepSpeed community

✨ Features (even more!)

  • 🧗 Optimizers1:
    • Support for many different optimizers:
      • Distributed Shampoo, Muon, Adopt, Sophia, Lamb, GaLORE, ScheduleFree, …
    • See full list
    • Large batch training
  • 📊 Experiment Tracking:
    • Automatic experiment and metric tracking with Weights & Biases

🧬 MProt-DPO

🧬 Scaling Results (2024)

Figure 16: Scaling results for 3.5B model across ~38,400 GPUs

🧬 MProt-DPO: Scaling Results

Figure 17: 3.5B model
Figure 18: 7B model

🚂 Loooooooooong Sequence Lengths

25B

33B
Figure 19: Maximum (achievable) SEQ_LEN for both 25B and 33B models (See: Song et al. (2023))

🌎 AERIS (2025)

Figure 20: arXiv:2509.13523
Figure 21: Pixel-level Swin diffusion transformer in sizes from [1–80]B

👀 High-Level Overview of AERIS

Figure 22: Rollout of AERIS model, specific humidity at 700m.
Table 2: Overview of AERIS model and training setup
Property Description
Domain Global
Resolution 0.25° & 1.4°
Training Data ERA5 (1979–2018)
Model Architecture Swin Transformer
Speedup1 O(10k–100k)

➕ Contributions

☔ AERIS

First billion-parameter diffusion model for weather + climate

  • Operates at the pixel level (1 × 1 patch size), guided by physical priors
  • Medium-range forecast skill:
    • Surpasses IFS ENS, competitive with GenCast1
    • Uniquely stable on seasonal scales to 90 days

🌀 SWiPe

A novel 3D (sequence-window-pipeline) parallelism strategy for training transformers across high-resolution inputs

  • Enables scalable small-batch training on large supercomputers2
    • 10.21 ExaFLOPS
    • @ 121,000 Intel XPUs (Aurora)

⚠️ Issues with the Deterministic Approach

  • Transformers:
    • Deterministic
    • Single input → single forecast
  • Diffusion:
    • Probabilistic
    • Single input → ensemble of forecasts
    • Captures uncertainty and variability in weather predictions
    • Enables ensemble forecasting for better risk assessment

🎲 Transitioning to a Probabilistic Model

Figure 23: Reverse diffusion with the input condition, individual sampling steps t_{0} \rightarrow t_{64}, the next time step estimate and the target output.

Reverse Diffusion Process (\mathcal{N}\rightarrow \pi)

Reverse Diffusion Process (\mathcal{N}\rightarrow \pi)

Forward Diffusion Process (\pi\rightarrow \mathcal{N})

Forward Diffusion Process (\pi\rightarrow \mathcal{N})

🌀 Sequence-Window-Pipeline Parallelism SWiPe

  • SWiPe is a novel parallelism strategy for Swin-based Transformers
  • Hybrid 3D Parallelism strategy, combining:
    • Sequence parallelism (SP)
    • Window parallelism (WP)
    • Pipeline parallelism (PP)
Figure 24
Figure 25: SWiPe Communication Patterns

🚀 AERIS: Scaling Results

Figure 26: AERIS: Scaling Results
  • 10 EFLOPs (sustained) @ 120,960 GPUs
  • See (Hatanpää et al. (2025)) for additional details
  • arXiv:2509.13523

🌪️ Hurricane Laura

Figure 27: Hurricane Laura tracks (top) and intensity (bottom). Initialized 7(a), 5(b) and 3(c) days prior to 2020-08-28T00z.

📓 References

Hatanpää, Väinö, Eugene Ku, Jason Stock, Murali Emani, Sam Foreman, Chunyong Jung, Sandeep Madireddy, et al. 2025. “AERIS: Argonne Earth Systems Model for Reliable and Skillful Predictions.” https://arxiv.org/abs/2509.13523.
Price, Ilan, Alvaro Sanchez-Gonzalez, Ferran Alet, Tom R. Andersson, Andrew El-Kadi, Dominic Masters, Timo Ewalds, et al. 2024. “GenCast: Diffusion-Based Ensemble Forecasting for Medium-Range Weather.” https://arxiv.org/abs/2312.15796.
Song, Shuaiwen Leon, Bonnie Kruft, Minjia Zhang, Conglong Li, Shiyang Chen, Chengming Zhang, Masahiro Tanaka, et al. 2023. “DeepSpeed4Science Initiative: Enabling Large-Scale Scientific Discovery Through Sophisticated AI System Technologies.” https://arxiv.org/abs/2310.04610.

❤️ Acknowledgements

This research used resources of the Argonne Leadership Computing Facility, which is a DOE Office of Science User Facility supported under Contract DE-AC02-06CH11357.