vrushabh94
commited on
Commit
•
0251703
1
Parent(s):
5f9dcd8
Removed the sharing
Browse files
app.py
CHANGED
@@ -11,12 +11,14 @@ model = AutoModelForSequenceClassification.from_pretrained(model_name)
|
|
11 |
premise = ""
|
12 |
hypothesis = ""
|
13 |
|
14 |
-
def zeroShotClassification(text_input,
|
|
|
|
|
15 |
input = tokenizer(text_input, hypothesis, truncation=True, return_tensors="pt")
|
16 |
output = model(input["input_ids"].to(device)) # device = "cuda:0" or "cpu"
|
17 |
prediction = torch.softmax(output["logits"][0], -1).tolist()
|
18 |
-
|
19 |
-
prediction = {name: round(float(pred) * 100, 1) for pred, name in zip(prediction,
|
20 |
return prediction
|
21 |
|
22 |
examples = [
|
@@ -28,5 +30,5 @@ examples = [
|
|
28 |
["Submission Receipt", "Submission Receipt"]
|
29 |
]
|
30 |
|
31 |
-
demo = gr.Interface(fn=zeroShotClassification, inputs=[gr.Textbox(label="Input"), gr.Textbox(label="Candidate Labels", value="Meeting Minutes / Outcomes,Submission Receipt")], outputs=gr.Label(label="Classification"), examples=examples)
|
32 |
demo.launch();
|
|
|
11 |
premise = ""
|
12 |
hypothesis = ""
|
13 |
|
14 |
+
def zeroShotClassification(text_input, candidate_labels, hypothesis):
|
15 |
+
print(text_input)
|
16 |
+
print(candidate_labels)
|
17 |
input = tokenizer(text_input, hypothesis, truncation=True, return_tensors="pt")
|
18 |
output = model(input["input_ids"].to(device)) # device = "cuda:0" or "cpu"
|
19 |
prediction = torch.softmax(output["logits"][0], -1).tolist()
|
20 |
+
labels = [label.strip(' ') for label in candidate_labels.split(',')]
|
21 |
+
prediction = {name: round(float(pred) * 100, 1) for pred, name in zip(prediction, labels)}
|
22 |
return prediction
|
23 |
|
24 |
examples = [
|
|
|
30 |
["Submission Receipt", "Submission Receipt"]
|
31 |
]
|
32 |
|
33 |
+
demo = gr.Interface(fn=zeroShotClassification, inputs=[gr.Textbox(label="Input"), gr.Textbox(label="Candidate Labels", value="Meeting Minutes / Outcomes,Submission Receipt"), gr.Textbox(label="Hypothesys", value="Meeting Minutes / Outcomes,Submission Receipt")], outputs=gr.Label(label="Classification"), examples=examples)
|
34 |
demo.launch();
|