import re import os from dotenv import load_dotenv import gradio as gr import pandas as pd from pandas import DataFrame as PandasDataFrame from llm import MessageChatCompletion from customization import css, js from examples import example_1, example_2, example_3, example_4 from prompt_template import system_message_template, user_message_template load_dotenv() API_KEY = os.getenv("API_KEY") df = pd.read_csv('subsectors.csv') logs_columns = ['Abstract', 'Model', 'Results'] logs_df = PandasDataFrame(columns=logs_columns) def download_logs(): global logs_df # Check for the current operating system's desktop path if os.name == 'nt': # For Windows desktop = os.path.join(os.path.join(os.environ['USERPROFILE']), 'Desktop') else: # For macOS and Linux desktop = os.path.join(os.path.join(os.path.expanduser('~')), 'Desktop') # Define the path to save the CSV file on the desktop file_path = os.path.join(desktop, 'classification_logs.csv') # Save the DataFrame to the CSV file on the desktop logs_df.to_csv(file_path) def build_context(row): subsector_name = row['Subsector'] context = f"Subsector name: {subsector_name}. " context += f"{subsector_name} Definition: {row['Definition']}. " context += f"{subsector_name} keywords: {row['Keywords']}. " context += f"{subsector_name} Does include: {row['Does include']}. " context += f"{subsector_name} Does not include: {row['Does not include']}.\n" return context def click_button(model, api_key, abstract): labels = df['Subsector'].tolist() prompt_context = [build_context(row) for _, row in df.iterrows()] language_model = MessageChatCompletion(model=model, api_key=api_key) system_message = system_message_template.format(prompt_context=prompt_context) user_message = user_message_template.format(labels=labels, abstract=abstract) language_model.new_system_message(content=system_message) language_model.new_user_message(content=user_message) language_model.send_message() response_reasoning = language_model.get_last_message() dict_pattern = r'\{.*?\}' match = re.search(dict_pattern, response_reasoning, re.DOTALL) if match and language_model.error is False: match_score_dict = eval(match.group(0)) else: match_score_dict = {} # Update Logs new_log_entry = pd.DataFrame({'Abstract': [abstract], 'Model': [model], 'Results': [str(match_score_dict)]}) global logs_df logs_df = pd.concat([logs_df, new_log_entry], ignore_index=True) return match_score_dict, response_reasoning, logs_df def on_select(evt: gr.SelectData): # SelectData is a subclass of EventData selected = df.iloc[[evt.index[0]]].iloc[0] name, definition, keywords, does_include, does_not_include = selected['Subsector'], selected['Definition'], selected['Keywords'], selected['Does include'], selected['Does not include'] name_accordion = gr.Accordion(label=name) return name_accordion, definition, keywords, does_include, does_not_include # --- GRADIO INTERFACE --- # with gr.Blocks(css=css, js=js) as demo: state_lotto = gr.State() selected_x_labels = gr.State() with gr.Tab("Patent Discovery"): with gr.Row(): with gr.Column(scale=5): dropdown_model = gr.Dropdown( label="Model", choices=["gpt-4", "gpt-4-turbo-preview", "gpt-3.5-turbo", "gpt-3.5-turbo-0125"], value="gpt-3.5-turbo-0125", multiselect=False, interactive=True ) with gr.Column(scale=5): api_key = gr.Textbox( label="API Key", interactive=True, lines=1, max_lines=1, type="password", value=API_KEY ) with gr.Row(equal_height=True): abstract_description = gr.Textbox( label="Abstract description", lines=5, max_lines=10000, interactive=True, placeholder="Input a patent abstract" ) with gr.Row(): with gr.Accordion(label="Example Abstracts", open=False): gr.Examples( examples=[example_1, example_2, example_3, example_4], inputs=abstract_description, fn=click_button, label="", # cache_examples=True, ) with gr.Row(): btn_get_result = gr.Button("Classify") with gr.Row(elem_classes=['all_results']): with gr.Column(scale=4): label_result = gr.Label(num_top_classes=None) with gr.Column(scale=6): reasoning = gr.Markdown(label="Reasoning", elem_classes=['reasoning_results']) with gr.Tab("Subsector definitions"): with gr.Row(): with gr.Column(scale=4): df_subsectors = gr.DataFrame(df[['Subsector']], interactive=False, height=800) with gr.Column(scale=6): with gr.Accordion(label='Artificial Intelligence, Big Data and Analytics') as subsector_name: s1_definition = gr.Textbox(label="Definition", lines=5, max_lines=100, value="Virtual reality (VR) is an artificial, computer-generated simulation or recreation of a real life environment or situation. Augmented reality (AR) is a technology that layers computer-generated enhancements atop an existing reality in order to make it more meaningful through the ability to interact with it. ") s1_keywords = gr.Textbox(label="Keywords", lines=5, max_lines=100, value="Mixed Reality, 360 video, frame rate, metaverse, virtual world, cross reality, Artificial intelligence, computer vision") does_include = gr.Textbox(label="Does include", lines=4) does_not_include = gr.Textbox(label="Does not include", lines=3) with gr.Tab("Logs"): output_dataframe = gr.Dataframe( value=logs_df, type="pandas", height=500, headers=['Abstract', 'Model', 'Results'], interactive=False, column_widths=["45%", "10%", "45%"], ) btn_export = gr.Button( value="Export to CSV", size="sm", ) btn_get_result.click( fn=click_button, inputs=[dropdown_model, api_key, abstract_description], outputs=[label_result, reasoning, output_dataframe]) btn_export.click( fn=download_logs, ) df_subsectors.select( fn=on_select, outputs=[subsector_name, s1_definition, s1_keywords, does_include, does_not_include] ) if __name__ == "__main__": # demo.queue() demo.launch()