StevenChen16 commited on
Commit
cba7e97
1 Parent(s): bd5c82b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -75
app.py CHANGED
@@ -1,8 +1,8 @@
1
  import gradio as gr
2
  import os
3
- import spaces
4
- from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
5
  from threading import Thread
 
 
6
 
7
  # Set an environment variable
8
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
@@ -41,98 +41,57 @@ h1 {
41
  }
42
  """
43
 
44
- # Load the tokenizer and model
45
- tokenizer = AutoTokenizer.from_pretrained("StevenChen16/llama3-8b-Lawyer")
46
- model = AutoModelForCausalLM.from_pretrained("StevenChen16/llama3-8b-Lawyer", device_map="auto")
 
 
 
 
 
47
 
48
- terminators = [
49
- tokenizer.eos_token_id,
50
- tokenizer.convert_tokens_to_ids("")
51
- ]
52
 
53
- @spaces.GPU(duration=120)
54
- def chat_llama3_8b(message: str,
55
- history: list,
56
- temperature: float,
57
- max_new_tokens: int
58
- ) -> str:
59
- """
60
- Generate a streaming response using the llama3-8b model.
61
- Args:
62
- message (str): The input message.
63
- history (list): The conversation history used by ChatInterface.
64
- temperature (float): The temperature for generating the response.
65
- max_new_tokens (int): The maximum number of new tokens to generate.
66
- Returns:
67
- str: The generated response.
68
- """
69
- conversation = []
70
- for user, assistant in history:
71
- conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
72
- conversation.append({"role": "user", "content": message})
73
 
74
- input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt").to(model.device)
 
 
 
 
 
 
 
75
 
76
- streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
77
-
78
- generate_kwargs = dict(
79
- input_ids= input_ids,
80
- streamer=streamer,
81
- max_new_tokens=max_new_tokens,
82
- do_sample=True,
83
- temperature=temperature,
84
- eos_token_id=terminators,
85
- )
86
- # This will enforce greedy generation (do_sample=False) when the temperature is passed 0, avoiding the crash.
87
- if temperature == 0:
88
- generate_kwargs['do_sample'] = False
89
-
90
- t = Thread(target=model.generate, kwargs=generate_kwargs)
91
- t.start()
92
-
93
- outputs = []
94
- for text in streamer:
95
- outputs.append(text)
96
- yield "".join(outputs)
97
-
98
-
99
  # Gradio block
100
  chatbot = gr.Chatbot(height=450, placeholder=PLACEHOLDER, label='Gradio ChatInterface')
101
 
102
- with gr.Blocks(fill_height=True, css=css) as demo:
103
-
104
  gr.Markdown(DESCRIPTION)
105
  gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
106
  gr.ChatInterface(
107
- fn=chat_llama3_8b,
108
  chatbot=chatbot,
109
- fill_height=True,
110
- additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
111
  additional_inputs=[
112
- gr.Slider(minimum=0,
113
- maximum=1,
114
- step=0.1,
115
- value=0.95,
116
- label="Temperature",
117
- render=False),
118
- gr.Slider(minimum=128,
119
- maximum=4096,
120
- step=1,
121
- value=512,
122
- label="Max new tokens",
123
- render=False ),
124
- ],
125
  examples=[
126
  ['How to setup a human base on Mars? Give short answer.'],
127
  ['Explain theory of relativity to me like I’m 8 years old.'],
128
  ['What is 9,000 * 9,000?'],
129
  ['Write a pun-filled happy birthday message to my friend Alex.'],
130
  ['Justify why a penguin might make a good king of the jungle.']
131
- ],
132
  cache_examples=False,
133
- )
134
-
135
  gr.Markdown(LICENSE)
136
-
137
  if __name__ == "__main__":
138
  demo.launch()
 
1
  import gradio as gr
2
  import os
 
 
3
  from threading import Thread
4
+ from llamafactory.chat import ChatModel
5
+ from llamafactory.extras.misc import torch_gc
6
 
7
  # Set an environment variable
8
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
 
41
  }
42
  """
43
 
44
+ args = dict(
45
+ model_name_or_path="StevenChen16/llama3-8b-Lawyer",
46
+ template="llama3",
47
+ finetuning_type="lora",
48
+ quantization_bit=8,
49
+ use_unsloth=True,
50
+ )
51
+ chat_model = ChatModel(args)
52
 
53
+ background_prompt = """
54
+ You are an advanced AI legal assistant trained to assist with a wide range of legal questions and issues. Your primary function is to provide accurate, comprehensive, and professional legal information based on U.S. and Canada law. Follow these guidelines when formulating responses:
 
 
55
 
56
+ 1. **Clarity and Precision**: Ensure that your responses are clear and precise. Use professional legal terminology, but explain complex legal concepts in a way that is understandable to individuals without a legal background.
57
+ 2. **Comprehensive Coverage**: Provide thorough answers that cover all relevant aspects of the question. Include explanations of legal principles, relevant statutes, case law, and their implications.
58
+ 3. **Contextual Relevance**: Tailor your responses to the specific context of the question asked. Provide examples or analogies where appropriate to illustrate legal concepts.
59
+ 4. **Statutory and Case Law References**: When mentioning statutes, include their significance and application. When citing case law, summarize the facts, legal issues, court decisions, and their broader implications.
60
+ 5. **Professional Tone**: Maintain a professional and respectful tone in all responses. Ensure that your advice is legally sound and adheres to ethical standards.
61
+ """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
+ def query_model(user_input, history, temperature, max_new_tokens):
64
+ combined_query = background_prompt + user_input
65
+ messages = [{"role": "user", "content": combined_query}]
66
+
67
+ response = ""
68
+ for new_text in chat_model.stream_chat(messages):
69
+ response += new_text
70
+ yield response
71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  # Gradio block
73
  chatbot = gr.Chatbot(height=450, placeholder=PLACEHOLDER, label='Gradio ChatInterface')
74
 
75
+ with gr.Blocks(css=css) as demo:
 
76
  gr.Markdown(DESCRIPTION)
77
  gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
78
  gr.ChatInterface(
79
+ fn=query_model,
80
  chatbot=chatbot,
 
 
81
  additional_inputs=[
82
+ gr.Slider(minimum=0, maximum=1, step=0.1, value=0.95, label="Temperature"),
83
+ gr.Slider(minimum=128, maximum=4096, step=1, value=512, label="Max new tokens"),
84
+ ],
 
 
 
 
 
 
 
 
 
 
85
  examples=[
86
  ['How to setup a human base on Mars? Give short answer.'],
87
  ['Explain theory of relativity to me like I’m 8 years old.'],
88
  ['What is 9,000 * 9,000?'],
89
  ['Write a pun-filled happy birthday message to my friend Alex.'],
90
  ['Justify why a penguin might make a good king of the jungle.']
91
+ ],
92
  cache_examples=False,
93
+ )
 
94
  gr.Markdown(LICENSE)
95
+
96
  if __name__ == "__main__":
97
  demo.launch()