Petr Tsvetkov commited on
Commit
347f566
1 Parent(s): 5bd86a2

Fix the synthetic data generation pipeline

Browse files
api_wrappers/grazie_wrapper.py CHANGED
@@ -32,7 +32,7 @@ def llm_request(prompt):
32
 
33
  while output is None:
34
  try:
35
- output = output = client.chat(
36
  chat=ChatPrompt()
37
  .add_system("You are a helpful assistant.")
38
  .add_user(prompt),
 
32
 
33
  while output is None:
34
  try:
35
+ output = client.chat(
36
  chat=ChatPrompt()
37
  .add_system("You are a helpful assistant.")
38
  .add_user(prompt),
dataset_statistics.py CHANGED
@@ -9,10 +9,7 @@ from scipy.stats import stats
9
  import config
10
 
11
 
12
- def get_statistics(row):
13
- start_msg = row["commit_msg_start"]
14
- end_msg = row["commit_msg_end"]
15
-
16
  edit_ops = Levenshtein.editops(start_msg, end_msg)
17
  n_deletes = sum([1 if op == 'delete' else 0 for op, _, _ in edit_ops])
18
  n_inserts = sum([1 if op == 'insert' else 0 for op, _, _ in edit_ops])
@@ -32,12 +29,18 @@ def get_statistics(row):
32
  "changes_norm": n_changes / len(end_msg),
33
 
34
  "lendiff": abs(len(start_msg) - len(end_msg)),
35
- "editdist": row["editdist_related"]
36
  }
37
 
38
 
 
 
 
 
 
 
39
  def get_statistics_for_df(df: pd.DataFrame):
40
- stats = [get_statistics(row) for _, row in
41
  df.iterrows()]
42
 
43
  assert len(stats) > 0
 
9
  import config
10
 
11
 
12
+ def get_statistics_for_sample(start_msg, end_msg, row=None):
 
 
 
13
  edit_ops = Levenshtein.editops(start_msg, end_msg)
14
  n_deletes = sum([1 if op == 'delete' else 0 for op, _, _ in edit_ops])
15
  n_inserts = sum([1 if op == 'insert' else 0 for op, _, _ in edit_ops])
 
29
  "changes_norm": n_changes / len(end_msg),
30
 
31
  "lendiff": abs(len(start_msg) - len(end_msg)),
32
+ "editdist": row["editdist_related"] if row is not None else Levenshtein.distance(start_msg, end_msg),
33
  }
34
 
35
 
36
+ def get_statistics_for_row(row):
37
+ start_msg = row["commit_msg_start"]
38
+ end_msg = row["commit_msg_end"]
39
+ return get_statistics_for_sample(start_msg, end_msg, row=row)
40
+
41
+
42
  def get_statistics_for_df(df: pd.DataFrame):
43
+ stats = [get_statistics_for_row(row) for _, row in
44
  df.iterrows()]
45
 
46
  assert len(stats) > 0
generation_steps/synthetic_end_to_start.py CHANGED
@@ -4,8 +4,8 @@ import pandas as pd
4
  from tqdm import tqdm
5
 
6
  import config
7
- import generate_annotated_diffs
8
  import dataset_statistics
 
9
  from api_wrappers import grazie_wrapper, hf_data_loader
10
  from generation_steps import examples
11
 
@@ -49,9 +49,8 @@ def generate_start_msg(end_msg, diff):
49
  for i in range(GENERATION_ATTEMPTS):
50
  start_msg_pred = grazie_wrapper.generate_for_prompt(prompt)
51
 
52
- stats = statistics.get_statistics(start_msg=start_msg_pred, end_msg=end_msg,
53
- annotated_msg=generate_annotated_diffs.get_annotated_diff(start_msg_pred,
54
- end_msg))
55
  if stats["insertions"] < REL_INSERTIONS_THRESHOLD:
56
  return start_msg_pred
57
  else:
 
4
  from tqdm import tqdm
5
 
6
  import config
 
7
  import dataset_statistics
8
+ import generate_annotated_diffs
9
  from api_wrappers import grazie_wrapper, hf_data_loader
10
  from generation_steps import examples
11
 
 
49
  for i in range(GENERATION_ATTEMPTS):
50
  start_msg_pred = grazie_wrapper.generate_for_prompt(prompt)
51
 
52
+ stats = dataset_statistics.get_statistics_for_sample(start_msg=start_msg_pred, end_msg=end_msg,)
53
+
 
54
  if stats["insertions"] < REL_INSERTIONS_THRESHOLD:
55
  return start_msg_pred
56
  else:
generation_steps/synthetic_start_to_end.py CHANGED
@@ -2,7 +2,6 @@ import pandas as pd
2
  from tqdm import tqdm
3
 
4
  import config
5
- import generate_annotated_diffs
6
  import dataset_statistics
7
  from api_wrappers import grazie_wrapper
8
  from generation_steps import examples
@@ -47,9 +46,7 @@ def generate_end_msg(start_msg, diff):
47
  for i in range(GENERATION_ATTEMPTS):
48
  end_msg_pred = grazie_wrapper.generate_for_prompt(prompt)
49
 
50
- stats = statistics.get_statistics(start_msg=start_msg, end_msg=end_msg_pred,
51
- annotated_msg=generate_annotated_diffs.get_annotated_diff(start_msg,
52
- end_msg_pred))
53
  if stats["deletions"] < REL_DELETIONS_THRESHOLD:
54
  return end_msg_pred
55
  else:
 
2
  from tqdm import tqdm
3
 
4
  import config
 
5
  import dataset_statistics
6
  from api_wrappers import grazie_wrapper
7
  from generation_steps import examples
 
46
  for i in range(GENERATION_ATTEMPTS):
47
  end_msg_pred = grazie_wrapper.generate_for_prompt(prompt)
48
 
49
+ stats = dataset_statistics.get_statistics_for_sample(start_msg=start_msg, end_msg=end_msg_pred, )
 
 
50
  if stats["deletions"] < REL_DELETIONS_THRESHOLD:
51
  return end_msg_pred
52
  else: