9
9
10
10
#!/usr/bin/env python3
11
11
12
+ import math
12
13
import re
13
14
import unittest
14
15
from collections .abc import Callable
@@ -123,6 +124,7 @@ def _shampoo_optim_factory(
123
124
| FullyShardShampooConfig
124
125
| HybridShardShampooConfig
125
126
| None ,
127
+ start_preconditioning_step : int = 2 ,
126
128
) -> Callable [[ParamsT ], torch .optim .Optimizer ]:
127
129
return partial (
128
130
DistributedShampoo ,
@@ -133,7 +135,7 @@ def _shampoo_optim_factory(
133
135
weight_decay = 0.0 ,
134
136
max_preconditioner_dim = PRECONDITIONER_DIM ,
135
137
precondition_frequency = 1 ,
136
- start_preconditioning_step = 2 ,
138
+ start_preconditioning_step = start_preconditioning_step ,
137
139
use_decoupled_weight_decay = True ,
138
140
grafting_config = AdaGradGraftingConfig (epsilon = 1e-8 ),
139
141
distributed_config = distributed_config ,
@@ -232,6 +234,9 @@ def test_hybrid_shampoo_n_by_one_mesh_against_default_shampoo(
232
234
233
235
@with_comms
234
236
@skip_if_lt_x_gpu (4 )
237
+ @parametrize (
238
+ "start_preconditioning_step" , (2 , math .inf )
239
+ ) # math.inf here is to test the grafting similarities between HSDP2 and DDP
235
240
@parametrize (
236
241
"communication_dtype, communicate_params" ,
237
242
(
@@ -247,6 +252,7 @@ def test_hybrid_shampoo_n_by_one_mesh_against_ddp_shampoo(
247
252
num_trainers_per_group : int ,
248
253
communication_dtype : torch .dtype ,
249
254
communicate_params : bool ,
255
+ start_preconditioning_step : int ,
250
256
) -> None :
251
257
"""
252
258
Testing the correctness of hybrid shard Shampoo distributor of (n, 1) mesh
@@ -269,10 +275,12 @@ def test_hybrid_shampoo_n_by_one_mesh_against_ddp_shampoo(
269
275
compare_two_optimizers_models_devices_on_weight_and_loss (
270
276
control_optim_factory = ShampooHybridShardDistributorTest ._shampoo_optim_factory (
271
277
distributed_config = ddp_config ,
278
+ start_preconditioning_step = start_preconditioning_step ,
272
279
),
273
280
control_model_factory = ShampooHybridShardDistributorTest ._construct_model ,
274
281
experimental_optim_factory = ShampooHybridShardDistributorTest ._shampoo_optim_factory (
275
282
distributed_config = hybrid_shard_config ,
283
+ start_preconditioning_step = start_preconditioning_step ,
276
284
),
277
285
experimental_model_factory = partial (
278
286
ShampooHybridShardDistributorTest ._construct_model ,
0 commit comments