Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from peft import PeftModel | |
import spaces | |
# Model configurations | |
BASE_MODEL = "HuggingFaceTB/SmolLM2-1.7B-Instruct" # Base model | |
ADAPTER_MODEL = "Joash2024/Math-SmolLM2-1.7B" # Our LoRA adapter | |
print("Loading tokenizer...") | |
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) | |
tokenizer.pad_token = tokenizer.eos_token | |
print("Loading base model...") | |
model = AutoModelForCausalLM.from_pretrained( | |
BASE_MODEL, | |
device_map="auto", | |
torch_dtype=torch.float16, | |
low_cpu_mem_usage=True, | |
use_safetensors=True | |
) | |
print("Loading LoRA adapter...") | |
model = PeftModel.from_pretrained( | |
model, | |
ADAPTER_MODEL, | |
torch_dtype=torch.float16, | |
device_map="auto" | |
) | |
model.eval() | |
def format_prompt(function: str) -> str: | |
"""Format input prompt for the model""" | |
return f"""Given a mathematical function, find its derivative. | |
Function: {function} | |
The derivative of this function is:""" | |
def generate_derivative(function: str, max_length: int = 100) -> str: | |
"""Generate derivative for a given function""" | |
# Format prompt | |
prompt = format_prompt(function) | |
# Tokenize | |
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
# Generate | |
with torch.no_grad(): | |
outputs = model.generate( | |
**inputs, | |
max_length=max_length, | |
num_return_sequences=1, | |
temperature=0.1, | |
do_sample=False, # Deterministic generation | |
pad_token_id=tokenizer.eos_token_id | |
) | |
# Decode and extract derivative | |
generated = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
derivative = generated[len(prompt):].strip() | |
return derivative | |
def solve_derivative(function: str) -> str: | |
"""Solve derivative and format output""" | |
if not function: | |
return "Please enter a function" | |
print(f"\nGenerating derivative for: {function}") | |
derivative = generate_derivative(function) | |
# Format output with step-by-step explanation | |
output = f"""Generated derivative: {derivative} | |
Let's verify this step by step: | |
1. Starting with f(x) = {function} | |
2. Applying differentiation rules | |
3. We get f'(x) = {derivative}""" | |
return output | |
# Create Gradio interface | |
with gr.Blocks(title="Mathematics Derivative Solver") as demo: | |
gr.Markdown("# Mathematics Derivative Solver") | |
gr.Markdown("Using our fine-tuned model to solve derivatives") | |
with gr.Row(): | |
with gr.Column(): | |
function_input = gr.Textbox( | |
label="Enter a function", | |
placeholder="Example: x^2, sin(x), e^x" | |
) | |
solve_btn = gr.Button("Find Derivative", variant="primary") | |
with gr.Row(): | |
output = gr.Textbox( | |
label="Solution with Steps", | |
lines=6 | |
) | |
# Example functions (reduced) | |
gr.Examples( | |
examples=[ | |
["x^2"], | |
["\\sin{\\left(x\\right)}"], | |
["e^x"] | |
], | |
inputs=function_input, | |
outputs=output, | |
fn=solve_derivative, | |
cache_examples=False # Disable caching | |
) | |
# Connect the interface | |
solve_btn.click( | |
fn=solve_derivative, | |
inputs=[function_input], | |
outputs=output | |
) | |
if __name__ == "__main__": | |
demo.launch() | |