imagetomusic / app.py
Shad0ws's picture
Update app.py
1675112
raw
history blame
5.43 kB
import time
import gradio as gr
from sentence_transformers import SentenceTransformer
import httpx
import json
from utils import get_tags_for_prompts, get_mubert_tags_embeddings, get_pat
#import subprocess
import os
import uuid
from tempfile import gettempdir
from PIL import Image
import cv2
from pprint import pprint
minilm = SentenceTransformer('all-MiniLM-L6-v2')
mubert_tags_embeddings = get_mubert_tags_embeddings(minilm)
# image_to_text = gr.Interface.load("spaces/doevent/image_to_text", api_key=os.environ['HF_TOKEN'])
image_to_text = gr.Blocks.load(name="spaces/banana-dev/demo-clip-interrogator")
def center_crop(img, dim: tuple = (512, 512)):
"""Returns center cropped image
Args:
img: image to be center cropped
dim: dimensions (width, height) to be cropped
"""
width, height = img.shape[1], img.shape[0]
# process crop width and height for max available dimension
crop_width = dim[0] if dim[0]<img.shape[1] else img.shape[1]
crop_height = dim[1] if dim[1]<img.shape[0] else img.shape[0]
mid_x, mid_y = int(width/2), int(height/2)
cw2, ch2 = int(crop_width/2), int(crop_height/2)
crop_img = img[mid_y-ch2:mid_y+ch2, mid_x-cw2:mid_x+cw2]
return crop_img
def scale_image(img, factor=1):
"""Returns resize image by scale factor.
This helps to retain resolution ratio while resizing.
Args:
img: image to be scaled
factor: scale factor to resize
"""
return cv2.resize(img,(int(img.shape[1]*factor), int(img.shape[0]*factor)))
def get_track_by_tags(tags, pat, duration, maxit=20, loop=False):
if loop:
mode = "loop"
else:
mode = "track"
r = httpx.post('https://api-b2b.mubert.com/v2/RecordTrackTTM',
json={
"method": "RecordTrackTTM",
"params": {
"pat": pat,
"duration": duration,
"tags": tags,
"mode": mode
}
})
pprint(r.text)
rdata = json.loads(r.text)
assert rdata['status'] == 1, rdata['error']['text']
trackurl = rdata['data']['tasks'][0]['download_link']
#print('Generating track ', end='')
for i in range(maxit):
r = httpx.get(trackurl)
if r.status_code == 200:
return trackurl
time.sleep(1)
def generate_track_by_prompt(image, email, duration, loop=False):
try:
# Checking Image Aspect Ratio
filename_png = f"{uuid.uuid4().hex}.png"
filepath_png = f"{gettempdir()}/{filename_png}"
with Image.open(image) as im:
# image size
ratio_width = im.size[0]
ratio_height = im.size[1]
im.convert("RGB").save(filepath_png)
if ratio_width > 3501 or ratio_height > 3501:
raise gr.Error("Image aspect ratio must not exceed width: 1024 px or height: 1024 px.")
elif ratio_width > 3500 or ratio_height > 3500:
image_g = cv2.imread(image)
scale_img = scale_image(image_g, factor=0.2)
cv2.imwrite(filepath_png, scale_img)
elif ratio_width > 1800 or ratio_height > 1800:
image_g = cv2.imread(image)
scale_img = scale_image(image_g, factor=0.3)
cv2.imwrite(filepath_png, scale_img)
elif ratio_width > 900 or ratio_height > 900:
image_g = cv2.imread(image)
scale_img = scale_image(image_g, factor=0.5)
cv2.imwrite(filepath_png, scale_img)
# prompt = image_to_text(filepath_png, "Image Captioning", "", "Nucleus sampling")
prompt = image_to_text(filepath_png, "ViT-L (best for Stable Diffusion 1.*)", "Fast", fn_index=1)[0]
print(f"PROMPT: {prompt}")
pat = get_pat(email)
_, tags = get_tags_for_prompts(minilm, mubert_tags_embeddings, [prompt, ])[0]
filepath = get_track_by_tags(tags, pat, int(duration), loop=loop)
filename_mp3 = filepath.split("/")[-1]
filepath_mp3 = f"{gettempdir()}/{filename_mp3}"
filename_mp4 = f"{uuid.uuid4().hex}.mp4"
filepath_mp4 = f"{gettempdir()}/{filename_mp4}"
os.system(f"wget {filepath} -P {gettempdir()}")
# waveform
with Image.open(filepath_png) as im:
width = im.size[0]
height = im.size[1]
print(f"{width}x{height}")
command = f'ffmpeg -hide_banner -loglevel warning -y -i {filepath_mp3} -loop 1 -i {filepath_png} -filter_complex "[0:a]showwaves=s={width}x{height}:colors=0xffffff:mode=cline,format=rgba[v];[1:v][v]overlay[outv]" -map "[outv]" -map 0:a -c:v libx264 -r 15 -c:a copy -pix_fmt yuv420p -shortest {filepath_mp4}'
os.system(command)
os.remove(filepath_png)
os.remove(filepath_mp3)
return filepath_mp4, filepath, prompt, tags
except Exception as e:
raise gr.Error(str(e))
iface = gr.Interface(fn=generate_track_by_prompt,
inputs=[gr.Image(type="filepath"),
"text",
gr.Slider(label="duration (seconds)", value=30, minimum=10, maximum=60)],
outputs=[gr.Video(label="Video"),
gr.Audio(label="Audio"),
gr.Text(label="Prompt"),
gr.Text(label="Tags")])
iface.queue().launch()