Skip to content

Commit 380318d

Browse files
【gpt-oss】change less weight (#2834)
1 parent fc1be25 commit 380318d

File tree

1 file changed

+17
-15
lines changed

1 file changed

+17
-15
lines changed

tests/transformers/gpt_oss/test_fp4_to_bf16.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,8 @@ def setUp(self):
7878
self.tempdir = "./models/gpt-oss"
7979

8080
def fp4_to_bf16(self):
81-
load_path = os.path.join(self.tempdir, "gpt-oss-test-fp4")
82-
save_path = os.path.join(self.tempdir, "gpt-oss-test-new-bf16")
81+
load_path = os.path.join(self.tempdir, "tiny-random-gpt-oss-fp4")
82+
save_path = os.path.join(self.tempdir, "tiny-random-gpt-oss-new-bf16")
8383

8484
safetensor_prefix = "model"
8585
save_index_file = os.path.join(save_path, safetensor_prefix + ".safetensors.index.json")
@@ -103,8 +103,8 @@ def fp4_to_bf16(self):
103103
logger.info(f"Model index file saved in {save_index_file}.")
104104

105105
def bf16_to_fp4(self):
106-
load_path = os.path.join(self.tempdir, "gpt-oss-test-bf16")
107-
save_path = os.path.join(self.tempdir, "gpt-oss-test-new-fp4")
106+
load_path = os.path.join(self.tempdir, "tiny-random-gpt-oss-bf16")
107+
save_path = os.path.join(self.tempdir, "tiny-random-gpt-oss-new-fp4")
108108
safetensor_prefix = "model"
109109
save_index_file = os.path.join(save_path, safetensor_prefix + ".safetensors.index.json")
110110
index = {"metadata": {"total_size": 0}, "weight_map": {}}
@@ -127,7 +127,7 @@ def bf16_to_fp4(self):
127127
logger.info(f"Model index file saved in {save_index_file}.")
128128

129129
def check_weight(self, origin_path, new_path, atol):
130-
origin_file_name = "model-00008-of-00009.safetensors"
130+
origin_file_name = "model.safetensors"
131131
new_file_name = "model-00001-of-00001.safetensors"
132132

133133
origin_dict = load_file(os.path.join(origin_path, origin_file_name))
@@ -151,23 +151,25 @@ def check_weight(self, origin_path, new_path, atol):
151151
@slow
152152
def test_change_weight(self):
153153

154-
repo_id = "PaddleFormers/gpt-oss-test-fp4"
155-
filename = "model-00008-of-00009.safetensors"
156-
aistudio_download(repo_id, filename, None, False, os.path.join(self.tempdir, "gpt-oss-test-fp4/"))
154+
repo_id = "PaddleFormers/tiny-random-gpt-oss-fp4"
155+
filename = "model.safetensors"
156+
aistudio_download(repo_id, filename, None, False, os.path.join(self.tempdir, "tiny-random-gpt-oss-fp4/"))
157157

158-
repo_id = "PaddleFormers/gpt-oss-test-bf16"
159-
filename = "model-00008-of-00009.safetensors"
160-
aistudio_download(repo_id, filename, None, False, os.path.join(self.tempdir, "gpt-oss-test-bf16/"))
158+
repo_id = "PaddleFormers/tiny-random-gpt-oss-bf16"
159+
filename = "model.safetensors"
160+
aistudio_download(repo_id, filename, None, False, os.path.join(self.tempdir, "tiny-random-gpt-oss-bf16/"))
161161

162-
self.fp4_to_bf16()
163162
self.bf16_to_fp4()
163+
self.fp4_to_bf16()
164164

165165
self.check_weight(
166-
os.path.join(self.tempdir, "gpt-oss-test-fp4/"), os.path.join(self.tempdir, "gpt-oss-test-new-fp4/"), 1e-2
166+
os.path.join(self.tempdir, "tiny-random-gpt-oss-fp4/"),
167+
os.path.join(self.tempdir, "tiny-random-gpt-oss-new-fp4/"),
168+
1e-2,
167169
)
168170
self.check_weight(
169-
os.path.join(self.tempdir, "gpt-oss-test-bf16/"),
170-
os.path.join(self.tempdir, "gpt-oss-test-new-bf16/"),
171+
os.path.join(self.tempdir, "tiny-random-gpt-oss-bf16/"),
172+
os.path.join(self.tempdir, "tiny-random-gpt-oss-new-bf16/"),
171173
1e-2,
172174
)
173175

0 commit comments

Comments
 (0)