Prophet Evaluation Metrics
import pandas as pd
from fbprophet import Prophet
%matplotlib inline
Load Data
In this notebook we will use the Miles_Traveled dataset.
df = pd.read_csv('Miles_Traveled.csv')
df.head()
DATE | TRFVOLUSM227NFWA | |
---|---|---|
0 | 1970-01-01 | 80173.0 |
1 | 1970-02-01 | 77442.0 |
2 | 1970-03-01 | 90223.0 |
3 | 1970-04-01 | 89956.0 |
4 | 1970-05-01 | 97972.0 |
df.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 588 entries, 0 to 587
Data columns (total 2 columns):
DATE 588 non-null object
TRFVOLUSM227NFWA 588 non-null float64
dtypes: float64(1), object(1)
memory usage: 9.3+ KB
df.describe()
TRFVOLUSM227NFWA
count 588.000000
mean 190420.380952
std 57795.538934
min 77442.000000
25% 133579.000000
50% 196797.500000
75% 243211.500000
max 288145.000000
Change columns names to Prophet’s specs and change the data column to a timeseries object.
df.columns = ['ds', 'y']
df['ds'] = pd.to_datetime(df['ds'])
df.head()
ds | y | |
---|---|---|
0 | 1970-01-01 | 80173.0 |
1 | 1970-02-01 | 77442.0 |
2 | 1970-03-01 | 90223.0 |
3 | 1970-04-01 | 89956.0 |
4 | 1970-05-01 | 97972.0 |
df.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 588 entries, 0 to 587
Data columns (total 2 columns):
ds 588 non-null datetime64[ns]
y 588 non-null float64
dtypes: datetime64[ns](1), float64(1)
memory usage: 9.3 KB
Visualize Data
pd.plotting.register_matplotlib_converters()
df.plot(x='ds', y='y', figsize = (12,8));
Run Prophet on Test set and run Evaluations.
len(df)
588
# Create training set
train = df.iloc[:len(df)-12]
test = df.iloc[len(df)-12:]
# Create Prophet instance
m = Prophet()
# fit model
m.fit(train)
# create a placeholder for future predictions
future = m.make_future_dataframe(periods=12, freq='MS')
# get forecast on future df
forecast = m.predict(future)
INFO:fbprophet:Disabling weekly seasonality. Run prophet with weekly_seasonality=True to override this.
INFO:fbprophet:Disabling daily seasonality. Run prophet with daily_seasonality=True to override this.
# print forecast tail
forecast.tail()
ds | trend | yhat_lower | yhat_upper | trend_lower | trend_upper | additive_terms | additive_terms_lower | additive_terms_upper | yearly | yearly_lower | yearly_upper | multiplicative_terms | multiplicative_terms_lower | multiplicative_terms_upper | yhat | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
583 | 2018-08-01 | 263410.800604 | 273946.550023 | 285503.932507 | 263343.077564 | 263477.133709 | 16448.013049 | 16448.013049 | 16448.013049 | 16448.013049 | 16448.013049 | 16448.013049 | 0.0 | 0.0 | 0.0 | 279858.813654 |
584 | 2018-09-01 | 263552.915940 | 256365.077279 | 267835.455188 | 263447.397126 | 263648.573373 | -1670.418537 | -1670.418537 | -1670.418537 | -1670.418537 | -1670.418537 | -1670.418537 | 0.0 | 0.0 | 0.0 | 261882.497404 |
585 | 2018-10-01 | 263690.446911 | 262760.269741 | 274866.907060 | 263542.899517 | 263810.852602 | 5305.505873 | 5305.505873 | 5305.505873 | 5305.505873 | 5305.505873 | 5305.505873 | 0.0 | 0.0 | 0.0 | 268995.952784 |
586 | 2018-11-01 | 263832.562247 | 249921.672602 | 261436.555424 | 263658.943503 | 263978.577730 | -8208.986942 | -8208.986942 | -8208.986942 | -8208.986942 | -8208.986942 | -8208.986942 | 0.0 | 0.0 | 0.0 | 255623.575305 |
587 | 2018-12-01 | 263970.093217 | 251560.897604 | 263186.227119 | 263746.322784 | 264156.491737 | -6922.716937 | -6922.716937 | -6922.716937 | -6922.716937 | -6922.716937 | -6922.716937 | 0.0 | 0.0 | 0.0 | 257047.376280 |
Plot Predicted Values Against True Values
ax = forecast.plot(x='ds', y='yhat', label='Predictions', legend=True, figsize=(12,8))
test.plot(x='ds', y='y', label='True Test Data', legend=True, ax=ax);
#ax connects the two graphs to same axis.
# Zoom into the forecast
ax = forecast.plot(x='ds', y='yhat', label='Predictions', legend=True, figsize=(12,8))
test.plot(x='ds', y='y', label='True Test Data', legend=True, ax=ax, xlim=('2018-01-01', '2019-01-01'));
Evaluate Predictions Using Statsmodels Root Mean Squared Error (MSE)
from statsmodels.tools.eval_measures import rmse
# get predictions
predictions = forecast.iloc[-12:]['yhat']
predictions
576 243850.453937
577 235480.588794
578 262683.274392
579 262886.236399
580 272609.522601
581 272862.615300
582 279321.841101
583 279858.813654
584 261882.497404
585 268995.952784
586 255623.575305
587 257047.376280
Name: yhat, dtype: float64
# compare to test set
test['y']
576 245695.0
577 226660.0
578 268480.0
579 272475.0
580 286164.0
581 280877.0
582 288145.0
583 286608.0
584 260595.0
585 282174.0
586 258590.0
587 268413.0
Name: y, dtype: float64
# calculate rmse
rmse(predictions, test['y'])
8618.783155559411
test.mean()
268739.666667
dtype: float64
Evaluate Predictions Using Prophet’s Evaluation Tools
Prophet has it own evaluations tools that allows us to perform cross-validation in different sections of the dataset instead of the last section as we saw above. When performing cross-validation, there are 3 things that need to be defined:
- Initial training period
- Period lenght to perform validation
- Horizon of prediction
## import prophet eval tools
from fbprophet.diagnostics import cross_validation, performance_metrics
from fbprophet.plot import plot_cross_validation_metric
# Define:
# Initial -- period is 5 years
initial = 5 * 365
initial = str(initial) + ' days'
initial
'1825 days'
# Define:
# Period
period = 5 * 365
period = str(period) + ' days'
period
'1825 days'
# Define:
# Horizon -- one year ahead
horizon = 365
horizon = str(horizon) + ' days'
# perform cross-validation -- args: model, initial, period, horizon
df_cv = cross_validation(m, initial=initial, period=period, horizon=horizon)
INFO:fbprophet:Making 9 forecasts with cutoffs between 1976-12-11 00:00:00 and 2016-12-01 00:00:00
df_cv.head()
ds | yhat | yhat_lower | yhat_upper | y | cutoff | |
---|---|---|---|---|---|---|
0 | 1977-01-01 | 108479.087306 | 106990.603148 | 109837.430105 | 102445.0 | 1976-12-11 |
1 | 1977-02-01 | 102996.111502 | 101607.192966 | 104416.448231 | 102416.0 | 1976-12-11 |
2 | 1977-03-01 | 118973.317944 | 117460.910009 | 120468.316343 | 119960.0 | 1976-12-11 |
3 | 1977-04-01 | 120612.923539 | 119182.220594 | 122067.307008 | 121513.0 | 1976-12-11 |
4 | 1977-05-01 | 127883.031663 | 126442.196287 | 129351.375440 | 128884.0 | 1976-12-11 |
# Get data from cv
# mae --> mean absolute error
#mape --> mean absolute percent error
performance_metrics(df_cv)
horizon | mse | rmse | mae | mape | coverage | |
---|---|---|---|---|---|---|
0 | 52 days | 2.402227e+07 | 4901.251892 | 4506.384371 | 0.027631 | 0.4 |
1 | 53 days | 2.150811e+07 | 4637.683407 | 4238.662732 | 0.024863 | 0.4 |
2 | 54 days | 1.807689e+07 | 4251.692535 | 3708.943275 | 0.019933 | 0.5 |
3 | 55 days | 2.298205e+07 | 4793.960154 | 4236.275244 | 0.023042 | 0.4 |
4 | 57 days | 2.078937e+07 | 4559.535784 | 3972.087270 | 0.021317 | 0.5 |
5 | 58 days | 2.306545e+07 | 4802.649969 | 4248.916338 | 0.022521 | 0.4 |
6 | 59 days | 3.794246e+07 | 6159.745363 | 5069.232548 | 0.026221 | 0.4 |
7 | 60 days | 3.875108e+07 | 6225.036249 | 5136.940670 | 0.026702 | 0.4 |
8 | 62 days | 3.722136e+07 | 6100.930821 | 4941.278113 | 0.025845 | 0.4 |
9 | 80 days | 3.260728e+07 | 5710.278221 | 4353.547479 | 0.023861 | 0.5 |
10 | 81 days | 3.285223e+07 | 5731.686347 | 4462.450667 | 0.024597 | 0.5 |
11 | 82 days | 3.281764e+07 | 5728.668498 | 4457.259054 | 0.023572 | 0.6 |
12 | 84 days | 3.390813e+07 | 5823.069158 | 4733.704727 | 0.024963 | 0.6 |
13 | 85 days | 2.998379e+07 | 5475.745696 | 4377.309899 | 0.022171 | 0.7 |
14 | 86 days | 2.918944e+07 | 5402.725633 | 4118.776449 | 0.020737 | 0.7 |
15 | 87 days | 2.788883e+07 | 5280.987591 | 3983.875446 | 0.019695 | 0.8 |
16 | 89 days | 1.453884e+07 | 3812.982510 | 3298.866149 | 0.016154 | 0.8 |
17 | 90 days | 2.084092e+07 | 4565.185871 | 3681.957136 | 0.017056 | 0.8 |
18 | 111 days | 2.006300e+07 | 4479.173895 | 3478.888509 | 0.016505 | 0.8 |
19 | 112 days | 2.060786e+07 | 4539.588179 | 3633.640152 | 0.017634 | 0.8 |
20 | 113 days | 2.237650e+07 | 4730.380639 | 3919.190632 | 0.019163 | 0.7 |
21 | 115 days | 2.180882e+07 | 4669.991853 | 3817.735659 | 0.018297 | 0.7 |
22 | 116 days | 2.112879e+07 | 4596.606343 | 3693.295346 | 0.017470 | 0.7 |
23 | 117 days | 1.974488e+07 | 4443.520604 | 3337.159852 | 0.015782 | 0.7 |
24 | 118 days | 2.098299e+07 | 4580.718952 | 3665.511183 | 0.017075 | 0.7 |
25 | 120 days | 1.929732e+07 | 4392.870982 | 3308.657741 | 0.015710 | 0.7 |
26 | 121 days | 4.005074e+07 | 6328.565250 | 4249.836052 | 0.019007 | 0.7 |
27 | 141 days | 2.986741e+07 | 5465.108854 | 3335.856109 | 0.015991 | 0.8 |
28 | 142 days | 3.348058e+07 | 5786.240648 | 3853.646111 | 0.019583 | 0.7 |
29 | 143 days | 3.436861e+07 | 5862.474577 | 3991.409535 | 0.019963 | 0.6 |
... | ... | ... | ... | ... | ... | ... |
69 | 271 days | 3.940425e+07 | 6277.280588 | 4952.925812 | 0.023246 | 0.4 |
70 | 273 days | 4.762939e+07 | 6901.404877 | 5510.902256 | 0.025811 | 0.4 |
71 | 274 days | 5.215876e+07 | 7222.102079 | 6036.799827 | 0.027817 | 0.3 |
72 | 294 days | 4.300135e+07 | 6557.541692 | 5087.841083 | 0.024501 | 0.4 |
73 | 295 days | 4.779605e+07 | 6913.468723 | 5677.222656 | 0.028653 | 0.3 |
74 | 296 days | 4.445228e+07 | 6667.254544 | 5359.639092 | 0.025830 | 0.3 |
75 | 298 days | 4.284148e+07 | 6545.340220 | 4996.540611 | 0.023579 | 0.4 |
76 | 299 days | 4.492137e+07 | 6702.340811 | 5343.691988 | 0.025058 | 0.3 |
77 | 300 days | 4.491255e+07 | 6701.682253 | 5334.940901 | 0.024989 | 0.3 |
78 | 301 days | 4.029280e+07 | 6347.660786 | 4861.651539 | 0.022849 | 0.4 |
79 | 303 days | 2.450723e+07 | 4950.477945 | 3736.409294 | 0.018257 | 0.5 |
80 | 304 days | 4.007715e+07 | 6330.651308 | 4329.508835 | 0.019772 | 0.5 |
81 | 325 days | 3.542847e+07 | 5952.181744 | 3763.155215 | 0.018206 | 0.6 |
82 | 326 days | 3.675447e+07 | 6062.546374 | 4119.368273 | 0.021009 | 0.5 |
83 | 327 days | 3.206591e+07 | 5662.677267 | 3570.147529 | 0.016933 | 0.6 |
84 | 329 days | 3.766847e+07 | 6137.464499 | 4036.418932 | 0.019430 | 0.6 |
85 | 330 days | 3.854161e+07 | 6208.188825 | 4294.374211 | 0.020697 | 0.6 |
86 | 331 days | 3.727481e+07 | 6105.310294 | 4132.975253 | 0.019910 | 0.7 |
87 | 332 days | 4.115499e+07 | 6415.215244 | 4711.526108 | 0.022264 | 0.6 |
88 | 334 days | 4.087113e+07 | 6393.053559 | 4646.719134 | 0.022078 | 0.6 |
89 | 335 days | 4.742068e+07 | 6886.267502 | 5329.021094 | 0.024717 | 0.5 |
90 | 355 days | 2.233198e+07 | 4725.672310 | 4003.530744 | 0.021388 | 0.5 |
91 | 356 days | 2.398709e+07 | 4897.661254 | 4302.563236 | 0.023667 | 0.4 |
92 | 357 days | 2.302278e+07 | 4798.206369 | 4128.684970 | 0.022024 | 0.5 |
93 | 359 days | 2.486947e+07 | 4986.930376 | 4432.355223 | 0.023565 | 0.4 |
94 | 360 days | 1.814608e+07 | 4259.821515 | 3750.359483 | 0.019596 | 0.5 |
95 | 361 days | 1.726110e+07 | 4154.647536 | 3473.037339 | 0.018212 | 0.5 |
96 | 362 days | 3.173990e+07 | 5633.817508 | 4404.300729 | 0.022034 | 0.4 |
97 | 364 days | 2.986513e+07 | 5464.900040 | 4229.869860 | 0.021378 | 0.5 |
98 | 365 days | 5.443147e+07 | 7377.768377 | 5621.707803 | 0.026524 | 0.4 |
Plot Cross Validation Metrics
# plot rmse
plot_cross_validation_metric(df_cv, metric='rmse');
# plot mape
plot_cross_validation_metric(df_cv, metric='mape');