gufett0 commited on
Commit
d3df8fd
·
1 Parent(s): 57b8c08

added new class

Browse files
Files changed (2) hide show
  1. backend.py +5 -5
  2. interface.py +26 -44
backend.py CHANGED
@@ -20,23 +20,23 @@ login(huggingface_token)
20
 
21
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
22
 
23
- """model_id = "google/gemma-2-2b-it"
24
  tokenizer = AutoTokenizer.from_pretrained(model_id)
25
  model = AutoModelForCausalLM.from_pretrained(
26
  model_id,
27
  device_map="auto", ## change this back to auto!!!
28
  torch_dtype= torch.bfloat16 if torch.cuda.is_available() else torch.float32,
29
  token=True)
30
- model.eval()"""
31
 
32
  #from accelerate import disk_offload
33
  #disk_offload(model=model, offload_dir="offload")
34
 
35
  # what models will be used by LlamaIndex:
36
  """Settings.embed_model = InstructorEmbedding(model_name="hkunlp/instructor-base")
37
- Settings.llm = GemmaLLMInterface(model=model, tokenizer=tokenizer)"""
38
- Settings.embed_model = InstructorEmbedding(model_name="hkunlp/instructor-base")
39
- Settings.llm = GemmaLLMInterface(model_id="google/gemma-2-2b-it")
40
 
41
  ############################---------------------------------
42
 
 
20
 
21
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
22
 
23
+ model_id = "google/gemma-2-2b-it"
24
  tokenizer = AutoTokenizer.from_pretrained(model_id)
25
  model = AutoModelForCausalLM.from_pretrained(
26
  model_id,
27
  device_map="auto", ## change this back to auto!!!
28
  torch_dtype= torch.bfloat16 if torch.cuda.is_available() else torch.float32,
29
  token=True)
30
+ model.eval()
31
 
32
  #from accelerate import disk_offload
33
  #disk_offload(model=model, offload_dir="offload")
34
 
35
  # what models will be used by LlamaIndex:
36
  """Settings.embed_model = InstructorEmbedding(model_name="hkunlp/instructor-base")
37
+ Settings.llm = GemmaLLMInterface(model=model)"""
38
+ gemma_llm = GemmaLLMInterface(model_name=model_id)
39
+
40
 
41
  ############################---------------------------------
42
 
interface.py CHANGED
@@ -9,66 +9,48 @@ from pydantic import Field, field_validator
9
 
10
  # for transformers 2
11
  class GemmaLLMInterface(CustomLLM):
12
- def __init__(self, model_id: str = "google/gemma-2b-it", context_window: int = 8192, num_output: int = 2048):
13
- self.model_id = model_id
14
- self.context_window = context_window
15
- self.num_output = num_output
16
- self.tokenizer = AutoTokenizer.from_pretrained(model_id)
 
 
17
  self.model = AutoModelForCausalLM.from_pretrained(
18
- model_id,
19
- device_map="auto",
20
- torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
21
  )
22
- self.model.eval()
23
-
24
  def _format_prompt(self, message: str) -> str:
25
- return f"<start_of_turn>user\n{message}<end_of_turn>\n<start_of_turn>model\n"
 
 
26
 
27
  @property
28
  def metadata(self) -> LLMMetadata:
 
29
  return LLMMetadata(
30
  context_window=self.context_window,
31
  num_output=self.num_output,
32
- model_name=self.model_id,
33
  )
34
-
35
- def _prepare_inputs(self, prompt: str) -> dict:
36
- formatted_prompt = self._format_prompt(prompt)
37
- inputs = self.tokenizer(formatted_prompt, return_tensors="pt", add_special_tokens=True).to(self.model.device)
38
- if inputs["input_ids"].shape[1] > self.context_window:
39
- inputs["input_ids"] = inputs["input_ids"][:, -self.context_window:]
40
- return inputs
41
-
42
- def _generate(self, inputs: dict) -> Iterator[str]:
43
- for output in self.model.generate(
44
- **inputs,
45
- max_new_tokens=self.num_output,
46
- do_sample=True,
47
- top_p=0.9,
48
- top_k=50,
49
- temperature=0.7,
50
- num_beams=1,
51
- repetition_penalty=1.1,
52
- streamer=None,
53
- return_dict_in_generate=True,
54
- output_scores=False,
55
- ):
56
- new_tokens = output.sequences[:, inputs["input_ids"].shape[-1]:]
57
- yield self.tokenizer.decode(new_tokens[0], skip_special_tokens=True)
58
-
59
  @llm_completion_callback()
60
  def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
61
- inputs = self._prepare_inputs(prompt)
62
- response = "".join(self._generate(inputs))
 
 
63
  return CompletionResponse(text=response)
64
 
65
  @llm_completion_callback()
66
  def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen:
67
- inputs = self._prepare_inputs(prompt)
68
- response = ""
69
- for new_token in self._generate(inputs):
70
- response += new_token
71
- yield CompletionResponse(text=response, delta=new_token)
72
 
73
 
74
  # for transformers 1
 
9
 
10
  # for transformers 2
11
  class GemmaLLMInterface(CustomLLM):
12
+ model: AutoModelForCausalLM = None
13
+ tokenizer: AutoTokenizer = None
14
+ context_window: int = 8192
15
+ num_output: int = 2048
16
+ model_name: str = "gemma_2"
17
+
18
+ def __init__(self, model_name: str):
19
  self.model = AutoModelForCausalLM.from_pretrained(
20
+ model_name,
21
+ device_map="auto", # Set device mapping automatically
22
+ torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
23
  )
24
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
25
+
26
  def _format_prompt(self, message: str) -> str:
27
+ return (
28
+ f"<start_of_turn>user\n{message}<end_of_turn>\n" f"<start_of_turn>model\n"
29
+ )
30
 
31
  @property
32
  def metadata(self) -> LLMMetadata:
33
+ """Get LLM metadata."""
34
  return LLMMetadata(
35
  context_window=self.context_window,
36
  num_output=self.num_output,
37
+ model_name=self.model_name,
38
  )
39
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  @llm_completion_callback()
41
  def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
42
+ prompt = self._format_prompt(prompt)
43
+ inputs = self.tokenizer(prompt, return_tensors="pt") # Tokenize the prompt
44
+ raw_response = self.model.generate(**inputs, max_length=self.num_output)
45
+ response = self.tokenizer.decode(raw_response[0], skip_special_tokens=True)
46
  return CompletionResponse(text=response)
47
 
48
  @llm_completion_callback()
49
  def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen:
50
+ response = self.complete(prompt).text
51
+ for token in response:
52
+ response += token
53
+ yield CompletionResponse(text=response, delta=token)
 
54
 
55
 
56
  # for transformers 1