Skip to content

Support sharding for auto_trainer #8164

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

zhangbo9674
Copy link
Contributor

@zhangbo9674 zhangbo9674 commented Mar 21, 2024

PR types

New features

PR changes

Others

Description

为自动并行动静统一组网适配 Sharding 策略:
测试:以 A100-40G 机器,对 Llama2 PP4-VPP2-MP1-Sharding_degree2 模型:分别验证 sharding_stage1、2、3下的精度、性能(缩小了 num_hidden_layers=8)
image
image

Copy link

paddle-bot bot commented Mar 21, 2024

Thanks for your contribution!

Copy link

codecov bot commented Mar 21, 2024

Codecov Report

Attention: Patch coverage is 7.14286% with 13 lines in your changes are missing coverage. Please review.

Project coverage is 55.15%. Comparing base (6b8f7f9) to head (c57e9a0).
Report is 2 commits behind head on develop.

Files Patch % Lines
paddlenlp/trainer/training_args.py 0.00% 7 Missing ⚠️
paddlenlp/trainer/auto_trainer.py 0.00% 6 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##           develop    #8164      +/-   ##
===========================================
- Coverage    55.15%   55.15%   -0.01%     
===========================================
  Files          601      601              
  Lines        91764    91764              
===========================================
- Hits         50614    50611       -3     
- Misses       41150    41153       +3     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

if self.sharding_parallel_degree == -1:
if len(self.sharding) > 0:
self.sharding_parallel_degree = self.data_parallel_degree
self.sharding_parallel_degree = world_size // (
self.tensor_parallel_degree * self.sep_parallel_degree * self.pipeline_parallel_degree
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sep_parallel_degree is not supported now, raise error if it is set in auto_parallel mode

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it can be fixed in the next pr

zhiqiu
zhiqiu previously approved these changes Mar 29, 2024
Copy link
Collaborator

@zhiqiu zhiqiu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

if (
ShardingOption.SHARD_OP in self.args.sharding
and not is_new_version_sharding_stage1_optimizer()
and not self.args.enable_auto_parallel
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个 is_new_version_sharding_stage1_optimizer 对于当前版本的paddle应该都是True吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个暂时不确定,这里的改动是针对在动静统一组网的情况,也就是开启 enable_auto_parallel 的情况

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

你是测试在这里遇到了什么问题吗?

if ShardingOption.OFFLOAD in self.sharding:
warnings.warn("`offload` is not supported NOW!")

strategy = fleet.auto.Strategy()
if self.data_parallel_degree > 1:
if self.dataset_world_size > 1:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why? 这里 sharding 和 dp 都当成dp吗?

Copy link
Contributor Author

@zhangbo9674 zhangbo9674 Mar 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

本 PR 改动之前,data_parallel_degree = world_size // (tensor_parallel_degree*pipeline_parallel_degree),本 PR 适配后,data_parallel_degree = world_size // (tensor_parallel_degree*pipeline_parallel_degree*sharding_parallel_degree),因此为了和之前的逻辑保持一致,这里改为了 dataset_world_size(data_parallel_degree*sharding_parallel_degree)
此外我理解,开启 sharding 之后,也相当于多了对应的 data_parallel,因此这里也应该使用 dataset_world_size

ZHUI
ZHUI previously approved these changes Mar 29, 2024
if (
ShardingOption.SHARD_OP in self.args.sharding
and not is_new_version_sharding_stage1_optimizer()
and not self.args.enable_auto_parallel
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

你是测试在这里遇到了什么问题吗?

@zhangbo9674 zhangbo9674 dismissed stale reviews from ZHUI and zhiqiu via c57e9a0 March 29, 2024 08:06
Copy link
Collaborator

@ZHUI ZHUI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Collaborator

@zhiqiu zhiqiu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@wawltor wawltor merged commit 7b493a8 into PaddlePaddle:develop Apr 1, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants