Spaces:
Runtime error
Runtime error
quantization added
Browse files- app.py +1 -1
- model.py +1 -1
- trainer.py +10 -1
app.py
CHANGED
@@ -34,7 +34,7 @@ def plot_mel_spectrogram(mel_spec):
|
|
34 |
def get_or_load_model():
|
35 |
if 'model' not in st.session_state or 'tokenizer' not in st.session_state or 'processor' not in st.session_state:
|
36 |
ckpt_path = "checkpoints/pretrained_checkpoint.ckpt"
|
37 |
-
model = SpeechLLMLightning.load_from_checkpoint(ckpt_path)
|
38 |
tokenizer = model.llm_tokenizer
|
39 |
model.eval()
|
40 |
model.freeze()
|
|
|
34 |
def get_or_load_model():
|
35 |
if 'model' not in st.session_state or 'tokenizer' not in st.session_state or 'processor' not in st.session_state:
|
36 |
ckpt_path = "checkpoints/pretrained_checkpoint.ckpt"
|
37 |
+
model = SpeechLLMLightning.load_from_checkpoint(ckpt_path, quantize=True)
|
38 |
tokenizer = model.llm_tokenizer
|
39 |
model.eval()
|
40 |
model.freeze()
|
model.py
CHANGED
@@ -13,7 +13,7 @@ else:
|
|
13 |
class HubertXCNNEnoder(nn.Module):
|
14 |
def __init__(self, audio_enc_dim, llm_dim, finetune=False):
|
15 |
super().__init__()
|
16 |
-
self.encoder = HubertModel.from_pretrained('facebook/hubert-xlarge-ll60k'
|
17 |
for param in self.encoder.parameters():
|
18 |
param.requires_grad = False
|
19 |
|
|
|
13 |
class HubertXCNNEnoder(nn.Module):
|
14 |
def __init__(self, audio_enc_dim, llm_dim, finetune=False):
|
15 |
super().__init__()
|
16 |
+
self.encoder = HubertModel.from_pretrained('facebook/hubert-xlarge-ll60k').to(device)
|
17 |
for param in self.encoder.parameters():
|
18 |
param.requires_grad = False
|
19 |
|
trainer.py
CHANGED
@@ -6,6 +6,9 @@ from peft import LoraConfig, get_peft_model, PeftModel
|
|
6 |
import pytorch_lightning as pl
|
7 |
from model import HubertXCNNEnoder
|
8 |
|
|
|
|
|
|
|
9 |
|
10 |
if torch.cuda.is_available():
|
11 |
# Set the device to CUDA
|
@@ -15,7 +18,7 @@ else:
|
|
15 |
device = "cpu"
|
16 |
|
17 |
class SpeechLLMLightning(pl.LightningModule):
|
18 |
-
def __init__(self, audio_enc_dim=512, llm_dim=2048, llm_name="TinyLlama/TinyLlama-1.1B-Chat-v1.0"):
|
19 |
super().__init__()
|
20 |
self.save_hyperparameters()
|
21 |
|
@@ -48,6 +51,12 @@ class SpeechLLMLightning(pl.LightningModule):
|
|
48 |
self.audio_encoder.eval()
|
49 |
self.llm_model.eval()
|
50 |
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
|
52 |
def encode(self, mel, pre_tokenized_ids, post_tokenized_ids, output_tokenized_ids):
|
53 |
batch_size = mel.shape[0]
|
|
|
6 |
import pytorch_lightning as pl
|
7 |
from model import HubertXCNNEnoder
|
8 |
|
9 |
+
from torch.quantization import quantize_dynamic
|
10 |
+
import torch.jit as jit
|
11 |
+
|
12 |
|
13 |
if torch.cuda.is_available():
|
14 |
# Set the device to CUDA
|
|
|
18 |
device = "cpu"
|
19 |
|
20 |
class SpeechLLMLightning(pl.LightningModule):
|
21 |
+
def __init__(self, audio_enc_dim=512, llm_dim=2048, llm_name="TinyLlama/TinyLlama-1.1B-Chat-v1.0", quantize=True):
|
22 |
super().__init__()
|
23 |
self.save_hyperparameters()
|
24 |
|
|
|
51 |
self.audio_encoder.eval()
|
52 |
self.llm_model.eval()
|
53 |
|
54 |
+
if quantize:
|
55 |
+
self.llm_model = jit.script(self.llm_model)
|
56 |
+
self.llm_model = quantize_dynamic(
|
57 |
+
self.llm_model, {nn.Linear}, dtype=torch.qint8
|
58 |
+
)
|
59 |
+
|
60 |
|
61 |
def encode(self, mel, pre_tokenized_ids, post_tokenized_ids, output_tokenized_ids):
|
62 |
batch_size = mel.shape[0]
|