not-lain's picture
🌘wπŸŒ–
9e442ad verified
raw
history blame
4.6 kB
import gradio as gr
from datasets import load_dataset
import os
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import torch
from threading import Thread
from sentence_transformers import SentenceTransformer
from datasets import load_dataset
import time
token = os.environ["HF_TOKEN"]
model = AutoModelForCausalLM.from_pretrained(
"google/gemma-7b-it",
# torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
torch_dtype=torch.float16,
token=token,
)
tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b-it", token=token)
device = torch.device("cuda")
model = model.to(device)
RAG = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")
TOP_K = 1
HEADER = "\n# RESOURCES:\n"
# prepare data
# since data is too big we will only select the first 3K lines
data = load_dataset("not-lain/wikipedia-small-3000-embedded", split="train")
# index dataset
data.add_faiss_index("embedding")
def search(query: str, k: int = TOP_K):
embedded_query = RAG.encode(query)
scores, retrieved_examples = data.get_nearest_examples(
"embedding", embedded_query, k=k
)
return retrieved_examples
def prepare_prompt(query, retrieved_examples):
prompt = (
f"Query: {query}\nContinue to answer the query in short sentences by using the Search Results:\n"
)
urls = []
titles = retrieved_examples["title"][::-1]
texts = retrieved_examples["text"][::-1]
urls = retrieved_examples["url"][::-1]
titles = titles[::-1]
for i in range(TOP_K):
prompt += f"* {texts[i]}\n"
return prompt, zip(titles, urls)
@spaces.GPU(duration=150)
def talk(message, history):
print("history, ", history)
print("message ", message)
print("searching dataset ...")
retrieved_examples = search(message)
print("preparing prompt ...")
message, metadata = prepare_prompt(message, retrieved_examples)
resources = HEADER
print("preparing metadata ...")
for title, url in metadata:
resources += f"[{title}]({url}), "
print("preparing chat template ...")
chat = []
for item in history:
chat.append({"role": "user", "content": item[0]})
cleaned_past = item[1].split(HEADER)[0]
chat.append({"role": "assistant", "content": cleaned_past})
chat.append({"role": "user", "content": message})
messages = tokenizer.apply_chat_template(
chat, tokenize=False, add_generation_prompt=True
)
print("chat template prepared, ", messages)
print("tokenizing input ...")
# Tokenize the messages string
model_inputs = tokenizer([messages], return_tensors="pt").to(device)
streamer = TextIteratorStreamer(
tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
)
generate_kwargs = dict(
model_inputs,
streamer=streamer,
max_new_tokens=1024,
do_sample=True,
top_p=0.95,
top_k=1000,
temperature=0.75,
num_beams=1,
)
print("initializing thread ...")
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
time.sleep(1)
# Initialize an empty string to store the generated text
partial_text = ""
i = 0
while t.is_alive():
try:
for new_text in streamer:
if new_text is not None:
partial_text += new_text
yield partial_text
except Exception as e:
print(f"retry number {i}\n LOGS:\n")
i+=1
print(e, e.args)
partial_text += resources
yield partial_text
TITLE = "# RAG"
DESCRIPTION = """
A rag pipeline with a chatbot feature
Resources used to build this project :
* embedding model : https://huggingface.co/mixedbread-ai/mxbai-embed-large-v1
* dataset : https://huggingface.co/datasets/not-lain/wikipedia-small-3000-embedded (used mxbai-colbert-large-v1 to create the embedding column )
* faiss docs : https://huggingface.co/docs/datasets/v2.18.0/en/package_reference/main_classes#datasets.Dataset.add_faiss_index
* chatbot : https://huggingface.co/google/gemma-7b-it
If you want to support my work consider clicking on the heart react button β€οΈπŸ€—
"""
demo = gr.ChatInterface(
fn=talk,
chatbot=gr.Chatbot(
show_label=True,
show_share_button=True,
show_copy_button=True,
likeable=True,
layout="bubble",
bubble_full_width=False,
),
examples=[["what's anarchy ? "]],
title=TITLE,
description=DESCRIPTION,
)
demo.launch(debug=True)