gufett0 commited on
Commit
1c8dd0f
·
1 Parent(s): 708da42

added text iterator

Browse files
Files changed (1) hide show
  1. backend.py +24 -17
backend.py CHANGED
@@ -13,26 +13,13 @@ from llama_cpp import Llama
13
  import spaces
14
  from huggingface_hub import login
15
 
 
 
16
 
17
 
18
  huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
19
  login(huggingface_token)
20
 
21
- """hf_hub_download(
22
- repo_id="google/gemma-2-2b-it-GGUF",
23
- filename="2b_it_v2.gguf",
24
- local_dir="./models",
25
- token=huggingface_token
26
- )
27
-
28
- llm = Llama(
29
- model_path=f"models/2b_it_v2.gguf",
30
- flash_attn=True,
31
- _gpu_layers=81,
32
- n_batch=1024,
33
- n_ctx=8192,
34
- )"""
35
-
36
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
37
 
38
  model_id = "google/gemma-2-2b-it"
@@ -85,8 +72,28 @@ def handle_query(query_str, chathistory):
85
  ("user", qa_prompt_str),
86
  ]
87
  text_qa_template = ChatPromptTemplate.from_messages(chat_text_qa_msgs)
88
-
 
 
 
89
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  result = index.as_query_engine(text_qa_template=text_qa_template).query(query_str)
91
  response_text = result.response
92
 
@@ -95,7 +102,7 @@ def handle_query(query_str, chathistory):
95
 
96
  yield cleaned_result
97
  except Exception as e:
98
- yield f"Error processing query: {str(e)}"
99
 
100
 
101
 
 
13
  import spaces
14
  from huggingface_hub import login
15
 
16
+ from transformers import TextIteratorStreamer
17
+ import threading
18
 
19
 
20
  huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
21
  login(huggingface_token)
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
24
 
25
  model_id = "google/gemma-2-2b-it"
 
72
  ("user", qa_prompt_str),
73
  ]
74
  text_qa_template = ChatPromptTemplate.from_messages(chat_text_qa_msgs)
75
+
76
+ # Create the query engine
77
+ query_engine = index.as_query_engine(text_qa_template=text_qa_template)
78
+
79
  try:
80
+ # Setup the TextIteratorStreamer for streaming the response
81
+ streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
82
+
83
+ # Create a thread to run the generation in the background
84
+ def generate_response():
85
+ query_engine.query(query_str, streamer=streamer)
86
+
87
+ generation_thread = threading.Thread(target=generate_response)
88
+ generation_thread.start()
89
+
90
+ # Stream tokens as they are generated
91
+ for new_text in streamer:
92
+ yield new_text
93
+ except Exception as e:
94
+ yield f"Error processing query: {str(e)}"
95
+
96
+ """ try:
97
  result = index.as_query_engine(text_qa_template=text_qa_template).query(query_str)
98
  response_text = result.response
99
 
 
102
 
103
  yield cleaned_result
104
  except Exception as e:
105
+ yield f"Error processing query: {str(e)}""""
106
 
107
 
108