Spaces:
Runtime error
Runtime error
Commit
·
25e0f62
1
Parent(s):
c654b20
add search type radio
Browse files
app.py
CHANGED
@@ -33,7 +33,7 @@ model.load_state_dict(torch.load(hf_hub_download(repo_id="nickgardner/chatbot",
|
|
33 |
filename="alpaca_train_400_epoch.pt"), map_location=device))
|
34 |
model.eval()
|
35 |
|
36 |
-
def respond(input):
|
37 |
model.eval()
|
38 |
src = torch.tensor(text_pipeline(input), dtype=torch.int64).unsqueeze(0).to(device)
|
39 |
src_mask = ((src != pad_token) & (src != unknown_token)).unsqueeze(-2).to(device)
|
@@ -46,18 +46,23 @@ def respond(input):
|
|
46 |
trg_mask = torch.autograd.Variable(torch.from_numpy(trg_mask) == 0).to(device)
|
47 |
|
48 |
out = model.out(model.decoder(outputs[:i].unsqueeze(0), e_outputs, src_mask, trg_mask))
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
|
|
|
|
|
|
60 |
return ' '.join([indices_to_tokens[ix] for ix in outputs[1:i]])
|
61 |
|
62 |
-
iface = gr.Interface(fn=respond,
|
|
|
|
|
63 |
iface.launch()
|
|
|
33 |
filename="alpaca_train_400_epoch.pt"), map_location=device))
|
34 |
model.eval()
|
35 |
|
36 |
+
def respond(search_type, input):
|
37 |
model.eval()
|
38 |
src = torch.tensor(text_pipeline(input), dtype=torch.int64).unsqueeze(0).to(device)
|
39 |
src_mask = ((src != pad_token) & (src != unknown_token)).unsqueeze(-2).to(device)
|
|
|
46 |
trg_mask = torch.autograd.Variable(torch.from_numpy(trg_mask) == 0).to(device)
|
47 |
|
48 |
out = model.out(model.decoder(outputs[:i].unsqueeze(0), e_outputs, src_mask, trg_mask))
|
49 |
+
if search_type == "Greedy":
|
50 |
+
out = torch.nn.functional.softmax(out, dim=-1)
|
51 |
+
val, ix = out[:, -1].data.topk(1)
|
52 |
+
|
53 |
+
outputs[i] = ix[0][0]
|
54 |
+
if ix[0][0] == vocab_token_dict['<eos>']:
|
55 |
+
break
|
56 |
+
else:
|
57 |
+
out = torch.nn.functional.softmax(out, dim=-1)[:, -1].squeeze().detach().numpy()
|
58 |
+
ix = np.random.choice(np.arange(len(out)), 1, p=out)
|
59 |
+
|
60 |
+
outputs[i] = ix[0]
|
61 |
+
if ix[0] == vocab_token_dict['<eos>']:
|
62 |
+
break
|
63 |
return ' '.join([indices_to_tokens[ix] for ix in outputs[1:i]])
|
64 |
|
65 |
+
iface = gr.Interface(fn=respond,
|
66 |
+
inputs=[gr.Radio(["Greedy", "Probabilistic"], label="Search Type"), "text"],
|
67 |
+
outputs="text")
|
68 |
iface.launch()
|