Spaces:
Sleeping
Sleeping
changed class interface with iterator
Browse files- backend.py +2 -37
- interface.py +4 -4
backend.py
CHANGED
@@ -34,9 +34,7 @@ model.eval()
|
|
34 |
|
35 |
# what models will be used by LlamaIndex:
|
36 |
Settings.embed_model = InstructorEmbedding(model_name="hkunlp/instructor-base")
|
37 |
-
|
38 |
Settings.llm = GemmaLLMInterface()
|
39 |
-
#Settings.llm = GemmaLLMInterface(model_name=model_id)
|
40 |
|
41 |
############################---------------------------------
|
42 |
|
@@ -60,43 +58,8 @@ def build_index():
|
|
60 |
def handle_query(query_str, chathistory) -> Iterator[str]:
|
61 |
|
62 |
index = build_index()
|
63 |
-
|
64 |
-
qa_prompt_str = (
|
65 |
-
"Context information is below.\n"
|
66 |
-
"---------------------\n"
|
67 |
-
"{context_str}\n"
|
68 |
-
"---------------------\n"
|
69 |
-
"Given the context information and not prior knowledge, "
|
70 |
-
"answer the question: {query_str}\n"
|
71 |
-
)
|
72 |
-
|
73 |
-
# Text QA Prompt
|
74 |
-
chat_text_qa_msgs = [
|
75 |
-
(
|
76 |
-
"system",
|
77 |
-
"Sei un assistente italiano di nome Ossy che risponde solo alle domande o richieste pertinenti. ",
|
78 |
-
),
|
79 |
-
("user", qa_prompt_str),
|
80 |
-
]
|
81 |
-
text_qa_template = ChatPromptTemplate.from_messages(chat_text_qa_msgs)
|
82 |
|
83 |
try:
|
84 |
-
# Create a streaming query engine
|
85 |
-
"""query_engine = index.as_query_engine(text_qa_template=text_qa_template, streaming=False, similarity_top_k=1)
|
86 |
-
|
87 |
-
# Execute the query
|
88 |
-
streaming_response = query_engine.query(query_str)
|
89 |
-
|
90 |
-
r = streaming_response.response
|
91 |
-
cleaned_result = r.replace("<end_of_turn>", "").strip()
|
92 |
-
yield cleaned_result"""
|
93 |
-
|
94 |
-
# Stream the response
|
95 |
-
"""outputs = []
|
96 |
-
for text in streaming_response.response_gen:
|
97 |
-
|
98 |
-
outputs.append(str(text))
|
99 |
-
yield "".join(outputs)"""
|
100 |
|
101 |
memory = ChatMemoryBuffer.from_defaults(token_limit=1500)
|
102 |
chat_engine = index.as_chat_engine(
|
@@ -112,6 +75,8 @@ def handle_query(query_str, chathistory) -> Iterator[str]:
|
|
112 |
response = chat_engine.stream_chat(query_str)
|
113 |
#response = chat_engine.chat(query_str)
|
114 |
for token in response.response_gen:
|
|
|
|
|
115 |
outputs.append(str(token))
|
116 |
print(f"Generated token: {token}")
|
117 |
yield "".join(outputs)
|
|
|
34 |
|
35 |
# what models will be used by LlamaIndex:
|
36 |
Settings.embed_model = InstructorEmbedding(model_name="hkunlp/instructor-base")
|
|
|
37 |
Settings.llm = GemmaLLMInterface()
|
|
|
38 |
|
39 |
############################---------------------------------
|
40 |
|
|
|
58 |
def handle_query(query_str, chathistory) -> Iterator[str]:
|
59 |
|
60 |
index = build_index()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
|
62 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
|
64 |
memory = ChatMemoryBuffer.from_defaults(token_limit=1500)
|
65 |
chat_engine = index.as_chat_engine(
|
|
|
75 |
response = chat_engine.stream_chat(query_str)
|
76 |
#response = chat_engine.chat(query_str)
|
77 |
for token in response.response_gen:
|
78 |
+
if not token.startswith("system:") and not token.startswith("user:"):
|
79 |
+
|
80 |
outputs.append(str(token))
|
81 |
print(f"Generated token: {token}")
|
82 |
yield "".join(outputs)
|
interface.py
CHANGED
@@ -7,19 +7,19 @@ from transformers import TextIteratorStreamer
|
|
7 |
from threading import Thread
|
8 |
from pydantic import Field, field_validator
|
9 |
|
10 |
-
# for transformers 2
|
11 |
class GemmaLLMInterface(CustomLLM):
|
12 |
def __init__(self, model_id: str = "google/gemma-2-2b-it", **kwargs):
|
13 |
super().__init__(**kwargs)
|
14 |
-
object.__setattr__(self, "model_id", model_id)
|
15 |
model = AutoModelForCausalLM.from_pretrained(
|
16 |
model_id,
|
17 |
device_map="auto",
|
18 |
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
|
19 |
)
|
20 |
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
21 |
-
object.__setattr__(self, "model", model)
|
22 |
-
object.__setattr__(self, "tokenizer", tokenizer)
|
23 |
object.__setattr__(self, "context_window", 8192)
|
24 |
object.__setattr__(self, "num_output", 2048)
|
25 |
|
|
|
7 |
from threading import Thread
|
8 |
from pydantic import Field, field_validator
|
9 |
|
10 |
+
# for transformers 2 (__setattr__ is used to bypass Pydantic check )
|
11 |
class GemmaLLMInterface(CustomLLM):
|
12 |
def __init__(self, model_id: str = "google/gemma-2-2b-it", **kwargs):
|
13 |
super().__init__(**kwargs)
|
14 |
+
object.__setattr__(self, "model_id", model_id)
|
15 |
model = AutoModelForCausalLM.from_pretrained(
|
16 |
model_id,
|
17 |
device_map="auto",
|
18 |
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
|
19 |
)
|
20 |
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
21 |
+
object.__setattr__(self, "model", model)
|
22 |
+
object.__setattr__(self, "tokenizer", tokenizer)
|
23 |
object.__setattr__(self, "context_window", 8192)
|
24 |
object.__setattr__(self, "num_output", 2048)
|
25 |
|