You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
PyTorch version: 2.2.0
Is debug build: False
CUDA used to build PyTorch: 11.8
GPU: A100 PCIe * 4
transformers: 4.45.2
Information
The official example scripts
My own modified scripts
🐛 Describe the bug
Hi, I am trying fintuning Llama2-7B using FSDP , but I found there exits two forward recomputations in a single backward when using FSDP with activation checkpointing, but there should only be one recomputation normally? This strange phenomenon decreases the throughput of the model’training
As you can see in the diagram, there are two flash_attention operations
When I delete model.gradient_checkpointing_enable(), the behavior of backward becomes normal, and the train_epoch_time before this change is 248.88s, after this change is 185.07s
Error logs
Please check the attachment above
Expected behavior
One backward operation corresponds to one recomputation
The text was updated successfully, but these errors were encountered:
Hi @mingyuanw-mt thanks for flagging this. I think they are using reentrant as well but in any case this will wrap with torch.utils.checkpoint twice which leads to the trace. I'll be creating a related PR this week and can disable the second call in it.
System Info
PyTorch version: 2.2.0
Is debug build: False
CUDA used to build PyTorch: 11.8
GPU: A100 PCIe * 4
transformers: 4.45.2
Information
🐛 Describe the bug
Hi, I am trying fintuning Llama2-7B using FSDP , but I found there exits two forward recomputations in a single backward when using FSDP with activation checkpointing, but there should only be one recomputation normally? This strange phenomenon decreases the throughput of the model’training
As you can see in the diagram, there are two flash_attention operations
here is my tracing file:
fsdp_ac_cuda_profile.zip
It seems that using both fsdp and transformer's activation checkpointing has caused this problem
https://github.com/meta-llama/llama-recipes/blob/799e90eb959aacd13b084fd81b5447152ba11d41/src/llama_recipes/finetuning.py#L220-L221
fsdp's activation checkpointing will use non reentrant version of
torch.utils.checkpoint
, and transformer will do the opposite, please correct me if my understanding is incorrect.When I delete
model.gradient_checkpointing_enable()
, the behavior of backward becomes normal, and thetrain_epoch_time
before this change is 248.88s, after this change is 185.07sError logs
Please check the attachment above
Expected behavior
One backward operation corresponds to one recomputation
The text was updated successfully, but these errors were encountered: