Spaces:
Sleeping
Sleeping
File size: 6,293 Bytes
d50bd1e 5978ae3 ee4f393 5978ae3 9ed1e74 5978ae3 9ed1e74 5978ae3 ee4f393 5978ae3 ae68709 5978ae3 ee4f393 5978ae3 ee4f393 5978ae3 ee4f393 5978ae3 9ed1e74 5978ae3 9ed1e74 5978ae3 d50bd1e 5978ae3 ee4f393 5978ae3 ee4f393 5978ae3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
import os
from warnings import simplefilter
import traceback
simplefilter("ignore")
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
import json
import time
import google.generativeai as genai
try:
from logger import logging
except:
import logging
music_prompt_examples = """
'A dynamic blend of hip-hop and orchestral elements, with sweeping strings and brass, evoking the vibrant energy of the city',
'Smooth jazz, with a saxophone solo, piano chords, and snare full drums',
'90s rock song with electric guitar and heavy drums, nightcore, 140bpm',
'lofi melody loop, A minor, 110 bpm, jazzy chords evoking a feeling of curiosity, relaxing, vinyl recording',
'J-Pop, 140bpm, 320kbps, 48kHz',
'funk, disco, R&B, AOR, soft rock, and boogie',
'a light and cheerly EDM track, with syncopated drums, aery pads, and strong emotions bpm: 130'.
"""
json_schema = """
{"Content Description": "string", "Music Prompt": "string"}
"""
gemini_instructions = f"""
You are a music supervisor who analyzes the content and tone of images and videos to describe music that fits well with the mood, evokes emotions, and enhances the narrative of the visuals. Given an image or video, describe the scene and generate a prompt suitable for music generation models. Generate a music prompt based on the description, and use keywords if provided by the user:
{music_prompt_examples}
You must return your response using this JSON schema: {json_schema}
"""
class DescribeVideo:
def __init__(self, model="flash"):
self.model = self.get_model_name(model)
__api_key = self.load_api_key()
self.is_safety_set = False
self.safety_settings = self.get_safety_settings()
genai.configure(api_key=__api_key)
self.mllm_model = genai.GenerativeModel(
self.model, system_instruction=gemini_instructions
)
logging.info(f"Initialized DescribeVideo with model: {self.model}")
def describe_video(self, video_path, genre, bpm, user_keywords):
video_file = genai.upload_file(video_path)
while video_file.state.name == "PROCESSING":
time.sleep(0.25)
video_file = genai.get_file(video_file.name)
if video_file.state.name == "FAILED":
logging.error(
f"Failed to upload video: {video_file.state.name}, Traceback: {traceback.format_exc()}"
)
raise ValueError(f"Failed to upload video: {video_file.state.name}")
additional_keywords = ", ".join(filter(None, [genre, user_keywords])) + (
f", {bpm} bpm" if bpm else ""
)
logging.info(f"Uploaded video: {video_path} and config: {additional_keywords}")
user_prompt = "Explain what is happening in this video."
if additional_keywords:
user_prompt += f" The following keywords are provided by the user for generating the music prompt: {additional_keywords}"
response = self.mllm_model.generate_content(
[video_file, user_prompt],
request_options={"timeout": 600},
safety_settings=self.safety_settings,
)
logging.info(f"Generated : {video_path} with response: {response.text}")
return json.loads(response.text.strip("```json\n"))
def __call__(self, video_path):
return self.describe_video(video_path)
def reset_safety_settings(self):
logging.info("Resetting safety settings")
self.is_safety_set = False
self.safety_settings = self.get_safety_settings()
def set_safety_settings(self, safety_settings):
self.safety_settings = safety_settings
# Sanity Checks
if not isinstance(safety_settings, dict):
raise ValueError("Safety settings must be a dictionary")
for harm_category, harm_block_threshold in safety_settings.items():
if harm_category not in genai.types.HarmCategory.__members__:
raise ValueError(f"Invalid harm category: {harm_category}")
if harm_block_threshold not in genai.types.HarmBlockThreshold.__members__:
raise ValueError(
f"Invalid harm block threshold: {harm_block_threshold}"
)
logging.info(f"Set safety settings: {safety_settings}")
self.safety_settings = safety_settings
self.is_safety_set = True
def get_safety_settings(self):
default_safety_settings = {
genai.types.HarmCategory.HARM_CATEGORY_HATE_SPEECH: genai.types.HarmBlockThreshold.BLOCK_NONE,
genai.types.HarmCategory.HARM_CATEGORY_HARASSMENT: genai.types.HarmBlockThreshold.BLOCK_NONE,
genai.types.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: genai.types.HarmBlockThreshold.BLOCK_NONE,
genai.types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: genai.types.HarmBlockThreshold.BLOCK_NONE,
}
if self.is_safety_set:
return self.safety_settings
return default_safety_settings
@staticmethod
def load_api_key(path="./creds.json"):
with open(path) as f:
creds = json.load(f)
api_key = creds.get("google_api_key", None)
if api_key is None or not isinstance(api_key, str):
logging.error(
f"Google API key not found in {path}, Traceback: {traceback.format_exc()}"
)
raise ValueError(f"Gemini API key not found in {path}")
return api_key
@staticmethod
def get_model_name(model):
models = {
"flash": "models/gemini-1.5-flash-latest",
"pro": "models/gemini-1.5-pro-latest",
}
if model not in models:
logging.error(
f"Invalid model name '{model}'. Valid options are: {', '.join(models.keys())}, Traceback: {traceback.format_exc()}"
)
raise ValueError(
f"Invalid model name '{model}'. Valid options are: {', '.join(models.keys())}"
)
logging.info(f"Selected model: {models[model]}")
return models[model]
if __name__ == "__main__":
video_path = "videos/3A49B385FD4A8FE2E3AEEF43C140D9AF_video_dashinit.mp4"
dv = DescribeVideo(model="flash")
video_description = dv.describe_video(video_path)
print(video_description)
|