DoodleNet - 用Quickdraw数据集训练的CNN涂鸦分类器
by yining1023
DoodleNet 是一个涂鸦分类器(CNN),对来自Quickdraw数据集的所有345个类别进行了训练。
Github项目地址:
https://github.com/yining1023/doodleNet
这是我使用tensorflow.js和tensorflow进行的关于涂鸦分类器(一种卷积神经网络)的一系列实验。使用的数据来自Quickdraw数据集。
以下是项目清单:
查看网络机器学习第3周了解更多信息以及CNN和迁移学习如何运作。
我用 tfjs 的 layers API 和 tf.js-vis 在浏览器中训练了一个涂有3个类(领结、棒棒糖、彩虹)的涂鸦分类器。代码基于 tf.js 示例 - 训练MNIST。
演示Demo:
https://yining1023.github.io/doodleNet/demo/TrainDoodleClassifier
打开网页后,请等待页面加载数据、训练模型、评估模型。 它将会下载两个文件:myDoodleNet.json 和 myDoodleNet.weights.bin 。如果要自己测试这个模型,你可以加载这两个文件,然后点击 'Load Model - 加载模型' 按钮,然后在画布上画画,点击'Guess'按钮让模型开始猜测画布上你画的是什么。
DoodleNet 对 Quickdraw 数据集中的345个类别进行了训练,每个类有50k张图片。它使用tensorflow进行训练,并在浏览器中移植到tf.js。点击打开训练笔记。
训练笔记主要基于@zaidalyafeai 的100个课程的Sketcher笔记本。我将数据扩展到345个类,并添加了几个层来改善345个类的准确性。
我使用 spell.run 的搭载大容量RAM的远程GPU机器来加载所有数据并训练模型。
演示Demo:
https://yining1023.github.io/doodleNet/demo/DoodleClassifier_345
基于之前的345个类的涂鸦分类器,我添加了KNN分类器,因此人们可以自定义自己的涂鸦类。
演示Demo:
https://yining1023.github.io/doodleNet/demo/DoodleClassifier_KNN
你可以绘制10个以上的圆圈并将它们添加到A类,并绘制10个以上的线条并将它们添加到B类,然后让模型猜测您的新绘图。你也可以定义任何其他类,它不需要是圆形或正方形。
要在本地运行每个示例,请打开终端,输入以下命令:
$ git clone https://github.com/yining1023/doodleNet.git$ cd doodleNet$ python -m SimpleHTTPServer # $ python3 -m http.server (if you are using python 3)
在浏览器中打开 localhost:8000/demo,你会看到如下的目录列表,单击即可查看对应演示。