Spaces:
Running
on
Zero
Running
on
Zero
# Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
from dataclasses import dataclass | |
from math import pi | |
from typing import Optional | |
import torch | |
import torch.nn as nn | |
import torch.utils.checkpoint | |
from ...configuration_utils import ConfigMixin, register_to_config | |
from ...models.modeling_utils import ModelMixin | |
from ...utils import BaseOutput, logging | |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name | |
class StableAudioPositionalEmbedding(nn.Module): | |
"""Used for continuous time""" | |
def __init__(self, dim: int): | |
super().__init__() | |
assert (dim % 2) == 0 | |
half_dim = dim // 2 | |
self.weights = nn.Parameter(torch.randn(half_dim)) | |
def forward(self, times: torch.Tensor) -> torch.Tensor: | |
times = times[..., None] | |
freqs = times * self.weights[None] * 2 * pi | |
fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1) | |
fouriered = torch.cat((times, fouriered), dim=-1) | |
return fouriered | |
class StableAudioProjectionModelOutput(BaseOutput): | |
""" | |
Args: | |
Class for StableAudio projection layer's outputs. | |
text_hidden_states (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): | |
Sequence of hidden-states obtained by linearly projecting the hidden-states for the text encoder. | |
seconds_start_hidden_states (`torch.Tensor` of shape `(batch_size, 1, hidden_size)`, *optional*): | |
Sequence of hidden-states obtained by linearly projecting the audio start hidden states. | |
seconds_end_hidden_states (`torch.Tensor` of shape `(batch_size, 1, hidden_size)`, *optional*): | |
Sequence of hidden-states obtained by linearly projecting the audio end hidden states. | |
""" | |
text_hidden_states: Optional[torch.Tensor] = None | |
seconds_start_hidden_states: Optional[torch.Tensor] = None | |
seconds_end_hidden_states: Optional[torch.Tensor] = None | |
class StableAudioNumberConditioner(nn.Module): | |
""" | |
A simple linear projection model to map numbers to a latent space. | |
Args: | |
number_embedding_dim (`int`): | |
Dimensionality of the number embeddings. | |
min_value (`int`): | |
The minimum value of the seconds number conditioning modules. | |
max_value (`int`): | |
The maximum value of the seconds number conditioning modules | |
internal_dim (`int`): | |
Dimensionality of the intermediate number hidden states. | |
""" | |
def __init__( | |
self, | |
number_embedding_dim, | |
min_value, | |
max_value, | |
internal_dim: Optional[int] = 256, | |
): | |
super().__init__() | |
self.time_positional_embedding = nn.Sequential( | |
StableAudioPositionalEmbedding(internal_dim), | |
nn.Linear(in_features=internal_dim + 1, out_features=number_embedding_dim), | |
) | |
self.number_embedding_dim = number_embedding_dim | |
self.min_value = min_value | |
self.max_value = max_value | |
def forward( | |
self, | |
floats: torch.Tensor, | |
): | |
floats = floats.clamp(self.min_value, self.max_value) | |
normalized_floats = (floats - self.min_value) / (self.max_value - self.min_value) | |
# Cast floats to same type as embedder | |
embedder_dtype = next(self.time_positional_embedding.parameters()).dtype | |
normalized_floats = normalized_floats.to(embedder_dtype) | |
embedding = self.time_positional_embedding(normalized_floats) | |
float_embeds = embedding.view(-1, 1, self.number_embedding_dim) | |
return float_embeds | |
class StableAudioProjectionModel(ModelMixin, ConfigMixin): | |
""" | |
A simple linear projection model to map the conditioning values to a shared latent space. | |
Args: | |
text_encoder_dim (`int`): | |
Dimensionality of the text embeddings from the text encoder (T5). | |
conditioning_dim (`int`): | |
Dimensionality of the output conditioning tensors. | |
min_value (`int`): | |
The minimum value of the seconds number conditioning modules. | |
max_value (`int`): | |
The maximum value of the seconds number conditioning modules | |
""" | |
def __init__(self, text_encoder_dim, conditioning_dim, min_value, max_value): | |
super().__init__() | |
self.text_projection = ( | |
nn.Identity() if conditioning_dim == text_encoder_dim else nn.Linear(text_encoder_dim, conditioning_dim) | |
) | |
self.start_number_conditioner = StableAudioNumberConditioner(conditioning_dim, min_value, max_value) | |
self.end_number_conditioner = StableAudioNumberConditioner(conditioning_dim, min_value, max_value) | |
def forward( | |
self, | |
text_hidden_states: Optional[torch.Tensor] = None, | |
start_seconds: Optional[torch.Tensor] = None, | |
end_seconds: Optional[torch.Tensor] = None, | |
): | |
text_hidden_states = ( | |
text_hidden_states if text_hidden_states is None else self.text_projection(text_hidden_states) | |
) | |
seconds_start_hidden_states = ( | |
start_seconds if start_seconds is None else self.start_number_conditioner(start_seconds) | |
) | |
seconds_end_hidden_states = end_seconds if end_seconds is None else self.end_number_conditioner(end_seconds) | |
return StableAudioProjectionModelOutput( | |
text_hidden_states=text_hidden_states, | |
seconds_start_hidden_states=seconds_start_hidden_states, | |
seconds_end_hidden_states=seconds_end_hidden_states, | |
) | |