SAMBOOM commited on
Commit
e8f6a07
1 Parent(s): e3c388b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -24
app.py CHANGED
@@ -1,16 +1,16 @@
1
  import streamlit as st
2
- from transformers import AutoTokenizer, AutoModelWithLMHead
3
 
4
  tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-Instruct-v0.1")
5
  @st.cache_data
6
  def load_model(model_name):
7
- model = AutoModelWithLMHead.from_pretrained("mistralai/Mixtral-8x7B-Instruct-v0.1")
8
  return model
9
 
10
  model = load_model("mistralai/Mixtral-8x7B-Instruct-v0.1")
11
 
12
  def infer(input_ids, max_length, temperature, top_k, top_p):
13
-
14
  output_sequences = model.generate(
15
  input_ids=input_ids,
16
  max_length=max_length,
@@ -20,8 +20,8 @@ def infer(input_ids, max_length, temperature, top_k, top_p):
20
  do_sample=True,
21
  num_return_sequences=1
22
  )
23
-
24
  return output_sequences
 
25
  default_value = "Ask me anything!"
26
 
27
  #prompts
@@ -39,28 +39,9 @@ if encoded_prompt.size()[-1] == 0:
39
  else:
40
  input_ids = encoded_prompt
41
 
42
-
43
  output_sequences = infer(input_ids, max_length, temperature, top_k, top_p)
44
 
45
-
46
-
47
  for generated_sequence_idx, generated_sequence in enumerate(output_sequences):
48
- print(f"=== GENERATED SEQUENCE {generated_sequence_idx + 1} ===")
49
  generated_sequences = generated_sequence.tolist()
50
-
51
- # Decode text
52
  text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)
53
-
54
- # Remove all text after the stop token
55
- #text = text[: text.find(args.stop_token) if args.stop_token else None]
56
-
57
- # Add the prompt at the beginning of the sequence. Remove the excess text that was used for pre-processing
58
- total_sequence = (
59
- sent + text[len(tokenizer.decode(encoded_prompt[0], clean_up_tokenization_spaces=True)) :]
60
- )
61
-
62
- generated_sequences.append(total_sequence)
63
- print(total_sequence)
64
-
65
-
66
- st.write(generated_sequences[-1])
 
1
  import streamlit as st
2
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
 
4
  tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-Instruct-v0.1")
5
  @st.cache_data
6
  def load_model(model_name):
7
+ model = AutoModelForSeq2SeqLM.from_pretrained("mistralai/Mixtral-8x7B-Instruct-v0.1")
8
  return model
9
 
10
  model = load_model("mistralai/Mixtral-8x7B-Instruct-v0.1")
11
 
12
  def infer(input_ids, max_length, temperature, top_k, top_p):
13
+ input_ids = torch.tensor(input_ids).unsqueeze(0)
14
  output_sequences = model.generate(
15
  input_ids=input_ids,
16
  max_length=max_length,
 
20
  do_sample=True,
21
  num_return_sequences=1
22
  )
 
23
  return output_sequences
24
+
25
  default_value = "Ask me anything!"
26
 
27
  #prompts
 
39
  else:
40
  input_ids = encoded_prompt
41
 
 
42
  output_sequences = infer(input_ids, max_length, temperature, top_k, top_p)
43
 
 
 
44
  for generated_sequence_idx, generated_sequence in enumerate(output_sequences):
 
45
  generated_sequences = generated_sequence.tolist()
 
 
46
  text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)
47
+ st.write(text)