import gradio as gr import numpy as np import pickle from sentence_transformers import SentenceTransformer #css_code='body {background-image:url("https://picsum.photos/seed/picsum/200/300");} div.gradio-container {background: white;}' categories = ["Censorship","Development","Digital Activism","Disaster","Economics & Business","Education","Environment","Governance","Health","History","Humanitarian Response","International Relations","Law","Media & Journalism","Migration & Immigration","Politics","Protest","Religion","Sport","Travel","War & Conflict","Technology_Science","Women&Gender_LGBTQ+_Youth","Freedom_of_Speech_Human_Rights","Literature_Arts&Culture"] model = SentenceTransformer('sentence-transformers/LaBSE') with open('models/MLP_classifier_average_en.pkl', 'rb') as f: classifier = pickle.load(f) def get_embedding(text): if text is None: text = "" return model.encode(text) def get_categories(y_pred): indices = [] for idx, value in enumerate(y_pred): if value == 1: indices.append(idx) cats = [categories[i] for i in indices] return cats def generate_output(article): paragraphs = article.split("\n") embdds = [] for par in paragraphs: embdds.append(get_embedding(par)) embedding = np.average(embdds, axis=0) #y_pred = classifier.predict_proba(embedding.reshape(1, 768)) y_pred = classifier.predict(embedding.reshape(1, 768)) y_pred = y_pred.flatten() classes = get_categories(y_pred) return (classes, "clustering tbd") # with gr.Blocks() as demo: # with gr.Row(): # # column for input # with gr.Column(): # input_text = gr.Textbox(lines=6, placeholder="Insert text of the article here...", label="Article"), # submit_button = gr.Button("Submit") # clear_button = gr.Button("Clear") # # column for output # with gr.Column(): # output_classification = gr.Textbox(lines=1, label="Article category") # output_topic_discovery = gr.Textbox(lines=5, label="Topic discovery") #submit_button.click(generate_output, inputs=input_text, outputs=[output_classification, output_topic_discovery]) demo = gr.Interface(fn=generate_output, inputs=gr.Textbox(lines=6, placeholder="Insert text of the article here...", label="Article"), outputs=[gr.Textbox(lines=1, label="Category"), gr.Textbox(lines=5, label="Topic discovery")], title="Article classification & topic discovery demo", flagging_options=["Incorrect"], theme=gr.themes.Base()) #css=css_code) demo.launch()