Training Foundation Models on Supercomputers

Sam Foreman
[email protected]

ALCF

2025-09-24

👀 Scaling: Overview

  • ✅ Goal:
    • Minimize: Cost (i.e. amount of time spent training)
    • Maximize: Performance

    📑 Note

    See 🤗 Performance and Scalability for more details

🐢 Training on a Single Device

  • See also:
    • Scientific AI at Scale: Distributed Training
    • 🤗 Methods and tools for efficient training on a single GPU

Figure 1: SLOW !! model size limited by GPU memory

👬 Training on Multiple GPUS: Data Parallelism

Figure 2: Each GPU receives unique data at each step

➡️ Data Parallel: Forward Pass

Figure 3: Average gradients across all GPUs

⬅️ Data Parallel: Backward Pass

Figure 4: Send global updates back to each GPU. See: PyTorch / Distributed Data Parallel

🔄 Data Parallel: Training

  • Each GPU:
    • has identical copy of model
    • works on a unique subset of data
  • Easy to get started (minor modifications to code):
    • saforem2/ezpz
    • 🔥 PyTorch / DDP
    • 🤗 HF / Accelerate
    • Microsoft / DeepSpeed

📡 Communication

  • Requires global communication
    • every rank must participate (collective communication) !!
  • Need mechanism(s) for communicating across GPUs:
    • mpi4py
    • torch.distributed
  • Collective Communication:
    • Nvidia: Collective Communications Library (NCCL)
    • Intel: oneAPI Collective Communications Library (oneCCL)
    • AMD: ROCm Communication Collectives Library (RCC)

⌛ Timeouts

  • Collective operations have to be called for each rank to form a complete collective operation.
    • Failure to do so will result in other ranks waiting indefinitely

🚧 Common Pitfalls

  • Each worker needs to be fed a unique batch of data at each step
  • Only perform File I/O on one worker (i.e. rank==0)
    • When loading from a checkpoint, read in on one worker and broadcast to others
  • Collective operations must be called by all workers
    • Ensure that all workers are using the same version of code / libraries

Figure 5: To ensure all workers have the same copies, we load on RANK==0 and broadcast

🎀 Best Practices

  • Use parallel IO whenever possible
    • Feed each rank from different files
    • Use MPI IO to have each rank read its own batch from a file
    • Use several ranks to read data, MPI to scatter to remaining ranks
      • Most practical in big at-scale training
  • Take advantage of data storage
    • Use striping on lustre
  • Use the right optimizations for Aurora, Polaris, etc.
  • Preload data when possible
    • Offloading to a GPU frees CPU cycles for loading the next batch of data
      • minimize IO latency this way

⏰ Keeping things in Sync

Computation stalls during communication !!

Keeping the communication to computation ratio small is important for effective scaling.

🤔 Plan of Attack

Figure 6: General strategy for scaling model training

🚀 Going Beyond Data Parallelism

  • ✅ Useful when model fits on single GPU:
    • ultimately limited by GPU memory
    • model performance limited by size
  • ⚠️ When model does not fit on a single GPU:
    • Offloading (can only get you so far…):
      • DeepSpeed + ZeRO
      • 🔥 PyTorch + FSDP
    • Otherwise, resort to model parallelism strategies

Going beyond Data Parallelism: DeepSpeed + ZeRO

  • Depending on the ZeRO stage (1, 2, 3), we can offload:
    1. Stage 1: optimizer states (Pos)\left(P_{\mathrm{os}}\right)(Pos​)
    2. Stage 2: gradients + opt. states (Pos+g)\left(P_{\mathrm{os}+\mathrm{g}}\right)(Pos+g​)
    3. Stage 3: model params + grads + opt. states (Pos+g+p)\left(P_{\mathrm{os}+\mathrm{g}+\mathrm{p}}\right)(Pos+g+p​)
Figure 7: DeepSpeed + ZeRO

🕸️ Additional Parallelism Strategies

  • Tensor (/ Model) Parallelism (TP):
    • 🤗 Tensor Parallelism
    • 🔥 Large Scale Transformer model training with Tensor Parallel (TP)
  • Pipeline Parallelism (PP):
    • 🔥 PyTorch, DeepSpeed
  • Sequence Parallelism (SP):
    • DeepSpeed Ulysses
    • Megatron / Context Parallelism
    • Unified Sequence Parallel (USP)
      • feifeibear/long-context-attention
    • Supports 4D Parallelism (DP + TP + PP + SP)

Pipeline Parallelism (PP)

  • Model is split up vertically (layer-level) across multiple GPUs
  • Each GPU:
    • has a portion of the full model
    • processes in parallel different stages of the pipeline (on a small chunk of the batch)
  • See:
    • 🔥 PyTorch / Pipeline Parallelism
    • DeepSpeed / Pipeline Parallelism

Figure 8: Pipeline Parallelism

Tensor Parallel (TP)

  • Each tensor is split up into multiple chunks
  • Each shard of the tensor resides on its designated GPU
  • During processing each shard gets processed separately (and in parallel) on different GPUs
    • synced at the end of the step
  • See: 🤗 Model Parallelism for additional details

Figure 9: Tensor Parallel Training

Tensor Parallel (TP)

  • Suitable when the model is too large to fit onto a single device (CPU / GPU)
  • Typically more complicated to implement than data parallel training
    • This is what one may call horizontal parallelism
    • Communication whenever dataflow between two subsets
  • argonne-lcf/Megatron-DeepSpeed
  • 🤗 huggingface/nanotron

Figure 10: Tensor Parallel Training
  • Split up network over multiple workers
  • Each receives disjoint subset
  • All communication associated with subsets are distributed
  • Communication whenever dataflow between two subsets
  • Typically more complicated to implement than data parallel training
  • Suitable when the model is too large to fit onto a single device (CPU / GPU)

Tensor (/ Model) Parallel Training: Example

Want to compute: y=∑ixiWi=x0∗W0+x1∗W1+x2∗W2y = \sum_{i} x_{i} W_{i} = x_0 * W_0 + x_1 * W_1 + x_2 * W_2y=∑i​xi​Wi​=x0​∗W0​+x1​∗W1​+x2​∗W2​
where each GPU only has only its portion of the full weights as shown below

  1. Compute: y0=x0∗W0→y_{0} = x_{0} * W_{0}\rightarrowy0​=x0​∗W0​→ GPU1
  2. Compute: y1=y0+x1∗W1→y_{1} = y_{0} + x_{1} * W_{1}\rightarrowy1​=y0​+x1​∗W1​→ GPU2
  3. Compute: y=y1+x2∗W2=∑ixiWiy = y_{1} + x_{2} * W_{2} = \sum_{i} x_{i} W_{i}y=y1​+x2​∗W2​=∑i​xi​Wi​ ✅

Figure 11

🧬 MProt-DPO: Scaling Results

Figure 12: Scaling results for 3.5B model across ~38,400 GPUs
  • ~ 4 EFLOPS @ Aurora

  • 38,400 XPUs
    = 3200 [node]
    x 12 [XPU / node]

  • 🔔 2024 ACM Gordon Bell Finalist (Dharuman et al. (2024)):

  • MProt-DPO: Breaking the ExaFLOPS Barrier for Multimodal Protein Design Workflows

🌎 AERIS: Scaling Results

Figure 13: AERIS: Scaling Results
  • 10 EFLOPs (sustained) @ 120,960 GPUs
  • See (Hatanpää et al. (2025)) for additional details
  • arXiv:2509.13523

🍋 ezpz

Write once, run anywhere

  • Setup (optional1):

    source <(curl -L https://bit.ly/ezpz-utils)
    ezpz_setup_env
  • Install:

    uv pip install "git+https://github.com/saforem2/ezpz" --no-cache --link-mode=copy
  • See also:

    • saforem2/ezpz
    • LLMS on Aurora: Hands-On
    • ezpz docs
  1. Can be skipped if you already have an environment with torch + mpi4py

🍋 ezpz @ ALCF

  • Polaris:

    uv venv --python=3.12
    source .venv/bin/activate
    module use /soft/modulefiles
    module load gcc-native cudatoolkit/12.8.1
    uv pip install
    uv pip install --no-cache --link-mode=copy torch torchvision torchaudio transformers deepspeed datasets accelerate torchinfo
    CC=mpicc CXX=mpicxx uv pip install --no-cache --link-mode=copy --no-binary=mpi4py mpi4py
    uv run --with "git+https://github.com/saforem2/ezpz@saforem2/tests" --with "numpy<2" ezpz-test

🐣 Getting Started

  1. Submit interactive job:

    qsub -I -l select=2 -l walltime=01:00:00 \
        -l filesystems=home:flare \
        -A gpu_hack \
        -q gpu_hack_prio
  2. Source1 the ezpz/bin/utils.sh script (using curl to download it2):

    source <(curl -L https://bit.ly/ezpz-utils)
  1. In general, you should be wary of running random scripts from the internet.

  2. https://bit.ly/ezpz-utils, since https://raw.githubusercontent.com/saforem2/ezpz/main/bin/utils.sh is a bit of a pain

🏖️ Shell Environment

  1. Setup environment:

    ezpz_setup_env

🔍 Environment Setup with ezpz_setup_env

  • Wrapper around ezpz_setup_job && ezpz_setup_python
  1. ezpz_setup_job: Determine the specifics of our active (PBS, SLURM) job1

  2. ezpz_setup_python:

    • if @ ALCF:
      • Load the appropriate modules and activate base conda env
    • else:
      • Look for an active conda environment
        • If found, use it to build a new virtual environment
    • Activate the newly created venvs/$(basename ${CONDA_PREFIX}) environment
  1. e.g. ${NHOSTS}, ${NGPU_PER_HOST}, ${NGPUS}, …

⏱️ Working with Job Scheduler(s)

  • ezpz integrates directly with your favorite job scheduler (PBS, slurm)
    • has mechanisms for getting information about our currently running jobs
  • 🪄 Automagically:
    • Determine the specifics of our active (PBS, SLURM) job
      (e.g. ${NHOSTS}, ${NGPU_PER_HOST}, ${NGPUS}, …)
    • Load the appropriate modules1
    • Create (or activate) a virtual environment on top of a base conda environment
  1. On any of the ALCF systems, including: Aurora, Polaris, …, etc.

🐍 Python Environments

  • ALWAYS work inside a virtual environment
    • best practice is to maintain separate virtual environments for:
      • each project you work on
      • different versions of a specific package you’re working with
        e.g you would want different envs for torch==2.X vs torch==2.Y
    • Mangled python environments are one of the most common issues faced by users

🧪 Simple Distributed Test

  1. Run distributed test:

    ezpz-test
  2. Launch any python from python

    • Launch a module:

      ezpz-launch -m ezpz.test_dist
    • Launch a python string:

      ezpz-launch -c "'import ezpz; ezpz.setup_torch()'"

➕ How to Modify Existing Code

+ import ezpz
+ _ = ezpz.setup_torch()

- model.to('cuda')
+ model.to(ezpz.get_torch_device_type())

✨ Features

  • Initializing PyTorch across multiple processes

    import ezpz
    _ = ezpz.setup_torch()
    rank = ezpz.get_rank()
    world_size = ezpz.get_world_size()
    local_rank = ezpz.get_local_rank()
  • Automatic device detection (xpu, cuda, mps, cpu, …)

    x = torch.rand((10, 10)).to(ezpz.get_torch_device_type())
  • Automatic (single-process) logging

    logger = ezpz.get_logger(__name__)
  • Distributed debugger:

    try:
        buggy_code()
    except Exception:
        ezpz.breakpoint(0)

🧪 Experiment Tracking

import ezpz
rank = ezpz.setup_torch()
logger = ezpz.get_logger(__name__)
if rank == 0:                   # -- [1.] --
    try:
        _ = ezpz.setup_wandb(
            "ezpz.examples.minimal"
        )
    except Exception:
        logger.exception(
            "Failed to initialize wandb, continuing without it"
        )

# ...build {model, optimizer}, etc...

for i in range(train_iters):
    metrics = train_step(...)
    logger.info(                 # -- [2.] --
        history.update(metrics)  # -- [3.] --
    )

if rank == 0:
    history.finalize()
  1. Initialize W&B (if WANDB_DISABLED is not set)
  2. Log summary of metrics to stdout
  3. Update history.history with metrics1
  1. Will automatically be reported to W&B if a run is detected

🤏 Minimal Example

  • See ezpz/examples/minimal.py
import os
import time
import ezpz
import torch

logger = ezpz.get_logger(__name__)


class Network(torch.nn.Module):
    def __init__(
        self,
        input_dim: int,
        output_dim: int,
        sizes: list[int] | None,
    ):
        super(Network, self).__init__()
        nh = output_dim if sizes is None else sizes[0]
        layers = [torch.nn.Linear(input_dim, nh), torch.nn.ReLU()]
        if sizes is not None and len(sizes) > 1:
            for idx, size in enumerate(sizes[1:]):
                layers.extend(
                    [torch.nn.Linear(sizes[idx], size), torch.nn.ReLU()]
                )
            layers.append(torch.nn.Linear(sizes[-1], output_dim))
        self.layers = torch.nn.Sequential(*layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.layers(x)


@ezpz.timeitlogit(rank=ezpz.get_rank())
def train(
    model: torch.nn.Module, optimizer: torch.optim.Optimizer
) -> ezpz.History:
    unwrapped_model = (
        model.module
        if isinstance(model, torch.nn.parallel.DistributedDataParallel)
        else model
    )
    history = ezpz.History()
    device_type = ezpz.get_torch_device_type()
    dtype = unwrapped_model.layers[0].weight.dtype
    bsize = int(os.environ.get("BATCH_SIZE", 64))
    isize = unwrapped_model.layers[0].in_features
    warmup = int(os.environ.get("WARMUP_ITERS", 10))
    log_freq = int(os.environ.get("LOG_FREQ", 1))
    model.train()
    for step in range(int(os.environ.get("TRAIN_ITERS", 500))):
        with torch.autocast(
            device_type=device_type,
            dtype=dtype,
        ):
            t0 = time.perf_counter()
            x = torch.rand((bsize, isize), dtype=dtype).to(device_type)
            y = model(x)
            loss = ((y - x) ** 2).sum()
            dtf = (t1 := time.perf_counter()) - t0
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            dtb = time.perf_counter() - t1
            if step % log_freq == 0 and step > warmup:
                logger.info(
                    history.update(
                        {
                            "iter": step,
                            "loss": loss.item(),
                            "dt": dtf + dtb,
                            "dtf": dtf,
                            "dtb": dtb,
                        }
                    )
                )
    return history


@ezpz.timeitlogit(rank=ezpz.get_rank())
def setup():
    rank = ezpz.setup_torch()
    if os.environ.get("WANDB_DISABLED", False):
        logger.info("WANDB_DISABLED is set, not initializing wandb")
    elif rank == 0:
        try:
            _ = ezpz.setup_wandb(
                project_name=os.environ.get(
                    "PROJECT_NAME", "ezpz.examples.minimal"
                )
            )
        except Exception:
            logger.exception(
                "Failed to initialize wandb, continuing without it"
            )
    device_type = ezpz.get_torch_device_type()
    model = Network(
        input_dim=int((os.environ.get("INPUT_SIZE", 128))),
        output_dim=int(os.environ.get("OUTPUT_SIZE", 128)),
        sizes=[
            int(x)
            for x in os.environ.get("LAYER_SIZES", "1024,512,256,128").split(
                ","
            )
        ],
    )
    model.to(device_type)
    model.to((os.environ.get("DTYPE", torch.bfloat16)))
    logger.info(f"{model=}")
    optimizer = torch.optim.Adam(model.parameters())
    if ezpz.get_world_size() > 1:
        from torch.nn.parallel import DistributedDataParallel as DDP

        model = DDP(model, device_ids=[ezpz.get_local_rank()])

    return model, optimizer


def main():
    model, optimizer = setup()
    history = train(model, optimizer)
    if ezpz.get_rank() == 0:
        dataset = history.finalize()
        logger.info(f"{dataset=}")


if __name__ == "__main__":
    main()

🏃‍♂️ Running the Minimal Example

To run the previous example we:

  1. Source the ezpz utils script:

    source <(curl -L https://bit.ly/ezpz-utils)
  2. Setup our environment:

    ezpz_setup_env
  3. Run the example:

    ezpz-launch -m ezpz.examples.minimal

📝 ezpz-test

  • ezpz-test is a simple test script that trains a small model using DDP across all available GPUs

    • It will automatically detect the number of GPUs and launch an appropriate mpiexec command to run the training script across all GPUs
  • See: ezpz/test.py

  • Command:

    #[🐍 aurora_nre_models_frameworks-2025.0.0](👻 aurora_nre_models_frameworks-2025.0.0)
    #[05/05/25 @ 07:41:35][x4520c1s0b0n0][/f/d/f/p/s/ezpz][🌱 update-utils][📦🤷✓] [⏱️ 54s]
    ; ezpz-test

🦜 Generate Text

  • See: ezpz/generate.py

  • Command:

    python3 -m ezpz.generate --model_name meta-llama/Llama-3.1-8B

🤗 Huggingface Trainer

  • See ezpz/hf_trainer.py

  • Command:

    ezpz-launch -m ezpz.hf_trainer \
        --dataset_name=eliplutchok/fineweb-small-sample \
        --streaming \
        --model_name_or_path=meta-llama/Llama-3.2-1B \
        --bf16=true \
        --do_train=true \
        --do_eval=true \
        --report-to=wandb \
        --logging-steps=1 \
        --include-tokens-per-second=true \
        --block-size=128 \
        --max-steps=10 \
        --include-num-input-tokens-seen=true \
        --auto_find_batch_size=true \
        --gradient_checkpointing=true \
        --optim=adamw_torch \
        --overwrite-output-dir=true \
        --logging-first-step \
        --include-for-metrics='inputs,loss' \
        --max-eval-samples=50 \
        --ddp-backend=ccl

🏎️ Megatron-DeepSpeed

git clone https://github.com/argonne-lcf/Megatron-DeepSpeed
cd Megatron-DeepSpeed
source <(curl -L https://bit.ly/ezpz-utils)
python3 -m pip install -e \
    deepspeed \
    "git+https://github.com/saforem2/ezpz"
bash train_alcf.sh

🙌 Acknowledgements

This research used resources of the Argonne Leadership Computing Facility, which is a DOE Office of Science User Facility supported under Contract DE-AC02-06CH11357.

📓 References

Dharuman, Gautham, Kyle Hippe, Alexander Brace, Sam Foreman, Väinö Hatanpää, Varuni K. Sastry, Huihuo Zheng, et al. 2024. “MProt-DPO: Breaking the ExaFLOPS Barrier for Multimodal Protein Design Workflows with Direct Preference Optimization.” In Proceedings of the International Conference for High Performance Computing, Networking, Storage, and Analysis. SC ’24. Atlanta, GA, USA: IEEE Press. https://doi.org/10.1109/SC41406.2024.00013.
Hatanpää, Väinö, Eugene Ku, Jason Stock, Murali Emani, Sam Foreman, Chunyong Jung, Sandeep Madireddy, et al. 2025. “AERIS: Argonne Earth Systems Model for Reliable and Skillful Predictions.” https://arxiv.org/abs/2509.13523.

samforeman.me/talks/2025/09/24/slides

1
Training Foundation Models on Supercomputers Sam Foreman [email protected] ALCF 2025-09-24

  1. Slides

  2. Tools

  3. Close
  • Training Foundation Models on Supercomputers
  • 👀 Scaling: Overview
  • 🐢 Training on a Single Device
  • 👬 Training on Multiple GPUS: Data Parallelism
  • ➡️ Data Parallel: Forward Pass
  • ⬅️ Data Parallel: Backward Pass
  • 🔄 Data Parallel: Training
  • 📡 Communication
  • 🚧 Common Pitfalls
  • 🎀 Best Practices
  • 🤔 Plan of Attack
  • 🚀 Going Beyond Data Parallelism
  • Going beyond Data Parallelism: DeepSpeed + ZeRO
  • 🕸️ Additional Parallelism Strategies
  • Pipeline Parallelism (PP)
  • Tensor Parallel (TP)
  • Tensor Parallel (TP)
  • Tensor (/ Model) Parallel Training: Example
  • 🧬 MProt-DPO: Scaling Results
  • 🌎 AERIS: Scaling Results
  • 🍋 ezpz
  • 🍋 ezpz @ ALCF
  • 🐣 Getting Started
  • 🏖️ Shell Environment
  • 🔍 Environment Setup with ezpz_setup_env
  • ⏱️ Working with Job Scheduler(s)
  • 🐍 Python Environments
  • 🧪 Simple Distributed Test
  • ➕ How to Modify Existing Code
  • ✨ Features
  • 🧪 Experiment Tracking
  • 🤏 Minimal Example
  • 🏃‍♂️ Running the Minimal Example
  • 📝 ezpz-test
  • 🦜 Generate Text
  • 🤗 Huggingface Trainer
  • 🏎️ Megatron-DeepSpeed
  • 🙌 Acknowledgements
  • 📓 References
  • f Fullscreen
  • s Speaker View
  • o Slide Overview
  • e PDF Export Mode
  • r Scroll View Mode
  • b Toggle Chalkboard
  • c Toggle Notes Canvas
  • d Download Drawings
  • ? Keyboard Help