Spaces:
Sleeping
Sleeping
Upload 5 files
Browse files- LLM/chat.py +25 -0
- LLM/language_model.py +134 -0
- STT/lightning_whisper_mlx_handler.py +58 -0
- STT/whisper_stt_handler.py +113 -0
- baseHandler.py +51 -0
LLM/chat.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class Chat:
|
2 |
+
"""
|
3 |
+
Handles the chat using to avoid OOM issues.
|
4 |
+
"""
|
5 |
+
|
6 |
+
def __init__(self, size):
|
7 |
+
self.size = size
|
8 |
+
self.init_chat_message = None
|
9 |
+
# maxlen is necessary pair, since a each new step we add an prompt and assitant answer
|
10 |
+
self.buffer = []
|
11 |
+
|
12 |
+
def append(self, item):
|
13 |
+
self.buffer.append(item)
|
14 |
+
if len(self.buffer) == 2 * (self.size + 1):
|
15 |
+
self.buffer.pop(0)
|
16 |
+
self.buffer.pop(0)
|
17 |
+
|
18 |
+
def init_chat(self, init_chat_message):
|
19 |
+
self.init_chat_message = init_chat_message
|
20 |
+
|
21 |
+
def to_list(self):
|
22 |
+
if self.init_chat_message:
|
23 |
+
return [self.init_chat_message] + self.buffer
|
24 |
+
else:
|
25 |
+
return self.buffer
|
LLM/language_model.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from threading import Thread
|
2 |
+
from transformers import (
|
3 |
+
AutoModelForCausalLM,
|
4 |
+
AutoTokenizer,
|
5 |
+
pipeline,
|
6 |
+
TextIteratorStreamer,
|
7 |
+
)
|
8 |
+
import torch
|
9 |
+
|
10 |
+
from LLM.chat import Chat
|
11 |
+
from baseHandler import BaseHandler
|
12 |
+
from rich.console import Console
|
13 |
+
import logging
|
14 |
+
from nltk import sent_tokenize
|
15 |
+
|
16 |
+
logging.basicConfig(
|
17 |
+
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
18 |
+
)
|
19 |
+
logger = logging.getLogger(__name__)
|
20 |
+
|
21 |
+
console = Console()
|
22 |
+
|
23 |
+
|
24 |
+
class LanguageModelHandler(BaseHandler):
|
25 |
+
"""
|
26 |
+
Handles the language model part.
|
27 |
+
"""
|
28 |
+
|
29 |
+
def setup(
|
30 |
+
self,
|
31 |
+
model_name="microsoft/Phi-3-mini-4k-instruct",
|
32 |
+
device="cuda",
|
33 |
+
torch_dtype="float16",
|
34 |
+
gen_kwargs={},
|
35 |
+
user_role="user",
|
36 |
+
chat_size=1,
|
37 |
+
init_chat_role=None,
|
38 |
+
init_chat_prompt="You are a helpful AI assistant.",
|
39 |
+
):
|
40 |
+
self.device = device
|
41 |
+
self.torch_dtype = getattr(torch, torch_dtype)
|
42 |
+
|
43 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
44 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
45 |
+
model_name, torch_dtype=torch_dtype, trust_remote_code=True
|
46 |
+
).to(device)
|
47 |
+
self.pipe = pipeline(
|
48 |
+
"text-generation", model=self.model, tokenizer=self.tokenizer, device=device
|
49 |
+
)
|
50 |
+
self.streamer = TextIteratorStreamer(
|
51 |
+
self.tokenizer,
|
52 |
+
skip_prompt=True,
|
53 |
+
skip_special_tokens=True,
|
54 |
+
)
|
55 |
+
self.gen_kwargs = {
|
56 |
+
"streamer": self.streamer,
|
57 |
+
"return_full_text": False,
|
58 |
+
**gen_kwargs,
|
59 |
+
}
|
60 |
+
|
61 |
+
self.chat = Chat(chat_size)
|
62 |
+
if init_chat_role:
|
63 |
+
if not init_chat_prompt:
|
64 |
+
raise ValueError(
|
65 |
+
"An initial promt needs to be specified when setting init_chat_role."
|
66 |
+
)
|
67 |
+
self.chat.init_chat({"role": init_chat_role, "content": init_chat_prompt})
|
68 |
+
self.user_role = user_role
|
69 |
+
|
70 |
+
self.warmup()
|
71 |
+
|
72 |
+
def warmup(self):
|
73 |
+
logger.info(f"Warming up {self.__class__.__name__}")
|
74 |
+
|
75 |
+
dummy_input_text = "Write me a poem about Machine Learning."
|
76 |
+
dummy_chat = [{"role": self.user_role, "content": dummy_input_text}]
|
77 |
+
warmup_gen_kwargs = {
|
78 |
+
"min_new_tokens": self.gen_kwargs["min_new_tokens"],
|
79 |
+
"max_new_tokens": self.gen_kwargs["max_new_tokens"],
|
80 |
+
**self.gen_kwargs,
|
81 |
+
}
|
82 |
+
|
83 |
+
n_steps = 2
|
84 |
+
|
85 |
+
if self.device == "cuda":
|
86 |
+
start_event = torch.cuda.Event(enable_timing=True)
|
87 |
+
end_event = torch.cuda.Event(enable_timing=True)
|
88 |
+
torch.cuda.synchronize()
|
89 |
+
start_event.record()
|
90 |
+
|
91 |
+
for _ in range(n_steps):
|
92 |
+
thread = Thread(
|
93 |
+
target=self.pipe, args=(dummy_chat,), kwargs=warmup_gen_kwargs
|
94 |
+
)
|
95 |
+
thread.start()
|
96 |
+
for _ in self.streamer:
|
97 |
+
pass
|
98 |
+
|
99 |
+
if self.device == "cuda":
|
100 |
+
end_event.record()
|
101 |
+
torch.cuda.synchronize()
|
102 |
+
|
103 |
+
logger.info(
|
104 |
+
f"{self.__class__.__name__}: warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s"
|
105 |
+
)
|
106 |
+
|
107 |
+
def process(self, prompt):
|
108 |
+
logger.debug("infering language model...")
|
109 |
+
|
110 |
+
self.chat.append({"role": self.user_role, "content": prompt})
|
111 |
+
thread = Thread(
|
112 |
+
target=self.pipe, args=(self.chat.to_list(),), kwargs=self.gen_kwargs
|
113 |
+
)
|
114 |
+
thread.start()
|
115 |
+
if self.device == "mps":
|
116 |
+
generated_text = ""
|
117 |
+
for new_text in self.streamer:
|
118 |
+
generated_text += new_text
|
119 |
+
printable_text = generated_text
|
120 |
+
torch.mps.empty_cache()
|
121 |
+
else:
|
122 |
+
generated_text, printable_text = "", ""
|
123 |
+
for new_text in self.streamer:
|
124 |
+
generated_text += new_text
|
125 |
+
printable_text += new_text
|
126 |
+
sentences = sent_tokenize(printable_text)
|
127 |
+
if len(sentences) > 1:
|
128 |
+
yield (sentences[0])
|
129 |
+
printable_text = new_text
|
130 |
+
|
131 |
+
self.chat.append({"role": "assistant", "content": generated_text})
|
132 |
+
|
133 |
+
# don't forget last sentence
|
134 |
+
yield printable_text
|
STT/lightning_whisper_mlx_handler.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from time import perf_counter
|
3 |
+
from baseHandler import BaseHandler
|
4 |
+
from lightning_whisper_mlx import LightningWhisperMLX
|
5 |
+
import numpy as np
|
6 |
+
from rich.console import Console
|
7 |
+
import torch
|
8 |
+
|
9 |
+
logging.basicConfig(
|
10 |
+
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
11 |
+
)
|
12 |
+
logger = logging.getLogger(__name__)
|
13 |
+
|
14 |
+
console = Console()
|
15 |
+
|
16 |
+
|
17 |
+
class LightningWhisperSTTHandler(BaseHandler):
|
18 |
+
"""
|
19 |
+
Handles the Speech To Text generation using a Whisper model.
|
20 |
+
"""
|
21 |
+
|
22 |
+
def setup(
|
23 |
+
self,
|
24 |
+
model_name="distil-large-v3",
|
25 |
+
device="cuda",
|
26 |
+
torch_dtype="float16",
|
27 |
+
compile_mode=None,
|
28 |
+
gen_kwargs={},
|
29 |
+
):
|
30 |
+
if len(model_name.split("/")) > 1:
|
31 |
+
model_name = model_name.split("/")[-1]
|
32 |
+
self.device = device
|
33 |
+
self.model = LightningWhisperMLX(model=model_name, batch_size=6, quant=None)
|
34 |
+
self.warmup()
|
35 |
+
|
36 |
+
def warmup(self):
|
37 |
+
logger.info(f"Warming up {self.__class__.__name__}")
|
38 |
+
|
39 |
+
# 2 warmup steps for no compile or compile mode with CUDA graphs capture
|
40 |
+
n_steps = 1
|
41 |
+
dummy_input = np.array([0] * 512)
|
42 |
+
|
43 |
+
for _ in range(n_steps):
|
44 |
+
_ = self.model.transcribe(dummy_input)["text"].strip()
|
45 |
+
|
46 |
+
def process(self, spoken_prompt):
|
47 |
+
logger.debug("infering whisper...")
|
48 |
+
|
49 |
+
global pipeline_start
|
50 |
+
pipeline_start = perf_counter()
|
51 |
+
|
52 |
+
pred_text = self.model.transcribe(spoken_prompt)["text"].strip()
|
53 |
+
torch.mps.empty_cache()
|
54 |
+
|
55 |
+
logger.debug("finished whisper inference")
|
56 |
+
console.print(f"[yellow]USER: {pred_text}")
|
57 |
+
|
58 |
+
yield pred_text
|
STT/whisper_stt_handler.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from time import perf_counter
|
2 |
+
from transformers import (
|
3 |
+
AutoModelForSpeechSeq2Seq,
|
4 |
+
AutoProcessor,
|
5 |
+
)
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from baseHandler import BaseHandler
|
9 |
+
from rich.console import Console
|
10 |
+
import logging
|
11 |
+
|
12 |
+
logging.basicConfig(
|
13 |
+
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
14 |
+
)
|
15 |
+
logger = logging.getLogger(__name__)
|
16 |
+
|
17 |
+
console = Console()
|
18 |
+
|
19 |
+
|
20 |
+
class WhisperSTTHandler(BaseHandler):
|
21 |
+
"""
|
22 |
+
Handles the Speech To Text generation using a Whisper model.
|
23 |
+
"""
|
24 |
+
|
25 |
+
def setup(
|
26 |
+
self,
|
27 |
+
model_name="distil-whisper/distil-large-v3",
|
28 |
+
device="cuda",
|
29 |
+
torch_dtype="float16",
|
30 |
+
compile_mode=None,
|
31 |
+
gen_kwargs={},
|
32 |
+
):
|
33 |
+
self.device = device
|
34 |
+
self.torch_dtype = getattr(torch, torch_dtype)
|
35 |
+
self.compile_mode = compile_mode
|
36 |
+
self.gen_kwargs = gen_kwargs
|
37 |
+
|
38 |
+
self.processor = AutoProcessor.from_pretrained(model_name)
|
39 |
+
self.model = AutoModelForSpeechSeq2Seq.from_pretrained(
|
40 |
+
model_name,
|
41 |
+
torch_dtype=self.torch_dtype,
|
42 |
+
).to(device)
|
43 |
+
|
44 |
+
# compile
|
45 |
+
if self.compile_mode:
|
46 |
+
self.model.generation_config.cache_implementation = "static"
|
47 |
+
self.model.forward = torch.compile(
|
48 |
+
self.model.forward, mode=self.compile_mode, fullgraph=True
|
49 |
+
)
|
50 |
+
self.warmup()
|
51 |
+
|
52 |
+
def prepare_model_inputs(self, spoken_prompt):
|
53 |
+
input_features = self.processor(
|
54 |
+
spoken_prompt, sampling_rate=16000, return_tensors="pt"
|
55 |
+
).input_features
|
56 |
+
input_features = input_features.to(self.device, dtype=self.torch_dtype)
|
57 |
+
|
58 |
+
return input_features
|
59 |
+
|
60 |
+
def warmup(self):
|
61 |
+
logger.info(f"Warming up {self.__class__.__name__}")
|
62 |
+
|
63 |
+
# 2 warmup steps for no compile or compile mode with CUDA graphs capture
|
64 |
+
n_steps = 1 if self.compile_mode == "default" else 2
|
65 |
+
dummy_input = torch.randn(
|
66 |
+
(1, self.model.config.num_mel_bins, 3000),
|
67 |
+
dtype=self.torch_dtype,
|
68 |
+
device=self.device,
|
69 |
+
)
|
70 |
+
if self.compile_mode not in (None, "default"):
|
71 |
+
# generating more tokens than previously will trigger CUDA graphs capture
|
72 |
+
# one should warmup with a number of generated tokens above max tokens targeted for subsequent generation
|
73 |
+
warmup_gen_kwargs = {
|
74 |
+
"min_new_tokens": self.gen_kwargs["min_new_tokens"],
|
75 |
+
"max_new_tokens": self.gen_kwargs["max_new_tokens"],
|
76 |
+
**self.gen_kwargs,
|
77 |
+
}
|
78 |
+
else:
|
79 |
+
warmup_gen_kwargs = self.gen_kwargs
|
80 |
+
|
81 |
+
if self.device == "cuda":
|
82 |
+
start_event = torch.cuda.Event(enable_timing=True)
|
83 |
+
end_event = torch.cuda.Event(enable_timing=True)
|
84 |
+
torch.cuda.synchronize()
|
85 |
+
start_event.record()
|
86 |
+
|
87 |
+
for _ in range(n_steps):
|
88 |
+
_ = self.model.generate(dummy_input, **warmup_gen_kwargs)
|
89 |
+
|
90 |
+
if self.device == "cuda":
|
91 |
+
end_event.record()
|
92 |
+
torch.cuda.synchronize()
|
93 |
+
|
94 |
+
logger.info(
|
95 |
+
f"{self.__class__.__name__}: warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s"
|
96 |
+
)
|
97 |
+
|
98 |
+
def process(self, spoken_prompt):
|
99 |
+
logger.debug("infering whisper...")
|
100 |
+
|
101 |
+
global pipeline_start
|
102 |
+
pipeline_start = perf_counter()
|
103 |
+
|
104 |
+
input_features = self.prepare_model_inputs(spoken_prompt)
|
105 |
+
pred_ids = self.model.generate(input_features, **self.gen_kwargs)
|
106 |
+
pred_text = self.processor.batch_decode(
|
107 |
+
pred_ids, skip_special_tokens=True, decode_with_timestamps=False
|
108 |
+
)[0]
|
109 |
+
|
110 |
+
logger.debug("finished whisper inference")
|
111 |
+
console.print(f"[yellow]USER: {pred_text}")
|
112 |
+
|
113 |
+
yield pred_text
|
baseHandler.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from time import perf_counter
|
2 |
+
import logging
|
3 |
+
|
4 |
+
logger = logging.getLogger(__name__)
|
5 |
+
|
6 |
+
|
7 |
+
class BaseHandler:
|
8 |
+
"""
|
9 |
+
Base class for pipeline parts. Each part of the pipeline has an input and an output queue.
|
10 |
+
The `setup` method along with `setup_args` and `setup_kwargs` can be used to address the specific requirements of the implemented pipeline part.
|
11 |
+
To stop a handler properly, set the stop_event and, to avoid queue deadlocks, place b"END" in the input queue.
|
12 |
+
Objects placed in the input queue will be processed by the `process` method, and the yielded results will be placed in the output queue.
|
13 |
+
The cleanup method handles stopping the handler, and b"END" is placed in the output queue.
|
14 |
+
"""
|
15 |
+
|
16 |
+
def __init__(self, stop_event, queue_in, queue_out, setup_args=(), setup_kwargs={}):
|
17 |
+
self.stop_event = stop_event
|
18 |
+
self.queue_in = queue_in
|
19 |
+
self.queue_out = queue_out
|
20 |
+
self.setup(*setup_args, **setup_kwargs)
|
21 |
+
self._times = []
|
22 |
+
|
23 |
+
def setup(self):
|
24 |
+
pass
|
25 |
+
|
26 |
+
def process(self):
|
27 |
+
raise NotImplementedError
|
28 |
+
|
29 |
+
def run(self):
|
30 |
+
while not self.stop_event.is_set():
|
31 |
+
input = self.queue_in.get()
|
32 |
+
if isinstance(input, bytes) and input == b"END":
|
33 |
+
# sentinelle signal to avoid queue deadlock
|
34 |
+
logger.debug("Stopping thread")
|
35 |
+
break
|
36 |
+
start_time = perf_counter()
|
37 |
+
for output in self.process(input):
|
38 |
+
self._times.append(perf_counter() - start_time)
|
39 |
+
logger.debug(f"{self.__class__.__name__}: {self.last_time: .3f} s")
|
40 |
+
self.queue_out.put(output)
|
41 |
+
start_time = perf_counter()
|
42 |
+
|
43 |
+
self.cleanup()
|
44 |
+
self.queue_out.put(b"END")
|
45 |
+
|
46 |
+
@property
|
47 |
+
def last_time(self):
|
48 |
+
return self._times[-1]
|
49 |
+
|
50 |
+
def cleanup(self):
|
51 |
+
pass
|