diff --git a/models/deepmapping.py b/models/deepmapping.py index 6e1494a..d0e2b60 100644 --- a/models/deepmapping.py +++ b/models/deepmapping.py @@ -4,7 +4,7 @@ import torch import torch.nn as nn from .networks import LocNetRegKITTI, MLP -from utils import transform_to_global_KITTI, compose_pose_diff, euler_pose_to_quaternion,quaternion_to_euler_pose, qmul_torch +from utils import transform_to_global_KITTI, compose_pose_diff, euler_pose_to_quaternion,quaternion_to_euler_pose, qmul_torch, matrix_to_rotation_6d, rotation_6d_to_matrix, euler_pose_to_6d_pose def get_M_net_inputs_labels(occupied_points, unoccupited_points): """ @@ -52,6 +52,8 @@ def __init__(self, n_points, loss_fn, rotation_representation='quaternion', n_sa self.rotation = rotation_representation if self.rotation == 'quaternion': self.loc_net = LocNetRegKITTI(n_points=n_points, out_dims=7) # + elif self.rotation == '6d': + self.loc_net = LocNetRegKITTI(n_points=n_points, out_dims=9) else: self.loc_net = LocNetRegKITTI(n_points=n_points, out_dims=6) # self.occup_net = MLP(dim) @@ -60,14 +62,22 @@ def __init__(self, n_points, loss_fn, rotation_representation='quaternion', n_sa def forward(self, obs_local, sensor_pose, valid_points=None, pairwise_pose=None): - # obs_local: + # obs_local: + # sensor_pose: G = obs_local.shape[0] self.obs_local = obs_local if self.rotation == 'quaternion': sensor_pose = euler_pose_to_quaternion(sensor_pose) - self.obs_initial = transform_to_global_KITTI( - sensor_pose, self.obs_local, rotation_representation=self.rotation) + # sensor_pose: + elif self.rotation == '6d': + sensor_pose = euler_pose_to_6d_pose(sensor_pose) + # sensor_pose: + + self.obs_initial = transform_to_global_KITTI(sensor_pose, self.obs_local, rotation_representation=self.rotation) + # obs_initial: self.l_net_out = self.loc_net(self.obs_initial) + # l_net_out: + print(self.l_net_out.shape) if self.rotation == 'quaternion': original_shape = list(sensor_pose.shape) xyz = self.l_net_out[:,:3]+ sensor_pose[:,:3] @@ -75,17 +85,27 @@ def forward(self, obs_local, sensor_pose, valid_points=None, pairwise_pose=None) self.pose_est = torch.cat((xyz, wxyz), dim=1).view(original_shape) elif self.rotation == 'euler_angle': self.pose_est = self.l_net_out + sensor_pose + elif self.rotation == '6d': + original_shape = list(sensor_pose.shape) + xyz = self.l_net_out[:, :3] + sensor_pose[:, :3] + l_net_6d = rotation_6d_to_matrix(self.l_net_out[:, 3:]) + sensor_6d = rotation_6d_to_matrix(sensor_pose[:, 3:]) + rotation_6d = torch.matmul(l_net_6d, sensor_6d) + rotation_6d = matrix_to_rotation_6d(rotation_6d) + self.pose_est = torch.cat((xyz, rotation_6d), dim=1).view(original_shape) # l_net_out[:, -1] = 0 # self.pose_est = cat_pose_KITTI(sensor_pose, self.loc_net(self.obs_initial)) # self.bs = obs_local.shape[0] # self.obs_local = self.obs_local.reshape(self.bs,-1,3) - self.obs_global_est = transform_to_global_KITTI( - self.pose_est, self.obs_local, rotation_representation=self.rotation) + self.obs_global_est = transform_to_global_KITTI(self.pose_est, self.obs_local, rotation_representation=self.rotation) if self.training: self.valid_points = valid_points if self.rotation == 'quaternion': pairwise_pose = euler_pose_to_quaternion(pairwise_pose) + elif self.rotation == '6d': + pairwise_pose = euler_pose_to_6d_pose(pairwise_pose) + if self.loss_fn.__name__ == "pose": self.t_src, self.t_dst, self.r_src, self.r_dst = compose_pose_diff(self.pose_est, pairwise_pose, rotation_representation=self.rotation) else: diff --git a/script/train.py b/script/train.py index 83ffe75..38a8231 100644 --- a/script/train.py +++ b/script/train.py @@ -5,7 +5,8 @@ import time import argparse import functools -print = functools.partial(print,flush=True) + +print = functools.partial(print, flush=True) import numpy as np import torch @@ -22,20 +23,20 @@ torch.manual_seed(42) parser = argparse.ArgumentParser() -parser.add_argument('--name',type=str,default='test',help='experiment name') -parser.add_argument('-e','--n_epochs',type=int,default=1000,help='number of epochs') -parser.add_argument('-l','--loss',type=str,default='bce_ch',help='loss function') -parser.add_argument('-n','--n_samples',type=int,default=35,help='number of sampled unoccupied points along rays') -parser.add_argument('-v','--voxel_size',type=float,default=1,help='size of downsampling voxel grid') -parser.add_argument('--lr',type=float,default=1e-4,help='learning rate') +parser.add_argument('--name', type=str, default='test', help='experiment name') +parser.add_argument('-e', '--n_epochs', type=int, default=1000, help='number of epochs') +parser.add_argument('-l', '--loss', type=str, default='bce_ch', help='loss function') +parser.add_argument('-n', '--n_samples', type=int, default=35, help='number of sampled unoccupied points along rays') +parser.add_argument('-v', '--voxel_size', type=float, default=1, help='size of downsampling voxel grid') +parser.add_argument('--lr', type=float, default=1e-4, help='learning rate') parser.add_argument('--dataset', type=str, default="KITTI", help="Type of dataset to use") -parser.add_argument('-d','--data_dir',type=str,default='../data/ActiveVisionDataset/',help='dataset path') -parser.add_argument('-t','--traj',type=str,default='2011_09_30_drive_0018_sync_full',help='trajectory file folder') -parser.add_argument('-m','--model', type=str, default=None,help='pretrained model name') -parser.add_argument('-i','--init', type=str, default=None, help='path to initial pose') +parser.add_argument('-d', '--data_dir', type=str, default='../data/ActiveVisionDataset/', help='dataset path') +parser.add_argument('-t', '--traj', type=str, default='2011_09_30_drive_0018_sync_full', help='trajectory file folder') +parser.add_argument('-m', '--model', type=str, default=None, help='pretrained model name') +parser.add_argument('-i', '--init', type=str, default=None, help='path to initial pose') parser.add_argument('-p', '--pairwise', type=str, default=None, help='path to pairwise pose') -parser.add_argument('--log_interval',type=int,default=10,help='logging interval of saving results') -parser.add_argument('--group_size',type=int,default=8,help='group size') +parser.add_argument('--log_interval', type=int, default=10, help='logging interval of saving results') +parser.add_argument('--group_size', type=int, default=8, help='group size') parser.add_argument('--resume', action='store_true', help='If present, restore checkpoint and resume training') parser.add_argument('--alpha', type=float, default=0.1, help='weight for chamfer loss') @@ -46,52 +47,64 @@ opt = parser.parse_args() -checkpoint_dir = os.path.join('../results/'+opt.dataset,opt.name) +checkpoint_dir = os.path.join('../results/' + opt.dataset, opt.name) if not os.path.exists(checkpoint_dir): os.makedirs(checkpoint_dir) if not os.path.exists(os.path.join(checkpoint_dir, "pose_ests")): os.makedirs(os.path.join(checkpoint_dir, "pose_ests")) -utils.save_opt(checkpoint_dir,opt) +utils.save_opt(checkpoint_dir, opt) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +# opt.init: INIT=$DATA_DIR/$TRAJ/prior/init_pose.npy +# init_pose.npy should be an Nx6 numpy array, where N is the number of frames. +# Each row is the initial pose of a frame represented by x, y, z, row, pitch, yaw. +# 把initial pose转换成tensor init_pose_np = np.load(opt.init).astype("float32") init_pose = torch.from_numpy(init_pose_np) pairwise_pose = np.load(opt.pairwise).astype("float32") print('loading dataset') if opt.dataset == "KITTI": - train_dataset = Kitti(opt.data_dir, opt.traj, opt.voxel_size, init_pose=init_pose, group_size=opt.group_size, pairwise_pose=pairwise_pose) + train_dataset = Kitti(opt.data_dir, opt.traj, opt.voxel_size, init_pose=init_pose, group_size=opt.group_size, + pairwise_pose=pairwise_pose) eval_dataset = KittiEval(train_dataset) + # eval_dataset大致与train_dataset相同,但是不包含gt_pose + elif opt.dataset == "NCLT" or "Nebula": - train_dataset = Nclt(opt.data_dir, opt.traj, opt.voxel_size, init_pose=init_pose, group_size=opt.group_size, pairwise_pose=pairwise_pose) + train_dataset = Nclt(opt.data_dir, opt.traj, opt.voxel_size, init_pose=init_pose, group_size=opt.group_size, + pairwise_pose=pairwise_pose) eval_dataset = NcltEval(train_dataset) else: assert 0, "Unsupported dataset" + + train_loader = DataLoader(train_dataset, batch_size=None, num_workers=4, shuffle=True) eval_loader = DataLoader(eval_dataset, batch_size=64, num_workers=4) -loss_fn = eval('loss.'+opt.loss) +# loss function is bce_ch_eu +loss_fn = eval('loss.' + opt.loss) -if opt.rotation not in ['quaternion','euler_angle']: +if opt.rotation not in ['quaternion', 'euler_angle', '6d']: print("Unsupported rotation representation") - assert() + assert () print('creating model') model = DeepMapping2(n_points=train_dataset.n_points, loss_fn=loss_fn, - n_samples=opt.n_samples, alpha=opt.alpha, beta=opt.beta, rotation_representation=opt.rotation).to(device) + n_samples=opt.n_samples, alpha=opt.alpha, beta=opt.beta, rotation_representation=opt.rotation).to( + device) if opt.optimizer == "Adam": - optimizer = optim.Adam(model.parameters(),lr=opt.lr) + optimizer = optim.Adam(model.parameters(), lr=opt.lr) elif opt.optimizer == "SGD": optimizer = optim.SGD(model.parameters(), lr=opt.lr, momentum=0.9) else: print("Unsupported optimizer") - assert() + assert () scaler = torch.cuda.amp.GradScaler() if opt.model is not None: - utils.load_checkpoint(opt.model,model,optimizer) + utils.load_checkpoint(opt.model, model, optimizer) if opt.resume: resume_filename = os.path.join(checkpoint_dir, "model_best.pth") @@ -112,9 +125,8 @@ ch_loss = 0 eu_loss = 0 model.train() - time_start = time.time() - for index,(obs, valid_pt, init_global_pose, pairwise_pose) in enumerate(train_loader): + for index, (obs, valid_pt, init_global_pose, pairwise_pose) in enumerate(train_loader): obs = obs.to(device) valid_pt = valid_pt.to(device) init_global_pose = init_global_pose.to(device) @@ -141,11 +153,11 @@ ch_loss += ch if loss == "bce_ch_eu" or loss == "pose": eu_loss += eu - + time_end = time.time() # print(model.parameters().grad) print("Training time: {:.2f}s".format(time_end - time_start)) - training_loss_epoch = training_loss/len(train_loader) + training_loss_epoch = training_loss / len(train_loader) bce_epoch = bce_loss / len(train_loader) ch_epoch = ch_loss / len(train_loader) eu_epoch = eu_loss / len(train_loader) @@ -154,12 +166,12 @@ ch_losses.append(ch_epoch) eu_losses.append(eu_epoch) - print('[{}/{}], training loss: {:.4f}'.format(epoch+1,opt.n_epochs,training_loss_epoch)) + print('[{}/{}], training loss: {:.4f}'.format(epoch + 1, opt.n_epochs, training_loss_epoch)) obs_global_est_np = [] pose_est_np = [] with torch.no_grad(): model.eval() - for index,(obs, init_global_pose) in enumerate(eval_loader): + for index, (obs, init_global_pose) in enumerate(eval_loader): obs = obs.to(device) init_global_pose = init_global_pose.to(device) model(obs, init_global_pose) @@ -168,13 +180,13 @@ pose_est = model.pose_est obs_global_est_np.append(obs_global_est.cpu().detach().numpy()) pose_est_np.append(pose_est.cpu().detach().numpy()) - + pose_est_np = np.concatenate(pose_est_np) - save_name = os.path.join(checkpoint_dir, "pose_ests", str(epoch+1)) - np.save(save_name,pose_est_np) + save_name = os.path.join(checkpoint_dir, "pose_ests", str(epoch + 1)) + np.save(save_name, pose_est_np) - utils.plot_global_pose(checkpoint_dir, opt.dataset, epoch+1, rotation_representation=opt.rotation) + utils.plot_global_pose(checkpoint_dir, opt.dataset, epoch + 1, rotation_representation=opt.rotation) try: trans_ate, rot_ate = utils.compute_ate(pose_est_np, train_dataset.gt_pose, rotation_representation=opt.rotation) @@ -193,17 +205,17 @@ if training_loss_epoch < best_loss: print("lowest loss:", training_loss_epoch) best_loss = training_loss_epoch - + # Visulize global point clouds obs_global_est_np = np.concatenate(obs_global_est_np) - save_name = os.path.join(checkpoint_dir,'obs_global_est.npy') - np.save(save_name,obs_global_est_np) + save_name = os.path.join(checkpoint_dir, 'obs_global_est.npy') + np.save(save_name, obs_global_est_np) # Save checkpoint - save_name = os.path.join(checkpoint_dir,'model_best.pth') - utils.save_checkpoint(save_name,model,optimizer,epoch) + save_name = os.path.join(checkpoint_dir, 'model_best.pth') + utils.save_checkpoint(save_name, model, optimizer, epoch) print() training_losses = np.array(training_losses) -np.save(os.path.join(checkpoint_dir, "loss.npy"), training_losses) \ No newline at end of file +np.save(os.path.join(checkpoint_dir, "loss.npy"), training_losses) diff --git a/utils/geometry_utils.py b/utils/geometry_utils.py index 7ea60f3..86cc6af 100644 --- a/utils/geometry_utils.py +++ b/utils/geometry_utils.py @@ -77,6 +77,27 @@ def euler_pose_to_quaternion(euler_pose): quaternion_pose = torch.cat((xyz, quaternion), dim=1) return quaternion_pose +def euler_pose_to_6d_pose(euler_pose): + """ + convert euler angles pose to 6d pose. + :param euler_pose: + :return 6d_pose: + """ + xyz = euler_pose[:, :3] + e = euler_pose[:,3:] + assert e.shape[-1] == 3 + + # Convert euler angles to rotation matrix + rotation_matrix = euler_angles_to_matrix(e, convention="XYZ") + + # Convert rotation matrix to 6D pose representation + six_d = matrix_to_rotation_6d(rotation_matrix) + + # Concatenate xyz and six_d along dimension 1 + six_d_pose = torch.cat((xyz, six_d), dim=1) + + return six_d_pose + def transform_to_global_KITTI(pose, obs_local, rotation_representation): """ transform obs local coordinate to global corrdinate frame @@ -92,6 +113,9 @@ def transform_to_global_KITTI(pose, obs_local, rotation_representation): elif rotation_representation == "quaternion": quat = pose[:, 3:] rotation_matrix = quaternion_to_matrix(quat) + elif rotation_representation == "6d": + sixd = pose[:, 3:] + rotation_matrix = rotation_6d_to_matrix(sixd) obs_global = torch.bmm(obs_local, rotation_matrix.transpose(1, 2)) # obs_global[:, :, 0] = obs_global[:, :, 0] + pose[:, [0]] # obs_global[:, :, 1] = obs_global[:, :, 1] + pose[:, [1]] @@ -123,6 +147,9 @@ def compose_pose_diff(pose_est, pairwise, rotation_representation): elif rotation_representation == "quaternion": rotation_est = quaternion_to_matrix(rpy_est) rotation_pairwise = quaternion_to_matrix(rpy_pairwise) + elif rotation_representation == "6d": + rotation_est = rotation_6d_to_matrix(rpy_est) + rotation_pairwise = rotation_6d_to_matrix(rpy_pairwise) r_dst = torch.bmm(rotation_est, rotation_pairwise) # rpy = matrix_to_euler_angles(rotation, convention="XYZ") # dst = torch.concat((xyz, rpy), dim=1) @@ -223,6 +250,11 @@ def compute_ate(output, target, rotation_representation): q = output[:,3:] output_quat = q[:, [1, 2, 3, 0]] rpy = Rot.from_quat(output_quat).as_euler("XYZ") + elif rotation_representation == "6d": + r = torch.tensor(output[:,3:]) + output_r = rotation_6d_to_matrix(r) + rpy = Rot.from_matrix(output_r.numpy()).as_euler("XYZ") + yaw_aligned = rpy[:, -1] + rotation[-1] yaw_gt = target[:, -1] while np.any(yaw_aligned > np.pi): diff --git a/utils/pytorch3d_utils.py b/utils/pytorch3d_utils.py index f1060d0..559488b 100755 --- a/utils/pytorch3d_utils.py +++ b/utils/pytorch3d_utils.py @@ -249,4 +249,46 @@ def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor: ), -1, ) - return o.reshape(quaternions.shape[:-1] + (3, 3)) \ No newline at end of file + return o.reshape(quaternions.shape[:-1] + (3, 3)) + +def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor: + """ + Converts 6D rotation representation by Zhou et al. [1] to rotation matrix + using Gram--Schmidt orthogonalization per Section B of [1]. + Args: + d6: 6D rotation representation, of size (*, 6) + + Returns: + batch of rotation matrices of size (*, 3, 3) + + [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. + On the Continuity of Rotation Representations in Neural Networks. + IEEE Conference on Computer Vision and Pattern Recognition, 2019. + Retrieved from http://arxiv.org/abs/1812.07035 + """ + + a1, a2 = d6[..., :3], d6[..., 3:] + b1 = F.normalize(a1, dim=-1) + b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1 + b2 = F.normalize(b2, dim=-1) + b3 = torch.cross(b1, b2, dim=-1) + return torch.stack((b1, b2, b3), dim=-2) + + +def matrix_to_rotation_6d(matrix: torch.Tensor) -> torch.Tensor: + """ + Converts rotation matrices to 6D rotation representation by Zhou et al. [1] + by dropping the last row. Note that 6D representation is not unique. + Args: + matrix: batch of rotation matrices of size (*, 3, 3) + + Returns: + 6D rotation representation, of size (*, 6) + + [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. + On the Continuity of Rotation Representations in Neural Networks. + IEEE Conference on Computer Vision and Pattern Recognition, 2019. + Retrieved from http://arxiv.org/abs/1812.07035 + """ + batch_dim = matrix.size()[:-2] + return matrix[..., :2, :].clone().reshape(batch_dim + (6,)) diff --git a/utils/vis_utils.py b/utils/vis_utils.py index 5579c66..9bb441e 100644 --- a/utils/vis_utils.py +++ b/utils/vis_utils.py @@ -61,6 +61,12 @@ def plot_global_pose(checkpoint_dir, dataset="kitti", epoch=None, mode=None, rot q = q[:, [1, 2, 3, 0]] rpy = Rot.from_quat(q).as_euler("XYZ") location = np.concatenate((location[:,:3],rpy),axis=1) + elif rotation_representation == "6d": + rotation_6d = torch.tensor(location[:, 3:]) + rotation_matrix = rotation_6d_to_matrix(rotation_6d) + rpy = Rot.from_matrix(rotation_matrix.numpy()).as_euler("XYZ") + location = np.concatenate((location[:, :3], rpy), axis=1) + t = np.arange(location.shape[0]) / location.shape[0] # location[:, 0] = location[:, 0] - np.mean(location[:, 0]) # location[:, 1] = location[:, 1] - np.mean(location[:, 1])