import torch from transformers import TextIteratorStreamer import threading class ModelWrapper: def __init__(self): self.model = None # Model will be loaded when GPU is allocated @spaces.GPU def generate(self, prompt): if self.model is None: # Load the model when GPU is allocated self.model = AutoGPTQForCausalLM.from_quantized( model_id, device_map='auto', trust_remote_code=True, ) self.model.eval() # Tokenize the input prompt inputs = tokenizer(prompt, return_tensors='pt').to('cuda') # Set up the streamer streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True) # Prepare generation arguments generation_kwargs = dict( **inputs, streamer=streamer, do_sample=True, max_new_tokens=512, ) # Start generation in a separate thread to enable streaming thread = threading.Thread(target=self.model.generate, kwargs=generation_kwargs) thread.start() # Yield generated text in real-time generated_text = "" for new_text in streamer: generated_text += new_text yield generated_text