torch.no_grad() 是 PyTorch 中的一个上下文管理器,用于在进入该上下文时禁用梯度计算。这在你只关心评估模型,而不是训练模型时非常有用,因为它可以显著减少内存使用并加速计算。
当你在 torch.no_grad() 上下文管理器中执行张量操作时,PyTorch 不会为这些操作计算梯度。这意味着不会在 .grad 属性中累积梯度,并且操作会更快地执行。
使用torch.no_grad()
import torch
# 创建一个需要梯度的张量
x = torch.tensor([1.0], requires_grad=True)
# 使用 no_grad() 上下文管理器
with torch.no_grad():
y = x * 2
y.backward()
print(x.grad)
输出:
RuntimeError Traceback (most recent call last