|
import gradio as gr |
|
from datasets import load_dataset |
|
|
|
import os |
|
import spaces |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer |
|
import torch |
|
from threading import Thread |
|
from sentence_transformers import SentenceTransformer |
|
from datasets import load_dataset |
|
import time |
|
|
|
token = os.environ["HF_TOKEN"] |
|
model = AutoModelForCausalLM.from_pretrained( |
|
"google/gemma-7b-it", |
|
|
|
torch_dtype=torch.float16, |
|
token=token, |
|
) |
|
tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b-it", token=token) |
|
device = torch.device("cuda") |
|
model = model.to(device) |
|
RAG = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1") |
|
TOP_K = 1 |
|
HEADER = "\n# RESOURCES:\n" |
|
|
|
|
|
|
|
data = load_dataset("not-lain/wikipedia-small-3000-embedded", split="train") |
|
|
|
|
|
data.add_faiss_index("embedding") |
|
|
|
|
|
def search(query: str, k: int = TOP_K): |
|
embedded_query = RAG.encode(query) |
|
scores, retrieved_examples = data.get_nearest_examples( |
|
"embedding", embedded_query, k=k |
|
) |
|
return retrieved_examples |
|
|
|
|
|
def prepare_prompt(query, retrieved_examples): |
|
prompt = ( |
|
f"Query: {query}\nContinue to answer the query by using the Search Results:\n" |
|
) |
|
urls = [] |
|
titles = retrieved_examples["title"][::-1] |
|
texts = retrieved_examples["text"][::-1] |
|
urls = retrieved_examples["url"][::-1] |
|
titles = titles[::-1] |
|
for i in range(TOP_K): |
|
prompt += f"* {texts[i]}\n" |
|
return prompt, zip(titles, urls) |
|
|
|
|
|
@spaces.GPU(duration=150) |
|
def talk(message, history): |
|
print("history, ", history) |
|
print("message ", message) |
|
print("searching dataset ...") |
|
retrieved_examples = search(message) |
|
print("preparing prompt ...") |
|
message, metadata = prepare_prompt(message, retrieved_examples) |
|
resources = HEADER |
|
print("preparing metadata ...") |
|
for title, url in metadata: |
|
resources += f"[{title}]({url}), " |
|
print("preparing chat template ...") |
|
chat = [] |
|
for item in history: |
|
chat.append({"role": "user", "content": item[0]}) |
|
cleaned_past = item[1].split(HEADER)[0] |
|
chat.append({"role": "assistant", "content": cleaned_past}) |
|
chat.append({"role": "user", "content": message}) |
|
messages = tokenizer.apply_chat_template( |
|
chat, tokenize=False, add_generation_prompt=True |
|
) |
|
print("chat template prepared, ", messages) |
|
print("tokenizing input ...") |
|
|
|
model_inputs = tokenizer([messages], return_tensors="pt").to(device) |
|
streamer = TextIteratorStreamer( |
|
tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True |
|
) |
|
generate_kwargs = dict( |
|
model_inputs, |
|
streamer=streamer, |
|
max_new_tokens=1024, |
|
do_sample=True, |
|
top_p=0.95, |
|
top_k=1000, |
|
temperature=0.75, |
|
num_beams=1, |
|
) |
|
print("initializing thread ...") |
|
t = Thread(target=model.generate, kwargs=generate_kwargs) |
|
t.start() |
|
time.sleep(1) |
|
|
|
partial_text = "" |
|
for new_text in streamer: |
|
if new_text is not None: |
|
partial_text += new_text |
|
yield partial_text |
|
partial_text += resources |
|
yield partial_text |
|
|
|
|
|
TITLE = "# RAG" |
|
|
|
DESCRIPTION = """ |
|
A rag pipeline with a chatbot feature |
|
|
|
Resources used to build this project : |
|
|
|
* embedding model : https://huggingface.co/mixedbread-ai/mxbai-embed-large-v1 |
|
* dataset : https://huggingface.co/datasets/not-lain/wikipedia-small-3000-embedded (used mxbai-colbert-large-v1 to create the embedding column ) |
|
* faiss docs : https://huggingface.co/docs/datasets/v2.18.0/en/package_reference/main_classes#datasets.Dataset.add_faiss_index |
|
* chatbot : https://huggingface.co/google/gemma-7b-it |
|
|
|
If you want to support my work please click on the heart react button β€οΈπ€ |
|
|
|
<sub><sup><sub><sup>psst, I am still open for work, so please reach me out at https://not-lain.github.io/</sup></sub></sup></sub> |
|
""" |
|
|
|
|
|
demo = gr.ChatInterface( |
|
fn=talk, |
|
chatbot=gr.Chatbot( |
|
show_label=True, |
|
show_share_button=True, |
|
show_copy_button=True, |
|
likeable=True, |
|
layout="bubble", |
|
bubble_full_width=False, |
|
), |
|
theme="Soft", |
|
examples=[["what is machine learning"]], |
|
title=TITLE, |
|
description=DESCRIPTION, |
|
) |
|
demo.launch(debug=True) |
|
|