File size: 6,067 Bytes
a1ca2de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ceefdf5
a1ca2de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
from typing import Literal
from tempfile import NamedTemporaryFile
from pathlib import Path

import uuid
import shutil
import json
import asyncio
import toml

import torch

from audiocraft.models import MusicGen
from audiocraft.data.audio import audio_write
from pydub import AudioSegment

from .utils import (
    set_all_seeds,
)
from .palmchat import (
    palm_prompts,
    gen_text,
)

class MusicMaker:
    # TODO: DocString...
    """Class for generating music from prompts."""

    def __init__(self, model_size: Literal['small', 'medium', 'melody', 'large'] = 'large',
                       output_format: Literal['wav', 'mp3'] = 'mp3',
                       device: str = None) -> None:
        """Initialize the MusicMaker class.

        Args:
            model_size (Literal['small', 'medium', 'melody', 'large'], optional): Model size. Defaults to 'large'.
            output_format (Literal['wav', 'mp3'], optional): Output format. Defaults to 'mp3'.
            device (str, optional): Device to use for the model. Defaults to None.
        """

        self.__model_size = model_size
        self.__output_format = output_format
        self.__device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if not device else device

        print("Loading the MusicGen model into memory...")
        self.__mg_model = MusicGen.get_pretrained(self.model_size, device=self.device)
        self.__mg_model.set_generation_params(use_sampling=True,
                                            top_k=250,
                                            top_p=0.0,
                                            temperature=1.0,
                                            cfg_coef=3.0
                                            )
        
        output_dir = Path('.') / 'outputs'
        if not output_dir.exists():
            output_dir.mkdir(parents=True, exist_ok=True)
        elif output_dir.is_file():
            assert False, f"A file with the same name as the desired directory ('{str(output_dir)}') already exists."
    

    def text2music(self, prompt: str, length: int = 60, seed: int = None) -> str:
        """Generate a music from the prompt.

        Args:
            prompt (str): Prompt to generate the music from.
            length (int, optional): Length of the music in seconds. Defaults to 60.
            seed (int, optional): Seed to use for the generation. Defaults to None.

        Returns:
            str: Path to the generated music.
        """

        def wavToMp3(src_file: str, dest_file: str) -> None:
            sound = AudioSegment.from_wav(src_file)  
            sound.export(dest_file, format="mp3")
        
        output_filename = Path('.') / 'outputs' / str(uuid.uuid4())

        if not seed or seed == -1:
            seed = torch.randint(0, 2**32 - 1, (1,)).item()
        set_all_seeds(seed)

        self.__mg_model.set_generation_params(duration=length)
        output = self.__mg_model.generate(descriptions=[prompt], progress=True)[0]

        with NamedTemporaryFile("wb", delete=True) as temp_file:
            audio_write(temp_file.name, output.cpu(), self.__mg_model.sample_rate, strategy="loudness", loudness_compressor=True)
            if self.output_format == 'mp3':
                wavToMp3(f'{temp_file.name}.wav', str(output_filename.with_suffix('.mp3')))
            else:
                shutil.copy(f'{temp_file.name}.wav', str(output_filename.with_suffix('.wav')))

        return str(output_filename.with_suffix('.mp3' if self.output_format == 'mp3' else '.wav'))


    def generate_prompt(self, genre:str, place:str, mood:str,
                              title:str, chapter_title:str, chapter_plot:str) -> str:
        """Generate a prompt for a background music based on given attributes.

        Args:
            genre (str): Genre of the story.
            place (str): Place of the story.
            mood (str): Mood of the story.
            title (str): Title of the story.
            chapter_title (str): Title of the chapter.
            chapter_plot (str): Plot of the chapter.

        Returns:
            str: Generated prompt.
        """

        # Generate prompts with PaLM
        t = palm_prompts['music_gen']['gen_prompt']
        q = palm_prompts['music_gen']['query']
        query_string = t.format(input=q.format(genre=genre,
                                               place=place,
                                               mood=mood,
                                               title=title,
                                               chapter_title=chapter_title,
                                               chapter_plot=chapter_plot))
        try:
            response, response_txt = asyncio.run(asyncio.wait_for(
                                                    gen_text(query_string, mode="text", use_filter=False),
                                                    timeout=10)
                                                )
        except asyncio.TimeoutError:
            raise TimeoutError("The response time for PaLM API exceeded the limit.")
        
        try: 
            res_json = json.loads(response_txt)
        except:
            print("=== PaLM Response ===")
            print(response.filters)
            print(response_txt)
            print("=== PaLM Response ===")            
            raise ValueError("The response from PaLM API is not in the expected format.")
            
        return res_json['primary_sentence']


    @property
    def model_size(self):
        """Model size

        Returns:
            Literal['small', 'medium', 'melody', 'large']: The model size (read-only)
        """
        return self.__model_size

    @property
    def output_format(self):
        """Output format

        Returns:
            Literal['wav', 'mp3']: The output format (read-only)
        """
        return self.__output_format

    @property
    def device(self):
        """Device

        Returns:
            str: The device (read-only)
        """
        return self.__device