File tree Expand file tree Collapse file tree 1 file changed +3
-3
lines changed
distributed_shampoo/utils Expand file tree Collapse file tree 1 file changed +3
-3
lines changed Original file line number Diff line number Diff line change @@ -140,9 +140,9 @@ def _get_params_or_grads(
140
140
def _get_params_or_grads (self , get_grad : bool = False ) -> Iterable [Tensor | None ]:
141
141
"""Helper function that gets params or grads from the parameter group.
142
142
143
- NOTE: The purpose of this function is for FullyShardShampooDistributor (supporting
144
- Shampoo on per-parameter FSDP, a.k.a. FSDP2 or FullyShard) to override, in order to
145
- get the local params/grads from DTensors.
143
+ NOTE: The purpose of this function is for FullyShardDistributor (supporting Shampoo on
144
+ per-parameter FSDP, a.k.a. FSDP2 or FullyShard) to override, in order to get the local
145
+ params/grads from DTensors.
146
146
147
147
By default, we just return the original params/grads.
148
148
You can’t perform that action at this time.
0 commit comments