MinkowskiPruning

MinkowskiPruning

class MinkowskiEngine.MinkowskiPruning

MinkowskiEngine.SparseTensor 中删除指定的坐标。

__init__()

初始化内部模块状态,由 nn.Module 和 ScriptModule 共享。

forward(input: MinkowskiSparseTensor.SparseTensor, mask: torch.Tensor)
参数

input (MinkowskiEnigne.SparseTensor): 要从中删除坐标的稀疏张量。

mask (torch.BoolTensor): 指定要保留哪个的掩码向量。 具有 False 的坐标将被删除。

返回

一个 MinkowskiEngine.SparseTensor,其中 C = 对应于 mask == True 的坐标,F = 来自 mask == True 的特征的副本。

示例

>>> # Define inputs
>>> input = SparseTensor(feats, coords=coords)
>>> # Any boolean tensor can be used as the filter
>>> mask = torch.rand(feats.size(0)) < 0.5
>>> pruning = MinkowskiPruning()
>>> output = pruning(input, mask)
training: bool