Ravi21 commited on
Commit
3d0c609
1 Parent(s): debd5f8

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +21 -9
model.py CHANGED
@@ -1,9 +1,11 @@
1
  from threading import Thread
2
  from typing import Iterator
3
 
4
-
5
  from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
6
- model="/kaggle/input/deberta-v3-large-hf-weights"
 
 
7
  if torch.cuda.is_available():
8
  config = AutoConfig.from_pretrained(model_id)
9
  config.pretraining_tp = 1
@@ -17,19 +19,28 @@ if torch.cuda.is_available():
17
  else:
18
  model = None
19
  tokenizer = AutoTokenizer.from_pretrained(model_id)
20
- def preprocess(sample):
21
- first_sentences = [sample["prompt"]] * 5
22
- second_sentences = [sample[option] for option in "ABCDE"]
23
- tokenized_sentences = tokenizer(first_sentences, second_sentences, truncation=True, padding=True, return_tensors="pt")
24
- sample["input_ids"] = tokenized_sentences["input_ids"]
25
- sample["attention_mask"] = tokenized_sentences["attention_mask"]
26
- return sample
 
 
 
 
 
 
 
 
27
 
28
  def get_input_token_length(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> int:
29
  prompt = get_prompt(message, chat_history, system_prompt)
30
  input_ids = tokenizer([prompt], return_tensors='np', add_special_tokens=False)['input_ids']
31
  return input_ids.shape[-1]
32
 
 
33
  def run(message: str,
34
  chat_history: list[tuple[str, str]],
35
  system_prompt: str,
@@ -61,3 +72,4 @@ def run(message: str,
61
  for text in streamer:
62
  outputs.append(text)
63
  yield ''.join(outputs)
 
 
1
  from threading import Thread
2
  from typing import Iterator
3
 
4
+ import torch
5
  from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
6
+
7
+ model_id = 'meta-llama/Llama-2-13b-chat-hf'
8
+
9
  if torch.cuda.is_available():
10
  config = AutoConfig.from_pretrained(model_id)
11
  config.pretraining_tp = 1
 
19
  else:
20
  model = None
21
  tokenizer = AutoTokenizer.from_pretrained(model_id)
22
+
23
+
24
+ def get_prompt(message: str, chat_history: list[tuple[str, str]],
25
+ system_prompt: str) -> str:
26
+ texts = [f'<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n']
27
+ # The first user input is _not_ stripped
28
+ do_strip = False
29
+ for user_input, response in chat_history:
30
+ user_input = user_input.strip() if do_strip else user_input
31
+ do_strip = True
32
+ texts.append(f'{user_input} [/INST] {response.strip()} </s><s>[INST] ')
33
+ message = message.strip() if do_strip else message
34
+ texts.append(f'{message} [/INST]')
35
+ return ''.join(texts)
36
+
37
 
38
  def get_input_token_length(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> int:
39
  prompt = get_prompt(message, chat_history, system_prompt)
40
  input_ids = tokenizer([prompt], return_tensors='np', add_special_tokens=False)['input_ids']
41
  return input_ids.shape[-1]
42
 
43
+
44
  def run(message: str,
45
  chat_history: list[tuple[str, str]],
46
  system_prompt: str,
 
72
  for text in streamer:
73
  outputs.append(text)
74
  yield ''.join(outputs)
75
+