ShravanHN commited on
Commit
4efef34
1 Parent(s): 557609c

added chunks if tokens are more

Browse files
Files changed (1) hide show
  1. app.py +88 -25
app.py CHANGED
@@ -1,14 +1,18 @@
1
- import spaces
2
  import gradio as gr
3
  import os
 
4
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
5
  import torch
6
  from threading import Thread
 
 
 
 
 
7
 
8
  # Set an environment variable
9
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
10
 
11
-
12
  DESCRIPTION = '''
13
  <div>
14
  <h1 style="text-align: center;">ContenteaseAI custom trained model</h1>
@@ -17,7 +21,6 @@ DESCRIPTION = '''
17
 
18
  LICENSE = """
19
  <p/>
20
-
21
  ---
22
  For more information, visit our [website](https://contentease.ai).
23
  """
@@ -29,14 +32,13 @@ PLACEHOLDER = """
29
  </div>
30
  """
31
 
32
-
33
  css = """
34
  h1 {
35
  text-align: center;
36
  display: block;
37
  }
38
-
39
  """
 
40
  # Load the tokenizer and model with quantization
41
  model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
42
  bnb_config = BitsAndBytesConfig(
@@ -46,14 +48,21 @@ bnb_config = BitsAndBytesConfig(
46
  bnb_4bit_compute_dtype=torch.bfloat16
47
  )
48
 
49
- tokenizer = AutoTokenizer.from_pretrained(model_id)
50
- model = AutoModelForCausalLM.from_pretrained(
51
- model_id,
52
- device_map="auto",
53
- quantization_config=bnb_config,
54
- torch_dtype=torch.bfloat16
55
- )
56
- model.generation_config.pad_token_id = tokenizer.pad_token_id
 
 
 
 
 
 
 
57
 
58
  terminators = [
59
  tokenizer.eos_token_id,
@@ -67,25 +76,41 @@ Bad JSON example: {'lobby': { 'frcm': { 'replace': [ 'carpet', 'carpet_pad', 'ba
67
  Make sure to fetch details from the provided text and ignore unnecessary information. The response should be in JSON format only, without any additional comments.
68
  """
69
 
70
- @spaces.GPU(duration=120)
71
- def chat_llama3_8b(message: str, history: list, temperature: float, max_new_tokens: int):
72
  """
73
- Generate a streaming response using the llama3-8b model.
74
 
75
  Args:
76
- message (str): The input message.
77
- history (list): The conversation history used by ChatInterface.
78
- temperature (float): The temperature for generating the response.
79
- max_new_tokens (int): The maximum number of new tokens to generate.
80
 
81
  Returns:
82
- str: The generated response.
83
  """
84
- conversation = [{"role": "system", "content": SYS_PROMPT}]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
 
86
  for user, assistant in history:
87
  conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
88
- conversation.append({"role": "user", "content": message})
89
 
90
  input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt").to(model.device)
91
 
@@ -109,8 +134,43 @@ def chat_llama3_8b(message: str, history: list, temperature: float, max_new_toke
109
  outputs = []
110
  for text in streamer:
111
  outputs.append(text)
112
- yield "".join(outputs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
 
 
 
 
114
 
115
  # Gradio block
116
  chatbot = gr.Chatbot(height=450, placeholder=PLACEHOLDER, label='Gradio ChatInterface')
@@ -132,4 +192,7 @@ with gr.Blocks(fill_height=True, css=css) as demo:
132
  gr.Markdown(LICENSE)
133
 
134
  if __name__ == "__main__":
135
- demo.launch(show_error=True, debug=True)
 
 
 
 
 
1
  import gradio as gr
2
  import os
3
+ import time
4
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
5
  import torch
6
  from threading import Thread
7
+ import logging
8
+ import spaces
9
+ # Set up logging
10
+ logging.basicConfig(level=logging.INFO)
11
+ logger = logging.getLogger(__name__)
12
 
13
  # Set an environment variable
14
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
15
 
 
16
  DESCRIPTION = '''
17
  <div>
18
  <h1 style="text-align: center;">ContenteaseAI custom trained model</h1>
 
21
 
22
  LICENSE = """
23
  <p/>
 
24
  ---
25
  For more information, visit our [website](https://contentease.ai).
26
  """
 
32
  </div>
33
  """
34
 
 
35
  css = """
36
  h1 {
37
  text-align: center;
38
  display: block;
39
  }
 
40
  """
41
+
42
  # Load the tokenizer and model with quantization
43
  model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
44
  bnb_config = BitsAndBytesConfig(
 
48
  bnb_4bit_compute_dtype=torch.bfloat16
49
  )
50
 
51
+ try:
52
+ logger.info("Loading tokenizer...")
53
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
54
+ logger.info("Loading model...")
55
+ model = AutoModelForCausalLM.from_pretrained(
56
+ model_id,
57
+ device_map="auto",
58
+ quantization_config=bnb_config,
59
+ torch_dtype=torch.bfloat16
60
+ )
61
+ model.generation_config.pad_token_id = tokenizer.pad_token_id
62
+ logger.info("Model and tokenizer loaded successfully.")
63
+ except Exception as e:
64
+ logger.error(f"Error loading model or tokenizer: {e}")
65
+ raise
66
 
67
  terminators = [
68
  tokenizer.eos_token_id,
 
76
  Make sure to fetch details from the provided text and ignore unnecessary information. The response should be in JSON format only, without any additional comments.
77
  """
78
 
79
+ def chunk_text(text, chunk_size=4000):
 
80
  """
81
+ Splits the input text into chunks of specified size.
82
 
83
  Args:
84
+ text (str): The input text to be chunked.
85
+ chunk_size (int): The size of each chunk in tokens.
 
 
86
 
87
  Returns:
88
+ list: A list of text chunks.
89
  """
90
+ words = text.split()
91
+ chunks = [' '.join(words[i:i + chunk_size]) for i in range(0, len(words), chunk_size)]
92
+ return chunks
93
+
94
+ def combine_responses(responses):
95
+ """
96
+ Combines the responses from all chunks into a final output string.
97
+
98
+ Args:
99
+ responses (list): A list of responses from each chunk.
100
+
101
+ Returns:
102
+ str: The combined output string.
103
+ """
104
+ combined_output = " ".join(responses)
105
+ return combined_output
106
+
107
+ def generate_response_for_chunk(chunk, history, temperature, max_new_tokens):
108
+ start_time = time.time()
109
 
110
+ conversation = [{"role": "system", "content": SYS_PROMPT}]
111
  for user, assistant in history:
112
  conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
113
+ conversation.append({"role": "user", "content": chunk})
114
 
115
  input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt").to(model.device)
116
 
 
134
  outputs = []
135
  for text in streamer:
136
  outputs.append(text)
137
+
138
+ end_time = time.time()
139
+ logger.info(f"Time taken for generating response for a chunk: {end_time - start_time} seconds")
140
+
141
+ return "".join(outputs)
142
+
143
+ @spaces.GPU(duration=120)
144
+ def chat_llama3_8b(message: str, history: list, temperature: float, max_new_tokens: int):
145
+ """
146
+ Generate a streaming response using the llama3-8b model with chunking.
147
+
148
+ Args:
149
+ message (str): The input message.
150
+ history (list): The conversation history used by ChatInterface.
151
+ temperature (float): The temperature for generating the response.
152
+ max_new_tokens (int): The maximum number of new tokens to generate.
153
+
154
+ Returns:
155
+ str: The generated response.
156
+ """
157
+ try:
158
+ start_time = time.time()
159
+
160
+ chunks = chunk_text(message)
161
+ responses = []
162
+ for chunk in chunks:
163
+ response = generate_response_for_chunk(chunk, history, temperature, max_new_tokens)
164
+ responses.append(response)
165
+ final_output = combine_responses(responses)
166
+
167
+ end_time = time.time()
168
+ logger.info(f"Total time taken for generating response: {end_time - start_time} seconds")
169
 
170
+ yield final_output
171
+ except Exception as e:
172
+ logger.error(f"Error generating response: {e}")
173
+ yield "An error occurred while generating the response. Please try again."
174
 
175
  # Gradio block
176
  chatbot = gr.Chatbot(height=450, placeholder=PLACEHOLDER, label='Gradio ChatInterface')
 
192
  gr.Markdown(LICENSE)
193
 
194
  if __name__ == "__main__":
195
+ try:
196
+ demo.launch(show_error=True)
197
+ except Exception as e:
198
+ logger.error(f"Error launching Gradio demo: {e}")