aryan10022001's picture
Update app/main.py
0f325aa verified
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import pickle
import pandas as pd
import numpy as np
import xgboost
from xgboost import XGBRegressor
# Load the pre-trained pipeline
pipe = pickle.load(open('app/pipe.pkl', 'rb'))
# Define the list of teams and cities
teams = ['Australia', 'India', 'Bangladesh', 'New Zealand', 'South Africa', 'England',
'West Indies', 'Afghanistan', 'Pakistan', 'Sri Lanka']
cities = ['Colombo', 'Mirpur', 'Johannesburg', 'Dubai', 'Auckland', 'Cape Town', 'London',
'Pallekele', 'Barbados', 'Sydney', 'Melbourne', 'Durban', 'St Lucia', 'Wellington',
'Lauderhill', 'Hamilton', 'Centurion', 'Manchester', 'Abu Dhabi', 'Mumbai',
'Nottingham', 'Southampton', 'Mount Maunganui', 'Chittagong', 'Kolkata', 'Lahore',
'Delhi', 'Nagpur', 'Chandigarh', 'Adelaide', 'Bangalore', 'St Kitts', 'Cardiff',
'Christchurch', 'Trinidad']
# Initialize FastAPI
app = FastAPI()
# Define the request model
class ScorePredictionRequest(BaseModel):
batting_team: str
bowling_team: str
city: str
current_score: int
overs: float
wickets: int
last_five: int
# Define the response model
class ScorePredictionResponse(BaseModel):
predicted_score: int
# Ensure input values are valid
def validate_input(data: ScorePredictionRequest):
if data.batting_team not in teams:
raise HTTPException(status_code=400, detail="Invalid batting team")
if data.bowling_team not in teams:
raise HTTPException(status_code=400, detail="Invalid bowling team")
if data.city not in cities:
raise HTTPException(status_code=400, detail="Invalid city")
if not (0 <= data.overs <= 20):
raise HTTPException(status_code=400, detail="Overs must be between 0 and 20")
if not (0 <= data.wickets <= 10):
raise HTTPException(status_code=400, detail="Wickets must be between 0 and 10")
if data.current_score < 0:
raise HTTPException(status_code=400, detail="Current score must be non-negative")
if data.last_five < 0:
raise HTTPException(status_code=400, detail="Runs scored in last 5 overs must be non-negative")
# Define the prediction endpoint
@app.post("/predict", response_model=ScorePredictionResponse)
def predict_score(data: ScorePredictionRequest):
# Validate input
validate_input(data)
# Calculate additional features
balls_left = 120 - (data.overs * 6)
wickets_left = 10 - data.wickets
crr = data.current_score / data.overs if data.overs > 0 else 0
# Create input dataframe for the model
input_df = pd.DataFrame(
{'batting_team': [data.batting_team], 'bowling_team': [data.bowling_team], 'city': [data.city],
'current_score': [data.current_score], 'balls_left': [balls_left], 'wickets_left': [wickets_left],
'crr': [crr], 'last_five': [data.last_five]}
)
# Predict the score
result = pipe.predict(input_df)
predicted_score = int(result[0])
# Return the prediction
return ScorePredictionResponse(predicted_score=predicted_score)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)