Skip to content

Commit

Permalink
temp
Browse files Browse the repository at this point in the history
  • Loading branch information
gc-fu committed Jan 16, 2025
1 parent e42d185 commit 520c35c
Show file tree
Hide file tree
Showing 3 changed files with 156 additions and 11 deletions.
88 changes: 79 additions & 9 deletions python/llm/src/ipex_llm/vllm/xpu/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.
#
from vllm.logger import init_logger
from typing import Dict, Optional
from typing import Dict, Optional, Any, Union, Type
from vllm.engine.llm_engine import LLMEngine
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
Expand All @@ -26,6 +26,11 @@
from vllm.engine.metrics import StatLoggerBase
from vllm.engine.multiprocessing.engine import MQLLMEngine
import signal
from vllm.engine.arg_utils import (EngineArgs, HfOverrides, PoolerConfig,
TaskOption)
from vllm.config import CompilationConfig
from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine
from vllm import envs

logger = init_logger(__name__)

Expand Down Expand Up @@ -59,34 +64,56 @@ def __init__(
tokenizer_mode: str = "auto",
skip_tokenizer_init: bool = False,
trust_remote_code: bool = False,
allowed_local_media_path: str = "",
tensor_parallel_size: int = 1,
dtype: str = "auto",
quantization: Optional[str] = None,
revision: Optional[str] = None,
tokenizer_revision: Optional[str] = None,
seed: int = 0,
gpu_memory_utilization: float = 0.9,
swap_space: int = 4,
swap_space: float = 4,
cpu_offload_gb: float = 0,
enforce_eager: bool = False,
enforce_eager: Optional[bool] = None,
max_seq_len_to_capture: int = 8192,
disable_custom_all_reduce: bool = False,
disable_async_output_proc: bool = True,
hf_overrides: Optional[HfOverrides] = None,
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
# After positional args are removed, move this right below `model`
task: TaskOption = "auto",
override_pooler_config: Optional[PoolerConfig] = None,
compilation_config: Optional[Union[int, Dict[str, Any]]] = None,
load_in_low_bit: str = "sym_int4",
**kwargs,
) -> None:
'''
LLM constructor.
Note: if enforce_eager is unset (enforce_eager is None)
it defaults to False.
'''

if "disable_log_stats" not in kwargs:
kwargs["disable_log_stats"] = True
removed_vision_keys = ("image_token_id", "image_feature_size",
"image_input_shape", "image_input_type")
if any(k in kwargs for k in removed_vision_keys):
raise TypeError( # noqa
"There is no need to pass vision-related arguments anymore.")

if compilation_config is not None:
if isinstance(compilation_config, (int, dict)):
compilation_config_instance = CompilationConfig.from_cli(
str(compilation_config))
else:
compilation_config_instance = compilation_config
else:
compilation_config_instance = None

engine_args = EngineArgs(
model=model,
task=task,
tokenizer=tokenizer,
tokenizer_mode=tokenizer_mode,
skip_tokenizer_init=skip_tokenizer_init,
trust_remote_code=trust_remote_code,
allowed_local_media_path=allowed_local_media_path,
tensor_parallel_size=tensor_parallel_size,
dtype=dtype,
quantization=quantization,
Expand All @@ -99,14 +126,55 @@ def __init__(
enforce_eager=enforce_eager,
max_seq_len_to_capture=max_seq_len_to_capture,
disable_custom_all_reduce=disable_custom_all_reduce,
disable_async_output_proc=disable_async_output_proc,
hf_overrides=hf_overrides,
mm_processor_kwargs=mm_processor_kwargs,
override_pooler_config=override_pooler_config,
compilation_config=compilation_config_instance,
**kwargs,
)
# Logic to switch between engines is done at runtime instead of import
# to avoid import order issues
# TODO(gc): we will need to override this function
self.engine_class = self.get_engine_class()
self.llm_engine = IPEXLLMLLMEngine.from_engine_args(
self.llm_engine = self.engine_class.from_engine_args(
engine_args, usage_context=UsageContext.LLM_CLASS,
load_in_low_bit=load_in_low_bit)

self.request_counter = Counter()

@staticmethod
def get_engine_class() -> Type[LLMEngine]:
if envs.VLLM_USE_V1:
# Lazy import: the v1 package isn't distributed
# from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine
return IPEXLLMLLMV1Engine# type: ignore
return IPEXLLMLLMEngine


# TODO(gc): implement this later...
class IPEXLLMLLMV1Engine(V1LLMEngine):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

@classmethod
def from_engine_args(
cls,
engine_args: EngineArgs,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
enable_multiprocessing: bool = False,
load_in_low_bit: str = "sym_int4",
) -> "LLMEngine":
"""Creates an LLM engine from the engine arguments."""
# Create the engine configs.

# TODO(gc): delete this later
print("IPEXLLM V1 Engine")
# This does not work as it is in the seperate process...
_ipex_llm_convert(load_in_low_bit)
return super().from_engine_args(engine_args, usage_context, stat_loggers, enable_multiprocessing)


class IPEXLLMLLMEngine(LLMEngine):
def __init__(self, *args, **kwargs):
Expand All @@ -122,6 +190,8 @@ def from_engine_args(
) -> "LLMEngine":
"""Creates an LLM engine from the engine arguments."""
# Create the engine configs.
# TODO(gc): Delete
print("Use vLLM v0 engine")
_ipex_llm_convert(load_in_low_bit)
return super().from_engine_args(engine_args, usage_context, stat_loggers)

Expand Down
24 changes: 24 additions & 0 deletions python/llm/src/ipex_llm/vllm/xpu/ipex_llm_v1_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from vllm.logger import init_logger
from vllm.v1.executor.ray_utils import RayWorkerWrapper


logger = init_logger(__name__)


class IPEXLLMV1Wrapper(RayWorkerWrapper):
def __init__(self, load_in_low_bit="sym_int4", *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
from ipex_llm.vllm.xpu.model_convert import _ipex_llm_convert
_ipex_llm_convert(load_in_low_bit=load_in_low_bit)
self.compiled_dag_cuda_device_set = False


def get_ipex_llm_v1_wrapper(load_in_low_bit):
# The reason why we not using functools.partial is that
# ray seems not work well with it.
class WrapperWithLoadBit(IPEXLLMV1Wrapper):
def __init__(self, *args, **kwargs) -> None:
super().__init__(load_in_low_bit=load_in_low_bit, *args, **kwargs)

# a = functools.partial(IPEXLLMWrapper, load_in_low_bit=load_in_low_bit)
return WrapperWithLoadBit
55 changes: 53 additions & 2 deletions python/llm/src/ipex_llm/vllm/xpu/model_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,61 @@ def _model_sample_convert():
def _ipex_llm_convert(load_in_low_bit):
from vllm.worker.xpu_model_runner import XPUModelRunner
from ipex_llm.vllm.xpu.ipex_llm_wrapper import get_ipex_llm_wrapper
import vllm.executor.ray_utils as ray_utils
from ipex_llm.vllm.xpu.ipex_llm_v1_wrapper import get_ipex_llm_v1_wrapper
import vllm.executor.ray_utils as ray_utils_v0
import vllm.v1.executor.ray_utils as ray_utils_v1
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
setattr(XPUModelRunner, "load_model", get_load_function(load_in_low_bit))
setattr(ray_utils, "RayWorkerWrapper", get_ipex_llm_wrapper(load_in_low_bit))
setattr(GPUModelRunner, "load_model", get_load_function(load_in_low_bit))
setattr(ray_utils_v0, "RayWorkerWrapper", get_ipex_llm_wrapper(load_in_low_bit))
setattr(ray_utils_v1, "RayWorkerWrapper", get_ipex_llm_v1_wrapper(load_in_low_bit))

# def get_load_function_v1(low_bit):
# def _ipex_llm_load_model(self) -> None:
# _model_sample_convert()

# # from vllm.utils import measure_device_memory
# logger.info("Starting to load model %s...", self.model_config.model)
# from vllm.utils import DeviceMemoryProfiler
# with DeviceMemoryProfiler() as m:
# from dataclasses import replace
# new_device_config = DeviceConfig("cpu")
# new_vllm_config = replace(self.vllm_config, device_config=new_device_config)
# self.model = get_model(
# vllm_config=new_vllm_config
# )
# if "qwen" in self.vllm_config.model_config.model.lower() or \
# "baichuan" in self.vllm_config.model_config.model.lower() or \
# "codegeex4-all" in self.vllm_config.model_config.model.lower() or \
# "chatglm" in self.vllm_config.model_config.model.lower():
# self.model.apply(padding_mlp)
# from ipex_llm import optimize_model
# import os
# not_convert_last_mlp = os.getenv("IPEX_LLM_NOT_CONVERT_LAST_MLP", None)
# if not_convert_last_mlp is not None:
# # only use to avoid nan value in last mlp forward running glm4-9b-chat
# modules = ["35.mlp", "36.mlp", "37.mlp", "38.mlp", "39.mlp"]
# else:
# modules = None
# if "minicpm" in self.vllm_config.model_config.model.lower():
# modules = ["vpm", "resampler"]
# # only for minicpm_2_6
# if "minicpm-v" in self.vllm_config.model_config.model.lower():
# from ipex_llm.transformers.models.minicpmv import merge_qkv
# self.model.vpm.apply(merge_qkv)
# if "internvl2" in self.vllm_config.model_config.model.lower():
# modules = ["vision_model", "mlp1"]
# optimize_model(self.model, low_bit=low_bit, torch_dtype=self.vllm_config.model_config.dtype,
# modules_to_not_convert=modules)
# self.model = self.model.to(device=self.vllm_config.device_config.device,
# dtype=self.vllm_config.model_config.dtype)

# self.model_memory_usage = m.consumed_memory
# logger = init_logger(__name__)
# logger.info("Loading model weights took %.4f GB",
# self.model_memory_usage / float(2**30))

# return _ipex_llm_load_model

def get_load_function(low_bit):
def _ipex_llm_load_model(self) -> None:
Expand Down

0 comments on commit 520c35c

Please sign in to comment.