-
Notifications
You must be signed in to change notification settings - Fork 439
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #23 from decodingml/feat/inference-pipeline-tweaks
Feat/inference pipeline tweaks
- Loading branch information
Showing
13 changed files
with
199 additions
and
28 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
import structlog | ||
|
||
|
||
def get_logger(cls: str): | ||
return structlog.get_logger().bind(cls=cls) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,3 +10,4 @@ bitsandbytes==0.42.0 | |
pydantic_settings==2.2.1 | ||
scikit-learn==1.4.2 | ||
qwak-sdk==0.5.68 | ||
structlog==24.2.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
import os | ||
from pathlib import Path | ||
from typing import Optional, Tuple | ||
|
||
import torch | ||
from peft import LoraConfig, PeftConfig, PeftModel | ||
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | ||
|
||
from finetuning.settings import settings | ||
from finetuning import logger_utils | ||
|
||
|
||
logger = logger_utils.get_logger(__name__) | ||
|
||
|
||
def build_qlora_model( | ||
pretrained_model_name_or_path: str, | ||
peft_pretrained_model_name_or_path: Optional[str] = None, | ||
bnb_config: Optional[BitsAndBytesConfig] = None, | ||
lora_config: Optional[LoraConfig] = None, | ||
cache_dir: Optional[Path] = None, | ||
) -> Tuple[AutoModelForCausalLM, AutoTokenizer, PeftConfig]: | ||
""" | ||
Function that builds a QLoRA LLM model based on the given HuggingFace name: | ||
1. Create and prepare the bitsandbytes configuration for QLoRa's quantization | ||
2. Download, load, and quantize on-the-fly Falcon-7b | ||
3. Create and prepare the LoRa configuration | ||
4. Load and configuration Falcon-7B's tokenizer | ||
""" | ||
|
||
if bnb_config is None: | ||
bnb_config = BitsAndBytesConfig( | ||
load_in_4bit=True, | ||
bnb_4bit_use_double_quant=True, | ||
bnb_4bit_quant_type="nf4", | ||
bnb_4bit_compute_dtype=torch.bfloat16, | ||
) | ||
|
||
model = AutoModelForCausalLM.from_pretrained( | ||
pretrained_model_name_or_path, | ||
token=settings.HUGGINGFACE_ACCESS_TOKEN, | ||
device_map=torch.cuda.current_device(), | ||
quantization_config=bnb_config, | ||
use_cache=False, | ||
torchscript=True, | ||
cache_dir=str(cache_dir) if cache_dir else None, | ||
) | ||
|
||
tokenizer = AutoTokenizer.from_pretrained( | ||
pretrained_model_name_or_path, | ||
token=settings.HUGGINGFACE_ACCESS_TOKEN, | ||
cache_dir=str(cache_dir) if cache_dir else None, | ||
) | ||
tokenizer.pad_token = tokenizer.eos_token | ||
tokenizer.padding_side = "right" | ||
|
||
if peft_pretrained_model_name_or_path: | ||
is_model_name = not os.path.isdir(peft_pretrained_model_name_or_path) | ||
if is_model_name: | ||
logger.info( | ||
f"Downloading {peft_pretrained_model_name_or_path} from CometML Model Registry:" | ||
) | ||
peft_pretrained_model_name_or_path = download_from_model_registry( | ||
model_id=peft_pretrained_model_name_or_path, | ||
cache_dir=cache_dir, | ||
) | ||
|
||
logger.info(f"Loading Lora Confing from: {peft_pretrained_model_name_or_path}") | ||
lora_config = LoraConfig.from_pretrained(peft_pretrained_model_name_or_path) | ||
assert ( | ||
lora_config.base_model_name_or_path == pretrained_model_name_or_path | ||
), f"Lora Model trained on different base model than the one requested: \ | ||
{lora_config.base_model_name_or_path} != {pretrained_model_name_or_path}" | ||
|
||
logger.info(f"Loading Peft Model from: {peft_pretrained_model_name_or_path}") | ||
model = PeftModel.from_pretrained(model, peft_pretrained_model_name_or_path) | ||
else: | ||
if lora_config is None: | ||
lora_config = LoraConfig( | ||
lora_alpha=16, | ||
lora_dropout=0.1, | ||
r=64, | ||
bias="none", | ||
task_type="CAUSAL_LM", | ||
) | ||
|
||
return model, tokenizer, lora_config | ||
|
||
|
||
def download_from_model_registry( | ||
model_id: str, cache_dir: Optional[Path] = None | ||
) -> Path: | ||
""" | ||
Downloads a model from the Comet ML Learning model registry. | ||
Args: | ||
model_id (str): The ID of the model to download, in the format "workspace/model_name:version". | ||
cache_dir (Optional[Path]): The directory to cache the downloaded model in. Defaults to the value of | ||
`constants.CACHE_DIR`. | ||
Returns: | ||
Path: The path to the downloaded model directory. | ||
""" | ||
|
||
if cache_dir is None: | ||
cache_dir = settings.CACHE_DIR | ||
output_folder = cache_dir / "models" / model_id | ||
|
||
already_downloaded = output_folder.exists() | ||
if not already_downloaded: | ||
workspace, model_id = model_id.split("/") | ||
model_name, version = model_id.split(":") | ||
|
||
api = API() | ||
model = api.get_model(workspace=workspace, model_name=model_name) | ||
model.download(version=version, output_folder=output_folder, expand=True) | ||
else: | ||
logger.info(f"Model {model_id=} already downloaded to: {output_folder}") | ||
|
||
subdirs = [d for d in output_folder.iterdir() if d.is_dir()] | ||
if len(subdirs) == 1: | ||
model_dir = subdirs[0] | ||
else: | ||
raise RuntimeError( | ||
f"There should be only one directory inside the model folder. \ | ||
Check the downloaded model at: {output_folder}" | ||
) | ||
|
||
logger.info(f"Model {model_id=} downloaded from the registry to: {model_dir}") | ||
|
||
return model_dir |
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,12 @@ | ||
help: | ||
@grep -E '^[a-zA-Z0-9 -]+:.*#' Makefile | sort | while read -r l; do printf "\033[1;32m$$(echo $$l | cut -f 1 -d':')\033[00m:$$(echo $$l | cut -f 2- -d'#')\n"; done | ||
|
||
call-inference-pipeline: # Test the inference pipeline. | ||
deploy-llm-microservice: # Deploy the Qwak model. | ||
qwak models deploy realtime --model-id "llm_twin" --instance "gpu.a10.2xl" --timeout 50000 --replicas 2 --server-workers 2 | ||
|
||
undeploy-llm-microservice: # Deploy the Qwak model. | ||
qwak models undeploy --model-id "llm_twin" | ||
|
||
call-inference-pipeline: # Call the inference pipeline. | ||
poetry run python main.py | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters