not-lain commited on
Commit
1b7e4b0
β€’
1 Parent(s): 3ed215d

🌘wπŸŒ’

Browse files
Files changed (2) hide show
  1. app.py +90 -103
  2. requirements.txt +4 -4
app.py CHANGED
@@ -1,106 +1,90 @@
1
  import gradio as gr
2
- from datasets import load_dataset
3
- from sentence_transformers import SentenceTransformer
4
- from sentence_transformers.quantization import quantize_embeddings
5
- import faiss
6
- from usearch.index import Index
7
  import os
8
  import spaces
9
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
10
  import torch
11
  from threading import Thread
12
- from huggingface_hub import hf_hub_download
 
 
13
 
14
  token = os.environ["HF_TOKEN"]
15
- model = AutoModelForCausalLM.from_pretrained("google/gemma-7b-it",
16
- # torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
17
- torch_dtype=torch.float16,
18
- token=token)
19
- tok = AutoTokenizer.from_pretrained("google/gemma-7b-it",token=token)
20
- device = torch.device('cuda')
 
 
21
  model = model.to(device)
 
22
 
23
- # Load titles and texts
24
- title_text_dataset = load_dataset(
25
- "mixedbread-ai/wikipedia-data-en-2023-11", split="train", num_proc=4
26
- ).select_columns(["title", "text"])
27
 
28
- # Load the int8 and binary indices. Int8 is loaded as a view to save memory, as we never actually perform search with it.
29
- path_int8_view = hf_hub_download(repo_id="sentence-transformers/quantized-retrieval",repo_type="space", filename="wikipedia_ubinary_faiss_1m.index")
30
- int8_view = Index.restore(path_int8_view, view=True)
31
-
32
- path_binary_index = hf_hub_download(repo_id="sentence-transformers/quantized-retrieval",repo_type="space", filename="wikipedia_ubinary_faiss_1m.index")
33
- binary_index: faiss.IndexBinaryFlat = faiss.read_index_binary(
34
- path_binary_index
35
- )
36
-
37
- # Load the SentenceTransformer model for embedding the queries
38
- model = SentenceTransformer(
39
- "mixedbread-ai/mxbai-embed-large-v1",
40
- prompts={
41
- "retrieval": "Represent this sentence for searching relevant passages: ",
42
- },
43
- default_prompt_name="retrieval",
44
  )
45
-
46
-
47
- def search(
48
- query, top_k: int = 10, rescore_multiplier: int = 1, use_approx: bool = False
49
- ):
50
- # 1. Embed the query as float32
51
- query_embedding = model.encode(query)
52
-
53
- # 2. Quantize the query to ubinary
54
- query_embedding_ubinary = quantize_embeddings(
55
- query_embedding.reshape(1, -1), "ubinary"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  )
 
 
 
 
 
 
 
 
 
 
57
 
58
- # 3. Search the binary index (either exact or approximate)
59
- index = binary_index
60
- _scores, binary_ids = index.search(
61
- query_embedding_ubinary, top_k * rescore_multiplier
62
- )
63
- binary_ids = binary_ids[0]
64
-
65
- # 4. Load the corresponding int8 embeddings
66
- int8_embeddings = int8_view[binary_ids].astype(int)
67
-
68
- # 5. Rescore the top_k * rescore_multiplier using the float32 query embedding and the int8 document embeddings
69
- scores = query_embedding @ int8_embeddings.T
70
-
71
- # 6. Sort the scores and return the top_k
72
- indices = scores.argsort()[::-1][:top_k]
73
- top_k_indices = binary_ids[indices]
74
- top_k_scores = scores[indices]
75
- top_k_titles, top_k_texts = zip(
76
- *[
77
- (title_text_dataset[idx]["title"], title_text_dataset[idx]["text"])
78
- for idx in top_k_indices.tolist()
79
- ]
80
- )
81
- df = {
82
- "Score": [round(value, 2) for value in top_k_scores],
83
- "Title": top_k_titles,
84
- "Text": top_k_texts,
85
- }
86
-
87
- return df
88
-
89
- def prepare_prompt(query, df):
90
- prompt = f"Query: {query}\nContinue to answer the query by using the Search Results:\n"
91
- for data in df :
92
- title = data["Title"]
93
- text = data["Text"]
94
- prompt+=f"Title: {title}, Text: {text}\n"
95
- return prompt
96
 
97
  @spaces.GPU
98
  def talk(message, history):
99
- df = search(message)
100
- message = prepare_prompt(message,df)
101
  resources = "\nRESOURCES:\n"
102
- for title in df["Title"][:3] :
103
- resources+=f"[{title}](https://huggingface.co/spaces/not-lain/RAG), "
104
  chat = []
105
  for item in history:
106
  chat.append({"role": "user", "content": item[0]})
@@ -112,7 +96,8 @@ def talk(message, history):
112
  # Tokenize the messages string
113
  model_inputs = tok([messages], return_tensors="pt").to(device)
114
  streamer = TextIteratorStreamer(
115
- tok, timeout=10., skip_prompt=True, skip_special_tokens=True)
 
116
  generate_kwargs = dict(
117
  model_inputs,
118
  streamer=streamer,
@@ -131,33 +116,35 @@ def talk(message, history):
131
  for new_text in streamer:
132
  partial_text += new_text
133
  yield partial_text
134
- partial_text+= resources
135
  yield partial_text
136
 
137
 
138
-
139
-
140
-
141
  TITLE = "RAG"
142
 
143
  DESCRIPTION = """
144
  ## Resources used to build this project
145
- * https://huggingface.co/learn/cookbook/rag_with_hugging_face_gemma_mongodb
146
- * https://huggingface.co/spaces/sentence-transformers/quantized-retrieval
147
- ## Retrival paramaters
148
- ```python
149
- top_k: int = 10, rescore_multiplier: int = 1, use_approx: bool = False
150
- ```
151
  ## Models
152
  the models used in this space are :
153
  * google/gemma-7b-it
154
- * mixedbread-ai/wikipedia-data-en-2023-11
155
  """
156
 
157
- demo = gr.ChatInterface(fn=talk,
158
- chatbot=gr.Chatbot(show_label=True, show_share_button=True, show_copy_button=True, likeable=True, layout="bubble", bubble_full_width=False),
159
- theme="Soft",
160
- examples=[["what is machine learning"]],
161
- title=TITLE,
162
- description=DESCRIPTION)
 
 
 
 
 
 
 
 
 
163
  demo.launch()
 
1
  import gradio as gr
2
+ from datasets import load_dataset, Dataset
3
+
4
+ # import faiss
 
 
5
  import os
6
  import spaces
7
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
8
  import torch
9
  from threading import Thread
10
+ from ragatouille import RAGPretrainedModel
11
+ from datasets import load_dataset
12
+
13
 
14
  token = os.environ["HF_TOKEN"]
15
+ model = AutoModelForCausalLM.from_pretrained(
16
+ "google/gemma-7b-it",
17
+ # torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
18
+ torch_dtype=torch.float16,
19
+ token=token,
20
+ )
21
+ tok = AutoTokenizer.from_pretrained("google/gemma-7b-it", token=token)
22
+ device = torch.device("cuda")
23
  model = model.to(device)
24
+ RAG = RAGPretrainedModel.from_pretrained("mixedbread-ai/mxbai-colbert-v1")
25
 
26
+ # prepare data
27
+ # since data is too big we will only select the first 3K lines
 
 
28
 
29
+ dataset = load_dataset(
30
+ "wikimedia/wikipedia", "20231101.en", split="train", streaming=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  )
32
+ # init data
33
+ data = Dataset.from_dict({})
34
+ i = 0
35
+ for i, entry in enumerate(dataset):
36
+ # each entry has the following columns
37
+ # ['id', 'url', 'title', 'text']
38
+ data.add_item(entry)
39
+ if i == 3000:
40
+ break
41
+ # free memory
42
+ del dataset # we keep data
43
+
44
+ # index data
45
+ documents = data["text"]
46
+ RAG.index(documents, index_name="wikipedia", use_faiss=True)
47
+ # free memory
48
+ del documents
49
+
50
+ def search(query, k: int = 5):
51
+ results = RAG.search(query, k=k)
52
+ # results are ordered according to their score
53
+ # results has the following keys
54
+ #
55
+ # {'content' : 'retrieved content'
56
+ # 'score' : score[float]
57
+ # 'rank' : "results are sorted using score and each is given a rank, also can be called place, 1 2 3 4 ..."
58
+ # 'document_id' : "no clue man i just got here"
59
+ # 'passage_id' : "or original row number"
60
+ # }
61
+ #
62
+ return [result["passage_id"] for result in results]
63
+
64
+
65
+ def prepare_prompt(query, indexes,data = data):
66
+ prompt = (
67
+ f"Query: {query}\nContinue to answer the query by using the Search Results:\n"
68
  )
69
+ titles = []
70
+ urls = []
71
+ for i in indexes:
72
+ title = entry["title"][i]
73
+ text = entry["text"][i]
74
+ url = entry["url"][i]
75
+ titles.append(title)
76
+ urls.append(url)
77
+ prompt += f"Title: {title}, Text: {text}\n"
78
+ return prompt, (titles,urls)
79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
  @spaces.GPU
82
  def talk(message, history):
83
+ indexes = search(message)
84
+ message,metadata = prepare_prompt(message, indexes)
85
  resources = "\nRESOURCES:\n"
86
+ for title,url in metadata:
87
+ resources += f"[{title}]({url}), "
88
  chat = []
89
  for item in history:
90
  chat.append({"role": "user", "content": item[0]})
 
96
  # Tokenize the messages string
97
  model_inputs = tok([messages], return_tensors="pt").to(device)
98
  streamer = TextIteratorStreamer(
99
+ tok, timeout=10.0, skip_prompt=True, skip_special_tokens=True
100
+ )
101
  generate_kwargs = dict(
102
  model_inputs,
103
  streamer=streamer,
 
116
  for new_text in streamer:
117
  partial_text += new_text
118
  yield partial_text
119
+ partial_text += resources
120
  yield partial_text
121
 
122
 
 
 
 
123
  TITLE = "RAG"
124
 
125
  DESCRIPTION = """
126
  ## Resources used to build this project
127
+ * https://huggingface.co/mixedbread-ai/mxbai-colbert-large-v1
128
+ * me 😎
 
 
 
 
129
  ## Models
130
  the models used in this space are :
131
  * google/gemma-7b-it
132
+ * mixedbread-ai/mxbai-colbert-v1
133
  """
134
 
135
+ demo = gr.ChatInterface(
136
+ fn=talk,
137
+ chatbot=gr.Chatbot(
138
+ show_label=True,
139
+ show_share_button=True,
140
+ show_copy_button=True,
141
+ likeable=True,
142
+ layout="bubble",
143
+ bubble_full_width=False,
144
+ ),
145
+ theme="Soft",
146
+ examples=[["what is machine learning"]],
147
+ title=TITLE,
148
+ description=DESCRIPTION,
149
+ )
150
  demo.launch()
requirements.txt CHANGED
@@ -1,6 +1,6 @@
1
  spaces
2
  torch==2.2.0
3
- git+https://github.com/huggingface/transformers/
4
- git+https://github.com/tomaarsen/sentence-transformers@feat/quantization
5
- usearch
6
- faiss-cpu
 
1
  spaces
2
  torch==2.2.0
3
+ transformers
4
+ faiss-gpu
5
+ ragatouille
6
+ datasets