Spaces:
Runtime error
Runtime error
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
import pickle | |
import pandas as pd | |
import numpy as np | |
import xgboost | |
from xgboost import XGBRegressor | |
# Load the pre-trained pipeline | |
pipe = pickle.load(open('app/pipe.pkl', 'rb')) | |
# Define the list of teams and cities | |
teams = ['Australia', 'India', 'Bangladesh', 'New Zealand', 'South Africa', 'England', | |
'West Indies', 'Afghanistan', 'Pakistan', 'Sri Lanka'] | |
cities = ['Colombo', 'Mirpur', 'Johannesburg', 'Dubai', 'Auckland', 'Cape Town', 'London', | |
'Pallekele', 'Barbados', 'Sydney', 'Melbourne', 'Durban', 'St Lucia', 'Wellington', | |
'Lauderhill', 'Hamilton', 'Centurion', 'Manchester', 'Abu Dhabi', 'Mumbai', | |
'Nottingham', 'Southampton', 'Mount Maunganui', 'Chittagong', 'Kolkata', 'Lahore', | |
'Delhi', 'Nagpur', 'Chandigarh', 'Adelaide', 'Bangalore', 'St Kitts', 'Cardiff', | |
'Christchurch', 'Trinidad'] | |
# Initialize FastAPI | |
app = FastAPI() | |
# Define the request model | |
class ScorePredictionRequest(BaseModel): | |
batting_team: str | |
bowling_team: str | |
city: str | |
current_score: int | |
overs: float | |
wickets: int | |
last_five: int | |
# Define the response model | |
class ScorePredictionResponse(BaseModel): | |
predicted_score: int | |
# Ensure input values are valid | |
def validate_input(data: ScorePredictionRequest): | |
if data.batting_team not in teams: | |
raise HTTPException(status_code=400, detail="Invalid batting team") | |
if data.bowling_team not in teams: | |
raise HTTPException(status_code=400, detail="Invalid bowling team") | |
if data.city not in cities: | |
raise HTTPException(status_code=400, detail="Invalid city") | |
if not (0 <= data.overs <= 20): | |
raise HTTPException(status_code=400, detail="Overs must be between 0 and 20") | |
if not (0 <= data.wickets <= 10): | |
raise HTTPException(status_code=400, detail="Wickets must be between 0 and 10") | |
if data.current_score < 0: | |
raise HTTPException(status_code=400, detail="Current score must be non-negative") | |
if data.last_five < 0: | |
raise HTTPException(status_code=400, detail="Runs scored in last 5 overs must be non-negative") | |
# Define the prediction endpoint | |
def predict_score(data: ScorePredictionRequest): | |
# Validate input | |
validate_input(data) | |
# Calculate additional features | |
balls_left = 120 - (data.overs * 6) | |
wickets_left = 10 - data.wickets | |
crr = data.current_score / data.overs if data.overs > 0 else 0 | |
# Create input dataframe for the model | |
input_df = pd.DataFrame( | |
{'batting_team': [data.batting_team], 'bowling_team': [data.bowling_team], 'city': [data.city], | |
'current_score': [data.current_score], 'balls_left': [balls_left], 'wickets_left': [wickets_left], | |
'crr': [crr], 'last_five': [data.last_five]} | |
) | |
# Predict the score | |
result = pipe.predict(input_df) | |
predicted_score = int(result[0]) | |
# Return the prediction | |
return ScorePredictionResponse(predicted_score=predicted_score) | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=8000) | |