Spaces:
Runtime error
Runtime error
import time | |
import gradio as gr | |
from sentence_transformers import SentenceTransformer | |
import httpx | |
import json | |
from utils import get_tags_for_prompts, get_mubert_tags_embeddings, get_pat | |
#import subprocess | |
import os | |
import uuid | |
from tempfile import gettempdir | |
from PIL import Image | |
import cv2 | |
from pprint import pprint | |
minilm = SentenceTransformer('all-MiniLM-L6-v2') | |
mubert_tags_embeddings = get_mubert_tags_embeddings(minilm) | |
# image_to_text = gr.Interface.load("spaces/doevent/image_to_text", api_key=os.environ['HF_TOKEN']) | |
image_to_text = gr.Blocks.load(name="spaces/banana-dev/demo-clip-interrogator") | |
def center_crop(img, dim: tuple = (512, 512)): | |
"""Returns center cropped image | |
Args: | |
img: image to be center cropped | |
dim: dimensions (width, height) to be cropped | |
""" | |
width, height = img.shape[1], img.shape[0] | |
# process crop width and height for max available dimension | |
crop_width = dim[0] if dim[0]<img.shape[1] else img.shape[1] | |
crop_height = dim[1] if dim[1]<img.shape[0] else img.shape[0] | |
mid_x, mid_y = int(width/2), int(height/2) | |
cw2, ch2 = int(crop_width/2), int(crop_height/2) | |
crop_img = img[mid_y-ch2:mid_y+ch2, mid_x-cw2:mid_x+cw2] | |
return crop_img | |
def scale_image(img, factor=1): | |
"""Returns resize image by scale factor. | |
This helps to retain resolution ratio while resizing. | |
Args: | |
img: image to be scaled | |
factor: scale factor to resize | |
""" | |
return cv2.resize(img,(int(img.shape[1]*factor), int(img.shape[0]*factor))) | |
def get_track_by_tags(tags, pat, duration, maxit=20, loop=False): | |
if loop: | |
mode = "loop" | |
else: | |
mode = "track" | |
r = httpx.post('https://api-b2b.mubert.com/v2/RecordTrackTTM', | |
json={ | |
"method": "RecordTrackTTM", | |
"params": { | |
"pat": pat, | |
"duration": duration, | |
"tags": tags, | |
"mode": mode | |
} | |
}) | |
pprint(r.text) | |
rdata = json.loads(r.text) | |
assert rdata['status'] == 1, rdata['error']['text'] | |
trackurl = rdata['data']['tasks'][0]['download_link'] | |
#print('Generating track ', end='') | |
for i in range(maxit): | |
r = httpx.get(trackurl) | |
if r.status_code == 200: | |
return trackurl | |
time.sleep(1) | |
def generate_track_by_prompt(image, email, duration, loop=False): | |
try: | |
# Checking Image Aspect Ratio | |
filename_png = f"{uuid.uuid4().hex}.png" | |
filepath_png = f"{gettempdir()}/{filename_png}" | |
with Image.open(image) as im: | |
# image size | |
ratio_width = im.size[0] | |
ratio_height = im.size[1] | |
im.convert("RGB").save(filepath_png) | |
if ratio_width > 3501 or ratio_height > 3501: | |
raise gr.Error("Image aspect ratio must not exceed width: 1024 px or height: 1024 px.") | |
elif ratio_width > 3500 or ratio_height > 3500: | |
image_g = cv2.imread(image) | |
scale_img = scale_image(image_g, factor=0.2) | |
cv2.imwrite(filepath_png, scale_img) | |
elif ratio_width > 1800 or ratio_height > 1800: | |
image_g = cv2.imread(image) | |
scale_img = scale_image(image_g, factor=0.3) | |
cv2.imwrite(filepath_png, scale_img) | |
elif ratio_width > 900 or ratio_height > 900: | |
image_g = cv2.imread(image) | |
scale_img = scale_image(image_g, factor=0.5) | |
cv2.imwrite(filepath_png, scale_img) | |
# prompt = image_to_text(filepath_png, "Image Captioning", "", "Nucleus sampling") | |
prompt = image_to_text(filepath_png, "ViT-L (best for Stable Diffusion 1.*)", "Fast", fn_index=1)[0] | |
print(f"PROMPT: {prompt}") | |
pat = get_pat(email) | |
_, tags = get_tags_for_prompts(minilm, mubert_tags_embeddings, [prompt, ])[0] | |
filepath = get_track_by_tags(tags, pat, int(duration), loop=loop) | |
filename_mp3 = filepath.split("/")[-1] | |
filepath_mp3 = f"{gettempdir()}/{filename_mp3}" | |
filename_mp4 = f"{uuid.uuid4().hex}.mp4" | |
filepath_mp4 = f"{gettempdir()}/{filename_mp4}" | |
os.system(f"wget {filepath} -P {gettempdir()}") | |
# waveform | |
with Image.open(filepath_png) as im: | |
width = im.size[0] | |
height = im.size[1] | |
print(f"{width}x{height}") | |
command = f'ffmpeg -hide_banner -loglevel warning -y -i {filepath_mp3} -loop 1 -i {filepath_png} -filter_complex "[0:a]showwaves=s={width}x{height}:colors=0xffffff:mode=cline,format=rgba[v];[1:v][v]overlay[outv]" -map "[outv]" -map 0:a -c:v libx264 -r 15 -c:a copy -pix_fmt yuv420p -shortest {filepath_mp4}' | |
os.system(command) | |
os.remove(filepath_png) | |
os.remove(filepath_mp3) | |
return filepath_mp4, filepath, prompt, tags | |
except Exception as e: | |
raise gr.Error(str(e)) | |
iface = gr.Interface(fn=generate_track_by_prompt, | |
inputs=[gr.Image(type="filepath"), | |
"text", | |
gr.Slider(label="duration (seconds)", value=30, minimum=10, maximum=60)], | |
outputs=[gr.Video(label="Video"), | |
gr.Audio(label="Audio"), | |
gr.Text(label="Prompt"), | |
gr.Text(label="Tags")]) | |
iface.queue().launch() | |