Spaces:
Runtime error
Runtime error
import gradio as gr | |
from typing import Dict | |
import os | |
from homepage2vec.model import WebsiteClassifier as Homepage2Vec | |
EXAMPLES = [ | |
# Personal site | |
["original", "tanjasenghaasdesigns.de"], | |
["finetuned-gpt4", "tanjasenghaasdesigns.de"], | |
# EPFL | |
["finetuned-gpt3.5", "epfl.ch"], | |
["finetuned-gpt4", "epfl.ch"], | |
# Czech Crunch - czech tech news | |
["original", "cc.cz"], | |
["finetuned-gpt4", "cc.cz"], | |
# Promaminky - czech site for moms | |
["original", "promaminky.cz"], | |
["finetuned-gpt3.5", "promaminky.cz"], | |
] | |
def predict(model_choice : str, url : str) -> Dict[str, float]: | |
""" | |
Predict the categories of a website using the Homepage2Vec model. | |
Args: | |
model_choice (str): The model to use for prediction. | |
url (str): The url of the website to predict. | |
Returns: | |
Dict[str, float]: The categories and their corresponding scores. | |
""" | |
if model_choice == "original": | |
model_dir = os.path.join("models", "homepage2vec") | |
else: | |
which_gpt = model_choice.split("-")[1] | |
model_dir = os.path.join("models", "finetuned", which_gpt) | |
# Initialise model | |
model = Homepage2Vec(model_dir=model_dir) | |
# Website to predict | |
website = model.fetch_website(url) | |
# Obtain scores and embeddings | |
scores, _ = model.predict(website) | |
# Filter only scores that have a value greater than 0.5 | |
scores = {k: v for k, v in scores.items() if v > 0.5} | |
return scores | |
iface = gr.Interface( | |
fn=predict, | |
inputs=[gr.Dropdown(choices=["original", "finetuned-gpt3.5", "finetuned-gpt4"], label="Select Model", show_label=True, value="finetuned-gpt4"), | |
gr.Textbox(label="Enter Website's URL or domain", placeholder="e.g. ikea.com")], | |
outputs=gr.Label(num_top_classes=14, label="Predicted Labels", show_label=True), | |
title="Homepage2Vec", | |
description="Select a version of the Homepage2Vec model and enter a website's URL or domain to predict its categories. The original model was trained on 886K websites from Curlie directory. The finetuned models, in addition, were trained on GPT annotated websites. On average, the fintuned models should predict more labels than the original model while maintaining high accuracy.", | |
examples=EXAMPLES, | |
live=False, | |
allow_flagging="never", | |
) | |
iface.launch() | |