Spaces:
Sleeping
Sleeping
changed class interface
Browse files- backend.py +0 -5
- interface.py +11 -6
backend.py
CHANGED
@@ -28,19 +28,14 @@ model = AutoModelForCausalLM.from_pretrained(
|
|
28 |
token=True)
|
29 |
|
30 |
model.tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it")
|
31 |
-
|
32 |
model.eval()
|
33 |
|
34 |
-
#from accelerate import disk_offload
|
35 |
-
#disk_offload(model=model, offload_dir="offload")
|
36 |
-
|
37 |
# what models will be used by LlamaIndex:
|
38 |
Settings.embed_model = InstructorEmbedding(model_name="hkunlp/instructor-base")
|
39 |
|
40 |
Settings.llm = GemmaLLMInterface(model=model)
|
41 |
#Settings.llm = GemmaLLMInterface(model_name=model_id)
|
42 |
|
43 |
-
|
44 |
############################---------------------------------
|
45 |
|
46 |
# Get the parser
|
|
|
28 |
token=True)
|
29 |
|
30 |
model.tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it")
|
|
|
31 |
model.eval()
|
32 |
|
|
|
|
|
|
|
33 |
# what models will be used by LlamaIndex:
|
34 |
Settings.embed_model = InstructorEmbedding(model_name="hkunlp/instructor-base")
|
35 |
|
36 |
Settings.llm = GemmaLLMInterface(model=model)
|
37 |
#Settings.llm = GemmaLLMInterface(model_name=model_id)
|
38 |
|
|
|
39 |
############################---------------------------------
|
40 |
|
41 |
# Get the parser
|
interface.py
CHANGED
@@ -36,17 +36,22 @@ class GemmaLLMInterface(CustomLLM):
|
|
36 |
outputs = self.model.generate(**inputs, max_length=self.num_output)
|
37 |
response = self.model.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
38 |
response = response[len(prompt):].strip()
|
39 |
-
# Ensure we always return a non-empty response
|
40 |
return CompletionResponse(text=response if response else "No response generated.")
|
41 |
|
42 |
@llm_completion_callback()
|
43 |
def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen:
|
44 |
-
|
45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
yield CompletionResponse(text="No response generated.", delta="No response generated.")
|
47 |
-
else:
|
48 |
-
for token in full_response:
|
49 |
-
yield CompletionResponse(text=token, delta=token)
|
50 |
|
51 |
# for transformers 1
|
52 |
"""class GemmaLLMInterface(CustomLLM):
|
|
|
36 |
outputs = self.model.generate(**inputs, max_length=self.num_output)
|
37 |
response = self.model.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
38 |
response = response[len(prompt):].strip()
|
|
|
39 |
return CompletionResponse(text=response if response else "No response generated.")
|
40 |
|
41 |
@llm_completion_callback()
|
42 |
def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen:
|
43 |
+
prompt = self._format_prompt(prompt)
|
44 |
+
inputs = self.model.tokenizer(prompt, return_tensors="pt").to(self.model.device)
|
45 |
+
|
46 |
+
streamed_response = ""
|
47 |
+
for output in self.model.generate(**inputs, max_length=self.num_output, streaming=True):
|
48 |
+
new_token = self.model.tokenizer.decode(output[0], skip_special_tokens=True)
|
49 |
+
if new_token:
|
50 |
+
streamed_response += new_token
|
51 |
+
yield CompletionResponse(text=streamed_response, delta=new_token)
|
52 |
+
|
53 |
+
if not streamed_response:
|
54 |
yield CompletionResponse(text="No response generated.", delta="No response generated.")
|
|
|
|
|
|
|
55 |
|
56 |
# for transformers 1
|
57 |
"""class GemmaLLMInterface(CustomLLM):
|