ds3-ml-model / app.py
Chaninder Rishi
Update app.py
4df6604
import streamlit as st
import pandas as pd
import numpy as np
import csv
import json
import matplotlib.pyplot as plt
import ast
#import pickle
import sklearn
from sklearn import linear_model
df = pd.read_csv('emily_election.csv')
#loaded_model = pickle.load(open(filename, 'rb'))
df['spend'] = df['cum_spend']
df['runtime'] = df['cumulative_ad_runtime'].apply(lambda s: int(s.split('days')[0]))
df['impressions'] = df['cumulative_impressions_by_region'].apply(lambda d: ast.literal_eval(d))
df['impressions'] = df['impressions'].apply(lambda d: np.array(list(d.values())).sum())
#feature 3 (for later)
df['audience_size'] = df['cumulative_est_audience'].apply(lambda d: ast.literal_eval(d))
df['audience_size'] = df['audience_size'].apply(lambda d: np.array(list(d.values())).sum())
#data = df[['runtime', 'spend', 'impressions']]
data = df[['runtime', 'spend', 'audience_size','impressions']]
msk = np.random.rand(len(data)) < 0.8
train = data[msk]
test = data[~msk]
#new_train = train[train['impressions'] < 1000000]
train['spend'] = train['spend'].astype('float')
new_train = train[(train['spend'] > 250)]
new_train = new_train[new_train['runtime']>4]
#this model predicts impressions given the runtime and the spend
regr = linear_model.LinearRegression()
new_train['log_runtime'] = np.log(new_train['runtime'])
new_train['log_spend'] = np.log(new_train['spend'])
new_train['log_impressions'] = np.log(new_train['impressions'])
new_train.replace([np.inf, -np.inf], np.nan, inplace=True)
new_train.dropna(inplace=True)
print(new_train.to_string())
x = np.asanyarray(new_train[['log_runtime', 'log_spend']])
y = np.asanyarray(new_train[['log_impressions']])
regr.fit(x, y)
spend = st.number_input('Enter Spend (in dollars): ')
runtime = st.number_input('Enter runtime (in days)')
if spend and runtime:
pred= regr.predict([np.log([spend, runtime])])
st.write('70% confidence interval for number of impressions is: {} to {} hits'.format(int(np.exp(pred[0][0]-1.65)), int(np.exp(pred[0][0]+1.65))))