Spaces:
Runtime error
Runtime error
import gradio as gr | |
from typing import Dict | |
import os | |
from homepage2vec.model import WebsiteClassifier as Homepage2Vec | |
EXAMPLES = [ | |
["gpt3.5", "tanjasenghaasdesigns.de"], | |
["gpt3.5", "epfl.ch"], | |
["gpt3.5", "cc.cz"], | |
["gpt3.5", "promaminky.cz"] | |
] | |
def predict(model_choice : str, url : str) -> Dict[str, float]: | |
""" | |
Predict the categories of a website using the Homepage2Vec model. | |
Args: | |
model_choice (str): The model to use for prediction. | |
url (str): The url of the website to predict. | |
Returns: | |
Dict[str, float]: The categories and their corresponding scores. | |
""" | |
# Define the model directory path | |
model_dir = os.path.join("models", model_choice) | |
# Initialise model | |
model = Homepage2Vec(model_dir=model_dir) | |
# Website to predict | |
website = model.fetch_website(url) | |
# Obtain scores and embeddings | |
scores, _ = model.predict(website) | |
# Filter only scores that have a value greater than 0.5 | |
scores = {k: v for k, v in scores.items() if v > 0.5} | |
return scores | |
iface = gr.Interface( | |
fn=predict, | |
inputs=[gr.Dropdown(choices=["gpt3.5", "gpt4"], label="Select Model"), | |
gr.Textbox(label="Enter Website URL", placeholder="www.mikasenghaas.de")], | |
outputs=gr.Label(num_top_classes=14, label="Predicted Labels", show_label=True), | |
title="Homepage2Vec", | |
description="Use Homepage2Vec to predict the categories of any website you wish.", | |
examples=EXAMPLES, | |
live=False, | |
allow_flagging="never", | |
) | |
iface.launch() | |