Spaces:
Runtime error
Runtime error
import gradio as gr | |
import os | |
import time | |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig | |
import torch | |
from threading import Thread | |
import logging | |
import spaces | |
from functools import lru_cache | |
print(f"Is CUDA available: {torch.cuda.is_available()}") | |
# True | |
print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}") | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Set an environment variable | |
HF_TOKEN = os.environ.get("HF_TOKEN", None) | |
DESCRIPTION = ''' | |
<div> | |
<h1 style="text-align: center;">ContenteaseAI custom trained model</h1> | |
</div> | |
''' | |
LICENSE = """ | |
<p/> | |
--- | |
For more information, visit our [website](https://contentease.ai). | |
""" | |
PLACEHOLDER = """ | |
<div style="padding: 30px; text-align: center; display: flex; flex-direction: column; align-items: center;"> | |
<h1 style="font-size: 28px; margin-bottom: 2px; opacity: 0.55;">ContenteaseAI Custom AI trained model</h1> | |
<p style="font-size: 18px; margin-bottom: 2px; opacity: 0.65;">Enter the text extracted from the PDF:</p> | |
</div> | |
""" | |
css = """ | |
h1 { | |
text-align: center; | |
display: block; | |
} | |
""" | |
# Load the tokenizer and model with quantization | |
model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct" | |
bnb_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_use_double_quant=True, | |
bnb_4bit_quant_type="nf4", | |
bnb_4bit_compute_dtype=torch.bfloat16 | |
) | |
def load_model_and_tokenizer(): | |
try: | |
start_time = time.time() | |
logger.info("Loading tokenizer...") | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
logger.info("Loading model...") | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
device_map="auto", | |
quantization_config=bnb_config, | |
torch_dtype=torch.bfloat16 | |
) | |
model.generation_config.pad_token_id = tokenizer.pad_token_id | |
end_time = time.time() | |
logger.info(f"Model and tokenizer loaded successfully in {end_time - start_time} seconds.") | |
return model, tokenizer | |
except Exception as e: | |
logger.error(f"Error loading model or tokenizer: {e}") | |
raise | |
try: | |
model, tokenizer = load_model_and_tokenizer() | |
except Exception as e: | |
logger.error(f"Failed to load model and tokenizer: {e}") | |
raise | |
terminators = [ | |
tokenizer.eos_token_id, | |
tokenizer.convert_tokens_to_ids("<|eot_id|>") | |
] | |
SYS_PROMPT = """ | |
Given the text of a hotel property improvement plan, extract the items to be replaced for only the Guest Rooms/ Suites, Guest Bathrooms/Suite Bathrooms. | |
First, find the section of the pdf which describes improvements to be done on the Guest Rooms and Guest Bathrooms, then find the items to be replaced. | |
Ignore items from other sections of the hotel. | |
Items to be replaced are usually preceded by the words replace, install, or provide. | |
Return the results as a JSON with "Guest Room" and "Guest Bathroom" as keys and each value the list of unique items to be replaced. | |
Return only the JSON with no extra text. | |
Example Text: | |
" | |
Site & Building Exterior | |
Replace all exterior decorative lighting | |
... | |
Guestrooms | |
Replace [ORG] C-Table. | |
Provide full length mirror. | |
Replace cabinets - Kitchen. | |
at doors where brass hardware finishes exist – replace with stainless | |
... | |
Guest Bathrooms - (FRCM) Replace mirrors. Install a vanity mirror that has integrated lighting | |
Guest Bathrooms - (FRCM) Replace artwork and decorative accessories. | |
... | |
Suites - Replace microwave, refrigerator, and associated casegood cabinet. | |
" | |
Example Response: | |
{ | |
"Guest Room": [ | |
"C-Table", | |
"full length mirror", | |
"kitchen cabinets", | |
"stainless steel door hardware", | |
"microwave", | |
"refrigerator", | |
"casegood cabinet",], | |
"Guest Bathroom": [ | |
"vanity mirror with integrated lighting", | |
"artwork", | |
"decorative accessories",], | |
} | |
""" | |
def chunk_text(text, chunk_size=5000): | |
""" | |
Splits the input text into chunks of specified size. | |
Args: | |
text (str): The input text to be chunked. | |
chunk_size (int): The size of each chunk in tokens. | |
Returns: | |
list: A list of text chunks. | |
""" | |
words = text.split() | |
chunks = [' '.join(words[i:i + chunk_size]) for i in range(0, len(words), chunk_size)] | |
logger.info(f"Total chunks created: {len(chunks)}") | |
return chunks | |
def combine_responses(responses): | |
""" | |
Combines the responses from all chunks into a final output string. | |
Args: | |
responses (list): A list of responses from each chunk. | |
Returns: | |
str: The combined output string. | |
""" | |
combined_output = " ".join(responses) | |
return combined_output | |
def generate_response_for_chunk(chunk, history, temperature, max_new_tokens): | |
start_time = time.time() | |
if len(history) == 0: | |
pass | |
else: | |
history.pop() | |
conversation = [{"role": "system", "content": SYS_PROMPT}] | |
for user, assistant in history: | |
conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}]) | |
conversation.append({"role": "user", "content": chunk}) | |
input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt").to(model.device) | |
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) | |
generate_kwargs = dict( | |
input_ids=input_ids, | |
streamer=streamer, | |
max_new_tokens=max_new_tokens, | |
do_sample=True, | |
temperature=temperature, | |
eos_token_id=terminators, | |
pad_token_id=tokenizer.eos_token_id | |
) | |
if temperature == 0: | |
generate_kwargs['do_sample'] = False | |
t = Thread(target=model.generate, kwargs=generate_kwargs) | |
t.start() | |
outputs = [] | |
for text in streamer: | |
outputs.append(text) | |
end_time = time.time() | |
logger.info(f"Time taken for generating response for a chunk: {end_time - start_time} seconds") | |
return "".join(outputs) | |
def chat_llama3_8b(message: str, history: list, temperature: float, max_new_tokens: int): | |
""" | |
Generate a streaming response using the llama3-8b model with chunking. | |
Args: | |
message (str): The input message. | |
history (list): The conversation history used by ChatInterface. | |
temperature (float): The temperature for generating the response. | |
max_new_tokens (int): The maximum number of new tokens to generate. | |
Returns: | |
str: The generated response. | |
""" | |
try: | |
start_time = time.time() | |
chunks = chunk_text(message) | |
responses = [] | |
count=0 | |
for chunk in chunks: | |
logger.info(f"Processing chunk {count+1}/{len(chunks)}") | |
response = generate_response_for_chunk(chunk, history, temperature, max_new_tokens) | |
responses.append(response) | |
count+=1 | |
final_output = combine_responses(responses) | |
end_time = time.time() | |
logger.info(f"Total time taken for generating response: {end_time - start_time} seconds") | |
yield final_output | |
except Exception as e: | |
logger.error(f"Error generating response: {e}") | |
yield "An error occurred while generating the response. Please try again." | |
# Gradio block | |
chatbot = gr.Chatbot(height=450, placeholder=PLACEHOLDER, label='Gradio ChatInterface') | |
with gr.Blocks(fill_height=True, css=css) as demo: | |
gr.Markdown(DESCRIPTION) | |
gr.ChatInterface( | |
fn=chat_llama3_8b, | |
chatbot=chatbot, | |
fill_height=True, | |
additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False), | |
additional_inputs=[ | |
gr.Slider(minimum=0, maximum=1, step=0.1, value=0.95, label="Temperature", render=False), | |
gr.Slider(minimum=128, maximum=2000, step=1, value=700, label="Max new tokens", render=False), | |
] | |
) | |
gr.Markdown(LICENSE) | |
if __name__ == "__main__": | |
try: | |
demo.launch(show_error=True) | |
except Exception as e: | |
logger.error(f"Error launching Gradio demo: {e}") | |