Spaces:
Running
Running
pseudotensor
commited on
Commit
•
8f3dc34
1
Parent(s):
8cb62ff
Update with h2oGPT hash 3e927fb6330dd3d1256b47eb201bd376230dd20a
Browse files- generate.py +3 -7
- utils.py +0 -50
generate.py
CHANGED
@@ -3,8 +3,9 @@ import sys
|
|
3 |
import os
|
4 |
import traceback
|
5 |
import typing
|
|
|
6 |
|
7 |
-
from utils import set_seed, clear_torch_cache, save_generate_output, NullContext,
|
8 |
|
9 |
SEED = 1236
|
10 |
set_seed(SEED)
|
@@ -828,15 +829,10 @@ def evaluate(
|
|
828 |
skip_prompt = False
|
829 |
streamer = TextIteratorStreamer(tokenizer, skip_prompt=skip_prompt)
|
830 |
gen_kwargs.update(dict(streamer=streamer))
|
831 |
-
if debug:
|
832 |
-
KThread.show_threads()
|
833 |
target_func = generate_with_exceptions
|
834 |
-
if concurrency_count == 1:
|
835 |
-
# otherwise can't do this
|
836 |
-
KThread.kill_threads(target_func.__name__, debug=debug)
|
837 |
target = wrapped_partial(generate_with_exceptions, model.generate, prompt, inputs_decoded,
|
838 |
raise_generate_gpu_exceptions, **gen_kwargs)
|
839 |
-
thread =
|
840 |
thread.start()
|
841 |
outputs = ""
|
842 |
for new_text in streamer:
|
|
|
3 |
import os
|
4 |
import traceback
|
5 |
import typing
|
6 |
+
from threading import Thread
|
7 |
|
8 |
+
from utils import set_seed, clear_torch_cache, save_generate_output, NullContext, wrapped_partial
|
9 |
|
10 |
SEED = 1236
|
11 |
set_seed(SEED)
|
|
|
829 |
skip_prompt = False
|
830 |
streamer = TextIteratorStreamer(tokenizer, skip_prompt=skip_prompt)
|
831 |
gen_kwargs.update(dict(streamer=streamer))
|
|
|
|
|
832 |
target_func = generate_with_exceptions
|
|
|
|
|
|
|
833 |
target = wrapped_partial(generate_with_exceptions, model.generate, prompt, inputs_decoded,
|
834 |
raise_generate_gpu_exceptions, **gen_kwargs)
|
835 |
+
thread = Thread(target=target)
|
836 |
thread.start()
|
837 |
outputs = ""
|
838 |
for new_text in streamer:
|
utils.py
CHANGED
@@ -244,56 +244,6 @@ class NullContext(threading.local):
|
|
244 |
pass
|
245 |
|
246 |
|
247 |
-
class KThread(threading.Thread):
|
248 |
-
"""Thread with a kill method."""
|
249 |
-
|
250 |
-
def __init__(self, *args, **keywords):
|
251 |
-
threading.Thread.__init__(self, *args, **keywords)
|
252 |
-
self.killed = False
|
253 |
-
|
254 |
-
def start(self):
|
255 |
-
"""Start the thread."""
|
256 |
-
self.__run_backup = self.run
|
257 |
-
self.run = self.__run # Force the Thread to install our trace.
|
258 |
-
threading.Thread.start(self)
|
259 |
-
|
260 |
-
def __run(self):
|
261 |
-
"""install trace."""
|
262 |
-
sys.settrace(self.globaltrace)
|
263 |
-
self.__run_backup()
|
264 |
-
self.run = self.__run_backup
|
265 |
-
|
266 |
-
def globaltrace(self, frame, why, arg):
|
267 |
-
if why == 'call':
|
268 |
-
return self.localtrace
|
269 |
-
else:
|
270 |
-
return None
|
271 |
-
|
272 |
-
def localtrace(self, frame, why, arg):
|
273 |
-
if self.killed:
|
274 |
-
if why == 'line':
|
275 |
-
raise SystemExit()
|
276 |
-
return self.localtrace
|
277 |
-
|
278 |
-
def kill(self):
|
279 |
-
self.killed = True
|
280 |
-
|
281 |
-
@staticmethod
|
282 |
-
def show_threads():
|
283 |
-
for thread in threading.enumerate():
|
284 |
-
print(thread.name, flush=True)
|
285 |
-
|
286 |
-
@staticmethod
|
287 |
-
def kill_threads(name, debug=False):
|
288 |
-
for thread in threading.enumerate():
|
289 |
-
if name in thread.name:
|
290 |
-
if debug:
|
291 |
-
print("Trying to kill %s %s" % (thread.ident, thread), flush=True)
|
292 |
-
thread.kill()
|
293 |
-
if debug:
|
294 |
-
print(thread, flush=True)
|
295 |
-
|
296 |
-
|
297 |
def wrapped_partial(func, *args, **kwargs):
|
298 |
"""
|
299 |
Give partial properties of normal function, like __name__ attribute etc.
|
|
|
244 |
pass
|
245 |
|
246 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
247 |
def wrapped_partial(func, *args, **kwargs):
|
248 |
"""
|
249 |
Give partial properties of normal function, like __name__ attribute etc.
|