๐ Torchtune Patch on Aurora
Aurora
pytorch
bug
Patch to get torchtune
working on Aurora
diff --git a/torchtune/training/_distributed.py b/torchtune/training/_distributed.py
index ff959c5f..c3966290 100644
--- a/torchtune/training/_distributed.py
+++ b/torchtune/training/_distributed.py
@@ -14,7 +14,11 @@ import torch
import torch.distributed as dist
from torch import nn
-from torch.distributed._composable.fsdp import CPUOffloadPolicy, fully_shard
+try:
+ from torch.distributed._composable.fsdp import fully_shard
+except (ImportError, ModuleNotFoundError):
+ from torch.distributed._composable.fsdp.fully_shard import fully_shard
+
from torch.distributed._tensor import distribute_tensor, DTensor
from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta
from torch.distributed.checkpoint.state_dict import (
@@ -532,6 +536,11 @@ def shard_model(
"""
fsdp_kwargs = {"reshard_after_forward": reshard_after_forward}
if cpu_offload:
+ try:
+ from torch.distributed._composable.fsdp import CPUOffloadPolicy
+ except (ImportError, ModuleNotFoundError):
+ from torch.distributed._composable.fsdp._fsdp_api import MixedPrecisionPolicy, CPUOffloadPolicy
+ # from torch.distributed._composable import CPUOffloadPolicy
fsdp_kwargs["offload_policy"] = CPUOffloadPolicy()
# Shard the model with FSDP, iterating in reverse to start with
Citation
BibTeX citation:
@online{foreman2025,
author = {Foreman, Sam},
title = {๐ {Torchtune} {Patch} on {Aurora}},
date = {2025-03-23},
url = {https://samforeman.me/posts/torchtune-patch-aurora/},
langid = {en}
}
For attribution, please cite this work as:
Foreman, Sam. 2025. โ๐ Torchtune Patch on Aurora.โ March
23, 2025. https://samforeman.me/posts/torchtune-patch-aurora/.