import gradio as gr from typing import Dict import os from homepage2vec.model import WebsiteClassifier as Homepage2Vec EXAMPLES = [ # Personal site ["original", "tanjasenghaasdesigns.de"], ["finetuned-gpt4", "tanjasenghaasdesigns.de"], # EPFL ["finetuned-gpt3.5", "epfl.ch"], ["finetuned-gpt4", "epfl.ch"], # Czech Crunch - czech tech news ["original", "cc.cz"], ["finetuned-gpt4", "cc.cz"], # Promaminky - czech site for moms ["original", "promaminky.cz"], ["finetuned-gpt3.5", "promaminky.cz"], ] def predict(model_choice : str, url : str) -> Dict[str, float]: """ Predict the categories of a website using the Homepage2Vec model. Args: model_choice (str): The model to use for prediction. url (str): The url of the website to predict. Returns: Dict[str, float]: The categories and their corresponding scores. """ if model_choice == "original": model_dir = os.path.join("models", "homepage2vec") else: which_gpt = model_choice.split("-")[1] model_dir = os.path.join("models", "finetuned", which_gpt) # Initialise model model = Homepage2Vec(model_dir=model_dir) # Website to predict website = model.fetch_website(url) # Obtain scores and embeddings scores, _ = model.predict(website) # Filter only scores that have a value greater than 0.5 scores = {k: v for k, v in scores.items() if v > 0.5} return scores iface = gr.Interface( fn=predict, inputs=[gr.Dropdown(choices=["original", "finetuned-gpt3.5", "finetuned-gpt4"], label="Select Model", show_label=True, value="finetuned-gpt4"), gr.Textbox(label="Enter Website's URL or domain", placeholder="e.g. ikea.com")], outputs=gr.Label(num_top_classes=14, label="Predicted Labels", show_label=True), title="Homepage2Vec", description="Select a version of the Homepage2Vec model and enter a website's URL or domain to predict its categories. The original model was trained on 886K websites from Curlie directory. The finetuned models, in addition, were trained on GPT annotated websites. On average, the fintuned models should predict more labels than the original model while maintaining high accuracy.", examples=EXAMPLES, live=False, allow_flagging="never", ) iface.launch()