PyTorch 的自动求导(Autograd)是其核心特性之一,也是深度学习框架最重要的功能。理解 Autograd 的工作原理,对于调试模型、排查梯度问题、优化训练过程至关重要。
计算图的概念 PyTorch 使用动态计算图 (Dynamic Computation Graph)来记录算子操作,从而支持自动求导。
什么是计算图 计算图是一种有向无环图(DAG),节点表示变量(张量),边表示运算:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 import torchx = torch.tensor([1.0 ], requires_grad=True ) y = torch.tensor([2.0 ], requires_grad=True ) z = x * y + x**2
在这个例子中,z 是关于 x 和 y 的函数:z = x*y + x²
叶子 节点与非叶子节点
叶子节点 :requires_grad=True 且由用户直接创建的张量
非叶子节点 :由运算自动生成的张量
1 2 3 4 5 6 7 a = torch.tensor([1.0 ], requires_grad=True ) b = a * 2 c = b * 2 print (a.is_leaf) print (b.is_leaf) print (c.is_leaf)
backward() 只在叶子节点上填充梯度,非叶子节点的梯度在计算完成后会被清空以节省内存。
require_grad 与叶子节点 张量的 requires_grad 属性 requires_grad 决定张量是否参与求导:
1 2 3 4 5 6 7 8 x = torch.tensor([1.0 ]) x = torch.tensor([1.0 ], requires_grad=True ) x.requires_grad_(True )
梯度的追踪范围 所有参与运算的张量(requires_grad=True)都会记录操作:
1 2 3 4 5 6 7 8 9 10 11 x = torch.tensor([1.0 ], requires_grad=True ) y = torch.tensor([2.0 ], requires_grad=True ) z = x * y + x**2 z.backward() print (x.grad) print (y.grad)
梯度的关闭 1 2 3 4 5 6 7 8 9 x = torch.tensor([1.0 ], requires_grad=False ) with torch.no_grad(): y = x * 2 y = x.detach()
backward() 的调用机制 backward() 是触发自动求导的核心方法。
基本用法 1 2 3 4 5 6 7 8 9 10 11 x = torch.tensor([2.0 ], requires_grad=True ) y = x ** 2 y.backward() print (x.grad) x = torch.tensor([2.0 ], requires_grad=True ) y = x ** 2 z = y ** 2 z.backward() print (x.grad)
非标量输出 当输出不是标量时,需要传入梯度参数 (与输出同形的张量):
1 2 3 4 5 6 7 8 9 x = torch.tensor([1.0 , 2.0 , 3.0 ], requires_grad=True ) y = x * 2 y.backward(torch.ones_like(y)) print (x.grad)
实际场景:loss.backward() 通常在损失是标量时使用;如果损失是向量(如多个样本的损失),需要在优化器中先做平均或求和。
backward 的流程 1 2 3 4 5 6 7 8 9 10 11 12 13 def backward (self, gradient=None ): if gradient is None : gradient = torch.ones_like(self) for node in reversed (topological_order): grad = gradient * node.grad_fn(node.inputs) if node.is_leaf: node.grad += grad
梯度累积与清零 梯度是累加的 PyTorch 默认累加梯度 ,而不是覆盖:
1 2 3 4 5 6 7 8 9 10 11 x = torch.tensor([1.0 ], requires_grad=True ) y1 = x * 2 y1.backward() print (x.grad) y2 = x * 3 y2.backward() print (x.grad)
正确的梯度清零 在每个训练步骤开始前,需要清零梯度:
1 2 3 4 5 6 7 8 model.zero_grad() x.grad = None x.grad = 0
典型的训练循环:
1 2 3 4 5 6 for data, target in dataloader: optimizer.zero_grad() output = model(data) loss = criterion(output, target) loss.backward() optimizer.step()
retain_graph 参数 backward() 默认在执行后释放计算图。如果需要多次求导,需要保留计算图:
1 2 3 4 5 6 7 8 9 10 11 x = torch.tensor([1.0 ], requires_grad=True ) y = x ** 2 z = y ** 2 z.backward(retain_graph=True ) print (x.grad) z.backward() print (x.grad)
hook 函数 PyTorch 提供两种 hook 用于检查和修改前向/反向传播过程。
register_hook 在张量上注册梯度 hook:
1 2 3 4 5 6 7 8 9 10 x = torch.tensor([1.0 ], requires_grad=True ) y = x * 2 hook_handle = y.register_hook(lambda grad: print (f"梯度: {grad} " )) y.backward() hook_handle.remove()
常用场景:梯度裁剪 1 2 3 4 5 6 7 8 for p in model.parameters(): torch.nn.utils.clip_grad_norm_(p, max_norm=1.0 ) for p in model.parameters(): if p.grad is not None : print (f"{p.name} : grad_norm = {p.grad.norm()} " )
前向 hook 与反向 hook 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 def forward_hook (module, input , output ): print (f"前向传播: {module.__class__.__name__} " ) print (f" 输入: {[i.shape for i in input ]} " ) print (f" 输出: {output.shape} " ) def backward_hook (module, grad_input, grad_output ): print (f"反向传播: {module.__class__.__name__} " ) print (f" grad_output: {grad_output} " ) layer = torch.nn.Linear(10 , 5 ) layer.register_forward_hook(forward_hook) layer.register_full_backward_hook(backward_hook)
自定义 autograd 函数 对于没有内置导数的自定义运算,可以通过继承 Function 实现自动求导:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 from torch.autograd import Functionclass StepFunction (Function ): """阶跃函数:x > 0 返回 1,否则返回 0""" @staticmethod def forward (ctx, x ): ctx.save_for_backward(x) return (x > 0 ).float () @staticmethod def backward (ctx, grad_output ): x, = ctx.saved_tensors return grad_output * 0 x = torch.tensor([1.0 , -1.0 , 2.0 ], requires_grad=True ) y = StepFunction.apply(x) print (y) y.sum ().backward() print (x.grad)
常见求导问题排查 问题1:梯度为 None 1 2 3 4 5 6 7 x = torch.tensor([1.0 ]) y = x * 2 y.backward() print (x.grad) x = torch.tensor([1.0 ], requires_grad=True )
问题2:非叶子节点梯度不可访问 1 2 3 4 5 6 7 8 9 10 11 12 13 14 x = torch.tensor([1.0 ], requires_grad=True ) y = x * 2 z = y * 2 print (y.grad) y.retain_grad() z.backward() print (y.grad) y = x * 2 y = y.detach().requires_grad_()
问题3:in-place 操作导致的问题 1 2 3 4 5 6 7 8 9 10 x = torch.tensor([1.0 ], requires_grad=True ) x = x + 1 x.data += 1
问题4:梯度爆炸/消失 1 2 3 4 5 6 7 8 9 torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0 ) for name, param in model.named_parameters(): if param.grad is not None : grad_norm = param.grad.norm() if grad_norm > 10 : print (f"警告:{name} 梯度范数过大: {grad_norm} " )
总结 PyTorch 自动求导的核心要点:
计算图 :动态构建有向无环图,记录运算过程
叶子节点 :用户创建且 requires_grad=True 的张量
**backward()**:从输出反向传播计算梯度
梯度清零 :每个训练步骤前必须清零
hook 函数 :检查和修改前向/反向传播
自定义 autograd :通过 Function 类扩展
理解这些机制,能够更好地调试模型和排查问题。