File size: 5,427 Bytes
e1b51d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1675112
e1b51d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import time

import gradio as gr
from sentence_transformers import SentenceTransformer

import httpx
import json

from utils import get_tags_for_prompts, get_mubert_tags_embeddings, get_pat
#import subprocess
import os
import uuid
from tempfile import gettempdir
from PIL import Image
import cv2
from pprint import pprint

minilm = SentenceTransformer('all-MiniLM-L6-v2')
mubert_tags_embeddings = get_mubert_tags_embeddings(minilm)

# image_to_text = gr.Interface.load("spaces/doevent/image_to_text", api_key=os.environ['HF_TOKEN'])
image_to_text = gr.Blocks.load(name="spaces/banana-dev/demo-clip-interrogator")
def center_crop(img, dim: tuple = (512, 512)):
    
    """Returns center cropped image
	Args:
	img: image to be center cropped
	dim: dimensions (width, height) to be cropped
    """
    
    width, height = img.shape[1], img.shape[0]

    # process crop width and height for max available dimension
    crop_width = dim[0] if dim[0]<img.shape[1] else img.shape[1]
    crop_height = dim[1] if dim[1]<img.shape[0] else img.shape[0] 
    mid_x, mid_y = int(width/2), int(height/2)
    cw2, ch2 = int(crop_width/2), int(crop_height/2) 
    crop_img = img[mid_y-ch2:mid_y+ch2, mid_x-cw2:mid_x+cw2]
    return crop_img


def scale_image(img, factor=1):
	"""Returns resize image by scale factor.
	This helps to retain resolution ratio while resizing.
	Args:
	img: image to be scaled
	factor: scale factor to resize
	"""
	return cv2.resize(img,(int(img.shape[1]*factor), int(img.shape[0]*factor)))
	
	
def get_track_by_tags(tags, pat, duration, maxit=20, loop=False):
    if loop:
        mode = "loop"
    else:
        mode = "track"
    r = httpx.post('https://api-b2b.mubert.com/v2/RecordTrackTTM',
                   json={
                       "method": "RecordTrackTTM",
                       "params": {
                           "pat": pat,
                           "duration": duration,
                           "tags": tags,
                           "mode": mode
                       }
                   })
    
    pprint(r.text)
    rdata = json.loads(r.text)
    assert rdata['status'] == 1, rdata['error']['text']
    trackurl = rdata['data']['tasks'][0]['download_link']

    #print('Generating track ', end='')
    for i in range(maxit):
        r = httpx.get(trackurl)
        if r.status_code == 200:
            return trackurl
        time.sleep(1)


def generate_track_by_prompt(image, email, duration, loop=False):
    try:
        # Checking Image Aspect Ratio
        filename_png = f"{uuid.uuid4().hex}.png"
        filepath_png = f"{gettempdir()}/{filename_png}"
        
        with Image.open(image) as im:
            # image size
            ratio_width = im.size[0]
            ratio_height = im.size[1]
            im.convert("RGB").save(filepath_png)
        if ratio_width > 3501 or ratio_height > 3501:
            raise gr.Error("Image aspect ratio must not exceed width: 1024 px or height: 1024 px.")
        elif ratio_width > 3500 or ratio_height > 3500:
            image_g = cv2.imread(image)
            scale_img = scale_image(image_g, factor=0.2)
            cv2.imwrite(filepath_png, scale_img)
        elif ratio_width > 1800 or ratio_height > 1800:
            image_g = cv2.imread(image)
            scale_img = scale_image(image_g, factor=0.3)
            cv2.imwrite(filepath_png, scale_img)
        elif ratio_width > 900 or ratio_height > 900:
            image_g = cv2.imread(image)
            scale_img = scale_image(image_g, factor=0.5)
            cv2.imwrite(filepath_png, scale_img)
        
        # prompt = image_to_text(filepath_png, "Image Captioning", "", "Nucleus sampling")
        prompt = image_to_text(filepath_png, "ViT-L (best for Stable Diffusion 1.*)", "Fast",  fn_index=1)[0]
        print(f"PROMPT: {prompt}")
  
        pat = get_pat(email)
        _, tags = get_tags_for_prompts(minilm, mubert_tags_embeddings, [prompt, ])[0]
        filepath = get_track_by_tags(tags, pat, int(duration), loop=loop)
           
        filename_mp3 = filepath.split("/")[-1]
        filepath_mp3 = f"{gettempdir()}/{filename_mp3}"
        filename_mp4 = f"{uuid.uuid4().hex}.mp4"
        filepath_mp4 = f"{gettempdir()}/{filename_mp4}"
        
        os.system(f"wget {filepath} -P {gettempdir()}")
        
        # waveform
        with Image.open(filepath_png) as im:
            width = im.size[0]
            height = im.size[1]
        print(f"{width}x{height}")
        command = f'ffmpeg -hide_banner -loglevel warning -y -i {filepath_mp3} -loop 1 -i {filepath_png} -filter_complex "[0:a]showwaves=s={width}x{height}:colors=0xffffff:mode=cline,format=rgba[v];[1:v][v]overlay[outv]" -map "[outv]" -map 0:a -c:v libx264 -r 15 -c:a copy -pix_fmt yuv420p -shortest {filepath_mp4}'
        os.system(command)
        os.remove(filepath_png)
        os.remove(filepath_mp3)
        
        return filepath_mp4, filepath, prompt, tags
    except Exception as e:
        raise gr.Error(str(e))


iface = gr.Interface(fn=generate_track_by_prompt, 
                    inputs=[gr.Image(type="filepath"), 
                    "text", 
                    gr.Slider(label="duration (seconds)", value=30, minimum=10, maximum=60)], 
                    outputs=[gr.Video(label="Video"), 
                    gr.Audio(label="Audio"),
                    gr.Text(label="Prompt"),
                    gr.Text(label="Tags")])
iface.queue().launch()