当前位置: 移动技术网 > IT编程>开发语言>.net > 机器学习框架ML.NET学习笔记【5】手写数字识别(续)

机器学习框架ML.NET学习笔记【5】手写数字识别(续)

2019年05月31日  | 移动技术网IT编程  | 我要评论

精品技术论坛,郄英才是谁的秘书,多方通话

一、概述

 上一篇文章我们利用ml.net的多元分类算法实现了一个手写数字识别的例子,这个例子存在一个问题,就是输入的数据是预处理过的,很不直观,这次我们要直接通过图片来进行学习和判断。思路很简单,就是写一个自定义的数据处理通道,输入为文件名,输出为float数字,里面保存的是像素信息。

 样本包括6万张训练图片和1万张测试图片,图片为灰度图片,分辨率为20*20 。train_tags.tsv文件对每个图片的数值进行了标记,如下:

  

二、源码

 全部代码: 

namespace multiclassclassification_mnist
{
    class program
    {
        //assets files download from:https://gitee.com/seabluescn/ml_assets
        static readonly string assetsfolder = @"d:\stepbystep\blogs\ml_assets\mnist";
        static readonly string traintagspath = path.combine(assetsfolder, "train_tags.tsv");
        static readonly string traindatafolder = path.combine(assetsfolder, "train");
        static readonly string modelpath = path.combine(environment.currentdirectory, "data", "sdca-model.zip");

        static void main(string[] args)
        {
            mlcontext mlcontext = new mlcontext(seed: 1);
          
            trainandsavemodel(mlcontext);
            testsomepredictions(mlcontext);

            console.writeline("hit any key to finish the app");
            console.readkey();
        }

        public static void trainandsavemodel(mlcontext mlcontext)
        {
            // step 1: 准备数据
            var fulldata = mlcontext.data.loadfromtextfile<inputdata>(path: traintagspath, separatorchar: '\t', hasheader: false);
            var traintestdata = mlcontext.data.traintestsplit(fulldata, testfraction: 0.1);
            var traindata = traintestdata.trainset;
            var testdata = traintestdata.testset;

            // step 2: 配置数据处理管道        
            var dataprocesspipeline = mlcontext.transforms.custommapping(new loadimageconversion().getmapping(), contractname: "loadimageconversionaction")
               .append(mlcontext.transforms.conversion.mapvaluetokey("label", "number", keyordinality: valuetokeymappingestimator.keyordinality.byvalue))
               .append(mlcontext.transforms.normalizemeanvariance( outputcolumnname: "featuresnormalizedbymeanvar", inputcolumnname: "imagepixels"));


            // step 3: 配置训练算法 (using a maximum entropy classification model trained with the l-bfgs method)
            var trainer = mlcontext.multiclassclassification.trainers.lbfgsmaximumentropy(labelcolumnname: "label", featurecolumnname: "featuresnormalizedbymeanvar");
            var trainingpipeline = dataprocesspipeline.append(trainer)
                 .append(mlcontext.transforms.conversion.mapkeytovalue("predictnumber", "label"));


            // step 4: 训练模型使其与数据集拟合           
            itransformer trainedmodel = trainingpipeline.fit(traindata);          

            // step 5:评估模型的准确性           
            var predictions = trainedmodel.transform(testdata);
            var metrics = mlcontext.multiclassclassification.evaluate(data: predictions, labelcolumnname: "label", scorecolumnname: "score");
            printmulticlassclassificationmetrics(trainer.tostring(), metrics);
          
            // step 6:保存模型            
            mlcontext.model.save(trainedmodel, traindata.schema, modelpath);           
        }

        private static void testsomepredictions(mlcontext mlcontext)
        {
            // load model           
            itransformer trainedmodel = mlcontext.model.load(modelpath, out var modelinputschema);

            // create prediction engine 
            var predengine = mlcontext.model.createpredictionengine<inputdata, outputdata>(trainedmodel);
          
            directoryinfo testfolder = new directoryinfo(path.combine(assetsfolder, "test"));           
            foreach(var image in testfolder.getfiles())
            {
                count++;

                inputdata img = new inputdata()
                {
                    filename = image.name
                };
                var result = predengine.predict(img);
               
                console.writeline($"current source={img.filename},predictresult={result.getpredictresult()}");                
            }
        }       
    }

    class inputdata
    {
        [loadcolumn(0)]
        public string filename;

        [loadcolumn(1)]
        public string number;

        [loadcolumn(1)]
        public float serial;       
    }

    class outputdata : inputdata
    {
        public float[] score;
        public int getpredictresult()
        {
            float max = 0;
            int index = 0;
            for (int i = 0; i < score.length; i++)
            {
                if (score[i] > max)
                {
                    max = score[i];
                    index = i;
                }
            }
            return index;
        }       
    }   
}

  

三、分析

 整个处理流程和上一篇文章基本一致,这里解释两个不一样的地方。

1、自定义的图片读取处理通道

namespace multiclassclassification_mnist
{
    public class loadimageconversioninput
    {
        public string  filename { get; set; }
    }
 
    public class loadimageconversionoutput
    {
        [vectortype(400)]
        public float[] imagepixels { get; set; }
        public string imagepath;
    }

    [custommappingfactoryattribute("loadimageconversionaction")]
    public class loadimageconversion : custommappingfactory<loadimageconversioninput, loadimageconversionoutput>
    {       
        static readonly string traindatafolder = @"d:\stepbystep\blogs\ml_assets\mnist\train";

        public void customaction(loadimageconversioninput input, loadimageconversionoutput output)
        {  
            string imagepath = path.combine(traindatafolder, input.filename);
            output.imagepath = imagepath;

            bitmap bmp = image.fromfile(imagepath) as bitmap;           

            output.imagepixels = new float[400];
            for (int x = 0; x < 20; x++)
                for (int y = 0; y < 20; y++)
                {
                    var pixel = bmp.getpixel(x, y);
                    var gray = (pixel.r + pixel.g + pixel.b) / 3 / 16;
                    output.imagepixels[x + y * 20] = gray;
                }           
            bmp.dispose();                     
        }

        public override action<loadimageconversioninput, loadimageconversionoutput> getmapping()
              => customaction;
    }
}

 这里可以看出,我们自定义的数据处理通道,输入为文件名称,输出是一个float数组,这里数组必须要指定宽度,由于图片分辨率为20*20,所以数组宽度指定为400,输出imagepath为文件详细地址,用来调试使用,没有实际用途。处理思路非常简单,遍历每个pixel,计算其灰度值,为了减少工作量我们把灰度值进行缩小,除以了16 ,由于后面数据会做归一化,所以这里影响不是太明显。

 

2、模型测试

            directoryinfo testfolder = new directoryinfo(path.combine(assetsfolder, "test"));
            int count = 0;
            int success = 0;
            foreach(var image in testfolder.getfiles())
            {
                count++;

                inputdata img = new inputdata()
                {
                    filename = image.name
                };
                var result = predengine.predict(img);

                if(int.parse(image.name.substring(0,1))==result.getpredictresult())
                {
                    success++;
                }                
            }

 我们把测试目录里的全面图片读出遍历了一遍,将其测试结果和实际结果做了一次验证,实际上是把评估(evaluate)的事情又重复做了一次,两次测试的成功率基本接近。

 

四、关于图片特征提取

我们是采用图片所有像素的灰度值来作为特征值的,但必须要强调的是:像素值矩阵不是图片的典型特征。虽然有时候对于较规则的图片,通过像素提取方式进行计算,也可以取得很好的效果,但在处理稍微复杂一点的图片的时候,就不管用了,原因很明显,我们人类在分析图片内容时看到的特征更多是线条等信息,绝对不是像素值,看下图:

我们人类很容易就判断出这两个图片表达的是同一件事情,但其像素值特征却相差甚远。

 传统的图片特征提取方式很多,比如:sift、hog、lbp、haar等。 现在采用tensorflow的模型进行特征提取效果非常好。下一篇文章介绍图片分类时再进行详细介绍。 

 

五、资源获取

源码下载地址:https://github.com/seabluescn/study_ml.net

工程名称:multiclassclassification_mnist_useful

mnist资源获取:https://gitee.com/seabluescn/ml_assets

点击查看机器学习框架ml.net学习笔记系列文章目录

如对本文有疑问,请在下面进行留言讨论,广大热心网友会与你互动!! 点击进行留言回复

相关文章:

验证码:
移动技术网