Spaces:
Runtime error
Runtime error
sradc
log hash of ip with hash of query so can track number of users and number of queries anonymously
a2309ed
import base64 | |
import hashlib | |
import os | |
import subprocess | |
from dataclasses import dataclass | |
from typing import Optional | |
import faiss | |
import numpy as np | |
import pandas as pd | |
import streamlit as st | |
from streamlit import runtime | |
from streamlit.logger import get_logger | |
from streamlit.runtime.scriptrunner import get_script_run_ctx | |
from pipeline import clip_wrapper | |
from pipeline.process_videos import DATAFRAME_PATH | |
NUM_FRAMES_TO_RETURN = 21 | |
logger = get_logger(__name__) | |
class SemanticSearcher: | |
def __init__(self, dataset: pd.DataFrame): | |
dim_columns = dataset.filter(regex="^dim_").columns | |
self.embedder = clip_wrapper.ClipWrapper().texts2vec | |
self.metadata = dataset.drop(columns=dim_columns) | |
self.index = faiss.IndexFlatIP(len(dim_columns)) | |
self.index.add(np.ascontiguousarray(dataset[dim_columns].to_numpy(np.float32))) | |
def search(self, query: str) -> list["SearchResult"]: | |
v = self.embedder([query]).detach().numpy() | |
D, I = self.index.search(v, NUM_FRAMES_TO_RETURN) | |
return [ | |
SearchResult( | |
video_id=row["video_id"], | |
frame_idx=row["frame_idx"], | |
timestamp=row["timestamp"], | |
base64_image=row["base64_image"], | |
score=score, | |
) | |
for score, (_, row) in zip(D[0], self.metadata.iloc[I[0]].iterrows()) | |
] | |
def get_semantic_searcher(): | |
return SemanticSearcher(pd.read_parquet(DATAFRAME_PATH)) | |
def get_git_hash() -> Optional[str]: | |
try: | |
return subprocess.check_output(["git", "rev-parse", "HEAD"]).decode().strip() | |
except subprocess.CalledProcessError: | |
return None | |
class SearchResult: | |
video_id: str | |
frame_idx: int | |
timestamp: float | |
base64_image: str | |
score: float | |
def get_video_url(video_id: str, timestamp: float) -> str: | |
return f"https://www.youtube.com/watch?v={video_id}&t={int(timestamp)}" | |
def display_search_results(results: list[SearchResult]) -> None: | |
col_count = 3 # Number of videos per row | |
col_num = 0 # Counter to keep track of the current column | |
row = st.empty() # Placeholder for the current row | |
for i, result in enumerate(results): | |
if col_num == 0: | |
row = st.columns(col_count) # Create a new row of columns | |
with row[col_num]: | |
# Apply CSS styling to the video container | |
st.markdown( | |
""" | |
<style> | |
.video-container { | |
position: relative; | |
padding-bottom: 56.25%; | |
padding-top: 30px; | |
height: 0; | |
overflow: hidden; | |
} | |
.video-container iframe, | |
.video-container object, | |
.video-container embed { | |
position: absolute; | |
top: 0; | |
left: 0; | |
width: 100%; | |
height: 100%; | |
} | |
</style> | |
""", | |
unsafe_allow_html=True, | |
) | |
st.markdown( | |
f""" | |
<a href="{get_video_url(result.video_id, result.timestamp)}"> | |
<img src="data:image/jpeg;base64,{result.base64_image.decode()}" alt="frame {result.frame_idx} timestamp {int(result.timestamp)}" width="100%"> | |
</a> | |
""", | |
unsafe_allow_html=True, | |
) | |
col_num += 1 | |
if col_num >= col_count: | |
col_num = 0 | |
def get_remote_ip() -> str: | |
"""Get remote ip.""" | |
try: | |
ctx = get_script_run_ctx() | |
if ctx is None: | |
return None | |
session_info = runtime.get_instance().get_client(ctx.session_id) | |
if session_info is None: | |
return None | |
except Exception as e: | |
return None | |
return session_info.request.remote_ip | |
def main(): | |
st.set_page_config(page_title="video-semantic-search", layout="wide") | |
st.header("Visual content search over music videos") | |
st.markdown("_App by Ben Tenmann and Sidney Radcliffe_") | |
searcher = get_semantic_searcher() | |
num_videos = len(searcher.metadata.video_id.unique()) | |
st.text_input( | |
f"What are you looking for? Search over {num_videos} music videos.", key="query" | |
) | |
query = st.session_state["query"] | |
if query: | |
query_sha256 = hashlib.sha256(query.encode()).hexdigest()[:10] | |
ip_sha256 = hashlib.sha256(get_remote_ip().encode()).hexdigest()[:10] | |
logger.info(f"sha256(ip)={ip_sha256} sha256(query)={query_sha256}") | |
st.text("Click image to open video") | |
display_search_results(searcher.search(query)) | |
if get_git_hash(): | |
st.text(f"Build: {get_git_hash()[0:7]}") | |
if __name__ == "__main__": | |
main() | |