gemma-2-9b-it1 / app.py
Leri777's picture
Update app.py
39d0572 verified
raw
history blame
4.17 kB
import os
import logging
import time
import random
from logging.handlers import RotatingFileHandler
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, GemmaTokenizerFast, pipeline
from langchain_huggingface import HuggingFacePipeline
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain
# Logging setup
log_file = '/tmp/app_debug.log'
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
file_handler = RotatingFileHandler(log_file, maxBytes=10*1024*1024, backupCount=5)
file_handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
logger.addHandler(file_handler)
logger.debug("Application started")
model_id = "google/gemma-2-9b-it"
tokenizer = GemmaTokenizerFast.from_pretrained(model_id)
# Function to load model with GPU availability check
def load_model():
attempts = 0
while attempts < 5: # Try up to 5 times to get a GPU
if torch.cuda.is_available():
logger.debug("GPU is available. Proceeding with GPU setup.")
try:
return AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto", torch_dtype=torch.bfloat16,
)
except Exception as e:
logger.error(f"Error initializing model with GPU: {e}. Retrying...")
attempts += 1
time.sleep(random.uniform(20, 60)) # Wait before retrying
else:
logger.warning("GPU is not available. Retrying GPU initialization...")
attempts += 1
time.sleep(random.uniform(20, 60))
# If GPU is still not available, fall back to CPU
logger.warning("Falling back to CPU setup after multiple attempts.")
return AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto", low_cpu_mem_usage=True, token=os.getenv('HF_TOKEN'),
)
# Retry logic to load model with random delay
model = None
while model is None:
try:
model = load_model()
model.eval()
except Exception as e:
retry_delay = random.uniform(10, 30) # Random delay between 10 to 30 seconds
logger.error(f"Failed to load model: {e}. Retrying in {retry_delay:.2f} seconds...")
time.sleep(retry_delay)
# Create Hugging Face pipeline
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_length=2048,
temperature=0.7,
top_k=50,
top_p=0.9,
repetition_penalty=1.2,
)
# Initialize HuggingFacePipeline model for LangChain
chat_model = HuggingFacePipeline(pipeline=pipe)
# Define the conversation template for LangChain
template = """<|im_start|>system
{system_prompt}
<|im_end|>
{history}
<|im_start|>user
{human_input}
<|im_end|>
<|im_start|>assistant"""
# Create LangChain prompt and chain
prompt = PromptTemplate(
template=template, input_variables=["system_prompt", "history", "human_input"]
)
chain = prompt | chat_model
# Prediction function using LangChain and model
def predict(message, chat_history=[]):
formatted_history = "\n".join(
[f"<|im_start|>{entry['role']}\n{entry['content']}<|im_end|>" for entry in chat_history]
)
system_prompt = "You are a helpful coding assistant."
try:
result = chain.run({
"system_prompt": system_prompt,
"history": formatted_history,
"human_input": message
})
return result
except Exception as e:
logger.exception(f"Error during prediction: {e}")
return "An error occurred."
# Gradio UI
interface = gr.Interface(
fn=predict,
inputs=[
gr.Textbox(label="User input")
],
outputs="text", allow_flagging='never',
live=True,
)
# Retry logic to launch interface with random delay
while True:
try:
interface.launch()
break
except Exception as e:
retry_delay = random.uniform(10, 30) # Random delay between 10 to 30 seconds
logger.error(f"Failed to launch interface: {e}. Retrying in {retry_delay:.2f} seconds...")
time.sleep(retry_delay)
logger.debug("Chat interface initialized and launched")