Javierss commited on
Commit
419bab3
·
1 Parent(s): fa48fcf

Update model download method

Browse files
Files changed (1) hide show
  1. game.py +11 -4
game.py CHANGED
@@ -67,6 +67,13 @@ class Model_class:
67
 
68
  def __init__(self, lang=0, model_type="SentenceTransformer"):
69
 
 
 
 
 
 
 
 
70
  # Check if the model exists, clone it if it doesn't
71
  if not os.path.exists(
72
  os.path.join(self.base_path, "config/strans_models/")
@@ -76,25 +83,25 @@ class Model_class:
76
  if lang == 1:
77
  if model_type == "word2vec":
78
  self.model = KeyedVectors.load(
79
- model_path,
80
  mmap="r",
81
  )
82
  elif model_type == "SentenceTransformer":
83
  self.model = KeyedVectors.load(
84
- model_path,
85
  mmap="r",
86
  )
87
 
88
  else:
89
  if model_type == "word2vec":
90
  self.model = KeyedVectors.load(
91
- model_path,
92
  mmap="r",
93
  )
94
 
95
  elif model_type == "SentenceTransformer":
96
  self.model = KeyedVectors.load(
97
- model_path,
98
  mmap="r",
99
  )
100
 
 
67
 
68
  def __init__(self, lang=0, model_type="SentenceTransformer"):
69
 
70
+ if model_type == "SentenceTransformer":
71
+ repo_url = "git@hf.co:Jsevisal/strans_models"
72
+ dest_path = "config/strans_models/"
73
+ else:
74
+ repo_url = "git@hf.co:Jsevisal/w2v_models"
75
+ dest_path = "config/w2v_models/"
76
+
77
  # Check if the model exists, clone it if it doesn't
78
  if not os.path.exists(
79
  os.path.join(self.base_path, "config/strans_models/")
 
83
  if lang == 1:
84
  if model_type == "word2vec":
85
  self.model = KeyedVectors.load(
86
+ os.path.join(model_path, "eng_w2v_model"),
87
  mmap="r",
88
  )
89
  elif model_type == "SentenceTransformer":
90
  self.model = KeyedVectors.load(
91
+ os.path.join(model_path, "eng_strans_model"),
92
  mmap="r",
93
  )
94
 
95
  else:
96
  if model_type == "word2vec":
97
  self.model = KeyedVectors.load(
98
+ os.path.join(model_path, "esp_w2v_model"),
99
  mmap="r",
100
  )
101
 
102
  elif model_type == "SentenceTransformer":
103
  self.model = KeyedVectors.load(
104
+ os.path.join(model_path, "esp_strans_model"),
105
  mmap="r",
106
  )
107