Technical Description

Here are some technical details about our Caltech CS156 model, including what the component models do, the data we use, input preprocessing, our aggregation method, and references. Academic papers about the project as a whole and about the different innovations in it are in preparation.


The most successful component models in our aggregate are neural networks, five in all. We also have other data-driven models as well as epidemiological models in the mix.

All the models were trained to forecast a full distribution rather than just a point prediction. They were rewarded for both accuracy and precision using the pinball loss function. This gave the aggregation algorithm added information about the performance of individual models. For instance, the model that is most uncertain about its prediction is rejected even if it gets the prediction right on average.

Here are the component models that made it to the aggregate:

  1. Feed-forward Neural Network: The architecture consists of a single feed-forward neural network trained on mean-squared error of county-level daily deaths. It has a shallow architecture with ReLU activation and dropout. Model input consists of stationary and lagged dynamic features. A negative binomial distribution is assumed on top of the mean predictions to generate a quantile distribution.
  2. Quantile Neural Network: The architecture consists of nine independent feed-forward neural networks, each optimizing for county-level daily death quantile loss of one of nine quantiles. Each network has a shallow architecture with ReLU activation and dropout. Model input consists of stationary and lagged dynamic features. Output is concatenated to form a quantile distribution.
  3. LSTM (Long Short-Term Memory): This is a specialized Neural-Network architecture, and its learning objective is the difference between prediction for various quantiles. Demographic/economic/healthcare features are used to initialize LSTM layer and other time-variant features including mortality and mobility are passed through as sequences.
  4. Conditional LSTM: This model uses static features such as demographic data to initialize the hidden states of LSTM. For time-series data, it uses cases, deaths, daily increase in the total test results, and flu seasonality for the last 21 days to make predictions.
  5. Encoder-decoder conditional LSTM: This model takes 28 days of deaths, cases, and mobility data as input, along with fixed features, including population density and overall population of the relevant county. Using this architecture and input features, the LSTM can be trained to predict deaths at various time horizons, including 4 weeks.
  6. Autoregressive model: It takes as inputs smoothed NYT cases/deaths data. Predictions are a blend of county and country-level OLS models, using the previous 2 weeks to validate hyperparameters (e.g., window sizes per feature, growth bounds on model increases) and KMeans to apply error residuals sampled from validation.
  7. Seasonal Autoregressive model: The model uses a combination of ARMA and SARMA whose main inputs are previous death numbers, with adjustments based on other factors such as demographics and mobility. An integrated version ARIMA and SARIMA were employed to deal with non-stationarity. The seasonal model SARMA received 70% of the weight in the final model.
  8. Decision Tree model: The model fits two decision-tree regressors to a county's daily features to predict the next day's cases and deaths for that county. The features used include the previous deaths and cases, two-week lagged mobility features, and population data. The model’s predictions are aggregated with the average of the existing data to reduce variance over time.
  9. Gradient-Boosted Decision Trees: The model is a combination of gradient-boosted tree regressors, all trained directly on quantile loss. For each county, Five separate GBDT models are trained, and at forecast time the average of those 5 predictions is taken as the prediction for a given day.
  10. K-nearest-neighbors model: For every county, the algorithm compares the county's case trend over the last 28 days to historical case trends in all other counties since the start of the pandemic and finds the snapshots that most closely resemble today. It then predicts future deaths by bringing forward the death trends following those historical snapshots and applies dynamic time warping to make the timelines align as well as possible.
  11. Gaussian Process model: This model trains a separate gaussian process regressor for each county using lagged cases and lagged case derivatives as reported by the NYT. The kernel and some hyperparameters were optimized by minimizing pinball loss on a 2 week validation set.
  12. Bayesian epidemiological model: Inspired by the Imperial College model, it fits to the per-day, per-country death data for the past 30 days, under the assumption that the true number of infections is proportional to the reported cases, and that the percentage of infections attributable to the elderly is the same across the country as it is in California.
  13. Two-group epidemiological model: With young and old people modelled as two separate but interacting groups (see 'Young and Old' model below), the interaction rates are calculated from mobility under the assumption that mobility is a fraction of pre-pandemic mobility, but that movement routes have not changed. We fit to deaths per county and the California age demographics for reported cases, assuming the latter is representative of the country.
  14. Curve-fitting model: It assumes that the number of deaths is proportional to the number of cases 25 days prior, and fits to the death data for the past 30 days. The case fatality rate for each county is a free parameter, which we fit for in a Bayesian framework.

Young and Old: We created an analytic model for the pandemic that treats the young (under 65) and the old (over 65) as two distinct groups as far as the spread of the virus is concerned. This was in reaction to the different social behavior of the two groups during the pandemic, as well as their different IFR (Infection Fatality Ratio). The model recursively predicts the number of newly infected within each group every day. It puts together established results about infectiousness, hospitalization, and mortality, as functions of time, to infer the newly-infected numbers based on available case and mortality data within the two groups. Instead of one reproduction number R, we have a 2x2 matrix describing within-group infections and between-group infections. Some simplifying assumptions were made, but the model had out-of-sample predictive power and created meaningful estimates of crossover infections where sufficient data were available; at the national and state level, and in some populous counties. The results are depicted in the Infection Dynamics graphic in the home page.


Our team explored numerous data sets that are directly or indirectly related to the pandemic trajectory, and also some data sets that were conjectured to possibly have an effect. Some data were useful and others were not, as is often the case in a data-driven approach. The following data sets are the ones that ended up being part of the inputs for at least one component model.

COVID Epidemiology: We draw our numbers of total confirmed COVID cases and deaths in US counties from the New York Times. This serves both as an input and as a reference for judging model performance. We also draw international COVID data from Johns Hopkins. Some states also provide data on age breakdown of confirmed cases and deaths (e.g. California).

COVID Testing and Hospitalizations: The COVID-19 Tracking Project provides state-level data on testing volumes and COVID hospitalization numbers.

COVID-Related Policies: We draw state-level COVID policy actions from the Kaiser Family Foundation.

County Demographics: The US Census Bureau American Community Survey provides a broad array of demographic data by county, including population, population density, age distribution, gender ratio, number of inter-generational households, and education and income levels.

Mobility: We draw aggregated mobility data from Google and Descartes Labs at the county level.

County Geolocations: We use spatial data on county center longitude and latitude to inform geographic clustering (

Hospital Capacity: We use the number of ICU beds per county from Kaiser Health News to estimate hospital capacity for COVID cases.


In addition to standard input preprocessing techniques in Machine Learning that were applied for individual component models, the following preprocessing methods were applied to the raw input data and utilized throughout the project.

Outliers: A Hampel filter is used to detect and remove outliers in each time series for model training. The filter consists of a sliding window of 21 days. If a point lies outside 6 standard deviations of the sliding window median, the point is marked as an outlier and subsequently replaced by the median.

Smoothing: Time series can be heavily smoothed to separate out the general trend from the weekly seasonality and stochastic noise. A Savitzky-Golay filter of sliding window length 15 and polyorder 1 implements smoothing through successive fitting of adjacent data points with low degree polynomials.


Different models have different strengths and weaknesses, and they work best in different locations and at different times. Our aggregation method capitalizes on these differences by extracting the best parts of each model and putting these parts together.

For instance, we take the models that are best at predicting the immediate future and create a reliable prediction for one week, then use that prediction to anchor other models that are not as good at short-term prediction. We take the models that are best in predicting the general trend of the pandemic, even if they get the absolute numbers wrong, and use them to create a "trend correction" function that anchors other models that may be better at getting the ballpark numbers right, but not as good at predicting the trend. Same principle is applied for weekly patterns and for predictions in sparsely populated counties.

It should be noted that all of these aggregation stages are automated, data-driven, and adaptive. The aggregation algorithm looks at the track record of the models over the past 6-12 weeks and automatically decides what the best combination is for each stage.


  • Bathwal, R., Chitta, P., Tirumala, K., & Varadarajan, V. (2020). Ensemble Machine Learning Methods for Modeling COVID19 Deaths. arXiv preprint arXiv:2010.04052.
  • Rahil Bathwal, Predicting the Spread of COVID-19 Using Artificial Intelligence, SURF presentation, October 2020.
  • Nicholas Chang, Using Machine Learning to predict the spread of COVID-19, SURF presentation, October 2020.
  • Jethin Gowda, Predicting the Spread of COVID-19 using Artificial Intelligence, SURF presentation, October 2020.
  • Jin, Q. (2020). Time Warping Clustering for the Forecast and Analysis of COVID-19. Accepted for MIT IEEE Undergraduate Research Technology Conference, preprint.
  • Samuel Lushtak, Developing and Publishing Statistical Models of the COVID-19 Pandemic, SURF presentation, October 2020.
  • Mann, E., Chevalier, Q., Magda, G., & Gonzalgo, M. (2020, September 24). Forecasting The Spread of COVID-19 with Enhanced Linear Autoregression (Version 1.0.0), Zenodo.
  • Max Popken, Predicting COVID-19 Spread, SURF presentation, October 2020.
  • Jagath Vytheeswaran, Using Artificial Intelligence to Predict the Spread of COVID-19, SURF presentation, October 2020.

External References: These are selected papers that influenced our project. A comprehensive bibliography will appear in our published papers.

  • Michael Betancourt, A conceptual introduction to Hamiltonian Monte Carlo, arXiv preprint arXiv:1701.02434, 2017.
  • Seth Flaxman et al, Estimating the number of infections and the impact of non-pharmaceutical interventions on COVID-19 in 11 European countries, Imperial College Report 13, 2020.
Back to top