Spaces:
Runtime error
Runtime error
import os | |
import subprocess | |
import datasets | |
from langchain.docstore.document import Document | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.vectorstores import FAISS | |
from langchain_community.embeddings import HuggingFaceEmbeddings | |
import json | |
from transformers.agents import Tool | |
from langchain_core.vectorstores import VectorStore | |
from transformers.agents import HfEngine, ReactJsonAgent | |
# Install flash attention | |
subprocess.run( | |
"pip install flash-attn --no-build-isolation", | |
"pip install git+https://github.com/huggingface/transformers.git#egg=transformers[agents]", | |
env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"}, | |
shell=True, | |
) | |
# Install RAG dependencies | |
subprocess.run( | |
"pip install langchain sentence-transformers faiss-cpu", | |
shell=True, | |
) | |
import copy | |
import spaces | |
import time | |
import torch | |
from threading import Thread | |
from typing import List, Dict, Union | |
import urllib | |
from PIL import Image | |
import io | |
import datasets | |
import gradio as gr | |
from transformers import AutoProcessor, TextIteratorStreamer | |
from transformers import Idefics2ForConditionalGeneration | |
DEVICE = torch.device("cuda") | |
MODELS = { | |
"idefics2-8b-chatty": Idefics2ForConditionalGeneration.from_pretrained( | |
"HuggingFaceM4/idefics2-8b-chatty", | |
# "Ali-C137/idefics2-8b-chatty-yalla", | |
torch_dtype=torch.bfloat16, | |
_attn_implementation="flash_attention_2", | |
).to(DEVICE), | |
} | |
PROCESSOR = AutoProcessor.from_pretrained( | |
"HuggingFaceM4/idefics2-8b", | |
# "Ali-C137/idefics2-8b-chatty-yalla", | |
) | |
# Load the custom dataset | |
knowledge_base = datasets.load_dataset("m-ric/huggingface_doc", split="train") | |
# Process the documents | |
source_docs = [ | |
Document(page_content=doc["text"], metadata={"source": doc["source"].split("/")[1]}) | |
for doc in knowledge_base | |
] | |
docs_processed = RecursiveCharacterTextSplitter(chunk_size=500).split_documents(source_docs)[:1000] | |
# Create embeddings and vector store | |
embedding_model = HuggingFaceEmbeddings("thenlper/gte-small") | |
vectordb = FAISS.from_documents(documents=docs_processed, embedding=embedding_model) | |
class RetrieverTool(Tool): | |
name = "retriever" | |
description = "Retrieves documents from the knowledge base that have the closest embeddings to the input query." | |
inputs = { | |
"query": { | |
"type": "text", | |
"description": "The query to perform. This should be semantically close to your target documents. Use the affirmative form rather than a question.", | |
}, | |
"source": { | |
"type": "text", | |
"description": "", | |
}, | |
} | |
output_type = "text" | |
def __init__(self, vectordb: VectorStore, all_sources: str, **kwargs): | |
super().__init__(**kwargs) | |
self.vectordb = vectordb | |
self.inputs["source"]["description"] = ( | |
f"The source of the documents to search, as a str representation of a list. Possible values in the list are: {all_sources}. If this argument is not provided, all sources will be searched." | |
) | |
def forward(self, query: str, source: str = None) -> str: | |
assert isinstance(query, str), "Your search query must be a string" | |
if source: | |
if isinstance(source, str) and "[" not in str(source): # if the source is not representing a list | |
source = [source] | |
source = json.loads(str(source).replace("'", '"')) | |
docs = self.vectordb.similarity_search(query, filter=({"source": source} if source else None), k=3) | |
if len(docs) == 0: | |
return "No documents found with this filtering. Try removing the source filter." | |
return "Retrieved documents:\n\n" + "\n===Document===\n".join( | |
[doc.page_content for doc in docs] | |
) | |
from transformers.agents import HfEngine, ReactJsonAgent | |
# Initialize the LLM engine and the agent with the retriever tool | |
llm_engine = HfEngine("meta-llama/Meta-Llama-3-8B-Instruct") | |
all_sources = list(set([doc.metadata["source"] for doc in docs_processed])) | |
retriever_tool = RetrieverTool(vectordb, all_sources) | |
agent = ReactJsonAgent(tools=[retriever_tool], llm_engine=llm_engine) | |
# Should change this section for the finetuned model | |
SYSTEM_PROMPT = [ | |
{ | |
"role": "system", | |
"content": [ | |
{ | |
"type": "text", | |
"text": "You are YALLA, a personalized AI chatbot assistant designed to enhance the user's experience in Morocco. Your mission is to provide accurate, real-time, and culturally rich information to make their visit enjoyable and stress-free. You can handle text and image inputs, offering recommendations on transport, event schedules, dining, accommodations, and cultural experiences. You can also perform real-time web searches and use various APIs to assist users effectively. Always be respectful, polite, and inclusive, and strive to offer truthful and helpful responses.", | |
}, | |
], | |
}, | |
{ | |
"role": "assistant", | |
"content": [ | |
{ | |
"type": "text", | |
"text": "Hello, I'm YALLA, your personalized AI assistant for exploring Morocco. How can I assist you today?", | |
}, | |
], | |
} | |
] | |
examples_path = os.path.dirname(__file__) | |
EXAMPLES = [ | |
[ | |
{ | |
"text": "For 2024, the interest expense is twice what it was in 2014, and the long-term debt is 10% higher than its 2015 level. Can you calculate the combined total of the interest and long-term debt for 2024?", | |
"files": [f"{examples_path}/example_images/mmmu_example_2.png"], | |
} | |
], | |
[ | |
{ | |
"text": "What's in the image?", | |
"files": [f"{examples_path}/example_images/plant_bulb.webp"], | |
} | |
], | |
[ | |
{ | |
"text": "Describe the image", | |
"files": [f"{examples_path}/example_images/baguettes_guarding_paris.png"], | |
} | |
], | |
[ | |
{ | |
"text": "Read what's written on the paper", | |
"files": [f"{examples_path}/example_images/paper_with_text.png"], | |
} | |
], | |
] | |
# BOT_AVATAR = "IDEFICS_logo.png" | |
BOT_AVATAR = "YALLA_logo.png" | |
# Chatbot utils | |
def turn_is_pure_media(turn): | |
return turn[1] is None | |
def load_image_from_url(url): | |
with urllib.request.urlopen(url) as response: | |
image_data = response.read() | |
image_stream = io.BytesIO(image_data) | |
image = Image.open(image_stream) | |
return image | |
def img_to_bytes(image_path): | |
image = Image.open(image_path).convert(mode='RGB') | |
buffer = io.BytesIO() | |
image.save(buffer, format="JPEG") | |
img_bytes = buffer.getvalue() | |
image.close() | |
return img_bytes | |
def format_user_prompt_with_im_history_and_system_conditioning( | |
user_prompt, chat_history | |
) -> List[Dict[str, Union[List, str]]]: | |
""" | |
Produces the resulting list that needs to go inside the processor. | |
It handles the potential image(s), the history and the system conditionning. | |
""" | |
resulting_messages = copy.deepcopy(SYSTEM_PROMPT) | |
resulting_images = [] | |
for resulting_message in resulting_messages: | |
if resulting_message["role"] == "user": | |
for content in resulting_message["content"]: | |
if content["type"] == "image": | |
resulting_images.append(load_image_from_url(content["image"])) | |
# Format history | |
for turn in chat_history: | |
if not resulting_messages or ( | |
resulting_messages and resulting_messages[-1]["role"] != "user" | |
): | |
resulting_messages.append( | |
{ | |
"role": "user", | |
"content": [], | |
} | |
) | |
if turn_is_pure_media(turn): | |
media = turn[0][0] | |
resulting_messages[-1]["content"].append({"type": "image"}) | |
resulting_images.append(Image.open(media)) | |
else: | |
user_utterance, assistant_utterance = turn | |
resulting_messages[-1]["content"].append( | |
{"type": "text", "text": user_utterance.strip()} | |
) | |
resulting_messages.append( | |
{ | |
"role": "assistant", | |
"content": [{"type": "text", "text": user_utterance.strip()}], | |
} | |
) | |
# Format current input | |
if not user_prompt["files"]: | |
resulting_messages.append( | |
{ | |
"role": "user", | |
"content": [{"type": "text", "text": user_prompt["text"]}], | |
} | |
) | |
else: | |
# Choosing to put the image first (i.e. before the text), but this is an arbiratrary choice. | |
resulting_messages.append( | |
{ | |
"role": "user", | |
"content": [{"type": "image"}] * len(user_prompt["files"]) | |
+ [{"type": "text", "text": user_prompt["text"]}], | |
} | |
) | |
resulting_images.extend([Image.open(path) for path in user_prompt["files"]]) | |
return resulting_messages, resulting_images | |
def extract_images_from_msg_list(msg_list): | |
all_images = [] | |
for msg in msg_list: | |
for c_ in msg["content"]: | |
if isinstance(c_, Image.Image): | |
all_images.append(c_) | |
return all_images | |
def model_inference( | |
user_prompt, | |
chat_history, | |
model_selector, | |
decoding_strategy, | |
temperature, | |
max_new_tokens, | |
repetition_penalty, | |
top_p, | |
): | |
if user_prompt["text"].strip() == "" and not user_prompt["files"]: | |
gr.Error("Please input a query and optionally image(s).") | |
if user_prompt["text"].strip() == "" and user_prompt["files"]: | |
gr.Error("Please input a text query along the image(s).") | |
streamer = TextIteratorStreamer( | |
PROCESSOR.tokenizer, | |
skip_prompt=True, | |
timeout=5.0, | |
) | |
# Common parameters to all decoding strategies | |
generation_args = { | |
"max_new_tokens": max_new_tokens, | |
"repetition_penalty": repetition_penalty, | |
"streamer": streamer, | |
} | |
assert decoding_strategy in [ | |
"Greedy", | |
"Top P Sampling", | |
] | |
if decoding_strategy == "Greedy": | |
generation_args["do_sample"] = False | |
elif decoding_strategy == "Top P Sampling": | |
generation_args["temperature"] = temperature | |
generation_args["do_sample"] = True | |
generation_args["top_p"] = top_p | |
# Creating model inputs | |
resulting_text, resulting_images = format_user_prompt_with_im_history_and_system_conditioning( | |
user_prompt=user_prompt, | |
chat_history=chat_history, | |
) | |
prompt = PROCESSOR.apply_chat_template(resulting_text, add_generation_prompt=True) | |
inputs = PROCESSOR( | |
text=prompt, | |
images=resulting_images if resulting_images else None, | |
return_tensors="pt", | |
) | |
inputs = {k: v.to(DEVICE) for k, v in inputs.items()} | |
generation_args.update(inputs) | |
# Use the agent to perform RAG | |
agent_output = agent.run(user_prompt["text"]) | |
print("Agent output:", agent_output) | |
# Stream the generated text | |
thread = Thread( | |
target=MODELS[model_selector].generate, | |
kwargs=generation_args, | |
) | |
thread.start() | |
acc_text = "" | |
for text_token in streamer: | |
time.sleep(0.04) | |
acc_text += text_token | |
if acc_text.endswith("<end_of_utterance>"): | |
acc_text = acc_text[:-18] | |
yield acc_text | |
print("Success - generated the following text:", acc_text) | |
print("-----") | |
FEATURES = datasets.Features( | |
{ | |
"model_selector": datasets.Value("string"), | |
"images": datasets.Sequence(datasets.Image(decode=True)), | |
"conversation": datasets.Sequence({"User": datasets.Value("string"), "Assistant": datasets.Value("string")}), | |
"decoding_strategy": datasets.Value("string"), | |
"temperature": datasets.Value("float32"), | |
"max_new_tokens": datasets.Value("int32"), | |
"repetition_penalty": datasets.Value("float32"), | |
"top_p": datasets.Value("int32"), | |
} | |
) | |
# Hyper-parameters for generation | |
max_new_tokens = gr.Slider( | |
minimum=8, | |
maximum=1024, | |
value=512, | |
step=1, | |
interactive=True, | |
label="Maximum number of new tokens to generate", | |
) | |
repetition_penalty = gr.Slider( | |
minimum=0.01, | |
maximum=5.0, | |
value=1.1, | |
step=0.01, | |
interactive=True, | |
label="Repetition penalty", | |
info="1.0 is equivalent to no penalty", | |
) | |
decoding_strategy = gr.Radio( | |
[ | |
"Greedy", | |
"Top P Sampling", | |
], | |
value="Greedy", | |
label="Decoding strategy", | |
interactive=True, | |
info="Higher values is equivalent to sampling more low-probability tokens.", | |
) | |
temperature = gr.Slider( | |
minimum=0.0, | |
maximum=5.0, | |
value=0.4, | |
step=0.1, | |
visible=False, | |
interactive=True, | |
label="Sampling temperature", | |
info="Higher values will produce more diverse outputs.", | |
) | |
top_p = gr.Slider( | |
minimum=0.01, | |
maximum=0.99, | |
value=0.8, | |
step=0.01, | |
visible=False, | |
interactive=True, | |
label="Top P", | |
info="Higher values is equivalent to sampling more low-probability tokens.", | |
) | |
chatbot = gr.Chatbot( | |
label="YALLA-Chatty", | |
avatar_images=[None, BOT_AVATAR], | |
height=450, | |
) | |
with gr.Blocks(fill_height=True) as demo: | |
gr.Markdown("# 🇲🇦 YALLA ") | |
with gr.Row(elem_id="model_selector_row"): | |
model_selector = gr.Dropdown( | |
choices=MODELS.keys(), | |
value=list(MODELS.keys())[0], | |
interactive=True, | |
show_label=False, | |
container=False, | |
label="Model", | |
visible=False, | |
) | |
decoding_strategy.change( | |
fn=lambda selection: gr.Slider( | |
visible=( | |
selection | |
in [ | |
"contrastive_sampling", | |
"beam_sampling", | |
"Top P Sampling", | |
"sampling_top_k", | |
] | |
) | |
), | |
inputs=decoding_strategy, | |
outputs=temperature, | |
) | |
decoding_strategy.change( | |
fn=lambda selection: gr.Slider(visible=(selection in ["Top P Sampling"])), | |
inputs=decoding_strategy, | |
outputs=top_p, | |
) | |
gr.ChatInterface( | |
fn=model_inference, | |
chatbot=chatbot, | |
examples=EXAMPLES, | |
multimodal=False, | |
cache_examples=False, | |
additional_inputs=[ | |
model_selector, | |
decoding_strategy, | |
temperature, | |
max_new_tokens, | |
repetition_penalty, | |
top_p, | |
], | |
) | |
demo.launch() | |