|
import logging |
|
import os |
|
|
|
import faiss |
|
import torch |
|
|
|
logger = logging.getLogger(__name__) |
|
logging.basicConfig(level=logging.INFO) |
|
|
|
|
|
class FaissIndex: |
|
def __init__( |
|
self, |
|
embedding_size=None, |
|
faiss_index_location=None, |
|
indexer=faiss.IndexFlatIP, |
|
): |
|
|
|
if embedding_size or faiss_index_location: |
|
self.embedding_size = embedding_size |
|
else: |
|
raise ValueError("Must provide embedding_size") |
|
|
|
self.faiss_index_location = faiss_index_location |
|
if faiss_index_location and os.path.exists(faiss_index_location): |
|
self.index = faiss.read_index(faiss_index_location) |
|
logger.info(f"Setting embedding size ({self.index.d}) to match saved index") |
|
self.embedding_size = self.index.d |
|
if os.path.exists(faiss_index_location + ".ids"): |
|
with open(faiss_index_location + ".ids") as f: |
|
self.id_list = f.read().split("\n") |
|
elif self.index.ntotal > 0: |
|
raise ValueError("Index file exists but ids file does not") |
|
else: |
|
self.id_list = [] |
|
|
|
else: |
|
os.makedirs(os.path.dirname(faiss_index_location), exist_ok=True) |
|
self.index = None |
|
self.indexer = indexer |
|
self.id_list = [] |
|
|
|
def faiss_init(self): |
|
|
|
index = self.indexer(self.embedding_size) |
|
if self.faiss_index_location: |
|
faiss.write_index(index, self.faiss_index_location) |
|
self.index = index |
|
|
|
def add(self, inputs, ids, normalize=True): |
|
|
|
if not self.index: |
|
self.faiss_init() |
|
|
|
if normalize: |
|
faiss.normalize_L2(inputs) |
|
self.index.add(inputs) |
|
self.id_list.extend(ids) |
|
|
|
faiss.write_index(self.index, self.faiss_index_location) |
|
with open(self.faiss_index_location + ".ids", "a") as f: |
|
f.write("\n".join(ids) + "\n") |
|
|
|
def search(self, embedding, k=10, normalize=True): |
|
|
|
if len(embedding.shape): |
|
embedding = embedding.reshape(1, -1) |
|
if normalize: |
|
faiss.normalize_L2(embedding) |
|
D, I = self.index.search(embedding, k) |
|
labels = [self.id_list[i] for i in I.squeeze()] |
|
return D, I, labels |
|
|
|
def reset(self): |
|
|
|
if self.index: |
|
self.index.reset() |
|
self.id_list = [] |
|
try: |
|
os.remove(self.faiss_index_location) |
|
os.remove(self.faiss_index_location + ".ids") |
|
except FileNotFoundError: |
|
pass |
|
|
|
def __len__(self): |
|
if self.index: |
|
return self.index.ntotal |
|
return 0 |
|
|
|
|
|
class VectorSearch: |
|
def __init__(self): |
|
self.places = self.load("places") |
|
self.objects = self.load("objects") |
|
|
|
def load(self, index_name): |
|
return FaissIndex( |
|
faiss_index_location=f"faiss_indices/{index_name}.index", |
|
) |
|
|
|
def top_places(self, query_vec, k=5): |
|
if isinstance(query_vec, torch.Tensor): |
|
query_vec = query_vec.detach().numpy() |
|
*_, results = self.places.search(query_vec, k=k) |
|
return results |
|
|
|
def top_objects(self, query_vec, k=5): |
|
if isinstance(query_vec, torch.Tensor): |
|
query_vec = query_vec.detach().numpy() |
|
*_, results = self.objects.search(query_vec, k=k) |
|
return results |
|
|
|
def prompt_activities(self, query_vec, k=5, one_shot=False): |
|
places = self.top_places(query_vec, k=k) |
|
objects = self.top_objects(query_vec, k=k) |
|
place_str = f"Places: {', '.join(places)}. " |
|
object_str = f"Objects: {', '.join(objects)}. " |
|
|
|
act_str = "I might be doing these 3 activities: " |
|
|
|
zs = place_str + object_str + act_str |
|
|
|
example = ( |
|
"Places: kitchen. Objects: coffee maker. " |
|
f"{act_str}: eating, making breakfast, grinding coffee.\n " |
|
) |
|
fs = example + place_str + object_str + act_str |
|
if one_shot: |
|
return (zs, fs) |
|
|
|
return zs, places, objects |
|
|
|
def prompt_summary(self, state_history: list, k=5): |
|
|
|
rec_strings = ["Event log:"] |
|
for rec in state_history: |
|
rec_strings.append( |
|
f"Places: {', '.join(rec.places)}. " |
|
f"Objects: {', '.join(rec.objects)}. " |
|
f"Activities: {', '.join(rec.activities)} " |
|
) |
|
question = "How would you summarize these events in a few full sentences? " |
|
return "\n".join(rec_strings) + "\n" + question |
|
|