Ishaan Shah commited on
Commit
0e57bfb
·
1 Parent(s): 0b341dc
Files changed (1) hide show
  1. main.py +52 -12
main.py CHANGED
@@ -1,25 +1,65 @@
1
- from fastapi import FastAPI
2
  import joblib
 
 
 
 
3
 
 
 
 
 
 
 
 
 
4
  def show_recommendations(product):
5
  Y = vectorizer.transform([product])
6
  prediction = model.predict(Y)
7
- return prediction,
8
 
 
9
  def get_cluster_terms(cluster_index):
10
  cluster_terms = [terms[ind] for ind in order_centroids[cluster_index, :10]]
11
  return cluster_terms
12
 
13
- model = joblib.load("./model.pkl")
14
- vectorizer = joblib.load("./vectorizer.pkl")
15
 
16
- order_centroids = model.cluster_centers_.argsort()[:, ::-1]
17
- terms = vectorizer.get_feature_names_out()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
- app = FastAPI()
20
 
21
- @app.post("/inference")
22
- def get_recommendations(product: str):
23
- cluster_index = int(show_recommendations(product)[0])
24
- cluster_terms = get_cluster_terms(cluster_index)
25
- return {"cluster": cluster_index, "top_terms": cluster_terms}
 
1
+ from fastapi import FastAPI, HTTPException, BackgroundTasks
2
  import joblib
3
+ import uuid
4
+ import asyncio
5
+ from pydantic import BaseModel
6
+ from typing import Dict, List
7
 
8
+ # Load your model and vectorizer
9
+ model = joblib.load("./model.pkl")
10
+ vectorizer = joblib.load("./vectorizer.pkl")
11
+
12
+ order_centroids = model.cluster_centers_.argsort()[:, ::-1]
13
+ terms = vectorizer.get_feature_names_out()
14
+
15
+ # Simulate function to show recommendations
16
  def show_recommendations(product):
17
  Y = vectorizer.transform([product])
18
  prediction = model.predict(Y)
19
+ return int(prediction[0]) # Ensure the prediction is a native Python int
20
 
21
+ # Get terms associated with a cluster
22
  def get_cluster_terms(cluster_index):
23
  cluster_terms = [terms[ind] for ind in order_centroids[cluster_index, :10]]
24
  return cluster_terms
25
 
26
+ app = FastAPI()
 
27
 
28
+ # In-memory store for inference batches
29
+ inferences: Dict[str, Dict] = {}
30
+
31
+ class BatchRequest(BaseModel):
32
+ products: List[str]
33
+
34
+ class BatchResponse(BaseModel):
35
+ inferenceId: str
36
+
37
+ class InferenceResponse(BaseModel):
38
+ inferenceId: str
39
+ status: str
40
+ results: List[Dict]
41
+
42
+ def process_batch(inferenceId, products):
43
+ results = []
44
+ for product in products:
45
+ cluster_index = show_recommendations(product)
46
+ cluster_terms = get_cluster_terms(cluster_index)
47
+ results.append({"product": product, "cluster": cluster_index, "top_terms": cluster_terms})
48
+ inferences[inferenceId]["status"] = "completed"
49
+ inferences[inferenceId]["result"] = results
50
+
51
+ @app.post("/inference/batch", response_model=BatchResponse)
52
+ async def start_get_recommendations_batch(batch_request: BatchRequest, background_tasks: BackgroundTasks):
53
+ inferenceId = str(uuid.uuid4())
54
+ inferences[inferenceId] = {"status": "in_progress", "result": []}
55
+ background_tasks.add_task(process_batch, inferenceId, batch_request.products)
56
+ return BatchResponse(inferenceId=inferenceId)
57
+
58
+ @app.get("/inference/batch/{inferenceId}", response_model=InferenceResponse)
59
+ async def get_recommendations_batch(inferenceId: str):
60
+ if inferenceId not in inferences:
61
+ raise HTTPException(status_code=404, detail="Inference ID not found")
62
+ inference = inferences[inferenceId]
63
+ return InferenceResponse(inferenceId=inferenceId, status=inference["status"], results=inference["result"])
64
 
 
65