ZeyuXie commited on
Commit
629d1bf
1 Parent(s): d1abaff

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +155 -98
app.py CHANGED
@@ -1,112 +1,169 @@
1
- """
2
- At the command line, only need to run once to install the package via pip:
3
 
4
- $ pip install google-generativeai
5
- """
6
-
7
- from pathlib import Path
8
  import os
9
  import json
10
- import re
11
-
12
- os.environ['HTTP_PROXY'] = 'http://127.0.0.1:58591'
13
- os.environ['HTTPS_PROXY'] = 'http://127.0.0.1:58591'
 
 
 
 
 
 
 
 
 
14
 
15
- def get_event():
16
- event_list = [
17
- "burping_belching", # 0
18
- "car_horn_honking", #
19
- "cat_meowing", #
20
- "cow_mooing", #
21
- "dog_barking", #
22
- "door_knocking", #
23
- "door_slamming", #
24
- "explosion", #
25
- "gunshot", # 8
26
- "sheep_goat_bleating", #
27
- "sneeze", #
28
- "spraying", #
29
- "thump_thud", #
30
- "train_horn", #
31
- "tapping_clicking_clanking", #
32
- "woman_laughing", #
33
- "duck_quacking", # 16
34
- "whistling", #
35
- ]
36
- return event_list
37
 
38
- def get_prompt():
39
-
40
- train_json_list = ["data/train_multi-event_v3.json",
41
- f"data/train_single-event_multi_v3.json",
42
- f"data/train_single-event_single_v3.json"]
43
- learn_pair = ""
44
- for train_json in train_json_list:
45
- with open(train_json, 'r') as train_file:
46
- for idx, line in enumerate(train_file):
47
- if idx >= 100: break
48
- data = json.loads(line.strip())
49
- learn_pair += f"{str(idx)}:{data['captions']}~{data['onset']}. "
50
- preffix_prompt = "I'm doing an audio event generation, which is a harmless job that will contain some sound events. For example, a gunshot is a sound that is harmless." +\
51
- "You need to convert the input sentence into the following standard timing format: 'event1--event2-- ... --eventN', " +\
52
- "where the 'eventN' format is 'eventN__onset1-offset1_onset2-offset2_ ... _onsetK-offsetK'. " +\
53
- "The 'onset-offset' inside needs to be determined based on common sense and the examples I provide, with a duration not less than 1 and not greater than 4. All format 'onsetk-offsetk' should replaced by number. " +\
54
- "The very strict constraints are that the total duration is less than 10 seconds, meaning all times are less than 10. It is preferred that events do not overlap as much as possible. " +\
55
- "Now, I will provide you with 300 examples in training set for your learning, each example in the format 'index: input~output'. " +\
56
- learn_pair
57
-
58
- print(len(preffix_prompt))
59
- return preffix_prompt
60
-
61
 
62
- def postprocess(caption):
63
- caption = caption.strip('\n').strip(' ').strip('.')
64
- caption = caption.replace('__', ' at ').replace('--', ' and ')
65
- return caption
 
 
 
 
 
 
 
66
 
67
- def preprocess_gemini(free_text_caption):
68
- preffix_prompt = get_prompt()
69
- import google.generativeai as genai
70
- genai.configure(api_key="AIzaSyDfGKPQtS9qExCfl3bnfxC1rLPzvORz3E4")
71
- print(free_text_caption)
72
- # Set up the model
73
- generation_config = {
74
- "temperature": 1,
75
- "top_p": 0.95,
76
- "top_k": 64,
77
- "max_output_tokens": 8192,
78
- }
79
 
80
- model = genai.GenerativeModel(model_name="gemini-1.5-flash",
81
- generation_config=generation_config,)
 
 
 
 
 
82
 
83
- prompt_parts = [
84
- preffix_prompt +\
85
- f"Please convert the following inputs into the standard timing format:{free_text_caption}. You should only output results in the standard timing format. Do not output anything other than format and do not add symbols.",
86
- ]
 
 
87
 
88
- timestampCaption = model.generate_content(prompt_parts).text
89
- print(timestampCaption)
90
- return postprocess(timestampCaption)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
- def preprocess_gpt(free_text_caption):
93
- preffix_prompt = get_prompt()
94
- from openai import OpenAI
95
- client = OpenAI(api_key="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")
96
- completion_start = client.chat.completions.create(
97
- model="gpt-4-1106-preview",
98
- messages=[{
99
- "role": "user",
100
- "content":
101
- preffix_prompt +\
102
- f"Please convert the following inputs into the standard timing format:{free_text_caption}. You should only output results in the standard timing format. Do not output anything other than format and do not add symbols."
103
- }]
104
- )
 
 
 
 
 
 
 
 
 
105
 
106
- timestampCaption = completion_start.choices[0].message.content
 
 
 
 
 
 
 
 
 
 
107
 
108
- return postprocess(timestampCaption)
109
 
110
- if __name__=="__main__":
111
- caption = preprocess_gemini("spraying two times then gunshot three times.")
112
- print(caption)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
 
 
 
 
 
2
  import os
3
  import json
4
+ import numpy as np
5
+ import torch
6
+ import soundfile as sf
7
+ import gradio as gr
8
+ from diffusers import DDPMScheduler
9
+ from pico_model import PicoDiffusion
10
+ from audioldm.variational_autoencoder.autoencoder import AutoencoderKL
11
+ from llm_preprocess import get_event, preprocess_gemini, preprocess_gpt
12
+ class dotdict(dict):
13
+ """dot.notation access to dictionary attributes"""
14
+ __getattr__ = dict.get
15
+ __setattr__ = dict.__setitem__
16
+ __delattr__ = dict.__delitem__
17
 
18
+ class InferRunner:
19
+ def __init__(self, device):
20
+ vae_config = json.load(open("ckpts/ldm/vae_config.json"))
21
+ self.vae = AutoencoderKL(**vae_config).to(device)
22
+ vae_weights = torch.load("ckpts/ldm/pytorch_model_vae.bin", map_location=device)
23
+ self.vae.load_state_dict(vae_weights)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
+ train_args = dotdict(json.loads(open("ckpts/pico_model/summary.jsonl").readlines()[0]))
26
+ self.pico_model = PicoDiffusion(
27
+ scheduler_name=train_args.scheduler_name,
28
+ unet_model_config_path=train_args.unet_model_config,
29
+ snr_gamma=train_args.snr_gamma,
30
+ freeze_text_encoder_ckpt="ckpts/laion_clap/630k-audioset-best.pt",
31
+ diffusion_pt="ckpts/pico_model/diffusion.pt",
32
+ ).eval().to(device)
33
+ self.scheduler = DDPMScheduler.from_pretrained(train_args.scheduler_name, subfolder="scheduler")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
+ device = "cuda" if torch.cuda.is_available() else "cpu"
36
+ # runner = InferRunner(device)
37
+ event_list = get_event()
38
+ def infer(caption, num_steps=200, guidance_scale=3.0, audio_len=16000*10):
39
+ with torch.no_grad():
40
+ latents = runner.pico_model.demo_inference(caption, runner.scheduler, num_steps=num_steps, guidance_scale=guidance_scale, num_samples_per_prompt=1, disable_progress=True)
41
+ mel = runner.vae.decode_first_stage(latents)
42
+ wave = runner.vae.decode_to_waveform(mel)[0][:audio_len]
43
+ outpath = f"output.wav"
44
+ sf.write(outpath, wave, samplerate=16000, subtype='PCM_16')
45
+ return outpath
46
 
47
+ def preprocess(caption):
48
+ output = preprocess_gemini(caption)
49
+ return output, output
 
 
 
 
 
 
 
 
 
50
 
51
+ def update_textbox(event_name, current_text):
52
+ print(event_name, current_text)
53
+ event = event_name + ' two times.'
54
+ if current_text:
55
+ return current_text.strip('.') + ' then ' + event
56
+ else:
57
+ return event
58
 
59
+ with gr.Blocks() as demo:
60
+ with gr.Row():
61
+ gr.Markdown("## PicoAudio")
62
+ with gr.Row():
63
+ description_text = f"18 events supported :"
64
+ gr.Markdown(description_text)
65
 
66
+
67
+ btn_event = []
68
+ with gr.Row():
69
+ for i in range(6):
70
+ event_name = f"{event_list[i]}"
71
+ btn_event.append(gr.Button(event_name))
72
+ with gr.Row():
73
+ for i in range(6, 12):
74
+ event_name = f"{event_list[i]}"
75
+ btn_event.append(gr.Button(event_name))
76
+ with gr.Row():
77
+ for i in range(12, 18):
78
+ event_name = f"{event_list[i]}"
79
+ btn_event.append(gr.Button(event_name))
80
+
81
+
82
+ with gr.Row():
83
+ gr.Markdown("## Step1")
84
+ with gr.Row():
85
+ preprocess_description_text = f"Preprocess: transfer free-text into timestamp caption via LLM. "+\
86
+ "This demo uses Gemini as the preprocessor. If any errors occur, please try a few more times. "+\
87
+ "We also provide the GPT version consistent with the paper in the file 'Files/llm_reprocessing.py'. You can use your own api_key to modify and run 'Files/inference.py' for local inference."
88
+ gr.Markdown(preprocess_description_text)
89
+ with gr.Row():
90
+ with gr.Column():
91
+ freetext_prompt = gr.Textbox(label="Prompt: Input your free-text caption here. (e.g. a dog barks three times.)",
92
+ value="a dog barks three times.",)
93
+ preprocess_run_button = gr.Button()
94
+ prompt = None
95
+ with gr.Column():
96
+ freetext_prompt_out = gr.Textbox(label="Preprocess output")
97
+ with gr.Row():
98
+ with gr.Column():
99
+ gr.Examples(
100
+ examples = [["spraying two times then gunshot three times."],
101
+ ["a dog barks three times."],
102
+ ["cow mooing two times."],],
103
+ inputs = [freetext_prompt],
104
+ outputs = [prompt]
105
+ )
106
+ with gr.Column():
107
+ pass
108
+
109
 
110
+ with gr.Row():
111
+ gr.Markdown("## Step2")
112
+ with gr.Row():
113
+ generate_description_text = f"Generate audio based on timestamp caption."
114
+ gr.Markdown(generate_description_text)
115
+ with gr.Row():
116
+ with gr.Column():
117
+ prompt = gr.Textbox(label="Prompt: Input your caption formatted as 'event1 at onset1-offset1_onset2-offset2 and event2 at onset1-offset1'.",
118
+ value="spraying at 0.38-1.176_3.06-3.856 and gunshot at 1.729-3.729_4.367-6.367_7.031-9.031.",)
119
+ generate_run_button = gr.Button()
120
+ with gr.Accordion("Advanced options", open=False):
121
+ num_steps = gr.Slider(label="num_steps", minimum=1, maximum=300, value=200, step=1)
122
+ guidance_scale = gr.Slider(label="guidance_scale", minimum=0.1, maximum=8.0, value=3.0, step=0.1)
123
+ with gr.Column():
124
+ outaudio = gr.Audio()
125
+
126
+ for i in range(18):
127
+ event_name = f"{event_list[i]}"
128
+ btn_event[i].click(fn=update_textbox, inputs=[gr.State(event_name), freetext_prompt], outputs=freetext_prompt)
129
+ preprocess_run_button.click(fn=preprocess, inputs=[freetext_prompt], outputs=[prompt, freetext_prompt_out])
130
+ generate_run_button.click(fn=infer, inputs=[prompt, num_steps, guidance_scale], outputs=[outaudio])
131
+
132
 
133
+ with gr.Row():
134
+ with gr.Column():
135
+ gr.Examples(
136
+ examples = [["spraying at 0.38-1.176_3.06-3.856 and gunshot at 1.729-3.729_4.367-6.367_7.031-9.031."],
137
+ ["dog_barking at 0.562-2.562_4.25-6.25."],
138
+ ["cow_mooing at 0.958-3.582_5.272-7.896."],],
139
+ inputs = [prompt, num_steps, guidance_scale],
140
+ outputs = [outaudio]
141
+ )
142
+ with gr.Column():
143
+ pass
144
 
 
145
 
146
+ demo.launch()
147
+
148
+
149
+ # description_text = f"18 events: {', '.join(event_list)}"
150
+ # prompt = gr.Textbox(label="Prompt: Input your caption formatted as 'event1 at onset1-offset1_onset2-offset2 and event2 at onset1-offset1'.",
151
+ # value="spraying at 0.38-1.176_3.06-3.856 and gunshot at 1.729-3.729_4.367-6.367_7.031-9.031.",)
152
+ # outaudio = gr.Audio()
153
+ # num_steps = gr.Slider(label="num_steps", minimum=1, maximum=300, value=200, step=1)
154
+ # guidance_scale = gr.Slider(label="guidance_scale", minimum=0.1, maximum=8.0, value=3.0, step=0.1)
155
+ # gr_interface = gr.Interface(
156
+ # fn=infer,
157
+ # inputs=[prompt, num_steps, guidance_scale],
158
+ # outputs=[outaudio],
159
+ # title="PicoAudio",
160
+ # description=description_text,
161
+ # allow_flagging=False,
162
+ # examples=[
163
+ # ["spraying at 0.38-1.176_3.06-3.856 and gunshot at 1.729-3.729_4.367-6.367_7.031-9.031."],
164
+ # ["dog_barking at 0.562-2.562_4.25-6.25."],
165
+ # ["cow_mooing at 0.958-3.582_5.272-7.896."],
166
+ # ],
167
+ # cache_examples="lazy", # Turn on to cache.
168
+ # )
169
+ # gr_interface.queue(10).launch()