StevenChen16 commited on
Commit
5f91d0f
1 Parent(s): ab96ee1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -40
app.py CHANGED
@@ -5,8 +5,6 @@ subprocess.run(["git", "clone", "https://github.com/hiyouga/LLaMA-Factory.git"],
5
  # 切换到仓库目录
6
  import os
7
  os.chdir("LLaMA-Factory")
8
- # 列出目录内容
9
- subprocess.run(["ls"], check=True)
10
  # 安装unsloth
11
  subprocess.run(["pip", "install", "unsloth[colab-new]@git+https://github.com/unslothai/unsloth.git"], check=True)
12
  # 安装xformers
@@ -20,56 +18,41 @@ import gradio as gr
20
  from llamafactory.chat import ChatModel
21
  from llamafactory.extras.misc import torch_gc
22
  import re
23
- import spaces
24
- from threading import Thread
25
-
26
-
27
 
28
  def split_into_sentences(text):
29
  sentence_endings = re.compile(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?|\!)\s')
30
  sentences = sentence_endings.split(text)
31
  return [sentence.strip() for sentence in sentences if sentence]
32
 
33
- @spaces.GPU(duration=120)
34
- def process_sentence(sentence, index, results, messages, progress, total_sentences):
35
- messages.append({"role": "user", "content": sentence})
36
- sentence_response = ""
37
- for new_text in chat_model.stream_chat(messages, temperature=0.7, top_p=0.9, top_k=50, max_new_tokens=300):
38
- sentence_response += new_text.strip()
39
- category = sentence_response.strip().lower().replace(' ', '_')
40
- if category != "fair":
41
- results[index] = (sentence, category)
42
- else:
43
- results[index] = (sentence, "fair")
44
- messages.append({"role": "assistant", "content": sentence_response})
45
- torch_gc()
46
- progress((index + 1) / total_sentences)
47
-
48
- @spaces.GPU(duration=120)
49
  def process_paragraph(paragraph, progress=gr.Progress()):
50
  sentences = split_into_sentences(paragraph)
51
- results = [None] * len(sentences)
52
  total_sentences = len(sentences)
53
- threads = []
54
-
55
  for i, sentence in enumerate(sentences):
56
- thread = Thread(target=process_sentence, args=(sentence, i, results, messages.copy(), progress, total_sentences))
57
- threads.append(thread)
58
- thread.start()
59
-
60
- for thread in threads:
61
- thread.join()
62
-
 
 
 
 
 
63
  return results
64
 
 
 
65
  args = dict(
66
- model_name_or_path="princeton-nlp/Llama-3-Instruct-8B-SimPO", # 使用量化的 Llama-3-8B-Instruct 模型
67
- # model_name_or_path="StevenChen16/llama3-8b-compliance-review",
68
- # adapter_name_or_path="StevenChen16/llama3-8b-compliance-review-adapter", # 加载保存的 LoRA 适配器
69
- template="llama3", # 与训练时使用的模板相同
70
- finetuning_type="lora", # 与训练时使用的微调类型相同
71
- quantization_bit=8, # 加载 8-bit 量化模型
72
- use_unsloth=True, # 使用 UnslothAI 的 LoRA 优化以加速生成
73
  )
74
  chat_model = ChatModel(args)
75
  messages = []
@@ -88,6 +71,7 @@ label_to_color = {
88
  }
89
 
90
  with gr.Blocks() as demo:
 
91
  with gr.Row(equal_height=True):
92
  with gr.Column():
93
  input_text = gr.Textbox(label="Input Paragraph", lines=10, placeholder="Enter the paragraph here...")
@@ -102,4 +86,4 @@ with gr.Blocks() as demo:
102
 
103
  btn.click(on_click, inputs=input_text, outputs=[output])
104
 
105
- demo.launch(share=True)
 
5
  # 切换到仓库目录
6
  import os
7
  os.chdir("LLaMA-Factory")
 
 
8
  # 安装unsloth
9
  subprocess.run(["pip", "install", "unsloth[colab-new]@git+https://github.com/unslothai/unsloth.git"], check=True)
10
  # 安装xformers
 
18
  from llamafactory.chat import ChatModel
19
  from llamafactory.extras.misc import torch_gc
20
  import re
 
 
 
 
21
 
22
  def split_into_sentences(text):
23
  sentence_endings = re.compile(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?|\!)\s')
24
  sentences = sentence_endings.split(text)
25
  return [sentence.strip() for sentence in sentences if sentence]
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  def process_paragraph(paragraph, progress=gr.Progress()):
28
  sentences = split_into_sentences(paragraph)
29
+ results = []
30
  total_sentences = len(sentences)
 
 
31
  for i, sentence in enumerate(sentences):
32
+ progress((i + 1) / total_sentences)
33
+ messages.append({"role": "user", "content": sentence})
34
+ sentence_response = ""
35
+ for new_text in chat_model.stream_chat(messages, temperature=0.7, top_p=0.9, top_k=50, max_new_tokens=300):
36
+ sentence_response += new_text.strip()
37
+ category = sentence_response.strip().lower().replace(' ', '_')
38
+ if category != "fair":
39
+ results.append((sentence, category))
40
+ else:
41
+ results.append((sentence, "fair"))
42
+ messages.append({"role": "assistant", "content": sentence_response})
43
+ torch_gc()
44
  return results
45
 
46
+ %cd /root/autodl-tmp/LLaMA-Factory/
47
+
48
  args = dict(
49
+ model_name_or_path="princeton-nlp/Llama-3-Instruct-8B-SimPO", # 使用量化的 Llama-3-8B-Instruct 模型
50
+ # model_name_or_path="StevenChen16/llama3-8b-compliance-review",
51
+ adapter_name_or_path="StevenChen16/llama3-8b-compliance-review-adapter", # 加载保存的 LoRA 适配器
52
+ template="llama3", # 与训练时使用的模板相同
53
+ finetuning_type="lora", # 与训练时使用的微调类型相同
54
+ quantization_bit=8, # 加载 4-bit 量化模型
55
+ use_unsloth=True, # 使用 UnslothAI 的 LoRA 优化以加速生成
56
  )
57
  chat_model = ChatModel(args)
58
  messages = []
 
71
  }
72
 
73
  with gr.Blocks() as demo:
74
+
75
  with gr.Row(equal_height=True):
76
  with gr.Column():
77
  input_text = gr.Textbox(label="Input Paragraph", lines=10, placeholder="Enter the paragraph here...")
 
86
 
87
  btn.click(on_click, inputs=input_text, outputs=[output])
88
 
89
+ demo.launch(share=True)