前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >SwanLab快速上手(开源在线训练可视化工具,Wandb国产平替)

SwanLab快速上手(开源在线训练可视化工具,Wandb国产平替)

原创
作者头像
用户9029617
发布2024-05-20 20:06:06
1220
发布2024-05-20 20:06:06

1. SwanLab介绍

swanlab地址:swanlab.cn

Github地址:https://github.com/SwanHubX/SwanLab

SwanLab是一款开源、轻量级的AI实验跟踪工具,提供了一个跟踪、比较、和协作实验的平台,旨在加速AI研发团队100倍的研发效率。其提供了友好的API和漂亮的界面,结合了超参数跟踪、指标记录、在线协作、实验链接分享、实时消息通知等功能,让您可以快速跟踪ML实验、可视化过程、分享给同伴。

相比于Tensorboard,SwanLab记录的信息更全、使用更方便。相比于Wandb,则访问速度更快,更方便于在国内使用,与主创团队交流更容易。

核心特性列表:

  1. 📊实验指标与超参数跟踪: 极简的代码嵌入您的机器学习pipeline,跟踪记录训练关键指标
    • 自由的超参数与实验配置记录
    • 支持的元数据类型:标量指标、图像、音频、文本、...
    • 支持的图表类型:折线图、媒体图(图像、音频、文本)、...
    • 自动记录:控制台logging、GPU硬件、Git信息、Python解释器、Python库列表、代码目录
  2. ⚡️全面的框架集成: PyTorch、Tensorflow、PyTorch Lightning、🤗HuggingFace Transformers、MMEngine、OpenAI、ZhipuAI、Hydra、...
  3. 📦组织实验: 集中式仪表板,快速管理多个项目与实验,通过整体视图速览训练全局
  4. 🆚比较结果: 通过在线表格与对比图表比较不同实验的超参数和结果,挖掘迭代灵感
  5. 👥在线协作: 您可以与团队进行协作式训练,支持将实验实时同步在一个项目下,您可以在线查看团队的训练记录,基于结果发表看法与建议
  6. ✉️分享结果: 复制和发送持久的URL来共享每个实验,方便地发送给伙伴,或嵌入到在线笔记中
  7. 💻支持自托管: 支持不联网使用,自托管的社区版同样可以查看仪表盘与管理实验

上图:

多个实验指标对比:

多个实验对比
多个实验对比

管理多个项目:

项目管理
项目管理

记录超参数和指标:

超参数与指标记录
超参数与指标记录

表格管理实验:

实验表格
实验表格

查看训练的日志:

查看训练的日志:
查看训练的日志:

2. 快速上手

参考链接:SwanLab快速开始

2.1 安装swanlab

代码语言:bash
复制
pip install swanlab

如果下载太慢,可以使用以下命令从清华源下载:

代码语言:bash
复制
pip install swanlab -i https://pypi.tuna.tsinghua.edu.cn/simple

2.2 登录账号

如果你之前没有注册用SwanLab账号,那么去官网注册一个,然后记一下你的API Key:

在命令行输入:

代码语言:bash
复制
swanlab login

然后将你的API Key粘贴进去,按回车,然后就登录完成了(后面无需再次登录)。

2.3 运行案例程序

代码语言:python
复制
import swanlab
import random

# 初始化一个新的swanlab run类来跟踪这个脚本
swanlab.init(
  # 设置将记录此次运行的项目信息
  project="my-awesome-project",
  
  # 跟踪超参数和运行元数据
  config={
    "learning_rate": 0.02,
    "architecture": "CNN",
    "dataset": "CIFAR-100",
    "epochs": 10
  }
)

# 模拟训练
epochs = 10
offset = random.random() / 5
for epoch in range(2, epochs):
  acc = 1 - 2 ** -epoch - random.random() / epoch - offset
  loss = 2 ** -epoch + random.random() / epoch + offset

  # 向swanlab上传训练指标
  swanlab.log({"acc": acc, "loss": loss})

运行后,你会在最开始看到swanlab链接:

点击链接就可以看到可视化效果,或者访问SwanLab官网,会在你的账号下看到新的实验。

3. 训练一个MNIST手写体识别

代码语言:python
复制
import os
import torch
from torch import nn, optim, utils
import torch.nn.functional as F
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
import swanlab

# CNN网络构建
class ConvNet(nn.Module):
    def __init__(self):
        super().__init__()
        # 1,28x28
        self.conv1 = nn.Conv2d(1, 10, 5)  # 10, 24x24
        self.conv2 = nn.Conv2d(10, 20, 3)  # 128, 10x10
        self.fc1 = nn.Linear(20 * 10 * 10, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        in_size = x.size(0)
        out = self.conv1(x)  # 24
        out = F.relu(out)
        out = F.max_pool2d(out, 2, 2)  # 12
        out = self.conv2(out)  # 10
        out = F.relu(out)
        out = out.view(in_size, -1)
        out = self.fc1(out)
        out = F.relu(out)
        out = self.fc2(out)
        out = F.log_softmax(out, dim=1)
        return out


# 捕获并可视化前20张图像
def log_images(loader, num_images=16):
    images_logged = 0
    logged_images = []
    for images, labels in loader:
        # images: batch of images, labels: batch of labels
        for i in range(images.shape[0]):
            if images_logged < num_images:
                # 使用swanlab.Image将图像转换为wandb可视化格式
                logged_images.append(swanlab.Image(images[i], caption=f"Label: {labels[i]}"))
                images_logged += 1
            else:
                break
        if images_logged >= num_images:
            break
    swanlab.log({"MNIST-Preview": logged_images})


if __name__ == "__main__":

    # 初始化swanlab
    run = swanlab.init(
        project="MNIST-example",
        experiment_name="ConvNet",
        description="Train ConvNet on MNIST dataset.",
        config={
            "model": "CNN",
            "optim": "Adam",
            "lr": 0.001,
            "batch_size": 512,
            "num_epochs": 10,
            "train_dataset_num": 55000,
            "val_dataset_num": 5000,
        },
    )

    # 设置训练机、验证集和测试集
    dataset = MNIST(os.getcwd(), train=True, download=True, transform=ToTensor())
    train_dataset, val_dataset = utils.data.random_split(
        dataset, [run.config.train_dataset_num, run.config.val_dataset_num]
    )

    train_loader = utils.data.DataLoader(train_dataset, batch_size=run.config.batch_size, shuffle=True)
    val_loader = utils.data.DataLoader(val_dataset, batch_size=1, shuffle=False)

    # 初始化模型、损失函数和优化器
    model = ConvNet()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=run.config.lr)

    # (可选)看一下数据集的前16张图像
    log_images(train_loader, 16)

    # 开始训练
    for epoch in range(1, run.config.num_epochs):
        swanlab.log({"train/epoch": epoch})
        # 训练循环
        for iter, batch in enumerate(train_loader):
            x, y = batch
            optimizer.zero_grad()
            output = model(x)
            loss = criterion(output, y)
            loss.backward()
            optimizer.step()

            print(
                f"Epoch [{epoch}/{run.config.num_epochs}], Iteration [{iter + 1}/{len(train_loader)}], Loss: {loss.item()}"
            )

            if iter % 20 == 0:
                swanlab.log({"train/loss": loss.item()}, step=(epoch - 1) * len(train_loader) + iter)

        # 每4个epoch验证一次
        if epoch % 2 == 0:
            model.eval()
            correct = 0
            total = 0
            with torch.no_grad():
                for batch in val_loader:
                    x, y = batch
                    output = model(x)
                    _, predicted = torch.max(output, 1)
                    total += y.size(0)
                    correct += (predicted == y).sum().item()

            accuracy = correct / total
            swanlab.log({"val/accuracy": accuracy})

效果:

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 1. SwanLab介绍
  • 2. 快速上手
    • 2.1 安装swanlab
      • 2.2 登录账号
        • 2.3 运行案例程序
        • 3. 训练一个MNIST手写体识别
        领券
        问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档