sander-wood commited on
Commit
8b1add7
1 Parent(s): c95e20d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -2
app.py CHANGED
@@ -9,7 +9,10 @@ import requests
9
  from samplings import top_p_sampling, top_k_sampling, temperature_sampling
10
  from transformers import GPT2Config, GPT2Model, GPT2LMHeadModel, PreTrainedModel
11
 
12
- device = torch.device("cpu")
 
 
 
13
 
14
  description = """
15
  <div>
@@ -130,7 +133,7 @@ def generate_abc(prompt,
130
  print(f"Error: {e}")
131
  exit()
132
 
133
- model.load_state_dict(torch.load('pytorch_model.bin'))
134
  model.eval()
135
 
136
  tunes = ""
 
9
  from samplings import top_p_sampling, top_k_sampling, temperature_sampling
10
  from transformers import GPT2Config, GPT2Model, GPT2LMHeadModel, PreTrainedModel
11
 
12
+ if torch.cuda.is_available():
13
+ device = torch.device("cuda")
14
+ else:
15
+ device = torch.device("cpu")
16
 
17
  description = """
18
  <div>
 
133
  print(f"Error: {e}")
134
  exit()
135
 
136
+ model.load_state_dict(torch.load('pytorch_model.bin', map_location=device))
137
  model.eval()
138
 
139
  tunes = ""