Germano Cavalcante commited on
Commit
ed15883
·
1 Parent(s): 5def575

New tool Wiki Search

Browse files

Searches the manual or any other embed information in order to find a
context for the information.

main.py CHANGED
@@ -6,7 +6,7 @@ from fastapi.responses import HTMLResponse
6
  from fastapi.staticfiles import StaticFiles
7
  from huggingface_hub import login
8
  from config import settings
9
- from routers import tool_bpy_doc, tool_gpu_checker, tool_calls, tool_find_related
10
 
11
  login(settings.huggingface_key)
12
 
@@ -30,6 +30,9 @@ app.include_router(
30
  app.include_router(
31
  tool_find_related.router, prefix="/api/v1", tags=["Tools"])
32
 
 
 
 
33
  app.include_router(
34
  tool_calls.router, prefix="/api/v1", tags=["Function Calls"])
35
 
 
6
  from fastapi.staticfiles import StaticFiles
7
  from huggingface_hub import login
8
  from config import settings
9
+ from routers import tool_bpy_doc, tool_gpu_checker, tool_calls, tool_find_related, tool_wiki_search
10
 
11
  login(settings.huggingface_key)
12
 
 
30
  app.include_router(
31
  tool_find_related.router, prefix="/api/v1", tags=["Tools"])
32
 
33
+ app.include_router(
34
+ tool_wiki_search.router, prefix="/api/v1", tags=["Tools"])
35
+
36
  app.include_router(
37
  tool_calls.router, prefix="/api/v1", tags=["Function Calls"])
38
 
routers/embedding/__init__.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # routers/embedding/__init__.py
2
+
3
+ import os
4
+ import sys
5
+ import threading
6
+ import torch
7
+ from sentence_transformers import SentenceTransformer, util
8
+
9
+
10
+ class EmbeddingContext:
11
+ # These don't change
12
+ TOKEN_LEN_MAX_FOR_EMBEDDING = 512
13
+
14
+ # Set when creating the object
15
+ lock = None
16
+ model = None
17
+ openai_client = None
18
+ model_name = ''
19
+ config_type = ''
20
+ embedding_shape = None
21
+ embedding_dtype = None
22
+ embedding_device = None
23
+
24
+ # Updates constantly
25
+ data = {}
26
+
27
+ def __init__(self):
28
+ try:
29
+ from config import settings
30
+ except:
31
+ sys.path.append(os.path.abspath(
32
+ os.path.join(os.path.dirname(__file__), '../..')))
33
+ from config import settings
34
+
35
+ self.lock = threading.Lock()
36
+ config_type = settings.embedding_api
37
+ model_name = settings.embedding_model
38
+
39
+ if config_type == 'sbert':
40
+ self.model = SentenceTransformer(model_name, use_auth_token=False)
41
+ self.model.max_seq_length = self.TOKEN_LEN_MAX_FOR_EMBEDDING
42
+ print("Max Sequence Length:", self.model.max_seq_length)
43
+
44
+ self.encode = self.encode_sbert
45
+ if torch.cuda.is_available():
46
+ self.model = self.model.to('cuda')
47
+
48
+ elif config_type == 'openai':
49
+ from openai import OpenAI
50
+ self.openai_client = OpenAI(
51
+ # base_url = settings.openai_api_base
52
+ api_key=settings.OPENAI_API_KEY,
53
+ )
54
+ self.encode = self.encode_openai
55
+
56
+ self.model_name = model_name
57
+ self.config_type = config_type
58
+
59
+ tmp = self.encode(['tmp'])
60
+ self.embedding_shape = tmp.shape[1:]
61
+ self.embedding_dtype = tmp.dtype
62
+ self.embedding_device = tmp.device
63
+
64
+ def encode(self, texts_to_embed):
65
+ pass
66
+
67
+ def encode_sbert(self, texts_to_embed):
68
+ return self.model.encode(texts_to_embed, show_progress_bar=True, convert_to_tensor=True, normalize_embeddings=True)
69
+
70
+ def encode_openai(self, texts_to_embed):
71
+ import math
72
+ import time
73
+
74
+ tokens_count = 0
75
+ for text in texts_to_embed:
76
+ tokens_count += len(self.get_tokens(text))
77
+
78
+ chunks_num = math.ceil(tokens_count / 500000)
79
+ chunk_size = math.ceil(len(texts_to_embed) / chunks_num)
80
+
81
+ embeddings = []
82
+ for i in range(chunks_num):
83
+ start = i * chunk_size
84
+ end = start + chunk_size
85
+ chunk = texts_to_embed[start:end]
86
+
87
+ embeddings_tmp = self.openai_client.embeddings.create(
88
+ model=self.model_name,
89
+ input=chunk,
90
+ ).data
91
+
92
+ if embeddings_tmp is None:
93
+ break
94
+
95
+ embeddings.extend(embeddings_tmp)
96
+
97
+ if i < chunks_num - 1:
98
+ time.sleep(60) # Wait 1 minute before the next call
99
+
100
+ return torch.stack([torch.tensor(embedding.embedding, dtype=torch.float32) for embedding in embeddings])
101
+
102
+ def get_tokens(self, text):
103
+ if self.model:
104
+ return self.model.tokenizer.tokenize(text)
105
+
106
+ tokens = []
107
+ for token in re.split(r'(\W|\b)', text):
108
+ if token.strip():
109
+ tokens.append(token)
110
+
111
+ return tokens
112
+
113
+
114
+ EMBEDDING_CTX = EmbeddingContext()
routers/{tool_find_related_cache.pkl → embedding/embeddings_issues.pkl} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:a181cc69d535d6502588e4c14bea367d74dfaca17a5602a23a72def479f592cc
3
- size 723433353
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7c3c012a8f86440dacedd6f1e4e9ea9f41f096031c0ac1ed5cdf64a9a8d46e42
3
+ size 723452942
routers/embedding/embeddings_manual.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9ed7475fc8ffda0d9e9deb6480b7152b53657f0fe6a6140bcb60360e425e7a01
3
+ size 18659241
routers/tool_calls.py CHANGED
@@ -8,10 +8,12 @@ try:
8
  from .tool_gpu_checker import gpu_checker_get_message
9
  from .tool_bpy_doc import bpy_doc_get_documentation
10
  from .tool_find_related import find_relatedness
 
11
  except:
12
  from tool_gpu_checker import gpu_checker_get_message
13
  from tool_bpy_doc import bpy_doc_get_documentation
14
  from tool_find_related import find_relatedness
 
15
 
16
 
17
  class ToolCallFunction(BaseModel):
@@ -43,6 +45,8 @@ def process_tool_call(tool_call: ToolCallInput) -> Dict:
43
  elif function_name == "find_related":
44
  output["output"] = find_relatedness(
45
  function_args["repo"], function_args["number"])
 
 
46
  except json.JSONDecodeError as e:
47
  error_message = f"Malformed JSON encountered at position {e.pos}: {e.msg}\n {tool_call.function.arguments}"
48
  output["output"] = error_message
 
8
  from .tool_gpu_checker import gpu_checker_get_message
9
  from .tool_bpy_doc import bpy_doc_get_documentation
10
  from .tool_find_related import find_relatedness
11
+ from .tool_wiki_search import wiki_search
12
  except:
13
  from tool_gpu_checker import gpu_checker_get_message
14
  from tool_bpy_doc import bpy_doc_get_documentation
15
  from tool_find_related import find_relatedness
16
+ from .tool_wiki_search import wiki_search
17
 
18
 
19
  class ToolCallFunction(BaseModel):
 
45
  elif function_name == "find_related":
46
  output["output"] = find_relatedness(
47
  function_args["repo"], function_args["number"])
48
+ elif function_name == "wiki_search":
49
+ output["output"] = wiki_search(function_args["query"])
50
  except json.JSONDecodeError as e:
51
  error_message = f"Malformed JSON encountered at position {e.pos}: {e.msg}\n {tool_call.function.arguments}"
52
  output["output"] = error_message
routers/tool_find_related.py CHANGED
@@ -1,22 +1,39 @@
1
- # find_related.py
2
 
3
  import os
4
  import pickle
5
- import re
6
  import torch
7
- import threading
8
 
 
9
  from datetime import datetime, timedelta
10
  from enum import Enum
11
- from sentence_transformers import SentenceTransformer, util
12
  from fastapi import APIRouter
13
 
14
  try:
 
15
  from .utils_gitea import gitea_fetch_issues, gitea_json_issue_get, gitea_issues_body_updated_at_get
16
  except:
 
17
  from utils_gitea import gitea_fetch_issues, gitea_json_issue_get, gitea_issues_body_updated_at_get
18
 
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  def _create_issue_string(title, body):
21
  cleaned_body = body.replace('\r', '')
22
  cleaned_body = cleaned_body.replace('**System Information**\n', '')
@@ -51,283 +68,149 @@ def _find_latest_date(issues, default_str=None):
51
  return max((issue['updated_at'] for issue in issues), default=default_str)
52
 
53
 
54
- class EmbeddingContext:
55
- # These don't change
56
- TOKEN_LEN_MAX_FOR_EMBEDDING = 512
57
- TOKEN_LEN_MAX_BALCKLIST = 2 * TOKEN_LEN_MAX_FOR_EMBEDDING
58
- ARRAY_CHUNK_SIZE = 4096
59
- issue_attr_filter = {'number', 'title', 'body',
60
- 'state', 'updated_at', 'created_at'}
61
- cache_path = "routers/tool_find_related_cache.pkl"
62
-
63
- # Set when creating the object
64
- lock = None
65
- model = None
66
- openai_client = None
67
- model_name = ''
68
- config_type = ''
69
- embedding_shape = None
70
- embedding_dtype = None
71
- embedding_device = None
72
-
73
- # Updates constantly
74
- data = {}
75
- black_list = {'blender': {109399, 113157, 114706},
76
- 'blender-addons': set()}
77
-
78
- def __init__(self):
79
- self.lock = threading.Lock()
80
-
81
- try:
82
- from config import settings
83
- except:
84
- import sys
85
- sys.path.append(os.path.abspath(
86
- os.path.join(os.path.dirname(__file__), '..')))
87
- from config import settings
88
-
89
- config_type = settings.embedding_api
90
- model_name = settings.embedding_model
91
-
92
- if config_type == 'sbert':
93
- self.model = SentenceTransformer(model_name, use_auth_token=False)
94
- self.model.max_seq_length = self.TOKEN_LEN_MAX_FOR_EMBEDDING
95
- print("Max Sequence Length:", self.model.max_seq_length)
96
-
97
- self.encode = self.encode_sbert
98
- if torch.cuda.is_available():
99
- self.model = self.model.to('cuda')
100
-
101
- elif config_type == 'openai':
102
- from openai import OpenAI
103
- self.openai_client = OpenAI(
104
- # base_url = settings.openai_api_base
105
- api_key=settings.OPENAI_API_KEY,
106
- )
107
- self.encode = self.encode_openai
108
-
109
- self.model_name = model_name
110
- self.config_type = config_type
111
-
112
- tmp = self.encode(['tmp'])
113
- self.embedding_shape = tmp.shape[1:]
114
- self.embedding_dtype = tmp.dtype
115
- self.embedding_device = tmp.device
116
-
117
- def encode(self, texts_to_embed):
118
- pass
119
 
120
- def encode_sbert(self, texts_to_embed):
121
- return self.model.encode(texts_to_embed, show_progress_bar=True, convert_to_tensor=True, normalize_embeddings=True)
122
 
123
- def encode_openai(self, texts_to_embed):
124
- import math
125
- import time
126
 
127
- tokens_count = 0
128
- for text in texts_to_embed:
129
- tokens_count += len(self.get_tokens(text))
130
 
131
- chunks_num = math.ceil(tokens_count / 500000)
132
- chunk_size = math.ceil(len(texts_to_embed) / chunks_num)
133
 
134
- embeddings = []
135
- for i in range(chunks_num):
136
- start = i * chunk_size
137
- end = start + chunk_size
138
- chunk = texts_to_embed[start:end]
 
 
 
 
139
 
140
- embeddings_tmp = self.openai_client.embeddings.create(
141
- model=self.model_name,
142
- input=chunk,
143
- ).data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
 
145
- if embeddings_tmp is None:
146
- break
147
 
148
- embeddings.extend(embeddings_tmp)
149
 
150
- if i < chunks_num - 1:
151
- time.sleep(60) # Wait 1 minute before the next call
152
 
153
- return torch.stack([torch.tensor(embedding.embedding, dtype=torch.float32) for embedding in embeddings])
 
 
 
 
154
 
155
- def get_tokens(self, text):
156
- if self.model:
157
- return self.model.tokenizer.tokenize(text)
158
 
159
- tokens = []
160
- for token in re.split(r'(\W|\b)', text):
161
- if token.strip():
162
- tokens.append(token)
163
 
164
- return tokens
 
 
165
 
166
- def create_strings_to_embbed(self, issues, black_list):
167
- texts_to_embed = [_create_issue_string(
168
- issue['title'], issue['body']) for issue in issues]
169
 
170
- # Create issue blacklist (for keepping track)
171
- token_count = 0
172
- for i, text in enumerate(texts_to_embed):
173
- tokens = self.get_tokens(text)
174
- tokens_len = len(tokens)
175
- token_count += tokens_len
176
 
177
- if tokens_len > self.TOKEN_LEN_MAX_BALCKLIST:
178
- # Only use the first TOKEN_LEN_MAX tokens
179
- black_list.add(int(issues[i]['number']))
180
- if self.config_type == 'openai':
181
- texts_to_embed[i] = ' '.join(
182
- tokens[:self.TOKEN_LEN_MAX_BALCKLIST])
 
 
183
 
184
- return texts_to_embed
185
 
186
- def data_ensure_size(self, repo, size_new):
187
- updated_at_old = None
188
- arrays_size_old = 0
189
- titles_old = []
190
- try:
191
- arrays_size_old = self.data[repo]['arrays_size']
192
- if size_new <= arrays_size_old:
193
- return
194
- except:
195
- pass
196
-
197
- arrays_size_new = self.ARRAY_CHUNK_SIZE * \
198
- (int(size_new / self.ARRAY_CHUNK_SIZE) + 1)
199
-
200
- data_new = {
201
- 'updated_at': updated_at_old,
202
- 'arrays_size': arrays_size_new,
203
- 'titles': titles_old + [None] * (arrays_size_new - arrays_size_old),
204
- 'embeddings': torch.empty((arrays_size_new, *self.embedding_shape),
205
- dtype=self.embedding_dtype,
206
- device=self.embedding_device),
207
- 'opened': torch.zeros(arrays_size_new, dtype=torch.bool),
208
- 'closed': torch.zeros(arrays_size_new, dtype=torch.bool),
209
- }
210
 
 
211
  try:
212
- data_new['embeddings'][:arrays_size_old] = self.data[repo]['embeddings']
213
- data_new['opened'][:arrays_size_old] = self.data[repo]['opened']
214
- data_new['closed'][:arrays_size_old] = self.data[repo]['closed']
215
  except:
216
- pass
217
-
218
- self.data[repo] = data_new
219
 
220
- def embeddings_generate(self, repo):
221
- if os.path.exists(self.cache_path):
222
- with open(self.cache_path, 'rb') as file:
223
- self.data = pickle.load(file)
224
- if repo in self.data:
225
- return
226
 
227
- if not repo in self.black_list:
228
- self.black_list[repo] = {}
229
 
230
- black_list = self.black_list[repo]
 
231
 
232
- issues = gitea_fetch_issues('blender', repo, state='all', since=None,
233
- issue_attr_filter=self.issue_attr_filter, exclude=black_list)
 
234
 
235
- # issues = sorted(issues, key=lambda issue: int(issue['number']))
236
 
237
- print("Embedding Issues...")
238
- texts_to_embed = self.create_strings_to_embbed(issues, black_list)
239
- embeddings = self.encode(texts_to_embed)
240
 
241
- self.data_ensure_size(repo, int(issues[0]['number']))
242
- self.data[repo]['updated_at'] = _find_latest_date(issues)
243
 
244
- titles = self.data[repo]['titles']
245
- embeddings_new = self.data[repo]['embeddings']
246
- opened = self.data[repo]['opened']
247
- closed = self.data[repo]['closed']
248
 
249
  for i, issue in enumerate(issues):
250
  number = int(issue['number'])
251
- titles[number] = issue['title']
252
- embeddings_new[number] = embeddings[i]
253
  if issue['state'] == 'open':
254
- opened[number] = True
255
  if issue['state'] == 'closed':
256
- closed[number] = True
257
-
258
- def embeddings_updated_get(self, repo):
259
- with self.lock:
260
- try:
261
- data = self.data[repo]
262
- except:
263
- self.embeddings_generate(repo)
264
- data = self.data[repo]
265
-
266
- black_list = self.black_list[repo]
267
- date_old = data['updated_at']
268
-
269
- issues = gitea_fetch_issues(
270
- 'blender', repo, since=date_old, issue_attr_filter=self.issue_attr_filter, exclude=black_list)
271
 
272
- # Get the most recent date
273
- date_new = _find_latest_date(issues, date_old)
 
 
 
 
274
 
275
- if date_new == date_old:
276
- # Nothing changed
277
- return data
 
278
 
279
- data['updated_at'] = date_new
280
-
281
- # autopep8: off
282
- # Consider that if the time hasn't changed, it's the same issue.
283
- issues = [issue for issue in issues if issue['updated_at'] != date_old]
284
-
285
- self.data_ensure_size(repo, int(issues[0]['number']))
286
-
287
- updated_at = gitea_issues_body_updated_at_get(issues)
288
- issues_to_embed = []
289
-
290
- for i, issue in enumerate(issues):
291
  number = int(issue['number'])
292
- if issue['state'] == 'open':
293
- data['opened'][number] = True
294
- if issue['state'] == 'closed':
295
- data['closed'][number] = True
296
-
297
- title_old = data['titles'][number]
298
- if title_old != issue['title']:
299
- data['titles'][number] = issue['title']
300
- issues_to_embed.append(issue)
301
- elif updated_at[i] >= date_old:
302
- issues_to_embed.append(issue)
303
-
304
- if issues_to_embed:
305
- texts_to_embed = self.create_strings_to_embbed(issues_to_embed, black_list)
306
- embeddings = self.encode(texts_to_embed)
307
-
308
- for i, issue in enumerate(issues_to_embed):
309
- number = int(issue['number'])
310
- data['embeddings'][number] = embeddings[i]
311
 
312
  # autopep8: on
313
- return data
314
-
315
-
316
- router = APIRouter()
317
- EMBEDDING_CTX = EmbeddingContext()
318
- # EMBEDDING_CTX.embeddings_generate('blender', 'blender')
319
- # EMBEDDING_CTX.embeddings_generate('blender', 'blender-addons')
320
-
321
-
322
- # Define your Enum class
323
- class State(str, Enum):
324
- opened = "opened"
325
- closed = "closed"
326
- all = "all"
327
 
328
 
329
  def _sort_similarity(data: dict,
330
- query_emb: torch.Tensor,
331
  limit: int,
332
  state: State = State.opened) -> list:
333
  duplicates = []
@@ -356,7 +239,7 @@ def _sort_similarity(data: dict,
356
 
357
 
358
  def find_relatedness(repo: str, number: int, limit: int = 20, state: State = State.opened):
359
- data = EMBEDDING_CTX.embeddings_updated_get(repo)
360
 
361
  # Check if the embedding already exists.
362
  if data['titles'][number] is not None:
@@ -383,7 +266,7 @@ def find_relatedness(repo: str, number: int, limit: int = 20, state: State = Sta
383
 
384
 
385
  @router.get("/find_related/{repo}/{number}")
386
- def find_related(repo: str = 'blender', number: int = 104399, limit: int = 15, state: State = State.opened):
387
  related = find_relatedness(repo, number, limit=limit, state=state)
388
  return related
389
 
@@ -391,28 +274,26 @@ def find_related(repo: str = 'blender', number: int = 104399, limit: int = 15, s
391
  if __name__ == "__main__":
392
  update_cache = True
393
  if update_cache:
394
- EMBEDDING_CTX.embeddings_updated_get('blender')
395
- EMBEDDING_CTX.embeddings_updated_get('blender-addons')
396
- cache_path = EMBEDDING_CTX.cache_path
397
- with open(cache_path, "wb") as file:
398
  # Converting the embeddings to be CPU compatible, as the virtual machine in use currently only supports the CPU.
399
- for val in EMBEDDING_CTX.data.values():
400
  val['embeddings'] = val['embeddings'].to(torch.device('cpu'))
401
 
402
- pickle.dump(EMBEDDING_CTX.data, file,
403
- protocol=pickle.HIGHEST_PROTOCOL)
404
- else:
405
- # Converting the embeddings to be GPU.
406
- for val in EMBEDDING_CTX.data.values():
407
- val['embeddings'] = val['embeddings'].to(torch.device('cuda'))
408
-
409
- # 'blender/blender/111434' must print #96153, #83604 and #79762
410
- related1 = find_relatedness(
411
- 'blender', 111434, limit=20, state=State.all)
412
- related2 = find_relatedness('blender-addons', 104399, limit=20)
413
-
414
- print("These are the 20 most related issues:")
415
- print(related1)
416
- print()
417
- print("These are the 20 most related issues:")
418
- print(related2)
 
1
+ # routers/find_related.py
2
 
3
  import os
4
  import pickle
 
5
  import torch
6
+ import re
7
 
8
+ from typing import List
9
  from datetime import datetime, timedelta
10
  from enum import Enum
11
+ from sentence_transformers import util
12
  from fastapi import APIRouter
13
 
14
  try:
15
+ from .embedding import EMBEDDING_CTX
16
  from .utils_gitea import gitea_fetch_issues, gitea_json_issue_get, gitea_issues_body_updated_at_get
17
  except:
18
+ from embedding import EMBEDDING_CTX
19
  from utils_gitea import gitea_fetch_issues, gitea_json_issue_get, gitea_issues_body_updated_at_get
20
 
21
 
22
+ router = APIRouter()
23
+
24
+ issue_attr_filter = {'number', 'title', 'body',
25
+ 'state', 'updated_at', 'created_at'}
26
+
27
+ G_cache_path = "routers/embedding/embeddings_issues.pkl"
28
+ G_data = {}
29
+
30
+
31
+ class State(str, Enum):
32
+ opened = "opened"
33
+ closed = "closed"
34
+ all = "all"
35
+
36
+
37
  def _create_issue_string(title, body):
38
  cleaned_body = body.replace('\r', '')
39
  cleaned_body = cleaned_body.replace('**System Information**\n', '')
 
68
  return max((issue['updated_at'] for issue in issues), default=default_str)
69
 
70
 
71
+ def _create_strings_to_embbed(issues):
72
+ texts_to_embed = [_create_issue_string(
73
+ issue['title'], issue['body']) for issue in issues]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
+ return texts_to_embed
 
76
 
 
 
 
77
 
78
+ def _data_ensure_size(repo, size_new):
79
+ global G_data
 
80
 
81
+ ARRAY_CHUNK_SIZE = 4096
 
82
 
83
+ updated_at_old = None
84
+ arrays_size_old = 0
85
+ titles_old = []
86
+ try:
87
+ arrays_size_old = G_data[repo]['arrays_size']
88
+ if size_new <= arrays_size_old:
89
+ return
90
+ except:
91
+ pass
92
 
93
+ arrays_size_new = ARRAY_CHUNK_SIZE * (int(size_new / ARRAY_CHUNK_SIZE) + 1)
94
+
95
+ data_new = {
96
+ 'updated_at': updated_at_old,
97
+ 'arrays_size': arrays_size_new,
98
+ 'titles': titles_old + [None] * (arrays_size_new - arrays_size_old),
99
+ 'embeddings': torch.empty((arrays_size_new, *EMBEDDING_CTX.embedding_shape),
100
+ dtype=EMBEDDING_CTX.embedding_dtype,
101
+ device=EMBEDDING_CTX.embedding_device),
102
+ 'opened': torch.zeros(arrays_size_new, dtype=torch.bool),
103
+ 'closed': torch.zeros(arrays_size_new, dtype=torch.bool),
104
+ }
105
+
106
+ try:
107
+ data_new['embeddings'][:arrays_size_old] = G_data[repo]['embeddings']
108
+ data_new['opened'][:arrays_size_old] = G_data[repo]['opened']
109
+ data_new['closed'][:arrays_size_old] = G_data[repo]['closed']
110
+ except:
111
+ pass
112
 
113
+ G_data[repo] = data_new
 
114
 
 
115
 
116
+ def _embeddings_generate(repo):
117
+ global G_data
118
 
119
+ if os.path.exists(G_cache_path):
120
+ with open(G_cache_path, 'rb') as file:
121
+ G_data = pickle.load(file)
122
+ if repo in G_data:
123
+ return
124
 
125
+ issues = gitea_fetch_issues('blender', repo, state='all', since=None,
126
+ issue_attr_filter=issue_attr_filter)
 
127
 
128
+ # issues = sorted(issues, key=lambda issue: int(issue['number']))
 
 
 
129
 
130
+ print("Embedding Issues...")
131
+ texts_to_embed = _create_strings_to_embbed(issues)
132
+ embeddings = EMBEDDING_CTX.encode(texts_to_embed)
133
 
134
+ _data_ensure_size(repo, int(issues[0]['number']))
135
+ G_data[repo]['updated_at'] = _find_latest_date(issues)
 
136
 
137
+ titles = G_data[repo]['titles']
138
+ embeddings_new = G_data[repo]['embeddings']
139
+ opened = G_data[repo]['opened']
140
+ closed = G_data[repo]['closed']
 
 
141
 
142
+ for i, issue in enumerate(issues):
143
+ number = int(issue['number'])
144
+ titles[number] = issue['title']
145
+ embeddings_new[number] = embeddings[i]
146
+ if issue['state'] == 'open':
147
+ opened[number] = True
148
+ if issue['state'] == 'closed':
149
+ closed[number] = True
150
 
 
151
 
152
+ def _embeddings_updated_get(repo):
153
+ global G_data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
 
155
+ with EMBEDDING_CTX.lock:
156
  try:
157
+ data_repo = G_data[repo]
 
 
158
  except:
159
+ _embeddings_generate(repo)
160
+ data_repo = G_data[repo]
 
161
 
162
+ date_old = data_repo['updated_at']
 
 
 
 
 
163
 
164
+ issues = gitea_fetch_issues(
165
+ 'blender', repo, since=date_old, issue_attr_filter=issue_attr_filter)
166
 
167
+ # Get the most recent date
168
+ date_new = _find_latest_date(issues, date_old)
169
 
170
+ if date_new == date_old:
171
+ # Nothing changed
172
+ return data_repo
173
 
174
+ data_repo['updated_at'] = date_new
175
 
176
+ # autopep8: off
177
+ # Consider that if the time hasn't changed, it's the same issue.
178
+ issues = [issue for issue in issues if issue['updated_at'] != date_old]
179
 
180
+ _data_ensure_size(repo, int(issues[0]['number']))
 
181
 
182
+ updated_at = gitea_issues_body_updated_at_get(issues)
183
+ issues_to_embed = []
 
 
184
 
185
  for i, issue in enumerate(issues):
186
  number = int(issue['number'])
 
 
187
  if issue['state'] == 'open':
188
+ data_repo['opened'][number] = True
189
  if issue['state'] == 'closed':
190
+ data_repo['closed'][number] = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
 
192
+ title_old = data_repo['titles'][number]
193
+ if title_old != issue['title']:
194
+ data_repo['titles'][number] = issue['title']
195
+ issues_to_embed.append(issue)
196
+ elif updated_at[i] >= date_old:
197
+ issues_to_embed.append(issue)
198
 
199
+ if issues_to_embed:
200
+ print(f"Embedding {len(issues_to_embed)} issue{'s' if len(issues_to_embed) > 1 else ''}")
201
+ texts_to_embed = _create_strings_to_embbed(issues_to_embed)
202
+ embeddings = EMBEDDING_CTX.encode(texts_to_embed)
203
 
204
+ for i, issue in enumerate(issues_to_embed):
 
 
 
 
 
 
 
 
 
 
 
205
  number = int(issue['number'])
206
+ data_repo['embeddings'][number] = embeddings[i]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
 
208
  # autopep8: on
209
+ return data_repo
 
 
 
 
 
 
 
 
 
 
 
 
 
210
 
211
 
212
  def _sort_similarity(data: dict,
213
+ query_emb: List[torch.Tensor],
214
  limit: int,
215
  state: State = State.opened) -> list:
216
  duplicates = []
 
239
 
240
 
241
  def find_relatedness(repo: str, number: int, limit: int = 20, state: State = State.opened):
242
+ data = _embeddings_updated_get(repo)
243
 
244
  # Check if the embedding already exists.
245
  if data['titles'][number] is not None:
 
266
 
267
 
268
  @router.get("/find_related/{repo}/{number}")
269
+ def find_related(repo: str = 'blender', number: int = 104399, limit: int = 15, state: State = State.opened) -> str:
270
  related = find_relatedness(repo, number, limit=limit, state=state)
271
  return related
272
 
 
274
  if __name__ == "__main__":
275
  update_cache = True
276
  if update_cache:
277
+ _embeddings_updated_get('blender')
278
+ _embeddings_updated_get('blender-addons')
279
+ with open(G_cache_path, "wb") as file:
 
280
  # Converting the embeddings to be CPU compatible, as the virtual machine in use currently only supports the CPU.
281
+ for val in G_data.values():
282
  val['embeddings'] = val['embeddings'].to(torch.device('cpu'))
283
 
284
+ pickle.dump(G_data, file, protocol=pickle.HIGHEST_PROTOCOL)
285
+
286
+ # Converting the embeddings to be GPU.
287
+ for val in G_data.values():
288
+ val['embeddings'] = val['embeddings'].to(torch.device('cuda'))
289
+
290
+ # 'blender/blender/111434' must print #96153, #83604 and #79762
291
+ related1 = find_relatedness(
292
+ 'blender', 111434, limit=20, state=State.all)
293
+ related2 = find_relatedness('blender-addons', 104399, limit=20)
294
+
295
+ print("These are the 20 most related issues:")
296
+ print(related1)
297
+ print()
298
+ print("These are the 20 most related issues:")
299
+ print(related2)
 
routers/tool_wiki_search.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # routers/wiki_search.py
2
+
3
+ import os
4
+ import pickle
5
+ import re
6
+ from typing import Dict, List
7
+ from sentence_transformers import util
8
+ from fastapi import APIRouter
9
+
10
+ try:
11
+ from .embedding import EMBEDDING_CTX
12
+ except:
13
+ from embedding import EMBEDDING_CTX
14
+
15
+ router = APIRouter()
16
+
17
+ MANUAL_DIR = "D:/BlenderDev/blender-manual/manual/"
18
+ BASE_URL = "https://docs.blender.org/manual/en/dev"
19
+ G_cache_path = "routers/embedding/embeddings_manual.pkl"
20
+ G_data = None
21
+
22
+
23
+ def _embeddings_generate():
24
+ global G_data
25
+
26
+ if os.path.exists(G_cache_path):
27
+ with open(G_cache_path, 'rb') as file:
28
+ G_data = pickle.load(file)
29
+ return G_data
30
+
31
+ # path = 'addons/3d_view'
32
+ G_data = parse_file_recursive(MANUAL_DIR, 'index.rst')
33
+ G_data['toctree']["copyright"] = parse_file_recursive(
34
+ MANUAL_DIR, 'copyright.rst')
35
+
36
+ # Create a list to store the text files
37
+ texts = get_texts_recursive(data)
38
+
39
+ print("Embedding Texts...")
40
+ G_data['texts'] = texts
41
+ G_data['embeddings'] = EMBEDDING_CTX.encode(texts)
42
+
43
+ with open(self.cache_path, "wb") as file:
44
+ # Converting the embeddings to be CPU compatible, as the virtual machine in use currently only supports the CPU.
45
+ G_data['embeddings'] = G_data['embeddings'].to(
46
+ torch.device('cpu'))
47
+
48
+ pickle.dump(G_data, file, protocol=pickle.HIGHEST_PROTOCOL)
49
+
50
+ return G_data
51
+
52
+
53
+ def reduce_text(text):
54
+ # Remove repeated characters
55
+ text = re.sub(r'%{2,}', '', text) # Title
56
+ text = re.sub(r'#{2,}', '', text) # Title
57
+ text = re.sub(r'\*{3,}', '', text) # Title
58
+ text = re.sub(r'={3,}', '', text) # Topic
59
+ text = re.sub(r'\^{3,}', '', text)
60
+ text = re.sub(r'-{3,}', '', text)
61
+
62
+ text = re.sub(r'(\s*\n\s*)+', '\n', text)
63
+ return text
64
+
65
+
66
+ def parse_file_recursive(filedir, filename):
67
+ with open(os.path.join(filedir, filename), 'r', encoding='utf-8') as file:
68
+ content = file.read()
69
+
70
+ parsed_data = {}
71
+
72
+ if not filename.endswith('index.rst'):
73
+ body = content.strip()
74
+ else:
75
+ parts = content.split(".. toctree::")
76
+ body = parts[0].strip()
77
+
78
+ if len(parts) > 1:
79
+ parsed_data["toctree"] = {}
80
+ for part in parts[1:]:
81
+ toctree_entries = part.split('\n')
82
+ line = toctree_entries[0]
83
+ for entry in toctree_entries[1:]:
84
+ entry = entry.strip()
85
+ if not entry:
86
+ continue
87
+
88
+ if entry.startswith('/'):
89
+ # relative path.
90
+ continue
91
+
92
+ if not entry.endswith('.rst'):
93
+ continue
94
+
95
+ if entry.endswith('/index.rst'):
96
+ entry_name = entry[:-10]
97
+ filedir_ = os.path.join(filedir, entry_name)
98
+ filename_ = 'index.rst'
99
+ else:
100
+ entry_name = entry[:-4]
101
+ filedir_ = filedir
102
+ filename_ = entry
103
+
104
+ parsed_data['toctree'][entry_name] = parse_file_recursive(
105
+ filedir_, filename_)
106
+
107
+ # The '\n' at the end of the file resolves regex patterns
108
+ parsed_data['body'] = body + '\n'
109
+
110
+ return parsed_data
111
+
112
+
113
+ def split_into_topics(text: str, prefix: str = '') -> Dict[str, List[str]]:
114
+ """
115
+ Splits a text into sections based on titles and subtitles, and organizes them into a dictionary.
116
+
117
+ Args:
118
+ text (str): The input text to be split. The text should contain titles marked by asterisks (***)
119
+ or subtitles marked by equal signs (===).
120
+ prefix (str): prefix to titles and subtitles
121
+
122
+ Returns:
123
+ Dict[str, List[str]]: A dictionary where keys are section titles or subtitles, and values are lists of
124
+ strings corresponding to the content under each title or subtitle.
125
+
126
+ Example:
127
+ text = '''
128
+ *********************
129
+ The Blender Community
130
+ *********************
131
+
132
+ Being freely available from the start.
133
+
134
+ Independent Sites
135
+ =================
136
+
137
+ There are `several independent websites.
138
+
139
+ Getting Support
140
+ ===============
141
+
142
+ Blender's community is one of its greatest features.
143
+ '''
144
+
145
+ result = split_in_topics(text)
146
+ # result will be:
147
+ # {
148
+ # "# The Blender Community": [
149
+ # "Being freely available from the start."
150
+ # ],
151
+ # "# The Blender Community | Independent Sites": [
152
+ # "There are `several independent websites."
153
+ # ],
154
+ # "# The Blender Community | Getting Support": [
155
+ # "Blender's community is one of its greatest features."
156
+ # ]
157
+ # }
158
+ """
159
+
160
+ # Remove patterns ".. word::" and ":word:"
161
+ text = re.sub(r'\.\. [^\n]+\n+(?: {3,}[^\n]*\n)*|:\w+:', '', text)
162
+
163
+ # Regular expression to find titles and subtitles
164
+ pattern = r'([\*|#|%]{3,}\n[^\n]+\n[\*|#|%]{3,}|(?:={3,}\n)?[^\n]+\n={3,}\n)'
165
+
166
+ # Split text by found patterns
167
+ sections = re.split(pattern, text)
168
+
169
+ # Remove possible white spaces at the beginning and end of each section
170
+ sections = [section for section in sections if section.strip()]
171
+
172
+ # Separate sections into a dictionary
173
+ topics = {}
174
+ current_title = ''
175
+ current_topic = prefix
176
+
177
+ for section in sections:
178
+ if match := re.match(r'[\*|#|%]{3,}\n([^\n]+)\n[\*|#|%]{3,}', section):
179
+ current_topic = current_title = f'{prefix}# {match.group(1)}'
180
+ topics[current_topic] = []
181
+ elif match := re.match(r'(?:={3,}\n)?([^\n]+)\n={3,}\n', section):
182
+ current_topic = current_title + ' | ' + match.group(1)
183
+ topics[current_topic] = []
184
+ else:
185
+ if current_topic == prefix:
186
+ raise
187
+ topics[current_topic].append(section)
188
+
189
+ return topics
190
+
191
+
192
+ # Function to split the text into chunks of a maximum number of tokens
193
+ def split_into_many(page_body, prefix=''):
194
+ tokenizer = EMBEDDING_CTX.model.tokenizer
195
+ max_tokens = EMBEDDING_CTX.model.max_seq_length
196
+ topics = split_into_topics(page_body, prefix)
197
+
198
+ for topic, content_list in topics.items():
199
+ title = topic + ':\n'
200
+ title_tokens_len = len(tokenizer.tokenize(title))
201
+ content_list_new = []
202
+ for content in content_list:
203
+ content_reduced = reduce_text(content)
204
+ content_tokens_len = len(tokenizer.tokenize(content_reduced))
205
+ if title_tokens_len + content_tokens_len <= max_tokens:
206
+ content_list_new.append(content_reduced)
207
+ continue
208
+
209
+ # Split the text into sentences
210
+ paragraphs = content_reduced.split('.\n')
211
+ sentences = ''
212
+ tokens_so_far = title_tokens_len
213
+
214
+ # Loop through the sentences and tokens joined together in a tuple
215
+ for sentence in paragraphs:
216
+ sentence += '.\n'
217
+
218
+ # Get the number of tokens for each sentence
219
+ n_tokens = len(tokenizer.tokenize(sentence))
220
+
221
+ # If the number of tokens so far plus the number of tokens in the current sentence is greater
222
+ # than the max number of tokens, then add the chunk to the list of chunks and reset
223
+ # the chunk and tokens so far
224
+ if tokens_so_far + n_tokens > max_tokens:
225
+ content_list_new.append(sentences)
226
+ sentences = ''
227
+ tokens_so_far = title_tokens_len
228
+
229
+ sentences += sentence
230
+ tokens_so_far += n_tokens
231
+
232
+ if sentences:
233
+ content_list_new.append(sentences)
234
+
235
+ # Replace content_list
236
+ content_list.clear()
237
+ content_list.extend(content_list_new)
238
+
239
+ result = []
240
+ for topic, content_list in topics.items():
241
+ for content in content_list:
242
+ result.append(topic + ':\n' + content)
243
+
244
+ return result
245
+
246
+
247
+ def get_texts_recursive(page, path=''):
248
+ result = split_into_many(page['body'], path)
249
+
250
+ try:
251
+ for key in page['toctree'].keys():
252
+ page_child = page['toctree'][key]
253
+ result.extend(get_texts_recursive(page_child, f'{path}/{key}'))
254
+ except KeyError:
255
+ pass
256
+
257
+ return result
258
+
259
+
260
+ def _sort_similarity(data, text_to_search, limit):
261
+ results = []
262
+
263
+ query_emb = EMBEDDING_CTX.encode([text_to_search])
264
+ ret = util.semantic_search(
265
+ query_emb, data['embeddings'], top_k=limit, score_function=util.dot_score)
266
+
267
+ texts = data['texts']
268
+ for score in ret[0]:
269
+ corpus_id = score['corpus_id']
270
+ text = texts[corpus_id]
271
+ results.append(text)
272
+
273
+ return results
274
+
275
+
276
+ @router.get("/wiki_search")
277
+ def wiki_search(query: str = "") -> str:
278
+ data = _embeddings_generate()
279
+ texts = _sort_similarity(data, query, 5)
280
+
281
+ result = f'BASE_URL: {BASE_URL}\n'
282
+ for text in texts:
283
+ index = text.find('#')
284
+ result += f'''---
285
+ {text[:index] + '.html'}
286
+ {text[index:]}
287
+
288
+ '''
289
+ return result
290
+
291
+
292
+ if __name__ == '__main__':
293
+ tests = ["Set Snap Base", "Building the Manual", "Bisect Object"]
294
+ result = wiki_search(tests[0])
295
+ print(result)
utils/generate_blender_doc.py DELETED
@@ -1,194 +0,0 @@
1
- import os
2
- import sys
3
- import re
4
- from sentence_transformers import util
5
-
6
- script_dir = os.path.dirname(os.path.realpath(__file__))
7
- parent_dir = os.path.dirname(script_dir)
8
- sys.path.append(parent_dir)
9
-
10
- # autopep8: off
11
- from routers.tool_find_related import EMBEDDING_CTX
12
- # autopep8: on
13
-
14
- MANUAL_DIR = "D:/BlenderDev/blender-manual/manual/"
15
- BASE_URL = "https://docs.blender.org/manual/en/dev"
16
-
17
-
18
- def process_text(text):
19
- # Remove repeated characters
20
- text = re.sub(r'%{2,}', '', text)
21
- text = re.sub(r'#{2,}', '', text)
22
- text = re.sub(r'={3,}', '', text)
23
- text = re.sub(r'\*{3,}', '', text)
24
- text = re.sub(r'\^{3,}', '', text)
25
- text = re.sub(r'-{3,}', '', text)
26
-
27
- # Remove patterns ".. word:: " and ":word:"
28
- text = re.sub(r'\.\. \S+', '', text)
29
- text = re.sub(r':\w+:', '', text)
30
-
31
- text = re.sub(r'(\s*\n\s*)+', '\n', text)
32
- return text
33
-
34
-
35
- def parse_file(filedir, filename):
36
- with open(os.path.join(filedir, filename), 'r', encoding='utf-8') as file:
37
- content = file.read()
38
-
39
- parsed_data = {}
40
-
41
- if not filename.endswith('index.rst'):
42
- body = content.strip()
43
- else:
44
- parts = content.split(".. toctree::")
45
- body = parts[0].strip()
46
-
47
- if len(parts) > 1:
48
- parsed_data["toctree"] = {}
49
- for part in parts[1:]:
50
- toctree_entries = part.split('\n')
51
- line = toctree_entries[0]
52
- for entry in toctree_entries[1:]:
53
- entry = entry.strip()
54
- if not entry:
55
- continue
56
-
57
- if entry.startswith('/'):
58
- # relative path.
59
- continue
60
-
61
- if not entry.endswith('.rst'):
62
- continue
63
-
64
- if entry.endswith('/index.rst'):
65
- entry_name = entry[:-10]
66
- filedir_ = os.path.join(filedir, entry_name)
67
- filename_ = 'index.rst'
68
- else:
69
- entry_name = entry[:-4]
70
- filedir_ = filedir
71
- filename_ = entry
72
-
73
- parsed_data['toctree'][entry_name] = parse_file(
74
- filedir_, filename_)
75
-
76
- processed_text = process_text(body)
77
- tokens = EMBEDDING_CTX.model.tokenizer.tokenize(processed_text)
78
- if len(tokens) > EMBEDDING_CTX.model.max_seq_length:
79
- pass
80
- # parsed_data['body'] = body
81
- parsed_data['processed_text'] = processed_text
82
- parsed_data['n_tokens'] = len(tokens)
83
-
84
- return parsed_data
85
-
86
-
87
- # Function to split the text into chunks of a maximum number of tokens
88
- def split_into_many(text, max_tokens):
89
-
90
- # Split the text into sentences
91
- paragraphs = text.split('.\n')
92
-
93
- # Get the number of tokens for each sentence
94
- n_tokens = [len(EMBEDDING_CTX.model.tokenizer.tokenize(" " + sentence))
95
- for sentence in paragraphs]
96
-
97
- chunks = []
98
- tokens_so_far = 0
99
- chunk = []
100
-
101
- # Loop through the sentences and tokens joined together in a tuple
102
- for sentence, token in zip(paragraphs, n_tokens):
103
-
104
- # If the number of tokens so far plus the number of tokens in the current sentence is greater
105
- # than the max number of tokens, then add the chunk to the list of chunks and reset
106
- # the chunk and tokens so far
107
- if tokens_so_far + token > max_tokens:
108
- chunks.append((".\n".join(chunk) + ".", tokens_so_far))
109
- chunk = []
110
- tokens_so_far = 0
111
-
112
- # If the number of tokens in the current sentence is greater than the max number of
113
- # tokens, go to the next sentence
114
- if token > max_tokens:
115
- continue
116
-
117
- # Otherwise, add the sentence to the chunk and add the number of tokens to the total
118
- chunk.append(sentence)
119
- tokens_so_far += token + 1
120
-
121
- if chunk:
122
- chunks.append((".\n".join(chunk) + ".", tokens_so_far))
123
-
124
- return chunks
125
-
126
-
127
- def get_texts(data, path):
128
- result = []
129
- processed_texts = [data['processed_text']]
130
- processed_tokens = [data['n_tokens']]
131
- max_tokens = EMBEDDING_CTX.model.max_seq_length
132
-
133
- data_ = data
134
- for key in path:
135
- data_ = data_['toctree'][key]
136
- processed_texts.append(data_['processed_text'])
137
- processed_tokens.append(data_['n_tokens'])
138
-
139
- if processed_tokens[-1] > max_tokens:
140
- chunks = split_into_many(processed_texts[-1], max_tokens)
141
- else:
142
- chunks = [(processed_texts[-1], processed_tokens[-1])]
143
-
144
- for text, n_tokens in chunks:
145
- # Add context to the text if we have space
146
- for i in range(len(processed_texts) - 2, -1, -1):
147
- n_tokens_parent = processed_tokens[i]
148
- if n_tokens + n_tokens_parent >= max_tokens:
149
- break
150
-
151
- text_parent = processed_texts[i]
152
- text = text_parent + '\n' + text
153
- n_tokens += n_tokens_parent
154
-
155
- result.append([path, text])
156
-
157
- try:
158
- for key in data_['toctree'].keys():
159
- result.extend(get_texts(data, path + [key]))
160
- except KeyError:
161
- pass
162
-
163
- return result
164
-
165
-
166
- def _sort_similarity(chunks, embeddings, text_to_search, limit):
167
- results = []
168
-
169
- query_emb = EMBEDDING_CTX.encode([text_to_search])
170
- ret = util.semantic_search(
171
- query_emb, embeddings, top_k=limit, score_function=util.dot_score)
172
-
173
- for score in ret[0]:
174
- corpus_id = score['corpus_id']
175
- chunk = chunks[corpus_id]
176
- path = chunk[0]
177
- results.append(path)
178
-
179
- return results
180
-
181
-
182
- if __name__ == '__main__':
183
- # path = 'addons/3d_view'
184
- data = parse_file(MANUAL_DIR, 'index.rst')
185
- data['toctree']["copyright"] = parse_file(MANUAL_DIR, 'copyright.rst')
186
-
187
- # Create a list to store the text files
188
- chunks = []
189
- chunks.extend(get_texts(data, []))
190
-
191
- embeddings = EMBEDDING_CTX.encode([text for path, text in chunks])
192
-
193
- result = _sort_similarity(chunks, embeddings, "Set Snap Base", 50)
194
- print(result)