pushpikaLiyanagama's picture
Upload 8 files
ecc3892 verified
raw
history blame
1.92 kB
# inference.py
import joblib
import pandas as pd
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List
app = FastAPI()
# Load the scaler and models at startup
scaler = joblib.load("models/scaler.joblib")
models = {
"processing": joblib.load("models/svm_model_processing.joblib"),
"perception": joblib.load("models/svm_model_perception.joblib"),
"input": joblib.load("models/svm_model_input.joblib"),
"understanding": joblib.load("models/svm_model_understanding.joblib")
}
# Define the input schema
class InputData(BaseModel):
course_overview: float
reading_file: float
abstract_materiale: float
concrete_material: float
visual_materials: float
self_assessment: float
exercises_submit: float
quiz_submitted: float
playing: float
paused: float
unstarted: float
buffering: float
class PredictionResponse(BaseModel):
processing: int
perception: int
input: int
understanding: int
@app.post("/predict", response_model=PredictionResponse)
def predict(data: InputData):
try:
# Convert input data to DataFrame
input_df = pd.DataFrame([data.dict()])
# If there are categorical variables that were one-hot encoded during training,
# ensure that input data matches the training data's dummy variables.
# For simplicity, assuming all inputs are numerical and match the training features.
# Scale the input
input_scaled = scaler.transform(input_df)
# Make predictions for each target
predictions = {}
for target, model in models.items():
pred = model.predict(input_scaled)
predictions[target] = int(pred[0])
return predictions
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))