not-lain's picture
🌘wπŸŒ–
e4b2161
raw
history blame
No virus
4.14 kB
import gradio as gr
from datasets import load_dataset
import os
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import torch
from threading import Thread
from sentence_transformers import SentenceTransformer
from datasets import load_dataset
token = os.environ["HF_TOKEN"]
model = AutoModelForCausalLM.from_pretrained(
"google/gemma-7b-it",
# torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
torch_dtype=torch.float16,
token=token,
)
tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b-it", token=token)
device = torch.device("cuda")
model = model.to(device)
RAG = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")
TOP_K = 3
# prepare data
# since data is too big we will only select the first 3K lines
data = load_dataset("not-lain/wikipedia-small-3000-embedded", split="train")
# index dataset
data.add_faiss_index("embedding")
@spaces.GPU
def search(query: str, k: int = TOP_K):
embedded_query = RAG.encode(query)
scores, retrieved_examples = data.get_nearest_examples(
"embedding", embedded_query, k=k
)
return retrieved_examples
def prepare_prompt(query, retrieved_examples):
prompt = (
f"Query: {query}\nContinue to answer the query by using the Search Results:\n"
)
urls = []
titles = retrieved_examples["title"][::-1]
texts = retrieved_examples["text"][::-1]
urls = retrieved_examples["url"][::-1]
titles = titles[::-1]
for i in range(TOP_K):
prompt += f"Title: {titles[i]}, Text: {texts[i]}\n"
return prompt, (titles, urls)
@spaces.GPU
def talk(message, history):
retrieved_examples = search(message)
message, metadata = prepare_prompt(message, retrieved_examples)
resources = "\nRESOURCES:\n"
for title, url in metadata:
resources += f"[{title}]({url}), "
chat = []
for item in history:
chat.append({"role": "user", "content": item[0]})
if item[1] is not None:
cleaned_past = item[1].split("\nRESOURCES:\n")[0]
chat.append({"role": "assistant", "content": cleaned_past})
chat.append({"role": "user", "content": message})
messages = tokenizer.apply_chat_template(
chat, tokenize=False, add_generation_prompt=True
)
# Tokenize the messages string
model_inputs = tokenizer([messages], return_tensors="pt").to(device)
streamer = TextIteratorStreamer(
tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
)
generate_kwargs = dict(
model_inputs,
streamer=streamer,
max_new_tokens=1024,
do_sample=True,
top_p=0.95,
top_k=1000,
temperature=0.75,
num_beams=1,
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
# Initialize an empty string to store the generated text
partial_text = ""
for new_text in streamer:
partial_text += new_text
yield partial_text
partial_text += resources
yield partial_text
TITLE = "RAG"
DESCRIPTION = """
A rag pipeline with a chatbot feature
Resources used to build this project :
* embedding model : https://huggingface.co/mixedbread-ai/mxbai-embed-large-v1
* dataset : https://huggingface.co/datasets/not-lain/wikipedia-small-3000-embedded (used mxbai-colbert-large-v1 to create the embedding column )
* faiss docs : https://huggingface.co/docs/datasets/v2.18.0/en/package_reference/main_classes#datasets.Dataset.add_faiss_index
* chatbot : google/gemma-7b-it
If you want to support my work please click on the heart react button β€οΈπŸ€—
<sub><sup><sub><sup>psst, I am still open for work if please reach me out at https://not-lain.github.io/</sup></sub></sup></sub>
"""
demo = gr.ChatInterface(
fn=talk,
chatbot=gr.Chatbot(
show_label=True,
show_share_button=True,
show_copy_button=True,
likeable=True,
layout="bubble",
bubble_full_width=False,
),
theme="Soft",
examples=[["what is machine learning"]],
title=TITLE,
description=DESCRIPTION,
)
demo.launch()