Joash2024's picture
feat: configure for ZeroGPU with A100
628f881
raw
history blame
3.43 kB
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import spaces
# Model configurations
BASE_MODEL = "HuggingFaceTB/SmolLM2-1.7B-Instruct" # Base model
ADAPTER_MODEL = "Joash2024/Math-SmolLM2-1.7B" # Our LoRA adapter
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
tokenizer.pad_token = tokenizer.eos_token
print("Loading base model...")
model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
device_map="auto",
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
use_safetensors=True
)
print("Loading LoRA adapter...")
model = PeftModel.from_pretrained(
model,
ADAPTER_MODEL,
torch_dtype=torch.float16,
device_map="auto"
)
model.eval()
def format_prompt(function: str) -> str:
"""Format input prompt for the model"""
return f"""Given a mathematical function, find its derivative.
Function: {function}
The derivative of this function is:"""
@spaces.GPU
def generate_derivative(function: str, max_length: int = 100) -> str:
"""Generate derivative for a given function"""
# Format prompt
prompt = format_prompt(function)
# Tokenize
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
# Generate
with torch.no_grad():
outputs = model.generate(
**inputs,
max_length=max_length,
num_return_sequences=1,
temperature=0.1,
do_sample=False, # Deterministic generation
pad_token_id=tokenizer.eos_token_id
)
# Decode and extract derivative
generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
derivative = generated[len(prompt):].strip()
return derivative
@spaces.GPU
def solve_derivative(function: str) -> str:
"""Solve derivative and format output"""
if not function:
return "Please enter a function"
print(f"\nGenerating derivative for: {function}")
derivative = generate_derivative(function)
# Format output with step-by-step explanation
output = f"""Generated derivative: {derivative}
Let's verify this step by step:
1. Starting with f(x) = {function}
2. Applying differentiation rules
3. We get f'(x) = {derivative}"""
return output
# Create Gradio interface
with gr.Blocks(title="Mathematics Derivative Solver") as demo:
gr.Markdown("# Mathematics Derivative Solver")
gr.Markdown("Using our fine-tuned model to solve derivatives")
with gr.Row():
with gr.Column():
function_input = gr.Textbox(
label="Enter a function",
placeholder="Example: x^2, sin(x), e^x"
)
solve_btn = gr.Button("Find Derivative", variant="primary")
with gr.Row():
output = gr.Textbox(
label="Solution with Steps",
lines=6
)
# Example functions (reduced)
gr.Examples(
examples=[
["x^2"],
["\\sin{\\left(x\\right)}"],
["e^x"]
],
inputs=function_input,
outputs=output,
fn=solve_derivative,
cache_examples=False # Disable caching
)
# Connect the interface
solve_btn.click(
fn=solve_derivative,
inputs=[function_input],
outputs=output
)
if __name__ == "__main__":
demo.launch()