kauabarros-24 commited on
Commit
3e98b7b
·
1 Parent(s): 7effd1e

Retest models

Browse files
Files changed (1) hide show
  1. src/recommend/main.py +36 -2
src/recommend/main.py CHANGED
@@ -3,5 +3,39 @@ from src.recommend.routes.ai_routes import router as ai_router
3
  from src.recommend.routes.system import router as system_router
4
  app = FastAPI()
5
 
6
- app.include_router(ai_router)
7
- app.include_router(system_router)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  from src.recommend.routes.system import router as system_router
4
  app = FastAPI()
5
 
6
+ from src.recommend.models.ai_request import Request
7
+ import torch
8
+ from transformers import BertForSequenceClassification, BertTokenizer
9
+
10
+ model_name = "src/recommend/kalium_recommend"
11
+ model = BertForSequenceClassification.from_pretrained(model_name)
12
+ tokenizer = BertTokenizer.from_pretrained(model_name)
13
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
+ model.to(device)
15
+
16
+
17
+ @app.post("/ai")
18
+ async def ai(request: Request):
19
+ text = (
20
+ f"This ads is focused in: {request.audience}"
21
+ f"This category is: {request.category} "
22
+ f"This area is: {request.area}"
23
+ f"This sub area of ads: {request.sub_area}"
24
+ )
25
+
26
+ inputs = tokenizer.encode_plus(
27
+ text,
28
+ add_special_tokens=True,
29
+ return_tensors="pt",
30
+ padding='max_length',
31
+ truncation=True,
32
+ max_length=255
33
+ )
34
+ input_ids = inputs['input_ids'].to(device)
35
+ attention_mask = inputs['attention_mask'].to(device)
36
+
37
+ with torch.no_grad():
38
+ outputs = model(input_ids, attention_mask=attention_mask)
39
+ logits = outputs.logits
40
+
41
+ return {"similarity": logits.tolist()}