Tonic commited on
Commit
4591b17
β€’
1 Parent(s): 6d9b134

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -10
app.py CHANGED
@@ -49,8 +49,6 @@ def predict(_chatbot, task_history, user_input):
49
  print("User: " + _parse_text(query))
50
  if not task_history:
51
  return _chatbot
52
-
53
- query = task_history[-1][0]
54
  history_cp = copy.deepcopy(task_history)
55
  history_filter = []
56
  audio_idx = 1
@@ -59,10 +57,8 @@ def predict(_chatbot, task_history, user_input):
59
  for i, item in enumerate(history_cp):
60
  if len(item) != 2:
61
  print(f"Error: Expected a tuple of length 2, but got {item}")
62
- continue
63
-
64
- q, a = item
65
-
66
  if isinstance(q, (tuple, list)):
67
  last_audio = q[0]
68
  q = f'Audio {audio_idx}: <audio>{q[0]}</audio>'
@@ -72,13 +68,10 @@ def predict(_chatbot, task_history, user_input):
72
  pre += q
73
  history_filter.append((pre, a))
74
  pre = ""
75
-
76
  if not history_filter:
77
- return _chatbot
78
-
79
  history, message = history_filter[:-1], history_filter[-1][0]
80
  response, history = model.chat(tokenizer, message, history=history)
81
-
82
  ts_pattern = r"<\|\d{1,2}\.\d+\|>"
83
  all_time_stamps = re.findall(ts_pattern, response)
84
  if (len(all_time_stamps) > 0) and (len(all_time_stamps) % 2 ==0) and last_audio:
 
49
  print("User: " + _parse_text(query))
50
  if not task_history:
51
  return _chatbot
 
 
52
  history_cp = copy.deepcopy(task_history)
53
  history_filter = []
54
  audio_idx = 1
 
57
  for i, item in enumerate(history_cp):
58
  if len(item) != 2:
59
  print(f"Error: Expected a tuple of length 2, but got {item}")
60
+ continue
61
+ q, a = item
 
 
62
  if isinstance(q, (tuple, list)):
63
  last_audio = q[0]
64
  q = f'Audio {audio_idx}: <audio>{q[0]}</audio>'
 
68
  pre += q
69
  history_filter.append((pre, a))
70
  pre = ""
 
71
  if not history_filter:
72
+ return _chatbot
 
73
  history, message = history_filter[:-1], history_filter[-1][0]
74
  response, history = model.chat(tokenizer, message, history=history)
 
75
  ts_pattern = r"<\|\d{1,2}\.\d+\|>"
76
  all_time_stamps = re.findall(ts_pattern, response)
77
  if (len(all_time_stamps) > 0) and (len(all_time_stamps) % 2 ==0) and last_audio: