File size: 6,293 Bytes
d50bd1e
5978ae3
ee4f393
5978ae3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9ed1e74
 
 
 
 
5978ae3
 
 
 
 
 
9ed1e74
 
5978ae3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ee4f393
 
 
5978ae3
 
 
ae68709
5978ae3
 
 
 
 
 
 
ee4f393
 
 
5978ae3
ee4f393
 
 
 
 
 
 
 
 
 
 
5978ae3
 
ee4f393
5978ae3
 
 
 
9ed1e74
5978ae3
9ed1e74
5978ae3
d50bd1e
 
 
5978ae3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ee4f393
 
 
5978ae3
 
 
 
 
 
 
 
 
 
 
 
ee4f393
5978ae3
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import os
from warnings import simplefilter
import traceback

simplefilter("ignore")
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
import json
import time
import google.generativeai as genai

try:
    from logger import logging
except:
    import logging

music_prompt_examples = """
'A dynamic blend of hip-hop and orchestral elements, with sweeping strings and brass, evoking the vibrant energy of the city',
'Smooth jazz, with a saxophone solo, piano chords, and snare full drums',
'90s rock song with electric guitar and heavy drums, nightcore, 140bpm',
'lofi melody loop, A minor, 110 bpm, jazzy chords evoking a feeling of curiosity, relaxing, vinyl recording',
'J-Pop, 140bpm, 320kbps, 48kHz',
'funk, disco, R&B, AOR, soft rock, and boogie',
'a light and cheerly EDM track, with syncopated drums, aery pads, and strong emotions bpm: 130'.
"""

json_schema = """
{"Content Description": "string", "Music Prompt": "string"}
"""

gemini_instructions = f"""
You are a music supervisor who analyzes the content and tone of images and videos to describe music that fits well with the mood, evokes emotions, and enhances the narrative of the visuals. Given an image or video, describe the scene and generate a prompt suitable for music generation models. Generate a music prompt based on the description, and use keywords if provided by the user:

{music_prompt_examples}

You must return your response using this JSON schema: {json_schema}
"""


class DescribeVideo:
    def __init__(self, model="flash"):
        self.model = self.get_model_name(model)
        __api_key = self.load_api_key()
        self.is_safety_set = False
        self.safety_settings = self.get_safety_settings()

        genai.configure(api_key=__api_key)
        self.mllm_model = genai.GenerativeModel(
            self.model, system_instruction=gemini_instructions
        )

        logging.info(f"Initialized DescribeVideo with model: {self.model}")

    def describe_video(self, video_path, genre, bpm, user_keywords):
        video_file = genai.upload_file(video_path)

        while video_file.state.name == "PROCESSING":
            time.sleep(0.25)
            video_file = genai.get_file(video_file.name)

        if video_file.state.name == "FAILED":
            logging.error(
                f"Failed to upload video: {video_file.state.name}, Traceback: {traceback.format_exc()}"
            )
            raise ValueError(f"Failed to upload video: {video_file.state.name}")

        additional_keywords = ", ".join(filter(None, [genre, user_keywords])) + (
            f", {bpm} bpm" if bpm else ""
        )

        logging.info(f"Uploaded video: {video_path} and config: {additional_keywords}")

        user_prompt = "Explain what is happening in this video."

        if additional_keywords:
            user_prompt += f" The following keywords are provided by the user for generating the music prompt: {additional_keywords}"

        response = self.mllm_model.generate_content(
            [video_file, user_prompt],
            request_options={"timeout": 600},
            safety_settings=self.safety_settings,
        )

        logging.info(f"Generated : {video_path} with response: {response.text}")

        return json.loads(response.text.strip("```json\n"))

    def __call__(self, video_path):
        return self.describe_video(video_path)

    def reset_safety_settings(self):
        logging.info("Resetting safety settings")
        self.is_safety_set = False
        self.safety_settings = self.get_safety_settings()

    def set_safety_settings(self, safety_settings):
        self.safety_settings = safety_settings
        # Sanity Checks
        if not isinstance(safety_settings, dict):
            raise ValueError("Safety settings must be a dictionary")
        for harm_category, harm_block_threshold in safety_settings.items():
            if harm_category not in genai.types.HarmCategory.__members__:
                raise ValueError(f"Invalid harm category: {harm_category}")
            if harm_block_threshold not in genai.types.HarmBlockThreshold.__members__:
                raise ValueError(
                    f"Invalid harm block threshold: {harm_block_threshold}"
                )

        logging.info(f"Set safety settings: {safety_settings}")
        self.safety_settings = safety_settings
        self.is_safety_set = True

    def get_safety_settings(self):
        default_safety_settings = {
            genai.types.HarmCategory.HARM_CATEGORY_HATE_SPEECH: genai.types.HarmBlockThreshold.BLOCK_NONE,
            genai.types.HarmCategory.HARM_CATEGORY_HARASSMENT: genai.types.HarmBlockThreshold.BLOCK_NONE,
            genai.types.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: genai.types.HarmBlockThreshold.BLOCK_NONE,
            genai.types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: genai.types.HarmBlockThreshold.BLOCK_NONE,
        }

        if self.is_safety_set:
            return self.safety_settings

        return default_safety_settings

    @staticmethod
    def load_api_key(path="./creds.json"):
        with open(path) as f:
            creds = json.load(f)

        api_key = creds.get("google_api_key", None)
        if api_key is None or not isinstance(api_key, str):
            logging.error(
                f"Google API key not found in {path}, Traceback: {traceback.format_exc()}"
            )
            raise ValueError(f"Gemini API key not found in {path}")
        return api_key

    @staticmethod
    def get_model_name(model):
        models = {
            "flash": "models/gemini-1.5-flash-latest",
            "pro": "models/gemini-1.5-pro-latest",
        }

        if model not in models:
            logging.error(
                f"Invalid model name '{model}'. Valid options are: {', '.join(models.keys())}, Traceback: {traceback.format_exc()}"
            )
            raise ValueError(
                f"Invalid model name '{model}'. Valid options are: {', '.join(models.keys())}"
            )

        logging.info(f"Selected model: {models[model]}")
        return models[model]


if __name__ == "__main__":
    video_path = "videos/3A49B385FD4A8FE2E3AEEF43C140D9AF_video_dashinit.mp4"
    dv = DescribeVideo(model="flash")
    video_description = dv.describe_video(video_path)
    print(video_description)