File size: 3,252 Bytes
3dbd5f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
441ee20
3dbd5f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
441ee20
3dbd5f7
 
 
 
 
 
 
 
 
 
 
441ee20
3dbd5f7
 
 
 
 
 
 
c3b17cb
3dbd5f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
# load important libraries
from datasets import load_dataset
from transformers import AutoModelForSeq2SeqLM
from transformers import AutoTokenizer
from transformers import GenerationConfig
import streamlit as st

# load the dialog summarization dataset
huggingface_dataset_name = "knkarthick/dialogsum"
dataset = load_dataset(huggingface_dataset_name)

# load the google FLAN-T5 base model
model_name='google/flan-t5-base'
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

# load the specific tokenizer for above model
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)

# initialize variables
example_indices_full = [40]
example_indices_full_few_shot = [40, 80, 120, 200, 220]
dash_line = '-'.join('' for x in range(100))

# zero_shot inference
def zero_shot(my_example):
    prompt = f"""
Dialogue:

{my_example}

What was going on?
"""

    inputs = tokenizer(prompt, return_tensors='pt')
    output = tokenizer.decode(
        model.generate(
            inputs["input_ids"],
            max_new_tokens=50
        )[0],
        skip_special_tokens=True
    )

    return output

# this prompt template will be used
def my_prompt(example_indices, my_example):
    prompt = ''
    for index in example_indices:
        dialogue = dataset['test'][index]['dialogue']
        summary = dataset['test'][index]['summary']
        prompt += f"""
Dialogue:

{dialogue}

What was going on?
{summary}


"""

    prompt += f"""
Dialogue:

{my_example}

What was going on?
"""

    return prompt


# this is for one_shot
def one_shot(example_indices_full,my_example):

  inputs = tokenizer(my_prompt(example_indices_full,my_example), return_tensors='pt')
  output = tokenizer.decode(
      model.generate(
          inputs["input_ids"],
          max_new_tokens=50
      )[0],
      skip_special_tokens=True
  )
  return output

# few_shot
def few_shot(example_indices_full_few_shot,my_example):
  inputs = tokenizer(my_prompt(example_indices_full_few_shot,my_example), return_tensors='pt')
  output = tokenizer.decode(
      model.generate(
          inputs["input_ids"],
          max_new_tokens=50
      )[0],
      skip_special_tokens=True
  )
  return output

st.title("Google FLAN-T5(Base) Prompt Engineered Model: Zero-shot, Single-shot, and Few-shot")

my_example = st.text_area("Enter dialogues to summarize", value="Maaz: Jalal how are you?\nJalal: I am good thank you.\nMaaz: Are you going to school tomorrow.\nJalal: No bro i am not going to school tomorrow.\nMaaz: why?\nJalal: I am working on a project, are you want to work with me on my project?\nMaaz: sorry, i have to go to school.")

if st.button("Run"):  
    zero_shot_output = zero_shot(my_example)
    one_shot_output = one_shot(example_indices_full, my_example)
    few_shot_output = few_shot(example_indices_full_few_shot, my_example)

    st.header("**Comparison of Outputs**")

    # Create three columns
    col1, col2, col3 = st.columns(3)

    # Display outputs in respective columns
    with col1:
        st.subheader("Zero-shot Output")
        st.write(zero_shot_output)

    with col2:
        st.subheader("One-shot Output")
        st.write(one_shot_output)

    with col3:
        st.subheader("Few-shot Output")
        st.write(few_shot_output)