gufett0 commited on
Commit
b210fbe
·
1 Parent(s): aac5496

removed huggingface_hub

Browse files
Files changed (2) hide show
  1. app.py +45 -0
  2. backend.py +2 -0
app.py CHANGED
@@ -16,6 +16,51 @@ from llama_cpp import Llama
16
  import spaces
17
 
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  #huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
20
 
21
 
 
16
  import spaces
17
 
18
 
19
+ from transformers import AutoTokenizer, AutoModelForCausalLM
20
+ from llama_index.core.llms import CustomLLM, LLMMetadata, CompletionResponse, CompletionResponseGen
21
+ from llama_index.core.llms.callbacks import llm_completion_callback
22
+ from typing import Any
23
+
24
+
25
+ class GemmaLLMInterface(CustomLLM):
26
+ model: Any
27
+ tokenizer: Any
28
+ context_window: int = 8192
29
+ num_output: int = 2048
30
+ model_name: str = "gemma_2"
31
+
32
+ class Config:
33
+ protected_namespaces = ()
34
+
35
+ def _format_prompt(self, message: str) -> str:
36
+ return (
37
+ f"<start_of_turn>user\n{message}<end_of_turn>\n<start_of_turn>model\n"
38
+ )
39
+
40
+ @property
41
+ def metadata(self) -> LLMMetadata:
42
+ #Get LLM metadata.
43
+ return LLMMetadata(
44
+ context_window=self.context_window,
45
+ num_output=self.num_output,
46
+ model_name=self.model_name,
47
+ )
48
+
49
+ @llm_completion_callback()
50
+ def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
51
+ prompt = self._format_prompt(prompt)
52
+ inputs = self.tokenizer(prompt, return_tensors="pt")
53
+ output = self.model.generate(**inputs, max_length=self.num_output)
54
+ raw_response = self.tokenizer.decode(output[0], skip_special_tokens=True)
55
+ response = raw_response[len(prompt):]
56
+ return CompletionResponse(text=response)
57
+
58
+ @llm_completion_callback()
59
+ def stream_complete(self, prompt: str, **kwargs: any) -> CompletionResponseGen:
60
+ response = self.complete(prompt).text
61
+ for token in response:
62
+ yield CompletionResponse(text=token)
63
+
64
  #huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
65
 
66
 
backend.py CHANGED
@@ -13,6 +13,8 @@ from llama_cpp import Llama
13
  import spaces
14
 
15
 
 
 
16
  #huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
17
 
18
 
 
13
  import spaces
14
 
15
 
16
+
17
+
18
  #huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
19
 
20