rahim-khan-iitg commited on
Commit
02cb283
·
verified ·
1 Parent(s): fbc3f56

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +168 -0
  2. sentiment.py +14 -0
app.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, File, UploadFile, HTTPException, Header
2
+ from pydantic import BaseModel,EmailStr, field_validator
3
+ from fastapi.middleware.cors import CORSMiddleware
4
+ import datetime
5
+ import jwt
6
+ import pandas as pd
7
+ from io import StringIO
8
+ import re
9
+ import os
10
+ import psycopg2
11
+ from bcrypt import hashpw,checkpw, gensalt
12
+ from sentiment import predict_sentiment
13
+ from dotenv import load_dotenv,find_dotenv
14
+
15
+ load_dotenv(find_dotenv(raise_error_if_not_found=True))
16
+
17
+ DATABASE_CONFIG = {
18
+ 'dbname': os.getenv("pg_db_name"),
19
+ 'user': os.getenv("pg_user"),
20
+ 'password': os.getenv("pg_password"),
21
+ 'host': os.getenv("pg_host"),
22
+ 'port': os.getenv("pg_port")
23
+ }
24
+
25
+ app = FastAPI()
26
+
27
+ # Add CORS Middleware
28
+ app.add_middleware(
29
+ CORSMiddleware,
30
+ allow_origins=["*"], # Allow requests from any origin (can be restricted)
31
+ allow_credentials=True,
32
+ allow_methods=["*"],
33
+ allow_headers=["*"],
34
+ )
35
+
36
+ # Secret key for JWT
37
+ SECRET_KEY = "your-secret-key"
38
+ ALGORITHM = "HS256" # Hashing algorithm for JWT
39
+
40
+
41
+
42
+ class LoginRequest(BaseModel):
43
+ email: EmailStr
44
+ password: str
45
+
46
+ @field_validator("password")
47
+ def validate_password(cls, value):
48
+ if len(value) < 8:
49
+ raise ValueError("Password must be at least 8 characters long.")
50
+ if not re.search(r"[A-Za-z]", value):
51
+ raise ValueError("Password must contain at least one letter.")
52
+ if not re.search(r"[0-9]", value):
53
+ raise ValueError("Password must contain at least one number.")
54
+ return value
55
+
56
+
57
+ def create_jwt_token(email: str):
58
+ """
59
+ Create a JWT token with an expiration time.
60
+ """
61
+ payload = {
62
+ "sub": email, # Subject (user's email)
63
+ "exp": datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(hours=1), # Token expiry
64
+ "iat": datetime.datetime.now(datetime.timezone.utc), # Issued at time
65
+ }
66
+ return jwt.encode(payload, SECRET_KEY, algorithm=ALGORITHM)
67
+
68
+
69
+ def verify_jwt_token(token: str):
70
+ """
71
+ Verify and decode the JWT token.
72
+ """
73
+ try:
74
+ payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
75
+ return payload["sub"] # Return the email (or user identifier)
76
+ except jwt.ExpiredSignatureError:
77
+ raise HTTPException(status_code=401, detail="Token has expired")
78
+ except jwt.InvalidTokenError:
79
+ raise HTTPException(status_code=401, detail="Invalid token")
80
+
81
+ def create_user(email:str,password:str)->bool:
82
+ try:
83
+ hashed_password = hashpw(password.encode('utf-8'), gensalt())
84
+ conn = psycopg2.connect(**DATABASE_CONFIG)
85
+ print("Connection successful!")
86
+ cursor = conn.cursor()
87
+ cursor.execute("""
88
+ INSERT INTO hanabi_user (email,password)
89
+ VALUES (%s, %s)
90
+ """, (
91
+ email,
92
+ hashed_password.decode("utf-8")
93
+ ))
94
+ conn.commit()
95
+ cursor.close()
96
+ conn.close()
97
+ return True
98
+ except psycopg2.Error as _:
99
+ print(_)
100
+ return False
101
+
102
+ def validate_user(email: str, password: str) -> bool:
103
+ try:
104
+ # Fetch the stored hashed password from the database
105
+ conn = psycopg2.connect(**DATABASE_CONFIG)
106
+ cursor = conn.cursor()
107
+ query = "SELECT password FROM hanabi_user WHERE email=%s;"
108
+ cursor.execute(query, (email,))
109
+ row = cursor.fetchone() # Fetch one row
110
+ cursor.close()
111
+ conn.close()
112
+ # print(row)
113
+ if row:
114
+ stored_hashed_password = row[0] # The hashed password from the DB
115
+ # Compare the entered password with the stored hashed password
116
+ if checkpw(password.encode('utf-8'), stored_hashed_password.encode('utf-8')):
117
+ return True # Password is correct
118
+ else:
119
+ return False # Password is incorrect
120
+ else:
121
+ return False # User not found
122
+ except Exception as _:
123
+ print(_)
124
+ return False
125
+
126
+ @app.post("/login/")
127
+ def login(data: LoginRequest):
128
+ if validate_user(data.email,data.password):
129
+ # Generate a JWT token
130
+ token = create_jwt_token(data.email)
131
+ return {"token": token}
132
+ else:
133
+ raise HTTPException(status_code=401, detail="Invalid email or password")
134
+
135
+ @app.post("/signup/")
136
+ def signup(data: LoginRequest):
137
+ if create_user(data.email,data.password):
138
+ return {"response":"successful"}
139
+ else:
140
+ raise HTTPException(status_code=401, detail="error")
141
+
142
+
143
+ @app.post("/upload-csv/")
144
+ async def upload_csv(
145
+ file: UploadFile = File(...),
146
+ authorization: str = Header(None), # Get the Authorization header
147
+ ):
148
+ # Verify the token
149
+ if not authorization or not authorization.startswith("Bearer "):
150
+ raise HTTPException(status_code=401, detail="Authorization token missing or invalid")
151
+ token = authorization.split(" ")[1] # Extract the token
152
+ email = verify_jwt_token(token) # Verify token and get the email
153
+
154
+ # Process the file
155
+ content = await file.read()
156
+ df = pd.read_csv(StringIO(content.decode("utf-8")))
157
+ texts = df["text"].tolist()
158
+ sentiments = predict_sentiment(texts)
159
+ df["sentiment"] = sentiments
160
+ sentiment_counts = df["sentiment"].value_counts().to_dict()
161
+ sentiment_records = df.to_dict(orient="index")
162
+
163
+ # Return results
164
+ return {
165
+ "user": email, # Include user info (from token)
166
+ "sentiment_counts": sentiment_counts,
167
+ "sentiments": list(sentiment_records.values()),
168
+ }
sentiment.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
2
+ import torch
3
+
4
+ model_name = "tabularisai/multilingual-sentiment-analysis"
5
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
6
+ model = AutoModelForSequenceClassification.from_pretrained(model_name)
7
+
8
+ def predict_sentiment(texts):
9
+ inputs = tokenizer(texts, return_tensors="pt", truncation=True, padding=True, max_length=512)
10
+ with torch.no_grad():
11
+ outputs = model(**inputs)
12
+ probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
13
+ sentiment_map = {0: "Very Negative", 1: "Negative", 2: "Neutral", 3: "Positive", 4: "Very Positive"}
14
+ return [sentiment_map[p] for p in torch.argmax(probabilities, dim=-1).tolist()]