C#でUMAPによる手書き文字(MNIST)のデータ次元削減を試してみる

この記事は公開から4年以上経過しています。

機械学習でお馴染みの次元削減サンプルといえばPythonとt-SNEの組み合わせですが、今回は.NET CoreのC#開発環境でUMAPを使ってNMINSTデータの次元削減を行なう方法を紹介します。

ここで紹介するコードはあくまで目的の機能を分かりやすく説明/実現するための最小限コードとなっています。このコードや説明を参考にされる場合は適切なエラーチェックや、充分な品質/動作検証を行って頂くようにお願い致します。

サンプルソースコード

ソースコードは以下の1本のみです。

  • 本体プログラム(Program.cs)

ここで紹介するサンプルコード一式はこちら(github)からダウンロード可能ですので、よろしければご利用ください(VisualStudioCode/.NETCore3.1/C#)。

本体プログラム(Program.cs)

プログラムのメイン処理です。

処理の流れは、

  1. gz形式のMNISTデータをインターネットからダウンロード
  2. gzファイルを展開して、ラベルとイメージデータを抽出
  3. UMAPで高次元データを低次元に圧縮
  4. 結果を散布図としてWEBブラウザに表示(HTMLファイル)

となります。

MNISTデータはヤン・ルカン氏のWEBサイトにて公開されているものを初回実行時(ファイルが未存在のときのみ)ダウンロードして利用するようにしています。
各データフォーマット仕様はデータダウンロードページ上に掲載されていますが、留意すべき点はデータがビッグエンディアンであること位でしょうか。

ちなみに、今回のグラフ描画にも過去のエントリと同様にXPlot.Plotlyを利用しています。

// Program.cs
using System.IO.Compression;
using System.Net;
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using UMAP;
using XPlot.Plotly;

namespace UMAPExampleInCSharp
{
    class Program
    {
        // MNISTデータのダウンロード元URL
        const string LABEL_FILE_URL = "http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz";
        const string IMAGE_FILE_URL = "http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz";

        // MNISTデータファイル名
        const string LABEL_FILE_NAME = "t10k-labels-idx1-ubyte.gz";
        const string IMAGE_FILE_NAME = "t10k-images-idx3-ubyte.gz";

        // Main
        static void Main(string[] args)
        {
            // カレントにデータが存在しない場合はヤン・ルカン氏のWEBサイトからダウンロード
            using (var wc = new WebClient())
            {
                if (!File.Exists(LABEL_FILE_NAME))
                    wc.DownloadFile(LABEL_FILE_URL, LABEL_FILE_NAME);
                if (!File.Exists(IMAGE_FILE_NAME))
                    wc.DownloadFile(IMAGE_FILE_URL, IMAGE_FILE_NAME);
            }

            // MNISTラベルデータをロード
            var labels = GetLabels().ToArray();

            // MNIST画像データをロード
            var pixels = GetImages().ToArray();

            // UMAPによる次元削減
            var umap = new Umap();
            var epochs = umap.InitializeFit(pixels);
            for (var i = 0; i < epochs; ++i)
                umap.Step();
            var embedding = umap.GetEmbedding().AsEnumerable();

            // 次元削減後のデータからグラフを生成
            var graph = new Graph.Scatter()
            {
                x = embedding.Select((o) => o[0]),
                y = embedding.Select((o) => o[1]),
                text = labels.Select((o) => o.ToString()),
                mode = "markers",
                marker = new Graph.Marker { color = labels, colorscale = "Rainbow", showscale = true },
            };

            var chart = Chart.Plot(graph);
            chart.WithTitle("MNIST Embedded via UMAP");
            chart.WithXTitle("X");
            chart.WithYTitle("Y");
            chart.WithSize(800, 800);
            chart.Show();
        }

        // GZファイルからデータ抽出
        private static byte[] Unzip(string filePath)
        {
            using (var fs = new FileStream(filePath, FileMode.Open))
            {
                using (var gzs = new GZipStream(fs, CompressionMode.Decompress))
                {
                    using (var ms = new MemoryStream())
                    {
                        gzs.CopyTo(ms);
                        return ms.ToArray();
                    }
                }
            }
        }

        // MNISTデータファイルからラベルデータを取得
        private static IEnumerable<float> GetLabels()
        {
            var binData = Unzip(LABEL_FILE_NAME);
            // Id of MNIST data(2049)
            var magicNumber = BitConverter.ToInt32(binData.Take(4).Reverse().ToArray());
            // Number of images(10000)
            var numOfItems = BitConverter.ToInt32(binData.Skip(4).Take(4).Reverse().ToArray());

            return binData.Skip(8).Select((o) => (float)o);
        }

        // MNISTデータファイルから画像データを取得
        private static IEnumerable<float[]> GetImages()
        {
            var binData = Unzip(IMAGE_FILE_NAME);
            // Magic number(2051)
            var magicNumber = BitConverter.ToInt32(binData.Take(4).Reverse().ToArray());
            // Number of images(10000)
            var numOfItems = BitConverter.ToInt32(binData.Skip(4).Take(4).Reverse().ToArray());
            // Number of rows(28)
            var numOfRows = BitConverter.ToInt32(binData.Skip(8).Take(4).Reverse().ToArray());
            // Number of columns(28)
            var numOfColumns = BitConverter.ToInt32(binData.Skip(12).Take(4).Reverse().ToArray());

            var length = numOfRows * numOfColumns;

            for (int i = 16, max = binData.Count(); i < max; i += length)
            {
                // 1画像データ単位に分割
                var chunk = binData.Skip(i).Select((o) => (float)o).Take(length).ToArray();
                yield return chunk;
            }
        }
    }
}

実行結果

10000件と比較的データ数が多いため、環境によっては多少時間が掛かるかもしれませんが、プログラムが正常に完了するとWEBブラウザで以下のような散布図を表示します。

どうでしょうか。手書きの0〜9を784次元で表した画像イメージデータが2次元まで削減されても、おおむね10分類にクラス分けされているのが視覚的に判断できると思います。

データのダウンロードが失敗すると後続処理がエラーになるため、カレントディレクトリ内にあるデータファイルを削除してから実行してください。

file

参考ウェブサイトなど

以上です。

シェアする

  • このエントリーをはてなブックマークに追加

フォローする

コメント

  1. ゴン太 より:

    pythonとc#での、実行速度の違いが気になります。