Spaces:
Sleeping
Sleeping
File size: 4,985 Bytes
14dc68f 73eedaf 14dc68f 73eedaf 14dc68f 73eedaf 14dc68f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
import importlib
import logging
import re
from typing import Dict, List
import openai
import weaviate
from weaviate.embedded import EmbeddedOptions
# default opt out of chromadb telemetry.
from chromadb.config import Settings
from transformers import AutoTokenizer, AutoModel
import torch
import numpy
# モデル名を指定
model_name = "sentence-transformers/all-MiniLM-L6-v2"
# トークナイザーとモデルをロード
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)
client = chromadb.Client(Settings(anonymized_telemetry=False))
def can_import(module_name):
try:
importlib.import_module(module_name)
return True
except ImportError:
return False
assert can_import("weaviate"), (
"\033[91m\033[1m"
+ "Weaviate storage requires package weaviate-client.\nInstall: pip install -r extensions/requirements.txt"
)
def create_client(
weaviate_url: str, weaviate_api_key: str, weaviate_use_embedded: bool
):
if weaviate_use_embedded:
client = weaviate.Client(embedded_options=EmbeddedOptions())
else:
auth_config = (
weaviate.auth.AuthApiKey(api_key=weaviate_api_key)
if weaviate_api_key
else None
)
client = weaviate.Client(weaviate_url, auth_client_secret=auth_config)
return client
class WeaviateResultsStorage:
schema = {
"properties": [
{"name": "result_id", "dataType": ["string"]},
{"name": "task", "dataType": ["string"]},
{"name": "result", "dataType": ["text"]},
]
}
def __init__(
self,
openai_api_key: str,
weaviate_url: str,
weaviate_api_key: str,
weaviate_use_embedded: bool,
llm_model: str,
llama_model_path: str,
results_store_name: str,
objective: str,
):
openai.api_key = openai_api_key
self.client = create_client(
weaviate_url, weaviate_api_key, weaviate_use_embedded
)
self.index_name = None
self.create_schema(results_store_name)
self.llm_model = llm_model
self.llama_model_path = llama_model_path
def create_schema(self, results_store_name: str):
valid_class_name = re.compile(r"^[A-Z][a-zA-Z0-9_]*$")
if not re.match(valid_class_name, results_store_name):
raise ValueError(
f"Invalid index name: {results_store_name}. "
"Index names must start with a capital letter and "
"contain only alphanumeric characters and underscores."
)
self.schema["class"] = results_store_name
if self.client.schema.contains(self.schema):
logging.info(
f"Index named {results_store_name} already exists. Reusing it."
)
else:
logging.info(f"Creating index named {results_store_name}")
self.client.schema.create_class(self.schema)
self.index_name = results_store_name
def add(self, task: Dict, result: Dict, result_id: int, vector: List):
enriched_result = {"data": result}
vector = self.get_embedding(enriched_result["data"])
with self.client.batch as batch:
data_object = {
"result_id": result_id,
"task": task["task_name"],
"result": result,
}
batch.add_data_object(
data_object=data_object, class_name=self.index_name, vector=vector
)
def query(self, query: str, top_results_num: int) -> List[dict]:
query_embedding = self.get_embedding(query)
results = (
self.client.query.get(self.index_name, ["task"])
.with_hybrid(query=query, alpha=0.5, vector=query_embedding)
.with_limit(top_results_num)
.do()
)
return self._extract_tasks(results)
def _extract_tasks(self, data):
task_data = data.get("data", {}).get("Get", {}).get(self.index_name, [])
return [item["task"] for item in task_data]
# Get embedding for the text
def get_embedding(self, text: str) -> list:
text = text.replace("\n", " ")
inputs = tokenizer(text, return_tensors="pt")
outputs = model(**inputs)
# [CLS]トークンの出力を取得
embeddings = outputs.last_hidden_state[:,0,:].squeeze().detach().cpu().numpy().tolist()
return embeddings
if self.llm_model.startswith("llama"):
from llama_cpp import Llama
llm_embed = Llama(
model_path=self.llama_model_path,
n_ctx=2048,
n_threads=4,
embedding=True,
use_mlock=True,
)
return llm_embed.embed(text)
return openai.Embedding.create(input=[text], model="text-embedding-ada-002")[
"data"
][0]["embedding"]
|