# 🚑 Torchtune Patch on Aurora
Sam Foreman
2025-03-23

- [Patch to get `torchtune` working on
  Aurora](#patch-to-get-torchtune-working-on-aurora)

## Patch to get `torchtune` working on Aurora

``` diff
diff --git a/torchtune/training/_distributed.py b/torchtune/training/_distributed.py
index ff959c5f..c3966290 100644
--- a/torchtune/training/_distributed.py
+++ b/torchtune/training/_distributed.py
@@ -14,7 +14,11 @@ import torch
 import torch.distributed as dist
 from torch import nn

-from torch.distributed._composable.fsdp import CPUOffloadPolicy, fully_shard
+try:
+    from torch.distributed._composable.fsdp import fully_shard
+except (ImportError, ModuleNotFoundError):
+    from torch.distributed._composable.fsdp.fully_shard import fully_shard
+
 from torch.distributed._tensor import distribute_tensor, DTensor
 from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta
 from torch.distributed.checkpoint.state_dict import (
@@ -532,6 +536,11 @@ def shard_model(
     """
     fsdp_kwargs = {"reshard_after_forward": reshard_after_forward}
     if cpu_offload:
+        try:
+            from torch.distributed._composable.fsdp import CPUOffloadPolicy
+        except (ImportError, ModuleNotFoundError):
+            from torch.distributed._composable.fsdp._fsdp_api import MixedPrecisionPolicy, CPUOffloadPolicy
+            # from torch.distributed._composable import CPUOffloadPolicy
         fsdp_kwargs["offload_policy"] = CPUOffloadPolicy()

     # Shard the model with FSDP, iterating in reverse to start with
```
