summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--python/fastembed-server.py12
1 files changed, 7 insertions, 5 deletions
diff --git a/python/fastembed-server.py b/python/fastembed-server.py
index fa3f7c82b..dd4a7a9c8 100644
--- a/python/fastembed-server.py
+++ b/python/fastembed-server.py
@@ -2,18 +2,20 @@ from fastembed import TextEmbedding
from fastapi import FastAPI
from pydantic import BaseModel
-model = TextEmbedding("snowflake/snowflake-arctic-embed-xs")
+models = {}
app = FastAPI()
class EmbeddingRequest(BaseModel):
model: str
- prompt: str
+ input: str
-@app.post("/api/embeddings")
+@app.post("/v1/embeddings")
def embeddings(request: EmbeddingRequest):
- embeddings = next(model.embed(request.prompt)).tolist()
- return {"embedding": embeddings}
+ model = models.get(request.model) or TextEmbedding(request.model)
+ models[request.model] = model
+ embeddings = next(model.embed(request.input)).tolist()
+ return {"data": [{"embedding": embeddings}]}
if __name__ == "__main__":
import uvicorn