Skip to content

Commit b6f1bd6

Browse files
tsunghsienleefacebook-github-bot
authored andcommitted
Add grafting identicalness test between HSDP2 and DDP
Summary: This diff adds a test to ensure that the grafting method used in HSDP2 is identical to the grafting method used in DDP. Reviewed By: runame Differential Revision: D77082136 fbshipit-source-id: 0d03a4c0c7549c7425c19f0634a49595c98e2a29
1 parent a89d205 commit b6f1bd6

File tree

1 file changed

+9
-1
lines changed

1 file changed

+9
-1
lines changed

distributed_shampoo/utils/gpu_tests/shampoo_hybrid_shard_distributor_test.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
#!/usr/bin/env python3
1111

12+
import math
1213
import re
1314
import unittest
1415
from collections.abc import Callable
@@ -123,6 +124,7 @@ def _shampoo_optim_factory(
123124
| FullyShardShampooConfig
124125
| HybridShardShampooConfig
125126
| None,
127+
start_preconditioning_step: int = 2,
126128
) -> Callable[[ParamsT], torch.optim.Optimizer]:
127129
return partial(
128130
DistributedShampoo,
@@ -133,7 +135,7 @@ def _shampoo_optim_factory(
133135
weight_decay=0.0,
134136
max_preconditioner_dim=PRECONDITIONER_DIM,
135137
precondition_frequency=1,
136-
start_preconditioning_step=2,
138+
start_preconditioning_step=start_preconditioning_step,
137139
use_decoupled_weight_decay=True,
138140
grafting_config=AdaGradGraftingConfig(epsilon=1e-8),
139141
distributed_config=distributed_config,
@@ -232,6 +234,9 @@ def test_hybrid_shampoo_n_by_one_mesh_against_default_shampoo(
232234

233235
@with_comms
234236
@skip_if_lt_x_gpu(4)
237+
@parametrize(
238+
"start_preconditioning_step", (2, math.inf)
239+
) # math.inf here is to test the grafting similarities between HSDP2 and DDP
235240
@parametrize(
236241
"communication_dtype, communicate_params",
237242
(
@@ -247,6 +252,7 @@ def test_hybrid_shampoo_n_by_one_mesh_against_ddp_shampoo(
247252
num_trainers_per_group: int,
248253
communication_dtype: torch.dtype,
249254
communicate_params: bool,
255+
start_preconditioning_step: int,
250256
) -> None:
251257
"""
252258
Testing the correctness of hybrid shard Shampoo distributor of (n, 1) mesh
@@ -269,10 +275,12 @@ def test_hybrid_shampoo_n_by_one_mesh_against_ddp_shampoo(
269275
compare_two_optimizers_models_devices_on_weight_and_loss(
270276
control_optim_factory=ShampooHybridShardDistributorTest._shampoo_optim_factory(
271277
distributed_config=ddp_config,
278+
start_preconditioning_step=start_preconditioning_step,
272279
),
273280
control_model_factory=ShampooHybridShardDistributorTest._construct_model,
274281
experimental_optim_factory=ShampooHybridShardDistributorTest._shampoo_optim_factory(
275282
distributed_config=hybrid_shard_config,
283+
start_preconditioning_step=start_preconditioning_step,
276284
),
277285
experimental_model_factory=partial(
278286
ShampooHybridShardDistributorTest._construct_model,

0 commit comments

Comments
 (0)