Joe99 commited on
Commit
7a7a56a
1 Parent(s): bbf4801

first finance bot

Browse files
Files changed (3) hide show
  1. app.py +54 -0
  2. gpt2talk.pt +3 -0
  3. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import transformers
2
+ import gradio as gr
3
+ import warnings
4
+ import torch
5
+ warnings.simplefilter('ignore')
6
+
7
+ device = "cuda" if torch.cuda.is_available() else "cpu"
8
+
9
+ tokenizer = transformers.GPT2Tokenizer.from_pretrained('gpt2')
10
+ #add padding token, beginstring and endstring tokens
11
+ tokenizer.add_special_tokens(
12
+ {
13
+ "pad_token":"<pad>",
14
+ "bos_token":"<startstring>",
15
+ "eos_token":"<endstring>"
16
+ })
17
+ #add bot token since it is not a special token
18
+ tokenizer.add_tokens(["<bot>:"])
19
+
20
+ model = transformers.GPT2LMHeadModel.from_pretrained('gpt2')
21
+ model.resize_token_embeddings(len(tokenizer))
22
+ model.load_state_dict(torch.load('gpt2talk.pt', map_location=torch.device('cpu')))
23
+
24
+ model.eval()
25
+ def inference(quiz):
26
+ quiz1 = quiz
27
+ quiz = "<startstring>"+quiz+" <bot>:"
28
+
29
+ quiztoken = tokenizer(quiz,
30
+ return_tensors='pt'
31
+ )
32
+
33
+ answer = model.generate(**quiztoken, max_length=200, top_k=0.7,top_p=0.1)[0]
34
+ answer = tokenizer.decode(answer, skip_special_tokens=True)
35
+ answer = answer.replace(" <bot>:","").replace(quiz1,"") + '.'
36
+ return answer
37
+
38
+ def chatbot(input_text):
39
+ response = inference(input_text)
40
+ return response
41
+
42
+ # Create the Gradio interface
43
+ iface = gr.Interface(
44
+ fn=chatbot,
45
+ inputs=gr.Textbox(),
46
+ outputs=gr.Textbox(),
47
+ live=False, #set false to avoid caching
48
+ interpretation="chat",
49
+ title="ChatFinance",
50
+ description="Ask the a question and see its response!",
51
+ )
52
+
53
+ # Launch the Gradio interface
54
+ iface.launch()
gpt2talk.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cbb4f3318512a3c112ad2fc12db6a7fb41b1beb98d4ead0825728a228705fb66
3
+ size 497826671
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch==2.0.1
2
+ transformers==4.31.0
3
+ gradio==3.44.4
4
+ gradio_client==0.5.1