1
- # app/pipelines/embeddings.py
2
1
import logging
3
2
import os
4
3
from typing import List , Union , Dict , Any , Optional
9
8
from InstructorEmbedding import INSTRUCTOR
10
9
from dataclasses import dataclass
11
10
from enum import Enum
11
+ from huggingface_hub import file_download
12
12
13
13
logger = logging .getLogger (__name__ )
14
14
@@ -39,6 +39,16 @@ def __init__(self, model_id: str):
39
39
logger .info ("Initializing embedding pipeline" )
40
40
41
41
self .model_id = model_id
42
+ folder_name = file_download .repo_folder_name (
43
+ repo_id = model_id , repo_type = "model" )
44
+ base_path = os .path .join (get_model_dir (), folder_name )
45
+
46
+ # Find the actual model path
47
+ self .local_model_path = self ._find_model_path (base_path )
48
+
49
+ if not self .local_model_path :
50
+ raise ValueError (f"Could not find model files for { model_id } " )
51
+
42
52
self .device = "cuda" if torch .cuda .is_available () else "cpu"
43
53
44
54
# Get configuration from environment
@@ -50,16 +60,15 @@ def __init__(self, model_id: str):
50
60
51
61
try :
52
62
if self .model_type == EmbeddingModelType .SENTENCE_TRANSFORMER :
53
- self .model = SentenceTransformer (model_id ).to (self .device )
63
+ self .model = SentenceTransformer (self . local_model_path ).to (self .device )
54
64
elif self .model_type == EmbeddingModelType .INSTRUCTOR :
55
- self .model = INSTRUCTOR (model_id ).to (self .device )
65
+ self .model = INSTRUCTOR (self . local_model_path ).to (self .device )
56
66
57
67
logger .info (f"Model loaded successfully on { self .device } " )
58
68
59
69
except Exception as e :
60
70
logger .error (f"Error loading model: { e } " )
61
71
raise
62
-
63
72
async def generate (
64
73
self ,
65
74
texts : Union [str , List [str ]],
@@ -123,3 +132,13 @@ async def __call__(
123
132
124
133
def __str__ (self ):
125
134
return f"EmbeddingPipeline(model_id={ self .model_id } )"
135
+
136
+ def _find_model_path (self , base_path ):
137
+ # Check if the model files are directly in the base path
138
+ if any (file .endswith ('.bin' ) or file .endswith ('.safetensors' ) for file in os .listdir (base_path )):
139
+ return base_path
140
+
141
+ # If not, look in subdirectories
142
+ for root , dirs , files in os .walk (base_path ):
143
+ if any (file .endswith ('.bin' ) or file .endswith ('.safetensors' ) for file in files ):
144
+ return root
0 commit comments