from datetime import datetime import numpy as np import pandas as pd from sklearn.ensemble import RandomForestRegressor import gradio as gr import plotly.graph_objects as go from huggingface_hub import from_pretrained_keras import os def predictAirPassengers(df, split): ts= pd.read_csv('AirPassengers.csv') df2 =ts.copy() ttSplit=split/100 ts['Month']=pd.to_datetime(ts['Month']) ts.rename(columns={'#Passengers':'Passengers'},inplace=True) ts=ts.set_index(['Month']) ts['months'] = [x.month for x in ts.index] ts['years'] = [x.year for x in ts.index] ts.reset_index(drop=True, inplace=True) # Split Data X=ts.drop("Passengers",axis=1) Y= ts["Passengers"] X_train=X[:int (len(Y)*ttSplit)] X_test=X[int(len(Y)*ttSplit):] Y_train=Y[:int (len(Y)*ttSplit)] Y_test=Y[int(len(Y)*ttSplit):] # fit the model rf = RandomForestRegressor() rf.fit(X_train, Y_train) df1=df2.set_index(['Month']) df1.rename(columns={'#Passengers':'Passengers'},inplace=True) train=df1.Passengers[:int (len(ts.Passengers)*ttSplit)] test=df1.Passengers[int(len(ts.Passengers)*ttSplit):] preds=rf.predict(X_test).astype(int) predictions=pd.DataFrame(preds,columns=['Passengers']) predictions.index=test.index predictions.reset_index(inplace=True) predictions['Month']=pd.to_datetime(predictions['Month']) print(predictions) #combine all into one table ts_df=df.copy() ts_df.rename(columns={'#Passengers':'Passengers'},inplace=True) train= ts_df[:int (len(ts_df)*ttSplit)] test= ts_df[int(len(ts_df)*ttSplit):] df2['Month']=pd.to_datetime(df2['Month']) df2.rename(columns={'#Passengers':'Passengers'},inplace=True) df3= predictions df2['origin']='ground truth' df3['origin']='prediction' df4=pd.concat([df2, df3]) print(df4) return df4 demo = gr.Interface( fn =predictAirPassengers, inputs = [ gr.Timeseries(label="Input for the timeseries", max_rows=1, interactive=False), gr.Slider(1, 100, value=75, step=1, label="Train test split percentage"), ], outputs= [ gr.LinePlot(x='Month', y='Passengers', color='origin') #gr.Timeseries(x='Month') ], examples=[ [os.path.join(os.path.abspath(''), "AirPassengers_dt.csv"), 75], ] ) demo.launch()