not-lain commited on
Commit
a10540b
1 Parent(s): 47f9330

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -110
app.py CHANGED
@@ -1,117 +1,8 @@
1
  import gradio as gr
2
- from datasets import load_dataset
3
 
4
- import os
5
- import spaces
6
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
7
- import torch
8
- from threading import Thread
9
- from sentence_transformers import SentenceTransformer
10
- from datasets import load_dataset
11
- import time
12
 
13
- token = os.environ["HF_TOKEN"]
14
- model = AutoModelForCausalLM.from_pretrained(
15
- "google/gemma-7b-it",
16
- # torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
17
- torch_dtype=torch.float16,
18
- token=token,
19
- )
20
- tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b-it", token=token)
21
- device = torch.device("cuda")
22
- model = model.to(device)
23
- RAG = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")
24
- TOP_K = 1
25
- HEADER = "\n# RESOURCES:\n"
26
- # prepare data
27
- # since data is too big we will only select the first 3K lines
28
-
29
- data = load_dataset("not-lain/wikipedia-small-3000-embedded", split="train")
30
-
31
- # index dataset
32
- data.add_faiss_index("embedding")
33
-
34
-
35
- def search(query: str, k: int = TOP_K):
36
- embedded_query = RAG.encode(query)
37
- scores, retrieved_examples = data.get_nearest_examples(
38
- "embedding", embedded_query, k=k
39
- )
40
- return retrieved_examples
41
-
42
-
43
- def prepare_prompt(query, retrieved_examples):
44
- prompt = (
45
- f"Query: {query}\nContinue to answer the query in short sentences by using the Search Results:\n"
46
- )
47
- urls = []
48
- titles = retrieved_examples["title"][::-1]
49
- texts = retrieved_examples["text"][::-1]
50
- urls = retrieved_examples["url"][::-1]
51
- titles = titles[::-1]
52
- for i in range(TOP_K):
53
- prompt += f"* {texts[i]}\n"
54
- return prompt, zip(titles, urls)
55
-
56
-
57
- @spaces.GPU(duration=150)
58
  def talk(message, history):
59
- print("history, ", history)
60
- print("message ", message)
61
- print("searching dataset ...")
62
- retrieved_examples = search(message)
63
- print("preparing prompt ...")
64
- message, metadata = prepare_prompt(message, retrieved_examples)
65
- resources = HEADER
66
- print("preparing metadata ...")
67
- for title, url in metadata:
68
- resources += f"[{title}]({url}), "
69
- print("preparing chat template ...")
70
- chat = []
71
- for item in history:
72
- chat.append({"role": "user", "content": item[0]})
73
- cleaned_past = item[1].split(HEADER)[0]
74
- chat.append({"role": "assistant", "content": cleaned_past})
75
- chat.append({"role": "user", "content": message})
76
- messages = tokenizer.apply_chat_template(
77
- chat, tokenize=False, add_generation_prompt=True
78
- )
79
- print("chat template prepared, ", messages)
80
- print("tokenizing input ...")
81
- # Tokenize the messages string
82
- model_inputs = tokenizer([messages], return_tensors="pt").to(device)
83
- streamer = TextIteratorStreamer(
84
- tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
85
- )
86
- generate_kwargs = dict(
87
- model_inputs,
88
- streamer=streamer,
89
- max_new_tokens=1024,
90
- do_sample=True,
91
- top_p=0.95,
92
- top_k=1000,
93
- temperature=0.75,
94
- num_beams=1,
95
- )
96
- print("initializing thread ...")
97
- t = Thread(target=model.generate, kwargs=generate_kwargs)
98
- t.start()
99
- time.sleep(1)
100
- # Initialize an empty string to store the generated text
101
- partial_text = ""
102
- i = 0
103
- while t.is_alive():
104
- try:
105
- for new_text in streamer:
106
- if new_text is not None:
107
- partial_text += new_text
108
- yield partial_text
109
- except Exception as e:
110
- print(f"retry number {i}\n LOGS:\n")
111
- i+=1
112
- print(e, e.args)
113
- partial_text += resources
114
- yield partial_text
115
 
116
 
117
  TITLE = "# RAG"
@@ -127,6 +18,8 @@ Resources used to build this project :
127
  * chatbot : https://huggingface.co/google/gemma-7b-it
128
 
129
  If you want to support my work consider clicking on the heart react button ❤️🤗
 
 
130
  """
131
 
132
 
 
1
  import gradio as gr
 
2
 
 
 
 
 
 
 
 
 
3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  def talk(message, history):
5
+ return "hi"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
 
8
  TITLE = "# RAG"
 
18
  * chatbot : https://huggingface.co/google/gemma-7b-it
19
 
20
  If you want to support my work consider clicking on the heart react button ❤️🤗
21
+
22
+ (testing the ui)
23
  """
24
 
25