Siddhant commited on
Commit
d4b17a2
1 Parent(s): 715b732

Upload 5 files

Browse files
LLM/chat.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class Chat:
2
+ """
3
+ Handles the chat using to avoid OOM issues.
4
+ """
5
+
6
+ def __init__(self, size):
7
+ self.size = size
8
+ self.init_chat_message = None
9
+ # maxlen is necessary pair, since a each new step we add an prompt and assitant answer
10
+ self.buffer = []
11
+
12
+ def append(self, item):
13
+ self.buffer.append(item)
14
+ if len(self.buffer) == 2 * (self.size + 1):
15
+ self.buffer.pop(0)
16
+ self.buffer.pop(0)
17
+
18
+ def init_chat(self, init_chat_message):
19
+ self.init_chat_message = init_chat_message
20
+
21
+ def to_list(self):
22
+ if self.init_chat_message:
23
+ return [self.init_chat_message] + self.buffer
24
+ else:
25
+ return self.buffer
LLM/language_model.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from threading import Thread
2
+ from transformers import (
3
+ AutoModelForCausalLM,
4
+ AutoTokenizer,
5
+ pipeline,
6
+ TextIteratorStreamer,
7
+ )
8
+ import torch
9
+
10
+ from LLM.chat import Chat
11
+ from baseHandler import BaseHandler
12
+ from rich.console import Console
13
+ import logging
14
+ from nltk import sent_tokenize
15
+
16
+ logging.basicConfig(
17
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
18
+ )
19
+ logger = logging.getLogger(__name__)
20
+
21
+ console = Console()
22
+
23
+
24
+ class LanguageModelHandler(BaseHandler):
25
+ """
26
+ Handles the language model part.
27
+ """
28
+
29
+ def setup(
30
+ self,
31
+ model_name="microsoft/Phi-3-mini-4k-instruct",
32
+ device="cuda",
33
+ torch_dtype="float16",
34
+ gen_kwargs={},
35
+ user_role="user",
36
+ chat_size=1,
37
+ init_chat_role=None,
38
+ init_chat_prompt="You are a helpful AI assistant.",
39
+ ):
40
+ self.device = device
41
+ self.torch_dtype = getattr(torch, torch_dtype)
42
+
43
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
44
+ self.model = AutoModelForCausalLM.from_pretrained(
45
+ model_name, torch_dtype=torch_dtype, trust_remote_code=True
46
+ ).to(device)
47
+ self.pipe = pipeline(
48
+ "text-generation", model=self.model, tokenizer=self.tokenizer, device=device
49
+ )
50
+ self.streamer = TextIteratorStreamer(
51
+ self.tokenizer,
52
+ skip_prompt=True,
53
+ skip_special_tokens=True,
54
+ )
55
+ self.gen_kwargs = {
56
+ "streamer": self.streamer,
57
+ "return_full_text": False,
58
+ **gen_kwargs,
59
+ }
60
+
61
+ self.chat = Chat(chat_size)
62
+ if init_chat_role:
63
+ if not init_chat_prompt:
64
+ raise ValueError(
65
+ "An initial promt needs to be specified when setting init_chat_role."
66
+ )
67
+ self.chat.init_chat({"role": init_chat_role, "content": init_chat_prompt})
68
+ self.user_role = user_role
69
+
70
+ self.warmup()
71
+
72
+ def warmup(self):
73
+ logger.info(f"Warming up {self.__class__.__name__}")
74
+
75
+ dummy_input_text = "Write me a poem about Machine Learning."
76
+ dummy_chat = [{"role": self.user_role, "content": dummy_input_text}]
77
+ warmup_gen_kwargs = {
78
+ "min_new_tokens": self.gen_kwargs["min_new_tokens"],
79
+ "max_new_tokens": self.gen_kwargs["max_new_tokens"],
80
+ **self.gen_kwargs,
81
+ }
82
+
83
+ n_steps = 2
84
+
85
+ if self.device == "cuda":
86
+ start_event = torch.cuda.Event(enable_timing=True)
87
+ end_event = torch.cuda.Event(enable_timing=True)
88
+ torch.cuda.synchronize()
89
+ start_event.record()
90
+
91
+ for _ in range(n_steps):
92
+ thread = Thread(
93
+ target=self.pipe, args=(dummy_chat,), kwargs=warmup_gen_kwargs
94
+ )
95
+ thread.start()
96
+ for _ in self.streamer:
97
+ pass
98
+
99
+ if self.device == "cuda":
100
+ end_event.record()
101
+ torch.cuda.synchronize()
102
+
103
+ logger.info(
104
+ f"{self.__class__.__name__}: warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s"
105
+ )
106
+
107
+ def process(self, prompt):
108
+ logger.debug("infering language model...")
109
+
110
+ self.chat.append({"role": self.user_role, "content": prompt})
111
+ thread = Thread(
112
+ target=self.pipe, args=(self.chat.to_list(),), kwargs=self.gen_kwargs
113
+ )
114
+ thread.start()
115
+ if self.device == "mps":
116
+ generated_text = ""
117
+ for new_text in self.streamer:
118
+ generated_text += new_text
119
+ printable_text = generated_text
120
+ torch.mps.empty_cache()
121
+ else:
122
+ generated_text, printable_text = "", ""
123
+ for new_text in self.streamer:
124
+ generated_text += new_text
125
+ printable_text += new_text
126
+ sentences = sent_tokenize(printable_text)
127
+ if len(sentences) > 1:
128
+ yield (sentences[0])
129
+ printable_text = new_text
130
+
131
+ self.chat.append({"role": "assistant", "content": generated_text})
132
+
133
+ # don't forget last sentence
134
+ yield printable_text
STT/lightning_whisper_mlx_handler.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from time import perf_counter
3
+ from baseHandler import BaseHandler
4
+ from lightning_whisper_mlx import LightningWhisperMLX
5
+ import numpy as np
6
+ from rich.console import Console
7
+ import torch
8
+
9
+ logging.basicConfig(
10
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
11
+ )
12
+ logger = logging.getLogger(__name__)
13
+
14
+ console = Console()
15
+
16
+
17
+ class LightningWhisperSTTHandler(BaseHandler):
18
+ """
19
+ Handles the Speech To Text generation using a Whisper model.
20
+ """
21
+
22
+ def setup(
23
+ self,
24
+ model_name="distil-large-v3",
25
+ device="cuda",
26
+ torch_dtype="float16",
27
+ compile_mode=None,
28
+ gen_kwargs={},
29
+ ):
30
+ if len(model_name.split("/")) > 1:
31
+ model_name = model_name.split("/")[-1]
32
+ self.device = device
33
+ self.model = LightningWhisperMLX(model=model_name, batch_size=6, quant=None)
34
+ self.warmup()
35
+
36
+ def warmup(self):
37
+ logger.info(f"Warming up {self.__class__.__name__}")
38
+
39
+ # 2 warmup steps for no compile or compile mode with CUDA graphs capture
40
+ n_steps = 1
41
+ dummy_input = np.array([0] * 512)
42
+
43
+ for _ in range(n_steps):
44
+ _ = self.model.transcribe(dummy_input)["text"].strip()
45
+
46
+ def process(self, spoken_prompt):
47
+ logger.debug("infering whisper...")
48
+
49
+ global pipeline_start
50
+ pipeline_start = perf_counter()
51
+
52
+ pred_text = self.model.transcribe(spoken_prompt)["text"].strip()
53
+ torch.mps.empty_cache()
54
+
55
+ logger.debug("finished whisper inference")
56
+ console.print(f"[yellow]USER: {pred_text}")
57
+
58
+ yield pred_text
STT/whisper_stt_handler.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from time import perf_counter
2
+ from transformers import (
3
+ AutoModelForSpeechSeq2Seq,
4
+ AutoProcessor,
5
+ )
6
+ import torch
7
+
8
+ from baseHandler import BaseHandler
9
+ from rich.console import Console
10
+ import logging
11
+
12
+ logging.basicConfig(
13
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
14
+ )
15
+ logger = logging.getLogger(__name__)
16
+
17
+ console = Console()
18
+
19
+
20
+ class WhisperSTTHandler(BaseHandler):
21
+ """
22
+ Handles the Speech To Text generation using a Whisper model.
23
+ """
24
+
25
+ def setup(
26
+ self,
27
+ model_name="distil-whisper/distil-large-v3",
28
+ device="cuda",
29
+ torch_dtype="float16",
30
+ compile_mode=None,
31
+ gen_kwargs={},
32
+ ):
33
+ self.device = device
34
+ self.torch_dtype = getattr(torch, torch_dtype)
35
+ self.compile_mode = compile_mode
36
+ self.gen_kwargs = gen_kwargs
37
+
38
+ self.processor = AutoProcessor.from_pretrained(model_name)
39
+ self.model = AutoModelForSpeechSeq2Seq.from_pretrained(
40
+ model_name,
41
+ torch_dtype=self.torch_dtype,
42
+ ).to(device)
43
+
44
+ # compile
45
+ if self.compile_mode:
46
+ self.model.generation_config.cache_implementation = "static"
47
+ self.model.forward = torch.compile(
48
+ self.model.forward, mode=self.compile_mode, fullgraph=True
49
+ )
50
+ self.warmup()
51
+
52
+ def prepare_model_inputs(self, spoken_prompt):
53
+ input_features = self.processor(
54
+ spoken_prompt, sampling_rate=16000, return_tensors="pt"
55
+ ).input_features
56
+ input_features = input_features.to(self.device, dtype=self.torch_dtype)
57
+
58
+ return input_features
59
+
60
+ def warmup(self):
61
+ logger.info(f"Warming up {self.__class__.__name__}")
62
+
63
+ # 2 warmup steps for no compile or compile mode with CUDA graphs capture
64
+ n_steps = 1 if self.compile_mode == "default" else 2
65
+ dummy_input = torch.randn(
66
+ (1, self.model.config.num_mel_bins, 3000),
67
+ dtype=self.torch_dtype,
68
+ device=self.device,
69
+ )
70
+ if self.compile_mode not in (None, "default"):
71
+ # generating more tokens than previously will trigger CUDA graphs capture
72
+ # one should warmup with a number of generated tokens above max tokens targeted for subsequent generation
73
+ warmup_gen_kwargs = {
74
+ "min_new_tokens": self.gen_kwargs["min_new_tokens"],
75
+ "max_new_tokens": self.gen_kwargs["max_new_tokens"],
76
+ **self.gen_kwargs,
77
+ }
78
+ else:
79
+ warmup_gen_kwargs = self.gen_kwargs
80
+
81
+ if self.device == "cuda":
82
+ start_event = torch.cuda.Event(enable_timing=True)
83
+ end_event = torch.cuda.Event(enable_timing=True)
84
+ torch.cuda.synchronize()
85
+ start_event.record()
86
+
87
+ for _ in range(n_steps):
88
+ _ = self.model.generate(dummy_input, **warmup_gen_kwargs)
89
+
90
+ if self.device == "cuda":
91
+ end_event.record()
92
+ torch.cuda.synchronize()
93
+
94
+ logger.info(
95
+ f"{self.__class__.__name__}: warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s"
96
+ )
97
+
98
+ def process(self, spoken_prompt):
99
+ logger.debug("infering whisper...")
100
+
101
+ global pipeline_start
102
+ pipeline_start = perf_counter()
103
+
104
+ input_features = self.prepare_model_inputs(spoken_prompt)
105
+ pred_ids = self.model.generate(input_features, **self.gen_kwargs)
106
+ pred_text = self.processor.batch_decode(
107
+ pred_ids, skip_special_tokens=True, decode_with_timestamps=False
108
+ )[0]
109
+
110
+ logger.debug("finished whisper inference")
111
+ console.print(f"[yellow]USER: {pred_text}")
112
+
113
+ yield pred_text
baseHandler.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from time import perf_counter
2
+ import logging
3
+
4
+ logger = logging.getLogger(__name__)
5
+
6
+
7
+ class BaseHandler:
8
+ """
9
+ Base class for pipeline parts. Each part of the pipeline has an input and an output queue.
10
+ The `setup` method along with `setup_args` and `setup_kwargs` can be used to address the specific requirements of the implemented pipeline part.
11
+ To stop a handler properly, set the stop_event and, to avoid queue deadlocks, place b"END" in the input queue.
12
+ Objects placed in the input queue will be processed by the `process` method, and the yielded results will be placed in the output queue.
13
+ The cleanup method handles stopping the handler, and b"END" is placed in the output queue.
14
+ """
15
+
16
+ def __init__(self, stop_event, queue_in, queue_out, setup_args=(), setup_kwargs={}):
17
+ self.stop_event = stop_event
18
+ self.queue_in = queue_in
19
+ self.queue_out = queue_out
20
+ self.setup(*setup_args, **setup_kwargs)
21
+ self._times = []
22
+
23
+ def setup(self):
24
+ pass
25
+
26
+ def process(self):
27
+ raise NotImplementedError
28
+
29
+ def run(self):
30
+ while not self.stop_event.is_set():
31
+ input = self.queue_in.get()
32
+ if isinstance(input, bytes) and input == b"END":
33
+ # sentinelle signal to avoid queue deadlock
34
+ logger.debug("Stopping thread")
35
+ break
36
+ start_time = perf_counter()
37
+ for output in self.process(input):
38
+ self._times.append(perf_counter() - start_time)
39
+ logger.debug(f"{self.__class__.__name__}: {self.last_time: .3f} s")
40
+ self.queue_out.put(output)
41
+ start_time = perf_counter()
42
+
43
+ self.cleanup()
44
+ self.queue_out.put(b"END")
45
+
46
+ @property
47
+ def last_time(self):
48
+ return self._times[-1]
49
+
50
+ def cleanup(self):
51
+ pass