Canstralian commited on
Commit
ed06dec
·
verified ·
1 Parent(s): 787a425

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -26
app.py CHANGED
@@ -1,53 +1,80 @@
1
- import os
2
- import logging
3
  from fastapi import FastAPI, HTTPException
4
  from pydantic import BaseModel
5
- from transformers import AutoAdapterModel, AutoTokenizer
 
 
6
 
7
- # Initialize the app
8
- app = FastAPI()
9
  logging.basicConfig(level=logging.INFO)
10
 
11
- # Load model and tokenizer once on startup
12
- MODEL_NAME = os.getenv("MODEL_NAME", "bert-base-uncased") # Set default model
13
- ADAPTER_NAME = os.getenv("ADAPTER_NAME", "Canstralian/RabbitRedux") # Adapter name
14
 
15
- try:
16
- logging.info("Loading model and adapter...")
17
- model = AutoAdapterModel.from_pretrained(MODEL_NAME)
18
- model.load_adapter(ADAPTER_NAME, set_active=True)
19
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
20
- logging.info("Model and adapter loaded successfully.")
21
- except Exception as e:
22
- logging.error("Error loading model or adapter:", exc_info=True)
23
- raise RuntimeError("Model or adapter loading failed.") from e
24
 
25
- # Define request and response data structures
26
  class PredictionRequest(BaseModel):
 
 
 
 
 
 
27
  text: str
28
 
29
  class PredictionResponse(BaseModel):
 
 
 
 
 
 
 
30
  text: str
31
  prediction: str
32
 
33
- # Endpoint for inference
34
  @app.post("/predict", response_model=PredictionResponse)
35
  async def predict(request: PredictionRequest):
 
 
 
 
 
 
 
 
 
 
 
 
36
  try:
37
- # Tokenize input text
38
  inputs = tokenizer(request.text, return_tensors="pt")
39
- # Perform inference
 
40
  outputs = model(**inputs)
41
- # Generate predicted text or classification (customize as needed)
 
42
  prediction = tokenizer.decode(outputs.logits.argmax(-1)[0], skip_special_tokens=True)
43
-
 
44
  return PredictionResponse(text=request.text, prediction=prediction)
45
  except Exception as e:
46
- logging.error("Error during prediction:", exc_info=True)
47
  raise HTTPException(status_code=500, detail="Prediction failed")
48
 
49
- # Health check endpoint
50
  @app.get("/health")
51
  async def health_check():
 
 
 
 
 
 
 
52
  return {"status": "healthy"}
53
-
 
1
+ import torch
 
2
  from fastapi import FastAPI, HTTPException
3
  from pydantic import BaseModel
4
+ import logging
5
+ import json
6
+ import os
7
 
8
+ # Set up logging configuration
 
9
  logging.basicConfig(level=logging.INFO)
10
 
11
+ # Initialize the FastAPI app
12
+ app = FastAPI()
 
13
 
14
+ # Load the trained model (adjust the path to your saved model)
15
+ model = torch.load("path/to/your/model.pth", map_location=torch.device("cpu")) # Replace with your actual model path
16
+ model.eval()
 
 
 
 
 
 
17
 
18
+ # Define the input and output format for prediction requests
19
  class PredictionRequest(BaseModel):
20
+ """
21
+ Data model for the prediction request.
22
+
23
+ Attributes:
24
+ text (str): Input text for model inference.
25
+ """
26
  text: str
27
 
28
  class PredictionResponse(BaseModel):
29
+ """
30
+ Data model for the prediction response.
31
+
32
+ Attributes:
33
+ text (str): The original input text.
34
+ prediction (str): The predicted result from the model.
35
+ """
36
  text: str
37
  prediction: str
38
 
39
+ # Define prediction endpoint
40
  @app.post("/predict", response_model=PredictionResponse)
41
  async def predict(request: PredictionRequest):
42
+ """
43
+ Endpoint for generating a prediction based on input text.
44
+
45
+ Args:
46
+ request (PredictionRequest): The request body containing the input text.
47
+
48
+ Returns:
49
+ PredictionResponse: The response body containing the original text and prediction.
50
+
51
+ Raises:
52
+ HTTPException: If any error occurs during the prediction process.
53
+ """
54
  try:
55
+ # Tokenize the input text (assuming you're using a tokenizer for text inputs)
56
  inputs = tokenizer(request.text, return_tensors="pt")
57
+
58
+ # Perform inference with the model
59
  outputs = model(**inputs)
60
+
61
+ # Get the predicted token and decode it back to text
62
  prediction = tokenizer.decode(outputs.logits.argmax(-1)[0], skip_special_tokens=True)
63
+
64
+ # Return the prediction response
65
  return PredictionResponse(text=request.text, prediction=prediction)
66
  except Exception as e:
67
+ logging.error("Error during prediction", exc_info=True)
68
  raise HTTPException(status_code=500, detail="Prediction failed")
69
 
70
+ # Define health check endpoint
71
  @app.get("/health")
72
  async def health_check():
73
+ """
74
+ Health check endpoint to verify if the service is up and running.
75
+
76
+ Returns:
77
+ dict: A dictionary containing the status of the service.
78
+ """
79
+ logging.info("Health check requested.")
80
  return {"status": "healthy"}