gufett0 commited on
Commit
f57e33c
·
1 Parent(s): b7a41e7

changed class interface

Browse files
Files changed (2) hide show
  1. backend.py +0 -5
  2. interface.py +11 -6
backend.py CHANGED
@@ -28,19 +28,14 @@ model = AutoModelForCausalLM.from_pretrained(
28
  token=True)
29
 
30
  model.tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it")
31
-
32
  model.eval()
33
 
34
- #from accelerate import disk_offload
35
- #disk_offload(model=model, offload_dir="offload")
36
-
37
  # what models will be used by LlamaIndex:
38
  Settings.embed_model = InstructorEmbedding(model_name="hkunlp/instructor-base")
39
 
40
  Settings.llm = GemmaLLMInterface(model=model)
41
  #Settings.llm = GemmaLLMInterface(model_name=model_id)
42
 
43
-
44
  ############################---------------------------------
45
 
46
  # Get the parser
 
28
  token=True)
29
 
30
  model.tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it")
 
31
  model.eval()
32
 
 
 
 
33
  # what models will be used by LlamaIndex:
34
  Settings.embed_model = InstructorEmbedding(model_name="hkunlp/instructor-base")
35
 
36
  Settings.llm = GemmaLLMInterface(model=model)
37
  #Settings.llm = GemmaLLMInterface(model_name=model_id)
38
 
 
39
  ############################---------------------------------
40
 
41
  # Get the parser
interface.py CHANGED
@@ -36,17 +36,22 @@ class GemmaLLMInterface(CustomLLM):
36
  outputs = self.model.generate(**inputs, max_length=self.num_output)
37
  response = self.model.tokenizer.decode(outputs[0], skip_special_tokens=True)
38
  response = response[len(prompt):].strip()
39
- # Ensure we always return a non-empty response
40
  return CompletionResponse(text=response if response else "No response generated.")
41
 
42
  @llm_completion_callback()
43
  def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen:
44
- full_response = self.complete(prompt).text
45
- if not full_response:
 
 
 
 
 
 
 
 
 
46
  yield CompletionResponse(text="No response generated.", delta="No response generated.")
47
- else:
48
- for token in full_response:
49
- yield CompletionResponse(text=token, delta=token)
50
 
51
  # for transformers 1
52
  """class GemmaLLMInterface(CustomLLM):
 
36
  outputs = self.model.generate(**inputs, max_length=self.num_output)
37
  response = self.model.tokenizer.decode(outputs[0], skip_special_tokens=True)
38
  response = response[len(prompt):].strip()
 
39
  return CompletionResponse(text=response if response else "No response generated.")
40
 
41
  @llm_completion_callback()
42
  def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen:
43
+ prompt = self._format_prompt(prompt)
44
+ inputs = self.model.tokenizer(prompt, return_tensors="pt").to(self.model.device)
45
+
46
+ streamed_response = ""
47
+ for output in self.model.generate(**inputs, max_length=self.num_output, streaming=True):
48
+ new_token = self.model.tokenizer.decode(output[0], skip_special_tokens=True)
49
+ if new_token:
50
+ streamed_response += new_token
51
+ yield CompletionResponse(text=streamed_response, delta=new_token)
52
+
53
+ if not streamed_response:
54
  yield CompletionResponse(text="No response generated.", delta="No response generated.")
 
 
 
55
 
56
  # for transformers 1
57
  """class GemmaLLMInterface(CustomLLM):