Faith-theAnalyst commited on
Commit
82bd30a
1 Parent(s): 6731796

Add application file

Browse files
Files changed (4) hide show
  1. .gitignore +1 -0
  2. Dockerfile +12 -0
  3. main.py +94 -0
  4. requirements.txt +37 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ /venv_api
Dockerfile ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM Python:3.11
2
+
3
+ WORKDIR /code
4
+
5
+ COPY ./requirements.txt /code/requirements.txt
6
+
7
+ RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
8
+
9
+ COPY . .
10
+
11
+ CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "7860"]
12
+
main.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI,HTTPException
2
+ from typing import Literal,List
3
+ import uvicorn
4
+ from pydantic import BaseModel
5
+ import pandas as pd
6
+ import os
7
+ import pickle
8
+
9
+ # setup
10
+ SRC = os.path.abspath('./SRC/Assets')
11
+
12
+ # Load the pipeline using pickle
13
+ pipeline_path = os.path.join(SRC, 'pipeline.pkl')
14
+ with open(pipeline_path, 'rb') as file:
15
+ pipeline = pickle.load(file)
16
+
17
+ # Load the encoder using pickle
18
+ model_path = os.path.join(SRC, 'rfc_model.pkl')
19
+ with open(model_path, 'rb') as file:
20
+ model = pickle.load(file)
21
+
22
+ app = FastAPI(
23
+ title= 'Income Classification FastAPI',
24
+ description='A FastAPI service to classify individuals based on income level using a trained machine learning model.',
25
+ version= '1.0.0'
26
+ )
27
+
28
+ class IncomePredictionInput(BaseModel):
29
+ age: int
30
+ gender: str
31
+ education: str
32
+ worker_class: str
33
+ marital_status: str
34
+ race: str
35
+ is_hispanic: str
36
+ employment_commitment: str
37
+ employment_stat: int
38
+ wage_per_hour: int
39
+ working_week_per_year: int
40
+ industry_code: int
41
+ industry_code_main: str
42
+ occupation_code: int
43
+ occupation_code_main: str
44
+ total_employed: int
45
+ household_summary: str
46
+ vet_benefit: int
47
+ tax_status: str
48
+ gains: int
49
+ losses: int
50
+ stocks_status: int
51
+ citizenship: str
52
+ importance_of_record: float
53
+
54
+
55
+ class IncomePredictionOutput(BaseModel):
56
+ income_prediction: str
57
+ prediction_probability: float
58
+
59
+
60
+ # get
61
+ @app.get('/')
62
+ def home():
63
+ return {
64
+ 'message': 'Income Classification FastAPI',
65
+ 'description': 'FastAPI service to classify individuals based on income level.',
66
+ 'instruction': 'Click here (/docs) to access API documentation and test endpoints.'
67
+ }
68
+
69
+
70
+ # post
71
+ @app.post('/classify', response_model=IncomePredictionOutput)
72
+ def income_classification(income: IncomePredictionInput):
73
+ try:
74
+ # Convert input data to DataFrame
75
+ input_df = pd.DataFrame([dict(income)])
76
+
77
+ # Preprocess the input data through the pipeline
78
+ input_df_transformed = pipeline.transform(input_df)
79
+
80
+ # Make predictions
81
+ prediction = model.predict(input_df_transformed)
82
+ probability = model.predict_proba(input_df_transformed).max(axis=1)[0]
83
+
84
+ prediction_result = "Above Limit" if prediction[0] == 1 else "Below Limit"
85
+ return {"income_prediction": prediction_result, "prediction_probability": probability}
86
+
87
+ except Exception as e:
88
+ error_detail = str(e)
89
+ raise HTTPException(status_code=500, detail=f"Error during classification: {error_detail}")
90
+
91
+
92
+ if __name__ == '__main__':
93
+ uvicorn.run('main:app', reload=True)
94
+
requirements.txt ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ annotated-types==0.6.0
2
+ anyio==4.2.0
3
+ certifi==2023.11.17
4
+ click==8.1.7
5
+ colorama==0.4.6
6
+ dnspython==2.4.2
7
+ email-validator==2.1.0.post1
8
+ fastapi==0.108.0
9
+ h11==0.14.0
10
+ httpcore==1.0.2
11
+ httptools==0.6.1
12
+ httpx==0.26.0
13
+ idna==3.6
14
+ itsdangerous==2.1.2
15
+ Jinja2==3.1.2
16
+ MarkupSafe==2.1.3
17
+ numpy==1.26.3
18
+ orjson==3.9.10
19
+ pandas==2.1.4
20
+ pydantic==2.5.3
21
+ pydantic-extra-types==2.4.0
22
+ pydantic-settings==2.1.0
23
+ pydantic_core==2.14.6
24
+ python-dateutil==2.8.2
25
+ python-dotenv==1.0.0
26
+ python-multipart==0.0.6
27
+ pytz==2023.3.post1
28
+ PyYAML==6.0.1
29
+ six==1.16.0
30
+ sniffio==1.3.0
31
+ starlette==0.32.0.post1
32
+ typing_extensions==4.9.0
33
+ tzdata==2023.4
34
+ ujson==5.9.0
35
+ uvicorn==0.25.0
36
+ watchfiles==0.21.0
37
+ websockets==12.0