import spaces import os import re import gradio as gr import torch import librosa import numpy as np from transformers import AutoProcessor, Qwen2AudioForConditionalGeneration, TextIteratorStreamer import torchaudio from threading import Thread # Model paths and configuration model_path_1 = "./model" model_path_2 = "./model2" base_model_id = "Qwen/Qwen2-Audio-7B-Instruct" # Dictionary to store loaded models and processors loaded_models = {} # Load the model and processor def load_model(model_path): # Check if model is already loaded if model_path in loaded_models: return loaded_models[model_path] # Load the processor from the base model processor = AutoProcessor.from_pretrained( base_model_id, trust_remote_code=True, ) # Load the model model = Qwen2AudioForConditionalGeneration.from_pretrained( model_path, torch_dtype=torch.bfloat16, trust_remote_code=True, device_map="auto", ) model.eval() # Store in cache loaded_models[model_path] = (model, processor) return model, processor # Initialize first model and processor model, processor = load_model(model_path_1) def process_output(output): if "" in output: rest = output.split("")[1] output = "\n" + rest elif "" in output: rest = output.split("")[1] output = "\n" + rest elif "" in output: rest = output.split("")[1] output = "\n" + rest elif "" in output: rest = output.split("")[0] output = rest + "\n\n\n" elif "" in output: rest = output.split("")[0] output = rest + "\n\n\n" elif "" in output: rest = output.split("")[0] output = rest + "\n\n" output = output.replace("\\n", "\n") output = output.replace("\\", "\n") output = output.replace("\n-", "-") return output # Keep only the process_audio_streaming function that's actually used in the Gradio interface @spaces.GPU def process_audio_streaming(audio_file, model_choice): # Load the selected model model_path = model_path_1 if model_choice == "Think" else model_path_2 model, processor = load_model(model_path) # Load and process the audio with torchaudio waveform, sr = torchaudio.load(audio_file) # Resample to 16kHz if needed if sr != 16000: waveform = torchaudio.functional.resample(waveform, sr, 16000) sr = 16000 # Convert to mono if stereo if waveform.shape[0] > 1: waveform = torch.mean(waveform, dim=0, keepdim=True) # Get the audio data as numpy array y = waveform.squeeze().numpy() # Set sampling rate for the processor sampling_rate = 16000 # Create conversation format conversation = [ {"role": "user", "content": [ {"type": "audio", "audio": y}, {"type": "text", "text": "Describe the audio in detail."} ]} ] # Format the chat chat_text = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False) # Process the inputs inputs = processor( text=chat_text, audios=[y], return_tensors="pt", sampling_rate=sampling_rate, ).to(model.device) # Create a streamer instance streamer = TextIteratorStreamer( processor.tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True) # Initialize an empty string to store the generated text accumulated_output = "" # Generate the output with streaming with torch.no_grad(): generate_kwargs = dict( **inputs, streamer=streamer, max_new_tokens=768, do_sample=False, ) t = Thread(target=model.generate, kwargs=generate_kwargs) t.start() # Yield the final outputs for output in streamer: output = process_output(output) accumulated_output += output # Append new output to the accumulated string yield accumulated_output # Yield the accumulated output # Create Gradio interface for audio processing audio_demo = gr.Interface( fn=process_audio_streaming, inputs=[ gr.Audio(type="filepath", label="Upload Audio"), gr.Radio(["Think", "Think + Semantics"], label="Select Model", value="Think + Semantics") ], outputs=gr.Textbox(label="Generated Output", lines=30), title="SemThink", description="Upload an audio file and the model will provide detailed analysis and description. Choose between different model versions.", examples=[["examples/1.wav", "Think + Semantics"]], # Updated default model in examples cache_examples=False, live=True # Enable live updates ) # Launch the apps if __name__ == "__main__": audio_demo.launch()