imalexianne commited on
Commit
3546eea
1 Parent(s): 0be50b6

Add main.py

Browse files
Files changed (1) hide show
  1. main.py +53 -0
main.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, Query
2
+ from pydantic import BaseModel
3
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
4
+
5
+ app = FastAPI()
6
+
7
+ # Load the pre-trained model and tokenizer
8
+ model_name = "imalexianne/Movie_Review_Roberta"
9
+ model = AutoModelForSequenceClassification.from_pretrained(model_name)
10
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
11
+ # tokenizer = AutoTokenizer.from_pretrained("username/model_name")
12
+
13
+ # Create a sentiment analysis pipeline
14
+ sentiment = pipeline("sentiment-analysis", model=model, tokenizer=tokenizer)
15
+
16
+ # Create a dictionary to map sentiment labels to positive and negative strings
17
+ sentiment_label_mapping = {
18
+ "LABEL_1": "positive",
19
+ "LABEL_0": "negative",
20
+ }
21
+
22
+ # Define a request body model
23
+ class SentimentRequest(BaseModel):
24
+ text: str
25
+
26
+ # Define a response model
27
+ class SentimentResponse(BaseModel):
28
+ sentiment: str # 1 for positive, 0 for negative
29
+ score: float
30
+ @app.get("/")
31
+ def read_root():
32
+ explanation = {
33
+ 'message': "Welcome to the Movie Review Sentiment Prediction App",
34
+ 'description': "This API allows you to predict Movie Review Sentiment based on a given text",
35
+ 'usage': "Submit a POST request to /predict with text to make predictions.",
36
+
37
+ }
38
+ return explanation
39
+ # Create an endpoint for sentiment analysis with query parameter
40
+ @app.get("/sentiment/")
41
+ async def analyze_sentiment(text: str = Query(..., description="Input text for sentiment analysis")):
42
+ result = sentiment(text)
43
+ sentiment_label = result[0]["label"]
44
+ sentiment_score = result[0]["score"]
45
+
46
+ sentiment_value = sentiment_label_mapping.get(sentiment_label, -1) # Default to -1 for unknown labels
47
+
48
+ return SentimentResponse(sentiment=sentiment_value, score=sentiment_score)
49
+
50
+ if __name__ == "__main__":
51
+ import uvicorn
52
+ uvicorn.run(app, host="0.0.0.0", port=8000)
53
+