File size: 4,888 Bytes
1801c3b
b4233e1
1801c3b
bac9ad8
1801c3b
15402aa
1801c3b
 
 
 
 
a2309ed
b4233e1
a2309ed
1801c3b
 
7969559
1801c3b
cd44e6d
b4233e1
cd44e6d
1801c3b
 
 
 
 
 
 
 
 
 
 
 
cd44e6d
1801c3b
 
 
 
 
4343947
1801c3b
 
 
 
 
 
7969559
 
 
1801c3b
 
bac9ad8
15402aa
 
 
 
 
bac9ad8
 
1801c3b
 
 
 
 
4343947
1801c3b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7969559
1801c3b
 
 
 
 
 
 
 
 
 
a2309ed
 
 
 
 
 
 
 
 
 
 
 
 
 
1801c3b
 
fe577c4
ba7f790
7969559
498eef9
7b9d833
 
 
fe577c4
1801c3b
a2309ed
 
 
7453c0d
7969559
15402aa
 
1801c3b
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import base64
import hashlib
import os
import subprocess
from dataclasses import dataclass
from typing import Optional

import faiss
import numpy as np
import pandas as pd
import streamlit as st
from streamlit import runtime
from streamlit.logger import get_logger
from streamlit.runtime.scriptrunner import get_script_run_ctx

from pipeline import clip_wrapper
from pipeline.process_videos import DATAFRAME_PATH

NUM_FRAMES_TO_RETURN = 21
logger = get_logger(__name__)


class SemanticSearcher:
    def __init__(self, dataset: pd.DataFrame):
        dim_columns = dataset.filter(regex="^dim_").columns

        self.embedder = clip_wrapper.ClipWrapper().texts2vec
        self.metadata = dataset.drop(columns=dim_columns)
        self.index = faiss.IndexFlatIP(len(dim_columns))
        self.index.add(np.ascontiguousarray(dataset[dim_columns].to_numpy(np.float32)))

    def search(self, query: str) -> list["SearchResult"]:
        v = self.embedder([query]).detach().numpy()
        D, I = self.index.search(v, NUM_FRAMES_TO_RETURN)
        return [
            SearchResult(
                video_id=row["video_id"],
                frame_idx=row["frame_idx"],
                timestamp=row["timestamp"],
                base64_image=row["base64_image"],
                score=score,
            )
            for score, (_, row) in zip(D[0], self.metadata.iloc[I[0]].iterrows())
        ]


@st.cache_resource
def get_semantic_searcher():
    return SemanticSearcher(pd.read_parquet(DATAFRAME_PATH))


@st.cache_data
def get_git_hash() -> Optional[str]:
    try:
        return subprocess.check_output(["git", "rev-parse", "HEAD"]).decode().strip()
    except subprocess.CalledProcessError:
        return None


@dataclass
class SearchResult:
    video_id: str
    frame_idx: int
    timestamp: float
    base64_image: str
    score: float


def get_video_url(video_id: str, timestamp: float) -> str:
    return f"https://www.youtube.com/watch?v={video_id}&t={int(timestamp)}"


def display_search_results(results: list[SearchResult]) -> None:
    col_count = 3  # Number of videos per row

    col_num = 0  # Counter to keep track of the current column
    row = st.empty()  # Placeholder for the current row

    for i, result in enumerate(results):
        if col_num == 0:
            row = st.columns(col_count)  # Create a new row of columns

        with row[col_num]:
            # Apply CSS styling to the video container
            st.markdown(
                """
                <style>
                .video-container {
                    position: relative;
                    padding-bottom: 56.25%;
                    padding-top: 30px;
                    height: 0;
                    overflow: hidden;
                }
                .video-container iframe,
                .video-container object,
                .video-container embed {
                    position: absolute;
                    top: 0;
                    left: 0;
                    width: 100%;
                    height: 100%;
                }
                </style>
                """,
                unsafe_allow_html=True,
            )
            st.markdown(
                f"""
                <a href="{get_video_url(result.video_id, result.timestamp)}">
                <img src="data:image/jpeg;base64,{result.base64_image.decode()}" alt="frame {result.frame_idx} timestamp {int(result.timestamp)}" width="100%">
                </a>
                """,
                unsafe_allow_html=True,
            )

        col_num += 1
        if col_num >= col_count:
            col_num = 0


def get_remote_ip() -> str:
    """Get remote ip."""
    try:
        ctx = get_script_run_ctx()
        if ctx is None:
            return None
        session_info = runtime.get_instance().get_client(ctx.session_id)
        if session_info is None:
            return None
    except Exception as e:
        return None
    return session_info.request.remote_ip


def main():
    st.set_page_config(page_title="video-semantic-search", layout="wide")
    st.header("Visual content search over music videos")
    st.markdown("_App by Ben Tenmann and Sidney Radcliffe_")
    searcher = get_semantic_searcher()
    num_videos = len(searcher.metadata.video_id.unique())
    st.text_input(
        f"What are you looking for? Search over {num_videos} music videos.", key="query"
    )
    query = st.session_state["query"]
    if query:
        query_sha256 = hashlib.sha256(query.encode()).hexdigest()[:10]
        ip_sha256 = hashlib.sha256(get_remote_ip().encode()).hexdigest()[:10]
        logger.info(f"sha256(ip)={ip_sha256} sha256(query)={query_sha256}")
        st.text("Click image to open video")
        display_search_results(searcher.search(query))
    if get_git_hash():
        st.text(f"Build: {get_git_hash()[0:7]}")


if __name__ == "__main__":
    main()