Spaces:
Runtime error
Runtime error
File size: 1,836 Bytes
5cbf359 962cccf 6ee5519 9fd7f68 a5beb04 eec5a54 f4de9a0 eec5a54 9fd7f68 6ee5519 9fd7f68 6ee5519 9fd7f68 eec5a54 9fd7f68 5cbf359 eec5a54 5cbf359 eec5a54 5cbf359 eec5a54 57cdbdf eec5a54 5cbf359 eec5a54 57cdbdf eec5a54 9fd7f68 eec5a54 57cdbdf eec5a54 9fd7f68 5cbf359 eec5a54 9fd7f68 57cdbdf 032d6c3 57cdbdf f4de9a0 57cdbdf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 |
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import warnings
from typing import Dict
import spaces
device = "cuda"
# Ignore warnings
warnings.filterwarnings(action='ignore')
# Set random seed
torch.random.manual_seed(0)
# Define model path and generation arguments
model_path = "microsoft/Phi-3-mini-4k-instruct"
generation_args = {
"max_new_tokens": 50,
"return_full_text": False,
"temperature": 0.1,
"do_sample": True
}
# Load the model and pipeline once and keep it in memory
def load_model_pipeline(model_path: str):
if not hasattr(load_model_pipeline, "pipe"):
model = AutoModelForCausalLM.from_pretrained(
model_path,
device_map=device,
torch_dtype="auto",
trust_remote_code=True,
)
tokenizer = AutoTokenizer.from_pretrained(model_path)
load_model_pipeline.pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
return load_model_pipeline.pipe
# Initialize the pipeline and keep it in memory
pipe = load_model_pipeline(model_path)
# Generate output from LLM
@spaces.GPU(duration=50)
def generate_logic(llm_output: str) -> str:
prompt = f"""
Provide a detailed response based on the description: '{llm_output}'.
"""
messages = [
{"role": "system", "content": "Please provide a detailed response."},
{"role": "user", "content": prompt},
]
response = pipe(messages, **generation_args)
generated_text = response[0]['generated_text']
# Log the generated text
print(f"Generated Text: {generated_text}")
return generated_text
# Main function to process LLM output and return raw text
def process_description(description: str) -> str:
generated_output = generate_logic(description)
return generated_output |