Source code for facemap.pose.transforms

"""
Copright © 2023 Howard Hughes Medical Institute, Authored by Carsen Stringer and Atika Syeda.
"""
"""
Facemap functions for: 
- bounding box: (suggested ROI) for UNet input images
- image preprocessing 
- image augmentation
"""
import numpy as np
import torch
from scipy import ndimage
from torch.nn import functional as F

from . import pose_helper_functions as pose_utils


[docs]def preprocess_img(im, bbox, add_padding, resize, device=None): """ Preproccesing of image involves: 1. Conversion to float32 and normalize99 2. Cropping image to select bounding box (bbox) region 3. padding image size to be square 4. Resize image to 256x256 for model input Parameters ------------- im: ND-array image of size [(Lz) x Ly x Lx] bbox: tuple of size (4,) bounding box positions in order x1, x2, y1, y2 add_padding: bool whether to add padding to image resize: bool whether to resize image Returns -------------- im: ND-array preprocessed image of size [1 x Ly x Lx] if input dimensions==2, else [Lz x Ly x Lx] postpad_shape: tuple of size (2,) shape of padded image pads: tuple of size (4,) padding values for (pad_y_top, pad_y_bottom, pad_x_left, pad_x_right) """ # 1. Convert to float32 in range 0-1 if im.ndim == 2: im = im[np.newaxis, ...] # Convert numpy array to tensor if device is not None: im = torch.from_numpy(im) if device.type != "cpu": im = im.pin_memory() im = im.to(device, dtype=torch.float32) # Normalize im = pose_utils.normalize99(im, device=device) # 2. Crop image im = crop_image(im, bbox) # 3. Pad image to square if add_padding: im, pads = pad_img_to_square(im, bbox) else: pads = (0, 0, 0, 0) # 4. Resize image to 256x256 for model input postpad_shape = im.shape[-2:] if resize: im = F.interpolate(im, size=(256, 256), mode="bilinear") return im, postpad_shape, pads
[docs]def randomize_bbox_coordinates(bbox, im_shape, random_factor_range=(0.1, 0.3)): """ Randomize bounding box by a random amount to increase/expand the bounding box region while staying within the image region. Parameters ---------- bbox: tuple of size (4,) bounding box positions in order x1, x2, y1, y2 im_shape: tuple of size (2,) image shape in order Ly, Lx random_factor_range: tuple of size (2,) range of random factor to use for expaning bounding box Returns ------- bbox: tuple of size (4,) randomized bounding box positions in order x1, x2, y1, y2 """ bbox = np.array(bbox) x1, x2, y1, y2 = bbox x_range = x2 - x1 y_range = y2 - y1 x_min = int(x1 - x_range * get_random_factor(random_factor_range)) x_max = int(x2 + x_range * get_random_factor(random_factor_range)) y_min = int(y1 - y_range * get_random_factor(random_factor_range)) y_max = int(y2 + y_range * get_random_factor(random_factor_range)) x_min = max(0, x_min) x_max = min(im_shape[0], x_max) y_min = max(0, y_min) y_max = min(im_shape[1], y_max) bbox = np.array([x_min, x_max, y_min, y_max]) return bbox
[docs]def get_random_factor(factor_range): """ Get a random factor within the range provided. Parameters ---------- factor_range: tuple of size (2,) factor range in order min, max Returns ------- factor: float random factor """ factor = np.random.uniform(factor_range[0], factor_range[1]) return factor
[docs]def get_cropped_imgs(imgs, bbox): """ Preproccesing of image involves: conversion to float32 in range0-1, normalize99, and padding image size to be compatible with UNet model input Parameters ------------- imgs: ND-array images of size [batch_size x nchan x Ly x Lx] bbox: tuple of size (4,) bounding box positions in order x1, x2, y1, y2 Returns -------------- cropped_imgs: ND-array images of size [batch_size x nchan x Ly' x Lx'] where Ly' = y2-y1 and Lx'=x2-x1 """ x1, x2, y1, y2 = (np.round(bbox)).astype(int) batch_size = imgs.shape[0] nchannels = imgs.shape[1] cropped_imgs = np.empty((batch_size, nchannels, x2 - x1, y2 - y1)) for i in range(batch_size): for n in range(nchannels): cropped_imgs[i, n] = imgs[i, n, x1:x2, y1:y2] return cropped_imgs
[docs]def pad_keypoints(keypoints, pad_h, pad_w): """ Pad keypoints using padding values for width and height. Parameters ---------- keypoints : ND-array keypoints of size [N x 2] pad_h : int height padding pad_w : int width padding Returns ------- keypoints : ND-array padded keypoints of size [N x 2] """ keypoints[:, 0] += pad_h keypoints[:, 1] += pad_w return keypoints
[docs]def pad_img_to_square(img, bbox=None): """ Pad image to square. Parameters ---------- im : ND-array image of size [c x h x w] bbox: tuple of size (4,) bounding box positions in order x1, x2, y1, y2 used for cropping image Returns ------- im : ND-array padded image of size [c x h x w] (pad_w, pad_h) : tuple of int padding values for width and height """ if bbox is not None: # Check if bbox is square x1, x2, y1, y2 = bbox dx, dy = x2 - x1, y2 - y1 else: dx, dy = img.shape[-2:] if dx == dy: return img, (0, 0, 0, 0) largest_dim = max(dx, dy) if (dx < largest_dim and abs(dx - largest_dim) % 2 != 0) or ( dy < largest_dim and abs(dy - largest_dim) % 2 != 0 ): largest_dim += 1 if dx < largest_dim: pad_x = abs(dx - largest_dim) pad_x_left = pad_x // 2 pad_x_right = pad_x - pad_x_left else: pad_x_left = 0 pad_x_right = 0 if dy < largest_dim: pad_y = abs(dy - largest_dim) pad_y_top = pad_y // 2 pad_y_bottom = pad_y - pad_y_top else: pad_y_top = 0 pad_y_bottom = 0 if img.ndim > 3: pads = (pad_y_top, pad_y_bottom, pad_x_left, pad_x_right, 0, 0, 0, 0) elif img.ndim == 3: pads = (pad_y_top, pad_y_bottom, pad_x_left, pad_x_right, 0, 0) else: pads = (pad_y_top, pad_y_bottom, pad_x_left, pad_x_right) img = F.pad( img, pads, mode="constant", value=0, ) return img, (pad_y_top, pad_y_bottom, pad_x_left, pad_x_right)
[docs]def resize_keypoints(keypoints, desired_shape, original_shape): """ Resize keypoints to desired shape. Parameters ---------- keypoints: ND-array keypoints of size [batch x N x 2] desired_shape: tuple of size (2,) desired shape of image original_shape: tuple of size (2,) original shape of image Returns ------- keypoints: ND-array keypoints of size [batch x N x 2] """ x_scale = desired_shape[1] / original_shape[1] # scale factor for x coordinates y_scale = desired_shape[0] / original_shape[0] # scale factor for y coordinates xlabels, ylabels = keypoints[:, 0], keypoints[:, 1] xlabels = xlabels * x_scale ylabels = ylabels * y_scale # Stack the x and y coordinates together using torch.stack keypoints = torch.stack([xlabels, ylabels], dim=1) return keypoints
[docs]def resize_image(im, resize_shape): """ Resize image to given height and width. Parameters ---------- im : ND-array image of size [Ly x Lx] resize_shape : tuple of size (2,) desired shape of image Returns ------- im : ND-array resized image of size [h x w] """ h, w = resize_shape if im.ndim == 3: im = torch.unsqueeze(im, dim=0) elif im.ndim == 2: im = torch.unsqueeze(im, dim=0) im = torch.unsqueeze(im, dim=0) im = F.interpolate(im, size=(h, w), mode="bilinear", align_corners=True).squeeze( dim=0 ) return im
[docs]def get_crop_resize_params(img, x_dims, y_dims, xy=(256, 256)): """ Get cropped and resized image dimensions Input:- img: image x_dims: min,max x pos y_dims: min,max y pos xy: final (desired) image size Output:- x1: (int) x dim start pos x2: (int) x dim stop pos y1: (int) y dim start pos y2: (int) y dim stop pos resize: (bool) whether to resize image """ x1 = int(x_dims[0]) x2 = int(x_dims[1]) y1 = int(y_dims[0]) y2 = int(y_dims[1]) resize = False if abs(y2 - y1) > xy[0]: # if cropped image larger than desired size # crop image then resize image and landmarks resize = True else: # if cropped image smaller than desired size then add padding accounting for labels in view y_pad = abs(abs(y2 - y1) - xy[0]) if y_pad % 2 == 0: y_pad = y_pad // 2 y1, y2 = y1 - y_pad, y2 + y_pad else: # odd number division so add 1 y_pad = y_pad // 2 y1, y2 = y1 - y_pad, y2 + y_pad + 1 if abs(x2 - x1) > xy[1]: # if cropped image larger than desired size resize = True else: x_pad = abs(abs(x2 - x1) - xy[1]) if x_pad % 2 == 0: x_pad = x_pad // 2 x1, x2 = x1 - x_pad, x2 + x_pad else: x_pad = x_pad // 2 x1, x2 = x1 - x_pad, x2 + x_pad + 1 if y2 > img.shape[1]: y1 -= y2 - img.shape[1] if x2 > img.shape[0]: x1 -= x2 - img.shape[0] y2, x2 = min(y2, img.shape[1]), min(x2, img.shape[0]) y1, x1 = max(0, y1), max(0, x1) y2, x2 = max(y2, xy[0]), max(x2, xy[1]) return x1, x2, y1, y2, resize
[docs]def crop_image(im, bbox=None): """ Crop image to bounding box. Parameters ---------- im : ND-array image of size [(Lz) x Ly x Lx] bbox : tuple of size (4,) bounding box positions in order x1, x2, y1, y2 Returns ------- im : ND-array cropped image of size [1 x Ly x Lx] """ if bbox is None: return im y1, y2, x1, x2 = bbox if im.ndim == 2: im = im[y1:y2, x1:x2] elif im.ndim == 3: im = im[:, y1:y2, x1:x2] elif im.ndim == 4: im = im[:, :, y1:y2, x1:x2] else: raise ValueError("Cannot handle image with ndim=" + str(im.ndim)) return im
[docs]def adjust_keypoints(xlabels, ylabels, crop_xy, padding, current_size, desired_size): """ Adjust raw keypoints (x,y coordinates) obtained from model to plot on original image Parameters ------------- xlabels: ND-array x coordinates of keypoints ylabels: ND-array y coordinates of keypoints crop_xy: tuple of size (2,) initial coordinates of bounding box (x1,y1) for cropping padding: tuple of size (4,) padding values for bounding box (x1,x2,y1,y2) Returns -------------- xlabels: ND-array x coordinates of keypoints ylabels: ND-array y coordinates of keypoints """ # Rescale keypoints to original image size xlabels, ylabels = rescale_keypoints(xlabels, ylabels, current_size, desired_size) xlabels, ylabels = adjust_keypoints_for_padding(xlabels, ylabels, padding) # Adjust for cropping x1, y1 = crop_xy[0], crop_xy[1] xlabels += x1 ylabels += y1 return xlabels, ylabels
[docs]def rescale_keypoints(xlabels, ylabels, current_size, desired_size): """ Rescale keypoints to original image size Parameters ------------- xlabels: ND-array x coordinates of keypoints ylabels: ND-array y coordinates of keypoints current_size: tuple of size (2,) current size of image (h,w) desired_size: tuple of size (2,) desired size of image (h,w) Returns -------------- xlabels: ND-array x coordinates of keypoints ylabels: ND-array y coordinates of keypoints """ xlabels *= desired_size[1] / current_size[1] # x_scale ylabels *= desired_size[0] / current_size[0] # y_scale return xlabels, ylabels
[docs]def adjust_keypoints_for_padding(xlabels, ylabels, pads): """ Adjust keypoints for padding. Adds padding to the top and left of the image only. Parameters ------------- xlabels: ND-array x coordinates of keypoints ylabels: ND-array y coordinates of keypoints pads: tuple of size (4,) padding values for bounding box (y1,y2,x1,x2) Returns -------------- xlabels: ND-array x coordinates of keypoints after padding ylabels: ND-array y coordinates of keypoints after padding """ pad_y_top, pad_y_bottom, pad_x_left, pad_x_right = pads xlabels -= pad_y_top ylabels -= pad_x_left return xlabels, ylabels
[docs]def adjust_bbox(prev_bbox, img_yx, div=16, extra=1): """ Takes a bounding box as an input and the original image size. Adjusts bounding box to be square instead of a rectangle. Uses longest dimension of prev_bbox for final image size that cannot exceed img_yx Parameters ------------- prev_bbox: tuple of size (4,) bounding box positions in order x1, x2, y1, y2 img_yx: tuple of size (2,) image size for y and x dimensions Returns -------------- bbox: tuple of size (4,) bounding box positions in order x1, x2, y1, y2 """ x1, x2, y1, y2 = np.round(prev_bbox) xdim, ydim = (x2 - x1), (y2 - y1) # Pad bbox dimensions to be divisible by div Lpad = int(div * np.ceil(xdim / div) - xdim) xpad1 = extra * div // 2 + Lpad // 2 xpad2 = extra * div // 2 + Lpad - Lpad // 2 Lpad = int(div * np.ceil(ydim / div) - ydim) ypad1 = extra * div // 2 + Lpad // 2 ypad2 = extra * div // 2 + Lpad - Lpad // 2 x1, x2, y1, y2 = x1 - xpad1, x2 + xpad2, y1 - ypad1, y2 + ypad2 xdim = min(x2 - x1, img_yx[1]) ydim = min(y2 - y1, img_yx[0]) # Choose largest dimension for image size if xdim > ydim: # Adjust ydim ypad = xdim - ydim if ypad % 2 != 0: ypad += 1 y1 = max(0, y1 - ypad // 2) y2 = min(y2 + ypad // 2, img_yx[0]) else: # Adjust xdim xpad = ydim - xdim if xpad % 2 != 0: xpad += 1 x1 = max(0, x1 - xpad // 2) x2 = min(x2 + xpad // 2, img_yx[1]) adjusted_bbox = (x1, x2, y1, y2) return adjusted_bbox
[docs]def augment_data( image, keypoints, scale=False, scale_range=0.5, rotation=False, rotation_range=10, flip=True, contrast_adjust=True, ): """ Augments data by randomly scaling, rotating, flipping, and adjusting contrast Parameters ---------- image: ND-array image of size nchan x Ly x Lx keypoints: ND-array keypoints of size nkeypoints x 2 scale: bool whether to scale the image scale_range: float range of scaling factor rotation: bool whether to rotate the image rotation_range: float range of rotation angle flip: bool whether to flip the image horizontally contrast_adjust: bool whether to adjust contrast of image Returns ------- image: ND-array image of size nchan x Ly x Lx keypoints: ND-array keypoints of size nkeypoints x 2 """ if scale: scale_range = max(0, min(2, float(scale_range))) scale_factor = (np.random.rand() - 0.5) * scale_range + 1 image = image.squeeze() * scale_factor keypoints = keypoints * scale_factor if rotation: theta = np.random.rand() * rotation_range - rotation_range / 2 print("rotating by {}".format(theta)) image = ndimage.rotate(image, theta, axes=(-2, -1), reshape=True) keypoints = rotate_points(keypoints, theta) # TODO: Add rotation function if flip and np.random.rand() > 0.5: keypoints[:, 0] = 256 - keypoints[:, 0] image = ndimage.rotate(image, 180, axes=(-1, 0), reshape=True) if contrast_adjust and np.random.rand() > 0.5: image = pose_utils.randomly_adjust_contrast(image) return image, keypoints