From 74a79cc0599b047a691c427d16344a824b21e0f3 Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Tue, 22 Oct 2024 12:35:36 +0800 Subject: [PATCH] fix test --- src/llamafactory/train/test_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/llamafactory/train/test_utils.py b/src/llamafactory/train/test_utils.py index e5579d2db1..a5560b49c2 100644 --- a/src/llamafactory/train/test_utils.py +++ b/src/llamafactory/train/test_utils.py @@ -37,9 +37,9 @@ def compare_model(model_a: "torch.nn.Module", model_b: "torch.nn.Module", diff_k assert set(state_dict_a.keys()) == set(state_dict_b.keys()) for name in state_dict_a.keys(): if any(key in name for key in diff_keys): - assert torch.allclose(state_dict_a[name], state_dict_b[name], rtol=1e-3, atol=1e-4) is False + assert torch.allclose(state_dict_a[name], state_dict_b[name], rtol=1e-2, atol=1e-3) is False else: - assert torch.allclose(state_dict_a[name], state_dict_b[name], rtol=1e-3, atol=1e-4) is True + assert torch.allclose(state_dict_a[name], state_dict_b[name], rtol=1e-2, atol=1e-3) is True def check_lora_model(model: "LoraModel") -> Tuple[Set[str], Set[str]]: