Pipeline
data:image/s3,"s3://crabby-images/538f8/538f83821790886e2d34cf2cf301b8375835f5d2" alt="概览图"
# 示例运行命令如下
python tools/run.py fit --config configs/cityscapes_darkzurich/refign_daformer.yaml --trainer.gpus [0] --trainer.precision 16
模型的入口,即run.py
其实是实例化了一个参数解析器,Lightning自己改进python原始的argparse,即LightningCLI
,这个参数解析器既可以从命令行,也可以使用yaml获取模型、数据集、trainer的参数。
fit
是训练+验证的子命令,还有validate
、test
、predict
,用来分离不同的训练阶段。整体的逻辑大概是LightningCLI
解析参数后,框架根据参数实例化trainer
,trainer
再根据fit
还是validate
等执行对应的训练逻辑,包括数据的处理和加载,模型的前向传播、反向传播、梯度更新等,最后利用Logger
来记录试验结果,利用Callback
来执行回调函数如EarlyStopping
等。
入口-LightningCLI
python内置的ArgumentParser使得代码的运行可以带有命令行参数,但是在大型项目中很不方便,因为参数会非常多,无论是在python代码里配置参数,还是在命令行指定参数都很复杂
from argparse import ArgumentParser
parser = ArgumentParser()
# Trainer arguments
parser.add_argument("--devices", type=int, default=2)
# Hyperparameters for the model
parser.add_argument("--layer_1_dim", type=int, default=128)
# Parse the user inputs and defaults (returns a argparse.Namespace)
args = parser.parse_args()
# Use the parsed arguments in your program
trainer = Trainer(devices=args.devices)
model = MyModel(layer_1_dim=args.layer_1_dim)
# 运行
python trainer.py --layer_1_dim 64 --devices 1
为此,pytorch Lightning提供了LightningCLI,使得一切参数可以用,简单实现CLI的方式
# main.py
from lightning.pytorch.cli import LightningCLI
# simple demo classes for your convenience
from lightning.pytorch.demos.boring_classes import DemoModel, BoringDataModule
def cli_main():
cli = LightningCLI(DemoModel, BoringDataModule)
# note: don't call fit!!
if __name__ == "__main__":
cli_main()
# note: it is good practice to implement the CLI in a function and call it in the main if block
可以查看帮助python main.py --help
python main.py [subcommand] --help
subcommand指fit、validate、test、predict,根据你的训练需求来执行,这个命令能查看Lightning module和lightningDataModule的参数
With the Lightning CLI enabled, you can now change the parameters without touching your code:python main.py fit --model.learning_rate 0.1
,当然也可以使用config来覆盖
[!TIP]
The options that become available in the CLI are the
__init__
parameters of theLightningModule
andLightningDataModule
classes. Thus, to make hyperparameters configurable, just add them to your class’s__init__
. It is highly recommended that these parameters are described in the docstring so that the CLI shows them in the help. Also, the parameters should have accurate type hints so that the CLI can fail early and give understandable error messages when incorrect values are given.
使用cli混合不同的模型和数据集
这个要替换的是如下的切换模型或者数据集的逻辑:
# choose model
if args.model == "gan":
model = GAN(args.feat_dim)
elif args.model == "transformer":
model = Transformer(args.feat_dim)
# choose datamodule
if args.data == "MNIST":
datamodule = MNIST()
elif args.data == "imagenet":
datamodule = Imagenet()
# mix them!
trainer.fit(model, datamodule)
通过cli只需这样,省略掉model_class
这个参数,DataModule类似:
# main.py
from lightning.pytorch.cli import LightningCLI
from lightning.pytorch.demos.boring_classes import DemoModel, BoringDataModule
class Model1(DemoModel):
def configure_optimizers(self):
print("⚡", "using Model1", "⚡")
return super().configure_optimizers()
class Model2(DemoModel):
def configure_optimizers(self):
print("⚡", "using Model2", "⚡")
return super().configure_optimizers()
cli = LightningCLI(datamodule_class=BoringDataModule)
[!TIP]
Instead of omitting the
datamodule_class
parameter, you can give a base class andsubclass_mode_data=True
. This will make the CLI only accept data modules that are a subclass of the given base class.
# use Model1
python main.py fit --data FakeDataset1
# use Model2
python main.py fit --data FakeDataset2
使用cli混合不同的优化器和lr_scheduler
Any custom subclass of torch.optim.Optimizer
can be used as an optimizer:
# main.py
import torch
from lightning.pytorch.cli import LightningCLI
from lightning.pytorch.demos.boring_classes import DemoModel, BoringDataModule
class LitAdam(torch.optim.Adam):
def step(self, closure):
print("⚡", "using LitAdam", "⚡")
super().step(closure)
class FancyAdam(torch.optim.Adam):
def step(self, closure):
print("⚡", "using FancyAdam", "⚡")
super().step(closure)
cli = LightningCLI(DemoModel, BoringDataModule)
学习率策略
# main.py
import torch
from lightning.pytorch.cli import LightningCLI
from lightning.pytorch.demos.boring_classes import DemoModel, BoringDataModule
class LitLRScheduler(torch.optim.lr_scheduler.CosineAnnealingLR):
def step(self):
print("⚡", "using LitLRScheduler", "⚡")
super().step()
cli = LightningCLI(DemoModel, BoringDataModule)
接受Standard learning rate schedulers from torch.optim.lr_scheduler
,必须先指定optimizer才能指定学习率策略
Classes from any package
在前面的部分中,要选择的自定义类是在运行 LightningCLI 类的同一 python 文件中定义的。 要仅使用类名从任何包中选择类,请导入相应的包:
from lightning.pytorch.cli import LightningCLI
import my_code.models # noqa: F401
import my_code.data_modules # noqa: F401
import my_code.optimizers # noqa: F401
cli = LightningCLI()
Help for specific classes
当接受多个模型或数据集时,CLI 的主要帮助不包括其特定参数,需要额外指定,例如python main.py fit --model.help Model1
Control it all via YAML
随着项目变得越来越复杂,可配置选项的数量变得非常大,使得通过单独的命令行参数进行控制变得不方便。 为了解决这个问题,使用 LightningCLI 实现的 CLI 始终支持从配置文件接收输入。 配置文件使用的默认格式是 YAML。
python main.py fit --config config.yaml --trainer.max_epochs 100
单独的参数会覆盖config里面的设置
Lightning会自动在日志里保存配置参数,从而帮助模型的可复现性。这是通过SaveConfigCallback
实现的,该回调函数被自动加入Trainer
,可通过参数设置调节其行为。
python main.py fit --config lightning_logs/version_7/config.yaml
从头编写yaml配置可能很复杂,可以使用python main.py fit --print_config
打印配置再进行修改。对于混合模型,需要额外指定,如下python main.py fit --model DemoModel --print_config
配置项可以是简单的 Python 对象(例如 int 和 str),也可以是由 class_path
和init_args
参数组成的复杂对象。class_path
是指项目类的完整导入路径,而init_args
是要传递给类构造函数的参数。 例如:
# model.py
class MyModel(L.LightningModule):
def __init__(self, criterion: torch.nn.Module):
self.criterion = criterion
# config.yaml
model:
class_path: model.MyModel
init_args:
criterion:
class_path: torch.nn.CrossEntropyLoss
init_args:
reduction: mean
...
LightningCLI 在底层使用jsonargparse
来解析配置文件并自动创建对象,无需额外创建对象,也无需在cli.py
去导入对应的类了。
使用多个配置文件,按顺序解析。
# config_1.yaml
trainer:
num_epochs: 10
...
# config_2.yaml
trainer:
num_epochs: 20
...
# python main.py fit --config config_1.yaml --config config_2.yaml
# The value from the last config will be used
使用参数组
# trainer.yaml
num_epochs: 10
# model.yaml
out_dim: 7
# data.yaml
data_dir: ./data
# $ python main.py fit --trainer trainer.yaml --model model.yaml --data data.yaml [...]
Multiple models and/or datasets
CLI 可以编写为由导入路径和初始化参数指定模型和/或数据模块。 例如,使用实现为以下形式的工具:
cli = LightningCLI(MyModelBaseClass, MyDataModuleBaseClass, subclass_mode_model=True, subclass_mode_data=True)
可能配置如下:
model:
class_path: mycode.mymodels.MyModel
init_args:
decoder_layers:
- 2
- 4
encoder_layers: 12
data:
class_path: mycode.mydatamodules.MyDataModule
init_args:
...
trainer:
callbacks:
- class_path: lightning.pytorch.callbacks.EarlyStopping
init_args:
patience: 5
...
仅允许使用属于 MyModelBaseClass
子类的模型类,同样,仅允许使用 MyDataModuleBaseClass 的
子类。 如果给出 LightningModule
和 LightningDataModule
基类,则 CLI 将允许任何Lightning模块和数据模块。
Models with multiple submodules
许多用例需要有多个模块,每个模块都有自己的可配置选项。 使用 LightningCLI 处理此问题的一种可能方法是实现一个将每个子模块作为初始化参数的模块。 这称为依赖注入
,这是改进代码库解耦的好方法。
class MyMainModel(LightningModule):
def __init__(self, encoder: nn.Module, decoder: nn.Module):
"""Example encoder-decoder submodules model
Args:
encoder: Instance of a module for encoding
decoder: Instance of a module for decoding
"""
super().__init__()
self.encoder = encoder
self.decoder = decoder
配置如下:
#If the CLI is implemented as LightningCLI(MyMainModel)
model:
encoder:
class_path: mycode.myencoders.MyEncoder
init_args:
...
decoder:
class_path: mycode.mydecoders.MyDecoder
init_args:
...
# If subclass_mode_model=True
model:
class_path:mymodel.MyMainModel
init_args:
encoder:
class_path: mycode.myencoders.MyEncoder
init_args:
...
decoder:
class_path: mycode.mydecoders.MyDecoder
init_args:
...
Fixed optimizer and scheduler
在某些情况下,可能希望固定优化器和/或学习率调度器,而不是允许多个选择。为此,你可以通过子类化CLI来手动添加特定类的参数。
class MyLightningCLI(LightningCLI):
def add_arguments_to_parser(self, parser):
parser.add_optimizer_args(torch.optim.Adam)
parser.add_lr_scheduler_args(torch.optim.lr_scheduler.ExponentialLR)
# config.yaml 直接指定,移除了model
optimizer:
lr: 0.01
lr_scheduler:
gamma: 0.2
model:
...
trainer:
...
上面的写法是指定了哪个优化器,也可以用类似基类的方式加class_path
和init_args
class ConditioningLightningCLI(LightningCLI):
# OPTIMIZER_REGISTRY.classes就是获取被注册过的类
# nested_key是配置文件中最上层的命名空间的名字
# 向parse传递额外的参数
# link到model的optimizer_init和lr_scheduler_init
def add_arguments_to_parser(self, parser):
parser.add_optimizer_args(
OPTIMIZER_REGISTRY.classes, nested_key="optimizer", link_to="model.init_args.optimizer_init")
parser.add_lr_scheduler_args(
LR_SCHEDULER_REGISTRY.classes, nested_key="lr_scheduler", link_to="model.init_args.lr_scheduler_init")
@MODEL_REGISTRY
class DomainAdaptationSegmentationModel(pl.LightningModule):
def __init__(self,
optimizer_init: dict,
lr_scheduler_init: dict,
backbone: nn.Module,
···
# config.yaml
model:
class_path: models.DomainAdaptationSegmentationModel
init_args:
backbone,# 不用指定optimizer_init和lr_scheduler_init
···
optimizer:
class_path: torch.optim.AdamW
init_args:
···
lr_scheduler:
class_path: helpers.lr_scheduler.LinearWarmupPolynomialLR
init_args:
···
Data Processing-DataModule
pytorch数据处理包含五个步骤:
1.Download / tokenize / process.
- 这个步骤涉及从数据源获取原始数据,并进行必要的预处理,比如分词、解码等。
2.Clean and (maybe) save to disk.
- 清洗数据可能包括去除噪声、处理缺失值、标准化等。
- 清洗后的数据可以保存到磁盘,以便后续快速加载,而不是每次都重新处理。
3.Load inside Dataet
Dataset
是PyTorch中用于封装数据的类,它定义了如何从数据集中获取单个样本。- 在
LightningDataModule
中,你需要定义自己的Dataset
类,并在setup
方法中实例化它
4.Apply transforms(rotate, tokenize, etc…).
- PyTorch提供了
torchvision.transforms
(针对图像)和torchtext.transforms
(针对文本)等模块来简化这些操作。 - 在
LightningDataModule
中,你可以在train_dataloader
、val_dataloader
和test_dataloader
方法中应用这些转换。
5.Wrap inside a DataLoader.
DataLoader
是PyTorch中用于加载数据并提供批处理、多线程/多进程数据加载等功能的类。- 在
LightningDataModule
中,你需要定义train_dataloader
、val_dataloader
和test_dataloader
方法,分别返回用于训练、验证和测试的数据加载器。
经典的DataModule实现
import lightning as L
from torch.utils.data import random_split, DataLoader
# Note - you must have torchvision installed for this example
from torchvision.datasets import MNIST
from torchvision import transforms
class MNISTDataModule(L.LightningDataModule):
def __init__(self, data_dir: str = "./"):
super().__init__()
self.data_dir = data_dir
self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
def prepare_data(self):
# download data
MNIST(self.data_dir, train=True, download=True)
MNIST(self.data_dir, train=False, download=True)
def setup(self, stage: str):
# Assign train/val datasets for use in dataloaders
if stage == "fit":
mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
self.mnist_train, self.mnist_val = random_split(
mnist_full, [55000, 5000], generator=torch.Generator().manual_seed(42)
)
# Assign test dataset for use in dataloader(s)
if stage == "test":
self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)
if stage == "predict":
self.mnist_predict = MNIST(self.data_dir, train=False, transform=self.transform)
def train_dataloader(self):
return DataLoader(self.mnist_train, batch_size=32)
def val_dataloader(self):
return DataLoader(self.mnist_val, batch_size=32)
def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=32)
def predict_dataloader(self):
return DataLoader(self.mnist_predict, batch_size=32)
prepare_data(self)
使用多个进程(分布式设置)下载和保存数据将导致数据损坏。 Lightning 确保仅在CPU上的单个进程中调用prepare_data()
,因此您可以安全地在其中添加下载逻辑。 在多节点训练的情况下,该钩子的执行取决于prepare_data_per_node
。 setup()
在prepare_data
之后调用,中间有一个屏障,确保一旦数据准备好并可供使用,所有进程都会继续进行设置。
所以该步骤适合:1下载数据2如nlp中的分词3保存到本地,如将分好的词保存到本地
setup(self, stage: Optional[str] = None)
执行在每个GPU上运行的数据操作。同时接受 stage
参数,用来分离设置逻辑如trainer.{fit,validate,test,predict}
.
train_dataloader(self)
使用train_dataloader()
方法生成训练数据加载器。 通常只需包装在设置中定义的数据集即可。这是Trainer的fit()
方法使用的数据加载器。
val_dataloader(self)
同上,是Trainer的fit()
和validate()
方法使用的数据加载器。
test_dataloader(self)
同上,是Trainer的test()
方法使用的数据加载器。
predict_dataloader(self)
同上,是Trainer的predict()
方法使用的数据加载器。
transfer_batch_to_device(self, batch, device, dataloader_id)
有了dataloader,框架使用该函数将对应数据送入device。如果DataLoader 返回自定义数据结构中的张量,重写这个钩子函数。默认的钩子函数(无需重写)支持的的数据类型包括: torch.Tensor
或任何实现 .to(...)
方法的,list dict tuple
。 对于其他任何数据类型,您需要定义如何将数据移动到目标设备(CPU、GPU、TPU……)。
最终dataloader提供给模型的也是打包好的一个batch。
def transfer_batch_to_device(self, batch, device, dataloader_idx):
#batch (Any) – A batch of data that needs to be transferred to a new device.
#device (device) – The target device as defined in PyTorch.
#dataloader_idx (int) – The index of the dataloader to which the batch belongs.
if isinstance(batch, CustomBatch):
# move all tensors in your custom data structure to the device
batch.samples = batch.samples.to(device)
batch.targets = batch.targets.to(device)
elif dataloader_idx == 0:
# skip device transfer for the first dataloader or anything you wish
pass
else:
batch = super().transfer_batch_to_device(batch, device, dataloader_idx)
return batch
on_before_batch_transfer(self, batch, dataloader_idx)
Override to alter or apply batch augmentations to your batch before it is transferred to the device.
You can use self.trainer.training/testing/validating/predicting
so that you can add different logic as per your requirement.
def on_before_batch_transfer(self, batch, dataloader_idx):
batch['x'] = transforms(batch['x'])
return batch
使用datamodule的方式,trainer会自动依次调用prepare_data,setup,train_dataloader
等函数,trainer会免除复杂的训练逻辑,不使用trainer也可以手动执行dm的各个方法来自己写训练逻辑
dm = MNISTDataModule()
model = Model()
trainer.fit(model, datamodule=dm)
trainer.test(datamodule=dm)
trainer.validate(datamodule=dm)
trainer.predict(datamodule=dm)
Model-Lightning Module
首先需要定义自己的网络模块
import os
import torch
from torch import nn
import torch.nn.functional as F
from torchvision import transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
import lightning as L
class Encoder(nn.Module):
def __init__(self):
super().__init__()
self.l1 = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))
def forward(self, x):
return self.l1(x)
class Decoder(nn.Module):
def __init__(self):
super().__init__()
self.l1 = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))
def forward(self, x):
return self.l1(x)
然后可以使用LightningModule
来定义模块的交互、配置优化器,使用该模块的优点是减少了training loop的复杂训练逻辑
class LitAutoEncoder(L.LightningModule):
def __init__(self, encoder, decoder):
super().__init__()
self.encoder = encoder
self.decoder = decoder
def training_step(self, batch, batch_idx):
# 训练阶段模块如何交互
# training_step defines the train loop.
x, _ = batch
x = x.view(x.size(0), -1)
z = self.encoder(x)
x_hat = self.decoder(z)
loss = F.mse_loss(x_hat, x)
return loss #这个需要返回loss
def validation_step(self, batch, batch_idx):
# this is the validation loop
x, _ = batch
x = x.view(x.size(0), -1)
z = self.encoder(x)
x_hat = self.decoder(z)
val_loss = F.mse_loss(x_hat, x)
self.log("val_loss", val_loss) # 记录loss就行
def test_step(self, batch, batch_idx):
# this is the test loop
x, _ = batch
x = x.view(x.size(0), -1)
z = self.encoder(x)
x_hat = self.decoder(z)
test_loss = F.mse_loss(x_hat, x)
self.log("test_loss", test_loss) # 记录loss就行
def configure_optimizers(self):
# 定义优化器
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
return optimizer
实际使用时利用Trainer
来负责处理所有工程问题,并将所需的所有复杂性抽象化以实现规模扩展。
# model
autoencoder = LitAutoEncoder(Encoder(), Decoder())
# train model
trainer = L.Trainer()
trainer.fit(model=autoencoder, train_dataloaders=train_loader)
# Trainer代替的就是以下逻辑
autoencoder = LitAutoEncoder(Encoder(), Decoder())
optimizer = autoencoder.configure_optimizers()
for batch_idx, batch in enumerate(train_loader):# 自动循环
loss = autoencoder.training_step(batch, batch_idx)
loss.backward() #自动反向传播
optimizer.step() # 自动更新参数
optimizer.zero_grad() # 自动梯度清0
# test the model
trainer.test(model, dataloaders=DataLoader(test_set))
# train with both splits(训练加验证)
trainer = L.Trainer()
trainer.fit(model, train_loader, valid_loader) #把两个dataloader都加入
上述是模型的训练部分(训练、验证、测试),对于推理部分,简单方式如下:
model = LitModel.load_from_checkpoint("best_model.ckpt")
model.eval()
x = torch.randn(1, 64)
with torch.no_grad():
y_hat = model(x)
上面的推理逻辑还是很复杂,LightningModule
提供predict_step
来解决:
class MyModel(LightningModule):
def predict_step(self, batch, batch_idx, dataloader_idx=0):
return self(batch)
# 推理如下
data_loader = DataLoader(...)
model = MyModel()
trainer = Trainer()
predictions = trainer.predict(model, data_loader)
还可以添加更复杂的预处理或后处理逻辑,例如:
class LitMCdropoutModel(L.LightningModule):
def __init__(self, model, mc_iteration):
super().__init__()
self.model = model
self.dropout = nn.Dropout()
self.mc_iteration = mc_iteration
def predict_step(self, batch, batch_idx):
# enable Monte Carlo Dropout
self.dropout.train()
# take average of `self.mc_iteration` iterations
pred = [self.dropout(self.model(x)).unsqueeze(0) for _ in range(self.mc_iteration)]
pred = torch.vstack(pred).mean(dim=0)
return pred
分布式推理
通过使用Lightning
的predict_step
,可以使用分布式推理通过BasePredictionWriter
.
import torch
from lightning.pytorch.callbacks import BasePredictionWriter
class CustomWriter(BasePredictionWriter):
def __init__(self, output_dir, write_interval):
super().__init__(write_interval)
self.output_dir = output_dir
def write_on_epoch_end(self, trainer, pl_module, predictions, batch_indices):
# this will create N (num processes) files in `output_dir` each containing
# the predictions of it's respective rank
torch.save(predictions, os.path.join(self.output_dir, f"predictions_{trainer.global_rank}.pt"))
# optionally, you can also save `batch_indices` to get the information about the data index
# from your prediction data
torch.save(batch_indices, os.path.join(self.output_dir, f"batch_indices_{trainer.global_rank}.pt"))
# or you can set `write_interval="batch"` and override `write_on_batch_end` to save
# predictions at batch level
pred_writer = CustomWriter(output_dir="pred_path", write_interval="epoch")
trainer = Trainer(accelerator="gpu", strategy="ddp", devices=8, callbacks=[pred_writer])
model = BoringModel()
trainer.predict(model, return_predictions=False)
Coming up next
之后将学习Lightning
关于Trainer
以及如callback
、调试方法、模型优化、可视化等涉及训练逻辑和优化方法的相关技术。
Citation
About Me
个人博客:月源
知乎文章:月源
公众号:月源的算法仙蛊屋