|
|
|
|
|
|
|
|
|
|
|
"""Dataset of music tracks with rich metadata. |
|
""" |
|
from dataclasses import dataclass, field, fields, replace |
|
import gzip |
|
import json |
|
import logging |
|
from pathlib import Path |
|
import random |
|
import typing as tp |
|
|
|
import torch |
|
|
|
from .info_audio_dataset import ( |
|
InfoAudioDataset, |
|
AudioInfo, |
|
get_keyword_list, |
|
get_keyword, |
|
get_string |
|
) |
|
from ..modules.conditioners import ( |
|
ConditioningAttributes, |
|
JointEmbedCondition, |
|
WavCondition, |
|
) |
|
from ..utils.utils import warn_once |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
@dataclass |
|
class MusicInfo(AudioInfo): |
|
"""Segment info augmented with music metadata. |
|
""" |
|
|
|
title: tp.Optional[str] = None |
|
artist: tp.Optional[str] = None |
|
key: tp.Optional[str] = None |
|
bpm: tp.Optional[float] = None |
|
genre: tp.Optional[str] = None |
|
moods: tp.Optional[list] = None |
|
keywords: tp.Optional[list] = None |
|
description: tp.Optional[str] = None |
|
name: tp.Optional[str] = None |
|
instrument: tp.Optional[str] = None |
|
|
|
self_wav: tp.Optional[WavCondition] = None |
|
|
|
joint_embed: tp.Dict[str, JointEmbedCondition] = field(default_factory=dict) |
|
|
|
@property |
|
def has_music_meta(self) -> bool: |
|
return self.name is not None |
|
|
|
def to_condition_attributes(self) -> ConditioningAttributes: |
|
out = ConditioningAttributes() |
|
for _field in fields(self): |
|
key, value = _field.name, getattr(self, _field.name) |
|
if key == 'self_wav': |
|
out.wav[key] = value |
|
elif key == 'joint_embed': |
|
for embed_attribute, embed_cond in value.items(): |
|
out.joint_embed[embed_attribute] = embed_cond |
|
else: |
|
if isinstance(value, list): |
|
value = ' '.join(value) |
|
out.text[key] = value |
|
return out |
|
|
|
@staticmethod |
|
def attribute_getter(attribute): |
|
if attribute == 'bpm': |
|
preprocess_func = get_bpm |
|
elif attribute == 'key': |
|
preprocess_func = get_musical_key |
|
elif attribute in ['moods', 'keywords']: |
|
preprocess_func = get_keyword_list |
|
elif attribute in ['genre', 'name', 'instrument']: |
|
preprocess_func = get_keyword |
|
elif attribute in ['title', 'artist', 'description']: |
|
preprocess_func = get_string |
|
else: |
|
preprocess_func = None |
|
return preprocess_func |
|
|
|
@classmethod |
|
def from_dict(cls, dictionary: dict, fields_required: bool = False): |
|
_dictionary: tp.Dict[str, tp.Any] = {} |
|
|
|
|
|
|
|
post_init_attributes = ['self_wav', 'joint_embed'] |
|
optional_fields = ['keywords'] |
|
|
|
for _field in fields(cls): |
|
if _field.name in post_init_attributes: |
|
continue |
|
elif _field.name not in dictionary: |
|
if fields_required and _field.name not in optional_fields: |
|
raise KeyError(f"Unexpected missing key: {_field.name}") |
|
else: |
|
preprocess_func: tp.Optional[tp.Callable] = cls.attribute_getter(_field.name) |
|
value = dictionary[_field.name] |
|
if preprocess_func: |
|
value = preprocess_func(value) |
|
_dictionary[_field.name] = value |
|
return cls(**_dictionary) |
|
|
|
|
|
def augment_music_info_description(music_info: MusicInfo, merge_text_p: float = 0., |
|
drop_desc_p: float = 0., drop_other_p: float = 0.) -> MusicInfo: |
|
"""Augment MusicInfo description with additional metadata fields and potential dropout. |
|
Additional textual attributes are added given probability 'merge_text_conditions_p' and |
|
the original textual description is dropped from the augmented description given probability drop_desc_p. |
|
|
|
Args: |
|
music_info (MusicInfo): The music metadata to augment. |
|
merge_text_p (float): Probability of merging additional metadata to the description. |
|
If provided value is 0, then no merging is performed. |
|
drop_desc_p (float): Probability of dropping the original description on text merge. |
|
if provided value is 0, then no drop out is performed. |
|
drop_other_p (float): Probability of dropping the other fields used for text augmentation. |
|
Returns: |
|
MusicInfo: The MusicInfo with augmented textual description. |
|
""" |
|
def is_valid_field(field_name: str, field_value: tp.Any) -> bool: |
|
valid_field_name = field_name in ['key', 'bpm', 'genre', 'moods', 'instrument', 'keywords'] |
|
valid_field_value = field_value is not None and isinstance(field_value, (int, float, str, list)) |
|
keep_field = random.uniform(0, 1) < drop_other_p |
|
return valid_field_name and valid_field_value and keep_field |
|
|
|
def process_value(v: tp.Any) -> str: |
|
if isinstance(v, (int, float, str)): |
|
return str(v) |
|
if isinstance(v, list): |
|
return ", ".join(v) |
|
else: |
|
raise ValueError(f"Unknown type for text value! ({type(v), v})") |
|
|
|
description = music_info.description |
|
|
|
metadata_text = "" |
|
if random.uniform(0, 1) < merge_text_p: |
|
meta_pairs = [f'{_field.name}: {process_value(getattr(music_info, _field.name))}' |
|
for _field in fields(music_info) if is_valid_field(_field.name, getattr(music_info, _field.name))] |
|
random.shuffle(meta_pairs) |
|
metadata_text = ". ".join(meta_pairs) |
|
description = description if not random.uniform(0, 1) < drop_desc_p else None |
|
logger.debug(f"Applying text augmentation on MMI info. description: {description}, metadata: {metadata_text}") |
|
|
|
if description is None: |
|
description = metadata_text if len(metadata_text) > 1 else None |
|
else: |
|
description = ". ".join([description.rstrip('.'), metadata_text]) |
|
description = description.strip() if description else None |
|
|
|
music_info = replace(music_info) |
|
music_info.description = description |
|
return music_info |
|
|
|
|
|
class Paraphraser: |
|
def __init__(self, paraphrase_source: tp.Union[str, Path], paraphrase_p: float = 0.): |
|
self.paraphrase_p = paraphrase_p |
|
open_fn = gzip.open if str(paraphrase_source).lower().endswith('.gz') else open |
|
with open_fn(paraphrase_source, 'rb') as f: |
|
self.paraphrase_source = json.loads(f.read()) |
|
logger.info(f"loaded paraphrasing source from: {paraphrase_source}") |
|
|
|
def sample_paraphrase(self, audio_path: str, description: str): |
|
if random.random() >= self.paraphrase_p: |
|
return description |
|
info_path = Path(audio_path).with_suffix('.json') |
|
if info_path not in self.paraphrase_source: |
|
warn_once(logger, f"{info_path} not in paraphrase source!") |
|
return description |
|
new_desc = random.choice(self.paraphrase_source[info_path]) |
|
logger.debug(f"{description} -> {new_desc}") |
|
return new_desc |
|
|
|
|
|
class MusicDataset(InfoAudioDataset): |
|
"""Music dataset is an AudioDataset with music-related metadata. |
|
|
|
Args: |
|
info_fields_required (bool): Whether to enforce having required fields. |
|
merge_text_p (float): Probability of merging additional metadata to the description. |
|
drop_desc_p (float): Probability of dropping the original description on text merge. |
|
drop_other_p (float): Probability of dropping the other fields used for text augmentation. |
|
joint_embed_attributes (list[str]): A list of attributes for which joint embedding metadata is returned. |
|
paraphrase_source (str, optional): Path to the .json or .json.gz file containing the |
|
paraphrases for the description. The json should be a dict with keys are the |
|
original info path (e.g. track_path.json) and each value is a list of possible |
|
paraphrased. |
|
paraphrase_p (float): probability of taking a paraphrase. |
|
|
|
See `audiocraft.data.info_audio_dataset.InfoAudioDataset` for full initialization arguments. |
|
""" |
|
def __init__(self, *args, info_fields_required: bool = True, |
|
merge_text_p: float = 0., drop_desc_p: float = 0., drop_other_p: float = 0., |
|
joint_embed_attributes: tp.List[str] = [], |
|
paraphrase_source: tp.Optional[str] = None, paraphrase_p: float = 0, |
|
**kwargs): |
|
kwargs['return_info'] = True |
|
super().__init__(*args, **kwargs) |
|
self.info_fields_required = info_fields_required |
|
self.merge_text_p = merge_text_p |
|
self.drop_desc_p = drop_desc_p |
|
self.drop_other_p = drop_other_p |
|
self.joint_embed_attributes = joint_embed_attributes |
|
self.paraphraser = None |
|
if paraphrase_source is not None: |
|
self.paraphraser = Paraphraser(paraphrase_source, paraphrase_p) |
|
|
|
def __getitem__(self, index): |
|
wav, info = super().__getitem__(index) |
|
info_data = info.to_dict() |
|
music_info_path = Path(info.meta.path).with_suffix('.json') |
|
|
|
if Path(music_info_path).exists(): |
|
with open(music_info_path, 'r') as json_file: |
|
music_data = json.load(json_file) |
|
music_data.update(info_data) |
|
music_info = MusicInfo.from_dict(music_data, fields_required=self.info_fields_required) |
|
if self.paraphraser is not None: |
|
music_info.description = self.paraphraser.sample(music_info.meta.path, music_info.description) |
|
if self.merge_text_p: |
|
music_info = augment_music_info_description( |
|
music_info, self.merge_text_p, self.drop_desc_p, self.drop_other_p) |
|
else: |
|
music_info = MusicInfo.from_dict(info_data, fields_required=False) |
|
|
|
music_info.self_wav = WavCondition( |
|
wav=wav[None], length=torch.tensor([info.n_frames]), |
|
sample_rate=[info.sample_rate], path=[info.meta.path], seek_time=[info.seek_time]) |
|
|
|
for att in self.joint_embed_attributes: |
|
att_value = getattr(music_info, att) |
|
joint_embed_cond = JointEmbedCondition( |
|
wav[None], [att_value], torch.tensor([info.n_frames]), |
|
sample_rate=[info.sample_rate], path=[info.meta.path], seek_time=[info.seek_time]) |
|
music_info.joint_embed[att] = joint_embed_cond |
|
|
|
return wav, music_info |
|
|
|
|
|
def get_musical_key(value: tp.Optional[str]) -> tp.Optional[str]: |
|
"""Preprocess key keywords, discarding them if there are multiple key defined.""" |
|
if value is None or (not isinstance(value, str)) or len(value) == 0 or value == 'None': |
|
return None |
|
elif ',' in value: |
|
|
|
return None |
|
else: |
|
return value.strip().lower() |
|
|
|
|
|
def get_bpm(value: tp.Optional[str]) -> tp.Optional[float]: |
|
"""Preprocess to a float.""" |
|
if value is None: |
|
return None |
|
try: |
|
return float(value) |
|
except ValueError: |
|
return None |
|
|