Canstralian commited on
Commit
3e66b84
·
verified ·
1 Parent(s): d84922a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -0
app.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ from fastapi import FastAPI, HTTPException
4
+ from pydantic import BaseModel
5
+ from transformers import AutoAdapterModel, AutoTokenizer
6
+
7
+ # Initialize the app
8
+ app = FastAPI()
9
+ logging.basicConfig(level=logging.INFO)
10
+
11
+ # Load model and tokenizer once on startup
12
+ MODEL_NAME = os.getenv("MODEL_NAME", "bert-base-uncased") # Set default model
13
+ ADAPTER_NAME = os.getenv("ADAPTER_NAME", "Canstralian/RabbitRedux") # Adapter name
14
+
15
+ try:
16
+ logging.info("Loading model and adapter...")
17
+ model = AutoAdapterModel.from_pretrained(MODEL_NAME)
18
+ model.load_adapter(ADAPTER_NAME, set_active=True)
19
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
20
+ logging.info("Model and adapter loaded successfully.")
21
+ except Exception as e:
22
+ logging.error("Error loading model or adapter:", exc_info=True)
23
+ raise RuntimeError("Model or adapter loading failed.") from e
24
+
25
+ # Define request and response data structures
26
+ class PredictionRequest(BaseModel):
27
+ text: str
28
+
29
+ class PredictionResponse(BaseModel):
30
+ text: str
31
+ prediction: str
32
+
33
+ # Endpoint for inference
34
+ @app.post("/predict", response_model=PredictionResponse)
35
+ async def predict(request: PredictionRequest):
36
+ try:
37
+ # Tokenize input text
38
+ inputs = tokenizer(request.text, return_tensors="pt")
39
+ # Perform inference
40
+ outputs = model(**inputs)
41
+ # Generate predicted text or classification (customize as needed)
42
+ prediction = tokenizer.decode(outputs.logits.argmax(-1)[0], skip_special_tokens=True)
43
+
44
+ return PredictionResponse(text=request.text, prediction=prediction)
45
+ except Exception as e:
46
+ logging.error("Error during prediction:", exc_info=True)
47
+ raise HTTPException(status_code=500, detail="Prediction failed")
48
+
49
+ # Health check endpoint
50
+ @app.get("/health")
51
+ async def health_check():
52
+ return {"status": "healthy"}
53
+