-
Notifications
You must be signed in to change notification settings - Fork 3.1k
[Unified Checkpoint] Support peft model #7691
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Thanks for your contribution! |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## develop #7691 +/- ##
===========================================
- Coverage 56.56% 56.42% -0.15%
===========================================
Files 589 589
Lines 89964 90252 +288
===========================================
+ Hits 50889 50921 +32
- Misses 39075 39331 +256 ☔ View full report in Codecov by Sentry. |
需要确认peft模型在 from_pretrained 时能否正确加载safetensors格式。 |
eaff4b1
to
4d531a6
Compare
82e6c32
to
7b2038c
Compare
03a7a10
to
4961ba8
Compare
4961ba8
to
4e17537
Compare
4e17537
to
1ad6aa8
Compare
b77f44e
to
17616c8
Compare
if self.args.unified_checkpoint and "skip_save_model_weight" in self.args.unified_checkpoint_config: | ||
raise ValueError( | ||
"We do not support skip_save_model_weight in peft model when using unified checkpoint." | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里对peft的一些支持功能做一些限制吧。比如一些动态扩缩容之类的,我们暂时没有必要支持。
看此处,要不直接删除 skip_save_model_weight 字段,改为 warning
@@ -2401,6 +2421,8 @@ def _load_optimizer_and_scheduler(self, checkpoint): | |||
opt_state_dict = tmp | |||
|
|||
# broadcast optimizer state in dp group | |||
if self.args.local_rank != -1: | |||
dist.barrier() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why?
paddlenlp/trainer/trainer.py
Outdated
weights_index_name = PADDLE_WEIGHTS_INDEX_NAME if not safe_serialization else SAFE_WEIGHTS_INDEX_NAME | ||
if isinstance(self.model, LoRAModel): | ||
weights_index_name = ( | ||
PADDLE_LORA_WEIGHTS_INDEX_NAME if not safe_serialization else SAFE_LORA_WEIGHTS_INDEX_NAME |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
如之前讨论,建议 PEFT 存为一种名字
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不建议这么做了,因为原先的写法中 lora、ptuning 就是区分开不同模型文件名称,要保证和原来的写法统一。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不建议这么做了,因为原先的写法中 lora、ptuning 就是区分开不同模型文件名称,要保证和原来的写法统一。
修改统一之后影响面有多大,哪些是无法兼容的?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
那就这样吧,原先的 pdparams 的名称不动,unified checkpoint 这边就统一成一样的名字。peft_model.xxx.
paddlenlp/trainer/trainer.py
Outdated
if distributed_isfile(weights_index_file) or distributed_isfile(master_weights_index_file): | ||
is_unified_checkpoint_type = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
之前的 distributed_isfile 不支持单卡是不?直接修改这个函数支持单卡好了。
17616c8
to
542a047
Compare
8f8762e
to
8010d07
Compare
8010d07
to
66c9a58
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm
SAFE_PEFT_WEIGHTS_NAME = "peft_model.safetensors" | ||
SAFE_PEFT_WEIGHTS_INDEX_NAME = "peft_model.safetensors.index.json" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SAFE_PEFT_WEIGHTS_NAME = "peft_model.safetensors" | |
SAFE_PEFT_WEIGHTS_INDEX_NAME = "peft_model.safetensors.index.json" | |
SAFE_PEFT_WEIGHTS_NAME = "peft_model.safetensors" | |
SAFE_PEFT_WEIGHTS_INDEX_NAME = "peft_model.safetensors.index.json" | |
PADDLE_PEFT_WEIGHTS_NAME = "peft_model.pdparams" | |
PADDLE_PEFT_WEIGHTS_INDEX_NAME = "peft_model.pdparams.index.json" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
b48776e
to
8240590
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
Function optimization
PR changes
Others
Description
Support peft model save and load in unified checkpoint.