PyTorch 自动求导机制详解:从 backward 到计算图

PyTorch 的自动求导(Autograd)是其核心特性之一,也是深度学习框架最重要的功能。理解 Autograd 的工作原理,对于调试模型、排查梯度问题、优化训练过程至关重要。

计算图的概念

PyTorch 使用动态计算图(Dynamic Computation Graph)来记录算子操作,从而支持自动求导。

什么是计算图

计算图是一种有向无环图(DAG),节点表示变量(张量),边表示运算:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
import torch

x = torch.tensor([1.0], requires_grad=True)
y = torch.tensor([2.0], requires_grad=True)
z = x * y + x**2

# 隐式构建的计算图:
# x ──► ◄── y
# \ /
# mul
# │
# x ──► mul2 ◄── x
# │
# add ──► z

在这个例子中,z 是关于 xy 的函数:z = x*y + x²

叶子 节点与非叶子节点

  • 叶子节点requires_grad=True 且由用户直接创建的张量
  • 非叶子节点:由运算自动生成的张量
1
2
3
4
5
6
7
a = torch.tensor([1.0], requires_grad=True)  # leaf
b = a * 2 # non-leaf
c = b * 2 # non-leaf

print(a.is_leaf) # True
print(b.is_leaf) # False
print(c.is_leaf) # False

backward() 只在叶子节点上填充梯度,非叶子节点的梯度在计算完成后会被清空以节省内存。

require_grad 与叶子节点

张量的 requires_grad 属性

requires_grad 决定张量是否参与求导:

1
2
3
4
5
6
7
8
# 默认不追踪梯度
x = torch.tensor([1.0]) # requires_grad=False

# 追踪梯度
x = torch.tensor([1.0], requires_grad=True)

# 或者后续修改
x.requires_grad_(True)

梯度的追踪范围

所有参与运算的张量(requires_grad=True)都会记录操作:

1
2
3
4
5
6
7
8
9
10
11
x = torch.tensor([1.0], requires_grad=True)
y = torch.tensor([2.0], requires_grad=True)

# z 依赖于 x 和 y
z = x * y + x**2

# 计算 ∂z/∂x 和 ∂z/∂y
z.backward()

print(x.grad) # dy/dx = y + 2x = 2 + 2 = 4
print(y.grad) # dz/dy = x = 1

梯度的关闭

1
2
3
4
5
6
7
8
9
# 方式1:创建时指定
x = torch.tensor([1.0], requires_grad=False)

# 方式2:临时关闭(上下文管理器)
with torch.no_grad():
y = x * 2 # 这个运算不追踪梯度

# 方式3:detach()
y = x.detach() # 返回一个共享内存但 requires_grad=False 的张量

backward() 的调用机制

backward() 是触发自动求导的核心方法。

基本用法

1
2
3
4
5
6
7
8
9
10
11
x = torch.tensor([2.0], requires_grad=True)
y = x ** 2 # y = 4
y.backward() # dy/dx = 2x = 4
print(x.grad) # tensor([4.])

# 对于标量,不需要传参数
x = torch.tensor([2.0], requires_grad=True)
y = x ** 2
z = y ** 2 # z = 16
z.backward() # dz/dz = 1(隐式),然后链式法则
print(x.grad) # dz/dx = dz/dy * dy/dx = 2y * 2x = 2*4*2*2 = 32

非标量输出

当输出不是标量时,需要传入梯度参数(与输出同形的张量):

1
2
3
4
5
6
7
8
9
# 输出是向量,不能直接 backward
x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
y = x * 2 # y = [2, 4, 6]

# y.backward() # 错误!

# 正确做法:传入等长的梯度向量(通常全 1)
y.backward(torch.ones_like(y))
print(x.grad) # tensor([2., 2., 2.])

实际场景:loss.backward() 通常在损失是标量时使用;如果损失是向量(如多个样本的损失),需要在优化器中先做平均或求和。

backward 的流程

1
2
3
4
5
6
7
8
9
10
11
12
13
# backward 内部简化逻辑
def backward(self, gradient=None):
# 1. 如果没指定 gradient,默认全 1
if gradient is None:
gradient = torch.ones_like(self)

# 2. 从当前节点开始,沿着计算图反向传播
for node in reversed(topological_order):
# 计算当前节点的梯度
grad = gradient * node.grad_fn(node.inputs)
# 累加到叶子节点
if node.is_leaf:
node.grad += grad

梯度累积与清零

梯度是累加的

PyTorch 默认累加梯度,而不是覆盖:

1
2
3
4
5
6
7
8
9
10
11
x = torch.tensor([1.0], requires_grad=True)

# 第一次前向
y1 = x * 2
y1.backward()
print(x.grad) # tensor([2.])

# 第二次前向(不清零梯度)
y2 = x * 3
y2.backward()
print(x.grad) # tensor([5.]) # 累加:2 + 3 = 5

正确的梯度清零

在每个训练步骤开始前,需要清零梯度:

1
2
3
4
5
6
7
8
# 错误方式
model.zero_grad() # 清零所有参数的梯度

# 或者
x.grad = None # 只清零 x 的梯度

# 不要这样
x.grad = 0 # 这是赋值,不是清零

典型的训练循环:

1
2
3
4
5
6
for data, target in dataloader:
optimizer.zero_grad() # Step 1: 清零梯度
output = model(data) # Step 2: 前向传播
loss = criterion(output, target)
loss.backward() # Step 3: 反向传播
optimizer.step() # Step 4: 更新参数

retain_graph 参数

backward() 默认在执行后释放计算图。如果需要多次求导,需要保留计算图:

1
2
3
4
5
6
7
8
9
10
11
x = torch.tensor([1.0], requires_grad=True)
y = x ** 2
z = y ** 2

# 第一次求导
z.backward(retain_graph=True) # 保留计算图
print(x.grad) # dz/dx = 4*x³ = 4

# 第二次求导(梯度会累加)
z.backward()
print(x.grad) # 8 = 4 + 4

hook 函数

PyTorch 提供两种 hook 用于检查和修改前向/反向传播过程。

register_hook

在张量上注册梯度 hook:

1
2
3
4
5
6
7
8
9
10
x = torch.tensor([1.0], requires_grad=True)
y = x * 2

# 注册 hook
hook_handle = y.register_hook(lambda grad: print(f"梯度: {grad}"))

y.backward()
# 输出: 梯度: tensor([1.])

hook_handle.remove() # 移除 hook

常用场景:梯度裁剪

1
2
3
4
5
6
7
8
# 梯度裁剪:限制梯度范数
for p in model.parameters():
torch.nn.utils.clip_grad_norm_(p, max_norm=1.0)

# 梯度清零后打印
for p in model.parameters():
if p.grad is not None:
print(f"{p.name}: grad_norm = {p.grad.norm()}")

前向 hook 与反向 hook

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# 前向 hook
def forward_hook(module, input, output):
print(f"前向传播: {module.__class__.__name__}")
print(f" 输入: {[i.shape for i in input]}")
print(f" 输出: {output.shape}")

# 反向 hook
def backward_hook(module, grad_input, grad_output):
print(f"反向传播: {module.__class__.__name__}")
print(f" grad_output: {grad_output}")

# 注册到层
layer = torch.nn.Linear(10, 5)
layer.register_forward_hook(forward_hook)
layer.register_full_backward_hook(backward_hook)

自定义 autograd 函数

对于没有内置导数的自定义运算,可以通过继承 Function 实现自动求导:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
from torch.autograd import Function

class StepFunction(Function):
"""阶跃函数:x > 0 返回 1,否则返回 0"""

@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x) # 保存反向传播需要的张量
return (x > 0).float()

@staticmethod
def backward(ctx, grad_output):
x, = ctx.saved_tensors
# 阶跃函数的梯度是 delta 函数,这里返回 0(近似)
return grad_output * 0

# 使用自定义函数
x = torch.tensor([1.0, -1.0, 2.0], requires_grad=True)
y = StepFunction.apply(x)
print(y) # tensor([1., 0., 1.])

y.sum().backward()
print(x.grad) # tensor([0., 0., 0.])

常见求导问题排查

问题1:梯度为 None

1
2
3
4
5
6
7
x = torch.tensor([1.0])  # requires_grad=False(默认)
y = x * 2
y.backward()
print(x.grad) # None

# 解决:确保 requires_grad=True
x = torch.tensor([1.0], requires_grad=True)

问题2:非叶子节点梯度不可访问

1
2
3
4
5
6
7
8
9
10
11
12
13
14
x = torch.tensor([1.0], requires_grad=True)
y = x * 2
z = y * 2

print(y.grad) # None(默认不保留非叶子节点梯度)

# 解决1:使用 retain_grad()
y.retain_grad()
z.backward()
print(y.grad) # tensor([2.])

# 解决2:将 y 设置为叶子节点
y = x * 2
y = y.detach().requires_grad_()

问题3:in-place 操作导致的问题

1
2
3
4
5
6
7
8
9
10
x = torch.tensor([1.0], requires_grad=True)

# 错误:in-place 修改叶子节点
# x += 1 # 报错!

# 正确:使用新的张量
x = x + 1

# 或者使用 data 属性(不推荐)
x.data += 1

问题4:梯度爆炸/消失

1
2
3
4
5
6
7
8
9
# 梯度裁剪
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

# 梯度消失检测
for name, param in model.named_parameters():
if param.grad is not None:
grad_norm = param.grad.norm()
if grad_norm > 10:
print(f"警告:{name} 梯度范数过大: {grad_norm}")

总结

PyTorch 自动求导的核心要点:

  1. 计算图:动态构建有向无环图,记录运算过程
  2. 叶子节点:用户创建且 requires_grad=True 的张量
  3. **backward()**:从输出反向传播计算梯度
  4. 梯度清零:每个训练步骤前必须清零
  5. hook 函数:检查和修改前向/反向传播
  6. 自定义 autograd:通过 Function 类扩展

理解这些机制,能够更好地调试模型和排查问题。