File size: 2,631 Bytes
17b23e6
79c5347
 
17b23e6
ab96ee1
17b23e6
 
 
 
 
94f56db
17b23e6
5f91d0f
17b23e6
 
5f91d0f
94f56db
 
79c5347
 
 
 
5f91d0f
 
 
 
 
94f56db
79c5347
8f48aeb
f9b65c3
79c5347
 
 
e0d43a9
79c5347
 
94f56db
 
17b23e6
79c5347
17b23e6
 
 
 
 
 
 
 
 
 
 
 
79c5347
17b23e6
5f91d0f
17b23e6
 
 
 
 
 
 
 
 
94f56db
17b23e6
 
 
 
fa6db21
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import torch
import re

def split_into_sentences(text):
    sentence_endings = re.compile(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?|\!)\s')
    sentences = sentence_endings.split(text)
    return [sentence.strip() for sentence in sentences if sentence]

def process_paragraph(paragraph, progress=gr.Progress()):
    sentences = split_into_sentences(paragraph)
    results = []
    total_sentences = len(sentences)
    for i, sentence in enumerate(sentences):
        progress((i + 1) / total_sentences)
        messages.append({"role": "user", "content": sentence})
        sentence_response = ""
        inputs = tokenizer(sentence, return_tensors="pt").to(device)
        with torch.no_grad():
            output = model.generate(**inputs, max_new_tokens=300, temperature=0.7, top_p=0.9, top_k=50)
            sentence_response = tokenizer.decode(output[0], skip_special_tokens=True)
        category = sentence_response.strip().lower().replace(' ', '_')
        if category != "fair":
            results.append((sentence, category))
        else:
            results.append((sentence, "fair"))
        messages.append({"role": "assistant", "content": sentence_response})
        torch.cuda.empty_cache()
    return results

# Load model and tokenizer
model_name = "princeton-nlp/Llama-3-Instruct-8B-SimPO"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

messages = []

# Define label to color mapping
label_to_color = {
    "fair": "green",
    "limitation_of_liability": "red",
    "unilateral_termination": "orange",
    "unilateral_change": "yellow",
    "content_removal": "purple",
    "contract_by_using": "blue",
    "choice_of_law": "cyan",
    "jurisdiction": "magenta",
    "arbitration": "brown",
}

# Gradio Interface
with gr.Blocks() as demo:
    
    with gr.Row(equal_height=True):
        with gr.Column():
            input_text = gr.Textbox(label="Input Paragraph", lines=10, placeholder="Enter the paragraph here...")
            btn = gr.Button("Process")
        with gr.Column():
            output = gr.HighlightedText(label="Processed Paragraph", color_map=label_to_color)
            progress = gr.Progress()

    def on_click(paragraph):
        results = process_paragraph(paragraph, progress=progress)
        return results

    btn.click(on_click, inputs=input_text, outputs=[output])

demo.launch(share=True)