-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbge-m3-api.py
28 lines (22 loc) · 894 Bytes
/
bge-m3-api.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from FlagEmbedding import BGEM3FlagModel
import numpy as np
app = FastAPI()
# 加载模型
model = BGEM3FlagModel('BAAI/bge-m3', use_fp16=True)
class EmbeddingRequest(BaseModel):
input: list
class EmbeddingResponse(BaseModel):
data: list
@app.post("/v1/embeddings", response_model=EmbeddingResponse)
async def get_embeddings(request: EmbeddingRequest):
try:
embeddings = model.encode(request.input, batch_size=12, max_length=8192)['dense_vecs']
response_data = [{"object": "embedding", "embedding": emb.tolist(), "index": idx} for idx, emb in enumerate(embeddings)]
return {"data": response_data}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7005)