paulperry commited on
Commit
5a142b8
1 Parent(s): c564746

Add application file

Browse files
Files changed (1) hide show
  1. app.py +16 -0
app.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+ from transformers import RagTokenizer, RagRetriever, RagTokenForGeneration
4
+
5
+ tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")
6
+ retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", index_name="exact", use_dummy_dataset=True)
7
+ model = RagTokenForGeneration.from_pretrained("facebook/rag-token-nq", retriever=retriever)
8
+
9
+ input_dict = tokenizer.prepare_seq2seq_batch("who holds the record in 100m freestyle", return_tensors="pt")
10
+
11
+ generated = model.generate(input_ids=input_dict["input_ids"])
12
+ outstring = tokenizer.batch_decode(generated, skip_special_tokens=True)[0]
13
+ print(outstring)
14
+
15
+ st.write(outstring)
16
+