ia_back / main.py
Ilyas KHIAT
whatif
060b78c
raw
history blame
4.4 kB
from fastapi import FastAPI, HTTPException, UploadFile, File,Request,Depends,status,BackgroundTasks
from fastapi.security import OAuth2PasswordBearer
from pydantic import BaseModel
from typing import Optional, List
from uuid import uuid4
import os
from dotenv import load_dotenv
from rag import *
from fastapi.responses import StreamingResponse
import json
from prompt import *
from fastapi.middleware.cors import CORSMiddleware
import requests
import pandas as pd
load_dotenv()
## setup authorization
api_keys = [os.environ.get("FASTAPI_API_KEY")]
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") # use token authentication
def api_key_auth(api_key: str = Depends(oauth2_scheme)):
if api_key not in api_keys:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Forbidden"
)
dev_mode = os.environ.get("DEV")
if dev_mode == "True":
app = FastAPI()
else:
app = FastAPI(dependencies=[Depends(api_key_auth)])
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"])
# Pydantic model for the form data
class verify_response_model(BaseModel):
response: str = Field(description="The response from the user to the question")
answers: list[str] = Field(description="The possible answers to the question to test if the user read the entire book")
question: str = Field(description="The question asked to the user to test if they read the entire book")
class UserInput(BaseModel):
query: str
stream: Optional[bool] = False
messages: Optional[list[dict]] = []
class Artwork(BaseModel):
name: str
artist: str
image_url: str
date: str
description: str
class WhatifInput(BaseModel):
question: str
response: str
# Global variable to store the data
artworks_data = []
def load_data():
global artworks_data
# Provide the path to your local spreadsheet
spreadsheet_path = "data.xlsx"
# Read the spreadsheet into a DataFrame
df = pd.read_excel(spreadsheet_path, sheet_name='Sheet1') # Adjust sheet_name as needed
df = df.fillna(False)
# Convert DataFrame to a list of dictionaries
df_filtered = df[df['Publication'] == True]
artworks_data = df_filtered.to_dict(orient='records')
print("Data loaded successfully")
load_data()
#endpoinds
@app.get("/artworks/{artist_name}")
async def get_artworks_by_artist(artist_name: str):
artist_name_lower = artist_name.lower()
results = []
for artwork in artworks_data:
if artist_name_lower in artwork['Artiste'].lower():
result = {
'name':artwork['Titre français'],
'artist':artwork['Artiste'],
'image_url':artwork['Image_URL'],
'date':str(artwork['Date']), # Ensure date is a string
'description':artwork['Media']
}
results.append(result)
if not results:
raise HTTPException(status_code=404, detail="Artist not found")
return results
@app.post("/generate_sphinx")
async def generate_sphinx():
try:
sphinx : sphinx_output = generate_sphinx_response()
return {"question": sphinx.question, "answers": sphinx.answers}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/verify_sphinx")
async def verify_sphinx(response: verify_response_model):
try:
score : bool = verify_response(response.response, response.answers, response.question)
return {"score": score}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/generate")
async def generate(user_input: UserInput):
try:
print(user_input.stream,user_input.query)
if user_input.stream:
return StreamingResponse(generate_stream(user_input.query,user_input.messages,stream=True),media_type="application/json")
else:
return generate_stream(user_input.query,user_input.messages,stream=False)
except Exception as e:
return {"message": str(e)}
@app.post("/whatif")
async def generate_whatif(whatif_input: WhatifInput):
try:
print(whatif_input)
return generate_whatif_stream(whatif_input["question"],whatif_input["response"])
except Exception as e:
return {"message": str(e)}