StevenChen16's picture
Update app.py
e0d43a9 verified
raw
history blame
2.63 kB
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import torch
import re
def split_into_sentences(text):
sentence_endings = re.compile(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?|\!)\s')
sentences = sentence_endings.split(text)
return [sentence.strip() for sentence in sentences if sentence]
def process_paragraph(paragraph, progress=gr.Progress()):
sentences = split_into_sentences(paragraph)
results = []
total_sentences = len(sentences)
for i, sentence in enumerate(sentences):
progress((i + 1) / total_sentences)
messages.append({"role": "user", "content": sentence})
sentence_response = ""
inputs = tokenizer(sentence, return_tensors="pt").to(device)
with torch.no_grad():
output = model.generate(**inputs, max_new_tokens=300, temperature=0.7, top_p=0.9, top_k=50)
sentence_response = tokenizer.decode(output[0], skip_special_tokens=True)
category = sentence_response.strip().lower().replace(' ', '_')
if category != "fair":
results.append((sentence, category))
else:
results.append((sentence, "fair"))
messages.append({"role": "assistant", "content": sentence_response})
torch.cuda.empty_cache()
return results
# Load model and tokenizer
model_name = "princeton-nlp/Llama-3-Instruct-8B-SimPO"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
messages = []
# Define label to color mapping
label_to_color = {
"fair": "green",
"limitation_of_liability": "red",
"unilateral_termination": "orange",
"unilateral_change": "yellow",
"content_removal": "purple",
"contract_by_using": "blue",
"choice_of_law": "cyan",
"jurisdiction": "magenta",
"arbitration": "brown",
}
# Gradio Interface
with gr.Blocks() as demo:
with gr.Row(equal_height=True):
with gr.Column():
input_text = gr.Textbox(label="Input Paragraph", lines=10, placeholder="Enter the paragraph here...")
btn = gr.Button("Process")
with gr.Column():
output = gr.HighlightedText(label="Processed Paragraph", color_map=label_to_color)
progress = gr.Progress()
def on_click(paragraph):
results = process_paragraph(paragraph, progress=progress)
return results
btn.click(on_click, inputs=input_text, outputs=[output])
demo.launch(share=True)