aryan10022001 commited on
Commit
59eed39
1 Parent(s): 75429f5

Upload 3 files

Browse files
Files changed (3) hide show
  1. app/__init__.py +0 -0
  2. app/main.py +84 -0
  3. app/pipe.pkl +3 -0
app/__init__.py ADDED
File without changes
app/main.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
3
+ import pickle
4
+ import pandas as pd
5
+ import numpy as np
6
+ import xgboost
7
+ from xgboost import XGBRegressor
8
+
9
+ # Load the pre-trained pipeline
10
+ pipe = pickle.load(open('pipe.pkl', 'rb'))
11
+
12
+ # Define the list of teams and cities
13
+ teams = ['Australia', 'India', 'Bangladesh', 'New Zealand', 'South Africa', 'England',
14
+ 'West Indies', 'Afghanistan', 'Pakistan', 'Sri Lanka']
15
+
16
+ cities = ['Colombo', 'Mirpur', 'Johannesburg', 'Dubai', 'Auckland', 'Cape Town', 'London',
17
+ 'Pallekele', 'Barbados', 'Sydney', 'Melbourne', 'Durban', 'St Lucia', 'Wellington',
18
+ 'Lauderhill', 'Hamilton', 'Centurion', 'Manchester', 'Abu Dhabi', 'Mumbai',
19
+ 'Nottingham', 'Southampton', 'Mount Maunganui', 'Chittagong', 'Kolkata', 'Lahore',
20
+ 'Delhi', 'Nagpur', 'Chandigarh', 'Adelaide', 'Bangalore', 'St Kitts', 'Cardiff',
21
+ 'Christchurch', 'Trinidad']
22
+
23
+ # Initialize FastAPI
24
+ app = FastAPI()
25
+
26
+ # Define the request model
27
+ class ScorePredictionRequest(BaseModel):
28
+ batting_team: str
29
+ bowling_team: str
30
+ city: str
31
+ current_score: int
32
+ overs: float
33
+ wickets: int
34
+ last_five: int
35
+
36
+ # Define the response model
37
+ class ScorePredictionResponse(BaseModel):
38
+ predicted_score: int
39
+
40
+ # Ensure input values are valid
41
+ def validate_input(data: ScorePredictionRequest):
42
+ if data.batting_team not in teams:
43
+ raise HTTPException(status_code=400, detail="Invalid batting team")
44
+ if data.bowling_team not in teams:
45
+ raise HTTPException(status_code=400, detail="Invalid bowling team")
46
+ if data.city not in cities:
47
+ raise HTTPException(status_code=400, detail="Invalid city")
48
+ if not (0 <= data.overs <= 20):
49
+ raise HTTPException(status_code=400, detail="Overs must be between 0 and 20")
50
+ if not (0 <= data.wickets <= 10):
51
+ raise HTTPException(status_code=400, detail="Wickets must be between 0 and 10")
52
+ if data.current_score < 0:
53
+ raise HTTPException(status_code=400, detail="Current score must be non-negative")
54
+ if data.last_five < 0:
55
+ raise HTTPException(status_code=400, detail="Runs scored in last 5 overs must be non-negative")
56
+
57
+ # Define the prediction endpoint
58
+ @app.post("/predict", response_model=ScorePredictionResponse)
59
+ def predict_score(data: ScorePredictionRequest):
60
+ # Validate input
61
+ validate_input(data)
62
+
63
+ # Calculate additional features
64
+ balls_left = 120 - (data.overs * 6)
65
+ wickets_left = 10 - data.wickets
66
+ crr = data.current_score / data.overs if data.overs > 0 else 0
67
+
68
+ # Create input dataframe for the model
69
+ input_df = pd.DataFrame(
70
+ {'batting_team': [data.batting_team], 'bowling_team': [data.bowling_team], 'city': [data.city],
71
+ 'current_score': [data.current_score], 'balls_left': [balls_left], 'wickets_left': [wickets_left],
72
+ 'crr': [crr], 'last_five': [data.last_five]}
73
+ )
74
+
75
+ # Predict the score
76
+ result = pipe.predict(input_df)
77
+ predicted_score = int(result[0])
78
+
79
+ # Return the prediction
80
+ return ScorePredictionResponse(predicted_score=predicted_score)
81
+
82
+ if __name__ == "__main__":
83
+ import uvicorn
84
+ uvicorn.run(app, host="0.0.0.0", port=8000)
app/pipe.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4cf3d0d97c298361ae8c0539456bfb5896aa499a0b00ed17ffabb2b19501a245
3
+ size 29835654