StevenChen16 commited on
Commit
a9547b0
1 Parent(s): 6226cf7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -19
app.py CHANGED
@@ -1,9 +1,7 @@
1
  import gradio as gr
2
  import os
3
  from threading import Thread
4
- from llamafactory.chat import ChatModel
5
- from llamafactory.extras.misc import torch_gc
6
-
7
 
8
  DESCRIPTION = '''
9
  <div>
@@ -37,14 +35,9 @@ h1 {
37
  }
38
  """
39
 
40
- args = dict(
41
- model_name_or_path="StevenChen16/llama3-8b-Lawyer",
42
- template="llama3",
43
- finetuning_type="lora",
44
- quantization_bit=8,
45
- use_unsloth=True,
46
- )
47
- chat_model = ChatModel(args)
48
 
49
  background_prompt = """
50
  As an AI legal assistant, you are a highly trained expert in U.S. and Canadian law. Your purpose is to provide accurate, comprehensive, and professional legal information to assist users with a wide range of legal questions. When answering questions, you should actively ask questions to obtain more information, analyze from different perspectives, and explain your reasoning process to the user. Please adhere to the following guidelines:
@@ -81,15 +74,44 @@ As an AI legal assistant, you are a highly trained expert in U.S. and Canadian l
81
  Please remember that your role is to provide general legal information and analysis, but also to actively guide and interact with the user during the conversation in a personalized and professional manner. If you feel that necessary information is missing to provide targeted analysis and advice, take the initiative to ask until you believe you have sufficient details. However, also be mindful to avoid over-inquiring or disregarding the user's needs and concerns. Now, please guide me step by step to describe the legal issues I am facing, according to the above requirements.
82
  """
83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  def query_model(user_input, history):
85
  combined_query = background_prompt + user_input
86
- messages = [{"role": "user", "content": combined_query}]
87
-
88
- response = ""
89
- for new_text in chat_model.stream_chat(messages, temperature=0.9):
90
- response += new_text
91
- yield response
92
-
93
  # Gradio block
94
  chatbot = gr.Chatbot(height=450, placeholder=PLACEHOLDER, label='Gradio ChatInterface')
95
 
@@ -110,4 +132,4 @@ with gr.Blocks(css=css) as demo:
110
  gr.Markdown(LICENSE)
111
 
112
  if __name__ == "__main__":
113
- demo.launch()
 
1
  import gradio as gr
2
  import os
3
  from threading import Thread
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
 
 
5
 
6
  DESCRIPTION = '''
7
  <div>
 
35
  }
36
  """
37
 
38
+ # Load the tokenizer and model
39
+ tokenizer = AutoTokenizer.from_pretrained("StevenChen16/llama3-8b-Lawyer")
40
+ model = AutoModelForCausalLM.from_pretrained("StevenChen16/llama3-8b-Lawyer", device_map="auto") # to("cuda:0")
 
 
 
 
 
41
 
42
  background_prompt = """
43
  As an AI legal assistant, you are a highly trained expert in U.S. and Canadian law. Your purpose is to provide accurate, comprehensive, and professional legal information to assist users with a wide range of legal questions. When answering questions, you should actively ask questions to obtain more information, analyze from different perspectives, and explain your reasoning process to the user. Please adhere to the following guidelines:
 
74
  Please remember that your role is to provide general legal information and analysis, but also to actively guide and interact with the user during the conversation in a personalized and professional manner. If you feel that necessary information is missing to provide targeted analysis and advice, take the initiative to ask until you believe you have sufficient details. However, also be mindful to avoid over-inquiring or disregarding the user's needs and concerns. Now, please guide me step by step to describe the legal issues I am facing, according to the above requirements.
75
  """
76
 
77
+ terminators = [
78
+ tokenizer.eos_token_id,
79
+ tokenizer.convert_tokens_to_ids("")
80
+ ]
81
+
82
+ def chat_llama3_8b(message: str, history: list, temperature: float, max_new_tokens: int) -> str:
83
+ conversation = []
84
+ for user, assistant in history:
85
+ conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
86
+ conversation.append({"role": "user", "content": message})
87
+
88
+ input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt").to(model.device)
89
+
90
+ streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
91
+
92
+ generate_kwargs = dict(
93
+ input_ids= input_ids,
94
+ streamer=streamer,
95
+ max_new_tokens=max_new_tokens,
96
+ do_sample=True,
97
+ temperature=temperature,
98
+ eos_token_id=terminators,
99
+ )
100
+ if temperature == 0:
101
+ generate_kwargs['do_sample'] = False
102
+
103
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
104
+ t.start()
105
+
106
+ outputs = []
107
+ for text in streamer:
108
+ outputs.append(text)
109
+ yield "".join(outputs)
110
+
111
  def query_model(user_input, history):
112
  combined_query = background_prompt + user_input
113
+ return chat_llama3_8b(combined_query, history, temperature=0.9, max_new_tokens=512)
114
+
 
 
 
 
 
115
  # Gradio block
116
  chatbot = gr.Chatbot(height=450, placeholder=PLACEHOLDER, label='Gradio ChatInterface')
117
 
 
132
  gr.Markdown(LICENSE)
133
 
134
  if __name__ == "__main__":
135
+ demo.launch(share=True)