C#と機械学習でコロナ感染の今後を予測してみる

この記事は公開から4年以上経過しています。

AIや機械学習分野の主流はPythonですが、業務アプリ分野ではシェアの高い.NET開発環境で手軽に機械学習が利用できるMicrosoft製の機械学習ライブラリML.NETを使いC#で時系列データ予測を行う方法を紹介します。

もし、.NETでディープラーニング(ニューラルネットワーク)を実装したいという場合は「.NET(C#)開発でディープラーニングを実装する方法」が、お役に立つかもしれません。

2021.12.5追記:
厚労省の陽性者数データ形式変更に伴うプログラムソースコードの修正と併せて、全国または各都道府県毎の予測ができるように変更しました。予測値など記事本編の内容については当時のものとなります。

ここで紹介するコードはあくまで目的の機能を分かりやすく説明/実現するための最小限コードとなっています。このコードや説明を参考にされる場合は適切なエラーチェックや、充分な品質/動作検証を行って頂くようにお願い致します。
また、ここで例として取り上げた予測データはあくまで機械学習を使った予測の一例であり、実際の感染状況と必ずしも一致するものではありませんのでご注意ください。

サンプルソースコード

ソースコードは以下の2部構成です。

  • 本体プログラム(Program.cs)
  • データクラス群(DataClasses.cs)

ここで紹介するサンプルコード一式はこちら(github)からダウンロード可能ですので、よろしければご利用ください(Visual Studio Code or 2019/.NET 5/C#)。

データクラス群(DataClasses.cs)

機械学習に利用する入出力データ、CSVデータをマッピングするためのクラス定義です。

// DataClasses.cs
using System;
using CsvHelper.Configuration;

namespace Covid19InJapan
{
    // Data class for pcr_tested_daily.csv
    public class PCRTestedDaily
    {
        public DateTime Date { get; set; }

        public int? Count { get; set; }
    }

    // Data class for newly_confirmed_cases_daily.csv
    public class NewlyConfirmedCasesDaily
    {
        public DateTime Date { get; set; }

        public int[] Counts { get; set; }
    }

    // Map pcr_tested_daily.csv to PCRTestedDaily class
    public class PCRTestedDailyMap : ClassMap<PCRTestedDaily>
    {
        public PCRTestedDailyMap()
        {
            Map(m => m.Date).Index(0);
            Map(m => m.Count).Index(1);
        }
    }

    // Map newly_confirmed_cases_daily.csv to NewlyConfirmedCasesDaily class
    public class NewlyConfirmedCasesDailyMap : ClassMap<NewlyConfirmedCasesDaily>
    {
        public NewlyConfirmedCasesDailyMap()
        {
            Map(m => m.Date).Index(0);
            Map(m => m.Counts).Index(1);
        }
    }

    // Input data for training
    public class InputData
    {
        // Date
        public DateTime Date { get; set; }

        // The number of positive patients
        public float PositiveRate { get; set; }

        // For debugging
        public override string ToString()
        {
            return $"{PositiveRate}";
        }
    }

    // Results of prediction engine
    public class OutputData
    {
        // The number of positive persons in the prediction period.
        public float[] ForecastedPositiveRates { get; set; }

        // The minumum number of positive persons in the prediction period.
        public float[] LowerBoundPositives { get; set; }

        // The maximum number of positive persons in the prediction period.
        public float[] UpperBoundPositives { get; set; }
    }
}

本体プログラム(Program.cs)

プログラムのメイン処理です。

処理の流れは、

  1. 最新の感染状況データ(日本全国の合計)を、厚労省のWEBサイトからダウンロード
  2. 陽性者数とPCR検査実施件数それぞれの7日移動平均を使い陽性率を算出
  3. 算出した陽性率でモデルのトレーニングを行い、時系列データ予測エンジンで90日先まで予測
  4. 現在の陽性率/予測結果の陽性率グラフをWEBブラウザに表示(HTMLファイル)

となります。

CSV読み込み、数値計算、グラフ表示などを自前で実装すると非効率なので、サードパーティー製のNuGetパッケージを利用しています。

// Program.cs
using System;
using System.Collections.Generic;
using System.Globalization;
using System.IO;
using System.Linq;
using System.Net;
using CsvHelper;
using MathNet.Numerics.Statistics;
using Microsoft.ML;
using Microsoft.ML.Transforms.TimeSeries;
using XPlot.Plotly;

namespace Covid19InJapan
{
    internal static class Program
    {
        // Filenames of the CSV
        private const string NEWLY_CONFIRMED_CASES_DAILY_FILENAME = "newly_confirmed_cases_daily.csv";

        private const string PCR_TESTED_DAILY_FILENAME = "pcr_tested_daily.csv";

        // URLs to download the CSV
        private const string NEWLY_CONFIRMED_CASES_DAILY_URL = "https://covid19.mhlw.go.jp/public/opendata/newly_confirmed_cases_daily.csv";

        private const string PCR_TESTED_DAILY_URL = "https://www.mhlw.go.jp/content/pcr_tested_daily.csv";

        // Window size for moving average
        private const int MA = 7;

        // Index number of the prefecture(0 is whole country)
        private const int PREF_INDEX = 0;

        // Prefecture names
        private static readonly string[] PREF_NAMES =
@"Japan,
Hokkaido,
Aomori,Iwate,Miyagi,Akita,Yamagata,Fukushima,
Ibaraki,Tochigi,Gunma,Saitama,Chiba,Tokyo,Kanagawa,
Niigata,Toyama,Ishikawa,Fukui,Yamanashi,Nagano,Gifu,Shizuoka,Aichi,
Mie,Shiga,Kyoto,Osaka,Hyogo,Nara,Wakayama,
Tottori,Shimane,Okayama,Hiroshima,Yamaguchi,
Tokushima,Kagawa,Ehime,Kochi,
Fukuoka,Saga,Nagasaki,Kumamoto,Oita,Miyazaki,Kagoshima,Okinawa".Split(',');

        // Main
        private static void Main(string[] _)
        {
            // Download the number of daily PCR positive data and the number of daily PCR tested data from the ministry of health, Japan, if that file does not exist
            using (var wc = new WebClient())
            {
                // Set UA(Bypass 403 Forbidden)
                wc.Headers["User-Agent"] = "Mozilla/5.0 (Windows NT 10.0; Win64; x64)";

                if (!File.Exists(NEWLY_CONFIRMED_CASES_DAILY_FILENAME))
                    wc.DownloadFile(NEWLY_CONFIRMED_CASES_DAILY_URL, NEWLY_CONFIRMED_CASES_DAILY_FILENAME);
                if (!File.Exists(PCR_TESTED_DAILY_FILENAME))
                    wc.DownloadFile(PCR_TESTED_DAILY_URL, PCR_TESTED_DAILY_FILENAME);
            }

            // Load the both CSV files
            IEnumerable<NewlyConfirmedCasesDaily> positiveDaily = Enumerable.Empty<NewlyConfirmedCasesDaily>();
            using (var reader = new StreamReader(NEWLY_CONFIRMED_CASES_DAILY_FILENAME, true))
            using (var csv = new CsvReader(reader, CultureInfo.InvariantCulture))
            {
                csv.Configuration.RegisterClassMap<NewlyConfirmedCasesDailyMap>();
                positiveDaily = csv.GetRecords<NewlyConfirmedCasesDaily>().OrderBy((d) => d.Date).ToArray();
            }
            IEnumerable<PCRTestedDaily> testedDaily = Enumerable.Empty<PCRTestedDaily>();
            using (var reader = new StreamReader(PCR_TESTED_DAILY_FILENAME, true))
            using (var csv = new CsvReader(reader, CultureInfo.InvariantCulture))
            {
                csv.Configuration.RegisterClassMap<PCRTestedDailyMap>();
                testedDaily = csv.GetRecords<PCRTestedDaily>().Where(s => s.Count.HasValue).ToArray();
            }

            // Unify the number of daily PCR positive data with the PCR tested data for the time series
            var rawData = positiveDaily.Join(testedDaily, (positive) => positive.Date, (tested) => tested.Date, (positive, tested) =>
            {
                return new { Date = positive.Date, Tested = (double)tested.Count, Positive = (double)positive.Counts[PREF_INDEX] };
            });//.Where(s => s.Date >= new DateTime(2021, 2, 8));

            // Calculate each of the positive rates and convert to an array of the InputData class
            var enmDates = rawData.Select((o) => o.Date).GetEnumerator();
            var enmMATested = rawData.Select((o) => o.Tested).MovingAverage(MA).GetEnumerator();
            var enmMAPositive = rawData.Select((o) => o.Positive).MovingAverage(MA).GetEnumerator();
            IEnumerable<InputData> GetInputData()
            {
                while (enmDates.MoveNext() && enmMATested.MoveNext() && enmMAPositive.MoveNext())
                {
                    var rate = 0f;
                    if (enmMATested.Current > 0)
                        rate = (float)(enmMAPositive.Current / enmMATested.Current) * 100;
                    yield return new InputData { Date = enmDates.Current, PositiveRate = rate };
                }
            }
            var inputData = GetInputData().ToArray();

            // Create the new ML Context
            var mlContext = new MLContext(0);

            // Create a new IDataView from an array of the InputData class
            var data = mlContext.Data.LoadFromEnumerable(inputData);

            // Create a SSA model for forecasting
            var model = mlContext.Forecasting.ForecastBySsa(
                outputColumnName: nameof(OutputData.ForecastedPositiveRates),
                inputColumnName: nameof(InputData.PositiveRate),
                windowSize: 14,
                seriesLength: 30,
                trainSize: inputData.Length,
                horizon: 90,
                confidenceLevel: 0.95f,
                confidenceLowerBoundColumn: nameof(OutputData.LowerBoundPositives),
                confidenceUpperBoundColumn: nameof(OutputData.UpperBoundPositives));

            // To train
            var transformer = model.Fit(data);

            // Create a prediction engine
            var forecastingEngine = transformer.CreateTimeSeriesEngine<InputData, OutputData>(mlContext);

            // Predict the number of positive patients in the next period
            var outputData = forecastingEngine.Predict();

            // Convert the result into chart data
            var actualData = inputData.Select((o) =>
            {
                return new { Date = o.Date, Rate = o.PositiveRate };
            });
            var predictiveData = outputData.ForecastedPositiveRates.Select((o, i) =>
            {
                return new { Date = positiveDaily.Last().Date.AddDays(i + 1), Rate = ReLU(o) };
            });
            var actualGraph = new Graph.Scattergl()
            {
                name = "Actuality",
                x = actualData.Select((o) => o.Date),
                y = actualData.Select((o) => o.Rate),
                mode = "lines+markers"
            };
            var predictiveGraph = new Graph.Scattergl()
            {
                name = "Prediction",
                x = predictiveData.Select((o) => o.Date),
                y = predictiveData.Select((o) => o.Rate),
                mode = "lines+markers"
            };

            // Show the result in a graph
            var chart = Chart.Plot(new[] { actualGraph, predictiveGraph });
            chart.WithTitle($"COVID-19 positive rate prediction in {PREF_NAMES[PREF_INDEX]}. ({MA}-day moving average)");
            chart.WithXTitle("Date");
            chart.WithYTitle("Rate(%)");
            chart.WithSize(800, 800);
            chart.Show();
        }

        // ReLU function
        private static float ReLU(float i) => i < 0 ? 0 : i;
    }
}

実行結果

プログラムが正常に完了すると、WEBブラウザで以下のようなグラフを表示します。
途中でデータのダウンロードに失敗した場合などは後のCSV処理でエラーになるため、カレントディレクトリ内にあるCSVファイルを削除してから実行してください。

グラフ描画はPlotlyを利用していますが、コンソールアプリケーションでもデータを即座に視覚化できるため、データサイエンスにもオススメです。

予測結果

CSVデータの先頭(2020/2/5)以降の推移を表示しています。赤色のグラフ部分が今回の機械学習による予測結果です。
これを見る限りでは、この記事を書いた2020/9/17時点では緩やかな現象傾向にあるように見えます。

file

例えば以下のように2020/9/1以降を拡大表示してみると、いつどれくらいの値なのかも確認できます。

file

2020.12.16更新:↓

<2020.12.16時点の前回と同条件の予測結果>

file

<緊急事態宣言解除2020.5.25〜のデータによる予測結果>

file

2020.12.31更新:↓

<2020.12.31時点の前回と同条件の予測結果>

file

前回更新と同様に未だ下降を予想していますが、これは学習用のデータが検査数/感染者数のみで緊急事態宣言の解除やGoToXXX運用の開始、気温低下といった感染拡大要因となる情報がないことで時系列の相関関係が弱くなっていることに起因しています。

<緊急事態宣言解除2020.5.25〜のデータによる予測結果>

file

こちらでは、マルサスモデル的な増加(オーバーシュート)を示しているようにも見えます。

<緊急事態宣言解除2020.5.25〜のデータによるコロナ日毎死者数予測結果>

file

上のグラフは、ここで紹介したプログラムをベースとしてコロナによる日毎死者数の7日移動平均を入力データとして予測した結果です。
気温のピークを迎えてない7月第4週辺りから既に増加に転じているのが分かります。

2021.1.15更新:↓

以下、本日時点の前回と同条件の予測結果です。

<前回緊急事態宣言解除2020.5.25〜のデータによる予測結果>

file

皆さんのご想像通りといったところでしょうか。

<前回緊急事態宣言解除2020.5.25〜のデータによるコロナ日毎死者数予測結果>

file

こちらは目を疑いますが、医療崩壊を考慮していない単純な推移による予測だけでこの結果ですと…もう言葉も出ません。


たった数十行程度のプログラムですが、.NETで機械学習が容易に実装できるML.NETには、とても大きな可能性が感じられました。

今回利用した時系列予測モデルForecastBySsaは設定パラメータが少なく数ある機械学習アルゴリズムの中でも入門に最適だと思いますので、興味がある方は是非試してみてください。

機械学習による予測結果ではコロナ陽性率が減少傾向にあるように見えていますが、正しい傾向を判断するためには上記サンプルプログラム以外のさまざまな条件や要因を考慮しての予測が必要です。これまでと同様以上に、しっかりと感染対策を続けていきましょう。

参考ウェブサイトなど

以上です。

シェアする

  • このエントリーをはてなブックマークに追加

フォローする