Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,631 Bytes
17b23e6 79c5347 17b23e6 ab96ee1 17b23e6 94f56db 17b23e6 5f91d0f 17b23e6 5f91d0f 94f56db 79c5347 5f91d0f 94f56db 79c5347 8f48aeb f9b65c3 79c5347 e0d43a9 79c5347 94f56db 17b23e6 79c5347 17b23e6 79c5347 17b23e6 5f91d0f 17b23e6 94f56db 17b23e6 fa6db21 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import torch
import re
def split_into_sentences(text):
sentence_endings = re.compile(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?|\!)\s')
sentences = sentence_endings.split(text)
return [sentence.strip() for sentence in sentences if sentence]
def process_paragraph(paragraph, progress=gr.Progress()):
sentences = split_into_sentences(paragraph)
results = []
total_sentences = len(sentences)
for i, sentence in enumerate(sentences):
progress((i + 1) / total_sentences)
messages.append({"role": "user", "content": sentence})
sentence_response = ""
inputs = tokenizer(sentence, return_tensors="pt").to(device)
with torch.no_grad():
output = model.generate(**inputs, max_new_tokens=300, temperature=0.7, top_p=0.9, top_k=50)
sentence_response = tokenizer.decode(output[0], skip_special_tokens=True)
category = sentence_response.strip().lower().replace(' ', '_')
if category != "fair":
results.append((sentence, category))
else:
results.append((sentence, "fair"))
messages.append({"role": "assistant", "content": sentence_response})
torch.cuda.empty_cache()
return results
# Load model and tokenizer
model_name = "princeton-nlp/Llama-3-Instruct-8B-SimPO"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
messages = []
# Define label to color mapping
label_to_color = {
"fair": "green",
"limitation_of_liability": "red",
"unilateral_termination": "orange",
"unilateral_change": "yellow",
"content_removal": "purple",
"contract_by_using": "blue",
"choice_of_law": "cyan",
"jurisdiction": "magenta",
"arbitration": "brown",
}
# Gradio Interface
with gr.Blocks() as demo:
with gr.Row(equal_height=True):
with gr.Column():
input_text = gr.Textbox(label="Input Paragraph", lines=10, placeholder="Enter the paragraph here...")
btn = gr.Button("Process")
with gr.Column():
output = gr.HighlightedText(label="Processed Paragraph", color_map=label_to_color)
progress = gr.Progress()
def on_click(paragraph):
results = process_paragraph(paragraph, progress=progress)
return results
btn.click(on_click, inputs=input_text, outputs=[output])
demo.launch(share=True)
|