Ishaan Shah
commited on
Commit
·
0e57bfb
1
Parent(s):
0b341dc
async api
Browse files
main.py
CHANGED
@@ -1,25 +1,65 @@
|
|
1 |
-
from fastapi import FastAPI
|
2 |
import joblib
|
|
|
|
|
|
|
|
|
3 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
def show_recommendations(product):
|
5 |
Y = vectorizer.transform([product])
|
6 |
prediction = model.predict(Y)
|
7 |
-
return prediction
|
8 |
|
|
|
9 |
def get_cluster_terms(cluster_index):
|
10 |
cluster_terms = [terms[ind] for ind in order_centroids[cluster_index, :10]]
|
11 |
return cluster_terms
|
12 |
|
13 |
-
|
14 |
-
vectorizer = joblib.load("./vectorizer.pkl")
|
15 |
|
16 |
-
|
17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
|
19 |
-
app = FastAPI()
|
20 |
|
21 |
-
@app.post("/inference")
|
22 |
-
def get_recommendations(product: str):
|
23 |
-
cluster_index = int(show_recommendations(product)[0])
|
24 |
-
cluster_terms = get_cluster_terms(cluster_index)
|
25 |
-
return {"cluster": cluster_index, "top_terms": cluster_terms}
|
|
|
1 |
+
from fastapi import FastAPI, HTTPException, BackgroundTasks
|
2 |
import joblib
|
3 |
+
import uuid
|
4 |
+
import asyncio
|
5 |
+
from pydantic import BaseModel
|
6 |
+
from typing import Dict, List
|
7 |
|
8 |
+
# Load your model and vectorizer
|
9 |
+
model = joblib.load("./model.pkl")
|
10 |
+
vectorizer = joblib.load("./vectorizer.pkl")
|
11 |
+
|
12 |
+
order_centroids = model.cluster_centers_.argsort()[:, ::-1]
|
13 |
+
terms = vectorizer.get_feature_names_out()
|
14 |
+
|
15 |
+
# Simulate function to show recommendations
|
16 |
def show_recommendations(product):
|
17 |
Y = vectorizer.transform([product])
|
18 |
prediction = model.predict(Y)
|
19 |
+
return int(prediction[0]) # Ensure the prediction is a native Python int
|
20 |
|
21 |
+
# Get terms associated with a cluster
|
22 |
def get_cluster_terms(cluster_index):
|
23 |
cluster_terms = [terms[ind] for ind in order_centroids[cluster_index, :10]]
|
24 |
return cluster_terms
|
25 |
|
26 |
+
app = FastAPI()
|
|
|
27 |
|
28 |
+
# In-memory store for inference batches
|
29 |
+
inferences: Dict[str, Dict] = {}
|
30 |
+
|
31 |
+
class BatchRequest(BaseModel):
|
32 |
+
products: List[str]
|
33 |
+
|
34 |
+
class BatchResponse(BaseModel):
|
35 |
+
inferenceId: str
|
36 |
+
|
37 |
+
class InferenceResponse(BaseModel):
|
38 |
+
inferenceId: str
|
39 |
+
status: str
|
40 |
+
results: List[Dict]
|
41 |
+
|
42 |
+
def process_batch(inferenceId, products):
|
43 |
+
results = []
|
44 |
+
for product in products:
|
45 |
+
cluster_index = show_recommendations(product)
|
46 |
+
cluster_terms = get_cluster_terms(cluster_index)
|
47 |
+
results.append({"product": product, "cluster": cluster_index, "top_terms": cluster_terms})
|
48 |
+
inferences[inferenceId]["status"] = "completed"
|
49 |
+
inferences[inferenceId]["result"] = results
|
50 |
+
|
51 |
+
@app.post("/inference/batch", response_model=BatchResponse)
|
52 |
+
async def start_get_recommendations_batch(batch_request: BatchRequest, background_tasks: BackgroundTasks):
|
53 |
+
inferenceId = str(uuid.uuid4())
|
54 |
+
inferences[inferenceId] = {"status": "in_progress", "result": []}
|
55 |
+
background_tasks.add_task(process_batch, inferenceId, batch_request.products)
|
56 |
+
return BatchResponse(inferenceId=inferenceId)
|
57 |
+
|
58 |
+
@app.get("/inference/batch/{inferenceId}", response_model=InferenceResponse)
|
59 |
+
async def get_recommendations_batch(inferenceId: str):
|
60 |
+
if inferenceId not in inferences:
|
61 |
+
raise HTTPException(status_code=404, detail="Inference ID not found")
|
62 |
+
inference = inferences[inferenceId]
|
63 |
+
return InferenceResponse(inferenceId=inferenceId, status=inference["status"], results=inference["result"])
|
64 |
|
|
|
65 |
|
|
|
|
|
|
|
|
|
|