Spaces:
Sleeping
Sleeping
import gradio as gr | |
from datasets import load_dataset | |
import os | |
import spaces | |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
import torch | |
from threading import Thread | |
from sentence_transformers import SentenceTransformer | |
from datasets import load_dataset | |
token = os.environ["HF_TOKEN"] | |
model = AutoModelForCausalLM.from_pretrained( | |
"google/gemma-7b-it", | |
# torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
torch_dtype=torch.float16, | |
token=token, | |
) | |
tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b-it", token=token) | |
device = torch.device("cuda") | |
model = model.to(device) | |
RAG = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1") | |
TOP_K = 3 | |
# prepare data | |
# since data is too big we will only select the first 3K lines | |
data = load_dataset("not-lain/wikipedia-small-3000-embedded", split="train") | |
# index dataset | |
data.add_faiss_index("embedding") | |
def search(query: str, k: int = TOP_K): | |
embedded_query = RAG.encode(query) | |
scores, retrieved_examples = data.get_nearest_examples( | |
"embedding", embedded_query, k=k | |
) | |
return retrieved_examples | |
def prepare_prompt(query, retrieved_examples): | |
prompt = ( | |
f"Query: {query}\nContinue to answer the query by using the Search Results:\n" | |
) | |
urls = [] | |
titles = retrieved_examples["title"][::-1] | |
texts = retrieved_examples["text"][::-1] | |
urls = retrieved_examples["url"][::-1] | |
titles = titles[::-1] | |
for i in range(TOP_K): | |
prompt += f"Title: {titles[i]}, Text: {texts[i]}\n" | |
return prompt, (titles, urls) | |
# @spaces.GPU | |
def talk(message, history): | |
retrieved_examples = search(message) | |
message, metadata = prepare_prompt(message, retrieved_examples) | |
resources = "\nRESOURCES:\n" | |
for title, url in metadata: | |
resources += f"[{title}]({url}), " | |
chat = [] | |
for item in history: | |
chat.append({"role": "user", "content": item[0]}) | |
if item[1] is not None: | |
cleaned_past = item[1].split("\nRESOURCES:\n")[0] | |
chat.append({"role": "assistant", "content": cleaned_past}) | |
chat.append({"role": "user", "content": message}) | |
messages = tokenizer.apply_chat_template( | |
chat, tokenize=False, add_generation_prompt=True | |
) | |
# Tokenize the messages string | |
model_inputs = tokenizer([messages], return_tensors="pt").to(device) | |
streamer = TextIteratorStreamer( | |
tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True | |
) | |
generate_kwargs = dict( | |
model_inputs, | |
streamer=streamer, | |
max_new_tokens=1024, | |
do_sample=True, | |
top_p=0.95, | |
top_k=1000, | |
temperature=0.75, | |
num_beams=1, | |
) | |
t = Thread(target=model.generate, kwargs=generate_kwargs) | |
t.start() | |
# Initialize an empty string to store the generated text | |
partial_text = "" | |
for new_text in streamer: | |
partial_text += new_text | |
print("partial_text : ", partial_text) | |
yield partial_text | |
# partial_text += resources | |
# yield partial_text | |
TITLE = "# RAG" | |
DESCRIPTION = """ | |
A rag pipeline with a chatbot feature | |
Resources used to build this project : | |
* embedding model : https://huggingface.co/mixedbread-ai/mxbai-embed-large-v1 | |
* dataset : https://huggingface.co/datasets/not-lain/wikipedia-small-3000-embedded (used mxbai-colbert-large-v1 to create the embedding column ) | |
* faiss docs : https://huggingface.co/docs/datasets/v2.18.0/en/package_reference/main_classes#datasets.Dataset.add_faiss_index | |
* chatbot : https://huggingface.co/google/gemma-7b-it | |
If you want to support my work please click on the heart react button β€οΈπ€ | |
<sub><sup><sub><sup>psst, I am still open for work, so please reach me out at https://not-lain.github.io/</sup></sub></sup></sub> | |
""" | |
demo = gr.ChatInterface( | |
fn=talk, | |
chatbot=gr.Chatbot( | |
show_label=True, | |
show_share_button=True, | |
show_copy_button=True, | |
likeable=True, | |
layout="bubble", | |
bubble_full_width=False, | |
), | |
theme="Soft", | |
examples=[["what is machine learning"]], | |
title=TITLE, | |
description=DESCRIPTION, | |
) | |
demo.launch(debug=True) | |