Spaces:
Running
on
L4
Running
on
L4
import os | |
import yaml | |
import json | |
import torch | |
import spaces | |
import librosa | |
import argparse | |
import numpy as np | |
import gradio as gr | |
from tqdm import tqdm | |
import soundfile as sf | |
from pydub import AudioSegment | |
from safetensors.torch import load_file | |
from huggingface_hub import snapshot_download | |
from data.data import get_audiotext_dataloader | |
from src.factory import create_model_and_transforms | |
from train.train_utils import Dict2Class, get_autocast, get_cast_dtype | |
HEADER = (""" | |
<div style="display: flex; justify-content: center; align-items: center; text-align: center;"> | |
<a href="https://github.com/NVIDIA/audio-flamingo" style="margin-right: 20px; text-decoration: none; display: flex; align-items: center;"> | |
<img src="https://github.com/NVIDIA/audio-flamingo/blob/main/assets/af_logo.png?raw=true" alt="Audio Flamingo 2 🔥🚀🔥" style="max-width: 120px; height: auto;"> | |
</a> | |
<div> | |
<h1>Audio Flamingo 2: An Audio-Language Model with Long-Audio Understanding and Expert Reasoning Abilities</h1> | |
<h5 style="margin: 0;">If this demo please you, please give us a star ⭐ on Github or 💖 on this space.</h5> | |
</div> | |
</div> | |
<div style="display: flex; justify-content: center; margin-top: 10px;"> | |
<a href="https://github.com/NVIDIA/audio-flamingo"><img src='https://img.shields.io/badge/Github-AudioFlamingo2-9C276A' style="margin-right: 5px;"></a> | |
<a href="https://arxiv.org/abs/2503.03983"><img src="https://img.shields.io/badge/Arxiv-2503.03983-AD1C18" style="margin-right: 5px;"></a> | |
<a href="https://huggingface.co/nvidia/audio-flamingo-2"><img src="https://img.shields.io/badge/🤗-Checkpoints-ED5A22.svg" style="margin-right: 5px;"></a> | |
<a href="https://github.com/NVIDIA/audio-flamingo/stargazers"><img src="https://img.shields.io/github/stars/NVIDIA/audio-flamingo.svg?style=social"></a> | |
</div> | |
""") | |
def int16_to_float32(x): | |
return (x / 32767.0).astype(np.float32) | |
def float32_to_int16(x): | |
x = np.clip(x, a_min=-1., a_max=1.) | |
return (x * 32767.).astype(np.int16) | |
os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
api_key = os.getenv("my_secret") | |
snapshot_download(repo_id="nvidia/audio-flamingo-2-1.5B", local_dir="./", token=api_key) | |
config = yaml.load(open("configs/inference.yaml"), Loader=yaml.FullLoader) | |
data_config = config['data_config'] | |
model_config = config['model_config'] | |
clap_config = config['clap_config'] | |
args = Dict2Class(config['train_config']) | |
model, tokenizer = create_model_and_transforms( | |
**model_config, | |
clap_config=clap_config, | |
use_local_files=args.offline, | |
gradient_checkpointing=args.gradient_checkpointing, | |
freeze_lm_embeddings=args.freeze_lm_embeddings, | |
) | |
device_id = 0 | |
model = model.to(device_id) | |
model.eval() | |
# Load metadata | |
with open("safe_ckpt/metadata.json", "r") as f: | |
metadata = json.load(f) | |
# Reconstruct the full state_dict | |
state_dict = {} | |
# Load each SafeTensors chunk | |
for chunk_name in metadata: | |
chunk_path = f"safe_ckpt/{chunk_name}.safetensors" | |
chunk_tensors = load_file(chunk_path) | |
# Merge tensors into state_dict | |
state_dict.update(chunk_tensors) | |
x,y = model.load_state_dict(state_dict, False) | |
autocast = get_autocast( | |
args.precision, cache_enabled=(not args.fsdp) | |
) | |
cast_dtype = get_cast_dtype(args.precision) | |
def get_num_windows(T, sr): | |
window_length = int(float(clap_config["window_length"]) * sr) | |
window_overlap = int(float(clap_config["window_overlap"]) * sr) | |
max_num_window = int(clap_config["max_num_window"]) | |
num_windows = 1 | |
if T <= window_length: | |
num_windows = 1 | |
full_length = window_length | |
elif T >= (max_num_window * window_length - (max_num_window - 1) * window_overlap): | |
num_windows = max_num_window | |
full_length = (max_num_window * window_length - (max_num_window - 1) * window_overlap) | |
else: | |
num_windows = 1 + int(np.ceil((T - window_length) / float(window_length - window_overlap))) | |
full_length = num_windows * window_length - (num_windows - 1) * window_overlap | |
return num_windows, full_length | |
def read_audio(file_path, target_sr=16000, duration=30.0, start=0.0): | |
if file_path.endswith('.mp3'): | |
audio = AudioSegment.from_file(file_path) | |
if len(audio) > (start + duration) * 1000: | |
audio = audio[start * 1000:(start + duration) * 1000] | |
if audio.frame_rate != target_sr: | |
audio = audio.set_frame_rate(target_sr) | |
if audio.channels > 1: | |
audio = audio.set_channels(1) | |
data = np.array(audio.get_array_of_samples()) | |
if audio.sample_width == 2: | |
data = data.astype(np.float32) / np.iinfo(np.int16).max | |
elif audio.sample_width == 4: | |
data = data.astype(np.float32) / np.iinfo(np.int32).max | |
else: | |
raise ValueError("Unsupported bit depth: {}".format(audio.sample_width)) | |
else: | |
with sf.SoundFile(file_path) as audio: | |
original_sr = audio.samplerate | |
channels = audio.channels | |
max_frames = int((start + duration) * original_sr) | |
audio.seek(int(start * original_sr)) | |
frames_to_read = min(max_frames, len(audio)) | |
data = audio.read(frames_to_read) | |
if data.max() > 1 or data.min() < -1: | |
data = data / max(abs(data.max()), abs(data.min())) | |
if original_sr != target_sr: | |
if channels == 1: | |
data = librosa.resample(data.flatten(), orig_sr=original_sr, target_sr=target_sr) | |
else: | |
data = librosa.resample(data.T, orig_sr=original_sr, target_sr=target_sr)[0] | |
else: | |
if channels != 1: | |
data = data.T[0] | |
if data.min() >= 0: | |
data = 2 * data / abs(data.max()) - 1.0 | |
else: | |
data = data / max(abs(data.max()), abs(data.min())) | |
assert len(data.shape) == 1, data.shape | |
return data | |
def load_audio(audio_path): | |
sr = 16000 | |
window_length = int(float(clap_config["window_length"]) * sr) | |
window_overlap = int(float(clap_config["window_overlap"]) * sr) | |
max_num_window = int(clap_config["max_num_window"]) | |
duration = max_num_window * (clap_config["window_length"] - clap_config["window_overlap"]) + clap_config["window_overlap"] | |
audio_data = read_audio(audio_path, sr, duration, 0.0) # hard code audio start to 0.0 | |
T = len(audio_data) | |
num_windows, full_length = get_num_windows(T, sr) | |
# pads to the nearest multiple of window_length | |
if full_length > T: | |
audio_data = np.append(audio_data, np.zeros(full_length - T)) | |
audio_data = audio_data.reshape(1, -1) | |
audio_data_tensor = torch.from_numpy(int16_to_float32(float32_to_int16(audio_data))).float() | |
audio_clips = [] | |
audio_embed_mask = torch.ones(num_windows) | |
for i in range(num_windows): | |
start = i * (window_length - window_overlap) | |
audio_data_tensor_this = audio_data_tensor[:, start:start+window_length] | |
audio_clips.append(audio_data_tensor_this) | |
if len(audio_clips) < max_num_window: | |
audio_clips = audio_clips[:max_num_window] | |
audio_embed_mask = audio_embed_mask[:max_num_window] | |
audio_clips = torch.cat(audio_clips) | |
return audio_clips, audio_embed_mask | |
def predict(filepath, question): | |
audio_clips, audio_embed_mask = load_audio(filepath) | |
audio_clips = audio_clips.to(device_id, dtype=cast_dtype, non_blocking=True) | |
audio_embed_mask = audio_embed_mask.to(device_id, dtype=cast_dtype, non_blocking=True) | |
text_prompt = str(question).lower() | |
text_output = str(question).lower() | |
sample = f"<audio>{text_prompt.strip()}{tokenizer.sep_token}" | |
# None<|endofchunk|>{tokenizer.eos_token}" | |
text = tokenizer( | |
sample, | |
max_length=512, | |
padding="longest", | |
truncation="only_first", | |
return_tensors="pt" | |
) | |
input_ids = text["input_ids"].to(device_id, non_blocking=True) | |
media_token_id = tokenizer.encode("<audio>")[-1] | |
sep_token_id = tokenizer.sep_token_id | |
prompt = input_ids | |
with torch.no_grad(): | |
output = model.generate( | |
audio_x=audio_clips.unsqueeze(0), | |
audio_x_mask=audio_embed_mask.unsqueeze(0), | |
lang_x=prompt, | |
eos_token_id=tokenizer.eos_token_id, | |
max_new_tokens=256, | |
temperature=0.0)[0] | |
output_decoded = tokenizer.decode(output).split(tokenizer.sep_token)[-1].replace(tokenizer.eos_token, '').replace(tokenizer.pad_token, '').replace('<|endofchunk|>', '') | |
return output_decoded | |
audio_examples = [ | |
["./examples/soundcap1.wav", "What is the soundscape in this audio?"], | |
["./examples/muscicap1.wav", "Summarize the music content in a sentence."], | |
["./examples/mmau1.wav", "What specific sounds can be distinguished from the audio clip? (A) Helicopter and impact sounds (B) Whistling and chatter (C) Car honking and raindrops (D) Birds chirping and water flowing"], | |
] | |
demo = gr.Blocks() | |
with demo: | |
gr.HTML(HEADER) | |
gr.Interface(fn=predict, | |
inputs=[gr.Audio(type="filepath"), gr.Textbox(value='Describe the audio.', label='Question')], | |
outputs=[gr.Textbox(label="Audio Flamingo 2 Output")], | |
cache_examples=True, | |
examples=audio_examples, | |
title="Audio Flamingo 2 Demo", | |
description="Audio Flamingo 2 is NVIDIA's latest Large Audio-Language Model that is capable of understanding audio inputs and answer any open-ended question about it. <br>" + | |
"**Audio Flamingo 2 is not an ASR model and has limited ability to recognize the speech content. It primarily focuses on perception and understanding of non-speech sounds and music.**<br>" + | |
"The demo is hosted on the Stage 2 checkpoints and supports upto 90 seconds of audios. Stage 3 checkpoints that support upto 5 minutes will be released at a later point.") | |
demo.launch(share=True) |