嘘~ 正在从服务器偷取页面 . . .

CV模型训练流程总结-Pytorch Lightning(一)


Pipeline

概览图
# 示例运行命令如下
python tools/run.py fit --config configs/cityscapes_darkzurich/refign_daformer.yaml --trainer.gpus [0] --trainer.precision 16

模型的入口,即run.py其实是实例化了一个参数解析器,Lightning自己改进python原始的argparse,即LightningCLI,这个参数解析器既可以从命令行,也可以使用yaml获取模型、数据集、trainer的参数。

fit是训练+验证的子命令,还有validatetestpredict,用来分离不同的训练阶段。整体的逻辑大概是LightningCLI解析参数后,框架根据参数实例化trainertrainer再根据fit还是validate等执行对应的训练逻辑,包括数据的处理和加载,模型的前向传播、反向传播、梯度更新等,最后利用Logger来记录试验结果,利用Callback来执行回调函数如EarlyStopping等。

入口-LightningCLI

python内置的ArgumentParser使得代码的运行可以带有命令行参数,但是在大型项目中很不方便,因为参数会非常多,无论是在python代码里配置参数,还是在命令行指定参数都很复杂

from argparse import ArgumentParser
parser = ArgumentParser()
# Trainer arguments
parser.add_argument("--devices", type=int, default=2)
# Hyperparameters for the model
parser.add_argument("--layer_1_dim", type=int, default=128)
# Parse the user inputs and defaults (returns a argparse.Namespace)
args = parser.parse_args()
# Use the parsed arguments in your program
trainer = Trainer(devices=args.devices)
model = MyModel(layer_1_dim=args.layer_1_dim)
# 运行
python trainer.py --layer_1_dim 64 --devices 1

为此,pytorch Lightning提供了LightningCLI,使得一切参数可以用,简单实现CLI的方式

# main.py
from lightning.pytorch.cli import LightningCLI
# simple demo classes for your convenience
from lightning.pytorch.demos.boring_classes import DemoModel, BoringDataModule
def cli_main():
    cli = LightningCLI(DemoModel, BoringDataModule)
    # note: don't call fit!!
if __name__ == "__main__":
    cli_main()
    # note: it is good practice to implement the CLI in a function and call it in the main if block

可以查看帮助python main.py --help

python main.py [subcommand] --help subcommand指fit、validate、test、predict,根据你的训练需求来执行,这个命令能查看Lightning module和lightningDataModule的参数

With the Lightning CLI enabled, you can now change the parameters without touching your code:python main.py fit --model.learning_rate 0.1,当然也可以使用config来覆盖

[!TIP]

The options that become available in the CLI are the __init__ parameters of the LightningModule and LightningDataModule classes. Thus, to make hyperparameters configurable, just add them to your class’s __init__. It is highly recommended that these parameters are described in the docstring so that the CLI shows them in the help. Also, the parameters should have accurate type hints so that the CLI can fail early and give understandable error messages when incorrect values are given.

使用cli混合不同的模型和数据集

这个要替换的是如下的切换模型或者数据集的逻辑:

# choose model
if args.model == "gan":
    model = GAN(args.feat_dim)
elif args.model == "transformer":
    model = Transformer(args.feat_dim)
# choose datamodule
if args.data == "MNIST":
    datamodule = MNIST()
elif args.data == "imagenet":
    datamodule = Imagenet()
# mix them!
trainer.fit(model, datamodule)

通过cli只需这样,省略掉model_class这个参数,DataModule类似:

# main.py
from lightning.pytorch.cli import LightningCLI
from lightning.pytorch.demos.boring_classes import DemoModel, BoringDataModule
class Model1(DemoModel):
    def configure_optimizers(self):
        print("⚡", "using Model1", "⚡")
        return super().configure_optimizers()
class Model2(DemoModel):
    def configure_optimizers(self):
        print("⚡", "using Model2", "⚡")
        return super().configure_optimizers()
cli = LightningCLI(datamodule_class=BoringDataModule)

[!TIP]

Instead of omitting the datamodule_class parameter, you can give a base class and subclass_mode_data=True. This will make the CLI only accept data modules that are a subclass of the given base class.

# use Model1
python main.py fit --data FakeDataset1
# use Model2
python main.py fit --data FakeDataset2

使用cli混合不同的优化器和lr_scheduler

Any custom subclass of torch.optim.Optimizer can be used as an optimizer:

# main.py
import torch
from lightning.pytorch.cli import LightningCLI
from lightning.pytorch.demos.boring_classes import DemoModel, BoringDataModule
class LitAdam(torch.optim.Adam):
    def step(self, closure):
        print("⚡", "using LitAdam", "⚡")
        super().step(closure)
class FancyAdam(torch.optim.Adam):
    def step(self, closure):
        print("⚡", "using FancyAdam", "⚡")
        super().step(closure)
cli = LightningCLI(DemoModel, BoringDataModule)

学习率策略

# main.py
import torch
from lightning.pytorch.cli import LightningCLI
from lightning.pytorch.demos.boring_classes import DemoModel, BoringDataModule
class LitLRScheduler(torch.optim.lr_scheduler.CosineAnnealingLR):
    def step(self):
        print("⚡", "using LitLRScheduler", "⚡")
        super().step()
cli = LightningCLI(DemoModel, BoringDataModule)

接受Standard learning rate schedulers from torch.optim.lr_scheduler,必须先指定optimizer才能指定学习率策略

Classes from any package

在前面的部分中,要选择的自定义类是在运行 LightningCLI 类的同一 python 文件中定义的。 要仅使用类名从任何包中选择类,请导入相应的包:

from lightning.pytorch.cli import LightningCLI
import my_code.models  # noqa: F401
import my_code.data_modules  # noqa: F401
import my_code.optimizers  # noqa: F401
cli = LightningCLI()

Help for specific classes

当接受多个模型或数据集时,CLI 的主要帮助不包括其特定参数,需要额外指定,例如python main.py fit --model.help Model1

Control it all via YAML

随着项目变得越来越复杂,可配置选项的数量变得非常大,使得通过单独的命令行参数进行控制变得不方便。 为了解决这个问题,使用 LightningCLI 实现的 CLI 始终支持从配置文件接收输入。 配置文件使用的默认格式是 YAML

python main.py fit --config config.yaml --trainer.max_epochs 100单独的参数会覆盖config里面的设置

Lightning会自动在日志里保存配置参数,从而帮助模型的可复现性。这是通过SaveConfigCallback实现的,该回调函数被自动加入Trainer,可通过参数设置调节其行为。

python main.py fit --config lightning_logs/version_7/config.yaml

从头编写yaml配置可能很复杂,可以使用python main.py fit --print_config打印配置再进行修改。对于混合模型,需要额外指定,如下python main.py fit --model DemoModel --print_config

配置项可以是简单的 Python 对象(例如 int 和 str),也可以是由 class_pathinit_args参数组成的复杂对象。class_path是指项目类的完整导入路径,而init_args是要传递给类构造函数的参数。 例如:

# model.py
class MyModel(L.LightningModule):
    def __init__(self, criterion: torch.nn.Module):
        self.criterion = criterion
# config.yaml
model:
  class_path: model.MyModel
  init_args:
    criterion:
      class_path: torch.nn.CrossEntropyLoss
      init_args:
        reduction: mean
    ...

LightningCLI 在底层使用jsonargparse解析配置文件自动创建对象,无需额外创建对象,也无需在cli.py导入对应的类了。

使用多个配置文件,按顺序解析

# config_1.yaml
trainer:
  num_epochs: 10
  ...
# config_2.yaml
trainer:
  num_epochs: 20 
  ...
# python main.py fit --config config_1.yaml --config config_2.yaml 
# The value from the last config will be used

使用参数组

# trainer.yaml
num_epochs: 10
# model.yaml
out_dim: 7
# data.yaml
data_dir: ./data
# $ python main.py fit --trainer trainer.yaml --model model.yaml --data data.yaml [...]

Multiple models and/or datasets

CLI 可以编写为由导入路径和初始化参数指定模型和/或数据模块。 例如,使用实现为以下形式的工具:

cli = LightningCLI(MyModelBaseClass, MyDataModuleBaseClass, subclass_mode_model=True, subclass_mode_data=True)

可能配置如下:

model:
  class_path: mycode.mymodels.MyModel
  init_args:
    decoder_layers:
    - 2
    - 4
    encoder_layers: 12
data:
  class_path: mycode.mydatamodules.MyDataModule
  init_args:
    ...
trainer:
  callbacks:
    - class_path: lightning.pytorch.callbacks.EarlyStopping
      init_args:
        patience: 5
    ...

仅允许使用属于 MyModelBaseClass 子类的模型类,同样,仅允许使用 MyDataModuleBaseClass 的子类。 如果给出 LightningModule LightningDataModule 基类,则 CLI 将允许任何Lightning模块和数据模块。

Models with multiple submodules

许多用例需要有多个模块,每个模块都有自己的可配置选项。 使用 LightningCLI 处理此问题的一种可能方法是实现一个将每个子模块作为初始化参数的模块。 这称为依赖注入,这是改进代码库解耦的好方法。

class MyMainModel(LightningModule):
    def __init__(self, encoder: nn.Module, decoder: nn.Module):
        """Example encoder-decoder submodules model
        Args:
            encoder: Instance of a module for encoding
            decoder: Instance of a module for decoding
        """
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

配置如下:

#If the CLI is implemented as LightningCLI(MyMainModel)
model:
  encoder:
    class_path: mycode.myencoders.MyEncoder
    init_args:
      ...
  decoder:
    class_path: mycode.mydecoders.MyDecoder
    init_args:
      ...
# If subclass_mode_model=True
model:
    class_path:mymodel.MyMainModel
    init_args:
    encoder:
      class_path: mycode.myencoders.MyEncoder
      init_args:
        ...
    decoder:
      class_path: mycode.mydecoders.MyDecoder
      init_args:
        ...

Fixed optimizer and scheduler

在某些情况下,可能希望固定优化器和/或学习率调度器,而不是允许多个选择。为此,你可以通过子类化CLI来手动添加特定类的参数。

class MyLightningCLI(LightningCLI):
    def add_arguments_to_parser(self, parser):
        parser.add_optimizer_args(torch.optim.Adam)
        parser.add_lr_scheduler_args(torch.optim.lr_scheduler.ExponentialLR)
# config.yaml 直接指定,移除了model
optimizer:
  lr: 0.01
lr_scheduler:
  gamma: 0.2
model:
  ...
trainer:
  ...

上面的写法是指定了哪个优化器,也可以用类似基类的方式加class_pathinit_args

class ConditioningLightningCLI(LightningCLI):
    # OPTIMIZER_REGISTRY.classes就是获取被注册过的类
    # nested_key是配置文件中最上层的命名空间的名字
    # 向parse传递额外的参数
    # link到model的optimizer_init和lr_scheduler_init
    def add_arguments_to_parser(self, parser):
        parser.add_optimizer_args(
            OPTIMIZER_REGISTRY.classes, nested_key="optimizer", link_to="model.init_args.optimizer_init")
        parser.add_lr_scheduler_args(
            LR_SCHEDULER_REGISTRY.classes, nested_key="lr_scheduler", link_to="model.init_args.lr_scheduler_init")
@MODEL_REGISTRY
class DomainAdaptationSegmentationModel(pl.LightningModule):
    def __init__(self,
                 optimizer_init: dict,
                 lr_scheduler_init: dict,
                 backbone: nn.Module,
                 ···
# config.yaml
model:
  class_path: models.DomainAdaptationSegmentationModel
  init_args:
       backbone,# 不用指定optimizer_init和lr_scheduler_init
    ···
optimizer:
  class_path: torch.optim.AdamW
  init_args:
        ···
lr_scheduler:
  class_path: helpers.lr_scheduler.LinearWarmupPolynomialLR
  init_args:
    ···

Data Processing-DataModule

pytorch数据处理包含五个步骤:

1.Download / tokenize / process.

  • 这个步骤涉及从数据源获取原始数据,并进行必要的预处理,比如分词、解码等。

2.Clean and (maybe) save to disk.

  • 清洗数据可能包括去除噪声、处理缺失值、标准化等。
  • 清洗后的数据可以保存到磁盘,以便后续快速加载,而不是每次都重新处理。

3.Load inside Dataet

  • Dataset是PyTorch中用于封装数据的类,它定义了如何从数据集中获取单个样本。
  • LightningDataModule中,你需要定义自己的Dataset类,并在setup方法中实例化它

4.Apply transforms(rotate, tokenize, etc…).

  • PyTorch提供了torchvision.transforms(针对图像)和torchtext.transforms(针对文本)等模块来简化这些操作。
  • LightningDataModule中,你可以在train_dataloaderval_dataloadertest_dataloader方法中应用这些转换。

5.Wrap inside a DataLoader.

  • DataLoader是PyTorch中用于加载数据并提供批处理、多线程/多进程数据加载等功能的类。
  • LightningDataModule中,你需要定义train_dataloaderval_dataloadertest_dataloader方法,分别返回用于训练、验证和测试的数据加载器。

经典的DataModule实现

import lightning as L
from torch.utils.data import random_split, DataLoader
# Note - you must have torchvision installed for this example
from torchvision.datasets import MNIST
from torchvision import transforms
class MNISTDataModule(L.LightningDataModule):
    def __init__(self, data_dir: str = "./"):
        super().__init__()
        self.data_dir = data_dir
        self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
    def prepare_data(self):
              # download data
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)
    def setup(self, stage: str):
        # Assign train/val datasets for use in dataloaders
        if stage == "fit":
            mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(
                mnist_full, [55000, 5000], generator=torch.Generator().manual_seed(42)
            )
        # Assign test dataset for use in dataloader(s)
        if stage == "test":
            self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)
        if stage == "predict":
            self.mnist_predict = MNIST(self.data_dir, train=False, transform=self.transform)
    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=32)
    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=32)
    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=32)
    def predict_dataloader(self):
        return DataLoader(self.mnist_predict, batch_size=32)

prepare_data(self)

使用多个进程(分布式设置)下载和保存数据将导致数据损坏。 Lightning 确保仅在CPU上的单个进程中调用prepare_data(),因此您可以安全地在其中添加下载逻辑。 在多节点训练的情况下,该钩子的执行取决于prepare_data_per_nodesetup()prepare_data之后调用,中间有一个屏障,确保一旦数据准备好并可供使用,所有进程都会继续进行设置。

所以该步骤适合:1下载数据2如nlp中的分词3保存到本地,如将分好的词保存到本地

setup(self, stage: Optional[str] = None)

执行在每个GPU上运行的数据操作。同时接受 stage参数,用来分离设置逻辑如trainer.{fit,validate,test,predict}.

train_dataloader(self)

使用train_dataloader()方法生成训练数据加载器。 通常只需包装在设置中定义的数据集即可。这是Trainer的fit()方法使用的数据加载器。

val_dataloader(self)

同上,是Trainer的fit()validate()方法使用的数据加载器。

test_dataloader(self)

同上,是Trainer的test()方法使用的数据加载器。

predict_dataloader(self)

同上,是Trainer的predict()方法使用的数据加载器。

transfer_batch_to_device(self, batch, device, dataloader_id)

有了dataloader,框架使用该函数将对应数据送入device。如果DataLoader 返回自定义数据结构中的张量,重写这个钩子函数。默认的钩子函数(无需重写)支持的的数据类型包括: torch.Tensor 或任何实现 .to(...)方法的,list dict tuple 。 对于其他任何数据类型,您需要定义如何将数据移动到目标设备(CPU、GPU、TPU……)。

最终dataloader提供给模型的也是打包好的一个batch

def transfer_batch_to_device(self, batch, device, dataloader_idx):
  #batch (Any) – A batch of data that needs to be transferred to a new device.
  #device (device) – The target device as defined in PyTorch.
  #dataloader_idx (int) – The index of the dataloader to which the batch belongs.
    if isinstance(batch, CustomBatch):
        # move all tensors in your custom data structure to the device
        batch.samples = batch.samples.to(device)
        batch.targets = batch.targets.to(device)
    elif dataloader_idx == 0:
        # skip device transfer for the first dataloader or anything you wish
        pass
    else:
        batch = super().transfer_batch_to_device(batch, device, dataloader_idx)
    return batch

on_before_batch_transfer(self, batch, dataloader_idx)

Override to alter or apply batch augmentations to your batch before it is transferred to the device.

You can use self.trainer.training/testing/validating/predicting so that you can add different logic as per your requirement.

def on_before_batch_transfer(self, batch, dataloader_idx):
    batch['x'] = transforms(batch['x'])
    return batch

使用datamodule的方式,trainer会自动依次调用prepare_data,setup,train_dataloader等函数,trainer会免除复杂的训练逻辑,不使用trainer也可以手动执行dm的各个方法来自己写训练逻辑

dm = MNISTDataModule()
model = Model()
trainer.fit(model, datamodule=dm)
trainer.test(datamodule=dm)
trainer.validate(datamodule=dm)
trainer.predict(datamodule=dm)

Model-Lightning Module

首先需要定义自己的网络模块

import os
import torch
from torch import nn
import torch.nn.functional as F
from torchvision import transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
import lightning as L
class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))

    def forward(self, x):
        return self.l1(x)
class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))

    def forward(self, x):
        return self.l1(x)

然后可以使用LightningModule来定义模块的交互、配置优化器,使用该模块的优点是减少了training loop的复杂训练逻辑

class LitAutoEncoder(L.LightningModule):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
    def training_step(self, batch, batch_idx):
          # 训练阶段模块如何交互
        # training_step defines the train loop.
        x, _ = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x)
        return loss #这个需要返回loss
     def validation_step(self, batch, batch_idx):
        # this is the validation loop
        x, _ = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        val_loss = F.mse_loss(x_hat, x)
        self.log("val_loss", val_loss) # 记录loss就行
     def test_step(self, batch, batch_idx):
        # this is the test loop
        x, _ = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        test_loss = F.mse_loss(x_hat, x)
        self.log("test_loss", test_loss) # 记录loss就行
    def configure_optimizers(self):
          # 定义优化器
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

实际使用时利用Trainer来负责处理所有工程问题,并将所需的所有复杂性抽象化以实现规模扩展。

# model
autoencoder = LitAutoEncoder(Encoder(), Decoder())
# train model
trainer = L.Trainer()
trainer.fit(model=autoencoder, train_dataloaders=train_loader)
# Trainer代替的就是以下逻辑
autoencoder = LitAutoEncoder(Encoder(), Decoder())
optimizer = autoencoder.configure_optimizers()
for batch_idx, batch in enumerate(train_loader):# 自动循环
    loss = autoencoder.training_step(batch, batch_idx)
    loss.backward() #自动反向传播
    optimizer.step() # 自动更新参数
    optimizer.zero_grad() # 自动梯度清0
# test the model
trainer.test(model, dataloaders=DataLoader(test_set))
# train with both splits(训练加验证)
trainer = L.Trainer()
trainer.fit(model, train_loader, valid_loader) #把两个dataloader都加入

上述是模型的训练部分(训练、验证、测试),对于推理部分,简单方式如下:

model = LitModel.load_from_checkpoint("best_model.ckpt")
model.eval()
x = torch.randn(1, 64)
with torch.no_grad():
    y_hat = model(x)

上面的推理逻辑还是很复杂,LightningModule提供predict_step来解决:

class MyModel(LightningModule):
    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        return self(batch)
# 推理如下
data_loader = DataLoader(...)
model = MyModel()
trainer = Trainer()
predictions = trainer.predict(model, data_loader)

还可以添加更复杂的预处理或后处理逻辑,例如:

class LitMCdropoutModel(L.LightningModule):
    def __init__(self, model, mc_iteration):
        super().__init__()
        self.model = model
        self.dropout = nn.Dropout()
        self.mc_iteration = mc_iteration
    def predict_step(self, batch, batch_idx):
        # enable Monte Carlo Dropout
        self.dropout.train()
        # take average of `self.mc_iteration` iterations
        pred = [self.dropout(self.model(x)).unsqueeze(0) for _ in range(self.mc_iteration)]
        pred = torch.vstack(pred).mean(dim=0)
        return pred

分布式推理

通过使用Lightningpredict_step,可以使用分布式推理通过BasePredictionWriter.

import torch
from lightning.pytorch.callbacks import BasePredictionWriter
class CustomWriter(BasePredictionWriter):
    def __init__(self, output_dir, write_interval):
        super().__init__(write_interval)
        self.output_dir = output_dir
    def write_on_epoch_end(self, trainer, pl_module, predictions, batch_indices):
        # this will create N (num processes) files in `output_dir` each containing
        # the predictions of it's respective rank
        torch.save(predictions, os.path.join(self.output_dir, f"predictions_{trainer.global_rank}.pt"))
        # optionally, you can also save `batch_indices` to get the information about the data index
        # from your prediction data
        torch.save(batch_indices, os.path.join(self.output_dir, f"batch_indices_{trainer.global_rank}.pt"))
# or you can set `write_interval="batch"` and override `write_on_batch_end` to save
# predictions at batch level
pred_writer = CustomWriter(output_dir="pred_path", write_interval="epoch")
trainer = Trainer(accelerator="gpu", strategy="ddp", devices=8, callbacks=[pred_writer])
model = BoringModel()
trainer.predict(model, return_predictions=False)

Coming up next

之后将学习Lightning关于Trainer以及如callback、调试方法、模型优化、可视化等涉及训练逻辑优化方法的相关技术。

Citation

Lightning github地址

官方文档

AEDA代码

Pytorch Lightning 完全攻略

About Me

个人博客:月源

知乎文章:月源

公众号:月源的算法仙蛊屋


文章作者: dch
版权声明: 本博客所有文章除特別声明外,均采用 CC BY 4.0 许可协议。转载请注明来源 dch !
评论
  目录