TVP

# [机器学习|理论&实践] 深度解析迁移学习：知识的精妙转移

6225

### 代码示例

#### 场景：图像分类任务

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, models, transforms

# 定义数据预处理
data_transforms = {
'train': transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'val': transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}

# 加载数据集
data_dir = 'path/to/your/dataset'
image_datasets = {x: datasets.ImageFolder(f'{data_dir}/{x}', data_transforms[x]) for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=32, shuffle=True) for x in ['train', 'val']}

# 加载预训练的ResNet

model = models.resnet18(pretrained=True)

# 固定模型参数
for param in model.parameters():
param.requires_grad = False

# 修改分类层，适应新的任务
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, len(image_datasets['train'].classes))

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.fc.parameters(), lr=0.001, momentum=0.9)

# 训练模型
num_epochs = 10
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)

for epoch in range(num_epochs):
model.train()
running_loss = 0.0
for inputs, labels in dataloaders['train']:
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item() * inputs.size(0)

epoch_loss = running_loss / len(image_datasets['train'])
print(f'Epoch {epoch}/{num_epochs}, Loss: {epoch_loss:.4f}')

# 模型评估
model.eval()
corrects = 0
total = 0
with torch.no_grad():
for inputs, labels in dataloaders['val']:
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
_, predicted = torch.max(outputs, 1)
total += labels.size(0)
corrects += (predicted == labels).sum().item()

accuracy = corrects / total
print(f'Validation Accuracy: {accuracy * 100:.2f}%')

### 结论

0 条评论

LV.

• 什么是迁移学习？
• 迁移学习的应用领域
• 1. 计算机视觉
• 2. 自然语言处理
• 3. 医疗影像分析
• 迁移学习的核心思想
• 1. 源领域和目标领域
• 2. 共享知识
• 3. 适应性学习
• 拓展迁移学习的视野
• 代码示例
• 场景：图像分类任务
• 结论
相关产品与服务
NLP 服务
NLP 服务（Natural Language Process，NLP）深度整合了腾讯内部的 NLP 技术，提供多项智能文本处理和文本生成能力，包括词法分析、相似词召回、词相似度、句子相似度、文本润色、句子纠错、文本补全、句子生成等。满足各行业的文本智能需求。
领券