qqaatw commited on
Commit
96a5a56
1 Parent(s): a75e207
Files changed (4) hide show
  1. .gitattributes +27 -27
  2. README.md +29 -5
  3. app.py +56 -0
  4. requirements.txt +3 -0
.gitattributes CHANGED
@@ -1,27 +1,27 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bin.* filter=lfs diff=lfs merge=lfs -text
5
- *.bz2 filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.model filter=lfs diff=lfs merge=lfs -text
12
- *.msgpack filter=lfs diff=lfs merge=lfs -text
13
- *.onnx filter=lfs diff=lfs merge=lfs -text
14
- *.ot filter=lfs diff=lfs merge=lfs -text
15
- *.parquet filter=lfs diff=lfs merge=lfs -text
16
- *.pb filter=lfs diff=lfs merge=lfs -text
17
- *.pt filter=lfs diff=lfs merge=lfs -text
18
- *.pth filter=lfs diff=lfs merge=lfs -text
19
- *.rar filter=lfs diff=lfs merge=lfs -text
20
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
21
- *.tar.* filter=lfs diff=lfs merge=lfs -text
22
- *.tflite filter=lfs diff=lfs merge=lfs -text
23
- *.tgz filter=lfs diff=lfs merge=lfs -text
24
- *.xz filter=lfs diff=lfs merge=lfs -text
25
- *.zip filter=lfs diff=lfs merge=lfs -text
26
- *.zstandard filter=lfs diff=lfs merge=lfs -text
27
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bin.* filter=lfs diff=lfs merge=lfs -text
5
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.model filter=lfs diff=lfs merge=lfs -text
12
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
13
+ *.onnx filter=lfs diff=lfs merge=lfs -text
14
+ *.ot filter=lfs diff=lfs merge=lfs -text
15
+ *.parquet filter=lfs diff=lfs merge=lfs -text
16
+ *.pb filter=lfs diff=lfs merge=lfs -text
17
+ *.pt filter=lfs diff=lfs merge=lfs -text
18
+ *.pth filter=lfs diff=lfs merge=lfs -text
19
+ *.rar filter=lfs diff=lfs merge=lfs -text
20
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
21
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
22
+ *.tflite filter=lfs diff=lfs merge=lfs -text
23
+ *.tgz filter=lfs diff=lfs merge=lfs -text
24
+ *.xz filter=lfs diff=lfs merge=lfs -text
25
+ *.zip filter=lfs diff=lfs merge=lfs -text
26
+ *.zstandard filter=lfs diff=lfs merge=lfs -text
27
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,13 +1,37 @@
1
  ---
2
- title: Realm
3
- emoji: 😻
4
  colorFrom: pink
5
  colorTo: green
6
  sdk: gradio
7
- sdk_version: 2.8.9
8
  app_file: app.py
9
  pinned: false
10
- license: apache-2.0
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces#reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: REALM Demo
3
+ emoji: 💻
4
  colorFrom: pink
5
  colorTo: green
6
  sdk: gradio
 
7
  app_file: app.py
8
  pinned: false
 
9
  ---
10
 
11
+ # Configuration
12
+
13
+ `title`: _string_
14
+ Display title for the Space
15
+
16
+ `emoji`: _string_
17
+ Space emoji (emoji-only character allowed)
18
+
19
+ `colorFrom`: _string_
20
+ Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
21
+
22
+ `colorTo`: _string_
23
+ Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
24
+
25
+ `sdk`: _string_
26
+ Can be either `gradio` or `streamlit`
27
+
28
+ `sdk_version` : _string_
29
+ Only applicable for `streamlit` SDK.
30
+ See [doc](https://hf.co/docs/hub/spaces) for more info on supported versions.
31
+
32
+ `app_file`: _string_
33
+ Path to your main application file (which contains either `gradio` or `streamlit` Python code).
34
+ Path is relative to the root of the repository.
35
+
36
+ `pinned`: _boolean_
37
+ Whether the Space stays on top of your list.
app.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import torch
4
+ from transformers import RealmForOpenQA, RealmRetriever
5
+
6
+ model_name = "google/realm-orqa-nq-openqa"
7
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
+
9
+ retriever = RealmRetriever.from_pretrained(model_name)
10
+ tokenizer = retriever.tokenizer
11
+ openqa = RealmForOpenQA.from_pretrained(model_name, retriever=retriever)
12
+ openqa.to(device)
13
+ default_num_block_records = openqa.config.num_block_records
14
+
15
+ def add_additional_documents(openqa, additional_documents):
16
+ documents = additional_documents.split("\n")
17
+ np_documents = np.array([doc.encode() for doc in documents], dtype=object)
18
+
19
+ total_documents = np_documents.shape[0]
20
+
21
+ retriever = openqa.retriever
22
+ tokenizer = openqa.retriever.tokenizer
23
+
24
+ # docs
25
+ retriever.block_records = np.concatenate((retriever.block_records[:default_num_block_records], np_documents), axis=0)
26
+
27
+ # embeds
28
+ inputs = tokenizer(documents, padding=True, truncation=True, return_tensors="pt").to(device)
29
+
30
+ with torch.no_grad():
31
+ projected_score = openqa.embedder(**inputs, return_dict=True).projected_score
32
+ openqa.block_emb = torch.cat((openqa.block_emb[:default_num_block_records], projected_score), dim=0)
33
+
34
+ openqa.config.num_block_records = default_num_block_records + total_documents
35
+
36
+ def question_answer(question, additional_documents):
37
+ question_ids = tokenizer(question, return_tensors="pt").input_ids
38
+
39
+ if additional_documents != "":
40
+ add_additional_documents(openqa, additional_documents)
41
+
42
+ with torch.no_grad():
43
+ outputs = openqa(input_ids=question_ids.to(device), return_dict=True)
44
+
45
+ return tokenizer.decode(outputs.predicted_answer_ids)
46
+
47
+
48
+ additional_documents_input = gr.inputs.Textbox(lines=5, placeholder="Each line represents a document entry. Leave blank to use default wiki documents.")
49
+
50
+ iface = gr.Interface(
51
+ fn=question_answer,
52
+ inputs=["text", additional_documents_input],
53
+ outputs=["textbox"],
54
+ allow_flagging="never"
55
+ )
56
+ iface.launch(enable_queue=True)
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ numpy
2
+ torch
3
+ transformers