CaoHaiNam commited on
Commit
3ca6892
·
1 Parent(s): 3a379e2

update code

Browse files
Files changed (3) hide show
  1. parameters.py +4 -3
  2. siameser.py +1 -13
  3. utils.py +19 -3
parameters.py CHANGED
@@ -1,7 +1,8 @@
1
  # transformer model
2
  embedding_model = 'CaoHaiNam/vietnamese-address-embedding'
3
- local_embedding_model = 'embedding-model'
4
-
5
 
6
  NORM_ADDS_FILE_ALL_1 = 'data/standard_address_all_1.json'
7
- STD_EMBEDDING_FILE_ALL_1 = 'data/address_matrix_all_1.pt'
 
 
 
 
1
  # transformer model
2
  embedding_model = 'CaoHaiNam/vietnamese-address-embedding'
 
 
3
 
4
  NORM_ADDS_FILE_ALL_1 = 'data/standard_address_all_1.json'
5
+ STD_EMBEDDING_FILE_ALL_1 = 'data/address_matrix_all_1.pt'
6
+
7
+ LOG_DIRECTORY = 'logs'
8
+ LOG_RESULT_FILE = 'logs.json'
siameser.py CHANGED
@@ -13,14 +13,8 @@ device = torch.device('cpu')
13
 
14
  class Siameser:
15
  def __init__(self, model_name=None, stadard_scope=None):
16
- # print('Load model')
17
  print("Load sentence embedding model (If this is the first time you run this repo, It could be take time to download sentence embedding model)")
18
  self.threshold = 0.61
19
- # if os.path.isdir(parameters.local_embedding_model):
20
- # self.embedding_model = SentenceTransformer(parameters.local_embedding_model).to(device)
21
- # else:
22
- # self.embedding_model = SentenceTransformer(parameters.embedding_model).to(device)
23
- # self.embedding_model.save(parameters.local_embedding_model)
24
  self.embedding_model = SentenceTransformer(parameters.embedding_model).to(device)
25
 
26
  if stadard_scope == 'all':
@@ -55,10 +49,8 @@ class Siameser:
55
  else:
56
  score = F.cosine_similarity(raw_add_vectors, self.std_embeddings)
57
  s, top_k = score.topk(1)
58
- # print(s, top_k)
59
- # return
60
  s, idx = s.tolist()[0], top_k.tolist()[0]
61
- # if s < 0.57:
62
  if s < self.threshold:
63
  return {'Format Error': 'Xâu truyền vào không phải địa chỉ, mời nhập lại.'}
64
  std_add = self.NORM_ADDS[str(idx)]
@@ -75,8 +67,6 @@ class Siameser:
75
  score = F.cosine_similarity(raw_add_vectors, self.std_embeddings)
76
  s, top_k = score.topk(k)
77
  s, top_k = s.tolist(), top_k.tolist()
78
- # print(s, top_k)
79
- # return
80
 
81
  if s[0] < self.threshold:
82
  return {'Format Error': 'Dường như xâu truyền vào không phải địa chỉ, mời nhập lại.'}, {}
@@ -86,6 +76,4 @@ class Siameser:
86
  std_add = self.NORM_ADDS[str(idx)]
87
  top_std_adds.append(utils.get_full_result(raw_add_, std_add, round(score, 4)))
88
 
89
- x1, x2 = top_std_adds[0], top_std_adds[1]
90
-
91
  return top_std_adds[0], top_std_adds
 
13
 
14
  class Siameser:
15
  def __init__(self, model_name=None, stadard_scope=None):
 
16
  print("Load sentence embedding model (If this is the first time you run this repo, It could be take time to download sentence embedding model)")
17
  self.threshold = 0.61
 
 
 
 
 
18
  self.embedding_model = SentenceTransformer(parameters.embedding_model).to(device)
19
 
20
  if stadard_scope == 'all':
 
49
  else:
50
  score = F.cosine_similarity(raw_add_vectors, self.std_embeddings)
51
  s, top_k = score.topk(1)
52
+
 
53
  s, idx = s.tolist()[0], top_k.tolist()[0]
 
54
  if s < self.threshold:
55
  return {'Format Error': 'Xâu truyền vào không phải địa chỉ, mời nhập lại.'}
56
  std_add = self.NORM_ADDS[str(idx)]
 
67
  score = F.cosine_similarity(raw_add_vectors, self.std_embeddings)
68
  s, top_k = score.topk(k)
69
  s, top_k = s.tolist(), top_k.tolist()
 
 
70
 
71
  if s[0] < self.threshold:
72
  return {'Format Error': 'Dường như xâu truyền vào không phải địa chỉ, mời nhập lại.'}, {}
 
76
  std_add = self.NORM_ADDS[str(idx)]
77
  top_std_adds.append(utils.get_full_result(raw_add_, std_add, round(score, 4)))
78
 
 
 
79
  return top_std_adds[0], top_std_adds
utils.py CHANGED
@@ -1,6 +1,9 @@
1
  # import numpy as np
2
  import re
3
  import string
 
 
 
4
 
5
  # delete tone and lower
6
  anphabet = ['a', 'ă', 'â', 'b', 'c', 'd',
@@ -39,10 +42,9 @@ def remove_accent(text):
39
  # remove functuation
40
  def remove_punctuation(text):
41
 
42
- punctuation = r"""!"#$%&'()*+,-./:;<=>?@[\]^_`{|}~"""
43
  whitespace = ' '
44
  for i in text:
45
- if i in punctuation:
46
  text = text.replace(i, whitespace)
47
  return ' '.join(text.split())
48
 
@@ -95,4 +97,18 @@ def get_full_result(raw_address, std_address, score):
95
  full_result['detail_address'] = get_detail_address(raw_address, std_address)
96
  full_result['main_address'] = std_address
97
  full_result['similarity_score'] = score
98
- return full_result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # import numpy as np
2
  import re
3
  import string
4
+ import json
5
+ from datetime import datetime
6
+ from typing import Text, Dict
7
 
8
  # delete tone and lower
9
  anphabet = ['a', 'ă', 'â', 'b', 'c', 'd',
 
42
  # remove functuation
43
  def remove_punctuation(text):
44
 
 
45
  whitespace = ' '
46
  for i in text:
47
+ if i in string.punctuation:
48
  text = text.replace(i, whitespace)
49
  return ' '.join(text.split())
50
 
 
97
  full_result['detail_address'] = get_detail_address(raw_address, std_address)
98
  full_result['main_address'] = std_address
99
  full_result['similarity_score'] = score
100
+ return full_result
101
+
102
+
103
+ def save_result(file_path: Text, result: Dict) -> None:
104
+ log_sample = dict()
105
+ log_sample['result'] = result
106
+ log_sample['created_at'] = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
107
+ logs = json.load(open(file_path, "r", encoding="utf8"))
108
+ logs.append(log_sample)
109
+ json.dump(
110
+ logs,
111
+ open(file_path, "w", encoding="utf8"),
112
+ ensure_ascii=False,
113
+ indent=4
114
+ )