Voice-CPU / app.py
Staticaliza's picture
Update app.py
3a7347e verified
raw
history blame
1.3 kB
import torch
from transformers import TextIteratorStreamer
import threading
class ModelWrapper:
def __init__(self):
self.model = None # Model will be loaded when GPU is allocated
@spaces.GPU
def generate(self, prompt):
if self.model is None:
# Load the model when GPU is allocated
self.model = AutoGPTQForCausalLM.from_quantized(
model_id,
device_map='auto',
trust_remote_code=True,
)
self.model.eval()
# Tokenize the input prompt
inputs = tokenizer(prompt, return_tensors='pt').to('cuda')
# Set up the streamer
streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
# Prepare generation arguments
generation_kwargs = dict(
**inputs,
streamer=streamer,
do_sample=True,
max_new_tokens=512,
)
# Start generation in a separate thread to enable streaming
thread = threading.Thread(target=self.model.generate, kwargs=generation_kwargs)
thread.start()
# Yield generated text in real-time
generated_text = ""
for new_text in streamer:
generated_text += new_text
yield generated_text