File size: 1,922 Bytes
ecc3892
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
# inference.py

import joblib
import pandas as pd
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List

app = FastAPI()

# Load the scaler and models at startup
scaler = joblib.load("models/scaler.joblib")
models = {
    "processing": joblib.load("models/svm_model_processing.joblib"),
    "perception": joblib.load("models/svm_model_perception.joblib"),
    "input": joblib.load("models/svm_model_input.joblib"),
    "understanding": joblib.load("models/svm_model_understanding.joblib")
}

# Define the input schema
class InputData(BaseModel):
    course_overview: float
    reading_file: float
    abstract_materiale: float
    concrete_material: float
    visual_materials: float
    self_assessment: float
    exercises_submit: float
    quiz_submitted: float
    playing: float
    paused: float
    unstarted: float
    buffering: float

class PredictionResponse(BaseModel):
    processing: int
    perception: int
    input: int
    understanding: int

@app.post("/predict", response_model=PredictionResponse)
def predict(data: InputData):
    try:
        # Convert input data to DataFrame
        input_df = pd.DataFrame([data.dict()])
        
        # If there are categorical variables that were one-hot encoded during training,
        # ensure that input data matches the training data's dummy variables.
        # For simplicity, assuming all inputs are numerical and match the training features.

        # Scale the input
        input_scaled = scaler.transform(input_df)

        # Make predictions for each target
        predictions = {}
        for target, model in models.items():
            pred = model.predict(input_scaled)
            predictions[target] = int(pred[0])

        return predictions

    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))