Skip to content

Commit

Permalink
refactor application of data_dict transforms
Browse files Browse the repository at this point in the history
  • Loading branch information
David Josef Emmerichs authored and demmerichs committed Jul 31, 2023
1 parent b4dd915 commit 4436d06
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 99 deletions.
106 changes: 58 additions & 48 deletions pcdet/datasets/augmentor/augmentor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from ...utils import box_utils


def random_flip_along(dim, gt_boxes, points, return_flip=False, enable=None):
def random_flip_along(dim, return_flip=False, enable=None):
"""
Args:
gt_boxes: (*, 7 + C), [x, y, z, dx, dy, dz, heading, [vx], [vy]]
Expand All @@ -16,19 +16,31 @@ def random_flip_along(dim, gt_boxes, points, return_flip=False, enable=None):
other_dim = 1 - dim
if enable is None:
enable = np.random.choice([False, True], replace=False, p=[0.5, 0.5])

if enable:
gt_boxes[..., other_dim] = -gt_boxes[..., other_dim]
gt_boxes[..., 6] = -(gt_boxes[..., 6] + np.pi * dim)
points[..., other_dim] = -points[..., other_dim]
def flip_pointlike(points):
points[..., other_dim] = -points[..., other_dim]
return points

def flip_boxlike(boxes):
boxes[..., other_dim] = -boxes[..., other_dim]
boxes[..., 6] = -(boxes[..., 6] + np.pi * dim)

if boxes.shape[-1] > 7:
boxes[..., 7 + other_dim] = -boxes[..., 7 + other_dim]

return boxes

tfs = dict(point=flip_pointlike, box=flip_boxlike)
else:
tfs = dict()

if gt_boxes.shape[-1] > 7:
gt_boxes[..., 7 + other_dim] = -gt_boxes[..., 7 + other_dim]
if return_flip:
return gt_boxes, points, enable
return gt_boxes, points
return tfs, enable
return tfs


def global_rotation(gt_boxes, points, rot_range, return_rot=False, noise_rotation=None):
def global_rotation(rot_range, return_rot=False, noise_rotation=None):
"""
Args:
gt_boxes: (*, 7 + C), [x, y, z, dx, dy, dz, heading, [vx], [vy]]
Expand All @@ -38,61 +50,59 @@ def global_rotation(gt_boxes, points, rot_range, return_rot=False, noise_rotatio
"""
if noise_rotation is None:
noise_rotation = np.random.uniform(rot_range[0], rot_range[1])
points = common_utils.rotate_points_along_z(points[np.newaxis, :, :], np.array([noise_rotation]))[0]
gt_boxes[..., 0:3] = common_utils.rotate_points_along_z(gt_boxes[np.newaxis, ..., 0:3], np.array([noise_rotation]))[0]
gt_boxes[..., 6] += noise_rotation
if gt_boxes.shape[-1] > 7:
gt_boxes[..., 7:9] = common_utils.rotate_points_along_z(
np.concatenate((gt_boxes[..., 7:9], np.zeros((*gt_boxes.shape[:-1], 1))), axis=-1)[np.newaxis, ...],
np.array([noise_rotation])
)[0, ..., 0:2]

def rotate_pointlike(points):
points = common_utils.rotate_points_along_z(points[np.newaxis, :, :], np.array([noise_rotation]))[0]
return points

def rotate_boxlike(boxes):
boxes[..., 0:3] = common_utils.rotate_points_along_z(boxes[np.newaxis, ..., 0:3], np.array([noise_rotation]))[0]
boxes[..., 6] += noise_rotation
if boxes.shape[-1] > 7:
boxes[..., 7:9] = common_utils.rotate_points_along_z(
np.concatenate((boxes[..., 7:9], np.zeros((*boxes.shape[:-1], 1))), axis=-1)[np.newaxis, ...],
np.array([noise_rotation])
)[0, ..., 0:2]
return boxes

tfs = dict(point=rotate_pointlike, box=rotate_boxlike)

if return_rot:
return gt_boxes, points, noise_rotation
return gt_boxes, points
return tfs, noise_rotation
return tfs


def global_scaling(gt_boxes, points, scale_range, return_scale=False):
def global_scaling(scale_range, return_scale=False):
"""
Args:
gt_boxes: (N, 7), [x, y, z, dx, dy, dz, heading, [vx], [vy]]
gt_boxes: (*, 7), [x, y, z, dx, dy, dz, heading, [vx], [vy]]
points: (M, 3 + C),
scale_range: [min, max]
Returns:
"""
if scale_range[1] - scale_range[0] < 1e-3:
return gt_boxes, points
noise_scale = sum(scale_range) / len(scale_range)
assert noise_scale == 1.0, (noise_scale, scale_range)
noise_scale = np.random.uniform(scale_range[0], scale_range[1])
points[:, :3] *= noise_scale
gt_boxes[:, :6] *= noise_scale
if gt_boxes.shape[1] > 7:
gt_boxes[:, 7:9] *= noise_scale

if return_scale:
return gt_boxes, points, noise_scale
return gt_boxes, points

def global_scaling_with_roi_boxes(gt_boxes, roi_boxes, points, scale_range, return_scale=False):
"""
Args:
gt_boxes: (N, 7), [x, y, z, dx, dy, dz, heading, [vx], [vy]]
points: (M, 3 + C),
scale_range: [min, max]
Returns:
"""
if scale_range[1] - scale_range[0] < 1e-3:
return gt_boxes, points
noise_scale = np.random.uniform(scale_range[0], scale_range[1])
points[:, :3] *= noise_scale
gt_boxes[:, :6] *= noise_scale
if gt_boxes.shape[1] > 7:
gt_boxes[:, 7:9] *= noise_scale
def scale_pointlike(points):
points[:, :3] *= noise_scale
return points

def scale_boxlike(boxes):
boxes[..., :6] *= noise_scale
if boxes.shape[-1] > 7:
boxes[..., 7:9] *= noise_scale
return boxes

roi_boxes[:,:, [0,1,2,3,4,5,7,8]] *= noise_scale
if noise_scale != 1.0:
tfs = dict(point=scale_pointlike, box=scale_boxlike)
else:
tfs = {}

if return_scale:
return gt_boxes,roi_boxes, points, noise_scale
return gt_boxes, roi_boxes, points
return tfs, noise_scale
return tfs


def random_image_flip_horizontal(image, depth_map, gt_boxes, calib):
Expand Down
58 changes: 13 additions & 45 deletions pcdet/datasets/augmentor/data_augmentor.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,24 +56,15 @@ def __setstate__(self, d):
def random_world_flip(self, data_dict=None, config=None):
if data_dict is None:
return partial(self.random_world_flip, config=config)
gt_boxes, points = data_dict['gt_boxes'], data_dict['points']
for cur_axis in config['ALONG_AXIS_LIST']:
assert cur_axis in ['x', 'y']
cur_dim = ['x', 'y'].index(cur_axis)
gt_boxes, points, enable = augmentor_utils.random_flip_along(
cur_dim, gt_boxes, points, return_flip=True
tfs, enable = augmentor_utils.random_flip_along(
cur_dim, return_flip=True
)
common_utils.apply_data_transform(data_dict, tfs)
data_dict['flip_%s'%cur_axis] = enable
if 'roi_boxes' in data_dict.keys():
data_dict['roi_boxes'], _ = augmentor_utils.random_flip_along(
cur_dim,
data_dict['roi_boxes'],
np.zeros([0,3]),
enable=enable,
)

data_dict['gt_boxes'] = gt_boxes
data_dict['points'] = points
return data_dict

def random_world_rotation(self, data_dict=None, config=None):
Expand All @@ -82,38 +73,18 @@ def random_world_rotation(self, data_dict=None, config=None):
rot_range = config['WORLD_ROT_ANGLE']
if not isinstance(rot_range, list):
rot_range = [-rot_range, rot_range]
gt_boxes, points, noise_rot = augmentor_utils.global_rotation(
data_dict['gt_boxes'], data_dict['points'], rot_range=rot_range, return_rot=True
tfs, noise_rot = augmentor_utils.global_rotation(
rot_range=rot_range, return_rot=True
)
if 'roi_boxes' in data_dict.keys():
data_dict['roi_boxes'], _ = augmentor_utils.global_rotation(
data_dict['roi_boxes'],
np.zeros([0, 3]),
rot_range=rot_range,
noise_rotation=noise_rot,
)

data_dict['gt_boxes'] = gt_boxes
data_dict['points'] = points
common_utils.apply_data_transform(data_dict, tfs)
data_dict['noise_rot'] = noise_rot
return data_dict

def random_world_scaling(self, data_dict=None, config=None):
if data_dict is None:
return partial(self.random_world_scaling, config=config)

if 'roi_boxes' in data_dict.keys():
gt_boxes, roi_boxes, points, noise_scale = augmentor_utils.global_scaling_with_roi_boxes(
data_dict['gt_boxes'], data_dict['roi_boxes'], data_dict['points'], config['WORLD_SCALE_RANGE'], return_scale=True
)
data_dict['roi_boxes'] = roi_boxes
else:
gt_boxes, points, noise_scale = augmentor_utils.global_scaling(
data_dict['gt_boxes'], data_dict['points'], config['WORLD_SCALE_RANGE'], return_scale=True
)

data_dict['gt_boxes'] = gt_boxes
data_dict['points'] = points
tfs, noise_scale = augmentor_utils.global_scaling(scale_range=config['WORLD_SCALE_RANGE'], return_scale=True)
common_utils.apply_data_transform(data_dict, tfs)
data_dict['noise_scale'] = noise_scale
return data_dict

Expand Down Expand Up @@ -147,15 +118,12 @@ def random_world_translation(self, data_dict=None, config=None):
np.random.normal(0, noise_translate_std[2], 1),
], dtype=np.float32).T

gt_boxes, points = data_dict['gt_boxes'], data_dict['points']
points[:, :3] += noise_translate
gt_boxes[:, :3] += noise_translate

if 'roi_boxes' in data_dict.keys():
data_dict['roi_boxes'][:, :, :3] += noise_translate
def translate_locationlike(locations):
locations[..., :3] += noise_translate
return locations

data_dict['gt_boxes'] = gt_boxes
data_dict['points'] = points
tfs = dict(point=translate_locationlike, box=translate_locationlike)
common_utils.apply_data_transform(data_dict, tfs)
data_dict['noise_translate'] = noise_translate
return data_dict

Expand Down
13 changes: 7 additions & 6 deletions pcdet/datasets/processor/data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ def mask_points_and_boxes_outside_range(self, data_dict=None, config=None):

if data_dict.get('points', None) is not None:
mask = common_utils.mask_points_by_range(data_dict['points'], self.point_cloud_range)
data_dict['points'] = data_dict['points'][mask]
tfs = dict(point=lambda x: x[mask])
common_utils.apply_data_transform(data_dict, tfs)

if data_dict.get('gt_boxes', None) is not None and config.REMOVE_OUTSIDE_BOXES and self.training:
mask = box_utils.mask_boxes_outside_range_numpy(
Expand All @@ -97,10 +98,9 @@ def shuffle_points(self, data_dict=None, config=None):
return partial(self.shuffle_points, config=config)

if config.SHUFFLE_ENABLED[self.mode]:
points = data_dict['points']
shuffle_idx = np.random.permutation(points.shape[0])
points = points[shuffle_idx]
data_dict['points'] = points
shuffle_idx = np.random.permutation(data_dict['points'].shape[0])
tfs = dict(point=lambda x: x[shuffle_idx])
common_utils.apply_data_transform(data_dict, tfs)

return data_dict

Expand Down Expand Up @@ -208,7 +208,8 @@ def sample_points(self, data_dict=None, config=None):
extra_choice = np.random.choice(choice, num_points - len(points), replace=False)
choice = np.concatenate((choice, extra_choice), axis=0)
np.random.shuffle(choice)
data_dict['points'] = points[choice]
tfs = dict(point=lambda x: x[choice])
common_utils.apply_data_transform(data_dict, tfs)
return data_dict

def calculate_grid_size(self, data_dict=None, config=None):
Expand Down
12 changes: 12 additions & 0 deletions pcdet/utils/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,18 @@ def drop_info_with_name(info, name):
return ret_info


def apply_data_transform(data_dict, transforms):
assert set(transforms.keys()).issubset({'point', 'box'})
data_keys = {
'point': ['points'],
'box': ['gt_boxes', 'roi_boxes']
}
for tf_type, tf in transforms.items():
for data_key in data_keys[tf_type]:
if data_key in data_dict:
data_dict[data_key] = tf(data_dict[data_key])


def rotate_points_along_z(points, angle):
"""
Args:
Expand Down

0 comments on commit 4436d06

Please sign in to comment.