Spaces:
Sleeping
Sleeping
import os | |
import logging | |
import time | |
import random | |
from logging.handlers import RotatingFileHandler | |
import gradio as gr | |
import torch | |
from transformers import AutoModelForCausalLM, GemmaTokenizerFast, pipeline | |
from langchain_huggingface import HuggingFacePipeline | |
from langchain.prompts import PromptTemplate | |
from langchain.chains import LLMChain | |
# Logging setup | |
log_file = '/tmp/app_debug.log' | |
logger = logging.getLogger(__name__) | |
logger.setLevel(logging.DEBUG) | |
file_handler = RotatingFileHandler(log_file, maxBytes=10*1024*1024, backupCount=5) | |
file_handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')) | |
logger.addHandler(file_handler) | |
logger.debug("Application started") | |
model_id = "google/gemma-2-9b-it" | |
tokenizer = GemmaTokenizerFast.from_pretrained(model_id) | |
# Function to load model with GPU availability check | |
def load_model(): | |
attempts = 0 | |
while attempts < 5: # Try up to 5 times to get a GPU | |
if torch.cuda.is_available(): | |
logger.debug("GPU is available. Proceeding with GPU setup.") | |
try: | |
return AutoModelForCausalLM.from_pretrained( | |
model_id, | |
device_map="auto", torch_dtype=torch.bfloat16, | |
) | |
except Exception as e: | |
logger.error(f"Error initializing model with GPU: {e}. Retrying...") | |
attempts += 1 | |
time.sleep(random.uniform(20, 60)) # Wait before retrying | |
else: | |
logger.warning("GPU is not available. Retrying GPU initialization...") | |
attempts += 1 | |
time.sleep(random.uniform(20, 60)) | |
# If GPU is still not available, fall back to CPU | |
logger.warning("Falling back to CPU setup after multiple attempts.") | |
return AutoModelForCausalLM.from_pretrained( | |
model_id, | |
device_map="auto", low_cpu_mem_usage=True, token=os.getenv('HF_TOKEN'), | |
) | |
# Retry logic to load model with random delay | |
model = None | |
while model is None: | |
try: | |
model = load_model() | |
model.eval() | |
except Exception as e: | |
retry_delay = random.uniform(10, 30) # Random delay between 10 to 30 seconds | |
logger.error(f"Failed to load model: {e}. Retrying in {retry_delay:.2f} seconds...") | |
time.sleep(retry_delay) | |
# Create Hugging Face pipeline | |
pipe = pipeline( | |
"text-generation", | |
model=model, | |
tokenizer=tokenizer, | |
max_length=2048, | |
temperature=0.7, | |
top_k=50, | |
top_p=0.9, | |
repetition_penalty=1.2, | |
) | |
# Initialize HuggingFacePipeline model for LangChain | |
chat_model = HuggingFacePipeline(pipeline=pipe) | |
# Define the conversation template for LangChain | |
template = """<|im_start|>system | |
{system_prompt} | |
<|im_end|> | |
{history} | |
<|im_start|>user | |
{human_input} | |
<|im_end|> | |
<|im_start|>assistant""" | |
# Create LangChain prompt and chain | |
prompt = PromptTemplate( | |
template=template, input_variables=["system_prompt", "history", "human_input"] | |
) | |
chain = prompt | chat_model | |
# Prediction function using LangChain and model | |
def predict(message, chat_history=[]): | |
formatted_history = "\n".join( | |
[f"<|im_start|>{entry['role']}\n{entry['content']}<|im_end|>" for entry in chat_history] | |
) | |
system_prompt = "You are a helpful coding assistant." | |
try: | |
result = chain.run({ | |
"system_prompt": system_prompt, | |
"history": formatted_history, | |
"human_input": message | |
}) | |
return result | |
except Exception as e: | |
logger.exception(f"Error during prediction: {e}") | |
return "An error occurred." | |
# Gradio UI | |
interface = gr.Interface( | |
fn=predict, | |
inputs=[ | |
gr.Textbox(label="User input") | |
], | |
outputs="text", allow_flagging='never', | |
live=True, | |
) | |
# Retry logic to launch interface with random delay | |
while True: | |
try: | |
interface.launch() | |
break | |
except Exception as e: | |
retry_delay = random.uniform(10, 30) # Random delay between 10 to 30 seconds | |
logger.error(f"Failed to launch interface: {e}. Retrying in {retry_delay:.2f} seconds...") | |
time.sleep(retry_delay) | |
logger.debug("Chat interface initialized and launched") | |