test_Idiot-Cultivation-System / src /text_embedding.py
hhhwmws's picture
Update src/text_embedding.py
01a1dcc verified
raw
history blame
No virus
8.97 kB
import torch
from transformers import AutoTokenizer, AutoModel
import os
class TextExtractor:
def __init__(self, model_name, proxy=None):
"""
Initialize the TextExtractor with a specified model and optional proxy settings.
Parameters:
- model_name (str): The name of the pre-trained model to load from HuggingFace Hub.
- proxy (str, optional): The proxy address to use for HTTP and HTTPS requests.
"""
# if proxy is None:
# proxy = 'http://localhost:8234'
# if proxy:
# os.environ['HTTP_PROXY'] = proxy
# os.environ['HTTPS_PROXY'] = proxy
try:
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModel.from_pretrained(model_name)
except:
print('try switch on local_files_only')
self.tokenizer = AutoTokenizer.from_pretrained(model_name, local_files_only=True)
self.model = AutoModel.from_pretrained(model_name, local_files_only=True)
self.model.eval()
def extract(self, sentences):
"""
Extract sentence embeddings for the provided sentences.
Parameters:
- sentences (list of str): A list of sentences to extract embeddings for.
Returns:
- torch.Tensor: The normalized sentence embeddings.
"""
encoded_input = self.tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')
with torch.no_grad():
model_output = self.model(**encoded_input)
sentence_embeddings = model_output[0][:, 0]
sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1)
return sentence_embeddings
import pandas as pd
def get_qas(excel_file = None):
defaule_excel_file = 'data/output_fixid.xlsx'
if excel_file is None:
excel_file = defaule_excel_file
# 读取Excel文件
df = pd.read_excel(excel_file)
df = df[df["question"].notna()]
df = df[df["summary"].notna()]
datas = []
# 遍历DataFrame的每一行
for index, row in df.iterrows():
id = row['id']
question = row['question']
short_answer = row['summary']
category = row['category']
texts = [question, short_answer]
data_value = {
"texts":texts,
}
data = {
"id":id,
"value":data_value
}
datas.append(data)
return datas
from tqdm import tqdm
def extract_embedding(datas, text_extractor):
"""
Extract embeddings for each item in the provided data.
Parameters:
- datas (list of dict): A list of dictionaries containing text data.
Returns:
- list of dict: The input data with added embeddings.
"""
for data in tqdm(datas):
texts = data["value"]["texts"]
text = "。".join(texts)
embeddings = text_extractor.extract(text)
embeddings_list = embeddings.tolist() # Convert tensor to list of lists
data["value"]["embedding"] = embeddings_list
return datas
def save_parquet(datas, file_path):
"""
Save the provided data to a Parquet file.
Parameters:
- datas (list of dict): A list of dictionaries containing text data and embeddings.
- file_path (str): The path to the output Parquet file.
"""
# Flatten the data for easier conversion to DataFrame
flattened_data = []
for data in datas:
id = data["id"]
texts = data["value"]["texts"]
text = "。".join(texts)
embedding = data["value"]["embedding"]
flattened_data.append({
"id": id,
"text": text,
"embedding": embedding
})
# Create DataFrame
df = pd.DataFrame(flattened_data)
# Save DataFrame to Parquet
df.to_parquet(file_path, index=False)
import pandas as pd
import os
def get_id2embedding(regen=False, parquet_file='datas/qa_with_embedding.parquet'):
"""
Get a dictionary mapping IDs to embeddings. Regenerate embeddings if specified.
Parameters:
- parquet_file (str): The path to the Parquet file.
- regen (bool): Whether to regenerate embeddings.
Returns:
- dict: A dictionary mapping IDs to list of float embeddings.
"""
if regen or not os.path.exists(parquet_file):
print("Regenerating embeddings...")
# Example usage:
model_name = 'BAAI/bge-small-zh-v1.5'
text_extractor = TextExtractor(model_name)
datas = get_qas()
print("Extracting embeddings for", len(datas), "data items")
datas = extract_embedding(datas, text_extractor)
save_parquet(datas, parquet_file)
df = pd.read_parquet(parquet_file)
id2embedding = {}
for index, row in df.iterrows():
id = row['id']
embedding = row['embedding']
id2embedding[id] = embedding[0]
return id2embedding
import torch
from sklearn.metrics.pairwise import cosine_similarity
import heapq
def __get_id2top30map(id2embedding):
"""
Get a dictionary mapping IDs to their top 30 nearest neighbors based on cosine similarity.
Parameters:
- id2embedding (dict): A dictionary mapping IDs to list of float embeddings.
Returns:
- dict: A dictionary mapping each ID to a list of the top 30 nearest neighbor IDs.
"""
ids = list(id2embedding.keys())
embeddings = torch.tensor([id2embedding[id] for id in ids])
# Compute cosine similarity matrix
cos_sim_matrix = cosine_similarity(embeddings)
id2top30map = {}
for i, id in enumerate(ids):
# Get the similarity scores for the current ID
sim_scores = cos_sim_matrix[i]
# Get the top 30 indices (excluding the current ID itself)
top_indices = heapq.nlargest(31, range(len(sim_scores)), key=lambda x: sim_scores[x])
top_indices.remove(i) # Remove the index of the current ID
# Map the indices back to IDs
top_30_ids = [ids[idx] for idx in top_indices[:30]]
id2top30map[id] = top_30_ids
return id2top30map
import pickle
def get_id2top30map( id2embedding = None ):
default_save_pkl = "data/id2top30map.pkl"
if id2embedding is None:
if os.path.exists(default_save_pkl):
with open(default_save_pkl, 'rb') as f:
id2top30map = pickle.load(f)
else:
print("No embedding found, generating new one...")
id2embedding = get_id2embedding(regen=False)
id2top30map = __get_id2top30map(id2embedding)
with open(default_save_pkl, 'wb') as f:
pickle.dump(id2top30map, f)
else:
id2top30map = __get_id2top30map(id2embedding)
return id2top30map
if __name__ == '__main__':
if False:
# Example usage:
model_name = 'BAAI/bge-small-zh-v1.5'
sentences = ["样例数据-1", "样例数据-2"]
text_extractor = TextExtractor(model_name)
embeddings = text_extractor.extract(sentences)
print("Sentence embeddings:", embeddings)
datas = get_qas()
print("extract embedding for ", len(datas), " datas")
datas = extract_embedding(datas, text_extractor )
default_parquet_save_name = "data/qa_with_embedding.parquet"
save_parquet(datas, default_parquet_save_name)
if True:
id2embedding = get_id2embedding(regen=False)
print(len(id2embedding[4]))
id2top30map = get_id2top30map( None )
print("ID to Top 30 Neighbors dictionary:", id2top30map[4])
if True:
start_id = 332
visited_ids = [start_id]
current_queue = [start_id]
expend_num = 5
for iteration in range(10):
current_node = current_queue.pop(0)
top30 = id2top30map[current_node]
current_expend = []
for id in top30:
if id not in visited_ids:
visited_ids.append(id)
current_queue.append(id)
current_expend.append(id)
if len(current_expend) >= expend_num:
break
display_text = f"{current_node} | ->" + ",".join([str(i) for i in current_expend])
print(display_text)
from get_qa_and_image import get_qa_and_image
image_datas = get_qa_and_image()
id2index = {}
for i, data in enumerate(image_datas):
id2index[data['id']] = i
indexes = [id2index[i] for i in visited_ids if i in id2index]
image_names = [image_datas[index]['value']['image'] for index in indexes]
target_copy_folder = "data/asso_collection"
import shutil
# copy image into target_copy_folder
for image_name in image_names:
shutil.copy(image_name, target_copy_folder)