Leri777 commited on
Commit
2f7b0a4
1 Parent(s): 4aae838

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +98 -54
app.py CHANGED
@@ -1,81 +1,123 @@
1
  import os
2
  from threading import Thread
3
  from typing import Iterator
4
-
 
5
  import gradio as gr
6
  import spaces
7
  import torch
8
- from transformers import AutoModelForCausalLM, GemmaTokenizerFast, TextIteratorStreamer
 
 
 
 
 
 
 
 
 
 
 
9
 
10
- DESCRIPTION = """\
11
- # Gemma 2 9B IT
 
 
12
 
13
  Gemma 2 is Google's latest iteration of open LLMs.
14
  This is a demo of [`google/gemma-2-9b-it`](https://huggingface.co/google/gemma-2-9b-it), fine-tuned for instruction following.
15
- For more details, please check [our post](https://huggingface.co/blog/gemma2).
16
-
17
- 👉 Looking for a larger and more powerful version? Try the 27B version in [HuggingChat](https://huggingface.co/chat/models/google/gemma-2-27b-it).
18
  """
19
 
20
  MAX_MAX_NEW_TOKENS = 2048
21
  DEFAULT_MAX_NEW_TOKENS = 1024
22
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
23
 
24
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
25
-
26
  model_id = "google/gemma-2-9b-it"
27
  tokenizer = GemmaTokenizerFast.from_pretrained(model_id)
28
- model = AutoModelForCausalLM.from_pretrained(
29
- model_id,
30
- device_map="auto",
31
- torch_dtype=torch.bfloat16,
32
- )
33
- model.config.sliding_window = 4096
34
- model.eval()
35
 
36
-
37
- @spaces.GPU(duration=90)
38
- def generate(
39
- message: str,
40
- chat_history: list[dict],
41
- max_new_tokens: int = 1024,
42
- temperature: float = 0.6,
43
- top_p: float = 0.9,
44
- top_k: int = 50,
45
- repetition_penalty: float = 1.2,
46
- ) -> Iterator[str]:
47
- conversation = chat_history.copy()
48
- conversation.append({"role": "user", "content": message})
49
-
50
- input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
51
- if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
52
- input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
53
- gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
54
- input_ids = input_ids.to(model.device)
55
-
56
- streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
57
- generate_kwargs = dict(
58
- {"input_ids": input_ids},
59
- streamer=streamer,
60
- max_new_tokens=max_new_tokens,
61
- do_sample=True,
62
- top_p=top_p,
63
- top_k=top_k,
64
- temperature=temperature,
65
- num_beams=1,
66
- repetition_penalty=repetition_penalty,
67
  )
68
- t = Thread(target=model.generate, kwargs=generate_kwargs)
69
- t.start()
70
 
71
- outputs = []
72
- for text in streamer:
73
- outputs.append(text)
74
- yield "".join(outputs)
75
 
 
 
 
 
 
 
 
 
 
 
 
76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  chat_interface = gr.ChatInterface(
78
- fn=generate,
79
  additional_inputs=[
80
  gr.Slider(
81
  label="Max new tokens",
@@ -132,3 +174,5 @@ with gr.Blocks(css="style.css", fill_height=True) as demo:
132
 
133
  if __name__ == "__main__":
134
  demo.queue(max_size=20).launch()
 
 
 
1
  import os
2
  from threading import Thread
3
  from typing import Iterator
4
+ import logging
5
+ from logging.handlers import RotatingFileHandler
6
  import gradio as gr
7
  import spaces
8
  import torch
9
+ from transformers import AutoModelForCausalLM, GemmaTokenizerFast, TextIteratorStreamer, pipeline
10
+ from langchain_huggingface import HuggingFacePipeline
11
+ from langchain.prompts import PromptTemplate
12
+ from langchain.chains import LLMChain
13
+
14
+ # Logging setup
15
+ log_file = '/tmp/app_debug.log'
16
+ logger = logging.getLogger(__name__)
17
+ logger.setLevel(logging.DEBUG)
18
+ file_handler = RotatingFileHandler(log_file, maxBytes=10*1024*1024, backupCount=5)
19
+ file_handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
20
+ logger.addHandler(file_handler)
21
 
22
+ logger.debug("Application started")
23
+
24
+ DESCRIPTION = """
25
+ # Gemma 2 9B IT with LangChain Integration
26
 
27
  Gemma 2 is Google's latest iteration of open LLMs.
28
  This is a demo of [`google/gemma-2-9b-it`](https://huggingface.co/google/gemma-2-9b-it), fine-tuned for instruction following.
29
+ Now integrated with LangChain for enhanced interaction capabilities.
 
 
30
  """
31
 
32
  MAX_MAX_NEW_TOKENS = 2048
33
  DEFAULT_MAX_NEW_TOKENS = 1024
34
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
35
 
 
 
36
  model_id = "google/gemma-2-9b-it"
37
  tokenizer = GemmaTokenizerFast.from_pretrained(model_id)
 
 
 
 
 
 
 
38
 
39
+ # Load model with GPU availability check
40
+ if torch.cuda.is_available():
41
+ logger.debug("GPU is available. Proceeding with GPU setup.")
42
+ model = AutoModelForCausalLM.from_pretrained(
43
+ model_id,
44
+ device_map="auto",
45
+ torch_dtype=torch.bfloat16,
46
+ )
47
+ else:
48
+ logger.warning("GPU is not available. Proceeding with CPU setup.")
49
+ model = AutoModelForCausalLM.from_pretrained(
50
+ model_id,
51
+ device_map="auto",
52
+ low_cpu_mem_usage=True,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  )
 
 
54
 
55
+ model.config.sliding_window = 4096
56
+ model.eval()
 
 
57
 
58
+ # Create Hugging Face pipeline
59
+ pipe = pipeline(
60
+ "text-generation",
61
+ model=model,
62
+ tokenizer=tokenizer,
63
+ max_length=MAX_MAX_NEW_TOKENS,
64
+ temperature=0.7,
65
+ top_k=50,
66
+ top_p=0.9,
67
+ repetition_penalty=1.2,
68
+ )
69
 
70
+ # Initialize HuggingFacePipeline model for LangChain
71
+ chat_model = HuggingFacePipeline(pipeline=pipe)
72
+ logger.debug("Model and tokenizer loaded successfully")
73
+
74
+ # Define the conversation template for LangChain
75
+ template = """<|im_start|>system
76
+ {system_prompt}
77
+ <|im_end|>
78
+ {history}
79
+ <|im_start|>user
80
+ {human_input}
81
+ <|im_end|>
82
+ <|im_start|>assistant"""
83
+
84
+ # Create LangChain prompt and chain
85
+ prompt = PromptTemplate(
86
+ template=template, input_variables=["system_prompt", "history", "human_input"]
87
+ )
88
+ chain = LLMChain(llm=chat_model, prompt=prompt)
89
+
90
+ # Prediction function using LangChain and model
91
+ def predict(
92
+ message,
93
+ chat_history,
94
+ max_new_tokens,
95
+ temperature,
96
+ top_p,
97
+ top_k,
98
+ repetition_penalty,
99
+ ):
100
+ formatted_history = "\n".join(
101
+ [f"<|im_start|>{entry['role']}\n{entry['content']}<|im_end|>" for entry in chat_history]
102
+ )
103
+ system_prompt = "You are a helpful coding assistant."
104
+
105
+ try:
106
+ result = chain.run(
107
+ {
108
+ "system_prompt": system_prompt,
109
+ "history": formatted_history,
110
+ "human_input": message,
111
+ }
112
+ )
113
+ return result
114
+ except Exception as e:
115
+ logger.exception(f"Error during prediction: {e}")
116
+ return "An error occurred."
117
+
118
+ # Gradio UI
119
  chat_interface = gr.ChatInterface(
120
+ fn=predict,
121
  additional_inputs=[
122
  gr.Slider(
123
  label="Max new tokens",
 
174
 
175
  if __name__ == "__main__":
176
  demo.queue(max_size=20).launch()
177
+
178
+ logger.debug("Chat interface initialized and launched")