Spaces:
Runtime error
Runtime error
import json | |
import os | |
import shutil | |
import requests | |
import gradio as gr | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
def generate(html, entity, website_desc, datasource, year, month, title, prompt): | |
html_text = "html | " if html == "on" else "" | |
entity_text = "" | |
if entity != "": | |
ent_list = [x.strip() for x in entity.split(',')] | |
for ent in ent_list: | |
entity_text = entity_text + " |" + ent + "|" | |
entity_text = "entity ||| <ENTITY_CHAIN>" + entity_text + " </ENTITY_CHAIN> " | |
else: | |
entity_text = "||| " | |
website_desc_text = "Website Description: " + website_desc + " | " if website_desc != "" else "" | |
datasource_text = "Datasource: " + datasource + " | " if datasource != "" else "" | |
year_text = "Year: " + year + " | " if year != "" else "" | |
month_text = "Month: " + month + " | " if month != "" else "" | |
title_text = "Title: " + title + " | " if title != "" else "" | |
final_prompt = html_text + year_text + month_text + website_desc_text + title_text + datasource_text + entity_text + prompt | |
model = AutoModelForCausalLM.from_pretrained("bs-modeling-metadata/checkpoints_all_04_23", subfolder="checkpoint-30000step") | |
tokenizer = AutoTokenizer.from_pretrained("bs-modeling-metadata/checkpoints_all_04_23", subfolder="tokenizer", add_prefix_space=True) | |
bad_words_ids = tokenizer(["<ENTITY_CHAIN>", " </ENTITY_CHAIN> "]).input_ids | |
inputs = tokenizer(final_prompt, return_tensors="pt") | |
outputs = model.generate(**inputs, max_new_tokens=128, bad_words_ids=bad_words_ids) | |
return tokenizer.batch_decode(outputs, skip_special_tokens=True) | |
html = gr.Radio(["on", "off"], label="html", info="turn html as on or off") | |
entity = gr.Textbox(placeholder="enter a list of comma separated entities or keywords", label="list of entities") | |
website_desc = gr.Textbox(placeholder="enter a website description", label="website description") | |
datasource = gr.Textbox(placeholder="enter a datasource", label="datasource") | |
year = gr.Textbox(placeholder="enter a year", label="year") | |
month = gr.Textbox(placeholder="enter a month", label="month") | |
title = gr.Textbox(placeholder="enter a website title", label="website title") | |
prompt = gr.Textbox(placeholder="enter a prompt", label="prompt") | |
demo = gr.Interface( | |
fn=generate, | |
inputs=[html, entity, website_desc, datasource, year, month, title, prompt], | |
outputs="text", | |
) | |
demo.launch() |