Training Foundation Models on Supercomputers

👀 Scaling: Overview

🐢 Training on a Single Device

Data

GPU0

Network

Loss

x0

x1

x2

Figure 1: SLOW !! model size limited by GPU memory

👬 Training on Multiple GPUS: Data Parallelism

Data

GPU2

NN

Loss

GPU1

NN

Loss

GPU0

NN

Loss

x₀

x₁

x₂

Figure 2: Each GPU receives unique data at each step

➡️ Data Parallel: Forward Pass

GPU2

GPU1

GPU0

Data

NN

NN

NN

x₀

x₁

x₂

Loss

Loss

Loss

Avg. Grads
(∑ₙgₙ)/N

Figure 3: Average gradients across all GPUs

⬅️ Data Parallel: Backward Pass

Data

Send Updates

GPU2

NN

Loss

GPU1

NN

Loss

GPU0

NN

Loss

x₀

x₁

x₂

Figure 4: Send global updates back to each GPU. See: PyTorch / Distributed Data Parallel

🔄 Data Parallel: Training

📡 Communication

  • Requires global communication
    • every rank must participate (collective communication) !!
  • Need mechanism(s) for communicating across GPUs:

⌛ Timeouts

  • Collective operations have to be called for each rank to form a complete collective operation.
    • Failure to do so will result in other ranks waiting indefinitely

🚧 Common Pitfalls

  • Each worker needs to be fed a unique batch of data at each step
  • Only perform File I/O on one worker (i.e. rank==0)
    • When loading from a checkpoint, read in on one worker and broadcast to others
  • Collective operations must be called by all workers
    • Ensure that all workers are using the same version of code / libraries

Model + Optim. State

GPU0

GPU 1

CKPT

GPU 2

GPU 3

GPU N

Figure 5: To ensure all workers have the same copies, we load on RANK==0 and broadcast

🎀 Best Practices

  • Use parallel IO whenever possible
    • Feed each rank from different files
    • Use MPI IO to have each rank read its own batch from a file
    • Use several ranks to read data, MPI to scatter to remaining ranks
      • Most practical in big at-scale training
  • Take advantage of data storage
  • Use the right optimizations for Aurora, Polaris, etc.
  • Preload data when possible
    • Offloading to a GPU frees CPU cycles for loading the next batch of data
      • minimize IO latency this way

⏰ Keeping things in Sync

Computation stalls during communication !!

Keeping the communication to computation ratio small is important for effective scaling.

🤔 Plan of Attack

no

yes

yes

no

Model Perfect?

Available Memory?

Done

Make Model Larger

Free Up Memory

TP (or) ZeRO (or) Act. Ckpt.

Figure 6: General strategy for scaling model training

🚀 Going Beyond Data Parallelism

Going beyond Data Parallelism: DeepSpeed + ZeRO

  • Depending on the ZeRO stage (1, 2, 3), we can offload:
    1. Stage 1: optimizer states \left(P_{\mathrm{os}}\right)
    2. Stage 2: gradients + opt. states \left(P_{\mathrm{os}+\mathrm{g}}\right)
    3. Stage 3: model params + grads + opt. states \left(P_{\mathrm{os}+\mathrm{g}+\mathrm{p}}\right)
Figure 7: DeepSpeed + ZeRO

🕸️ Additional Parallelism Strategies

Pipeline Parallelism (PP)

GPU 1

GPU 0

Layer 0

Layer 1

Layer 2

Layer 3

Figure 8: Pipeline Parallelism

Tensor Parallel (TP)

  • Each tensor is split up into multiple chunks
  • Each shard of the tensor resides on its designated GPU
  • During processing each shard gets processed separately (and in parallel) on different GPUs
    • synced at the end of the step
  • See: 🤗 Model Parallelism for additional details

GPU1

GPU0

Layer 0

Layer 1

Layer 2

Layer 3

Layer 0

Layer 1

Layer 2

Layer 3

Figure 9: Tensor Parallel Training

Tensor Parallel (TP)

  • Suitable when the model is too large to fit onto a single device (CPU / GPU)
  • Typically more complicated to implement than data parallel training
    • This is what one may call horizontal parallelism
    • Communication whenever dataflow between two subsets
  • argonne-lcf/Megatron-DeepSpeed
  • 🤗 huggingface/nanotron

GPU1

GPU0

Layer 0

Layer 1

Layer 2

Layer 3

Layer 0

Layer 1

Layer 2

Layer 3

Figure 10: Tensor Parallel Training

Tensor (/ Model) Parallel Training: Example

Want to compute: y = \sum_{i} x_{i} W_{i} = x_0 * W_0 + x_1 * W_1 + x_2 * W_2
where each GPU only has only its portion of the full weights as shown below

  1. Compute: y_{0} = x_{0} * W_{0}\rightarrow GPU1
  2. Compute: y_{1} = y_{0} + x_{1} * W_{1}\rightarrow GPU2
  3. Compute: y = y_{1} + x_{2} * W_{2} = \sum_{i} x_{i} W_{i}

x₀ W₀

x₀ W₀
+ x₁ W₁

GPU2

W2

GPU1

W1

GPU0

W0

x₀

x₁

x₂

Figure 11

🧬 MProt-DPO: Scaling Results

Figure 12: Scaling results for 3.5B model across ~38,400 GPUs

🌎 AERIS: Scaling Results

Figure 13

AERIS: Scaling Results

  • 10 EFLOPs (sustained) @ 120,960 GPUs
  • See (Hatanpää et al. (2025)) for additional details
  • arXiv:2509.13523

🍋 ezpz

Write once, run anywhere

🍋 ezpz @ ALCF

  • Polaris:

    uv venv --python=3.12
    source .venv/bin/activate
    module use /soft/modulefiles
    module load gcc-native cudatoolkit/12.8.1
    uv pip install
    uv pip install --no-cache --link-mode=copy torch torchvision torchaudio transformers deepspeed datasets accelerate torchinfo
    CC=mpicc CXX=mpicxx uv pip install --no-cache --link-mode=copy --no-binary=mpi4py mpi4py
    uv run --with "git+https://github.com/saforem2/ezpz@saforem2/tests" --with "numpy<2" ezpz-test

🐣 Getting Started

  1. Submit interactive job:

    qsub -I -l select=2 -l walltime=01:00:00 \
        -l filesystems=home:flare \
        -A gpu_hack \
        -q gpu_hack_prio
  2. Source1 the ezpz/bin/utils.sh script (using curl to download it2):

    source <(curl -L https://bit.ly/ezpz-utils)

🏖️ Shell Environment

  1. Setup environment:

    ezpz_setup_env

🔍 Environment Setup with ezpz_setup_env

  • Wrapper around ezpz_setup_job && ezpz_setup_python
  1. ezpz_setup_job: Determine the specifics of our active (PBS, SLURM) job1

  2. ezpz_setup_python:

    • if @ ALCF:
      • Load the appropriate modules and activate base conda env
    • else:
      • Look for an active conda environment
        • If found, use it to build a new virtual environment
    • Activate the newly created venvs/$(basename ${CONDA_PREFIX}) environment

⏱️ Working with Job Scheduler(s)

  • ezpz integrates directly with your favorite job scheduler (PBS, slurm)
    • has mechanisms for getting information about our currently running jobs
  • 🪄 Automagically:
    • Determine the specifics of our active (PBS, SLURM) job
      (e.g. ${NHOSTS}, ${NGPU_PER_HOST}, ${NGPUS}, …)
    • Load the appropriate modules1
    • Create (or activate) a virtual environment on top of a base conda environment

🐍 Python Environments

  • ALWAYS work inside a virtual environment
    • best practice is to maintain separate virtual environments for:
      • each project you work on
      • different versions of a specific package you’re working with
        e.g you would want different envs for torch==2.X vs torch==2.Y
    • Mangled python environments are one of the most common issues faced by users

🧪 Simple Distributed Test

  1. Run distributed test:

    ezpz-test
  2. Launch any python from python

    • Launch a module:

      ezpz-launch -m ezpz.test_dist
    • Launch a python string:

      ezpz-launch -c "'import ezpz; ezpz.setup_torch()'"

➕ How to Modify Existing Code

+ import ezpz
+ _ = ezpz.setup_torch()

- model.to('cuda')
+ model.to(ezpz.get_torch_device_type())

✨ Features

  • Initializing PyTorch across multiple processes

    import ezpz
    _ = ezpz.setup_torch()
    rank = ezpz.get_rank()
    world_size = ezpz.get_world_size()
    local_rank = ezpz.get_local_rank()
  • Automatic device detection (xpu, cuda, mps, cpu, …)

    x = torch.rand((10, 10)).to(ezpz.get_torch_device_type())
  • Automatic (single-process) logging

    logger = ezpz.get_logger(__name__)
  • Distributed debugger:

    try:
        buggy_code()
    except Exception:
        ezpz.breakpoint(0)

🧪 Experiment Tracking

import ezpz
rank = ezpz.setup_torch()
logger = ezpz.get_logger(__name__)
if rank == 0:                   # -- [1.] --
    try:
        _ = ezpz.setup_wandb(
            "ezpz.examples.minimal"
        )
    except Exception:
        logger.exception(
            "Failed to initialize wandb, continuing without it"
        )

# ...build {model, optimizer}, etc...

for i in range(train_iters):
    metrics = train_step(...)
    logger.info(                 # -- [2.] --
        history.update(metrics)  # -- [3.] --
    )

if rank == 0:
    history.finalize()
  1. Initialize W&B (if WANDB_DISABLED is not set)
  2. Log summary of metrics to stdout
  3. Update history.history with metrics1

🤏 Minimal Example

import os
import time
import ezpz
import torch

logger = ezpz.get_logger(__name__)


class Network(torch.nn.Module):
    def __init__(
        self,
        input_dim: int,
        output_dim: int,
        sizes: list[int] | None,
    ):
        super(Network, self).__init__()
        nh = output_dim if sizes is None else sizes[0]
        layers = [torch.nn.Linear(input_dim, nh), torch.nn.ReLU()]
        if sizes is not None and len(sizes) > 1:
            for idx, size in enumerate(sizes[1:]):
                layers.extend(
                    [torch.nn.Linear(sizes[idx], size), torch.nn.ReLU()]
                )
            layers.append(torch.nn.Linear(sizes[-1], output_dim))
        self.layers = torch.nn.Sequential(*layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.layers(x)


@ezpz.timeitlogit(rank=ezpz.get_rank())
def train(
    model: torch.nn.Module, optimizer: torch.optim.Optimizer
) -> ezpz.History:
    unwrapped_model = (
        model.module
        if isinstance(model, torch.nn.parallel.DistributedDataParallel)
        else model
    )
    history = ezpz.History()
    device_type = ezpz.get_torch_device_type()
    dtype = unwrapped_model.layers[0].weight.dtype
    bsize = int(os.environ.get("BATCH_SIZE", 64))
    isize = unwrapped_model.layers[0].in_features
    warmup = int(os.environ.get("WARMUP_ITERS", 10))
    log_freq = int(os.environ.get("LOG_FREQ", 1))
    model.train()
    for step in range(int(os.environ.get("TRAIN_ITERS", 500))):
        with torch.autocast(
            device_type=device_type,
            dtype=dtype,
        ):
            t0 = time.perf_counter()
            x = torch.rand((bsize, isize), dtype=dtype).to(device_type)
            y = model(x)
            loss = ((y - x) ** 2).sum()
            dtf = (t1 := time.perf_counter()) - t0
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            dtb = time.perf_counter() - t1
            if step % log_freq == 0 and step > warmup:
                logger.info(
                    history.update(
                        {
                            "iter": step,
                            "loss": loss.item(),
                            "dt": dtf + dtb,
                            "dtf": dtf,
                            "dtb": dtb,
                        }
                    )
                )
    return history


@ezpz.timeitlogit(rank=ezpz.get_rank())
def setup():
    rank = ezpz.setup_torch()
    if os.environ.get("WANDB_DISABLED", False):
        logger.info("WANDB_DISABLED is set, not initializing wandb")
    elif rank == 0:
        try:
            _ = ezpz.setup_wandb(
                project_name=os.environ.get(
                    "PROJECT_NAME", "ezpz.examples.minimal"
                )
            )
        except Exception:
            logger.exception(
                "Failed to initialize wandb, continuing without it"
            )
    device_type = ezpz.get_torch_device_type()
    model = Network(
        input_dim=int((os.environ.get("INPUT_SIZE", 128))),
        output_dim=int(os.environ.get("OUTPUT_SIZE", 128)),
        sizes=[
            int(x)
            for x in os.environ.get("LAYER_SIZES", "1024,512,256,128").split(
                ","
            )
        ],
    )
    model.to(device_type)
    model.to((os.environ.get("DTYPE", torch.bfloat16)))
    logger.info(f"{model=}")
    optimizer = torch.optim.Adam(model.parameters())
    if ezpz.get_world_size() > 1:
        from torch.nn.parallel import DistributedDataParallel as DDP

        model = DDP(model, device_ids=[ezpz.get_local_rank()])

    return model, optimizer


def main():
    model, optimizer = setup()
    history = train(model, optimizer)
    if ezpz.get_rank() == 0:
        dataset = history.finalize()
        logger.info(f"{dataset=}")


if __name__ == "__main__":
    main()

🏃‍♂️ Running the Minimal Example

To run the previous example we:

  1. Source the ezpz utils script:

    source <(curl -L https://bit.ly/ezpz-utils)
  2. Setup our environment:

    ezpz_setup_env
  3. Run the example:

    ezpz-launch -m ezpz.examples.minimal

📝 ezpz-test

  • ezpz-test is a simple test script that trains a small model using DDP across all available GPUs

    • It will automatically detect the number of GPUs and launch an appropriate mpiexec command to run the training script across all GPUs
  • See: ezpz/test.py

  • Command:

    #[🐍 aurora_nre_models_frameworks-2025.0.0](👻 aurora_nre_models_frameworks-2025.0.0)
    #[05/05/25 @ 07:41:35][x4520c1s0b0n0][/f/d/f/p/s/ezpz][🌱 update-utils][📦🤷✓] [⏱️ 54s]
    ; ezpz-test

🦜 Generate Text

  • See: ezpz/generate.py

  • Command:

    python3 -m ezpz.generate --model_name meta-llama/Llama-3.1-8B

🤗 Huggingface Trainer

  • See ezpz/hf_trainer.py

  • Command:

    ezpz-launch -m ezpz.hf_trainer \
        --dataset_name=eliplutchok/fineweb-small-sample \
        --streaming \
        --model_name_or_path=meta-llama/Llama-3.2-1B \
        --bf16=true \
        --do_train=true \
        --do_eval=true \
        --report-to=wandb \
        --logging-steps=1 \
        --include-tokens-per-second=true \
        --block-size=128 \
        --max-steps=10 \
        --include-num-input-tokens-seen=true \
        --auto_find_batch_size=true \
        --gradient_checkpointing=true \
        --optim=adamw_torch \
        --overwrite-output-dir=true \
        --logging-first-step \
        --include-for-metrics='inputs,loss' \
        --max-eval-samples=50 \
        --ddp-backend=ccl

🏎️ Megatron-DeepSpeed

git clone https://github.com/argonne-lcf/Megatron-DeepSpeed
cd Megatron-DeepSpeed
source <(curl -L https://bit.ly/ezpz-utils)
python3 -m pip install -e \
    deepspeed \
    "git+https://github.com/saforem2/ezpz"
bash train_alcf.sh

🙌 Acknowledgements

This research used resources of the Argonne Leadership Computing Facility, which is a DOE Office of Science User Facility supported under Contract DE-AC02-06CH11357.

📓 References

Dharuman, Gautham, Kyle Hippe, Alexander Brace, Sam Foreman, Väinö Hatanpää, Varuni K. Sastry, Huihuo Zheng, et al. 2024. “MProt-DPO: Breaking the ExaFLOPS Barrier for Multimodal Protein Design Workflows with Direct Preference Optimization.” In Proceedings of the International Conference for High Performance Computing, Networking, Storage, and Analysis. SC ’24. Atlanta, GA, USA: IEEE Press. https://doi.org/10.1109/SC41406.2024.00013.
Hatanpää, Väinö, Eugene Ku, Jason Stock, Murali Emani, Sam Foreman, Chunyong Jung, Sandeep Madireddy, et al. 2025. “AERIS: Argonne Earth Systems Model for Reliable and Skillful Predictions.” https://arxiv.org/abs/2509.13523.