Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cannot restore FSDP checkpoint with LOCAL_STATE_DICT #30811

Open
4 tasks
helloworld1 opened this issue May 14, 2024 · 0 comments
Open
4 tasks

Cannot restore FSDP checkpoint with LOCAL_STATE_DICT #30811

helloworld1 opened this issue May 14, 2024 · 0 comments

Comments

@helloworld1
Copy link
Contributor

helloworld1 commented May 14, 2024

System Info

  • transformers version: 4.40.1
  • Platform: Linux-5.15.148.2-2.cm2-x86_64-with-glibc2.35
  • Python version: 3.10.2
  • Huggingface_hub version: 0.23.0
  • Safetensors version: 0.4.2
  • Accelerate version: 0.29.1
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.1.2.1+gita8e7c98 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: Yes
  • Using distributed or parallel set-up in script?: FSDP

Who can help?

@pacman100 @muellerzr

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

I used FSDP with fsdp_state_dict_type = LOCAL_STATE_DICT
The accelerate config is like below

compute_environment: LOCAL_MACHINE                                                                                                  
debug: false                                                                                                                        
distributed_type: FSDP                                                                                                              
downcast_bf16: 'no'                                                                                                                 
fsdp_config:                                                                                                                        
  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP                                                                                     
  fsdp_backward_prefetch: BACKWARD_PRE                                                                                              
  fsdp_cpu_ram_efficient_loading: true                                                                                              
  fsdp_forward_prefetch: false                                                                                                      
  fsdp_offload_params: false                                                                                                        
  fsdp_sharding_strategy: FULL_SHARD                                                                                                
  fsdp_state_dict_type: LOCAL_STATE_DICT                                                                                            
  fsdp_sync_module_states: true                                                                                                     
  fsdp_use_orig_params: true                                                                                                        
main_training_function: main                                                                                                        
mixed_precision: bf16                                                                                                               
rdzv_backend: c10d                                                                                                                  
same_network: true                                                                                                                  
num_machines: 1                                                                                                                     
num_processes: 1                                                                                                                    
tpu_env: []                                                                                                                         
tpu_use_cluster: false                                                                                                              
tpu_use_sudo: false                                                                                                                 
use_cpu: false   

The checkpoint structure is like below

./trainer_state.json
./rng_state_1.pth
./pytorch_model_fsdp_rank1.bin
./pytorch_model_fsdp_rank0.bin
./pytorch_model_fsdp_rank4.bin
./rng_state_5.pth
./rng_state_4.pth
./rng_state_2.pth
./rng_state_3.pth
./pytorch_model_fsdp_rank6.bin
./rng_state_6.pth
./pytorch_model_fsdp_rank2.bin
./scheduler.pt
./rng_state_7.pth
./pytorch_model_fsdp_rank5.bin
./optimizer_0
./optimizer_0/__7_0.distcp
./optimizer_0/__1_0.distcp
./optimizer_0/.metadata
./optimizer_0/__3_0.distcp
./optimizer_0/__0_0.distcp
./optimizer_0/__4_0.distcp
./optimizer_0/__2_0.distcp
./optimizer_0/__6_0.distcp
./optimizer_0/__5_0.distcp
./pytorch_model_fsdp_rank3.bin
./pytorch_model_fsdp_rank7.bin
./rng_state_0.pth

When I try to restore the checkpoint from

trainer.train(resume_from_checkpoint="/home/user/checkpoint-10") 

I got error

training.py 146 <module>     
main()                                                                                                                              
                                                                                                                                    
training.py 125 main                                                                                                                
train_results = trainer.train(resume_from_checkpoint=checkpoint)                                                                    
                                                                                                                                    
sft_trainer.py 360 train                                          
output = super().train(*args, **kwargs)                                                                                             
                                                                  
trainer.py 1859 train                                                                                                               
return inner_training_loop(                                                                                                         
                                                                  
trainer.py 2037 _inner_training_loop                              
self._load_from_checkpoint(resume_from_checkpoint, self.model_wrapped)
                                                                  
trainer.py 2431 _load_from_checkpoint
raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}")                                                      
                                 
ValueError:                                                                                                                         
Can't find a valid checkpoint at /home/user/checkpoint-10  

If I used SHARDED_STATE_DICT, I don't have this error.

Expected behavior

Expect the checkpoint can be restored

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants