Spaces:
Sleeping
Sleeping
import re | |
import gradio as gr | |
import pandas as pd | |
from chatgpt import MessageChatCompletion | |
# theme = gr.themes.Monochrome(spacing_size=gr.themes.sizes.spacing_md, | |
# radius_size=gr.themes.sizes.radius_sm, | |
# text_size=gr.themes.sizes.text_lg).set( | |
# loader_color="#FF0000", | |
# slider_color="#FF0000" | |
# ) | |
# df = pd.read_json("description_subsector.json", orient='index') | |
# df = df.reset_index().rename(columns={'index': 'Subsector'}) | |
# df = df.sort_values(by=['Subsector']) | |
# df.reset_index(drop=True, inplace=True) | |
# df.to_csv('subsectors.csv', index=False) | |
df = pd.read_csv('subsectors.csv') | |
df.fillna('', inplace=True) | |
def build_context(row): | |
subsector_name = row['Subsector'] | |
context = f"Subsector name: {subsector_name}. " | |
context += f"{subsector_name} Definition: {row['Definition']}. " | |
context += f"{subsector_name} keywords: {row['Keywords']}. " | |
context += f"{subsector_name} Does include: {row['Does include']}. " | |
context += f"{subsector_name} Does not include: {row['Does not include']}.\n" | |
return context | |
def click_button(model, api_key, abstract): | |
labels = df['Subsector'].tolist() | |
contexts = [build_context(row) for _, row in df.iterrows()] | |
my_chatgpt = MessageChatCompletion(model=model, api_key=api_key) | |
system_message = ( | |
"You are a system designed to classify patent abstracts into one or more subsectors based on their content. " | |
"Each subsector is defined by a unique set of characteristics: " | |
"\n- Name: The name of the subsector." | |
"\n- Definition: A brief description of the subsector." | |
"\n- Keywords: Important words associated with the subsector." | |
"\n- Does include: Elements typically found within the subsector." | |
"\n- Does not include: Elements typically not found within the subsector." | |
"\nConsider 'nan' values as 'not available' or 'not applicable'. " | |
"When classifying an abstract, provide the following: " | |
"\n1. Subsector(s): Name(s) of the subsector(s) you believe the abstract belongs to." | |
"\n2. Reasoning: " | |
"\n\t- Conclusion: Explain why the abstract was classified in this subsector(s), based on its alignment with the subsector's definition, keywords, and includes/excludes criteria." | |
"\n\t- Keywords found: Specify any 'Keywords' from the subsector that are present in the abstract." | |
"\n\t- Does include found: Specify any 'Includes' criteria from the subsector that are present in the abstract." | |
"\n\t- If no specific 'Keywords' or 'Includes' are found, state that none were directly identified, but the classification was made based on the overall relevance to the subsector." | |
"\n3. Non-selected Subsectors: " | |
"\n\t- If a subsector had a high probability of being a match but was ultimately not chosen because the abstract contained terms from the 'Does not include' list, provide a brief explanation. Highlight the specific 'Does not include' terms found and why this led to the subsector's exclusion." | |
f"\n4. Probability: Provide a dictionary containing the subsectors ({labels}) and their corresponding probabilities of being a match. Each probability should be formatted to show two decimal places." | |
"\n5. Suggested Subsector: Based on the primary classification in item 1, suggest a more specialized area or a closely related subsector within it. This suggestion should delve deeper into the nuances of the primary subsector, targeting areas of emerging interest, innovation, or specialization that align with the abstract's content and the initial subsector's broader context." | |
"\nYour task is to classify the following patent abstract into the appropriate subsector(s), taking into account the details of each subsector as described above. Here are the subsectors and their definitions for your reference:\n" | |
f"{contexts}" | |
) | |
user_message = f'Classify this patent abstract: {abstract}' | |
my_chatgpt.new_system_message(content=system_message) | |
my_chatgpt.new_user_message(content=user_message) | |
my_chatgpt.send_message() | |
reasoning = my_chatgpt.get_last_message() | |
dict_pattern = r'\{.*?\}' | |
probabilities_match = re.search(dict_pattern, reasoning, re.DOTALL) | |
if probabilities_match and my_chatgpt.error is False: | |
probabilities_dict = eval(probabilities_match.group(0)) | |
else: | |
probabilities_dict = {} | |
return probabilities_dict, reasoning | |
def on_select(evt: gr.SelectData): # SelectData is a subclass of EventData | |
selected = df.iloc[[evt.index[0]]].iloc[0] | |
name, definition, keywords, does_include, does_not_include = selected['Subsector'], selected['Definition'], selected['Keywords'], selected['Does include'], selected['Does not include'] | |
name_accordion = gr.Accordion(label=name) | |
return name_accordion, definition, keywords, does_include, does_not_include | |
# with gr.Blocks(theme=theme) as startup_genome_demo: | |
with gr.Blocks() as startup_genome_demo: | |
state_lotto = gr.State() | |
selected_x_labels = gr.State() | |
with gr.Tab("Patent Discovery"): | |
with gr.Row(): | |
with gr.Column(scale=5): | |
dropdown_model = gr.Dropdown(label="Model", choices=["gpt-4", "gpt-4-turbo-preview", "gpt-3.5-turbo", "gpt-3.5-turbo-0125"], value="gpt-3.5-turbo", multiselect=False, interactive=True) | |
with gr.Column(scale=5): | |
api_key = gr.Textbox(label="API KEY", interactive=True, lines=1, max_lines=1) | |
with gr.Row(equal_height=True): | |
abstract_description = gr.Textbox(label="Abstract description", lines=10, max_lines=10000, interactive=True, value="A holographic optical accessing system includes a light source for emitting a light beam; an optical assembly module for receiving the light beam and generating a signal beam and a reference beam that are parallel to each other rather than overlap with each other, and have the same first polarization state; a lens module for focusing the signal beam and the reference beam on a focal point at the same time; and a storage medium for recording the focal point. The optical assembly module includes at least a data plane for displaying image information so that the signal beam contains the image information.") | |
with gr.Row(): | |
btn_get_result = gr.Button("Show classification") | |
with gr.Row(): | |
with gr.Column(scale=4): | |
label_result = gr.Label(num_top_classes=None) | |
with gr.Column(scale=6): | |
reasoning = gr.Textbox(label="Reasoning", lines=34) | |
with gr.Tab("Sector definitions"): | |
with gr.Row(): | |
with gr.Column(scale=4): | |
df_subsectors = gr.DataFrame(df[['Subsector']], interactive=False, height=800) | |
with gr.Column(scale=6): | |
with gr.Accordion(label='Artificial Intelligence, Big Data and Analytics') as subsector_name: | |
s1_definition = gr.Textbox(label="Definition", lines=5, max_lines=100, value="Virtual reality (VR) is an artificial, computer-generated simulation or recreation of a real life environment or situation. Augmented reality (AR) is a technology that layers computer-generated enhancements atop an existing reality in order to make it more meaningful through the ability to interact with it. ") | |
s1_keywords = gr.Textbox(label="Keywords", lines=5, max_lines=100, | |
value="Mixed Reality, 360 video, frame rate, metaverse, virtual world, cross reality, Artificial intelligence, computer vision") | |
does_include = gr.Textbox(label="Does include", lines=4) | |
does_not_include = gr.Textbox(label="Does not include", lines=3) | |
btn_get_result.click(fn=click_button, inputs=[dropdown_model, api_key, abstract_description], outputs=[label_result, reasoning]) | |
df_subsectors.select(fn=on_select, outputs=[subsector_name, s1_definition, s1_keywords, does_include, does_not_include]) | |
if __name__ == "__main__": | |
startup_genome_demo.queue() | |
startup_genome_demo.launch() | |