Skip to content

Commit 825f7df

Browse files
committed
wip
1 parent 1dff7b0 commit 825f7df

File tree

2 files changed

+23
-6
lines changed

2 files changed

+23
-6
lines changed

runner/app/pipelines/embeddings.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# app/pipelines/embeddings.py
21
import logging
32
import os
43
from typing import List, Union, Dict, Any, Optional
@@ -9,6 +8,7 @@
98
from InstructorEmbedding import INSTRUCTOR
109
from dataclasses import dataclass
1110
from enum import Enum
11+
from huggingface_hub import file_download
1212

1313
logger = logging.getLogger(__name__)
1414

@@ -39,6 +39,16 @@ def __init__(self, model_id: str):
3939
logger.info("Initializing embedding pipeline")
4040

4141
self.model_id = model_id
42+
folder_name = file_download.repo_folder_name(
43+
repo_id=model_id, repo_type="model")
44+
base_path = os.path.join(get_model_dir(), folder_name)
45+
46+
# Find the actual model path
47+
self.local_model_path = self._find_model_path(base_path)
48+
49+
if not self.local_model_path:
50+
raise ValueError(f"Could not find model files for {model_id}")
51+
4252
self.device = "cuda" if torch.cuda.is_available() else "cpu"
4353

4454
# Get configuration from environment
@@ -50,16 +60,15 @@ def __init__(self, model_id: str):
5060

5161
try:
5262
if self.model_type == EmbeddingModelType.SENTENCE_TRANSFORMER:
53-
self.model = SentenceTransformer(model_id).to(self.device)
63+
self.model = SentenceTransformer(self.local_model_path).to(self.device)
5464
elif self.model_type == EmbeddingModelType.INSTRUCTOR:
55-
self.model = INSTRUCTOR(model_id).to(self.device)
65+
self.model = INSTRUCTOR(self.local_model_path).to(self.device)
5666

5767
logger.info(f"Model loaded successfully on {self.device}")
5868

5969
except Exception as e:
6070
logger.error(f"Error loading model: {e}")
6171
raise
62-
6372
async def generate(
6473
self,
6574
texts: Union[str, List[str]],
@@ -123,3 +132,13 @@ async def __call__(
123132

124133
def __str__(self):
125134
return f"EmbeddingPipeline(model_id={self.model_id})"
135+
136+
def _find_model_path(self, base_path):
137+
# Check if the model files are directly in the base path
138+
if any(file.endswith('.bin') or file.endswith('.safetensors') for file in os.listdir(base_path)):
139+
return base_path
140+
141+
# If not, look in subdirectories
142+
for root, dirs, files in os.walk(base_path):
143+
if any(file.endswith('.bin') or file.endswith('.safetensors') for file in files):
144+
return root

runner/app/routes/embeddings.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
1-
# app/routes/embeddings.py
21
import logging
32
import os
4-
from typing import Union, List
53
from fastapi import APIRouter, Depends, status
64
from fastapi.responses import JSONResponse
75
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer

0 commit comments

Comments
 (0)