Skip to content

Commit 35bd72b

Browse files
haoyuzfacebook-github-bot
authored andcommitted
Fix device mesh setup in hybrid Cifar10 example (#192)
Summary: Pull Request resolved: #192 Device mesh shape should be `(replica, WORLD_SIZE // replica)` for consistency across ranks. Reviewed By: tsunghsienlee Differential Revision: D77030402 fbshipit-source-id: bba8555af80864d5333a98c6b06b46bb1ea6c53e
1 parent cedb8d3 commit 35bd72b

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

distributed_shampoo/examples/hybrid_shard_cifar10_example.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ def create_model_and_optimizer_and_loss_fn(
199199
# initialize device_mesh for hybrid shard data parallel
200200
device_mesh: DeviceMesh = init_device_mesh(
201201
"cuda",
202-
(args.dp_replicate_degree, WORLD_RANK // args.dp_replicate_degree),
202+
(args.dp_replicate_degree, WORLD_SIZE // args.dp_replicate_degree),
203203
mesh_dim_names=("dp_replicate", "dp_shard"),
204204
)
205205

0 commit comments

Comments
 (0)