Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
kyriediculous committed Dec 28, 2024
1 parent 7987601 commit be5f890
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 6 deletions.
27 changes: 23 additions & 4 deletions runner/app/pipelines/embeddings.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# app/pipelines/embeddings.py
import logging
import os
from typing import List, Union, Dict, Any, Optional
Expand All @@ -9,6 +8,7 @@
from InstructorEmbedding import INSTRUCTOR
from dataclasses import dataclass
from enum import Enum
from huggingface_hub import file_download

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -39,6 +39,16 @@ def __init__(self, model_id: str):
logger.info("Initializing embedding pipeline")

self.model_id = model_id
folder_name = file_download.repo_folder_name(
repo_id=model_id, repo_type="model")
base_path = os.path.join(get_model_dir(), folder_name)

# Find the actual model path
self.local_model_path = self._find_model_path(base_path)

if not self.local_model_path:
raise ValueError(f"Could not find model files for {model_id}")

self.device = "cuda" if torch.cuda.is_available() else "cpu"

# Get configuration from environment
Expand All @@ -50,16 +60,15 @@ def __init__(self, model_id: str):

try:
if self.model_type == EmbeddingModelType.SENTENCE_TRANSFORMER:
self.model = SentenceTransformer(model_id).to(self.device)
self.model = SentenceTransformer(self.local_model_path).to(self.device)
elif self.model_type == EmbeddingModelType.INSTRUCTOR:
self.model = INSTRUCTOR(model_id).to(self.device)
self.model = INSTRUCTOR(self.local_model_path).to(self.device)

logger.info(f"Model loaded successfully on {self.device}")

except Exception as e:
logger.error(f"Error loading model: {e}")
raise

async def generate(
self,
texts: Union[str, List[str]],
Expand Down Expand Up @@ -123,3 +132,13 @@ async def __call__(

def __str__(self):
return f"EmbeddingPipeline(model_id={self.model_id})"

def _find_model_path(self, base_path):
# Check if the model files are directly in the base path
if any(file.endswith('.bin') or file.endswith('.safetensors') for file in os.listdir(base_path)):
return base_path

# If not, look in subdirectories
for root, dirs, files in os.walk(base_path):
if any(file.endswith('.bin') or file.endswith('.safetensors') for file in files):
return root
2 changes: 0 additions & 2 deletions runner/app/routes/embeddings.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
# app/routes/embeddings.py
import logging
import os
from typing import Union, List
from fastapi import APIRouter, Depends, status
from fastapi.responses import JSONResponse
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
Expand Down

0 comments on commit be5f890

Please sign in to comment.