Shakhovak commited on
Commit
3fb88a6
1 Parent(s): 7042ec5

Upload 9 files

Browse files

Adding main files

Files changed (9) hide show
  1. Dockerfile +22 -0
  2. app.py +29 -0
  3. data/scripts.pkl +3 -0
  4. data/scripts_vectors.pkl +3 -0
  5. requirements.txt +7 -0
  6. retrieve_bot.py +72 -0
  7. static/style.css +223 -0
  8. templates/chat.html +80 -0
  9. utils.py +166 -0
Dockerfile ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.9.13
2
+
3
+ WORKDIR /code
4
+
5
+ COPY ./requirements.txt /code//requirements.txt
6
+
7
+ RUN pip install --no-cache-dir -r /code/requirements.txt
8
+
9
+ COPY . /code
10
+
11
+ RUN useradd -m -u 1000 user
12
+
13
+ USER user
14
+
15
+ ENV HOME=/home/user \
16
+ PATH=/home/user/.local/bin:$PATH
17
+
18
+ WORKDIR $HOME/app
19
+
20
+ COPY --chown=user . $HOME/app
21
+
22
+ CMD ["gunicorn", "-b", "0.0.0.0:7860", "app:app"]
app.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, render_template, request
2
+ from retrieve_bot import ChatBot
3
+
4
+ app = Flask(__name__)
5
+ chatSheldon = ChatBot()
6
+ chatSheldon.load()
7
+
8
+ # this script is running flask application
9
+
10
+
11
+ @app.route("/")
12
+ def index():
13
+ return render_template("chat.html")
14
+
15
+
16
+ @app.route("/get", methods=["GET", "POST"])
17
+ def chat():
18
+ msg = request.form["msg"]
19
+ input = msg
20
+ return get_Chat_response(input)
21
+
22
+
23
+ def get_Chat_response(text):
24
+ answer = chatSheldon.generate_response(text)
25
+ return answer
26
+
27
+
28
+ if __name__ == "__main__":
29
+ app.run(debug=True, port=7860)
data/scripts.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fd8ded525a9faf9031e899ba75c5b7f91fdc4052619a43ca1ff608a7cce73b42
3
+ size 2127113
data/scripts_vectors.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ba242c25adc032bcf265fa1c805bf1f506150f181a6fc13f6753088af79cd9c7
3
+ size 71223174
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ sentence-transformers==2.2.2
2
+ flask==2.2.5
3
+ pandas==1.3.5
4
+ gunicorn==20.1.0
5
+ requests==2.27.
6
+ datasets==2.13.2
7
+ transformers==4.37.2
retrieve_bot.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import pickle
3
+ from sentence_transformers import SentenceTransformer
4
+ from utils import encode, cosine_sim, top_candidates, candidates_reranking
5
+ from collections import deque
6
+ from transformers import pipeline
7
+ import torch
8
+ from transformers import AutoTokenizer
9
+
10
+ # this class representes main functions of retrieve bot
11
+
12
+
13
+ class ChatBot:
14
+ def __init__(self):
15
+ self.vect_data = []
16
+ self.scripts = []
17
+ self.conversation_history = deque([], maxlen=5)
18
+ self.ranking_model = None
19
+ self.reranking_model = None
20
+ self.device = None
21
+ self.tokenizer = None
22
+
23
+ def load(self):
24
+ """ "This method is called first to load all datasets and
25
+ model used by the chat bot; all the data to be saved in
26
+ tha data folder, models to be loaded from hugging face"""
27
+
28
+ with open("data/scripts_vectors.pkl", "rb") as fp:
29
+ self.vect_data = pickle.load(fp)
30
+ self.scripts = pd.read_pickle("data/scripts.pkl")
31
+ self.ranking_model = SentenceTransformer("sentence-transformers/LaBSE")
32
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
33
+ self.tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
34
+ self.reranking_model = pipeline(
35
+ model="Shakhovak/RerankerModel_chat_bot",
36
+ device=self.device,
37
+ tokenizer=self.tokenizer,
38
+ )
39
+
40
+ def generate_response(self, utterance: str) -> str:
41
+ """this functions identifies potential
42
+ candidates for answer and ranks them"""
43
+ query_encoding = encode(
44
+ utterance, self.ranking_model, contexts=self.conversation_history
45
+ )
46
+ bot_cosine_scores = cosine_sim(self.vect_data, query_encoding)
47
+ top_scores, top_indexes = top_candidates(bot_cosine_scores, top=20)
48
+
49
+ # test candidates and collects them with label 0 to dictionary
50
+
51
+ reranked_dict = candidates_reranking(
52
+ top_indexes,
53
+ self.conversation_history,
54
+ utterance,
55
+ self.scripts,
56
+ self.reranking_model,
57
+ )
58
+ # if any candidates were selected, range them and pick up the top
59
+ # else keep up the initial top 1
60
+
61
+ if len(reranked_dict) >= 1:
62
+ updated_top_candidates = dict(
63
+ sorted(reranked_dict.items(), key=lambda item: item[1])
64
+ )
65
+ answer = self.scripts.iloc[list(updated_top_candidates.keys())[0]]["answer"]
66
+ else:
67
+ answer = self.scripts.iloc[top_indexes[0]]["answer"]
68
+
69
+ self.conversation_history.append(utterance)
70
+ self.conversation_history.append(answer)
71
+
72
+ return answer
static/style.css ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ body,html{
2
+ height: 100%;
3
+ margin: 0;
4
+ background: rgb(44, 47, 59);
5
+ background: -webkit-linear-gradient(to right, rgb(40, 59, 34), rgb(54, 60, 70), rgb(32, 32, 43));
6
+ background: linear-gradient(to right, rgb(38, 51, 61), rgb(50, 55, 65), rgb(33, 33, 78));
7
+ }
8
+
9
+ .chat{
10
+ margin-top: auto;
11
+ margin-bottom: auto;
12
+ }
13
+ .card{
14
+ height: 500px;
15
+ border-radius: 15px !important;
16
+ background-color: rgba(0,0,0,0.4) !important;
17
+ }
18
+ .contacts_body{
19
+ padding: 0.75rem 0 !important;
20
+ overflow-y: auto;
21
+ white-space: nowrap;
22
+ }
23
+ .msg_card_body{
24
+ overflow-y: auto;
25
+ }
26
+ .card-header{
27
+ border-radius: 15px 15px 0 0 !important;
28
+ border-bottom: 0 !important;
29
+ }
30
+ .card-footer{
31
+ border-radius: 0 0 15px 15px !important;
32
+ border-top: 0 !important;
33
+ }
34
+ .container{
35
+ align-content: center;
36
+ }
37
+ .search{
38
+ border-radius: 15px 0 0 15px !important;
39
+ background-color: rgba(0,0,0,0.3) !important;
40
+ border:0 !important;
41
+ color:white !important;
42
+ }
43
+ .search:focus{
44
+ box-shadow:none !important;
45
+ outline:0px !important;
46
+ }
47
+ .type_msg{
48
+ background-color: rgba(0,0,0,0.3) !important;
49
+ border:0 !important;
50
+ color:white !important;
51
+ height: 60px !important;
52
+ overflow-y: auto;
53
+ }
54
+ .type_msg:focus{
55
+ box-shadow:none !important;
56
+ outline:0px !important;
57
+ }
58
+ .attach_btn{
59
+ border-radius: 15px 0 0 15px !important;
60
+ background-color: rgba(0,0,0,0.3) !important;
61
+ border:0 !important;
62
+ color: white !important;
63
+ cursor: pointer;
64
+ }
65
+ .send_btn{
66
+ border-radius: 0 15px 15px 0 !important;
67
+ background-color: rgba(0,0,0,0.3) !important;
68
+ border:0 !important;
69
+ color: white !important;
70
+ cursor: pointer;
71
+ }
72
+ .search_btn{
73
+ border-radius: 0 15px 15px 0 !important;
74
+ background-color: rgba(0,0,0,0.3) !important;
75
+ border:0 !important;
76
+ color: white !important;
77
+ cursor: pointer;
78
+ }
79
+ .contacts{
80
+ list-style: none;
81
+ padding: 0;
82
+ }
83
+ .contacts li{
84
+ width: 100% !important;
85
+ padding: 5px 10px;
86
+ margin-bottom: 15px !important;
87
+ }
88
+ .active{
89
+ background-color: rgba(0,0,0,0.3);
90
+ }
91
+ .user_img{
92
+ height: 70px;
93
+ width: 70px;
94
+ border:1.5px solid #f5f6fa;
95
+
96
+ }
97
+ .user_img_msg{
98
+ height: 40px;
99
+ width: 40px;
100
+ border:1.5px solid #f5f6fa;
101
+
102
+ }
103
+ .img_cont{
104
+ position: relative;
105
+ height: 70px;
106
+ width: 70px;
107
+ }
108
+ .img_cont_msg{
109
+ height: 40px;
110
+ width: 40px;
111
+ }
112
+ .online_icon{
113
+ position: absolute;
114
+ height: 15px;
115
+ width:15px;
116
+ background-color: #4cd137;
117
+ border-radius: 50%;
118
+ bottom: 0.2em;
119
+ right: 0.4em;
120
+ border:1.5px solid white;
121
+ }
122
+ .offline{
123
+ background-color: #c23616 !important;
124
+ }
125
+ .user_info{
126
+ margin-top: auto;
127
+ margin-bottom: auto;
128
+ margin-left: 15px;
129
+ }
130
+ .user_info span{
131
+ font-size: 20px;
132
+ color: white;
133
+ }
134
+ .user_info p{
135
+ font-size: 10px;
136
+ color: rgba(255,255,255,0.6);
137
+ }
138
+ .video_cam{
139
+ margin-left: 50px;
140
+ margin-top: 5px;
141
+ }
142
+ .video_cam span{
143
+ color: white;
144
+ font-size: 20px;
145
+ cursor: pointer;
146
+ margin-right: 20px;
147
+ }
148
+ .msg_cotainer{
149
+ margin-top: auto;
150
+ margin-bottom: auto;
151
+ margin-left: 10px;
152
+ border-radius: 25px;
153
+ background-color: rgb(82, 172, 255);
154
+ padding: 10px;
155
+ position: relative;
156
+ }
157
+ .msg_cotainer_send{
158
+ margin-top: auto;
159
+ margin-bottom: auto;
160
+ margin-right: 10px;
161
+ border-radius: 25px;
162
+ background-color: #58cc71;
163
+ padding: 10px;
164
+ position: relative;
165
+ }
166
+ .msg_time{
167
+ position: absolute;
168
+ left: 0;
169
+ bottom: -15px;
170
+ color: rgba(255,255,255,0.5);
171
+ font-size: 10px;
172
+ }
173
+ .msg_time_send{
174
+ position: absolute;
175
+ right:0;
176
+ bottom: -15px;
177
+ color: rgba(255,255,255,0.5);
178
+ font-size: 10px;
179
+ }
180
+ .msg_head{
181
+ position: relative;
182
+ }
183
+ #action_menu_btn{
184
+ position: absolute;
185
+ right: 10px;
186
+ top: 10px;
187
+ color: white;
188
+ cursor: pointer;
189
+ font-size: 20px;
190
+ }
191
+ .action_menu{
192
+ z-index: 1;
193
+ position: absolute;
194
+ padding: 15px 0;
195
+ background-color: rgba(0,0,0,0.5);
196
+ color: white;
197
+ border-radius: 15px;
198
+ top: 30px;
199
+ right: 15px;
200
+ display: none;
201
+ }
202
+ .action_menu ul{
203
+ list-style: none;
204
+ padding: 0;
205
+ margin: 0;
206
+ }
207
+ .action_menu ul li{
208
+ width: 100%;
209
+ padding: 10px 15px;
210
+ margin-bottom: 5px;
211
+ }
212
+ .action_menu ul li i{
213
+ padding-right: 10px;
214
+ }
215
+ .action_menu ul li:hover{
216
+ cursor: pointer;
217
+ background-color: rgba(0,0,0,0.2);
218
+ }
219
+ @media(max-width: 576px){
220
+ .contacts_card{
221
+ margin-bottom: 15px !important;
222
+ }
223
+ }
templates/chat.html ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <link href="//maxcdn.bootstrapcdn.com/bootstrap/4.1.1/css/bootstrap.min.css" rel="stylesheet" id="bootstrap-css">
2
+ <script src="//maxcdn.bootstrapcdn.com/bootstrap/4.1.1/js/bootstrap.min.js"></script>
3
+ <script src="//cdnjs.cloudflare.com/ajax/libs/jquery/3.2.1/jquery.min.js"></script>
4
+
5
+ <!DOCTYPE html>
6
+ <html>
7
+ <head>
8
+ <title>Chatbot</title>
9
+ <link rel="stylesheet" href="https://stackpath.bootstrapcdn.com/bootstrap/4.1.3/css/bootstrap.min.css" integrity="sha384-MCw98/SFnGE8fJT3GXwEOngsV7Zt27NXFoaoApmYm81iuXoPkFOJwJ8ERdknLPMO" crossorigin="anonymous">
10
+ <link rel="stylesheet" href="https://use.fontawesome.com/releases/v5.5.0/css/all.css" integrity="sha384-B4dIYHKNBt8Bc12p+WXckhzcICo0wtJAoU8YZTY5qE0Id1GSseTk6S+L3BlXeVIU" crossorigin="anonymous">
11
+ <script src="https://ajax.googleapis.com/ajax/libs/jquery/3.3.1/jquery.min.js"></script>
12
+ <link rel="stylesheet" type="text/css" href="{{ url_for('static', filename='style.css')}}"/>
13
+ </head>
14
+
15
+
16
+ <body>
17
+ <div class="container-fluid h-100">
18
+ <div class="row justify-content-center h-100">
19
+ <div class="col-md-8 col-xl-6 chat">
20
+ <div class="card">
21
+ <div class="card-header msg_head">
22
+ <div class="d-flex bd-highlight">
23
+ <div class="img_cont">
24
+ <img src="https://stickerpacks.ru/wp-content/uploads/2023/04/nabor-stikerov-teorija-bolshogo-vzryva-5-dlja-telegram-3.webp" class="rounded-circle user_img">
25
+ <span class="online_icon"></span>
26
+ </div>
27
+ <div class="user_info">
28
+ <span>ChatBot</span>
29
+ <p>Ask me anything!</p>
30
+ </div>
31
+ </div>
32
+ </div>
33
+ <div id="messageFormeight" class="card-body msg_card_body">
34
+
35
+
36
+ </div>
37
+ <div class="card-footer">
38
+ <form id="messageArea" class="input-group">
39
+ <input type="text" id="text" name="msg" placeholder="Type your message..." autocomplete="off" class="form-control type_msg" required/>
40
+ <div class="input-group-append">
41
+ <button type="submit" id="send" class="input-group-text send_btn"><i class="fas fa-location-arrow"></i></button>
42
+ </div>
43
+ </form>
44
+ </div>
45
+ </div>
46
+ </div>
47
+ </div>
48
+ </div>
49
+
50
+ <script>
51
+ $(document).ready(function() {
52
+ $("#messageArea").on("submit", function(event) {
53
+ const date = new Date();
54
+ const hour = date.getHours();
55
+ const minute = date.getMinutes();
56
+ const str_time = hour+":"+minute;
57
+ var rawText = $("#text").val();
58
+
59
+ var userHtml = '<div class="d-flex justify-content-end mb-4"><div class="msg_cotainer_send">' + rawText + '<span class="msg_time_send">'+ str_time + '</span></div><div class="img_cont_msg"><img src="https://i.ibb.co/d5b84Xw/Untitled-design.png" class="rounded-circle user_img_msg"></div></div>';
60
+
61
+ $("#text").val("");
62
+ $("#messageFormeight").append(userHtml);
63
+
64
+ $.ajax({
65
+ data: {
66
+ msg: rawText,
67
+ },
68
+ type: "POST",
69
+ url: "/get",
70
+ }).done(function(data) {
71
+ var botHtml = '<div class="d-flex justify-content-start mb-4"><div class="img_cont_msg"><img src="https://stickerpacks.ru/wp-content/uploads/2023/04/nabor-stikerov-teorija-bolshogo-vzryva-5-dlja-telegram-3.webp" class="rounded-circle user_img_msg"></div><div class="msg_cotainer">' + data + '<span class="msg_time">' + str_time + '</span></div></div>';
72
+ $("#messageFormeight").append($.parseHTML(botHtml));
73
+ });
74
+ event.preventDefault();
75
+ });
76
+ });
77
+ </script>
78
+
79
+ </body>
80
+ </html>
utils.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from sklearn.metrics.pairwise import cosine_similarity
3
+ from scipy import sparse
4
+ import pandas as pd
5
+ import pickle
6
+ import random
7
+
8
+
9
+ def encode(texts, model, contexts=None, do_norm=True):
10
+ """function to encode texts for cosine similarity search"""
11
+
12
+ question_vectors = model.encode(texts)
13
+ context_vectors = model.encode("".join(contexts))
14
+
15
+ return np.concatenate(
16
+ [np.asarray(question_vectors), np.asarray(context_vectors)], axis=-1
17
+ )
18
+
19
+
20
+ def cosine_sim(data_vectors, query_vectors) -> list:
21
+ """returns list of tuples with similarity score and
22
+ script index in initial dataframe"""
23
+ data_emb = sparse.csr_matrix(data_vectors)
24
+ query_emb = sparse.csr_matrix(query_vectors)
25
+ similarity = cosine_similarity(query_emb, data_emb).flatten()
26
+ ind = np.argwhere(similarity)
27
+ match = sorted(zip(similarity, ind.tolist()), reverse=True)
28
+ return match
29
+
30
+
31
+ def scripts_rework(path, character):
32
+ """this functions split scripts for queation, answer, context,
33
+ picks up the cahracter and saves data in pickle format"""
34
+
35
+ df = pd.read_csv(path)
36
+
37
+ # split data for scenes
38
+ count = 0
39
+ df["scene_count"] = ""
40
+ for index, row in df.iterrows():
41
+ if index == 0:
42
+ df.iloc[index]["scene_count"] = count
43
+ elif row["person_scene"] == "Scene":
44
+ count += 1
45
+ df.iloc[index]["scene_count"] = count
46
+ else:
47
+ df.iloc[index]["scene_count"] = count
48
+
49
+ df = df.dropna().reset_index()
50
+
51
+ # rework scripts to filer by caracter utterances and related context
52
+ scripts = pd.DataFrame()
53
+ for index, row in df.iterrows():
54
+ if (row["person_scene"] == character) & (
55
+ df.iloc[index - 1]["person_scene"] != "Scene"
56
+ ):
57
+ context = []
58
+ for i in reversed(range(2, 5)):
59
+ if (df.iloc[index - i]["person_scene"] != "Scene") & (index - i >= 0):
60
+ context.append(df.iloc[index - i]["dialogue"])
61
+ else:
62
+ break
63
+ new_row = {
64
+ "answer": row["dialogue"],
65
+ "question": df.iloc[index - 1]["dialogue"],
66
+ "context": context,
67
+ }
68
+
69
+ scripts = scripts.append(new_row, ignore_index=True)
70
+
71
+ elif (row["person_scene"] == character) & (
72
+ df.iloc[index - 1]["person_scene"] == "Scene"
73
+ ):
74
+ context = []
75
+ new_row = {"answer": row["dialogue"], "question": "", "context": context}
76
+ scripts = scripts.append(new_row, ignore_index=True)
77
+ # load reworked data to pkl
78
+ scripts.to_pickle("data/scripts.pkl")
79
+
80
+
81
+ def encode_df_save(model):
82
+ """this functions vectorizes reworked scripts and loads them to
83
+ pickle file to be used as retrieval base for ranking script"""
84
+
85
+ scripts_reopened = pd.read_pickle("data/scripts.pkl")
86
+ vect_data = []
87
+ for index, row in scripts_reopened.iterrows():
88
+ vect = encode(row["question"], model, row["context"])
89
+ vect_data.append(vect)
90
+ with open("data/scripts_vectors.pkl", "wb") as f:
91
+ pickle.dump(vect_data, f)
92
+
93
+
94
+ def top_candidates(score_lst_sorted, top=1):
95
+ """this functions receives results of the cousine similarity ranking and
96
+ returns top items' scores and their indices"""
97
+
98
+ scores = [item[0] for item in score_lst_sorted]
99
+ candidates_indexes = [item[1][0] for item in score_lst_sorted]
100
+ return scores[0:top], candidates_indexes[0:top]
101
+
102
+
103
+ def candidates_reranking(
104
+ top_candidates_idx_lst, conversational_history, utterance, initial_df, pipeline
105
+ ):
106
+ """this function applies trained bert classifier to identified candidates and
107
+ returns their updated rank"""
108
+ reranked_idx = {}
109
+ for idx in top_candidates_idx_lst:
110
+
111
+ combined_text = (
112
+ " ".join(conversational_history)
113
+ + " [SEP] "
114
+ + utterance
115
+ + " [SEP] "
116
+ + initial_df.iloc[idx]["answer"]
117
+ )
118
+
119
+ prediction = pipeline(combined_text)
120
+ if prediction[0]["label"] == "LABEL_0":
121
+ reranked_idx[idx] = prediction[0]["score"]
122
+
123
+ return reranked_idx
124
+
125
+
126
+ def read_files_negative(path1, path2):
127
+ """this functions creates training dataset for classifier incl negative
128
+ examples and saves it to the pickle file"""
129
+
130
+ star_wars = []
131
+ for file in path1:
132
+ star_wars.append(pd.read_csv(file, sep='"', on_bad_lines="warn"))
133
+ total = pd.concat(star_wars, ignore_index=True)
134
+
135
+ rick_and_morty = pd.read_csv(path2)
136
+ negative_lines_to_add = list(rick_and_morty["line"])
137
+ negative_lines_to_add.extend(list(total["dialogue"]))
138
+
139
+ scripts_reopened = pd.read_pickle("data/scripts.pkl")
140
+ scripts_reopened["label"] = 0
141
+ source = random.sample(
142
+ list(scripts_reopened[scripts_reopened["question"] != ""]["question"]), 7062
143
+ )
144
+ negative_lines_to_add.extend(source)
145
+ random.shuffle(negative_lines_to_add)
146
+
147
+ scripts_negative = scripts_reopened[["question", "context"]]
148
+ scripts_negative["label"] = 1
149
+
150
+ scripts_negative["answer"] = negative_lines_to_add[0 : len(scripts_negative)]
151
+
152
+ fin_scripts = pd.concat([scripts_negative, scripts_reopened])
153
+
154
+ fin_scripts = fin_scripts.sample(frac=1).reset_index(drop=True)
155
+ fin_scripts["context"] = fin_scripts["context"].apply(lambda x: "".join(x))
156
+ fin_scripts = fin_scripts[fin_scripts["question"] != ""]
157
+ fin_scripts = fin_scripts[fin_scripts["answer"] != ""]
158
+ fin_scripts["combined"] = (
159
+ fin_scripts["context"]
160
+ + "[SEP]"
161
+ + fin_scripts["question"]
162
+ + "[SEP]"
163
+ + fin_scripts["answer"]
164
+ )
165
+ # fin_scripts = fin_scripts.dropna(how='any')
166
+ fin_scripts.to_pickle("data/scripts_for_reranker.pkl")