Leri777 commited on
Commit
b78e9ba
·
verified ·
1 Parent(s): 4df36c7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -11
app.py CHANGED
@@ -3,7 +3,7 @@ import logging
3
  from logging.handlers import RotatingFileHandler
4
  import gradio as gr
5
  import torch
6
- from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
7
  from langchain_huggingface import HuggingFacePipeline
8
  from langchain.prompts import PromptTemplate
9
  from langchain.chains import LLMChain
@@ -19,22 +19,23 @@ logger.addHandler(file_handler)
19
  logger.debug("Application started")
20
 
21
  model_id = "google/gemma-2-9b-it"
22
- tokenizer = AutoTokenizer.from_pretrained(model_id)
23
 
24
  # Load model with GPU availability check
25
  if torch.cuda.is_available():
26
  logger.debug("GPU is available. Proceeding with GPU setup.")
27
  model = AutoModelForCausalLM.from_pretrained(
28
  model_id,
29
- device_map="auto",
30
  torch_dtype=torch.bfloat16,
31
  )
32
  else:
33
  logger.warning("GPU is not available. Proceeding with CPU setup.")
34
  model = AutoModelForCausalLM.from_pretrained(
35
  model_id,
 
36
  low_cpu_mem_usage=True,
37
- use_auth_token=os.getenv('HF_TOKEN'),
38
  )
39
 
40
  model.eval()
@@ -54,8 +55,6 @@ pipe = pipeline(
54
  # Initialize HuggingFacePipeline model for LangChain
55
  chat_model = HuggingFacePipeline(pipeline=pipe)
56
 
57
- logger.debug("Model and tokenizer loaded successfully")
58
-
59
  # Define the conversation template for LangChain
60
  template = """<|im_start|>system
61
  {system_prompt}
@@ -70,12 +69,12 @@ template = """<|im_start|>system
70
  prompt = PromptTemplate(
71
  template=template, input_variables=["system_prompt", "history", "human_input"]
72
  )
73
- chain = LLMChain(llm=chat_model, prompt=prompt)
74
 
75
  # Prediction function using LangChain and model
76
- def predict(message, history=[]):
77
  formatted_history = "\n".join(
78
- [f"<|im_start|>{entry['role']}\n{entry['content']}<|im_end|>" for entry in history]
79
  )
80
  system_prompt = "You are a helpful coding assistant."
81
 
@@ -93,9 +92,10 @@ def predict(message, history=[]):
93
  # Gradio UI
94
  interface = gr.Interface(
95
  fn=predict,
96
- inputs=gr.Textbox(label="User input"),
 
 
97
  outputs="text",
98
- allow_flagging='never',
99
  live=True,
100
  )
101
 
 
3
  from logging.handlers import RotatingFileHandler
4
  import gradio as gr
5
  import torch
6
+ from transformers import AutoModelForCausalLM, GemmaTokenizerFast, pipeline
7
  from langchain_huggingface import HuggingFacePipeline
8
  from langchain.prompts import PromptTemplate
9
  from langchain.chains import LLMChain
 
19
  logger.debug("Application started")
20
 
21
  model_id = "google/gemma-2-9b-it"
22
+ tokenizer = GemmaTokenizerFast.from_pretrained(model_id)
23
 
24
  # Load model with GPU availability check
25
  if torch.cuda.is_available():
26
  logger.debug("GPU is available. Proceeding with GPU setup.")
27
  model = AutoModelForCausalLM.from_pretrained(
28
  model_id,
29
+ device_map="auto",
30
  torch_dtype=torch.bfloat16,
31
  )
32
  else:
33
  logger.warning("GPU is not available. Proceeding with CPU setup.")
34
  model = AutoModelForCausalLM.from_pretrained(
35
  model_id,
36
+ device_map="cpu",
37
  low_cpu_mem_usage=True,
38
+ token=os.getenv('HF_TOKEN'),
39
  )
40
 
41
  model.eval()
 
55
  # Initialize HuggingFacePipeline model for LangChain
56
  chat_model = HuggingFacePipeline(pipeline=pipe)
57
 
 
 
58
  # Define the conversation template for LangChain
59
  template = """<|im_start|>system
60
  {system_prompt}
 
69
  prompt = PromptTemplate(
70
  template=template, input_variables=["system_prompt", "history", "human_input"]
71
  )
72
+ chain = prompt | chat_model
73
 
74
  # Prediction function using LangChain and model
75
+ def predict(message, chat_history=[]):
76
  formatted_history = "\n".join(
77
+ [f"<|im_start|>{entry['role']}\n{entry['content']}<|im_end|>" for entry in chat_history]
78
  )
79
  system_prompt = "You are a helpful coding assistant."
80
 
 
92
  # Gradio UI
93
  interface = gr.Interface(
94
  fn=predict,
95
+ inputs=[
96
+ gr.Textbox(label="User input")
97
+ ],
98
  outputs="text",
 
99
  live=True,
100
  )
101