FLAN-T5_RSiCS / app.py
armandstrickernlp
app + chackpoint
fb373b9
import gradio as gr
from transformers import (T5Tokenizer,
T5ForConditionalGeneration,
AddedToken,
)
tokenizer = T5Tokenizer.from_pretrained(f"google/flan-t5-base")
tokenizer.add_special_tokens({"additional_special_tokens": [AddedToken("\n")]})
# load model
model_cktp = "model_checkpoint"
model = T5ForConditionalGeneration.from_pretrained(model_cktp)
def predict(input):
input_ids = tokenizer.encode(input, return_tensors="pt")
outputs = model.generate(input_ids,
max_length=200,
early_stopping=True)
response = tokenizer.decode(outputs[0], skip_special_tokens=False)
response = response.replace("<pad>", "").replace("</s>", "")
return response
examples = [["""Service: telecom customer service.
Customer utterance : "I'm trying to find out when my tv service will be turn back on??????"|
Extract all strictly unnecessary sequences for the service provider to process the request/issue and then classify them using relational tags."""],
["""Service: airline customer service.
Customer utterance : "I need a ticket to Boston this Saturday, my son is graduating!"|
Extract all strictly unnecessary sequences for the service provider to process the request/issue and classify them using relational tags."""]
]
description = """This model detects and classifies relational strategies in customer service requests, using an instruction-based approach."""
demo = gr.Interface(fn=predict,
inputs="text",
outputs="text",
title="FLAN-T5: Detect and Classify Relational Strategies",
examples=examples,
description=description)
if __name__ == "__main__":
demo.launch()