名称 | 链接 | 备注 |
---|---|---|
项目主页 | 该项目在GitHub上的主页 | |
git仓库地址(https) | 该项目源码的仓库地址,https协议 | |
git仓库地址(ssh) | git@github.com:zq2599/blog_demos.git | 该项目源码的仓库地址,ssh协议 |
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<parent>
<artifactId>dlfj-tutorials</artifactId>
<groupId>com.bolingcavalry</groupId>
<version>1.0-SNAPSHOT</version>
</parent>
<modelVersion>4.0.0</modelVersion>
<artifactId>predict-number-image</artifactId>
<packaging>jar</packaging>
<!--不用spring-boot-starter-parent作为parent时的配置-->
<dependencyManagement>
<dependencies>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-dependencies</artifactId>
<version>${springboot.version}</version>
<type>pom</type>
<scope>import</scope>
</dependency>
</dependencies>
</dependencyManagement>
<dependencies>
<dependency>
<groupId>com.bolingcavalry</groupId>
<artifactId>commons</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.nd4j</groupId>
<!--注意要用nd4j-native-platform,否则容器启动时报错:no jnind4jcpu in java.library.path-->
<!--<artifactId>${nd4j.backend}</artifactId>-->
<artifactId>nd4j-native-platform</artifactId>
</dependency>
<dependency>
<groupId>ch.qos.logback</groupId>
<artifactId>logback-classic</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</dependency>
</dependencies>
<build>
<plugins>
<!-- 如果父工程不是springboot,就要用以下方式使用插件,才能生成正常的jar -->
<plugin>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-maven-plugin</artifactId>
<configuration>
<mainClass>com.bolingcavalry.predictnumber.PredictNumberApplication</mainClass>
</configuration>
<executions>
<execution>
<goals>
<goal>repackage</goal>
</goals>
</execution>
</executions>
</plugin>
</plugins>
</build>
</project>
# 上传文件总的最大值
spring.servlet.multipart.max-request-size=1024MB
# 单个文件的最大值
spring.servlet.multipart.max-file-size=10MB
# 处理图片文件的目录
predict.imagefilepath=/app/images/
# 模型所在位置
predict.modelpath=/app/model/minist-model.zip
package com.bolingcavalry.commons.utils;
import lombok.extern.slf4j.Slf4j;
import org.datavec.api.split.FileSplit;
import org.datavec.image.loader.NativeImageLoader;
import org.datavec.image.recordreader.ImageRecordReader;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
import org.springframework.web.multipart.MultipartFile;
import javax.imageio.ImageIO;
import java.awt.*;
import java.awt.image.BufferedImage;
import java.io.File;
import java.io.IOException;
import java.util.UUID;
@Slf4j
public class ImageFileUtil {
/**
* 调整后的文件宽度
*/
public static final int RESIZE_WIDTH = 28;
/**
* 调整后的文件高度
*/
public static final int RESIZE_HEIGHT = 28;
/**
* 将上传的文件存在服务器上
* @param base 要处理的文件所在的目录
* @param file 要处理的文件
* @return
*/
public static String save(String base, MultipartFile file) {
// 检查是否为空
if (file.isEmpty()) {
log.error("invalid file");
return null;
}
// 文件名来自原始文件
String fileName = file.getOriginalFilename();
// 要保存的位置
File dest = new File(base + fileName);
// 开始保存
try {
file.transferTo(dest);
} catch (IOException e) {
log.error("upload fail", e);
return null;
}
return fileName;
}
/**
* 将图片转为28*28像素
* @param base 处理文件的目录
* @param fileName 待调整的文件名
* @return
*/
public static String resize(String base, String fileName) {
// 新文件名是原文件名在加个随机数后缀,而且扩展名固定为png
String resizeFileName = fileName.substring(0, fileName.lastIndexOf(".")) + "-" + UUID.randomUUID() + ".png";
log.info("start resize, from [{}] to [{}]", fileName, resizeFileName);
try {
// 读原始文件
BufferedImage bufferedImage = ImageIO.read(new File(base + fileName));
// 缩放后的实例
Image image = bufferedImage.getScaledInstance(RESIZE_WIDTH, RESIZE_HEIGHT, Image.SCALE_SMOOTH);
BufferedImage resizeBufferedImage = new BufferedImage(28, 28, BufferedImage.TYPE_INT_RGB);
Graphics graphics = resizeBufferedImage.getGraphics();
// 绘图
graphics.drawImage(image, 0, 0, null);
graphics.dispose();
// 转换后的图片写文件
ImageIO.write(resizeBufferedImage, "png", new File(base + resizeFileName));
} catch (Exception exception) {
log.info("resize error from [{}] to [{}], {}", fileName, resizeFileName, exception);
resizeFileName = null;
}
log.info("finish resize, from [{}] to [{}]", fileName, resizeFileName);
return resizeFileName;
}
/**
* 将RGB转为int数字
* @param alpha
* @param red
* @param green
* @param blue
* @return
*/
private static int colorToRGB(int alpha, int red, int green, int blue) {
int pixel = 0;
pixel += alpha;
pixel = pixel << 8;
pixel += red;
pixel = pixel << 8;
pixel += green;
pixel = pixel << 8;
pixel += blue;
return pixel;
}
/**
* 反色处理
* @param base 处理文件的目录
* @param src 用于处理的源文件
* @return 反色处理后的新文件
* @throws IOException
*/
public static String colorRevert(String base, String src) throws IOException {
int color, r, g, b, pixel;
// 读原始文件
BufferedImage srcImage = ImageIO.read(new File(base + src));
// 修改后的文件
BufferedImage destImage = new BufferedImage(srcImage.getWidth(), srcImage.getHeight(), srcImage.getType());
for (int i=0; i<srcImage.getWidth(); i++) {
for (int j=0; j<srcImage.getHeight(); j++) {
color = srcImage.getRGB(i, j);
r = (color >> 16) & 0xff;
g = (color >> 8) & 0xff;
b = color & 0xff;
pixel = colorToRGB(255, 0xff - r, 0xff - g, 0xff - b);
destImage.setRGB(i, j, pixel);
}
}
// 反射文件的名字
String revertFileName = src.substring(0, src.lastIndexOf(".")) + "-revert.png";
// 转换后的图片写文件
ImageIO.write(destImage, "png", new File(base + revertFileName));
return revertFileName;
}
/**
* 取黑白图片的特征
* @param base
* @param fileName
* @return
* @throws Exception
*/
public static INDArray getGrayImageFeatures(String base, String fileName) throws Exception {
log.info("start getImageFeatures [{}]", base + fileName);
// 和训练模型时一样的设置
ImageRecordReader imageRecordReader = new ImageRecordReader(RESIZE_HEIGHT, RESIZE_WIDTH, 1);
FileSplit fileSplit = new FileSplit(new File(base + fileName),
NativeImageLoader.ALLOWED_FORMATS);
imageRecordReader.initialize(fileSplit);
DataSetIterator dataSetIterator = new RecordReaderDataSetIterator(imageRecordReader, 1);
dataSetIterator.setPreProcessor(new ImagePreProcessingScaler(0, 1));
// 取特征
return dataSetIterator.next().getFeatures();
}
/**
* 批量清理文件
* @param base 处理文件的目录
* @param fileNames 待清理文件集合
*/
public static void clear(String base, String...fileNames) {
for (String fileName : fileNames) {
if (null==fileName) {
continue;
}
File file = new File(base + fileName);
if (file.exists()) {
file.delete();
}
}
}
}
package com.bolingcavalry.predictnumber.service;
import org.springframework.web.multipart.MultipartFile;
public interface PredictService {
/**
* 取得上传的图片,做转换后识别成数字
* @param file 上传的文件
* @param isNeedRevert 是否要做反色处理
* @return
*/
int predict(MultipartFile file, boolean isNeedRevert) throws Exception ;
}
package com.bolingcavalry.predictnumber.service.impl;
import com.bolingcavalry.commons.utils.ImageFileUtil;
import com.bolingcavalry.predictnumber.service.PredictService;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
import org.springframework.web.multipart.MultipartFile;
import javax.annotation.PostConstruct;
import java.io.File;
@Service
@Slf4j
public class PredictServiceImpl implements PredictService {
/**
* -1表示识别失败
*/
private static final int RLT_INVALID = -1;
/**
* 模型文件的位置
*/
@Value("${predict.modelpath}")
private String modelPath;
/**
* 处理图片文件的目录
*/
@Value("${predict.imagefilepath}")
private String imageFilePath;
/**
* 神经网络
*/
private MultiLayerNetwork net;
/**
* bean实例化成功就加载模型
*/
@PostConstruct
private void loadModel() {
log.info("load model from [{}]", modelPath);
// 加载模型
try {
net = ModelSerializer.restoreMultiLayerNetwork(new File(modelPath));
log.info("module summary\n{}", net.summary());
} catch (Exception exception) {
log.error("loadModel error", exception);
}
}
@Override
public int predict(MultipartFile file, boolean isNeedRevert) throws Exception {
log.info("start predict, file [{}], isNeedRevert [{}]", file.getOriginalFilename(), isNeedRevert);
// 先存文件
String rawFileName = ImageFileUtil.save(imageFilePath, file);
if (null==rawFileName) {
return RLT_INVALID;
}
// 反色处理后的文件名
String revertFileName = null;
// 调整大小后的文件名
String resizeFileName;
// 是否需要反色处理
if (isNeedRevert) {
// 把原始文件做反色处理,返回结果是反色处理后的新文件
revertFileName = ImageFileUtil.colorRevert(imageFilePath, rawFileName);
// 把反色处理后调整为28*28大小的文件
resizeFileName = ImageFileUtil.resize(imageFilePath, revertFileName);
} else {
// 直接把原始文件调整为28*28大小的文件
resizeFileName = ImageFileUtil.resize(imageFilePath, rawFileName);
}
// 现在已经得到了结果反色和调整大小处理过后的文件,
// 那么原始文件和反色处理过的文件就可以删除了
ImageFileUtil.clear(imageFilePath, rawFileName, revertFileName);
// 取出该黑白图片的特征
INDArray features = ImageFileUtil.getGrayImageFeatures(imageFilePath, resizeFileName);
// 将特征传给模型去识别
return net.predict(features)[0];
}
}
package com.bolingcavalry.predictnumber.controller;
import com.bolingcavalry.predictnumber.service.PredictService;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.ResponseBody;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.multipart.MultipartFile;
@RestController
public class PredictController {
final PredictService predictService;
public PredictController(PredictService predictService) {
this.predictService = predictService;
}
@PostMapping("/predict-with-black-background")
@ResponseBody
public int predictWithBlackBackground(@RequestParam("file") MultipartFile file) throws Exception {
// 训练模型的时候,用的数字是白字黑底,
// 因此如果上传白字黑底的图片,可以直接拿去识别,而无需反色处理
return predictService.predict(file, false);
}
@PostMapping("/predict-with-white-background")
@ResponseBody
public int predictWithWhiteBackground(@RequestParam("file") MultipartFile file) throws Exception {
// 训练模型的时候,用的数字是白字黑底,
// 因此如果上传黑字白底的图片,就需要做反色处理,
// 反色之后就是白字黑底了,可以拿去识别
return predictService.predict(file, true);
}
}
package com.bolingcavalry.predictnumber;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
@SpringBootApplication
public class PredictNumberApplication {
public static void main(String[] args) {
SpringApplication.run(PredictNumberApplication.class, args);
}
}
# 8-jdk-alpine版本启动的时候会crash
FROM openjdk:8u292-jdk
# 创建目录
RUN mkdir -p /app/images && mkdir -p /app/model
# 指定镜像的内容的来源位置
ARG DEPENDENCY=target/dependency
# 复制内容到镜像
COPY ${DEPENDENCY}/BOOT-INF/lib /app/lib
COPY ${DEPENDENCY}/META-INF /app/META-INF
COPY ${DEPENDENCY}/BOOT-INF/classes /app
# 指定启动命令
ENTRYPOINT ["java","-cp","app:app/lib/*","com.bolingcavalry.predictnumber.PredictNumberApplication"]
mkdir -p target/dependency && (cd target/dependency; jar -xf ../*.jar)
docker run \
--rm \
-p 18080:8080 \
-v /home/will/temp/202106/29/images:/app/images \
-v /home/will/temp/202106/29/model:/app/model \
bolingcavalry/dl4j-model-app:0.0.3