Spaces:
Runtime error
Runtime error
import os | |
import cohere | |
import gradio as gr | |
import numpy as np | |
import pinecone | |
import torch | |
from transformers import AutoModel, AutoTokenizer | |
co = cohere.Client(os.environ.get('COHERE_API', '')) | |
pinecone.init( | |
api_key=os.environ.get('PINECONE_API', ''), | |
environment=os.environ.get('PINECONE_ENV', '') | |
) | |
# model = AutoModel.from_pretrained('monsoon-nlp/gpt-nyc') | |
# tokenizer = AutoTokenizer.from_pretrained('monsoon-nlp/gpt-nyc') | |
# zos = np.zeros(4096-1024).tolist() | |
def list_me(matches): | |
result = '' | |
for match in matches: | |
result += '<li><a target="_blank" href="https://reddit.com/r/AskNYC/comments/' + match['id'] + '">' | |
result += match['metadata']['question'] | |
result += '</a>' | |
if 'body' in match['metadata']: | |
result += '<br/>' + match['metadata']['body'] | |
result += '</li>' | |
return result.replace('/mini', '/') | |
def query(question): | |
# Cohere search | |
response = co.embed( | |
model='large', | |
texts=[question], | |
) | |
index = pinecone.Index("gptnyc") | |
closest = index.query( | |
top_k=2, | |
include_metadata=True, | |
vector=response.embeddings[0], | |
) | |
# SGPT search | |
# batch_tokens = tokenizer( | |
# [question], | |
# padding=True, | |
# truncation=True, | |
# return_tensors="pt" | |
# ) | |
# with torch.no_grad(): | |
# last_hidden_state = model(**batch_tokens, output_hidden_states=True, return_dict=True).last_hidden_state | |
# weights = ( | |
# torch.arange(start=1, end=last_hidden_state.shape[1] + 1) | |
# .unsqueeze(0) | |
# .unsqueeze(-1) | |
# .expand(last_hidden_state.size()) | |
# .float().to(last_hidden_state.device) | |
# ) | |
# input_mask_expanded = ( | |
# batch_tokens["attention_mask"] | |
# .unsqueeze(-1) | |
# .expand(last_hidden_state.size()) | |
# .float() | |
# ) | |
# sum_embeddings = torch.sum(last_hidden_state * input_mask_expanded * weights, dim=1) | |
# sum_mask = torch.sum(input_mask_expanded * weights, dim=1) | |
# embeddings = sum_embeddings / sum_mask | |
# closest_sgpt = index.query( | |
# top_k=2, | |
# include_metadata=True, | |
# namespace="mini", | |
# vector=embeddings[0].tolist() + zos, | |
# ) | |
return '<h3>Cohere</h3><ul>' + list_me(closest['matches']) + '</ul>' | |
#'<h3>SGPT</h3><ul>' + list_me(closest_sgpt['matches']) + '</ul>' | |
iface = gr.Interface( | |
fn=query, | |
inputs="text", | |
outputs="html" | |
) | |
iface.launch() | |