gufett0 commited on
Commit
f7aeb1e
·
1 Parent(s): e7fc85b

added new class

Browse files
Files changed (3) hide show
  1. backend.py +32 -17
  2. interface.py +65 -2
  3. requirements.txt +1 -1
backend.py CHANGED
@@ -33,9 +33,10 @@ model.eval()
33
  #disk_offload(model=model, offload_dir="offload")
34
 
35
  # what models will be used by LlamaIndex:
 
 
36
  Settings.embed_model = InstructorEmbedding(model_name="hkunlp/instructor-base")
37
- Settings.llm = GemmaLLMInterface(model=model, tokenizer=tokenizer)
38
- #Settings.llm = llm
39
 
40
 
41
  ############################---------------------------------
@@ -57,7 +58,8 @@ def build_index():
57
 
58
 
59
  @spaces.GPU(duration=20)
60
- async def handle_query(query_str, chathistory):
 
61
  index = build_index()
62
 
63
  qa_prompt_str = (
@@ -73,32 +75,45 @@ async def handle_query(query_str, chathistory):
73
  chat_text_qa_msgs = [
74
  (
75
  "system",
76
- "Sei un assistente italiano di nome Ossy che risponde solo alle domande o richieste pertinenti.",
77
  ),
78
  ("user", qa_prompt_str),
79
  ]
80
  text_qa_template = ChatPromptTemplate.from_messages(chat_text_qa_msgs)
81
 
82
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  memory = ChatMemoryBuffer.from_defaults(token_limit=1500)
84
  chat_engine = index.as_chat_engine(
85
- chat_mode="context",
86
- memory=memory,
87
- system_prompt=(
88
- "Sei un assistente italiano di nome Ossy che risponde solo alle domande o richieste pertinenti."
89
- ),
90
  )
91
 
92
- # Stream the response
93
  response = chat_engine.stream_chat(query_str)
94
- outputs = []
95
-
96
- async for token in response.response_gen:
97
- outputs.append(token)
98
- yield "".join(outputs)
99
 
100
- except StopAsyncIteration:
101
- yield "No more responses to stream."
102
  except Exception as e:
103
  yield f"Error processing query: {str(e)}"
104
 
 
33
  #disk_offload(model=model, offload_dir="offload")
34
 
35
  # what models will be used by LlamaIndex:
36
+ """Settings.embed_model = InstructorEmbedding(model_name="hkunlp/instructor-base")
37
+ Settings.llm = GemmaLLMInterface(model=model, tokenizer=tokenizer)"""
38
  Settings.embed_model = InstructorEmbedding(model_name="hkunlp/instructor-base")
39
+ Settings.llm = GemmaLLMInterface(model_id="google/gemma-2-2b-it")
 
40
 
41
 
42
  ############################---------------------------------
 
58
 
59
 
60
  @spaces.GPU(duration=20)
61
+ def handle_query(query_str, chathistory):
62
+
63
  index = build_index()
64
 
65
  qa_prompt_str = (
 
75
  chat_text_qa_msgs = [
76
  (
77
  "system",
78
+ "Sei un assistente italiano di nome Ossy che risponde solo alle domande o richieste pertinenti. ",
79
  ),
80
  ("user", qa_prompt_str),
81
  ]
82
  text_qa_template = ChatPromptTemplate.from_messages(chat_text_qa_msgs)
83
 
84
  try:
85
+ # Create a streaming query engine
86
+ """query_engine = index.as_query_engine(text_qa_template=text_qa_template, streaming=False, similarity_top_k=1)
87
+
88
+ # Execute the query
89
+ streaming_response = query_engine.query(query_str)
90
+
91
+ r = streaming_response.response
92
+ cleaned_result = r.replace("<end_of_turn>", "").strip()
93
+ yield cleaned_result"""
94
+
95
+ # Stream the response
96
+ """outputs = []
97
+ for text in streaming_response.response_gen:
98
+
99
+ outputs.append(str(text))
100
+ yield "".join(outputs)"""
101
+
102
  memory = ChatMemoryBuffer.from_defaults(token_limit=1500)
103
  chat_engine = index.as_chat_engine(
104
+ chat_mode="context",
105
+ memory=memory,
106
+ system_prompt=(
107
+ "Sei un assistente italiano di nome Ossy che risponde solo alle domande o richieste pertinenti. "
108
+ ),
109
  )
110
 
 
111
  response = chat_engine.stream_chat(query_str)
112
+ #response = chat_engine.chat(query_str)
113
+ for token in response.response_gen:
114
+ yield token
115
+
 
116
 
 
 
117
  except Exception as e:
118
  yield f"Error processing query: {str(e)}"
119
 
interface.py CHANGED
@@ -6,8 +6,71 @@ import torch
6
  from transformers import TextIteratorStreamer
7
  from threading import Thread
8
 
9
-
10
  class GemmaLLMInterface(CustomLLM):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  model: Any
12
  tokenizer: Any
13
  context_window: int = 8192
@@ -76,4 +139,4 @@ class GemmaLLMInterface(CustomLLM):
76
  for new_token in streamer:
77
  yield CompletionResponse(text=new_token)
78
  except StopIteration:
79
- return
 
6
  from transformers import TextIteratorStreamer
7
  from threading import Thread
8
 
9
+ # for transformers 2
10
  class GemmaLLMInterface(CustomLLM):
11
+ def __init__(self, model_id: str = "google/gemma-2-2b-it", context_window: int = 8192, num_output: int = 2048):
12
+ self.model_id = model_id
13
+ self.context_window = context_window
14
+ self.num_output = num_output
15
+
16
+ self.tokenizer = AutoTokenizer.from_pretrained(model_id)
17
+ self.model = AutoModelForCausalLM.from_pretrained(
18
+ model_id,
19
+ device_map="auto",
20
+ torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
21
+ )
22
+ self.model.eval()
23
+
24
+ def _format_prompt(self, message: str) -> str:
25
+ return f"<start_of_turn>user\n{message}<end_of_turn>\n<start_of_turn>model\n"
26
+
27
+ @property
28
+ def metadata(self) -> LLMMetadata:
29
+ return LLMMetadata(
30
+ context_window=self.context_window,
31
+ num_output=self.num_output,
32
+ model_name=self.model_id,
33
+ )
34
+
35
+ @llm_completion_callback()
36
+ def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
37
+ formatted_prompt = self._format_prompt(prompt)
38
+ inputs = self.tokenizer(formatted_prompt, return_tensors="pt").to(self.model.device)
39
+
40
+ with torch.no_grad():
41
+ outputs = self.model.generate(
42
+ **inputs,
43
+ max_new_tokens=self.num_output,
44
+ do_sample=True,
45
+ temperature=0.7,
46
+ top_p=0.95,
47
+ )
48
+
49
+ response = self.tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
50
+ return CompletionResponse(text=response)
51
+
52
+ @llm_completion_callback()
53
+ def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen:
54
+ formatted_prompt = self._format_prompt(prompt)
55
+ inputs = self.tokenizer(formatted_prompt, return_tensors="pt").to(self.model.device)
56
+
57
+ response = ""
58
+ with torch.no_grad():
59
+ for output in self.model.generate(
60
+ **inputs,
61
+ max_new_tokens=self.num_output,
62
+ do_sample=True,
63
+ temperature=0.7,
64
+ top_p=0.95,
65
+ streamer=True,
66
+ ):
67
+ token = self.tokenizer.decode(output, skip_special_tokens=True)
68
+ response += token
69
+ yield CompletionResponse(text=response, delta=token)
70
+
71
+
72
+ # for transformers 1
73
+ """class GemmaLLMInterface(CustomLLM):
74
  model: Any
75
  tokenizer: Any
76
  context_window: int = 8192
 
139
  for new_token in streamer:
140
  yield CompletionResponse(text=new_token)
141
  except StopIteration:
142
+ return"""
requirements.txt CHANGED
@@ -6,7 +6,7 @@ llama-index-embeddings-instructor
6
  sentence-transformers==2.2.2
7
  llama-index-readers-web
8
  llama-index-readers-file
9
- gradio==4.17.0
10
  transformers
11
  llama-cpp-agent>=0.2.25
12
  setuptools
 
6
  sentence-transformers==2.2.2
7
  llama-index-readers-web
8
  llama-index-readers-file
9
+ gradio
10
  transformers
11
  llama-cpp-agent>=0.2.25
12
  setuptools