From 9b80f3d50ccb98ceee94bab4145a36e7e58aa4eb Mon Sep 17 00:00:00 2001 From: Yu Chin Fabian Lim Date: Sat, 21 Sep 2024 00:11:34 -0700 Subject: [PATCH] fix: device could be in meta, transformers#33154 (#2089) Signed-off-by: Yu Chin Fabian Lim --- trl/trainer/sft_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index c645734781..49bea851e7 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -232,7 +232,7 @@ def __init__( if getattr(model, "is_loaded_in_4bit", False): for _, param in model.named_parameters(): if param.__class__.__name__ == "Params4bit": - is_sharded_qlora = param.data.device.type == "cpu" + is_sharded_qlora = param.data.device.type in {"cpu", "meta"} break if getattr(model, "is_loaded_in_8bit", False) or ( getattr(model, "is_loaded_in_4bit", False) and not is_sharded_qlora