ZealPyae commited on
Commit
46122a7
·
verified ·
1 Parent(s): 433bad5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -3
app.py CHANGED
@@ -1,7 +1,56 @@
1
- from fastapi import FastAPI
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  app = FastAPI()
4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  @app.get("/")
6
- def greet_json():
7
- return {"Hello": "World!"}
 
1
+ # from fastapi import FastAPI
2
+
3
+ # app = FastAPI()
4
+
5
+ # @app.get("/")
6
+ # def greet_json():
7
+ # return {"Hello": "World!"}
8
+
9
+ from fastapi import FastAPI, HTTPException
10
+ from pydantic import BaseModel
11
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
12
+ import torch
13
 
14
  app = FastAPI()
15
 
16
+ # Check if CUDA is available
17
+ if torch.cuda.is_available():
18
+ device = torch.device("cuda:0")
19
+ else:
20
+ device = torch.device("cpu")
21
+
22
+ # Load the tokenizer and model
23
+ tokenizer = AutoTokenizer.from_pretrained("kmack/malicious-url-detection")
24
+ model = AutoModelForSequenceClassification.from_pretrained("kmack/malicious-url-detection")
25
+ model = model.to(device)
26
+
27
+ # Define the request model
28
+ class URLRequest(BaseModel):
29
+ url: str
30
+
31
+ # Prediction function
32
+ def get_prediction(input_text: str) -> dict:
33
+ label2id = model.config.label2id
34
+ inputs = tokenizer(input_text, return_tensors='pt', truncation=True)
35
+ inputs = inputs.to(device)
36
+ outputs = model(**inputs)
37
+ logits = outputs.logits
38
+ sigmoid = torch.nn.Sigmoid()
39
+ probs = sigmoid(logits.squeeze().cpu())
40
+ probs = probs.detach().numpy()
41
+ for i, k in enumerate(label2id.keys()):
42
+ label2id[k] = probs[i]
43
+ label2id = {k: float(v) for k, v in sorted(label2id.items(), key=lambda item: item[1].item(), reverse=True)}
44
+ return label2id
45
+
46
+ # Define the API endpoint for URL prediction
47
+ @app.post("/predict")
48
+ async def predict(url_request: URLRequest):
49
+ url_to_check = url_request.url
50
+ result = get_prediction(url_to_check)
51
+ return {"prediction": result}
52
+
53
+ # Health check endpoint
54
  @app.get("/")
55
+ async def read_root():
56
+ return {"message": "API is up and running"}