使用 Pytorch-Lightning 的多 GPU¶
目前,MinkowskiEngine 支持通过数据并行化进行多 GPU 训练。在数据并行化中,我们有一组小批量数据,这些数据将被馈送到一组网络副本中。
目前有多个多 GPU 示例,但建议使用 DistributedDataParallel (DDP) 和 Pytorch-lightning 示例。
在本教程中,我们将介绍 pytorch-lightning 多 GPU 示例。 我们将首先介绍如何定义数据集、数据加载器和网络。
数据集¶
让我们创建一个读取点云的虚拟数据集。
class DummyDataset(Dataset):
...
def __getitem__(self, i):
filename = self.filenames[i]
pcd = o3d.io.read_point_cloud(filename)
quantized_coords, feats = ME.utils.sparse_quantize(
np.array(pcd.points, dtype=np.float32),
np.array(pcd.colors, dtype=np.float32),
quantization_size=self.voxel_size,
)
random_labels = torch.zeros(len(feats))
return {
"coordinates": quantized_coords,
"features": feats,
"labels": random_labels,
}
为了在 pytorch 数据加载器中使用它,我们需要一个自定义的整理函数,该函数将所有坐标合并为与 MinkowskiEngine 稀疏张量格式兼容的批处理坐标。
def minkowski_collate_fn(list_data):
r"""
Collation function for MinkowskiEngine.SparseTensor that creates batched
cooordinates given a list of dictionaries.
"""
coordinates_batch, features_batch, labels_batch = ME.utils.sparse_collate(
[d["coordinates"] for d in list_data],
[d["features"] for d in list_data],
[d["labels"] for d in list_data],
dtype=torch.float32,
)
return {
"coordinates": coordinates_batch,
"features": features_batch,
"labels": labels_batch,
}
...
dataset = torch.utils.data.DataLoader(
DummyDataset("train", voxel_size=voxel_size),
batch_size=batch_size,
collate_fn=minkowski_collate_fn,
shuffle=True,
)
网络¶
接下来,我们可以为分割定义一个简单的虚拟网络。
class DummyNetwork(nn.Module):
def __init__(self, in_channels, out_channels, D=3):
nn.Module.__init__(self)
self.net = nn.Sequential(
ME.MinkowskiConvolution(in_channels, 32, 3, dimension=D),
ME.MinkowskiBatchNorm(32),
ME.MinkowskiReLU(),
ME.MinkowskiConvolution(32, 64, 3, stride=2, dimension=D),
ME.MinkowskiBatchNorm(64),
ME.MinkowskiReLU(),
ME.MinkowskiConvolutionTranspose(64, 32, 3, stride=2, dimension=D),
ME.MinkowskiBatchNorm(32),
ME.MinkowskiReLU(),
ME.MinkowskiConvolution(32, out_channels, kernel_size=1, dimension=D),
)
def forward(self, x):
return self.net(x)
Lightning 模块¶
Pytorch lightning 是一个高级的 pytorch 包装器,它简化了许多样板代码。 pytorch lightning 的核心是 LightningModule
,它为训练框架提供了一个包装器。 在本节中,我们提供了一个扩展 LightningModule
的分割训练包装器。
class MinkowskiSegmentationModule(LightningModule):
r"""
Segmentation Module for MinkowskiEngine.
"""
def __init__(
self,
model,
optimizer_name="SGD",
lr=1e-3,
weight_decay=1e-5,
voxel_size=0.05,
batch_size=12,
val_batch_size=6,
train_num_workers=4,
val_num_workers=2,
):
super().__init__()
for name, value in vars().items():
if name != "self":
setattr(self, name, value)
self.criterion = nn.CrossEntropyLoss()
def train_dataloader(self):
return DataLoader(
DummyDataset("train", voxel_size=self.voxel_size),
batch_size=self.batch_size,
collate_fn=minkowski_collate_fn,
shuffle=True,
)
def val_dataloader(self):
return DataLoader(
DummyDataset("val", voxel_size=self.voxel_size),
batch_size=self.val_batch_size,
collate_fn=minkowski_collate_fn,
)
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx):
stensor = ME.SparseTensor(
coordinates=batch["coordinates"], features=batch["features"]
)
# Must clear cache at regular interval
if self.global_step % 10 == 0:
torch.cuda.empty_cache()
return self.criterion(self(stensor).F, batch["labels"].long())
def validation_step(self, batch, batch_idx):
stensor = ME.SparseTensor(
coordinates=batch["coordinates"], features=batch["features"]
)
return self.criterion(self(stensor).F, batch["labels"].long())
def configure_optimizers(self):
return SGD(self.model.parameters(), lr=self.lr, weight_decay=self.weight_decay)
请注意,我们定期清除缓存。 这是因为输入稀疏张量在每次迭代中都有不同的长度,如果当前批次大于已分配的内存,这将导致新的内存分配。 这种重复的内存分配将导致内存不足错误,因此必须定期清除 GPU 缓存。 如果您的 GPU 内存较小,请尝试在更短的时间间隔内清除缓存。
def training_step(self, batch, batch_idx):
...
# Must clear cache at a regular interval
if self.global_step % 10 == 0:
torch.cuda.empty_cache()
return self.criterion(self(stensor).F, batch["labels"].long())
训练¶
创建分割模块后,我们可以使用以下代码训练网络。
pl_module = MinkowskiSegmentationModule(DummyNetwork(3, 20, D=3), lr=args.lr)
trainer = Trainer(max_epochs=args.max_epochs, gpus=num_devices, accelerator="ddp")
trainer.fit(pl_module)
在这里,如果我们设置 num_devices
为可用 GPUS 的数量,pytorch-lightning 将自动使用 pytorch DistributedDataParallel 在所有 GPU 上训练网络。