Joe99 commited on
Commit
428c8b1
1 Parent(s): ea73e99

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -6
app.py CHANGED
@@ -1,8 +1,8 @@
1
  import transformers
2
  import gradio as gr
3
- import warnings
4
  import torch
5
- warnings.simplefilter('ignore')
6
 
7
  device = "cuda" if torch.cuda.is_available() else "cpu"
8
 
@@ -16,11 +16,11 @@ tokenizer.add_special_tokens(
16
  })
17
  #add bot token since it is not a special token
18
  tokenizer.add_tokens(["<bot>:"])
19
-
20
  model = transformers.GPT2LMHeadModel.from_pretrained('gpt2')
21
  model.resize_token_embeddings(len(tokenizer))
22
- model.load_state_dict(torch.load('gpt2talk.pt', map_location=torch.device('cpu')))
23
-
24
  model.eval()
25
  def inference(quiz):
26
  quiz1 = quiz
@@ -40,6 +40,7 @@ def chatbot(input_text):
40
  return response
41
 
42
  # Create the Gradio interface
 
43
  iface = gr.Interface(
44
  fn=chatbot,
45
  inputs=gr.Textbox(),
@@ -49,6 +50,6 @@ iface = gr.Interface(
49
  title="ChatFinance",
50
  description="Ask the a question and see its response!",
51
  )
52
-
53
  # Launch the Gradio interface
54
  iface.launch()
 
1
  import transformers
2
  import gradio as gr
3
+ # import warnings
4
  import torch
5
+ # warnings.simplefilter('ignore')
6
 
7
  device = "cuda" if torch.cuda.is_available() else "cpu"
8
 
 
16
  })
17
  #add bot token since it is not a special token
18
  tokenizer.add_tokens(["<bot>:"])
19
+ print("=====Done 1")
20
  model = transformers.GPT2LMHeadModel.from_pretrained('gpt2')
21
  model.resize_token_embeddings(len(tokenizer))
22
+ model.load_state_dict(torch.load('./gpt2talk.pt', map_location=torch.device('cpu')))
23
+ print("=====Done 2")
24
  model.eval()
25
  def inference(quiz):
26
  quiz1 = quiz
 
40
  return response
41
 
42
  # Create the Gradio interface
43
+ print("=====Done 3")
44
  iface = gr.Interface(
45
  fn=chatbot,
46
  inputs=gr.Textbox(),
 
50
  title="ChatFinance",
51
  description="Ask the a question and see its response!",
52
  )
53
+ print("=====Done 4")
54
  # Launch the Gradio interface
55
  iface.launch()