Jalalkhan912 commited on
Commit
3dbd5f7
·
verified ·
1 Parent(s): 195b7ab

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +126 -0
app.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # load important libraries
2
+ from datasets import load_dataset
3
+ from transformers import AutoModelForSeq2SeqLM
4
+ from transformers import AutoTokenizer
5
+ from transformers import GenerationConfig
6
+ import streamlit as st
7
+
8
+ # load the dialog summarization dataset
9
+ huggingface_dataset_name = "knkarthick/dialogsum"
10
+ dataset = load_dataset(huggingface_dataset_name)
11
+
12
+ # load the google FLAN-T5 base model
13
+ model_name='google/flan-t5-base'
14
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
15
+
16
+ # load the specific tokenizer for above model
17
+ tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
18
+
19
+ # initialize variables
20
+ example_indices_full = [40]
21
+ example_indices_full_few_shot = [40, 80, 120, 200, 220]
22
+ dash_line = '-'.join('' for x in range(100))
23
+
24
+ # zero_shot inference
25
+ def zero_shot(my_example):
26
+ prompt = f"""
27
+ Dialogue:
28
+
29
+ {my_example}
30
+
31
+ What was going on?
32
+ """
33
+ generation_config = GenerationConfig(max_new_tokens=80, do_sample=True, temperature=1.0)
34
+
35
+ inputs = tokenizer(prompt, return_tensors='pt')
36
+ output = tokenizer.decode(
37
+ model.generate(
38
+ inputs["input_ids"],
39
+ generation_config=generation_config,
40
+ )[0],
41
+ skip_special_tokens=True
42
+ )
43
+
44
+ return output
45
+
46
+ # this prompt template will be used
47
+ def my_prompt(example_indices, my_example):
48
+ prompt = ''
49
+ for index in example_indices:
50
+ dialogue = dataset['test'][index]['dialogue']
51
+ summary = dataset['test'][index]['summary']
52
+ prompt += f"""
53
+ Dialogue:
54
+
55
+ {dialogue}
56
+
57
+ What was going on?
58
+ {summary}
59
+
60
+
61
+ """
62
+
63
+ prompt += f"""
64
+ Dialogue:
65
+
66
+ {my_example}
67
+
68
+ What was going on?
69
+ """
70
+
71
+ return prompt
72
+
73
+
74
+ # this is for one_shot
75
+ def one_shot(example_indices_full,my_example):
76
+ generation_config = GenerationConfig(max_new_tokens=80, do_sample=True, temperature=1.0)
77
+
78
+ inputs = tokenizer(my_prompt(example_indices_full,my_example), return_tensors='pt')
79
+ output = tokenizer.decode(
80
+ model.generate(
81
+ inputs["input_ids"],
82
+ generation_config=generation_config,
83
+ )[0],
84
+ skip_special_tokens=True
85
+ )
86
+ return output
87
+
88
+ # few_shot
89
+ def few_shot(example_indices_full_few_shot,my_example):
90
+ generation_config = GenerationConfig(max_new_tokens=80, do_sample=True, temperature=1.0)
91
+ inputs = tokenizer(my_prompt(example_indices_full_few_shot,my_example), return_tensors='pt')
92
+ output = tokenizer.decode(
93
+ model.generate(
94
+ inputs["input_ids"],
95
+ generation_config=generation_config,
96
+ )[0],
97
+ skip_special_tokens=True
98
+ )
99
+ return output
100
+
101
+ st.title("Google FLAN-T5(Base) Prompt Engineered Model: Zero-shot, Single-shot, and Few-shot")
102
+
103
+ my_example = st.text_area("Enter dialogues to summarize", value="#Maaz#: Jalal how are you?\n#Jalal#: I am good thank you.\n#Maaz#: Are you going to school tomorrow.\n#Jalal#: No bro i am not going to school tomorrow.\n#Maaz#: why?\n#Jalal#: I am working on a project, are you want to work with me on my project?\n#Maaz#: sorry, i have to go to school.")
104
+
105
+ if st.button("Run"):
106
+ zero_shot_output = zero_shot(my_example)
107
+ one_shot_output = one_shot(example_indices_full, my_example)
108
+ few_shot_output = few_shot(example_indices_full_few_shot, my_example)
109
+
110
+ st.header("**Comparison of Outputs**")
111
+
112
+ # Create three columns
113
+ col1, col2, col3 = st.columns(3)
114
+
115
+ # Display outputs in respective columns
116
+ with col1:
117
+ st.subheader("Zero-shot Output")
118
+ st.write(zero_shot_output)
119
+
120
+ with col2:
121
+ st.subheader("One-shot Output")
122
+ st.write(one_shot_output)
123
+
124
+ with col3:
125
+ st.subheader("Few-shot Output")
126
+ st.write(few_shot_output)