StevenChen16 commited on
Commit
f9b65c3
·
verified ·
1 Parent(s): 372f45f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -43
app.py CHANGED
@@ -1,58 +1,44 @@
1
  import gradio as gr
2
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
3
- from peft import PeftModel
4
- from threading import Thread
5
  import re
6
- import torch
7
- import spaces
8
 
9
- # 分割段落为句子
10
  def split_into_sentences(text):
11
  sentence_endings = re.compile(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?|\!)\s')
12
  sentences = sentence_endings.split(text)
13
  return [sentence.strip() for sentence in sentences if sentence]
14
 
15
- # 处理段落
16
- @spaces.GPU(duration=120)
17
  def process_paragraph(paragraph, progress=gr.Progress()):
18
- print("Process_Paragraph Function has been called")
19
  sentences = split_into_sentences(paragraph)
20
  results = []
21
  total_sentences = len(sentences)
22
- print("sentences: ", sentences)
23
-
24
  for i, sentence in enumerate(sentences):
25
  progress((i + 1) / total_sentences)
26
-
27
- input_ids = tokenizer.encode(sentence, return_tensors='pt').to(device)
28
- output = model.generate(input_ids,
29
- max_new_tokens=50,
30
- eos_token_id=terminators,
31
- temperature=0.9,
32
- do_sample=True,
33
- )
34
- sentence_response = tokenizer.decode(output[0], skip_special_tokens=True).strip()
35
-
36
- category = sentence_response.lower().replace(' ', '_')
37
- print("Single Sentence: ", sentence)
38
- print("category: ", category)
39
  if category != "fair":
40
  results.append((sentence, category))
41
  else:
42
  results.append((sentence, "fair"))
43
-
 
44
  return results
45
 
46
- # 模型和分词器
47
- MODEL_NAME_OR_PATH = "StevenChen16/llama3-8b-compliance-review"
48
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_OR_PATH)
49
- model = AutoModelForCausalLM.from_pretrained(MODEL_NAME_OR_PATH)
50
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
51
- terminators = [
52
- tokenizer.eos_token_id,
53
- tokenizer.convert_tokens_to_ids("<|eot_id|>")
54
- ]
55
- model.to(device)
 
 
56
 
57
  # 定义类型到颜色的映射
58
  label_to_color = {
@@ -67,14 +53,7 @@ label_to_color = {
67
  "arbitration": "brown",
68
  }
69
 
70
- css = """
71
- <style>
72
- .gradio-container { height: auto; max-height: 500px; overflow-y: scroll; }
73
- </style>
74
- """
75
-
76
  with gr.Blocks() as demo:
77
- gr.Markdown(css)
78
 
79
  with gr.Row(equal_height=True):
80
  with gr.Column():
@@ -90,4 +69,4 @@ with gr.Blocks() as demo:
90
 
91
  btn.click(on_click, inputs=input_text, outputs=[output])
92
 
93
- demo.launch()
 
1
  import gradio as gr
2
+ from llamafactory.chat import ChatModel
3
+ from llamafactory.extras.misc import torch_gc
 
4
  import re
 
 
5
 
 
6
  def split_into_sentences(text):
7
  sentence_endings = re.compile(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?|\!)\s')
8
  sentences = sentence_endings.split(text)
9
  return [sentence.strip() for sentence in sentences if sentence]
10
 
 
 
11
  def process_paragraph(paragraph, progress=gr.Progress()):
 
12
  sentences = split_into_sentences(paragraph)
13
  results = []
14
  total_sentences = len(sentences)
 
 
15
  for i, sentence in enumerate(sentences):
16
  progress((i + 1) / total_sentences)
17
+ messages.append({"role": "user", "content": sentence})
18
+ sentence_response = ""
19
+ for new_text in chat_model.stream_chat(messages, temperature=0.7, top_p=0.9, top_k=50, max_new_tokens=300):
20
+ sentence_response += new_text.strip()
21
+ category = sentence_response.strip().lower().replace(' ', '_')
 
 
 
 
 
 
 
 
22
  if category != "fair":
23
  results.append((sentence, category))
24
  else:
25
  results.append((sentence, "fair"))
26
+ messages.append({"role": "assistant", "content": sentence_response})
27
+ torch_gc()
28
  return results
29
 
30
+ %cd /root/autodl-tmp/LLaMA-Factory/
31
+
32
+ args = dict(
33
+ model_name_or_path="StevenChen16/llama3-8b-compliance-review", # 使用量化的 Llama-3-8B-Instruct 模型
34
+ # adapter_name_or_path="llama3_cr_sft_5", # 加载保存的 LoRA 适配器
35
+ template="llama3", # 与训练时使用的模板相同
36
+ finetuning_type="lora", # 与训练时使用的微调类型相同
37
+ quantization_bit=8, # 加载 4-bit 量化模型
38
+ use_unsloth=True, # 使用 UnslothAI 的 LoRA 优化以加速生成
39
+ )
40
+ chat_model = ChatModel(args)
41
+ messages = []
42
 
43
  # 定义类型到颜色的映射
44
  label_to_color = {
 
53
  "arbitration": "brown",
54
  }
55
 
 
 
 
 
 
 
56
  with gr.Blocks() as demo:
 
57
 
58
  with gr.Row(equal_height=True):
59
  with gr.Column():
 
69
 
70
  btn.click(on_click, inputs=input_text, outputs=[output])
71
 
72
+ demo.launch(share=True)