前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >SpringBoot用深度学习模型识别数字:开发详解

SpringBoot用深度学习模型识别数字:开发详解

作者头像
程序员欣宸
发布2021-12-07 10:23:23
9530
发布2021-12-07 10:23:23
举报
文章被收录于专栏:实战docker实战docker

本篇概览

  • 前文《三分钟体验:SpringBoot用深度学习模型识别数字》中,咱们轻点鼠标体验了一个Java应用,该应用集成了深度学习模型,能识别出图像中的手写数字,那篇文章以体验和操作为主,并没有谈到背后的实现
  • 此刻的您,如果之前对深度学习了解不多,只随着《三分钟体验:SpringBoot用深度学习模型识别数字》做过简单体验,现在应该一头雾水,心中可能有以下疑问:
  • 前文提到的模型文件minist-model.zip是什么?怎么来的?
  • SpringBoot拿到模型文件后,怎么用的?和识别功能有什么关系?
  • 今天咱们一起来写代码,开发出前文的那个应用,也会解答您的疑问,整篇文章由以下内容构成:
  • 全流程概览
  • 训练实战
  • SpringBoot应用的设计阶段,重点说明第一个关键点
  • SpringBoot应用的设计阶段,重点说明第二个关键点
  • SpringBoot应用的设计阶段,梳理完整的流程
  • SpringBoot应用编码
  • 将应用制作成docker镜像

全流程概览

  • 简单的说,如果想通过深度学习来解决实际问题(以图片识别为例),需要要做好以下两件事:
  • 拿已有数据做训练,训练结果保存在模型文件中
  • 在业务代码中使用模型文件来识别图片
  • 先说说训练:
  • 提前准备一些手写数字图片,这些图片对应的数字已确定,如下图,目录9下面全是9的手写体
  1. 编写训练模型的代码,设置神经网络的各种参数,例如:归一化、激活函数、损失函数、卷积层数等等
  2. 执行训练模型的代码,将上述图片拿去训练
  3. 将训练结果保存到文件中,这个文件就是模型文件,前文中就是minist-model.zip
  4. 至此训练完成
  5. 接下来就是使用此模型文件解决实际问题
  6. 开发一个业务应用
  7. 应用中加载前面训练生成的模型文件
  8. 收到用户提交的数据,交给模型处理
  9. 将处理结果返回给用户
  10. 以上就是将深度学习用于业务场景的大致过程,接下来咱们从训练出发,开始实战过程

源码下载

名称

链接

备注

项目主页

该项目在GitHub上的主页

git仓库地址(https)

该项目源码的仓库地址,https协议

git仓库地址(ssh)

git@github.com:zq2599/blog_demos.git

该项目源码的仓库地址,ssh协议

  • 这个git项目中有多个文件夹,《DL4J实战》系列的源码在dl4j-tutorials文件夹下,如下图红框所示:
  • dl4j-tutorials文件夹下有多个子工程,本次实战代码在dlfj-tutorials目录下,如下图红框:

训练实战

  • 整个训练的详细过程,也就是minist-model.zip的生成过程,已在《DL4J实战之三:经典卷积实例(LeNet-5)》一文中详细说明,也给出了完整的代码,您只要按照文中所说操作一遍即可,不过操作之前有个问题要特别注意:
  • 《DL4J实战之三:经典卷积实例(LeNet-5)》一文中使用的deeplearning4j框架的版本是1.0.0-beta6,请您将其改为1.0.0-beta7,具体的改动方法是打开simple-convolution工程的pom.xml文件(注意是simple-convolution工程,不是它的父工程dlfj-tutorials),修改下图红框2位置的内容:
  • 改完后即可运行程序生成模型文件,然后进入使用模型的阶段
  • 您可能会问,为什么之 《DL4J实战之三:经典卷积实例(LeNet-5)》一文中会用1.0.0-beta6版本呢?其实那里是为了使用GPU加速训练过程,那时候1.0.0-beta7不支持CUDA-9.2,在本文中不会用到GPU加速,因此推荐使用1.0.0-beta7版本
  • 接下来开始开发SpringBoot应用,在应用中使用模型去识别图片

SpringBoot应用设计(第一个关键点)

  • 在设计阶段有两个关键点要提前注意,第一个和图片大小有关:在训练模型时,使用的图片都是28*28像素,所以咱们的应用在收到用户提交的图片后,需要做缩放处理将其缩放到28*28像素

SpringBoot应用设计(第二个关键点)

  • 再看看用于训练的图片,如下图,所有图片都是黑底白字:
  • 那么问题来了:模型是根据黑底白字训练的,无法识别白底黑字,遇到白底黑字的图片该如何处理呢?
  • 所以如果用户输入的是白底黑字的图片,咱们的程序来做颜色反转,将其变为黑底白字再做识别

SpringBoot应用设计(流程设计)

  • 现在咱们来梳理一下整个流程,如果用户输入的是白底黑字的图片,那么整个应用的处理流程如下:
  • 如果用户输入的是黑底白字的图片,只需要将上述流程中的反色处理去掉即可
  • 为白底黑字图片提供专用接口predict-with-white-background
  • 为黑底白字图片提供专用接口predict-with-black-background
  • 现在设计工作已经完成,可以开始编码了

使用模型(编码)

  • 为了便于管理demo代码和依赖库的版本,《DL4J实战之一:准备》一文中创建了名为dl4j-tutorials的maven工程,咱们今天新建的应用也是dl4j-tutorials的子工程,名为predict-number-image,其自身的pom.xml内容如下:
代码语言:javascript
复制
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
         xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
    <parent>
        <artifactId>dlfj-tutorials</artifactId>
        <groupId>com.bolingcavalry</groupId>
        <version>1.0-SNAPSHOT</version>
    </parent>
    <modelVersion>4.0.0</modelVersion>

    <artifactId>predict-number-image</artifactId>

    <packaging>jar</packaging>

    <!--不用spring-boot-starter-parent作为parent时的配置-->
    <dependencyManagement>
        <dependencies>
            <dependency>
                <groupId>org.springframework.boot</groupId>
                <artifactId>spring-boot-dependencies</artifactId>
                <version>${springboot.version}</version>
                <type>pom</type>
                <scope>import</scope>
            </dependency>
        </dependencies>
    </dependencyManagement>

    <dependencies>

        <dependency>
            <groupId>com.bolingcavalry</groupId>
            <artifactId>commons</artifactId>
            <version>${project.version}</version>
        </dependency>

        <dependency>
            <groupId>org.nd4j</groupId>
            <!--注意要用nd4j-native-platform,否则容器启动时报错:no jnind4jcpu in java.library.path-->
            <!--<artifactId>${nd4j.backend}</artifactId>-->
            <artifactId>nd4j-native-platform</artifactId>
        </dependency>

        <dependency>
            <groupId>ch.qos.logback</groupId>
            <artifactId>logback-classic</artifactId>
        </dependency>

        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-web</artifactId>
        </dependency>

    </dependencies>

    <build>
        <plugins>
            <!-- 如果父工程不是springboot,就要用以下方式使用插件,才能生成正常的jar -->
            <plugin>
                <groupId>org.springframework.boot</groupId>
                <artifactId>spring-boot-maven-plugin</artifactId>
                <configuration>
                    <mainClass>com.bolingcavalry.predictnumber.PredictNumberApplication</mainClass>
                </configuration>
                <executions>
                    <execution>
                        <goals>
                            <goal>repackage</goal>
                        </goals>
                    </execution>
                </executions>
            </plugin>
        </plugins>
    </build>
</project>
  • 可见predict-number-image工程与其父工程dlfj-tutorials的关系不大,仅仅使用了父工程定义的几个库的版本号而已,您也可以独立创建一个没有父子关系的工程;
  • 新建配置文件application.properties,里面有图片相关的配置:
代码语言:javascript
复制
# 上传文件总的最大值
spring.servlet.multipart.max-request-size=1024MB

# 单个文件的最大值
spring.servlet.multipart.max-file-size=10MB

# 处理图片文件的目录
predict.imagefilepath=/app/images/

# 模型所在位置
predict.modelpath=/app/model/minist-model.zip
  • 将处理图片所需的静态方法集中在ImageFileUtil.java的文件中,主要是save(存到磁盘上)、resize(缩放)、colorRevert(反色)、clear(清理)、getGrayImageFeatures(提取特征,操作和训练时的是一样的):
代码语言:javascript
复制
package com.bolingcavalry.commons.utils;

import lombok.extern.slf4j.Slf4j;
import org.datavec.api.split.FileSplit;
import org.datavec.image.loader.NativeImageLoader;
import org.datavec.image.recordreader.ImageRecordReader;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
import org.springframework.web.multipart.MultipartFile;
import javax.imageio.ImageIO;
import java.awt.*;
import java.awt.image.BufferedImage;
import java.io.File;
import java.io.IOException;
import java.util.UUID;

@Slf4j
public class ImageFileUtil {

    /**
     * 调整后的文件宽度
     */
    public static final int RESIZE_WIDTH = 28;

    /**
     * 调整后的文件高度
     */
    public static final int RESIZE_HEIGHT = 28;

    /**
     * 将上传的文件存在服务器上
     * @param base 要处理的文件所在的目录
     * @param file 要处理的文件
     * @return
     */
    public static String save(String base, MultipartFile file) {

        // 检查是否为空
        if (file.isEmpty()) {
            log.error("invalid file");
            return null;
        }

        // 文件名来自原始文件
        String fileName = file.getOriginalFilename();

        // 要保存的位置
        File dest = new File(base + fileName);

        // 开始保存
        try {
            file.transferTo(dest);
        } catch (IOException e) {
            log.error("upload fail", e);
            return null;
        }

        return fileName;
    }

    /**
     * 将图片转为28*28像素
     * @param base     处理文件的目录
     * @param fileName 待调整的文件名
     * @return
     */
    public static String resize(String base, String fileName) {

        // 新文件名是原文件名在加个随机数后缀,而且扩展名固定为png
        String resizeFileName = fileName.substring(0, fileName.lastIndexOf(".")) + "-" + UUID.randomUUID() + ".png";

        log.info("start resize, from [{}] to [{}]", fileName, resizeFileName);

        try {
            // 读原始文件
            BufferedImage bufferedImage = ImageIO.read(new File(base + fileName));

            // 缩放后的实例
            Image image = bufferedImage.getScaledInstance(RESIZE_WIDTH, RESIZE_HEIGHT, Image.SCALE_SMOOTH);

            BufferedImage resizeBufferedImage = new BufferedImage(28, 28, BufferedImage.TYPE_INT_RGB);
            Graphics graphics = resizeBufferedImage.getGraphics();

            // 绘图
            graphics.drawImage(image, 0, 0, null);
            graphics.dispose();

            // 转换后的图片写文件
            ImageIO.write(resizeBufferedImage, "png", new File(base + resizeFileName));

        } catch (Exception exception) {
            log.info("resize error from [{}] to [{}], {}", fileName, resizeFileName, exception);
            resizeFileName = null;
        }

        log.info("finish resize, from [{}] to [{}]", fileName, resizeFileName);

        return resizeFileName;
    }

    /**
     * 将RGB转为int数字
     * @param alpha
     * @param red
     * @param green
     * @param blue
     * @return
     */
    private static int colorToRGB(int alpha, int red, int green, int blue) {
        int pixel = 0;

        pixel += alpha;
        pixel = pixel << 8;

        pixel += red;
        pixel = pixel << 8;

        pixel += green;
        pixel = pixel << 8;

        pixel += blue;

        return pixel;
    }

    /**
     * 反色处理
     * @param base 处理文件的目录
     * @param src 用于处理的源文件
     * @return 反色处理后的新文件
     * @throws IOException
     */
    public static String colorRevert(String base, String src) throws IOException {
        int color, r, g, b, pixel;

        // 读原始文件
        BufferedImage srcImage = ImageIO.read(new File(base + src));

        // 修改后的文件
        BufferedImage destImage = new BufferedImage(srcImage.getWidth(), srcImage.getHeight(), srcImage.getType());

        for (int i=0; i<srcImage.getWidth(); i++) {

            for (int j=0; j<srcImage.getHeight(); j++) {
                color = srcImage.getRGB(i, j);
                r = (color >> 16) & 0xff;
                g = (color >> 8) & 0xff;
                b = color & 0xff;
                pixel = colorToRGB(255, 0xff - r, 0xff - g, 0xff - b);
                destImage.setRGB(i, j, pixel);
            }
        }

        // 反射文件的名字
        String revertFileName =  src.substring(0, src.lastIndexOf(".")) + "-revert.png";

        // 转换后的图片写文件
        ImageIO.write(destImage, "png", new File(base + revertFileName));

        return revertFileName;
    }

    /**
     * 取黑白图片的特征
     * @param base
     * @param fileName
     * @return
     * @throws Exception
     */
    public static INDArray getGrayImageFeatures(String base, String fileName) throws Exception {
        log.info("start getImageFeatures [{}]", base + fileName);

        // 和训练模型时一样的设置
        ImageRecordReader imageRecordReader = new ImageRecordReader(RESIZE_HEIGHT, RESIZE_WIDTH, 1);

        FileSplit fileSplit = new FileSplit(new File(base + fileName),
                NativeImageLoader.ALLOWED_FORMATS);

        imageRecordReader.initialize(fileSplit);

        DataSetIterator dataSetIterator = new RecordReaderDataSetIterator(imageRecordReader, 1);
        dataSetIterator.setPreProcessor(new ImagePreProcessingScaler(0, 1));

        // 取特征
        return dataSetIterator.next().getFeatures();
    }

    /**
     * 批量清理文件
     * @param base      处理文件的目录
     * @param fileNames 待清理文件集合
     */
    public static void clear(String base, String...fileNames) {
        for (String fileName : fileNames) {

            if (null==fileName) {
                continue;
            }

            File file = new File(base + fileName);

            if (file.exists()) {
                file.delete();
            }
        }
    }
}
  • 定义service层,只有一个方法,可以通过入参决定是否做反色处理:
代码语言:javascript
复制
package com.bolingcavalry.predictnumber.service;

import org.springframework.web.multipart.MultipartFile;

public interface PredictService {

    /**
     * 取得上传的图片,做转换后识别成数字
     * @param file 上传的文件
     * @param isNeedRevert 是否要做反色处理
     * @return
     */
    int predict(MultipartFile file, boolean isNeedRevert) throws Exception ;
}
  • sevice层的实现,也是本篇的核心,有几处要注意的地方稍后会提到:
代码语言:javascript
复制
package com.bolingcavalry.predictnumber.service.impl;

import com.bolingcavalry.commons.utils.ImageFileUtil;
import com.bolingcavalry.predictnumber.service.PredictService;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
import org.springframework.web.multipart.MultipartFile;
import javax.annotation.PostConstruct;
import java.io.File;

@Service
@Slf4j
public class PredictServiceImpl implements PredictService {

    /**
     * -1表示识别失败
     */
    private static final int RLT_INVALID = -1;

    /**
     * 模型文件的位置
     */
    @Value("${predict.modelpath}")
    private String modelPath;

    /**
     * 处理图片文件的目录
     */
    @Value("${predict.imagefilepath}")
    private String imageFilePath;

    /**
     * 神经网络
     */
    private MultiLayerNetwork net;

    /**
     * bean实例化成功就加载模型
     */
    @PostConstruct
    private void loadModel() {
        log.info("load model from [{}]", modelPath);

        // 加载模型
        try {
            net = ModelSerializer.restoreMultiLayerNetwork(new File(modelPath));
            log.info("module summary\n{}", net.summary());
        } catch (Exception exception) {
            log.error("loadModel error", exception);
        }
    }

    @Override
    public int predict(MultipartFile file, boolean isNeedRevert) throws Exception {
        log.info("start predict, file [{}], isNeedRevert [{}]", file.getOriginalFilename(), isNeedRevert);

        // 先存文件
        String rawFileName = ImageFileUtil.save(imageFilePath, file);

        if (null==rawFileName) {
            return RLT_INVALID;
        }

        // 反色处理后的文件名
        String revertFileName = null;

        // 调整大小后的文件名
        String resizeFileName;

        // 是否需要反色处理
        if (isNeedRevert) {
            // 把原始文件做反色处理,返回结果是反色处理后的新文件
            revertFileName = ImageFileUtil.colorRevert(imageFilePath, rawFileName);

            // 把反色处理后调整为28*28大小的文件
            resizeFileName = ImageFileUtil.resize(imageFilePath, revertFileName);
        } else {
            // 直接把原始文件调整为28*28大小的文件
            resizeFileName = ImageFileUtil.resize(imageFilePath, rawFileName);
        }

        // 现在已经得到了结果反色和调整大小处理过后的文件,
        // 那么原始文件和反色处理过的文件就可以删除了
        ImageFileUtil.clear(imageFilePath, rawFileName, revertFileName);

        // 取出该黑白图片的特征
        INDArray features = ImageFileUtil.getGrayImageFeatures(imageFilePath, resizeFileName);

        // 将特征传给模型去识别
        return net.predict(features)[0];
    }
}
  • 上述代码中,有两需要注意:
  • loadModel方法在bean初始化时会执行,里面通过ModelSerializer.restoreMultiLayerNetwork完成模型文件加载
  • 真正的识别操作其实就是MultiLayerNetwork.predict方法,一步而已,何其简单
  • 然后是web接口层,对外提供两个接口:
代码语言:javascript
复制
package com.bolingcavalry.predictnumber.controller;

import com.bolingcavalry.predictnumber.service.PredictService;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.ResponseBody;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.multipart.MultipartFile;

@RestController
public class PredictController {

    final PredictService predictService;

    public PredictController(PredictService predictService) {
        this.predictService = predictService;
    }

    @PostMapping("/predict-with-black-background")
    @ResponseBody
    public int predictWithBlackBackground(@RequestParam("file") MultipartFile file) throws Exception {
        // 训练模型的时候,用的数字是白字黑底,
        // 因此如果上传白字黑底的图片,可以直接拿去识别,而无需反色处理
        return predictService.predict(file, false);
    }

    @PostMapping("/predict-with-white-background")
    @ResponseBody
    public int predictWithWhiteBackground(@RequestParam("file") MultipartFile file) throws Exception {
        // 训练模型的时候,用的数字是白字黑底,
        // 因此如果上传黑字白底的图片,就需要做反色处理,
        // 反色之后就是白字黑底了,可以拿去识别
        return predictService.predict(file, true);
    }
}
  • 最后是启动类:
代码语言:javascript
复制
package com.bolingcavalry.predictnumber;

import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
@SpringBootApplication
public class PredictNumberApplication {
    public static void main(String[] args) {
        SpringApplication.run(PredictNumberApplication.class, args);
    }
}

制作docker镜像

  • 将SpringBoot应用做成docker镜像有很多方法,这里选用的是SpringBoot官方推荐的方式:
  • 先把Dockerfile文件写好,放在predict-number-image目录下,可见只是简单的文件复制操作,然后指定启动命令:
代码语言:javascript
复制
# 8-jdk-alpine版本启动的时候会crash
FROM openjdk:8u292-jdk

# 创建目录
RUN mkdir -p /app/images && mkdir -p /app/model

# 指定镜像的内容的来源位置
ARG DEPENDENCY=target/dependency

# 复制内容到镜像
COPY ${DEPENDENCY}/BOOT-INF/lib /app/lib
COPY ${DEPENDENCY}/META-INF /app/META-INF
COPY ${DEPENDENCY}/BOOT-INF/classes /app

# 指定启动命令
ENTRYPOINT ["java","-cp","app:app/lib/*","com.bolingcavalry.predictnumber.PredictNumberApplication"]
  • 接下来准备Dockerfile中所需的那些文件,在父工程目录下执行mvn clean package -U,这是个纯粹的maven操作,和docker没有任何关系
  • 进入predict-number-image目录,执行以下命令,作用是从jar文件中提取class、配置文件、依赖库等内容到target/dependency目录:
代码语言:javascript
复制
mkdir -p target/dependency && (cd target/dependency; jar -xf ../*.jar)
  • 最后,在Dockerfile文件所在目录执行命令docker build -t bolingcavalry/dl4j-model-app:0.0.3 .(命令的最后有个点,不要漏了),即可完成镜像制作
  • 如果您有hub.docker.com的账号,还可以通过docker push命令把镜像推送到中央仓库,让更多的人用到:
  • 最后,再来回顾一下《三分钟体验:SpringBoot用深度学习模型识别数字》一文中启动docker容器的命令,如下可见,通过两个-v参数,将宿主机的目录映射到容器中,因此,容器中的/app/images和/app/model可以保持不变,只要能保证宿主机的目录映射正确即可:
代码语言:javascript
复制
docker run \
--rm \
-p 18080:8080 \
-v /home/will/temp/202106/29/images:/app/images \
-v /home/will/temp/202106/29/model:/app/model \
bolingcavalry/dl4j-model-app:0.0.3
  • 有关SpringBoot官方推荐的docker镜像制作的更多信息,请参考《SpringBoot(2.4)应用制作Docker镜像(Gradle版官方方案)》
  • 至此,SpringBoot用深度学习模型识别数字的开发实战就完成了,如果您是一位对深度学习感兴趣的java程序员,相信本文能给您带来一些参考,更多深度学习的实战请关注欣宸的《DL4J实战》系列原创;
本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
原始发表:2021-07-03 ,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 本篇概览
  • 全流程概览
  • 源码下载
  • 训练实战
  • SpringBoot应用设计(第一个关键点)
  • SpringBoot应用设计(第二个关键点)
  • SpringBoot应用设计(流程设计)
  • 使用模型(编码)
  • 制作docker镜像
相关产品与服务
容器镜像服务
容器镜像服务(Tencent Container Registry,TCR)为您提供安全独享、高性能的容器镜像托管分发服务。您可同时在全球多个地域创建独享实例,以实现容器镜像的就近拉取,降低拉取时间,节约带宽成本。TCR 提供细颗粒度的权限管理及访问控制,保障您的数据安全。
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档