Spaces:
Sleeping
Sleeping
trying keras
Browse files- backend.py +16 -46
- interface.py +38 -2
- requirements.txt +3 -0
backend.py
CHANGED
@@ -14,32 +14,34 @@ from typing import Iterator, List, Any
|
|
14 |
from llama_index.core.chat_engine import CondensePlusContextChatEngine
|
15 |
from llama_index.core.llms import ChatMessage, MessageRole , CompletionResponse
|
16 |
from IPython.display import Markdown, display
|
|
|
|
|
17 |
#from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
18 |
#from llama_index import LangchainEmbedding, ServiceContext
|
19 |
|
|
|
|
|
|
|
|
|
20 |
|
21 |
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
26 |
|
27 |
-
"""
|
28 |
-
|
29 |
-
|
30 |
-
device_map="auto",
|
31 |
-
torch_dtype= torch.bfloat16 if torch.cuda.is_available() else torch.float32,
|
32 |
-
token=True)
|
33 |
|
34 |
-
|
35 |
-
|
|
|
36 |
|
37 |
# what models will be used by LlamaIndex:
|
38 |
Settings.embed_model = InstructorEmbedding(model_name="hkunlp/instructor-base")
|
39 |
#Settings.embed_model = LangchainEmbedding(HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2'))
|
40 |
|
41 |
-
|
42 |
-
Settings.llm = GemmaLLMInterface()
|
43 |
|
44 |
documents_paths = {
|
45 |
'blockchain': 'data/blockchainprova.txt',
|
@@ -47,7 +49,6 @@ documents_paths = {
|
|
47 |
'payment': 'data/paymentprova.txt'
|
48 |
}
|
49 |
|
50 |
-
|
51 |
global session_state
|
52 |
session_state = {"index": False,
|
53 |
"documents_loaded": False,
|
@@ -97,7 +98,6 @@ def handle_query(query_str: str,
|
|
97 |
|
98 |
index= build_index("data/blockchainprova.txt")
|
99 |
|
100 |
-
|
101 |
conversation: List[ChatMessage] = []
|
102 |
for user, assistant in chat_history:
|
103 |
conversation.extend([
|
@@ -106,33 +106,6 @@ def handle_query(query_str: str,
|
|
106 |
]
|
107 |
)
|
108 |
|
109 |
-
"""if not session_state["index"]:
|
110 |
-
|
111 |
-
matched_path = None
|
112 |
-
words = query_str.lower()
|
113 |
-
for key, path in documents_paths.items():
|
114 |
-
if key in words:
|
115 |
-
matched_path = path
|
116 |
-
break
|
117 |
-
if matched_path:
|
118 |
-
index = build_index(matched_path)
|
119 |
-
gr.Info("index costruito con la path sulla base della query")
|
120 |
-
session_state["index"] = True
|
121 |
-
|
122 |
-
else: ## CHIEDI CHIARIMENTO
|
123 |
-
|
124 |
-
conversation.append(ChatMessage(role=MessageRole.SYSTEM, content=ISTR))
|
125 |
-
|
126 |
-
index = build_index("data/blockchainprova.txt")
|
127 |
-
gr.Info("index costruito con richiesta di chiarimento")
|
128 |
-
|
129 |
-
|
130 |
-
else:
|
131 |
-
|
132 |
-
index = build_index(matched_path)
|
133 |
-
#storage_context = StorageContext.from_defaults(persist_dir=PERSIST_DIR)
|
134 |
-
#index = load_index_from_storage(storage_context)
|
135 |
-
gr.Info("index is true")"""
|
136 |
|
137 |
try:
|
138 |
|
@@ -175,9 +148,6 @@ def handle_query(query_str: str,
|
|
175 |
print(info_message)
|
176 |
gr.Info(info_message)"""
|
177 |
|
178 |
-
|
179 |
-
#prompts_dict = chat_engine.get_prompts()
|
180 |
-
#display_prompt_dict(prompts_dict)
|
181 |
|
182 |
|
183 |
#chat_engine.reset()
|
|
|
14 |
from llama_index.core.chat_engine import CondensePlusContextChatEngine
|
15 |
from llama_index.core.llms import ChatMessage, MessageRole , CompletionResponse
|
16 |
from IPython.display import Markdown, display
|
17 |
+
import keras
|
18 |
+
import keras_nlp
|
19 |
#from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
20 |
#from llama_index import LangchainEmbedding, ServiceContext
|
21 |
|
22 |
+
# Set the backbend before importing Keras
|
23 |
+
os.environ["KERAS_BACKEND"] = "jax"
|
24 |
+
# Avoid memory fragmentation on JAX backend.
|
25 |
+
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "1.00"
|
26 |
|
27 |
|
28 |
+
os.getenv("KAGGLE_USERNAME")
|
29 |
+
os.getenv["KAGGLE_KEY"]
|
|
|
|
|
30 |
|
31 |
+
"""huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
|
32 |
+
login(huggingface_token)
|
33 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")"""
|
|
|
|
|
|
|
34 |
|
35 |
+
# Let's load Gemma using Keras
|
36 |
+
gemma_model_id = "gemma2_instruct_2b_en"
|
37 |
+
gemma = keras_nlp.models.GemmaCausalLM.from_preset(gemma_model_id)
|
38 |
|
39 |
# what models will be used by LlamaIndex:
|
40 |
Settings.embed_model = InstructorEmbedding(model_name="hkunlp/instructor-base")
|
41 |
#Settings.embed_model = LangchainEmbedding(HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2'))
|
42 |
|
43 |
+
#Settings.llm = GemmaLLMInterface()
|
44 |
+
Settings.llm = GemmaLLMInterface(model=gemma)
|
45 |
|
46 |
documents_paths = {
|
47 |
'blockchain': 'data/blockchainprova.txt',
|
|
|
49 |
'payment': 'data/paymentprova.txt'
|
50 |
}
|
51 |
|
|
|
52 |
global session_state
|
53 |
session_state = {"index": False,
|
54 |
"documents_loaded": False,
|
|
|
98 |
|
99 |
index= build_index("data/blockchainprova.txt")
|
100 |
|
|
|
101 |
conversation: List[ChatMessage] = []
|
102 |
for user, assistant in chat_history:
|
103 |
conversation.extend([
|
|
|
106 |
]
|
107 |
)
|
108 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
|
110 |
try:
|
111 |
|
|
|
148 |
print(info_message)
|
149 |
gr.Info(info_message)"""
|
150 |
|
|
|
|
|
|
|
151 |
|
152 |
|
153 |
#chat_engine.reset()
|
interface.py
CHANGED
@@ -6,9 +6,11 @@ import torch
|
|
6 |
from transformers import TextIteratorStreamer
|
7 |
from threading import Thread
|
8 |
from pydantic import Field, field_validator
|
|
|
|
|
9 |
|
10 |
# for transformers 2 (__setattr__ is used to bypass Pydantic check )
|
11 |
-
class GemmaLLMInterface(CustomLLM):
|
12 |
def __init__(self, model_id: str = "google/gemma-2-2b-it", **kwargs):
|
13 |
super().__init__(**kwargs)
|
14 |
object.__setattr__(self, "model_id", model_id)
|
@@ -65,5 +67,39 @@ class GemmaLLMInterface(CustomLLM):
|
|
65 |
yield CompletionResponse(text=streamed_response, delta=new_text)
|
66 |
|
67 |
if not streamed_response:
|
68 |
-
yield CompletionResponse(text="No response generated.", delta="No response generated.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
from transformers import TextIteratorStreamer
|
7 |
from threading import Thread
|
8 |
from pydantic import Field, field_validator
|
9 |
+
import keras
|
10 |
+
import keras_nlp
|
11 |
|
12 |
# for transformers 2 (__setattr__ is used to bypass Pydantic check )
|
13 |
+
"""class GemmaLLMInterface(CustomLLM):
|
14 |
def __init__(self, model_id: str = "google/gemma-2-2b-it", **kwargs):
|
15 |
super().__init__(**kwargs)
|
16 |
object.__setattr__(self, "model_id", model_id)
|
|
|
67 |
yield CompletionResponse(text=streamed_response, delta=new_text)
|
68 |
|
69 |
if not streamed_response:
|
70 |
+
yield CompletionResponse(text="No response generated.", delta="No response generated.")"""
|
71 |
+
|
72 |
+
|
73 |
+
class GemmaLLMInterface(CustomLLM):
|
74 |
+
model: keras_nlp.models.GemmaCausalLM = None
|
75 |
+
context_window: int = 8192
|
76 |
+
num_output: int = 2048
|
77 |
+
model_name: str = "gemma_2"
|
78 |
|
79 |
+
def _format_prompt(self, message: str) -> str:
|
80 |
+
return (
|
81 |
+
f"<start_of_turn>user\n{message}<end_of_turn>\n" f"<start_of_turn>model\n"
|
82 |
+
)
|
83 |
+
|
84 |
+
@property
|
85 |
+
def metadata(self) -> LLMMetadata:
|
86 |
+
"""Get LLM metadata."""
|
87 |
+
return LLMMetadata(
|
88 |
+
context_window=self.context_window,
|
89 |
+
num_output=self.num_output,
|
90 |
+
model_name=self.model_name,
|
91 |
+
)
|
92 |
+
|
93 |
+
@llm_completion_callback()
|
94 |
+
def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
|
95 |
+
prompt = self._format_prompt(prompt)
|
96 |
+
raw_response = self.model.generate(prompt, max_length=self.num_output)
|
97 |
+
response = raw_response[len(prompt) :]
|
98 |
+
return CompletionResponse(text=response)
|
99 |
+
|
100 |
+
@llm_completion_callback()
|
101 |
+
def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen:
|
102 |
+
response = self.complete(prompt).text
|
103 |
+
for token in response:
|
104 |
+
response += token
|
105 |
+
yield CompletionResponse(text=response, delta=token)
|
requirements.txt
CHANGED
@@ -13,6 +13,9 @@ setuptools
|
|
13 |
spaces
|
14 |
pydantic
|
15 |
ipython
|
|
|
|
|
|
|
16 |
#langchain
|
17 |
#langchain-community
|
18 |
#langchain_huggingface
|
|
|
13 |
spaces
|
14 |
pydantic
|
15 |
ipython
|
16 |
+
keras
|
17 |
+
keras-nlp
|
18 |
+
tensorflow
|
19 |
#langchain
|
20 |
#langchain-community
|
21 |
#langchain_huggingface
|