BeveledCube commited on
Commit
564bd7c
·
1 Parent(s): e8caf01

Added max new tokens value

Browse files
main.py CHANGED
@@ -1,5 +1,5 @@
1
  from flask import Flask, request, render_template, jsonify
2
- from models import tiny as chatbot
3
 
4
  app = Flask("AI API")
5
 
 
1
  from flask import Flask, request, render_template, jsonify
2
+ from models import tinystories as chatbot
3
 
4
  app = Flask("AI API")
5
 
models/fast.py CHANGED
@@ -11,6 +11,6 @@ def load():
11
 
12
  def generate(input_text):
13
  input_ids = tokenizer.encode(input_text, return_tensors="pt")
14
- output_ids = model.generate(input_ids, num_beams=5, no_repeat_ngram_size=2)
15
 
16
  return tokenizer.decode(output_ids[0], skip_special_tokens=True)
 
11
 
12
  def generate(input_text):
13
  input_ids = tokenizer.encode(input_text, return_tensors="pt")
14
+ output_ids = model.generate(input_ids, num_beams=5, no_repeat_ngram_size=2, max_new_tokens=100)
15
 
16
  return tokenizer.decode(output_ids[0], skip_special_tokens=True)
models/gpt2.py CHANGED
@@ -16,6 +16,6 @@ def generate(input_text):
16
  attention_mask = tf.ones_like(input_ids)
17
 
18
  # Generate output using the model
19
- output_ids = model.generate(input_ids, num_beams=5, no_repeat_ngram_size=2)
20
 
21
  return tokenizer.decode(output_ids[0], skip_special_tokens=True)
 
16
  attention_mask = tf.ones_like(input_ids)
17
 
18
  # Generate output using the model
19
+ output_ids = model.generate(input_ids, num_beams=5, no_repeat_ngram_size=2, max_new_tokens=100)
20
 
21
  return tokenizer.decode(output_ids[0], skip_special_tokens=True)
models/hermes.py CHANGED
@@ -2,24 +2,17 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
2
 
3
  model_name = "NousResearch/Hermes-2-Pro-Llama-3-8B"
4
 
5
- model = None
6
- tokenizer = None
7
 
8
  # Example messages input
9
  # messages = [
10
  # {"role": "system", "content": "You are Hermes 2."},
11
  # {"role": "user", "content": "Hello, who are you?"}
12
  #]
13
-
14
- def load():
15
- global model
16
- global tokenizer
17
-
18
- model = AutoModelForCausalLM.from_pretrained(model_name)
19
- tokenizer = AutoTokenizer.from_pretrained(model_name)
20
 
21
  def generate(messages):
22
  gen_input = tokenizer.apply_chat_template(messages, return_tensors="pt")
23
- output_ids = model.generate(**gen_input, num_beams=5, no_repeat_ngram_size=2)
24
 
25
  return tokenizer.decode(output_ids[0], skip_special_tokens=True)
 
2
 
3
  model_name = "NousResearch/Hermes-2-Pro-Llama-3-8B"
4
 
5
+ model = AutoModelForCausalLM.from_pretrained(model_name)
6
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
7
 
8
  # Example messages input
9
  # messages = [
10
  # {"role": "system", "content": "You are Hermes 2."},
11
  # {"role": "user", "content": "Hello, who are you?"}
12
  #]
 
 
 
 
 
 
 
13
 
14
  def generate(messages):
15
  gen_input = tokenizer.apply_chat_template(messages, return_tensors="pt")
16
+ output_ids = model.generate(**gen_input, num_beams=5, no_repeat_ngram_size=2, max_new_tokens=100)
17
 
18
  return tokenizer.decode(output_ids[0], skip_special_tokens=True)
models/llama2.py CHANGED
@@ -11,6 +11,6 @@ def load():
11
 
12
  def generate(input_text):
13
  input_ids = tokenizer.encode(input_text, return_tensors="pt")
14
- output_ids = model.generate(input_ids, num_beams=5, no_repeat_ngram_size=2)
15
 
16
  return tokenizer.decode(output_ids[0], skip_special_tokens=True)
 
11
 
12
  def generate(input_text):
13
  input_ids = tokenizer.encode(input_text, return_tensors="pt")
14
+ output_ids = model.generate(input_ids, num_beams=5, no_repeat_ngram_size=2, max_new_tokens=100)
15
 
16
  return tokenizer.decode(output_ids[0], skip_special_tokens=True)
models/llama3.py CHANGED
@@ -11,6 +11,6 @@ def load():
11
 
12
  def generate(input_text):
13
  input_ids = tokenizer.encode(input_text, return_tensors="pt")
14
- output_ids = model.generate(input_ids, num_beams=5, no_repeat_ngram_size=2)
15
 
16
  return tokenizer.decode(output_ids[0], skip_special_tokens=True)
 
11
 
12
  def generate(input_text):
13
  input_ids = tokenizer.encode(input_text, return_tensors="pt")
14
+ output_ids = model.generate(input_ids, num_beams=5, no_repeat_ngram_size=2, max_new_tokens=100)
15
 
16
  return tokenizer.decode(output_ids[0], skip_special_tokens=True)
models/llamatiny.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForCausalLM
2
+
3
+ model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
4
+
5
+ def load():
6
+ global model
7
+ global tokenizer
8
+
9
+ model = AutoModelForCausalLM.from_pretrained(model_name)
10
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
11
+
12
+ def generate(input_text):
13
+ input_ids = tokenizer.encode(input_text, return_tensors="pt")
14
+ output_ids = model.generate(input_ids, num_beams=5, no_repeat_ngram_size=2, max_new_tokens=100)
15
+
16
+ return tokenizer.decode(output_ids[0], skip_special_tokens=True)
models/mamba.py CHANGED
@@ -11,6 +11,6 @@ def load():
11
 
12
  def generate(input_text):
13
  input_ids = tokenizer.encode(input_text, return_tensors="pt")
14
- output_ids = model.generate(input_ids, num_beams=5, no_repeat_ngram_size=2)
15
 
16
  return tokenizer.decode(output_ids[0], skip_special_tokens=True)
 
11
 
12
  def generate(input_text):
13
  input_ids = tokenizer.encode(input_text, return_tensors="pt")
14
+ output_ids = model.generate(input_ids, num_beams=5, no_repeat_ngram_size=2, max_new_tokens=100)
15
 
16
  return tokenizer.decode(output_ids[0], skip_special_tokens=True)
models/{tiny.py → tinystories.py} RENAMED
@@ -11,6 +11,6 @@ def load():
11
 
12
  def generate(input_text):
13
  input_ids = tokenizer.encode(input_text, return_tensors="pt")
14
- output_ids = model.generate(input_ids, num_beams=5, no_repeat_ngram_size=2)
15
 
16
  return tokenizer.decode(output_ids[0], skip_special_tokens=True)
 
11
 
12
  def generate(input_text):
13
  input_ids = tokenizer.encode(input_text, return_tensors="pt")
14
+ output_ids = model.generate(input_ids, num_beams=5, no_repeat_ngram_size=2, max_new_tokens=100)
15
 
16
  return tokenizer.decode(output_ids[0], skip_special_tokens=True)