前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >【机器学习】Tensorflow.js:我在浏览器中实现了迁移学习

【机器学习】Tensorflow.js:我在浏览器中实现了迁移学习

作者头像
前端修罗场
发布2023-10-07 19:31:14
1690
发布2023-10-07 19:31:14
举报
文章被收录于专栏:Web 技术Web 技术

迁移学习是将预训练模型与自定义训练数据相结合的能力。 这意味着你可以利用模型的功能并添加自己的样本,而无需从头开始创建所有内容。

例如,一种算法已经用数千张图像进行了训练以创建图像分类模型,而不是创建自己的图像分类模型,迁移学习允许你将新的自定义图像样本与预先训练的模型相结合以创建新的图像分类器。 这个特性使得拥有一个更加定制化的分类器变得非常快速和容易。

为了提供代码中的示例,让我们重新利用之前的示例并对其进行修改,以便我们可以对新图像进行分类。

请添加图片描述
请添加图片描述

以下是此设置最重要部分的一些代码示例,但如果你需要查看整个代码,可以在本文的最后找到它。

我们仍然需要从导入 Tensorflow.js 和 MobileNet 开始,但是这次我们还需要添加一个 KNN(k-nearest neighbor)分类器:

代码语言:javascript
复制
<!-- Load TensorFlow.js -->
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs"></script>
<!-- Load MobileNet -->
<script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/mobilenet"></script>
<!-- Load KNN Classifier -->
<script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/knn-classifier"></script>

我们需要分类器的原因是(不仅仅是使用 MobileNet 模块)我们正在添加以前从未见过的自定义样本,因此 KNN 分类器将允许我们将所有内容组合在一起并对组合的数据进行预测

然后,我们可以用视频标签替换猫的图像,以使用来自摄像头的图像。

代码语言:javascript
复制
<video autoplay id="webcam" width="227" height="227"></video>

最后,我们需要在页面上添加一些按钮,我们将用作标签来记录一些视频样本并开始预测。

代码语言:javascript
复制
<section>
  <button class="button">Left</button>

  <button class="button">Right</button>

  <button class="test-predictions">Test</button>
</section>

现在,让我们转到 JavaScript 文件,我们将从设置几个重要变量开始:

代码语言:javascript
复制
//要分类的数量
const NUM_CLASSES = 2;
// 分类标签
const classes = ["Left", "Right"];
// Webcam Image size. Must be 227.
const IMAGE_SIZE = 227;
// KNN 的 K 值
const TOPK = 10;

const video = document.getElementById("webcam");

在这个特定的示例中,我们希望能够在我们的头部向左或向右倾斜之间对网络摄像头输入进行分类,因此我们需要两个标记为 leftright 的类。

设置为 227 的图像大小是视频元素的大小(以像素为单位)。 根据 Tensorflow.js 示例,该值需要设置为 227 以匹配用于训练 MobileNet 模型的数据格式。 为了能够对我们的新数据进行分类,后者需要适应相同的格式。

如果你真的需要它更大,这是可能的,但你必须在将数据提供给 KNN 分类器之前转换和调整数据大小。

然后,我们将 K 的值设置为 10。KNN 算法中的 K 值很重要,因为它代表了我们在确定新输入的类别时考虑的实例数。

在这种情况下,10 意味着,在预测一些新数据的标签时,我们将查看训练数据中的 10 个最近邻,以确定如何对新输入进行分类。

最后,我们得到了视频元素。

对于逻辑,让我们从加载模型和分类器开始:

代码语言:javascript
复制
async load() {
    const knn = knnClassifier.create();
    const mobilenetModule = await mobilenet.load();
    console.log("model loaded");
}

然后,让我们访问视频源:

代码语言:javascript
复制
navigator.mediaDevices
  .getUserMedia({ video: true, audio: false })
  .then(stream => {
    video.srcObject = stream;
    video.width = IMAGE_SIZE;
    video.height = IMAGE_SIZE;
  });

接下来,让我们设置一些按钮事件来记录我们的示例数据:

代码语言:javascript
复制
setupButtonEvents() {
    for (let i = 0; i < NUM_CLASSES; i++) {
      let button = document.getElementsByClassName("button")[i];

      button.onmousedown = () => {
        this.training = i;
        this.recordSamples = true;
      };
      button.onmouseup = () => (this.training = -1);
    }
  }

让我们编写我们的函数,它将获取网络摄像头图像样本,重新格式化它们并将它们与 MobileNet 模块结合起来:

代码语言:javascript
复制
// 从视频元素中获取图像数据
const image = tf.browser.fromPixels(video);

let logits;
// 'conv_preds' 是 MobileNet 的 logits 激活。
const infer = () => this.mobilenetModule.infer(image, "conv_preds");

// 如果按住其中一个按钮,则进行训练
if (this.training != -1) {
  logits = infer();

  // 将当前图像添加到分类器
  this.knn.addExample(logits, this.training);
}

最后,一旦我们收集了一些网络摄像头图像,我们就可以使用以下代码测试我们的预测:

代码语言:javascript
复制
logits = infer();
const res = await this.knn.predictClass(logits, TOPK);
const prediction = classes[res.classIndex];

最后,您可以处理我们不再需要的网络摄像头数据:

代码语言:javascript
复制
// 完成后处理图像
image.dispose();
if (logits != null) {
  logits.dispose();
}
本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
相关产品与服务
视频理解
视频理解是基于腾讯领先的 AI 技术和丰富的内容运营经验,对视频内容输出涵盖人物、场景、物体、事件的高精度、多维度的优质标签内容。通过对视频内容进行细粒度的结构化解析,应用于媒资系统管理、素材检索、内容运营等业务场景中。其中一款产品是媒体智能标签(Intelligent Media Label Detection)
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档