|
|
|
|
|
import os |
|
import pickle |
|
import torch |
|
import re |
|
|
|
from typing import List |
|
from datetime import datetime, timedelta |
|
from enum import Enum |
|
from sentence_transformers import util |
|
from fastapi import APIRouter |
|
from fastapi.responses import PlainTextResponse |
|
|
|
try: |
|
from .rag import EMBEDDING_CTX |
|
from .utils_gitea import gitea_fetch_issues, gitea_json_issue_get, gitea_issues_body_updated_at_get |
|
except: |
|
from rag import EMBEDDING_CTX |
|
from utils_gitea import gitea_fetch_issues, gitea_json_issue_get, gitea_issues_body_updated_at_get |
|
|
|
|
|
router = APIRouter() |
|
|
|
issue_attr_filter = {'number', 'title', 'body', |
|
'state', 'updated_at', 'created_at'} |
|
|
|
|
|
class State(str, Enum): |
|
opened = "opened" |
|
closed = "closed" |
|
all = "all" |
|
|
|
|
|
class _Data(dict): |
|
cache_path = "routers/rag/embeddings_issues.pkl" |
|
|
|
@staticmethod |
|
def _create_issue_string(title, body): |
|
cleaned_body = body.replace('\r', '') |
|
cleaned_body = cleaned_body.replace('**System Information**\n', '') |
|
cleaned_body = cleaned_body.replace('**Blender Version**\n', '') |
|
cleaned_body = cleaned_body.replace( |
|
'Worked: (newest version of Blender that worked as expected)\n', '') |
|
cleaned_body = cleaned_body.replace( |
|
'**Short description of error**\n', '') |
|
cleaned_body = cleaned_body.replace('**Addon Information**\n', '') |
|
cleaned_body = cleaned_body.replace( |
|
'**Exact steps for others to reproduce the error**\n', '') |
|
cleaned_body = cleaned_body.replace( |
|
'[Please describe the exact steps needed to reproduce the issue]\n', '') |
|
cleaned_body = cleaned_body.replace( |
|
'[Please fill out a short description of the error here]\n', '') |
|
cleaned_body = cleaned_body.replace( |
|
'[Based on the default startup or an attached .blend file (as simple as possible)]\n', '') |
|
cleaned_body = re.sub( |
|
r', branch: .+?, commit date: \d{4}-\d{2}-\d{2} \d{2}:\d{2}, hash: `.+?`', '', cleaned_body) |
|
cleaned_body = re.sub( |
|
r'\/?attachments\/[a-zA-Z0-9\-]+', 'attachment', cleaned_body) |
|
cleaned_body = re.sub( |
|
r'https?:\/\/[^\s/]+(?:\/[^\s/]+)*\/([^\s/]+)', lambda match: match.group(1), cleaned_body) |
|
|
|
return title + '\n' + cleaned_body |
|
|
|
@staticmethod |
|
def _find_latest_date(issues, default_str=None): |
|
|
|
if not issues: |
|
return default_str |
|
|
|
return max((issue['updated_at'] for issue in issues), default=default_str) |
|
|
|
@classmethod |
|
def _create_strings_to_embbed(cls, issues): |
|
texts_to_embed = [cls._create_issue_string( |
|
issue['title'], issue['body']) for issue in issues] |
|
|
|
return texts_to_embed |
|
|
|
def _data_ensure_size(self, repo, size_new): |
|
ARRAY_CHUNK_SIZE = 4096 |
|
|
|
updated_at_old = None |
|
arrays_size_old = 0 |
|
titles_old = [] |
|
try: |
|
arrays_size_old = self[repo]['arrays_size'] |
|
if size_new <= arrays_size_old: |
|
return |
|
updated_at_old = self[repo]['updated_at'] |
|
titles_old = self[repo]['titles'] |
|
except: |
|
pass |
|
|
|
arrays_size_new = ARRAY_CHUNK_SIZE * \ |
|
(int(size_new / ARRAY_CHUNK_SIZE) + 1) |
|
|
|
data_new = { |
|
'updated_at': updated_at_old, |
|
'arrays_size': arrays_size_new, |
|
'titles': titles_old + [None] * (arrays_size_new - arrays_size_old), |
|
'embeddings': torch.empty((arrays_size_new, *EMBEDDING_CTX.embedding_shape), |
|
dtype=EMBEDDING_CTX.embedding_dtype, |
|
device=EMBEDDING_CTX.embedding_device), |
|
'opened': torch.zeros(arrays_size_new, dtype=torch.bool), |
|
'closed': torch.zeros(arrays_size_new, dtype=torch.bool), |
|
} |
|
|
|
try: |
|
data_new['embeddings'][:arrays_size_old] = self[repo]['embeddings'] |
|
data_new['opened'][:arrays_size_old] = self[repo]['opened'] |
|
data_new['closed'][:arrays_size_old] = self[repo]['closed'] |
|
except: |
|
pass |
|
|
|
self[repo] = data_new |
|
|
|
def _embeddings_generate(self, repo): |
|
if os.path.exists(self.cache_path): |
|
with open(self.cache_path, 'rb') as file: |
|
data = pickle.load(file) |
|
self.update(data) |
|
if repo in self: |
|
return |
|
|
|
issues = gitea_fetch_issues('blender', repo, state='all', since=None, |
|
issue_attr_filter=issue_attr_filter) |
|
|
|
|
|
|
|
print("Embedding Issues...") |
|
texts_to_embed = self._create_strings_to_embbed(issues) |
|
embeddings = EMBEDDING_CTX.encode(texts_to_embed) |
|
|
|
self._data_ensure_size(repo, int(issues[0]['number'])) |
|
self[repo]['updated_at'] = self._find_latest_date(issues) |
|
|
|
titles = self[repo]['titles'] |
|
embeddings_new = self[repo]['embeddings'] |
|
opened = self[repo]['opened'] |
|
closed = self[repo]['closed'] |
|
|
|
for i, issue in enumerate(issues): |
|
number = int(issue['number']) |
|
titles[number] = issue['title'] |
|
embeddings_new[number] = embeddings[i] |
|
if issue['state'] == 'open': |
|
opened[number] = True |
|
if issue['state'] == 'closed': |
|
closed[number] = True |
|
|
|
def _embeddings_updated_get(self, repo): |
|
with EMBEDDING_CTX.lock: |
|
if not repo in self: |
|
self._embeddings_generate(repo) |
|
|
|
date_old = self[repo]['updated_at'] |
|
|
|
issues = gitea_fetch_issues( |
|
'blender', repo, since=date_old, issue_attr_filter=issue_attr_filter) |
|
|
|
|
|
date_new = self._find_latest_date(issues, date_old) |
|
|
|
if date_new == date_old: |
|
|
|
return self[repo] |
|
|
|
self[repo]['updated_at'] = date_new |
|
|
|
|
|
|
|
issues = [issue for issue in issues if issue['updated_at'] != date_old] |
|
|
|
self._data_ensure_size(repo, int(issues[0]['number'])) |
|
|
|
updated_at = gitea_issues_body_updated_at_get(issues) |
|
issues_to_embed = [] |
|
|
|
for i, issue in enumerate(issues): |
|
number = int(issue['number']) |
|
self[repo]['opened'][number] = issue['state'] == 'open' |
|
self[repo]['closed'][number] = issue['state'] == 'closed' |
|
|
|
title_old = self[repo]['titles'][number] |
|
if title_old != issue['title']: |
|
self[repo]['titles'][number] = issue['title'] |
|
issues_to_embed.append(issue) |
|
elif not updated_at or updated_at[i] >= date_old: |
|
issues_to_embed.append(issue) |
|
|
|
if issues_to_embed: |
|
print(f"Embedding {len(issues_to_embed)} issue{'s' if len(issues_to_embed) > 1 else ''}") |
|
texts_to_embed = self._create_strings_to_embbed(issues_to_embed) |
|
embeddings = EMBEDDING_CTX.encode(texts_to_embed) |
|
|
|
for i, issue in enumerate(issues_to_embed): |
|
number = int(issue['number']) |
|
self[repo]['embeddings'][number] = embeddings[i] |
|
|
|
|
|
return self[repo] |
|
|
|
def _sort_similarity(self, |
|
repo: str, |
|
query_emb: List[torch.Tensor], |
|
limit: int, |
|
state: State = State.opened) -> list: |
|
duplicates = [] |
|
|
|
data = self[repo] |
|
embeddings = data['embeddings'] |
|
mask_opened = data["opened"] |
|
|
|
if state == State.all: |
|
mask = mask_opened | data["closed"] |
|
else: |
|
mask = data[state.value] |
|
|
|
embeddings = embeddings[mask] |
|
true_indices = mask.nonzero(as_tuple=True)[0] |
|
|
|
ret = util.semantic_search( |
|
query_emb, embeddings, top_k=limit, score_function=util.dot_score) |
|
|
|
for score in ret[0]: |
|
corpus_id = score['corpus_id'] |
|
number = true_indices[corpus_id].item() |
|
closed_char = "" if mask_opened[number] else "~~" |
|
text = f"{closed_char}#{number}{closed_char}: {data['titles'][number]}" |
|
duplicates.append(text) |
|
|
|
return duplicates |
|
|
|
def find_relatedness(self, repo: str, number: int, limit: int = 20, state: State = State.opened): |
|
data = self._embeddings_updated_get(repo) |
|
|
|
|
|
if data['titles'][number] is not None: |
|
new_embedding = data['embeddings'][number] |
|
else: |
|
gitea_issue = gitea_json_issue_get('blender', repo, number) |
|
text_to_embed = self._create_issue_string( |
|
gitea_issue['title'], gitea_issue['body']) |
|
|
|
new_embedding = EMBEDDING_CTX.encode([text_to_embed]) |
|
|
|
duplicates = self._sort_similarity( |
|
repo, new_embedding, limit=limit, state=state) |
|
|
|
if not duplicates: |
|
return '' |
|
|
|
if match := re.search(r'(~~)?#(\d+)(~~)?:', duplicates[0]): |
|
number_cached = int(match.group(2)) |
|
if number_cached == number: |
|
return '\n'.join(duplicates[1:]) |
|
|
|
return '\n'.join(duplicates) |
|
|
|
|
|
G_data = _Data() |
|
|
|
|
|
@router.get("/find_related/{repo}/{number}", response_class=PlainTextResponse) |
|
def find_related(repo: str = 'blender', number: int = 104399, limit: int = 15, state: State = State.opened) -> str: |
|
related = G_data.find_relatedness(repo, number, limit=limit, state=state) |
|
return related |
|
|
|
|
|
if __name__ == "__main__": |
|
update_cache = True |
|
if update_cache: |
|
G_data._embeddings_updated_get('blender') |
|
G_data._embeddings_updated_get('blender-addons') |
|
with open(G_data.cache_path, "wb") as file: |
|
|
|
for val in G_data.values(): |
|
val['embeddings'] = val['embeddings'].to(torch.device('cpu')) |
|
|
|
pickle.dump(dict(G_data), file, protocol=pickle.HIGHEST_PROTOCOL) |
|
|
|
|
|
for val in G_data.values(): |
|
val['embeddings'] = val['embeddings'].to(torch.device('cuda')) |
|
|
|
|
|
related1 = G_data.find_relatedness( |
|
'blender', 111434, limit=20, state=State.all) |
|
related2 = G_data.find_relatedness('blender-addons', 104399, limit=20) |
|
|
|
print("These are the 20 most related issues:") |
|
print(related1) |
|
print() |
|
print("These are the 20 most related issues:") |
|
print(related2) |
|
|