Spaces:
Runtime error
Runtime error
import gradio as gr | |
import os | |
from functools import partial | |
from config import file_id_htl_biotech, file_id_kamera_express, file_id_smart_sd, file_id_sunday_naturals | |
import gdown | |
import pandas as pd | |
from hogwarts import get_answer | |
from hogwats_gemini import get_answer as get_answer_gemini | |
from evaluator import eval_answer | |
import nltk | |
nltk.download('punkt') | |
os.system("apt update; yes | apt-get install poppler-utils; yes | ls") | |
dico = {file_id_htl_biotech : {"name" : "htl-biotechnology", "data" : None}, | |
file_id_smart_sd : {"name" : "smart-sd", "data" : None}, | |
file_id_kamera_express : {"name" : "kamera-express", "data" : None}, | |
file_id_sunday_naturals : {"name" : "sunday-naturals", "data" : None}, } | |
choices = ["htl-biotechnology", | |
"smart-sd", | |
"kamera-express", | |
"sunday-naturals"] | |
title = "AI4PE - Olivier and Adam \n contact: adamrida.ra@gmail.com or sp.olivier@hotmail.com" | |
for file_id in dico: | |
print("GOING FOR ", dico[file_id]["name"]) | |
download_url = f'https://drive.google.com/uc?id={file_id}' | |
# Download the file using gdown | |
output = 'downloaded_file.csv' | |
gdown.download(download_url, output, quiet=False) | |
# Read the CSV file into a DataFrame | |
df = pd.read_csv(output, sep=";")[["content", "embeddings"]].replace("transcript_", "expert_meeting_notes_") | |
dico[file_id]["data"] = df | |
id_to_name_mapper = { | |
file_id_htl_biotech : 'htl-biotechnology', | |
file_id_smart_sd : 'smart-sd', | |
file_id_kamera_express : 'kamera-express', | |
file_id_sunday_naturals : 'sunday-naturals', | |
} | |
name_to_id_mapper = { | |
'htl-biotechnology': file_id_htl_biotech, | |
'smart-sd': file_id_smart_sd, | |
'kamera-express': file_id_kamera_express, | |
'sunday-naturals': file_id_sunday_naturals, | |
} | |
def get_list_files(company, dico=dico, name_to_id_mapper=name_to_id_mapper): | |
pdfs = [] | |
web_pages = [] | |
transcript = [] | |
for ext in dico[name_to_id_mapper[company]]["data"].content.values: | |
# break | |
filename = ext.split("\n")[0] | |
if "SOURCE: COMPANY WEBSITE" in ext: | |
filename=filename.replace("https::", "").replace("https:", "").replace(".txt", "").replace(".com", " ").replace(".", " Page: ") | |
web_pages.append(filename) | |
if "SOURCE: PDF FILE" in ext: | |
# nb_pdfs += 1 | |
filename = "SOURCE: UPLOADED PDF - " + ext.split("PATH_FILE =")[1].split("'}\"")[0].split("/pdfs/")[1].split("/png")[0]+".pdf" | |
pdfs.append(filename) | |
# break | |
# ext | |
pass | |
if "SOURCE: NOTES FROM EXPERT CALL" in ext: | |
# nb_expert_transcripts += 1 | |
filename = ext.replace("_1 copy", "").replace("transcript ", "Note #").replace("transcript_1", "Note #2").replace("transcript", "Note #1").replace(".txt", "").split("\n")[0] | |
transcript.append(filename) | |
pass | |
# print(filename) | |
pdfs_string = "## Uploaded PDF files: \n" + "\n\n".join(list(set(pdfs))) | |
web_pages = "## Enriched from the web: \n" + "\n\n".join(list(set(web_pages))) | |
transcript = "## Uploaded notes from expert calls: \n" + "\n\n".join(list(set(transcript))) | |
return web_pages, pdfs_string, transcript | |
def get_data_room_overview(company, dico = dico,name_to_id_mapper = name_to_id_mapper): | |
nb_pdfs = 0 | |
nb_expert_transcripts = 0 | |
nb_web = 0 | |
for ext in dico[name_to_id_mapper[company]]["data"].content.values: | |
if "SOURCE: COMPANY WEBSITE" in ext: | |
nb_web += 1 | |
if "SOURCE: PDF FILE" in ext: | |
nb_pdfs += 1 | |
if "SOURCE: NOTES FROM EXPERT CALL" in ext: | |
nb_expert_transcripts += 1 | |
disp = f"""--- | |
### Overview of the data room | |
Enriched data room with: Linkedin profile and company website | |
Volumetry: | |
- {nb_pdfs} passages from PDF files | |
- {nb_web} passages from company website | |
- {nb_expert_transcripts} passages from notes of expert calls | |
""" | |
sunday_naturals_web, sunday_naturals_pdfs, sunday_naturals_expert = get_list_files("sunday-naturals", dico, name_to_id_mapper) | |
smart_sd_web, smart_sd_pdfs, smart_sd_expert, = get_list_files("smart-sd", dico, name_to_id_mapper) | |
htl_biotech_web, htl_biotech_pdfs, htl_biotech_expert, = get_list_files("htl-biotechnology", dico, name_to_id_mapper) | |
kamera_express_web, kamera_express_pdfs, kamera_express_expert =get_list_files("kamera-express", dico, name_to_id_mapper) | |
return disp, sunday_naturals_web, sunday_naturals_pdfs,sunday_naturals_expert,smart_sd_web,smart_sd_pdfs,smart_sd_expert,htl_biotech_web,htl_biotech_pdfs,htl_biotech_expert,kamera_express_web,kamera_express_pdfs,kamera_express_expert | |
def generate_chat_answer(company_name, query): | |
df = dico[name_to_id_mapper[company_name]]["data"] | |
response = get_answer(df, 15, query) | |
print("=====> Evaluating answer quality...") | |
eval_score = eval(eval_answer(query, response)) | |
eval_md = f""" | |
### Evalation of how well the response answer the intial question | |
Score of **{eval_score["score"]}/5** | |
Rationale: | |
{eval_score["rationale_based_on_scoring_rules"]} | |
""" | |
return response, eval_md | |
def generate_chat_answer_gemini(company_name, query): | |
df = dico[name_to_id_mapper[company_name]]["data"] | |
content = df["content"].values | |
response = get_answer_gemini(query, company_name, content) | |
print("=====> Evaluating answer quality...") | |
eval_score = eval(eval_answer(query, response)) | |
eval_md = f""" | |
### Evalation of how well the response answer the intial question | |
Score of **{eval_score["score"]}/5** | |
Rationale: | |
{eval_score["rationale_based_on_scoring_rules"]} | |
""" | |
return response, eval_md | |
with gr.Blocks(title=title,theme='nota-ai/theme') as demo: | |
gr.Markdown(f"## {title}") | |
with gr.Row(equal_height=True): | |
with gr.Column(scale=1): | |
company_name = gr.Dropdown(choices=choices, label="Select company") | |
submit_button = gr.Button(value="Load workspace") | |
data_room_overview = gr.Markdown("---\n### Overview of the data room") | |
with gr.Column(scale=6): | |
with gr.Tab("Chat - Baseline"): | |
with gr.Row(): | |
with gr.Column(scale=5): | |
chat_input = gr.Textbox(placeholder="Chat input", lines=2, label="Retrieve anything from the dataroom") | |
with gr.Column(scale=1): | |
chat_submit_button = gr.Button(value="Submit") | |
with gr.Accordion("Accuracy score", open=False): | |
evaluator = gr.Markdown("Waiting for answer to evaluate...") | |
chat_output = gr.Markdown("Waiting for question...") | |
with gr.Tab("Chat - ICL", interactive=True): | |
with gr.Row(): | |
with gr.Column(scale=5): | |
chat_input_gemini = gr.Textbox(placeholder="Chat input", lines=2, label="Retrieve anything from the dataroom") | |
with gr.Column(scale=1): | |
chat_submit_button_gemini = gr.Button(value="Submit") | |
with gr.Accordion("Accuracy score", open=False): | |
evaluator_gemini = gr.Markdown("Waiting for answer to evaluate...") | |
chat_output_gemini = gr.Markdown("Waiting for question...") | |
with gr.Tab("Data", interactive = True): | |
with gr.Tab("Sunday Naturals"): | |
with gr.Row(): | |
with gr.Column(): | |
sunday_naturals_web = gr.Markdown("Sources obtained from website") | |
with gr.Column(): | |
sunday_naturals_pdfs = gr.Markdown("Sources obtained from uploaded pdfs") | |
# with gr.Column(): | |
sunday_naturals_expert = gr.Markdown("Sources obtained from expert call notes") | |
pass | |
with gr.Tab("Smart SD"): | |
with gr.Row(): | |
with gr.Column(): | |
smart_sd_web = gr.Markdown("Sources obtained from website") | |
with gr.Column(): | |
smart_sd_pdfs = gr.Markdown("Sources obtained from uploaded pdfs") | |
# with gr.Column(): | |
smart_sd_expert = gr.Markdown("Sources obtained from expert call notes") | |
pass | |
with gr.Tab("HTL Biotech"): | |
with gr.Row(): | |
with gr.Column(): | |
htl_biotech_web = gr.Markdown("Sources obtained from website") | |
with gr.Column(): | |
htl_biotech_pdfs = gr.Markdown("Sources obtained from uploaded pdfs") | |
# with gr.Column(): | |
htl_biotech_expert = gr.Markdown("Sources obtained from expert call notes") | |
pass | |
with gr.Tab("Kamera Express"): | |
with gr.Row(): | |
with gr.Column(): | |
kamera_express_web = gr.Markdown("Sources obtained from website") | |
with gr.Column(): | |
kamera_express_pdfs = gr.Markdown("Sources obtained from uploaded pdfs") | |
# with gr.Column(): | |
kamera_express_expert = gr.Markdown("Sources obtained from expert call notes") | |
pass | |
with gr.Tab("Benchmark", interactive=False): | |
pass | |
fn = partial(get_data_room_overview) | |
fn_chat = partial(generate_chat_answer) | |
fn_chat_gemini = partial(generate_chat_answer_gemini) | |
submit_button.click(fn=fn, inputs=[company_name], outputs=[ | |
data_room_overview, | |
sunday_naturals_web, | |
sunday_naturals_pdfs, | |
sunday_naturals_expert, | |
smart_sd_web, | |
smart_sd_pdfs, | |
smart_sd_expert, | |
htl_biotech_web, | |
htl_biotech_pdfs, | |
htl_biotech_expert, | |
kamera_express_web, | |
kamera_express_pdfs, | |
kamera_express_expert]) | |
chat_submit_button.click(fn=fn_chat, inputs=[company_name, chat_input], outputs=[chat_output, evaluator]) | |
chat_submit_button_gemini.click(fn=fn_chat_gemini, inputs=[company_name, chat_input_gemini], outputs=[chat_output_gemini, evaluator_gemini]) | |
login = os.environ.get("login") | |
pwd = os.environ.get("pwd") | |
demo.launch(max_threads=40, max_file_size="100mb",auth=(login, pwd)) | |
# demo.launch(max_threads=40, max_file_size="100mb") | |