第四十二章 快速进入PyTorch(工具)

本文详细介绍了如何使用PyTorchLightning框架组织代码,包括安装、定义LightningModule、配置Trainer、自动与手动优化、部署预测、数据处理、日志记录、回调机制、DataModules和调试工具。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

目录

一、简介

二、安装 PyTorch Lightning

三、定义 LightningModule

3.1 SYSTEM VS MODEL

3.2 FORWARD vs TRAINING_STEP

三、配置 Lightning Trainer

四、基本特性

4.1 Manual vs automatic optimization

4.1.1 自动优化 (Automatic optimization)

4.1.1 手动优化 (Manual optimization)

4.2 Predict or Deploy

4.2.1 选项一 —— 子模型 (Sub-models)

4.2.2 选项二 —— 前馈 (Forward)

4.2.2 选项三 —— 生产 (Production)

4.3 Using CPUs/GPUs/TPUs

4.4 Checkpoints

4.5 Data flow

4.6 Logging

4.7 Optional extensions

4.7.1 回调 (Callbacks)

4.7.2 LightningDataModules

4.8 Debugging

五、其他炫酷特性


项目地址:https://github.com/PyTorchLightning/pytorch-lightning


一、简介

本指南将展示如何分两步将 PyTorch 代码组织成 Lightning。

使用 PyTorch Lightning 组织代码,可以使代码:

  • 保留所有灵活性(这全是纯 PyTorch),但去除了大量样板(boilerplate)
  • 将研究代码与工程解耦,更具可读性
  • 更容易复现
  • 通过自动化大多数训练循环和棘手的工程设计,减少了出错的可能性
  • 可扩展到任何硬件而无需更改模型

二、安装 PyTorch Lightning

pip 安装:

pip install pytorch-lightning

或 conda 安装:

conda install pytorch-lightning -c conda-forge

或在 conda 虚拟环境下安装:

conda activate my_env

pip install pytorch-lightning

在新源文件中导入以下将用到的依赖:

import os

import torch

from torch import nn

import torch.nn.functional as F

from torchvision.datasets import MNIST

from torchvision import transforms

from torch.utils.data import DataLoader

import pytorch_lightning as pl

from torch.utils.data import random_split

三、定义 LightningModule

class LitAutoEncoder(pl.LightningModule):

 

    def __init__(self):

        super().__init__()

        self.encoder = nn.Sequential(

            nn.Linear(28*28, 64),

            nn.ReLU(),

            nn.Linear(64, 3)

        )

        self.decoder = nn.Sequential(

            nn.Linear(3, 64),

            nn.ReLU(),

            nn.Linear(64, 28*28)

        )

 

    def forward(self, x):

        # in lightning, forward defines the prediction/inference actions

        embedding = self.encoder(x)

        return embedding

 

    def training_step(self, batch, batch_idx):

        # training_step defined the train loop.

        # It is independent of forward

        x, y = batch

        x = x.view(x.size(0), -1)

        z = self.encoder(x)

        x_hat = self.decoder(z)

        loss = F.mse_loss(x_hat, x)

        # Logging to TensorBoard by default

        self.log('train_loss', loss)

        return loss

 

    def configure_optimizers(self):

        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)

        return optimizer

3.1 SYSTEM VS MODEL

注意,LightningModule 定义了一个系统 (system) 而不仅仅是一个模型 (model):

img

关于系统 (system) 的例子还有:

在内部,LightningModule 仍只是一个 torch.nn.Module,它将所有研究代码分组到一个文件中以使其自成一体:

  • The Train loop
  • The Validation loop
  • The Test loop
  • The Model or system of Models
  • The Optimizer

可以通过覆盖 Available Callback hooks 中找到的 20+ 个 hooks 中的任意一个,来自定义任何训练部分 (如反向传播):

class LitAutoEncoder(pl.LightningModule):

 

    def backward(self, loss, optimizer, optimizer_idx):

        loss.backward()

3.2 FORWARD vs TRAINING_STEP

在 Lightning 中,我们将训练与推理分开。training_step 定义了完整的训练循环。鼓励用户用 forward 定义推理行为。

例如,在这种情况下,可以定义自动编码器以充当嵌入提取器 (embedding extractor):

def forward(self, x):

    embeddings = self.encoder(x)

    return embeddings

当然,没有什么可以阻止你在 training_step 中使用 forward:

def training_step(self, batch, batch_idx):

    ...

    z = self(x)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

小酒馆燃着灯

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值