gufett0 commited on
Commit
cf360c7
·
1 Parent(s): 86b68c0

changed class interface with iterator

Browse files
Files changed (2) hide show
  1. backend.py +2 -37
  2. interface.py +4 -4
backend.py CHANGED
@@ -34,9 +34,7 @@ model.eval()
34
 
35
  # what models will be used by LlamaIndex:
36
  Settings.embed_model = InstructorEmbedding(model_name="hkunlp/instructor-base")
37
-
38
  Settings.llm = GemmaLLMInterface()
39
- #Settings.llm = GemmaLLMInterface(model_name=model_id)
40
 
41
  ############################---------------------------------
42
 
@@ -60,43 +58,8 @@ def build_index():
60
  def handle_query(query_str, chathistory) -> Iterator[str]:
61
 
62
  index = build_index()
63
-
64
- qa_prompt_str = (
65
- "Context information is below.\n"
66
- "---------------------\n"
67
- "{context_str}\n"
68
- "---------------------\n"
69
- "Given the context information and not prior knowledge, "
70
- "answer the question: {query_str}\n"
71
- )
72
-
73
- # Text QA Prompt
74
- chat_text_qa_msgs = [
75
- (
76
- "system",
77
- "Sei un assistente italiano di nome Ossy che risponde solo alle domande o richieste pertinenti. ",
78
- ),
79
- ("user", qa_prompt_str),
80
- ]
81
- text_qa_template = ChatPromptTemplate.from_messages(chat_text_qa_msgs)
82
 
83
  try:
84
- # Create a streaming query engine
85
- """query_engine = index.as_query_engine(text_qa_template=text_qa_template, streaming=False, similarity_top_k=1)
86
-
87
- # Execute the query
88
- streaming_response = query_engine.query(query_str)
89
-
90
- r = streaming_response.response
91
- cleaned_result = r.replace("<end_of_turn>", "").strip()
92
- yield cleaned_result"""
93
-
94
- # Stream the response
95
- """outputs = []
96
- for text in streaming_response.response_gen:
97
-
98
- outputs.append(str(text))
99
- yield "".join(outputs)"""
100
 
101
  memory = ChatMemoryBuffer.from_defaults(token_limit=1500)
102
  chat_engine = index.as_chat_engine(
@@ -112,6 +75,8 @@ def handle_query(query_str, chathistory) -> Iterator[str]:
112
  response = chat_engine.stream_chat(query_str)
113
  #response = chat_engine.chat(query_str)
114
  for token in response.response_gen:
 
 
115
  outputs.append(str(token))
116
  print(f"Generated token: {token}")
117
  yield "".join(outputs)
 
34
 
35
  # what models will be used by LlamaIndex:
36
  Settings.embed_model = InstructorEmbedding(model_name="hkunlp/instructor-base")
 
37
  Settings.llm = GemmaLLMInterface()
 
38
 
39
  ############################---------------------------------
40
 
 
58
  def handle_query(query_str, chathistory) -> Iterator[str]:
59
 
60
  index = build_index()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
  memory = ChatMemoryBuffer.from_defaults(token_limit=1500)
65
  chat_engine = index.as_chat_engine(
 
75
  response = chat_engine.stream_chat(query_str)
76
  #response = chat_engine.chat(query_str)
77
  for token in response.response_gen:
78
+ if not token.startswith("system:") and not token.startswith("user:"):
79
+
80
  outputs.append(str(token))
81
  print(f"Generated token: {token}")
82
  yield "".join(outputs)
interface.py CHANGED
@@ -7,19 +7,19 @@ from transformers import TextIteratorStreamer
7
  from threading import Thread
8
  from pydantic import Field, field_validator
9
 
10
- # for transformers 2
11
  class GemmaLLMInterface(CustomLLM):
12
  def __init__(self, model_id: str = "google/gemma-2-2b-it", **kwargs):
13
  super().__init__(**kwargs)
14
- object.__setattr__(self, "model_id", model_id) # Bypass Pydantic for model_id
15
  model = AutoModelForCausalLM.from_pretrained(
16
  model_id,
17
  device_map="auto",
18
  torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
19
  )
20
  tokenizer = AutoTokenizer.from_pretrained(model_id)
21
- object.__setattr__(self, "model", model) # Bypass Pydantic for model
22
- object.__setattr__(self, "tokenizer", tokenizer) # Bypass Pydantic for tokenizer
23
  object.__setattr__(self, "context_window", 8192)
24
  object.__setattr__(self, "num_output", 2048)
25
 
 
7
  from threading import Thread
8
  from pydantic import Field, field_validator
9
 
10
+ # for transformers 2 (__setattr__ is used to bypass Pydantic check )
11
  class GemmaLLMInterface(CustomLLM):
12
  def __init__(self, model_id: str = "google/gemma-2-2b-it", **kwargs):
13
  super().__init__(**kwargs)
14
+ object.__setattr__(self, "model_id", model_id)
15
  model = AutoModelForCausalLM.from_pretrained(
16
  model_id,
17
  device_map="auto",
18
  torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
19
  )
20
  tokenizer = AutoTokenizer.from_pretrained(model_id)
21
+ object.__setattr__(self, "model", model)
22
+ object.__setattr__(self, "tokenizer", tokenizer)
23
  object.__setattr__(self, "context_window", 8192)
24
  object.__setattr__(self, "num_output", 2048)
25