You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
An officially supported task in the examples folder
My own task or dataset (give details below)
Reproduction
model_id="microsoft/resnet-18"@pytest.fixturedefimage_processor():
image_processor=AutoImageProcessor.from_pretrained(model_id)
returnimage_processor@pytest.fixturedefdata(image_processor):
dataset=load_dataset("huggingface/cats-image")
image=dataset["test"]["image"][0]
returnimage_processor(image, return_tensors="pt")
deftest_model_with_batchnorm(tmp_path, data):
torch.manual_seed(0)
model=AutoModelForImageClassification.from_pretrained(model_id)
config=LoraConfig(target_modules=["convolution"], modules_to_save=["classifier"])
model=get_peft_model(model, config)
# record outputs before trainingmodel.eval()
withtorch.inference_mode():
output_before=model(**data)
model.train()
optimizer=torch.optim.AdamW(model.parameters(), lr=1e-3)
batch_size=4max_steps=5*batch_sizelabels=torch.zeros(1, 1000)
labels[0, 283] =1foriinrange(0, max_steps, batch_size):
optimizer.zero_grad()
outputs=model(**data, labels=labels)
loss=outputs.lossloss.backward()
optimizer.step()
model.eval()
withtorch.inference_mode():
output_after=model(**data)
asserttorch.isfinite(output_after.logits).all()
atol, rtol=1e-4, 1e-4# sanity check: model was updatedassertnottorch.allclose(output_before.logits, output_after.logits, atol=atol, rtol=rtol)
# check saving the model and loading itmodel.save_pretrained(tmp_path)
delmodeltorch.manual_seed(0)
model=AutoModelForImageClassification.from_pretrained(model_id)
model=PeftModel.from_pretrained(model, tmp_path).eval()
withtorch.inference_mode():
output_loaded=model(**data)
# THIS FAILSasserttorch.allclose(output_after.logits, output_loaded.logits, atol=atol, rtol=rtol)
Expected behavior
After loading a model that was trained with PEFT on a base model with some kind of batch norm layer, the loaded model should produce the same output. Right now, this does not happen.
The reason is that during training, buffers for running mean etc. are updated, but they are not saved when calling save_pretrained on the PeftModel instance. Normally in PEFT, we assume that during training, the base model parameters are kept constant, which is not the case with batch norm. We only save the PEFT parameters and assume that when the user loads the base model, all parameters are restored exactly. That way, the information in the buffers is lost completely.
One possible solution would be to try to include the buffers in the PEFT adapter, which is not very pretty. For this to work, we would need to have a way to identify buffers that were updated vs those that are static. If someone knows a way to achieve this, or has a better idea how to fix this, please let us know.
Edit: Best suggestion so far by @kashif: Check for the track_running_stats and if it's True, save the module's buffer. This will not cover all possible corner cases, but hopefully most.
The text was updated successfully, but these errors were encountered:
Fixeshuggingface#1732
After loading a model that was trained with PEFT on a base model with
some kind of batch norm layer, the loaded model should produce the same
output. Right now, this does not happen.
The reason is that during training, buffers for running mean etc. are
updated, but they are not saved when calling save_pretrained on the
PeftModel instance. Normally in PEFT, we assume that during training,
the base model parameters are kept constant, which is not the case with
batch norm. We only save the PEFT parameters and assume that when the
user loads the base model, all parameters are restored exactly. That
way, the information in the buffers is lost completely.
This PR fixes this issue by saving the buffers of the batch norm layers.
They are identified by checking for the presence of the
track_running_stats attribute.
Note: One test for BOFT is currently failing, see the comment in the
test file.
Fixes#1732
After loading a model that was trained with PEFT on a base model with
some kind of batch norm layer, the loaded model should produce the same
output. Right now, this does not happen.
The reason is that during training, buffers for running mean etc. are
updated, but they are not saved when calling save_pretrained on the
PeftModel instance. Normally in PEFT, we assume that during training,
the base model parameters are kept constant, which is not the case with
batch norm. We only save the PEFT parameters and assume that when the
user loads the base model, all parameters are restored exactly. That
way, the information in the buffers is lost completely.
The fix is to add the batch norm layers to modules_to_save. This fix is
now documented and tested.
System Info
Latest version of PEFT
Who can help?
No response
Information
Tasks
examples
folderReproduction
Expected behavior
After loading a model that was trained with PEFT on a base model with some kind of batch norm layer, the loaded model should produce the same output. Right now, this does not happen.
The reason is that during training, buffers for running mean etc. are updated, but they are not saved when calling
save_pretrained
on thePeftModel
instance. Normally in PEFT, we assume that during training, the base model parameters are kept constant, which is not the case with batch norm. We only save the PEFT parameters and assume that when the user loads the base model, all parameters are restored exactly. That way, the information in the buffers is lost completely.One possible solution would be to try to include the buffers in the PEFT adapter, which is not very pretty. For this to work, we would need to have a way to identify buffers that were updated vs those that are static. If someone knows a way to achieve this, or has a better idea how to fix this, please let us know.
Edit: Best suggestion so far by @kashif: Check for the
track_running_stats
and if it'sTrue
, save the module's buffer. This will not cover all possible corner cases, but hopefully most.The text was updated successfully, but these errors were encountered: