专栏首页XAIJava分布式神经网络库Deeplearning4j之上手实践手写数字图像识别与模型训练

Java分布式神经网络库Deeplearning4j之上手实践手写数字图像识别与模型训练

环境的搭建可以参考另一篇文章。
  • 第一步运行MnistImagePipelineExampleSave代码下载数据集,并进行训练和保存

需要下载一个文件(windows默认保存在C:\Users\Administrator\AppData\Local\Temp\dl4j_Mnist)。文件存在git。如果网络不好。建议手动下载并解压。然后注释掉代码中的下载方法即可。如图所示:

训练需要一段时间等待即可。时间长短取决于自己电脑配置。

  • 第二步运行MnistImagePipelineLoadChooser代码。并选中一个手写数字图像。进行识别测试
package org.deeplearning4j.examples.dataexamples;

import org.datavec.image.loader.NativeImageLoader;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.swing.*;
import java.io.File;
import java.util.Arrays;
import java.util.List;

/**
 * 
 * 给定用户一个文件选择框来选中要测试的手写数字图像
 * 0-9数字 白色或者黑色背景进行识别
 */
public class MnistImagePipelineLoadChooser {
    private static Logger log = LoggerFactory.getLogger(MnistImagePipelineLoadChooser.class);


    /*
    Create a popup window to allow you to chose an image file to test against the
    trained Neural Network
    Chosen images will be automatically
    scaled to 28*28 grayscale
     */
    public static String fileChose(){
        JFileChooser fc = new JFileChooser();
        int ret = fc.showOpenDialog(null);
        if (ret == JFileChooser.APPROVE_OPTION)
        {
            File file = fc.getSelectedFile();
            String filename = file.getAbsolutePath();
            return filename;
        }
        else {
            return null;
        }
    }

    public static void main(String[] args) throws Exception{
        int height = 28;
        int width = 28;
        int channels = 1;

        List<Integer> labelList = Arrays.asList(0,1,2,3,4,5,6,7,8,9);

        // pop up file chooser
        String filechose = fileChose().toString();

        //LOAD NEURAL NETWORK

        // MnistImagePipelineExampleSave训练并保存模型
        File locationToSave = new File("trained_mnist_model.zip");
        // 检查保存的模型是否存在
        if(locationToSave.exists()){
            System.out.println("\n######存在保存的训练模型######\n");
        }else{
            System.out.println("\n\n#######File not found!#######");
            System.out.println("This example depends on running ");
            System.out.println("MnistImagePipelineExampleSave");
            System.out.println("Run that Example First");
            System.out.println("#############################\n\n");


            System.exit(0);
        }

        MultiLayerNetwork model = ModelSerializer.restoreMultiLayerNetwork(locationToSave);

        log.info("*********TEST YOUR IMAGE AGAINST SAVED NETWORK********");

        // 选择一个文件

        File file = new File(filechose);

        // 使用NativeImageLoader转换为数值矩阵

        NativeImageLoader loader = new NativeImageLoader(height, width, channels);

        // 得到图像并赋值INDArray

        INDArray image = loader.asMatrix(file);

        // 0-255
        // 0-1
        DataNormalization scaler = new ImagePreProcessingScaler(0,1);
        scaler.transform(image);
        // 传递到神经网络 并得到概率值
        INDArray output = model.output(image);

        log.info("## The FILE CHOSEN WAS " + filechose);
        log.info("## The Neural Nets Pediction ##");
        log.info("## list of probabilities per label ##");
        //log.info("## List of Labels in Order## ");
        //有序状态
        log.info(output.toString());
        log.info(labelList.toString());

    }



}
  • 选择图片运行后的结果
######Saved Model Found######

o.n.l.f.Nd4jBackend - Loaded [CpuBackend] backend
o.n.n.NativeOpsHolder - Number of threads used for NativeOps: 2
o.n.n.Nd4jBlas - Number of threads used for BLAS: 2
o.n.l.a.o.e.DefaultOpExecutioner - Backend used: [CPU]; OS: [Windows 7]
o.n.l.a.o.e.DefaultOpExecutioner - Cores: [4]; Memory: [1.8GB];
o.n.l.a.o.e.DefaultOpExecutioner - Blas vendor: [OPENBLAS]
o.d.n.m.MultiLayerNetwork - Starting MultiLayerNetwork with WorkspaceModes set to [training: NONE; inference: SEPARATE]
o.d.e.d.MnistImagePipelineLoadChooser - *********TEST YOUR IMAGE AGAINST SAVED NETWORK********
o.d.e.d.MnistImagePipelineLoadChooser - ## The FILE CHOSEN WAS C:\Users\Administrator\Desktop\93.png
o.d.e.d.MnistImagePipelineLoadChooser - ## The Neural Nets Pediction ##
o.d.e.d.MnistImagePipelineLoadChooser - ## list of probabilities per label ##
o.d.e.d.MnistImagePipelineLoadChooser - [0.00,  0.00,  0.00,  1.00,  0.00,  0.00,  0.00,  0.00,  0.00,  0.00]
o.d.e.d.MnistImagePipelineLoadChooser - [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
图中的数字为: 3
数字的置信度为:100.0%

Process finished with exit code 0

选择的图片为:

可见模型对黑白的手写数字识别度还算是可以的。

相关资料。建议还是去官网查阅。本博客只是进行上手实践

https://deeplearning4j.org/cn/

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • SpringMVC+Hibernate +MySql+ EasyUI实现CRUD(一)

    最新项目下载地址 访问地址 ? 1.基于easyui的 增 删 改 查 2.基于poi的导出excel 3.基于 SpringMVC HandlerInte...

    小帅丶
  • SpringMVC+Hibernate +MySql+ EasyUI实现POI导出Excel(二)

    注:使用的是MyEclipse 10.0 javaee 6.0 tomcat 6.0 导出指定列名。使用VO接受参数。 SpringMVC+Hibernate...

    小帅丶
  • Dubbo与Zookeeper、SpringMVC整合和使用(入门级)

    后续会补充完善SpringMVC部分 项目码云GIT地址:https://gitee.com/xshuai/dubbo/ 开发工具 MyEclipse 1...

    小帅丶
  • Android APP测试的日志文件抓取

      实时打印的主要有:logcat main,logcat radio,logcat events,tcpdump,还有高通平台的还会有QXDM日志

    流柯
  • Vert.x工具—使用Dropwizard Metrics对指标进行监控(Metrics使用教程)

        最近项目中需要针对Vert.x的运行效率进行监控,查阅Vert.x官文,发现目前提供了Dropwizard和Hawkular两种开箱即用的工具。本文将介...

    随风溜达的向日葵
  • Master Method Restated-主项定理-递归时间复杂度

    比较n^log b (a)与Θ(h(n)) 的大小(Θ的含义和“等于”类似,而大O的含义和“小于等于”类似,感觉好像这里都可以用):

    sr
  • python连接sql server并执

    python操作sql server,可以使用pymssql,成功安装pymssql后,按照如下的方法,可以连接数据库并执行查询操作:

    py3study
  • 【kafka】高吞吐源码分析-顺序写入与刷盘机制

    kafka作为一个处理实时数据和日志的管道,每秒可以处理几十万条消息。其瓶颈自然也在I/O层面,所以其高吞吐背后离不开如下几个特性:

    皮皮熊
  • CentOS7使用dnf安装mysql

    代码改变世界-coding
  • WebService获取数据实例及WSDL文件解读

    这是一个汇总webservice的网站:http://www.webxml.com.cn 里面有非常多可以供调用的WebService

    ZONGLYN

扫码关注云+社区

领取腾讯云代金券