PyTorch 模型训练流程实战:数据加载、训练循环、验证调优

本文从实战角度,梳理 PyTorch 完整的模型训练流程。从数据加载、训练循环、验证调优到模型保存,结合代码模板,帮助你快速搭建可靠的训练管道。

Dataset 与 DataLoader

数据加载是训练的第一步。PyTorch 提供了 DatasetDataLoader 两个核心抽象。

自定义 Dataset

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
from torch.utils.data import Dataset, DataLoader

class MyDataset(Dataset):
def __init__(self, data_path, transform=None):
# 加载数据路径或直接加载数据到内存
self.data = self._load_data(data_path)
self.transform = transform

def __len__(self):
return len(self.data)

def __getitem__(self, idx):
sample = self.data[idx]
if self.transform:
sample = self.transform(sample)
return sample

def _load_data(self, path):
# 实际项目中从文件、数据库等加载
pass

# 使用
dataset = MyDataset('data/train')
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

DataLoader 参数详解

1
2
3
4
5
6
7
8
9
DataLoader(
dataset,
batch_size=32, # 每批样本数
shuffle=True, # 是否在每个 epoch 打乱数据
num_workers=4, # 并行加载进程数(>0 启用多进程)
pin_memory=True, # 加速 GPU 传输
drop_last=False, # 丢弃最后不完整批次
batch_sampler=None, # 自定义批次采样器
)

num_workers 经验值

数据集大小 num_workers
小 (< 1GB) 0-2
中 (1-10GB) 4-8
大 (> 10GB) 8-16

pin_memory:启用后,CPU 数据先拷贝到固定内存,再传输到 GPU,能显著加速。

常见数据增强

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import torchvision.transforms as transforms

train_transform = transforms.Compose([
transforms.RandomResizedCrop(224), # 随机裁剪
transforms.RandomHorizontalFlip(), # 水平翻转
transforms.ColorJitter(brightness=0.2), # 颜色抖动
transforms.ToTensor(), # 转为张量
transforms.Normalize( # ImageNet 标准化
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])

test_transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean, std)
])

训练循环完整代码模板

核心组件

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm

class Trainer:
def __init__(self, model, train_loader, val_loader, criterion, optimizer, device, epochs):
self.model = model.to(device)
self.train_loader = train_loader
self.val_loader = val_loader
self.criterion = criterion
self.optimizer = optimizer
self.device = device
self.epochs = epochs
self.best_val_loss = float('inf')
self.history = {'train': [], 'val': []}

def train_epoch(self):
self.model.train()
total_loss = 0
correct = 0
total = 0

pbar = tqdm(self.train_loader, desc='Training')
for batch_idx, (data, target) in enumerate(pbar):
data, target = data.to(self.device), target.to(self.device)

# 前向传播
self.optimizer.zero_grad()
output = self.model(data)
loss = self.criterion(output, target)

# 反向传播
loss.backward()
self.optimizer.step()

# 统计
total_loss += loss.item()
_, predicted = output.max(1)
total += target.size(0)
correct += predicted.eq(target).sum().item()

# 更新进度条
pbar.set_postfix({
'loss': f'{loss.item():.4f}',
'acc': f'{100.*correct/total:.2f}%'
})

return total_loss / len(self.train_loader), 100. * correct / total

def validate(self):
self.model.eval()
total_loss = 0
correct = 0
total = 0

with torch.no_grad():
for data, target in self.val_loader:
data, target = data.to(self.device), target.to(self.device)
output = self.model(data)
loss = self.criterion(output, target)

total_loss += loss.item()
_, predicted = output.max(1)
total += target.size(0)
correct += predicted.eq(target).sum().item()

return total_loss / len(self.val_loader), 100. * correct / total

def train(self):
for epoch in range(1, self.epochs + 1):
print(f'\nEpoch {epoch}/{self.epochs}')

train_loss, train_acc = self.train_epoch()
val_loss, val_acc = self.validate()

# 记录历史
self.history['train'].append({'loss': train_loss, 'acc': train_acc})
self.history['val'].append({'loss': val_loss, 'acc': val_acc})

print(f'Train Loss: {train_loss:.4f}, Acc: {train_acc:.2f}%')
print(f'Val Loss: {val_loss:.4f}, Acc: {val_acc:.2f}%')

# 保存最佳模型
if val_loss < self.best_val_loss:
self.best_val_loss = val_loss
self.save_checkpoint('best_model.pth')

使用示例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

# 数据
train_dataset = MyDataset('data/train', transform=train_transform)
val_dataset = MyDataset('data/val', transform=test_transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=64, num_workers=4)

# 模型
model = torchvision.models.resnet18(pretrained=True)
model.fc = nn.Linear(model.fc.in_features, num_classes)

# 损失与优化
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)

# 训练
trainer = Trainer(model, train_loader, val_loader, criterion, optimizer, device, epochs)
trainer.train()

验证集与测试集处理

划分数据集

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
from sklearn.model_selection import train_test_split
import numpy as np

def split_dataset(data, labels, train_ratio=0.8, val_ratio=0.1, test_ratio=0.1):
"""划分训练集、验证集、测试集"""
assert train_ratio + val_ratio + test_ratio == 1.0

# 先划分训练集和其他
train_data, temp_data, train_labels, temp_labels = train_test_split(
data, labels, test_size=(val_ratio + test_ratio), random_state=42
)

# 再划分验证集和测试集
val_size = val_ratio / (val_ratio + test_ratio)
val_data, test_data, val_labels, test_labels = train_test_split(
temp_data, temp_labels, test_size=1-val_size, random_state=42
)

return train_data, val_data, test_data, train_labels, val_labels, test_labels

测试集评估

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
def test_evaluate(model, test_loader, device):
model.eval()
all_preds = []
all_targets = []
all_probs = []

with torch.no_grad():
for data, target in test_loader:
data = data.to(device)
output = model(data)
probs = torch.softmax(output, dim=1)

all_preds.extend(output.argmax(1).cpu().numpy())
all_targets.extend(target.numpy())
all_probs.extend(probs.cpu().numpy())

# 计算指标
from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score

print("Classification Report:")
print(classification_report(all_targets, all_preds))

print("Confusion Matrix:")
print(confusion_matrix(all_targets, all_preds))

# 多分类 AUC
auc = roc_auc_score(all_targets, all_probs, multi_class='ovr')
print(f"mAUC: {auc:.4f}")

早停策略实现

早停(Early Stopping)防止过拟合:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
class EarlyStopping:
def __init__(self, patience=7, min_delta=0, mode='min'):
self.patience = patience
self.min_delta = min_delta
self.mode = mode
self.counter = 0
self.best_score = None
self.early_stop = False

def __call__(self, score):
if self.best_score is None:
self.best_score = score
return False

if self.mode == 'min':
improved = score < (self.best_score - self.min_delta)
else:
improved = score > (self.best_score + self.min_delta)

if improved:
self.best_score = score
self.counter = 0
else:
self.counter += 1
if self.counter >= self.patience:
self.early_stop = True
return True

return False

集成到训练流程:

1
2
3
4
5
6
7
8
9
10
11
12
13
class Trainer:
def __init__(self, ...):
# ... 其他初始化
self.early_stopping = EarlyStopping(patience=10, mode='min')

def train(self):
for epoch in range(1, self.epochs + 1):
# ... 训练和验证 ...

should_stop = self.early_stopping(val_loss)
if should_stop:
print(f"早停触发,epoch {epoch} 停止训练")
break

模型保存与加载

两种保存方式的区别

1
2
3
4
5
6
7
8
# 方式1:保存整个模型(不推荐)
torch.save(model, 'model.pth')
model = torch.load('model.pth')

# 方式2:只保存参数(推荐)
torch.save(model.state_dict(), 'model.pth')
model = MyModel()
model.load_state_dict(torch.load('model.pth'))

checkpoint 保存

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
def save_checkpoint(self, path, epoch, **kwargs):
checkpoint = {
'epoch': epoch,
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'best_val_loss': self.best_val_loss,
'history': self.history,
**kwargs
}
torch.save(checkpoint, path)

def load_checkpoint(self, path):
checkpoint = torch.load(path, map_location=self.device)
self.model.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
self.best_val_loss = checkpoint['best_val_loss']
self.history = checkpoint['history']
return checkpoint['epoch']

混合精度训练基础

torch.cuda.amp 提供自动混合精度,显著减少显存占用:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
from torch.cuda.amp import autocast, GradScaler

class TrainerWithAMP:
def __init__(self, ...):
self.scaler = GradScaler()

def train_epoch(self):
self.model.train()
for data, target in self.train_loader:
data, target = data.to(self.device), target.to(self.device)

self.optimizer.zero_grad()

# 自动 FP16 计算
with autocast():
output = self.model(data)
loss = self.criterion(output, target)

# 梯度缩放,防止下溢
self.scaler.scale(loss).backward()
self.scaler.step(self.optimizer)
self.scaler.update()

使用 torchrunpython -m torch.distributed.launch 启用多卡训练:

1
torchrun --nproc_per_node=4 train.py

实战:图像分类训练完整流程

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
def main():
import torchvision

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 数据
train_dataset = torchvision.datasets.CIFAR10(
root='./data', train=True, download=True,
transform=train_transform
)
val_dataset = torchvision.datasets.CIFAR10(
root='./data', train=False, download=True,
transform=test_transform
)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=256, num_workers=4)

# 模型:轻量级 CNN
model = torchvision.models.resnet18(weights='IMAGENET1K_V1')
model.fc = nn.Linear(model.fc.in_features, 10)

# 训练配置
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=1e-3)
scheduler = optim.lr_scheduler.OneCycleLR(
optimizer, max_lr=1e-2, epochs=20, steps_per_epoch=len(train_loader)
)

# 训练
trainer = Trainer(model, train_loader, val_loader, criterion, optimizer, device, epochs=20)
trainer.train()

# 测试
test_evaluate(model, val_loader, device)

if __name__ == '__main__':
main()

总结

完整的 PyTorch 训练流程要点:

组件 作用 关键点
Dataset 数据抽象 __len__, __getitem__
DataLoader 批次加载 batch_size, num_workers, pin_memory
训练循环 前向→loss→反向→更新 zero_grad() 在 backward 前
验证 评估模型 model.eval(), torch.no_grad()
早停 防止过拟合 patience 合理设置
Checkpoint 保存训练状态 保存 optimizer, epoch 等
AMP 节省显存 需要 CUDA 支持

掌握这套模板,能够快速搭建各种深度学习任务的训练管道。