Point Transformer MindSpore
本仓库是基于MindSpore的Point-Transformer Zhao et al 实现, 主要包括在ModelNet40数据集上的分类任务以及ShapeNet数据集上的分割任务。
由于本文代码未开源,本仓库主要参考了其他非官方复现代码,例如 qq456cvb, Point-Transformer , 等。在评价指标方面,选择非官方复现代码为基准进行分析对比。
Point Transformer 简介
作者借鉴了在自然语言处理领域取得巨大成就的Transformer模块,将其应用于三维点云处理领域的目标检测和语义分割。Point Transformer 模型框架如下图所示。
本文主要设计了Transformer、TransitionDown和TransitionUp模块,并结合MLP在ModelNet40和ShapeNet数据集上取得了很好的效果。Transformer模块促进局部特征向量之间的信息交换,为所有数据点生成新的特征向量作为输出,并考虑特征向量的内容及其在三维空间中的分布。TransitionDown模块的关键功能是根据需要从点集中减少点的数量,并且采样点要服从良好的分布。TransitionUp模块的主要功能是将子集中的特征映射回高分辨率点集。因此,每个输入特征首先经过一个线性层,然后通过批处理归一化和ReLU,通过三线性插值将特征映射到高分辨率的点集。Transformer,TransitionDown与TransitionUp模块结构示意图如下图所示。
环境配置
python = 3.8
MindSpore = 1.7
numpy
tqdm
yaml
argparse
os
sys
数据集简介以及下载
Point Trasnformer模型是用到的数据集主要是ModelNet40以及ShapeNet.
ModelNet40 数据集可以通过本链接下载.
ModelNet40的文件结构展示如下.
|--./ModelNet40/
|--airplane
|-- airplane_0001.txt
|-- airplane_0002.txt
|-- .......
|--bathtub
|-- bathtub_0001.txt
|-- bathtub_0002.txt
|-- .......
|--filelist.txt
|-- modelnet40_shape_names.txt
|-- modelnet40_test.txt
|-- modelnet40_train.txt
ShapeNet 数据集可以通过链接下载.
ShapeNet 文件结构如下所示.
|--./shapenetcore_partanno_segmentation_benchmark_v0_normal/
|--02691156
|-- 1a04e3eab45ca15dd86060f189eb133.txt
|-- 1a32f10b20170883663e90eaf6b4ca52.txt
|-- .......
|--02773838
|--02954340
|-- .......
|--train_test_split
|--synsetoffset2category.txt
基于MindSpore的关键模块实现
Transformer
class Transformer(nn.Cell):
"""
The transformer block integrates the self-attention layer,linear projections that can reduce dimensionality and accelerate processing, and a residual connection.
The information aggregation adapts both to the content of the feature vectors and their layout in 3D.
Input:
x(Tensor) : The input tensor. The shapes of tensor is (B,N,C), where B is batch size; N is input number; C is channels.
feature(Tensor) : Feature tensor. The shapes of tensor is (B,N,D), where D =
Output:
output(Tensor) : The output tensor. The shapes of the tensor is (B,N,D), where * is label dimensions.
Argument:
d_points: The dim of input.
d_models: The dim of feature.
tasks(str) : The name of dataset.
k : KNN Argument
Example:
X = mindspore.Tensor(np.random.rand(32, 1024, 3), dtype = mindspore.float32)
Y = mindspore.Tensor(np.random.rand(32, 1024, 32), dtype = mindspore.float32)
Transformer = Transformer(32,512,16)
Output, attn = T(X, Y)
Print(output.shape)
Print(attn.shape)
"""
def __init__(self, d_points, d_model, k):
super(Transformer, self).__init__()
self.fc1 = nn.Dense(d_points, d_model, weight_init = "Uniform", bias_init = "Uniform")
self.fc2 = nn.Dense(d_model, d_points, weight_init = "Uniform", bias_init = "Uniform")
self.fc_delta = nn.SequentialCell(
nn.Dense(3, d_model, weight_init = "Uniform", bias_init = "Uniform"),
nn.ReLU(),
nn.Dense(d_model, d_model, weight_init = "Uniform", bias_init = "Uniform")
)
self.fc_gamma = nn.SequentialCell(
nn.Dense(d_model, d_model, weight_init = "Uniform", bias_init = "Uniform"),
nn.ReLU(),
nn.Dense(d_model, d_model, weight_init = "Uniform", bias_init = "Uniform")
)
self.w_qs = nn.Dense(d_model, d_model, has_bias=False, weight_init = "Uniform")
self.w_ks = nn.Dense(d_model, d_model, has_bias=False, weight_init = "Uniform")
self.w_vs = nn.Dense(d_model, d_model, has_bias=False, weight_init = "Uniform")
self.k = k
def construct(self, xyz, features):
"""
Transformer construct
"""
dist = Square_distance(xyz, xyz)
sort = ops.Sort()
knn_num, knn_idx= sort(dist)
knn_idx = knn_idx[:, :, :self.k]
knn_xyz = index_points(xyz, knn_idx)
pre = features
x = self.fc1(features)
q, k, v = self.w_qs(x), index_points(self.w_ks(x), knn_idx), index_points(self.w_vs(x), knn_idx)
pos_enc = self.fc_delta(xyz[:, :, None] - knn_xyz)
attn = self.fc_gamma(q[:, :, None] - k + pos_enc)
k = mnp.size(k, axis=-1)
s = mnp.sqrt(float(k))
at = attn / s
softmax = nn.Softmax(axis=-2)
attn = softmax(at)
eq = "bmnf, bmnf->bmf"
einsum = ops.Einsum(eq)
res = einsum((attn, v + pos_enc))
res = self.fc2(res) + pre
return res, attn
TransitionDown
class Transitiondown(nn.Cell):
"""
TransitionDown Block
A key function of TransitionDown Block is to reduce the number of points from the point set as needed.
Input:
xyz(Tensor) : Input point sosition data. The shapes of xyz is [B, C, N].
points(Tensor) : Input position data.The shapes of points is [B, D’, S].
Output:
new_xyz: Sampled points position data. The shapes of New_xyz is [B,C,S].
new_points: Sampled points position feature data. The shapes of New_points is [B,D’,S].
Argument:
k: KNN argument.
n neighbor: KNN argument.
channels: PointNet++ Set Abstraction module argument.
Example:
C = TransitionDown(16,512,32)
"""
def __init__(self, k, nneighbor, channels):
super(Transitiondown, self).__init__()
self.sa = PointNet2SetAbstraction(k, 0, nneighbor, channels[0], channels[1:], group_all=False, knn=True)
def construct(self, xyz, points):
"""
TransitionDown construct
"""
after_sa = self.sa(xyz, points)
return after_sa
TransitionUp
class TransitionUp(nn.Cell):
"""
TransitionTransitionup is a transition module in the decoder.
Its main function is to map features from subset P2 to superset P1.
To this end, each input feature passes through a linear layer, then batch normalization, and then ReLU.
Finally, the feature is mapped to the P1 point set on the high resolution by trilinear interpolation.
These features interpolated from the previous decoder stage summarize the features of the corresponding encoder stageUp Block.
Input:
Xyz1(Tensor) : The first input tensor.
Point1(Tensor) : The first input feature.
Xyz2(Tensore) : The second input tensor
Point2(Tensoor) : The second input feature
Output:
Output(Tensor) : The output tensor. The shapes of the tensor is (B,*), where * is label dimensions.
Argument:
Dim1: The dim of first input.
Dim2: The dim of second input.
Dim_out: The dim of output.
Example:
XYZ1 = mindspore.Tensor(np.random.rand(1, 4, 3), dtype = mindspore.float32)
XYZ2 = mindspore.Tensor(np.random.rand(1, 4, 512), dtype = mindspore.float32)
Points1 = mindspore.Tensor(np.random.rand(1, 16, 3), dtype = mindspore.float32)
Points2 = mindspore.Tensor(np.random.rand(1, 16, 256), dtype = mindspore.float32)
TransitionUp = TransitionUp(512,256,256)
Output = TransitionUp(XYZ1,XYZ2,Points1,Points2)
Print(output.shape)
"""
def __init__(self, dim1, dim2, dim_out):
class SwapAxes(nn.Cell):
def __init__(self):
super(SwapAxes, self).__init__()
self.random = ops.Transpose()
def construct(self, x):
new_features = self.random(x, (0, 2, 1))
return new_features
super(TransitionUp, self).__init__()
self.random = ops.Transpose()
self.f1 = nn.Dense(dim1, dim_out, weight_init = "Uniform", bias_init = "Uniform")
self.Swap = SwapAxes()
self.BN1 = nn.BatchNorm2d(dim_out, momentum=0.1, use_batch_statistics = None)
self.BN2 = nn.BatchNorm2d(dim_out, momentum=0.1, use_batch_statistics = None)
self.relu = nn.ReLU()
self.f2 = nn.Dense(dim2, dim_out, weight_init = "Uniform", bias_init = "Uniform")
self.fp = PointNetFeaturePropagation(-1, [])
def construct(self, xyz1, point1, xyz2, point2):
"""
TransitionUp Block
"""
feats1 = self.relu(self.Swap(ops.Squeeze(-1)(self.BN1(ops.ExpandDims()(self.Swap(self.f1(point1)), -1)))))
feats2 = self.relu(self.Swap(ops.Squeeze(-1)(self.BN2(ops.ExpandDims()(self.Swap(self.f2(point2)), -1)))))
feats1 = ops.Transpose()(feats1,(0,2,1))
xyz1 = ops.Transpose()(xyz1,(0,2,1))
xyz2 = ops.Transpose()(xyz2,(0,2,1))
mind =ops.Transpose()(self.fp(xyz2,xyz1,None,feats1), (0,2,1))
return mind + feats2
其他引用代码
Point Transformer中使用了PointNet++中提出的SetAbstraction以及FeaturePropagation模块作为TransitionDown以及TransitionUp模块的主要构成。完整代码请参考:PointTransformer-MindSpore
致谢
基于MindSpore的PointTransformer代码主要参考了以下工作:
qq456cvb, Point-Transformer; HelloMaroon, Point-Transformer-classification; pierrefdz,point-transformer;Sharpiless,point-transformer.