Leri777 commited on
Commit
806020c
1 Parent(s): 7f13eee

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -21
app.py CHANGED
@@ -1,11 +1,15 @@
1
  import os
2
  import logging
 
 
3
  from logging.handlers import RotatingFileHandler
4
  import gradio as gr
5
  import torch
 
6
  from transformers import AutoModelForCausalLM, GemmaTokenizerFast, pipeline
7
  from langchain_huggingface import HuggingFacePipeline
8
  from langchain.prompts import PromptTemplate
 
9
 
10
  # Logging setup
11
  log_file = '/tmp/app_debug.log'
@@ -20,24 +24,31 @@ logger.debug("Application started")
20
  model_id = "google/gemma-2-9b-it"
21
  tokenizer = GemmaTokenizerFast.from_pretrained(model_id)
22
 
23
- # Load model with GPU availability check
24
- if torch.cuda.is_available():
25
- logger.debug("GPU is available. Proceeding with GPU setup.")
26
- model = AutoModelForCausalLM.from_pretrained(
27
- model_id,
28
- device_map="auto",
29
- torch_dtype=torch.bfloat16,
30
- )
31
- else:
32
- logger.warning("GPU is not available. Proceeding with CPU setup.")
33
- model = AutoModelForCausalLM.from_pretrained(
34
- model_id,
35
- device_map="auto",
36
- low_cpu_mem_usage=True,
37
- token=os.getenv('HF_TOKEN'),
38
- )
39
 
40
- model.eval()
 
 
 
 
 
 
 
 
 
41
 
42
  # Create Hugging Face pipeline
43
  pipe = pipeline(
@@ -91,11 +102,22 @@ def predict(message, chat_history=[]):
91
  # Gradio UI
92
  interface = gr.Interface(
93
  fn=predict,
94
- inputs=gr.Textbox(label="User input"),
95
- outputs="text",
 
 
 
96
  live=True,
97
  )
98
 
99
- interface.launch()
 
 
 
 
 
 
 
 
100
 
101
- logger.debug("Chat interface initialized and launched")
 
1
  import os
2
  import logging
3
+ import time
4
+ import random
5
  from logging.handlers import RotatingFileHandler
6
  import gradio as gr
7
  import torch
8
+ from accelerate import Accelerator
9
  from transformers import AutoModelForCausalLM, GemmaTokenizerFast, pipeline
10
  from langchain_huggingface import HuggingFacePipeline
11
  from langchain.prompts import PromptTemplate
12
+ from langchain.chains import LLMChain
13
 
14
  # Logging setup
15
  log_file = '/tmp/app_debug.log'
 
24
  model_id = "google/gemma-2-9b-it"
25
  tokenizer = GemmaTokenizerFast.from_pretrained(model_id)
26
 
27
+ # Function to load model with GPU availability check
28
+ def load_model():
29
+ if torch.cuda.is_available():
30
+ logger.debug("GPU is available. Proceeding with GPU setup.")
31
+ return AutoModelForCausalLM.from_pretrained(
32
+ model_id,
33
+ device_map="auto", torch_dtype=torch.bfloat16,
34
+ )
35
+ else:
36
+ logger.warning("GPU is not available. Proceeding with CPU setup.")
37
+ return AutoModelForCausalLM.from_pretrained(
38
+ model_id,
39
+ device_map="auto", low_cpu_mem_usage=True, token=os.getenv('HF_TOKEN'),
40
+ )
 
 
41
 
42
+ # Retry logic to load model with random delay
43
+ model = None
44
+ while model is None:
45
+ try:
46
+ model = load_model()
47
+ model.eval()
48
+ except Exception as e:
49
+ retry_delay = random.uniform(10, 30) # Random delay between 10 to 30 seconds
50
+ logger.error(f"Failed to load model: {e}. Retrying in {retry_delay:.2f} seconds...")
51
+ time.sleep(retry_delay)
52
 
53
  # Create Hugging Face pipeline
54
  pipe = pipeline(
 
102
  # Gradio UI
103
  interface = gr.Interface(
104
  fn=predict,
105
+ inputs=[
106
+ gr.Textbox(label="User input"),
107
+ gr.State(),
108
+ ],
109
+ outputs="text", allow_flagging='never',
110
  live=True,
111
  )
112
 
113
+ # Retry logic to launch interface with random delay
114
+ while True:
115
+ try:
116
+ interface.launch()
117
+ break
118
+ except Exception as e:
119
+ retry_delay = random.uniform(10, 30) # Random delay between 10 to 30 seconds
120
+ logger.error(f"Failed to launch interface: {e}. Retrying in {retry_delay:.2f} seconds...")
121
+ time.sleep(retry_delay)
122
 
123
+ logger.debug("Chat interface initialized and launched")