Skip to content

Commit d27a8cd

Browse files
authored
Fix Pipeline Parallel dtype (#11623)
1 parent d020ad6 commit d27a8cd

File tree

2 files changed

+5
-3
lines changed

2 files changed

+5
-3
lines changed

python/llm/src/ipex_llm/transformers/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,7 @@ def from_pretrained(cls,
374374
"Please make sure you've called `init_pipeline_parallel()` "
375375
"and world size is the same as `pipeline_parallel_stages`")
376376
from .pipeline_parallel import pipeline_parallel, pipeline_parallel_generate
377-
model = pipeline_parallel(model, pipeline_parallel_stages)
377+
model = pipeline_parallel(model, pipeline_parallel_stages, kwargs["torch_dtype"])
378378
import types
379379
# add pipeline_parallel_generate to pretrained model dynamically
380380
model.pipeline_parallel_generate = types.MethodType(pipeline_parallel_generate,
@@ -788,7 +788,7 @@ def load_low_bit(cls,
788788

789789
if pipeline_parallel_stages > 1:
790790
from .pipeline_parallel import pipeline_parallel, pipeline_parallel_generate
791-
model = pipeline_parallel(model, pipeline_parallel_stages)
791+
model = pipeline_parallel(model, pipeline_parallel_stages, torch_dtype)
792792
import types
793793
# add pipeline_parallel_generate to pretrained model dynamically
794794
model.pipeline_parallel_generate = types.MethodType(pipeline_parallel_generate,

python/llm/src/ipex_llm/transformers/pipeline_parallel.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ def _check_quantize_kv_cache(model, idx, batch_size):
162162
os.environ["IPEX_LLM_QUANTIZE_KV_CACHE"] = "0"
163163

164164

165-
def pipeline_parallel(model, pipeline_parallel_stages):
165+
def pipeline_parallel(model, pipeline_parallel_stages, torch_dtype=torch.float32):
166166
global num_layers
167167
if hasattr(model.config, 'num_hidden_layers'):
168168
num_layers = model.config.num_hidden_layers
@@ -227,6 +227,8 @@ def pipeline_parallel(model, pipeline_parallel_stages):
227227
model.layer_start = layer_start
228228
model.layer_end = layer_end
229229
model.num_layers = num_layers
230+
if torch_dtype == torch.float16:
231+
model = model.half()
230232
model = model.to(f'xpu:{local_rank}')
231233
return model
232234

0 commit comments

Comments
 (0)