๐Ÿš‘ Torchtune Patch on Aurora

Aurora
pytorch
bug
Author
Affiliation
Published

March 23, 2025

Modified

April 7, 2025

Patch to get torchtune working on Aurora

diff --git a/torchtune/training/_distributed.py b/torchtune/training/_distributed.py
index ff959c5f..c3966290 100644
--- a/torchtune/training/_distributed.py
+++ b/torchtune/training/_distributed.py
@@ -14,7 +14,11 @@ import torch
 import torch.distributed as dist
 from torch import nn

-from torch.distributed._composable.fsdp import CPUOffloadPolicy, fully_shard
+try:
+    from torch.distributed._composable.fsdp import fully_shard
+except (ImportError, ModuleNotFoundError):
+    from torch.distributed._composable.fsdp.fully_shard import fully_shard
+
 from torch.distributed._tensor import distribute_tensor, DTensor
 from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta
 from torch.distributed.checkpoint.state_dict import (
@@ -532,6 +536,11 @@ def shard_model(
     """
     fsdp_kwargs = {"reshard_after_forward": reshard_after_forward}
     if cpu_offload:
+        try:
+            from torch.distributed._composable.fsdp import CPUOffloadPolicy
+        except (ImportError, ModuleNotFoundError):
+            from torch.distributed._composable.fsdp._fsdp_api import MixedPrecisionPolicy, CPUOffloadPolicy
+            # from torch.distributed._composable import CPUOffloadPolicy
         fsdp_kwargs["offload_policy"] = CPUOffloadPolicy()

     # Shard the model with FSDP, iterating in reverse to start with

Citation

BibTeX citation:
@online{foreman2025,
  author = {Foreman, Sam},
  title = {๐Ÿš‘ {Torchtune} {Patch} on {Aurora}},
  date = {2025-03-23},
  url = {https://samforeman.me/posts/torchtune-patch-aurora/},
  langid = {en}
}
For attribution, please cite this work as:
Foreman, Sam. 2025. โ€œ๐Ÿš‘ Torchtune Patch on Aurora.โ€ March 23, 2025. https://samforeman.me/posts/torchtune-patch-aurora/.