使用 Pytorch-Lightning 的多 GPU

目前,MinkowskiEngine 支持通过数据并行化进行多 GPU 训练。在数据并行化中,我们有一组小批量数据,这些数据将被馈送到一组网络副本中。

目前有多个多 GPU 示例,但建议使用 DistributedDataParallel (DDP) 和 Pytorch-lightning 示例。

在本教程中,我们将介绍 pytorch-lightning 多 GPU 示例。 我们将首先介绍如何定义数据集、数据加载器和网络。

数据集

让我们创建一个读取点云的虚拟数据集。

class DummyDataset(Dataset):

    ...

    def __getitem__(self, i):
        filename = self.filenames[i]
        pcd = o3d.io.read_point_cloud(filename)
        quantized_coords, feats = ME.utils.sparse_quantize(
            np.array(pcd.points, dtype=np.float32),
            np.array(pcd.colors, dtype=np.float32),
            quantization_size=self.voxel_size,
        )
        random_labels = torch.zeros(len(feats))
        return {
            "coordinates": quantized_coords,
            "features": feats,
            "labels": random_labels,
        }

为了在 pytorch 数据加载器中使用它,我们需要一个自定义的整理函数,该函数将所有坐标合并为与 MinkowskiEngine 稀疏张量格式兼容的批处理坐标。

def minkowski_collate_fn(list_data):
    r"""
    Collation function for MinkowskiEngine.SparseTensor that creates batched
    cooordinates given a list of dictionaries.
    """
    coordinates_batch, features_batch, labels_batch = ME.utils.sparse_collate(
        [d["coordinates"] for d in list_data],
        [d["features"] for d in list_data],
        [d["labels"] for d in list_data],
        dtype=torch.float32,
    )
    return {
        "coordinates": coordinates_batch,
        "features": features_batch,
        "labels": labels_batch,
    }

...

dataset = torch.utils.data.DataLoader(
       DummyDataset("train", voxel_size=voxel_size),
       batch_size=batch_size,
       collate_fn=minkowski_collate_fn,
       shuffle=True,
    )

网络

接下来,我们可以为分割定义一个简单的虚拟网络。

class DummyNetwork(nn.Module):
    def __init__(self, in_channels, out_channels, D=3):
        nn.Module.__init__(self)
        self.net = nn.Sequential(
            ME.MinkowskiConvolution(in_channels, 32, 3, dimension=D),
            ME.MinkowskiBatchNorm(32),
            ME.MinkowskiReLU(),
            ME.MinkowskiConvolution(32, 64, 3, stride=2, dimension=D),
            ME.MinkowskiBatchNorm(64),
            ME.MinkowskiReLU(),
            ME.MinkowskiConvolutionTranspose(64, 32, 3, stride=2, dimension=D),
            ME.MinkowskiBatchNorm(32),
            ME.MinkowskiReLU(),
            ME.MinkowskiConvolution(32, out_channels, kernel_size=1, dimension=D),
        )

    def forward(self, x):
        return self.net(x)

Lightning 模块

Pytorch lightning 是一个高级的 pytorch 包装器,它简化了许多样板代码。 pytorch lightning 的核心是 LightningModule,它为训练框架提供了一个包装器。 在本节中,我们提供了一个扩展 LightningModule 的分割训练包装器。

class MinkowskiSegmentationModule(LightningModule):
    r"""
    Segmentation Module for MinkowskiEngine.
    """

    def __init__(
        self,
        model,
        optimizer_name="SGD",
        lr=1e-3,
        weight_decay=1e-5,
        voxel_size=0.05,
        batch_size=12,
        val_batch_size=6,
        train_num_workers=4,
        val_num_workers=2,
    ):
        super().__init__()
        for name, value in vars().items():
            if name != "self":
                setattr(self, name, value)

        self.criterion = nn.CrossEntropyLoss()

    def train_dataloader(self):
        return DataLoader(
            DummyDataset("train", voxel_size=self.voxel_size),
            batch_size=self.batch_size,
            collate_fn=minkowski_collate_fn,
            shuffle=True,
        )

    def val_dataloader(self):
        return DataLoader(
            DummyDataset("val", voxel_size=self.voxel_size),
            batch_size=self.val_batch_size,
            collate_fn=minkowski_collate_fn,
        )

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        stensor = ME.SparseTensor(
            coordinates=batch["coordinates"], features=batch["features"]
        )
        # Must clear cache at regular interval
        if self.global_step % 10 == 0:
            torch.cuda.empty_cache()
        return self.criterion(self(stensor).F, batch["labels"].long())

    def validation_step(self, batch, batch_idx):
        stensor = ME.SparseTensor(
            coordinates=batch["coordinates"], features=batch["features"]
        )
        return self.criterion(self(stensor).F, batch["labels"].long())

    def configure_optimizers(self):
        return SGD(self.model.parameters(), lr=self.lr, weight_decay=self.weight_decay)

请注意,我们定期清除缓存。 这是因为输入稀疏张量在每次迭代中都有不同的长度,如果当前批次大于已分配的内存,这将导致新的内存分配。 这种重复的内存分配将导致内存不足错误,因此必须定期清除 GPU 缓存。 如果您的 GPU 内存较小,请尝试在更短的时间间隔内清除缓存。

def training_step(self, batch, batch_idx):
    ...
    # Must clear cache at a regular interval
    if self.global_step % 10 == 0:
        torch.cuda.empty_cache()
    return self.criterion(self(stensor).F, batch["labels"].long())

训练

创建分割模块后,我们可以使用以下代码训练网络。

pl_module = MinkowskiSegmentationModule(DummyNetwork(3, 20, D=3), lr=args.lr)
trainer = Trainer(max_epochs=args.max_epochs, gpus=num_devices, accelerator="ddp")
trainer.fit(pl_module)

在这里,如果我们设置 num_devices 为可用 GPUS 的数量,pytorch-lightning 将自动使用 pytorch DistributedDataParallel 在所有 GPU 上训练网络。