Skip to content

Commit

Permalink
[Delivery] Update model delivery script (#2565)
Browse files Browse the repository at this point in the history
Some improvements of the delivery script:

- provide different overrides for different quantization. e.g. we can change
prefill chunk size for q0/q3/q4
- rerun gen config only if only conv_template changes
- do NOT recreate HF repo when the repo already exists. This will preserve
commit history
- dry-run validation
  • Loading branch information
rickzx committed Jun 11, 2024
1 parent 42f146d commit a231ae1
Showing 1 changed file with 120 additions and 33 deletions.
153 changes: 120 additions & 33 deletions python/mlc_llm/cli/delivery.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,18 @@
T = TypeVar("T", bound="BaseModel")


class OverrideConfigs(BaseModel):
"""
The class that specifies the override configurations.
"""

context_window_size: Optional[int] = None
sliding_window_size: Optional[int] = None
prefill_chunk_size: Optional[int] = None
attention_sink_size: Optional[int] = None
tensor_parallel_shards: Optional[int] = None


class ModelDeliveryTask(BaseModel):
"""
Example:
Expand All @@ -38,21 +50,21 @@ class ModelDeliveryTask(BaseModel):
"model": "HF://microsoft/Phi-3-mini-128k-instruct",
"conv_template": "phi-3",
"quantization": ["q3f16_1"],
"context_window_size": 4096
"overrides": {
"q3f16_1": {
"context_window_size": 512
}
}
}
"""

model_id: str
model: str
conv_template: str
quantization: Optional[Union[List[str], str]] = Field(default_factory=list)
quantization: Union[List[str], str] = Field(default_factory=list)
overrides: Dict[str, OverrideConfigs] = Field(default_factory=dict)
destination: Optional[str] = None

context_window_size: Optional[int] = None
sliding_window_size: Optional[int] = None
prefill_chunk_size: Optional[int] = None
attention_sink_size: Optional[int] = None
tensor_parallel_shards: Optional[int] = None
gen_config_only: Optional[bool] = False


class ModelDeliveryList(BaseModel):
Expand All @@ -63,7 +75,8 @@ class ModelDeliveryList(BaseModel):
tasks: List[ModelDeliveryTask]
# For delivered log, the default destination and quantization fields are optional
default_destination: Optional[str] = None
default_quantization: Optional[List[str]] = None
default_quantization: List[str] = Field(default_factory=list)
default_overrides: Dict[str, OverrideConfigs] = Field(default_factory=dict)

@classmethod
def from_json(cls: Type[T], json_dict: Dict[str, Any]) -> T:
Expand Down Expand Up @@ -115,10 +128,7 @@ def _run_quantization(
except HfHubHTTPError as error:
if error.response.status_code != 409:
raise
logger.info("[HF] Repo already exists. Recreating...")
api.delete_repo(repo_id=repo)
api.create_repo(repo_id=repo, private=False)
logger.info("[HF] Repo recreated")
logger.info("[HF] Repo already exists. Skipping creation.")
succeeded = True
log_path = Path(output_dir) / "logs.txt"
with log_path.open("a", encoding="utf-8") as log_file:
Expand Down Expand Up @@ -147,21 +157,24 @@ def _run_quantization(

print(" ".join(cmd), file=log_file, flush=True)
subprocess.run(cmd, check=True, stdout=log_file, stderr=subprocess.STDOUT, env=os.environ)
cmd = [
sys.executable,
"-m",
"mlc_llm",
"convert_weight",
str(model_info.model),
"--quantization",
model_info.quantization,
"--output",
output_dir,
]
print(" ".join(cmd), file=log_file, flush=True)
subprocess.run(cmd, check=False, stdout=log_file, stderr=subprocess.STDOUT, env=os.environ)
if not model_info.gen_config_only:
cmd = [
sys.executable,
"-m",
"mlc_llm",
"convert_weight",
str(model_info.model),
"--quantization",
model_info.quantization,
"--output",
output_dir,
]
print(" ".join(cmd), file=log_file, flush=True)
subprocess.run(
cmd, check=False, stdout=log_file, stderr=subprocess.STDOUT, env=os.environ
)
logger.info("[MLC] Complete!")
if not (Path(output_dir) / "ndarray-cache.json").exists():
if not (Path(output_dir) / "ndarray-cache.json").exists() and not model_info.gen_config_only:
logger.error(
"[%s] Model %s. Quantization %s. No weights metadata found.",
red("FAILED"),
Expand All @@ -175,7 +188,7 @@ def _run_quantization(
api.upload_folder(
folder_path=output_dir,
repo_id=repo,
commit_message="Initial commit",
ignore_patterns=["logs.txt"],
)
except Exception as exc: # pylint: disable=broad-except
logger.error("[%s] %s. Retrying...", red("FAILED"), exc)
Expand All @@ -198,38 +211,99 @@ def _get_current_log(log: str) -> ModelDeliveryList:
return current_log


def _generate_model_delivery_diff( # pylint: disable=too-many-locals
spec: ModelDeliveryList, log: ModelDeliveryList
) -> ModelDeliveryList:
diff_tasks = []
default_quantization = spec.default_quantization
default_overrides = spec.default_overrides

for task in spec.tasks:
model_id = task.model_id
conv_template = task.conv_template
quantization = task.quantization
overrides = {**default_overrides, **task.overrides}

logger.info("Checking task: %s %s %s %s", model_id, conv_template, quantization, overrides)
log_tasks = [t for t in log.tasks if t.model_id == model_id]
delivered_quantizations = set()
gen_config_only = set()

for log_task in log_tasks:
log_quantization = log_task.quantization
assert isinstance(log_quantization, str)
log_override = log_task.overrides.get(log_quantization, OverrideConfigs())
override = overrides.get(log_quantization, OverrideConfigs())
if log_override == override:
if log_task.conv_template == conv_template:
delivered_quantizations.add(log_quantization)
else:
gen_config_only.add(log_quantization)

all_quantizations = set(default_quantization) | set(quantization)
quantization_diff = all_quantizations - set(delivered_quantizations)

if quantization_diff:
for q in quantization_diff:
logger.info("Adding task %s %s %s to the diff.", model_id, conv_template, q)
task_copy = task.model_copy()
task_copy.quantization = [q]
task_copy.overrides = {q: overrides.get(q, OverrideConfigs())}
task_copy.gen_config_only = task_copy.gen_config_only or q in gen_config_only
diff_tasks.append(task_copy)
else:
logger.info("Task %s %s %s is up-to-date.", model_id, conv_template, quantization)

diff_config = spec.model_copy()
diff_config.default_quantization = []
diff_config.default_overrides = {}
diff_config.tasks = diff_tasks

logger.info("Model delivery diff: %s", diff_config.model_dump_json(indent=4, exclude_none=True))

return diff_config


def _main( # pylint: disable=too-many-locals, too-many-arguments
username: str,
api: HfApi,
spec: ModelDeliveryList,
log: str,
hf_local_dir: Optional[str],
output: str,
dry_run: bool,
):
delivery_diff = _generate_model_delivery_diff(spec, _get_current_log(log))
if dry_run:
logger.info("Dry run. No actual delivery.")
return

failed_cases: List[Tuple[str, str]] = []
delivered_log = _get_current_log(log)
for task_index, task in enumerate(spec.tasks, 1):
for task_index, task in enumerate(delivery_diff.tasks, 1):
logger.info(
bold("[{task_index}/{total_tasks}] Processing model: ").format(
task_index=task_index,
total_tasks=len(spec.tasks),
total_tasks=len(delivery_diff.tasks),
)
+ green(task.model_id)
)
model = _clone_repo(task.model, hf_local_dir)

quantizations = []

if spec.default_quantization:
quantizations += spec.default_quantization
if delivery_diff.default_quantization:
quantizations += delivery_diff.default_quantization

if task.quantization:
if isinstance(task.quantization, str):
quantizations.append(task.quantization)
else:
quantizations += task.quantization

default_destination = spec.default_destination or "{username}/{model_id}-{quantization}-MLC"
default_destination = (
delivery_diff.default_destination or "{username}/{model_id}-{quantization}-MLC"
)
for quantization in quantizations:
repo = default_destination.format(
username=username,
Expand Down Expand Up @@ -260,12 +334,19 @@ def _main( # pylint: disable=too-many-locals, too-many-arguments
(task.model_id, quantization),
)
else:
delivered_log.tasks = [
task
for task in delivered_log.tasks
if task.model_id != model_info.model_id
or task.quantization != model_info.quantization
]
delivered_log.tasks.append(model_info)
if failed_cases:
logger.info("Total %s %s:", len(failed_cases), red("failures"))
for model_id, quantization in failed_cases:
logger.info(" Model %s. Quantization %s.", model_id, quantization)

delivered_log.tasks.sort(key=lambda task: task.model_id)
logger.info("Writing log to %s", log)
with open(log, "w", encoding="utf-8") as o_f:
json.dump(delivered_log.to_json(), o_f, indent=4)
Expand Down Expand Up @@ -336,6 +417,11 @@ def _get_default_hf_token() -> str:
required=False,
help="Local directory to store the downloaded HuggingFace model",
)
parser.add_argument(
"--dry-run",
action="store_true",
help="Dry run without uploading to HuggingFace Hub",
)
parsed = parser.parse_args()
_main(
parsed.username,
Expand All @@ -344,6 +430,7 @@ def _get_default_hf_token() -> str:
api=HfApi(token=parsed.token),
hf_local_dir=parsed.hf_local_dir,
output=parsed.output,
dry_run=parsed.dry_run,
)


Expand Down

0 comments on commit a231ae1

Please sign in to comment.