socm / log_generation.py
spencer's picture
add normal files
6df828c
raw
history blame
3.1 kB
import glob
import string
from datetime import datetime
from pathlib import Path
import cv2
import yt_dlp
from nltk.tokenize import sent_tokenize
from tqdm import tqdm
from embeddings import VectorSearch, FaissIndex
def download_youtube(url, parent_dir="."):
def extract_youtube_id(url):
return url.split("watch?v=")[-1]
video_path = extract_youtube_id(url)
ydl_opts = {
"format": "mp4",
"outtmpl": f"{parent_dir}/{video_path}/{video_path}.%(ext)s",
}
with yt_dlp.YoutubeDL(ydl_opts) as ydl:
error_code = ydl.download([url])
return error_code
def extract_video_frames(video_path, dims=(600, 400), sampling_rate=100):
video_dir = str(Path(video_path).parent)
video_name = str(Path(video_path).stem)
cap = cv2.VideoCapture(video_path)
i = 0
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
if i % sampling_rate == 0:
print(i)
frame = cv2.resize(frame, dims, fx=0, fy=0, interpolation=cv2.INTER_CUBIC)
timestamp = datetime.utcnow().timestamp()
cv2.imwrite(f"{video_dir}/{video_name}_{timestamp}_{i}.jpg", frame)
i += 1
cap.release()
cv2.destroyAllWindows()
def strip_punctuation(text):
return text.translate(str.maketrans("", "", string.punctuation))
def clean_response(act_text):
act_text = act_text.lower().replace("\n", "")
text_split = act_text.split("places")[0]
if not text_split:
text_split = act_text
try:
first_sent = sent_tokenize(text_split)[0]
except:
first_sent = text_split
list_split = first_sent.split(",")
no_spaces = list(map(str.strip, list_split))
return list(map(strip_punctuation, no_spaces))[:3]
def log_activity_from_image(image_file, frame, vlm, llm, vs, fi):
img_embed = vlm.get_image_emb(image_file)
fi.add(img_embed, [frame])
zs, places, objects = vs.prompt_activities(img_embed, 3)
# kwargs = {
# "top_p": 0.9,
# "temperature": 1.2,
# "max_new_tokens": 20,
# "return_full_text": False,
# }
activities_raw = llm(zs)
act_text = activities_raw[0]["generated_text"].lower()
activities_clean = clean_response(act_text)
log = (
f"{frame}:"
f"Places: {', '.join(places)}. "
f"Objects: {', '.join(objects)}. "
f"Activities: {', '.join(activities_clean)}"
)
# log = f'{zs} {", ".join(activities_clean)}'
return log
def generate_log(log_path, images_path, vlm, llm):
vs = VectorSearch()
fi = FaissIndex(768, f"{images_path}/video.index")
fi.reset()
with open(log_path, "w") as f:
for image in tqdm(sorted(glob.glob(f"{images_path}/*.jpg"))):
video_name, timestamp, frame = Path(image).stem.split("_")
try:
log = log_activity_from_image(image, frame, vlm, llm, vs, fi)
print(log)
f.write(f"{frame}:{log}\n")
except Exception as e:
print(e)
continue