import gradio as gr from transformers import AutoTokenizer, AutoModelForCausalLM import torch import logging import sys import gc # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) logger.info("Starting application...") logger.info(f"CUDA available: {torch.cuda.is_available()}") if torch.cuda.is_available(): logger.info(f"GPU: {torch.cuda.get_device_name(0)}") try: logger.info("Loading tokenizer...") # Use the base model's tokenizer instead base_model_id = "unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit" tokenizer = AutoTokenizer.from_pretrained( base_model_id, use_fast=True, trust_remote_code=True ) tokenizer.pad_token = tokenizer.eos_token logger.info("Tokenizer loaded successfully") logger.info("Loading fine-tuned model in 8-bit...") model_id = "htigenai/finetune_test_2" model = AutoModelForCausalLM.from_pretrained( model_id, device_map="auto", load_in_8bit=True, torch_dtype=torch.float16, low_cpu_mem_usage=True, max_memory={0: "12GB", "cpu": "4GB"} ) model.eval() logger.info("Model loaded successfully in 8-bit") # Clear any residual memory gc.collect() torch.cuda.empty_cache() def generate_text(prompt, max_tokens=100, temperature=0.7): try: # Format prompt with chat template formatted_prompt = f"### Human: {prompt}\n\n### Assistant:" inputs = tokenizer( formatted_prompt, return_tensors="pt", padding=True, truncation=True, max_length=256 ).to(model.device) with torch.inference_mode(): outputs = model.generate( **inputs, max_new_tokens=max_tokens, temperature=temperature, do_sample=True, top_p=0.95, repetition_penalty=1.2, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id, early_stopping=True, no_repeat_ngram_size=3, use_cache=True ) response = tokenizer.decode(outputs[0], skip_special_tokens=True) # Extract assistant's response if "### Assistant:" in response: response = response.split("### Assistant:")[-1].strip() # Clean up del outputs, inputs gc.collect() torch.cuda.empty_cache() return response except Exception as e: logger.error(f"Error during generation: {str(e)}") return f"Error generating response: {str(e)}" # Create Gradio interface iface = gr.Interface( fn=generate_text, inputs=[ gr.Textbox( lines=3, placeholder="Enter your prompt here...", label="Input Prompt", max_lines=5 ), gr.Slider( minimum=10, maximum=100, value=50, step=10, label="Max Tokens" ), gr.Slider( minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature" ) ], outputs=gr.Textbox( label="Generated Response", lines=5 ), title="HTIGENAI Reflection Analyzer (8-bit)", description="Using Llama 3.1 base tokenizer with fine-tuned model. Keep prompts concise for best results.", examples=[ ["What is machine learning?", 50, 0.7], ["Explain quantum computing", 50, 0.7], ], cache_examples=False ) # Launch interface iface.launch( server_name="0.0.0.0", share=False, show_error=True, enable_queue=True, max_threads=1 ) except Exception as e: logger.error(f"Application startup failed: {str(e)}") raise