Copycats commited on
Commit
3890d25
1 Parent(s): 8023460

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -2
app.py CHANGED
@@ -2,13 +2,14 @@ import streamlit as st
2
  import torch
3
  from transformers import AutoModelForQuestionAnswering, AutoTokenizer
4
 
 
5
 
6
  @st.cache(allow_output_mutation=True)
7
  def get_model():
8
  # Load fine-tuned MRC model by HuggingFace Model Hub
9
  HUGGINGFACE_MODEL_PATH = "bespin-global/klue-bert-base-aihub-mrc"
10
  tokenizer = AutoTokenizer.from_pretrained(HUGGINGFACE_MODEL_PATH)
11
- model = AutoModelForQuestionAnswering.from_pretrained(HUGGINGFACE_MODEL_PATH)
12
 
13
  return tokenizer, model
14
 
@@ -62,7 +63,7 @@ if st.button("Submit", key='question'):
62
  return_token_type_ids=False,
63
  return_offsets_mapping=True
64
  )
65
- encodings = {key: torch.tensor([val]) for key, val in encodings.items()}
66
 
67
  # Predict
68
  pred = model(encodings["input_ids"], attention_mask=encodings["attention_mask"])
 
2
  import torch
3
  from transformers import AutoModelForQuestionAnswering, AutoTokenizer
4
 
5
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
6
 
7
  @st.cache(allow_output_mutation=True)
8
  def get_model():
9
  # Load fine-tuned MRC model by HuggingFace Model Hub
10
  HUGGINGFACE_MODEL_PATH = "bespin-global/klue-bert-base-aihub-mrc"
11
  tokenizer = AutoTokenizer.from_pretrained(HUGGINGFACE_MODEL_PATH)
12
+ model = AutoModelForQuestionAnswering.from_pretrained(HUGGINGFACE_MODEL_PATH).to(device)
13
 
14
  return tokenizer, model
15
 
 
63
  return_token_type_ids=False,
64
  return_offsets_mapping=True
65
  )
66
+ encodings = {key: torch.tensor([val]).to(device) for key, val in encodings.items()}
67
 
68
  # Predict
69
  pred = model(encodings["input_ids"], attention_mask=encodings["attention_mask"])