Spaces:
Runtime error
Runtime error
aryan10022001
commited on
Commit
•
59eed39
1
Parent(s):
75429f5
Upload 3 files
Browse files- app/__init__.py +0 -0
- app/main.py +84 -0
- app/pipe.pkl +3 -0
app/__init__.py
ADDED
File without changes
|
app/main.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import FastAPI, HTTPException
|
2 |
+
from pydantic import BaseModel
|
3 |
+
import pickle
|
4 |
+
import pandas as pd
|
5 |
+
import numpy as np
|
6 |
+
import xgboost
|
7 |
+
from xgboost import XGBRegressor
|
8 |
+
|
9 |
+
# Load the pre-trained pipeline
|
10 |
+
pipe = pickle.load(open('pipe.pkl', 'rb'))
|
11 |
+
|
12 |
+
# Define the list of teams and cities
|
13 |
+
teams = ['Australia', 'India', 'Bangladesh', 'New Zealand', 'South Africa', 'England',
|
14 |
+
'West Indies', 'Afghanistan', 'Pakistan', 'Sri Lanka']
|
15 |
+
|
16 |
+
cities = ['Colombo', 'Mirpur', 'Johannesburg', 'Dubai', 'Auckland', 'Cape Town', 'London',
|
17 |
+
'Pallekele', 'Barbados', 'Sydney', 'Melbourne', 'Durban', 'St Lucia', 'Wellington',
|
18 |
+
'Lauderhill', 'Hamilton', 'Centurion', 'Manchester', 'Abu Dhabi', 'Mumbai',
|
19 |
+
'Nottingham', 'Southampton', 'Mount Maunganui', 'Chittagong', 'Kolkata', 'Lahore',
|
20 |
+
'Delhi', 'Nagpur', 'Chandigarh', 'Adelaide', 'Bangalore', 'St Kitts', 'Cardiff',
|
21 |
+
'Christchurch', 'Trinidad']
|
22 |
+
|
23 |
+
# Initialize FastAPI
|
24 |
+
app = FastAPI()
|
25 |
+
|
26 |
+
# Define the request model
|
27 |
+
class ScorePredictionRequest(BaseModel):
|
28 |
+
batting_team: str
|
29 |
+
bowling_team: str
|
30 |
+
city: str
|
31 |
+
current_score: int
|
32 |
+
overs: float
|
33 |
+
wickets: int
|
34 |
+
last_five: int
|
35 |
+
|
36 |
+
# Define the response model
|
37 |
+
class ScorePredictionResponse(BaseModel):
|
38 |
+
predicted_score: int
|
39 |
+
|
40 |
+
# Ensure input values are valid
|
41 |
+
def validate_input(data: ScorePredictionRequest):
|
42 |
+
if data.batting_team not in teams:
|
43 |
+
raise HTTPException(status_code=400, detail="Invalid batting team")
|
44 |
+
if data.bowling_team not in teams:
|
45 |
+
raise HTTPException(status_code=400, detail="Invalid bowling team")
|
46 |
+
if data.city not in cities:
|
47 |
+
raise HTTPException(status_code=400, detail="Invalid city")
|
48 |
+
if not (0 <= data.overs <= 20):
|
49 |
+
raise HTTPException(status_code=400, detail="Overs must be between 0 and 20")
|
50 |
+
if not (0 <= data.wickets <= 10):
|
51 |
+
raise HTTPException(status_code=400, detail="Wickets must be between 0 and 10")
|
52 |
+
if data.current_score < 0:
|
53 |
+
raise HTTPException(status_code=400, detail="Current score must be non-negative")
|
54 |
+
if data.last_five < 0:
|
55 |
+
raise HTTPException(status_code=400, detail="Runs scored in last 5 overs must be non-negative")
|
56 |
+
|
57 |
+
# Define the prediction endpoint
|
58 |
+
@app.post("/predict", response_model=ScorePredictionResponse)
|
59 |
+
def predict_score(data: ScorePredictionRequest):
|
60 |
+
# Validate input
|
61 |
+
validate_input(data)
|
62 |
+
|
63 |
+
# Calculate additional features
|
64 |
+
balls_left = 120 - (data.overs * 6)
|
65 |
+
wickets_left = 10 - data.wickets
|
66 |
+
crr = data.current_score / data.overs if data.overs > 0 else 0
|
67 |
+
|
68 |
+
# Create input dataframe for the model
|
69 |
+
input_df = pd.DataFrame(
|
70 |
+
{'batting_team': [data.batting_team], 'bowling_team': [data.bowling_team], 'city': [data.city],
|
71 |
+
'current_score': [data.current_score], 'balls_left': [balls_left], 'wickets_left': [wickets_left],
|
72 |
+
'crr': [crr], 'last_five': [data.last_five]}
|
73 |
+
)
|
74 |
+
|
75 |
+
# Predict the score
|
76 |
+
result = pipe.predict(input_df)
|
77 |
+
predicted_score = int(result[0])
|
78 |
+
|
79 |
+
# Return the prediction
|
80 |
+
return ScorePredictionResponse(predicted_score=predicted_score)
|
81 |
+
|
82 |
+
if __name__ == "__main__":
|
83 |
+
import uvicorn
|
84 |
+
uvicorn.run(app, host="0.0.0.0", port=8000)
|
app/pipe.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4cf3d0d97c298361ae8c0539456bfb5896aa499a0b00ed17ffabb2b19501a245
|
3 |
+
size 29835654
|