# MIT License
# Copyright (c) 2023 Yang You
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from ._utils import PointNetFeaturePropagation, PointNetSetAbstraction, index_points, square_distance
class TransformerBlock(nn.Module):
def __init__(self, d_points, d_model, k) -> None:
super().__init__()
self.fc1 = nn.Linear(d_points, d_model)
self.fc2 = nn.Linear(d_model, d_points)
self.fc_delta = nn.Sequential(
nn.Linear(2, d_model),
nn.ReLU(),
nn.Linear(d_model, d_model)
)
self.fc_gamma = nn.Sequential(
nn.Linear(d_model, d_model),
nn.ReLU(),
nn.Linear(d_model, d_model)
)
self.w_qs = nn.Linear(d_model, d_model, bias=False)
self.w_ks = nn.Linear(d_model, d_model, bias=False)
self.w_vs = nn.Linear(d_model, d_model, bias=False)
self.k = k
# xyz: b x n x 3, features: b x n x f
def forward(self, xyz, features):
dists = square_distance(xyz, xyz)
knn_idx = dists.argsort()[:, :, :self.k] # b x n x k
knn_xyz = index_points(xyz, knn_idx)
pre = features
x = self.fc1(features)
q, k, v = self.w_qs(x), index_points(self.w_ks(x), knn_idx), index_points(self.w_vs(x), knn_idx)
pos_enc = self.fc_delta(xyz[:, :, None] - knn_xyz) # b x n x k x f
attn = self.fc_gamma(q[:, :, None] - k + pos_enc)
attn = F.softmax(attn / np.sqrt(k.size(-1)), dim=-2) # b x n x k x f
res = torch.einsum('bmnf,bmnf->bmf', attn, v + pos_enc)
res = self.fc2(res) + pre
return res, attn
class TransitionDown(nn.Module):
def __init__(self, k, nneighbor, channels):
super().__init__()
self.sa = PointNetSetAbstraction(k, 0, nneighbor, channels[0], channels[1:], group_all=False, knn=True)
def forward(self, xyz, points):
return self.sa(xyz, points)
class TransitionUp(nn.Module):
def __init__(self, dim1, dim2, dim_out):
class SwapAxes(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x.transpose(1, 2)
super().__init__()
self.fc1 = nn.Sequential(
nn.Linear(dim1, dim_out),
SwapAxes(),
nn.BatchNorm1d(dim_out), # TODO
SwapAxes(),
nn.ReLU(),
)
self.fc2 = nn.Sequential(
nn.Linear(dim2, dim_out),
SwapAxes(),
nn.BatchNorm1d(dim_out), # TODO
SwapAxes(),
nn.ReLU(),
)
self.fp = PointNetFeaturePropagation(-1, [])
def forward(self, xyz1, points1, xyz2, points2):
feats1 = self.fc1(points1)
feats2 = self.fc2(points2)
feats1 = self.fp(xyz2.transpose(1, 2), xyz1.transpose(1, 2), None, feats1.transpose(1, 2)).transpose(1, 2)
return feats1 + feats2
class Backbone(nn.Module):
def __init__(self, npoints, nblocks, nneighbor, d_points, transformer_dim):
super().__init__()
#npoints, nblocks, nneighbor, n_c, d_points = cfg.num_point, cfg.model.nblocks, cfg.model.nneighbor, cfg.num_class, cfg.input_dim
self.fc1 = nn.Sequential(
nn.Linear(d_points, 32),
nn.ReLU(),
nn.Linear(32, 32)
)
self.transformer1 = TransformerBlock(32, transformer_dim, nneighbor)
self.transition_downs = nn.ModuleList()
self.transformers = nn.ModuleList()
for i in range(nblocks):
channel = 32 * 2 ** (i + 1)
self.transition_downs.append(TransitionDown(npoints // 4 ** (i + 1), nneighbor, [channel // 2 + 2, channel, channel]))
self.transformers.append(TransformerBlock(channel, transformer_dim, nneighbor))
self.nblocks = nblocks
def forward(self, x, p):
xyz = p
points = self.transformer1(xyz, self.fc1(x))[0]
xyz_and_feats = [(xyz, points)]
for i in range(self.nblocks):
xyz, points = self.transition_downs[i](xyz, points)
points = self.transformers[i](xyz, points)[0]
xyz_and_feats.append((xyz, points))
return points, xyz_and_feats