Spaces:
Running
Running
sander-wood
commited on
Commit
•
8b1add7
1
Parent(s):
c95e20d
Update app.py
Browse files
app.py
CHANGED
@@ -9,7 +9,10 @@ import requests
|
|
9 |
from samplings import top_p_sampling, top_k_sampling, temperature_sampling
|
10 |
from transformers import GPT2Config, GPT2Model, GPT2LMHeadModel, PreTrainedModel
|
11 |
|
12 |
-
|
|
|
|
|
|
|
13 |
|
14 |
description = """
|
15 |
<div>
|
@@ -130,7 +133,7 @@ def generate_abc(prompt,
|
|
130 |
print(f"Error: {e}")
|
131 |
exit()
|
132 |
|
133 |
-
model.load_state_dict(torch.load('pytorch_model.bin'))
|
134 |
model.eval()
|
135 |
|
136 |
tunes = ""
|
|
|
9 |
from samplings import top_p_sampling, top_k_sampling, temperature_sampling
|
10 |
from transformers import GPT2Config, GPT2Model, GPT2LMHeadModel, PreTrainedModel
|
11 |
|
12 |
+
if torch.cuda.is_available():
|
13 |
+
device = torch.device("cuda")
|
14 |
+
else:
|
15 |
+
device = torch.device("cpu")
|
16 |
|
17 |
description = """
|
18 |
<div>
|
|
|
133 |
print(f"Error: {e}")
|
134 |
exit()
|
135 |
|
136 |
+
model.load_state_dict(torch.load('pytorch_model.bin', map_location=device))
|
137 |
model.eval()
|
138 |
|
139 |
tunes = ""
|