ierhon commited on
Commit
c22db75
1 Parent(s): cc641ed

Fix incorrect input length error using hash as data

Browse files
Files changed (1) hide show
  1. chatbot_constructor.py +4 -1
chatbot_constructor.py CHANGED
@@ -29,7 +29,7 @@ def todset(text: str):
29
  def hash_str(data: str):
30
  return hashlib.md5(data.encode('utf-8')).hexdigest()
31
 
32
- def train(message: str = "", epochs: int = 16, learning_rate: float = 0.001, emb_size: int = 128, inp_len: int = 16, kernels_count: int = 8, kernel_size: int = 8, data: str = ""):
33
  data_hash = None
34
  if "→" not in data or "\n" not in data:
35
  if data in os.listdir("cache"):
@@ -42,10 +42,13 @@ def train(message: str = "", epochs: int = 16, learning_rate: float = 0.001, emb
42
  tokenizer.fit_on_texts(list(dset.keys()))
43
 
44
  vocab_size = len(tokenizer.word_index) + 1
 
45
  if data_hash is None:
46
  data_hash = hash_str(data)+"_"+str(epochs)+"_"+str(learning_rate)+"_"+str(emb_size)+"_"+str(inp_len)+"_"+str(kernels_count)+"_"+str(kernel_size)+".keras"
47
  elif message == "!getmodelhash":
48
  return data_hash
 
 
49
  if data_hash in os.listdir("cache"):
50
  model = load_model("cache/"+data_hash)
51
  else:
 
29
  def hash_str(data: str):
30
  return hashlib.md5(data.encode('utf-8')).hexdigest()
31
 
32
+ def train(message: str = "", epochs: int = 16, learning_rate: float = 0.001, emb_size: int = 128, input_len: int = 16, kernels_count: int = 8, kernel_size: int = 8, data: str = ""):
33
  data_hash = None
34
  if "→" not in data or "\n" not in data:
35
  if data in os.listdir("cache"):
 
42
  tokenizer.fit_on_texts(list(dset.keys()))
43
 
44
  vocab_size = len(tokenizer.word_index) + 1
45
+ inp_len = input_len
46
  if data_hash is None:
47
  data_hash = hash_str(data)+"_"+str(epochs)+"_"+str(learning_rate)+"_"+str(emb_size)+"_"+str(inp_len)+"_"+str(kernels_count)+"_"+str(kernel_size)+".keras"
48
  elif message == "!getmodelhash":
49
  return data_hash
50
+ else:
51
+ inp_len = int(data_hash.split("_")[-3])
52
  if data_hash in os.listdir("cache"):
53
  model = load_model("cache/"+data_hash)
54
  else: