gufett0 commited on
Commit
bbe77cb
·
1 Parent(s): 975ddfc

trying keras

Browse files
Files changed (3) hide show
  1. backend.py +16 -46
  2. interface.py +38 -2
  3. requirements.txt +3 -0
backend.py CHANGED
@@ -14,32 +14,34 @@ from typing import Iterator, List, Any
14
  from llama_index.core.chat_engine import CondensePlusContextChatEngine
15
  from llama_index.core.llms import ChatMessage, MessageRole , CompletionResponse
16
  from IPython.display import Markdown, display
 
 
17
  #from langchain.embeddings.huggingface import HuggingFaceEmbeddings
18
  #from llama_index import LangchainEmbedding, ServiceContext
19
 
 
 
 
 
20
 
21
 
22
- huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
23
- login(huggingface_token)
24
-
25
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
26
 
27
- """model_id = "google/gemma-2-2b-it"
28
- model = AutoModelForCausalLM.from_pretrained(
29
- model_id,
30
- device_map="auto",
31
- torch_dtype= torch.bfloat16 if torch.cuda.is_available() else torch.float32,
32
- token=True)
33
 
34
- model.tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it")
35
- model.eval()"""
 
36
 
37
  # what models will be used by LlamaIndex:
38
  Settings.embed_model = InstructorEmbedding(model_name="hkunlp/instructor-base")
39
  #Settings.embed_model = LangchainEmbedding(HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2'))
40
 
41
-
42
- Settings.llm = GemmaLLMInterface()
43
 
44
  documents_paths = {
45
  'blockchain': 'data/blockchainprova.txt',
@@ -47,7 +49,6 @@ documents_paths = {
47
  'payment': 'data/paymentprova.txt'
48
  }
49
 
50
-
51
  global session_state
52
  session_state = {"index": False,
53
  "documents_loaded": False,
@@ -97,7 +98,6 @@ def handle_query(query_str: str,
97
 
98
  index= build_index("data/blockchainprova.txt")
99
 
100
-
101
  conversation: List[ChatMessage] = []
102
  for user, assistant in chat_history:
103
  conversation.extend([
@@ -106,33 +106,6 @@ def handle_query(query_str: str,
106
  ]
107
  )
108
 
109
- """if not session_state["index"]:
110
-
111
- matched_path = None
112
- words = query_str.lower()
113
- for key, path in documents_paths.items():
114
- if key in words:
115
- matched_path = path
116
- break
117
- if matched_path:
118
- index = build_index(matched_path)
119
- gr.Info("index costruito con la path sulla base della query")
120
- session_state["index"] = True
121
-
122
- else: ## CHIEDI CHIARIMENTO
123
-
124
- conversation.append(ChatMessage(role=MessageRole.SYSTEM, content=ISTR))
125
-
126
- index = build_index("data/blockchainprova.txt")
127
- gr.Info("index costruito con richiesta di chiarimento")
128
-
129
-
130
- else:
131
-
132
- index = build_index(matched_path)
133
- #storage_context = StorageContext.from_defaults(persist_dir=PERSIST_DIR)
134
- #index = load_index_from_storage(storage_context)
135
- gr.Info("index is true")"""
136
 
137
  try:
138
 
@@ -175,9 +148,6 @@ def handle_query(query_str: str,
175
  print(info_message)
176
  gr.Info(info_message)"""
177
 
178
-
179
- #prompts_dict = chat_engine.get_prompts()
180
- #display_prompt_dict(prompts_dict)
181
 
182
 
183
  #chat_engine.reset()
 
14
  from llama_index.core.chat_engine import CondensePlusContextChatEngine
15
  from llama_index.core.llms import ChatMessage, MessageRole , CompletionResponse
16
  from IPython.display import Markdown, display
17
+ import keras
18
+ import keras_nlp
19
  #from langchain.embeddings.huggingface import HuggingFaceEmbeddings
20
  #from llama_index import LangchainEmbedding, ServiceContext
21
 
22
+ # Set the backbend before importing Keras
23
+ os.environ["KERAS_BACKEND"] = "jax"
24
+ # Avoid memory fragmentation on JAX backend.
25
+ os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "1.00"
26
 
27
 
28
+ os.getenv("KAGGLE_USERNAME")
29
+ os.getenv["KAGGLE_KEY"]
 
 
30
 
31
+ """huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
32
+ login(huggingface_token)
33
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")"""
 
 
 
34
 
35
+ # Let's load Gemma using Keras
36
+ gemma_model_id = "gemma2_instruct_2b_en"
37
+ gemma = keras_nlp.models.GemmaCausalLM.from_preset(gemma_model_id)
38
 
39
  # what models will be used by LlamaIndex:
40
  Settings.embed_model = InstructorEmbedding(model_name="hkunlp/instructor-base")
41
  #Settings.embed_model = LangchainEmbedding(HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2'))
42
 
43
+ #Settings.llm = GemmaLLMInterface()
44
+ Settings.llm = GemmaLLMInterface(model=gemma)
45
 
46
  documents_paths = {
47
  'blockchain': 'data/blockchainprova.txt',
 
49
  'payment': 'data/paymentprova.txt'
50
  }
51
 
 
52
  global session_state
53
  session_state = {"index": False,
54
  "documents_loaded": False,
 
98
 
99
  index= build_index("data/blockchainprova.txt")
100
 
 
101
  conversation: List[ChatMessage] = []
102
  for user, assistant in chat_history:
103
  conversation.extend([
 
106
  ]
107
  )
108
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
 
110
  try:
111
 
 
148
  print(info_message)
149
  gr.Info(info_message)"""
150
 
 
 
 
151
 
152
 
153
  #chat_engine.reset()
interface.py CHANGED
@@ -6,9 +6,11 @@ import torch
6
  from transformers import TextIteratorStreamer
7
  from threading import Thread
8
  from pydantic import Field, field_validator
 
 
9
 
10
  # for transformers 2 (__setattr__ is used to bypass Pydantic check )
11
- class GemmaLLMInterface(CustomLLM):
12
  def __init__(self, model_id: str = "google/gemma-2-2b-it", **kwargs):
13
  super().__init__(**kwargs)
14
  object.__setattr__(self, "model_id", model_id)
@@ -65,5 +67,39 @@ class GemmaLLMInterface(CustomLLM):
65
  yield CompletionResponse(text=streamed_response, delta=new_text)
66
 
67
  if not streamed_response:
68
- yield CompletionResponse(text="No response generated.", delta="No response generated.")
 
 
 
 
 
 
 
69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  from transformers import TextIteratorStreamer
7
  from threading import Thread
8
  from pydantic import Field, field_validator
9
+ import keras
10
+ import keras_nlp
11
 
12
  # for transformers 2 (__setattr__ is used to bypass Pydantic check )
13
+ """class GemmaLLMInterface(CustomLLM):
14
  def __init__(self, model_id: str = "google/gemma-2-2b-it", **kwargs):
15
  super().__init__(**kwargs)
16
  object.__setattr__(self, "model_id", model_id)
 
67
  yield CompletionResponse(text=streamed_response, delta=new_text)
68
 
69
  if not streamed_response:
70
+ yield CompletionResponse(text="No response generated.", delta="No response generated.")"""
71
+
72
+
73
+ class GemmaLLMInterface(CustomLLM):
74
+ model: keras_nlp.models.GemmaCausalLM = None
75
+ context_window: int = 8192
76
+ num_output: int = 2048
77
+ model_name: str = "gemma_2"
78
 
79
+ def _format_prompt(self, message: str) -> str:
80
+ return (
81
+ f"<start_of_turn>user\n{message}<end_of_turn>\n" f"<start_of_turn>model\n"
82
+ )
83
+
84
+ @property
85
+ def metadata(self) -> LLMMetadata:
86
+ """Get LLM metadata."""
87
+ return LLMMetadata(
88
+ context_window=self.context_window,
89
+ num_output=self.num_output,
90
+ model_name=self.model_name,
91
+ )
92
+
93
+ @llm_completion_callback()
94
+ def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
95
+ prompt = self._format_prompt(prompt)
96
+ raw_response = self.model.generate(prompt, max_length=self.num_output)
97
+ response = raw_response[len(prompt) :]
98
+ return CompletionResponse(text=response)
99
+
100
+ @llm_completion_callback()
101
+ def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen:
102
+ response = self.complete(prompt).text
103
+ for token in response:
104
+ response += token
105
+ yield CompletionResponse(text=response, delta=token)
requirements.txt CHANGED
@@ -13,6 +13,9 @@ setuptools
13
  spaces
14
  pydantic
15
  ipython
 
 
 
16
  #langchain
17
  #langchain-community
18
  #langchain_huggingface
 
13
  spaces
14
  pydantic
15
  ipython
16
+ keras
17
+ keras-nlp
18
+ tensorflow
19
  #langchain
20
  #langchain-community
21
  #langchain_huggingface