import gradio as gr from transformers import AutoModelForCausalLM, AutoTokenizer import torch import os import copy import re import secrets from pathlib import Path from pydub import AudioSegment import ast torch.manual_seed(420) tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-Audio-Chat", trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-Audio-Chat", device_map="cuda", trust_remote_code=True).eval() def _parse_text(text): lines = text.split("\n") lines = [line for line in lines if line != ""] count = 0 for i, line in enumerate(lines): if "```" in line: count += 1 items = line.split("`") if count % 2 == 1: lines[i] = f'
'
            else:
                lines[i] = f"
" else: if i > 0: if count % 2 == 1: line = line.replace("`", r"\`") line = line.replace("<", "<") line = line.replace(">", ">") line = line.replace(" ", " ") line = line.replace("*", "*") line = line.replace("_", "_") line = line.replace("-", "-") line = line.replace(".", ".") line = line.replace("!", "!") line = line.replace("(", "(") line = line.replace(")", ")") line = line.replace("$", "$") lines[i] = "
" + line text = "".join(lines) return text def predict(_chatbot, task_history_str, user_input): print("Predict - Start: task_history_str =", task_history_str) task_history = parse_task_history(task_history_str) print("Type of user_input:", type(user_input)) print("Type of task_history:", type(task_history)) if not isinstance(task_history, list): task_history = [] query = user_input if user_input else (task_history[-1][0] if task_history else "") print("User: " + _parse_text(query)) if not task_history: return _chatbot history_cp = copy.deepcopy(task_history) history_filter = [] audio_idx = 1 pre = "" last_audio = None for item in history_cp: q, a = item if isinstance(q, (tuple, list)): last_audio = q[0] q = f'Audio {audio_idx}: ' pre += q + '\n' audio_idx += 1 else: pre += q history_filter.append((pre, a)) pre = "" if not history_filter: return _chatbot history, message = history_filter[:-1], history_filter[-1][0] response, history = model.chat(tokenizer, message, history=history) ts_pattern = r"<\|\d{1,2}\.\d+\|>" all_time_stamps = re.findall(ts_pattern, response) if all_time_stamps and last_audio: ts_float = [ float(t.replace("<|","").replace("|>","")) for t in all_time_stamps] ts_float_pair = [ts_float[i:i + 2] for i in range(0,len(all_time_stamps),2)] # 读取音频文件 format = os.path.splitext(last_audio)[-1].replace(".","") audio_file = AudioSegment.from_file(last_audio, format=format) chat_response_t = response.replace("<|", "").replace("|>", "") chat_response = chat_response_t temp_dir = secrets.token_hex(20) temp_dir = Path(uploaded_file_dir) / temp_dir temp_dir.mkdir(exist_ok=True, parents=True) # 截取音频文件 for pair in ts_float_pair: audio_clip = audio_file[pair[0] * 1000: pair[1] * 1000] # 保存音频文件 name = f"tmp{secrets.token_hex(5)}.{format}" filename = temp_dir / name audio_clip.export(filename, format=format) _chatbot[-1] = (_parse_text(query), chat_response) _chatbot.append((None, (str(filename),))) return str(filename), _chatbot else: _chatbot.append((query, response)) return response, _chatbot print("Predict - End: task_history =", task_history) return _chatbot[-1][1], _chatbot def parse_task_history(task_history_str): try: parsed_task_history = ast.literal_eval(task_history_str) if isinstance(parsed_task_history, list) and all(isinstance(item, tuple) and len(item) == 2 for item in parsed_task_history): return parsed_task_history else: raise ValueError("Parsed task history is not a list of tuples") except Exception as e: print(f"Error parsing task history: {e}") return [] def regenerate(_chatbot, task_history): if task_history is None or not isinstance(task_history, list): task_history = [] print("Regenerate - Start: task_history =", task_history) if not task_history: return _chatbot item = task_history[-1] if item[1] is None: return _chatbot task_history[-1] = (item[0], None) chatbot_item = _chatbot.pop(-1) if chatbot_item[0] is None: _chatbot[-1] = (_chatbot[-1][0], None) else: _chatbot.append((chatbot_item[0], None)) print("Regenerate - End: task_history =", task_history) return predict(_chatbot, task_history) def add_text(history, task_history, text): if task_history is None or not isinstance(task_history, list): task_history = [] print("Add Text - Before: task_history =", task_history) if not isinstance(task_history, list): task_history = [] history.append((_parse_text(text), None)) task_history.append((text, None)) print("Add Text - After: task_history =", task_history) return history, task_history def add_file(history, task_history, file): if task_history is None or not isinstance(task_history, list): task_history = [] print("Add File - Before: task_history =", task_history) history.append(((file.name,), None)) task_history.append(((file.name,), None)) print("Add File - After: task_history =", task_history) return history, task_history def add_mic(history, task_history, file): if task_history is None or not isinstance(task_history, list): task_history = [] print("Add Mic - Before: task_history =", task_history) if file is None: return history, task_history file_with_extension = file + '.wav' os.rename(file, file_with_extension) history.append(((file_with_extension,), None)) task_history.append(((file_with_extension,), None)) print("Add Mic - After: task_history =", task_history) return history, task_history def reset_user_input(): return gr.update(value="") def reset_state(task_history): if task_history is None or not isinstance(task_history, list): task_history = [] print("Reset State - Before: task_history =", task_history) task_history = [] print("Reset State - After: task_history =", task_history) return [] iface = gr.Interface( fn=predict, inputs=[ gr.Audio(label="Audio Input"), gr.Textbox(label="Text Query"), gr.State() ], outputs=[ "text", gr.State() ], title="Audio-Text Interaction Model", description="This model can process an audio input along with a text query and provide a response.", theme="default", allow_flagging="never" ) iface.launch()