"""
Copright © 2023 Howard Hughes Medical Institute, Authored by Carsen Stringer and Atika Syeda.
"""
import os
import pickle
import time
from io import StringIO
import h5py
import numpy as np
import torch
from tqdm import tqdm
from facemap import utils
from . import datasets, facemap_network, model_loader
from . import pose_helper_functions as pose_utils
from . import transforms
try:
from . import model_training
except:
train = False
"""
Base class for generating pose estimates.
Contains functions that can be used through CLI or GUI
Currently supports single video processing, whereas multi-view videos recorded simultaneously are processed sequentially.
"""
[docs]class Pose:
"""Pose estimation for single video processing
Parameters
----------
filenames: 2D-list
List of filenames to be processed.
bbox: list
Bounding box for cropping the video [x1, x2, y1, y2]. If not set, the entire frame is used.
bbox_set: bool
Flag to indicate whether the bounding box has been set. Default is False.
resize: bool
Flag to indicate whether the video needs to be resized
add_padding: bool
Flag to indicate whether the video needs to be padded. Default is False.
gui: object
GUI object.
GUIobject: object
GUI mainwindow object.
net: object
PyTorch model object.
model_name: str
Name of the model to be used for pose estimation. Default is None which uses the pre-trained model.
"""
def __init__(
self,
filenames=None,
bbox=[],
bbox_set=False,
resize=False,
add_padding=False,
gui=None,
GUIobject=None,
net=None,
model_name=None,
):
self.gui = gui
self.GUIobject = GUIobject
if self.gui is not None:
self.filenames = self.gui.filenames
self.batch_size = self.gui.batch_size_spinbox.value()
else:
self.filenames = filenames
self.batch_size = 1
self.cumframes, self.Ly, self.Lx, self.containers = utils.get_frame_details(
self.filenames
)
self.nframes = self.cumframes[-1]
self.pose_labels = None
if gui is not None:
self.device = self.gui.device
else:
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.bbox = bbox
self.bbox_set = bbox_set
self.resize = resize
self.add_padding = add_padding
self.net = net
self.model_name = model_name
self.bodyparts = [
"eye(back)",
"eye(bottom)",
"eye(front)",
"eye(top)",
"lowerlip",
"mouth",
"nose(bottom)",
"nose(r)",
"nose(tip)",
"nose(top)",
"nosebridge",
"paw",
"whisker(I)", # "whisker(c1)",
"whisker(III)", # "whisker(d2)",
"whisker(II)", # "whisker(d1)",
]
def pose_prediction_setup(self):
# Setup the model
self.load_model()
# Setup the bounding box
if not self.bbox_set:
for i in range(len(self.Ly)):
x1, x2, y1, y2 = 0, self.Ly[i], 0, self.Lx[i]
self.bbox.append([x1, x2, y1, y2])
# Update resize and add padding flags
if x2 - x1 != y2 - y1: # if not a square frame view then add padding
self.add_padding = True
if x2 - x1 != 256 or y2 - y1 != 256: # if not 256x256 then resize
self.resize = True
prompt = (
"No bbox set. Using entire frame view: {} and resize={}".format(
self.bbox, self.resize
)
)
utils.update_mainwindow_message(
MainWindow=self.gui,
GUIobject=self.GUIobject,
prompt=prompt,
hide_progress=True,
)
self.bbox_set = True
[docs] def set_model(self, model_selected=None):
"""Set model to use for pose estimation
Args:
model_selected (str, optional): Path of trained model weights to use. Default value of None sets the model_name to Base model/facemap_model_state which uses pre-trained weights.
"""
if model_selected is None:
if self.gui is None:
model_selected = "Base model"
else:
model_selected = self.gui.pose_model_combobox.currentText()
# Get all model names
model_paths = model_loader.get_model_states_paths()
if len(model_paths) == 0: # No models found, set default model
self.model_name = model_loader.get_basemodel_state_path()
model_names = [os.path.splitext(os.path.basename(m))[0] for m in model_paths]
for model in model_names: # Find selected model and update model name
if (model == model_selected) or (
model_selected == "Base model" and "facemap_model_state" in model
):
print("Setting model name to:", model)
self.model_name = model_paths[model_names.index(model)]
break
print("Loading model state from:", self.model_name)
self.net.load_state_dict(torch.load(self.model_name, map_location=self.device, weights_only=False))
self.net.to(self.device)
[docs] def load_model(self):
"""
Load model for keypoints prediction. Uses default model unless set_model is used to update the model name and load the model state.
"""
model_params_file = model_loader.get_model_params_path()
print("{} set as device".format(self.device))
print("Loading model parameters from:", model_params_file)
utils.update_mainwindow_message(
MainWindow=self.gui,
GUIobject=self.GUIobject,
prompt="Loading model... {}".format(model_params_file),
)
model_params = torch.load(model_params_file, map_location=self.device, weights_only=False)
# self.bodyparts = model_params["params"]["bodyparts"]
channels = model_params["params"]["channels"]
kernel_size = 3
nout = len(self.bodyparts) # number of outputs from the model
self.net = facemap_network.FMnet(
img_ch=1,
output_ch=nout,
labels_id=self.bodyparts,
channels=channels,
kernel=kernel_size,
device=self.device,
)
if self.model_name is None:
self.set_model()
else:
self.set_model(self.model_name)
[docs] def train(
self,
image_data,
keypoints_data,
num_epochs,
batch_size,
learning_rate,
weight_decay,
bbox,
):
"""
Train the model
Parameters
----------
image_data: ND-array
Array of images of shape (nframes, Ly, Lx)
keypoints_data: ND-array
Array of keypoints of shape (nframes, nkeypoints, 2)
num_epochs: int
Number of epochs for training
batch_size: int
Batch size for training
learning_rate: float
Learning rate for training
weight_decay: float
Weight decay for training
Returns
-------
model: torch.nn.Module
Trained/finetuned model
"""
# Create a dataset object for training
dataset = datasets.FacemapDataset(
image_data=image_data,
keypoints_data=keypoints_data,
bbox=bbox,
)
# Create a dataloader object for training
dataloader = torch.utils.data.DataLoader(
dataset, batch_size=batch_size, shuffle=True
)
# Use preprocessed data to train the model
self.net = model_training.train(
dataloader,
self.net,
num_epochs,
learning_rate,
weight_decay,
gui=self.gui,
gui_obj=self.GUIobject,
)
print("Model training complete!")
return self.net
[docs] def predict_landmarks(self, video_id, frame_ind=None):
"""
Predict keypoints/landmarks for all frames in video and save output as .h5 file. If frame_ind is specified, only predict keypoints/landmarks for those frames.
Parameters
----------
video_id: int
Index of video in filenames list to be used for prediction.
frame_ind: list
List of frame indices for keypoints/landmarks prediction.
"""
nchannels = 1
if frame_ind is None:
total_frames = self.cumframes[-1]
frame_ind = np.arange(total_frames)
else:
total_frames = len(frame_ind)
# Create array for storing predictions
pred_data = torch.zeros(total_frames, len(self.bodyparts), 3)
# Store predictions in dataframe
self.net.eval()
start = 0
end = self.batch_size
# Get bounding box for the video
y1, _, x1, _ = self.bbox[video_id]
inference_time = 0
print("Using params:")
print("\tbbox:", self.bbox[video_id])
print("\tbatch size:", self.batch_size)
print("\tresize:", self.resize)
print("\tpadding:", self.add_padding)
# FIXME: Plotting keypoints after batch processing is not working properly
progress_output = StringIO()
with tqdm(
total=total_frames, unit="frame", unit_scale=True, file=progress_output
) as pbar:
while start != total_frames: # for analyzing entire video
# Pre-pocess images
imall = np.zeros(
(self.batch_size, nchannels, self.Ly[video_id], self.Lx[video_id])
)
cframes = np.array(frame_ind[start:end])
# utils.get_frames(imall, self.containers, cframes, self.cumframes)
imall = utils.get_batch_frames(
cframes,
total_frames,
self.cumframes,
self.containers,
video_idx=video_id,
grayscale=True,
)
# Inference time includes: pre-processing, inference, post-processing
t0 = time.time()
# Pre-process images
imall, postpad_shape, pads = transforms.preprocess_img(
imall,
self.bbox[video_id],
self.add_padding,
self.resize,
device=self.net.device,
)
# Run inference
xlabels, ylabels, likelihood = pose_utils.predict(
self.net, imall, smooth=False
)
xlabels, ylabels = transforms.adjust_keypoints(
xlabels,
ylabels,
crop_xy=(x1, y1),
padding=pads,
current_size=(256, 256),
desired_size=postpad_shape,
)
# Add predictions to array
pred_data[start:end, :, 0] = xlabels
pred_data[start:end, :, 1] = ylabels
pred_data[start:end, :, 2] = likelihood
# Update progress bar and inference time
inference_time += time.time() - t0
pbar.update(self.batch_size)
start = end
end += self.batch_size
end = min(end, total_frames)
# Update progress bar for every 5% of the total frames
percent_frames = int(np.floor(total_frames * 0.05))
if percent_frames != 0 and (end) % percent_frames == 0:
utils.update_mainwindow_progressbar(
MainWindow=self.gui,
GUIobject=self.GUIobject,
s=progress_output,
prompt="Pose prediction progress:",
)
inference_speed = total_frames / inference_time
print("Inference speed:", inference_speed, "fps")
metadata = {
"batch_size": self.batch_size,
"image_size": (self.Ly, self.Lx),
"bbox": self.bbox[video_id],
"total_frames": total_frames,
"bodyparts": self.bodyparts,
"inference_speed": inference_speed,
}
return pred_data, metadata
[docs] def save_model(self, model_filepath):
"""
Save model to file
Parameters
----------
model_filepath: str
Path to save model weights
"""
torch.save(self.net.state_dict(), model_filepath)
model_loader.copy_to_models_dir(model_filepath)
return model_filepath
[docs] def save_data_to_hdf5(self, data, video_id, selected_frame_ind=None):
"""save_data_to_hdf5: Save data to an HDF5 file
Args:
data (2D-array): Data to save (nframes x nbodyparts x 3)
selected_frame_ind (list): Indices of selected frames
"""
# Create a multi-index dict to store data in HDF5 file. First index is the scorer name, second index is the bodypart names, and third index is the coordinates (x, y, likelihood)
scorer = "Facemap"
bodyparts = self.bodyparts
data_dict = {}
data_dict[scorer] = {}
if selected_frame_ind is None:
indices = np.arange(self.cumframes[-1])
else:
indices = selected_frame_ind
for index, bodypart in enumerate(bodyparts):
data_dict[scorer][bodypart] = {}
data_dict[scorer][bodypart]["x"] = data[:, index, 0][indices]
data_dict[scorer][bodypart]["y"] = data[:, index, 1][indices]
data_dict[scorer][bodypart]["likelihood"] = data[:, index, 2][indices]
if self.gui is not None:
basename = self.gui.save_path
_, filename = os.path.split(self.filenames[0][video_id])
videoname, _ = os.path.splitext(filename)
else:
basename, filename = os.path.split(self.filenames[0][video_id])
videoname, _ = os.path.splitext(filename)
hdf5_filepath = os.path.join(basename, videoname + "_FacemapPose.h5")
with h5py.File(hdf5_filepath, "w") as f:
self.save_dict_to_hdf5(f, "", data_dict)
return hdf5_filepath
[docs] def save_dict_to_hdf5(self, h5file, path, data_dict):
"""
Saves dictionary to an HDF5 file. Adapted from https://github.com/talmolab/sleap/blob/391bc0421fe3820ddd6b5d07e31311d60b129fe3/sleap/util.py#L116
Calls itself recursively if items in dictionary are not
`np.ndarray`, `np.int64`, `np.float64`, `str`, or bytes.
Objects must be iterable.
Args:
h5file: The HDF5 filename object to save the data to.
Assume it is open.
path: The path to group save the dict under.
data_dict: The dict containing data to save.
Raises:
ValueError: If type for item in dict cannot be saved.
Returns:
None
"""
for key, item in list(data_dict.items()):
if item is None:
h5file[path + key] = ""
elif isinstance(item, bool):
h5file[path + key] = int(item)
elif isinstance(item, list):
items_encoded = []
for it in item:
if isinstance(it, str):
items_encoded.append(it.encode("utf8"))
else:
items_encoded.append(it)
h5file[path + key] = np.asarray(items_encoded)
elif isinstance(item, (str)):
h5file[path + key] = item.encode("utf8")
elif isinstance(
item, (np.ndarray, np.int64, np.float64, str, bytes, float)
):
h5file[path + key] = item
elif isinstance(item, dict):
self.save_dict_to_hdf5(h5file, path + key + "/", item)
elif isinstance(item, int):
h5file[path + key] = item
else:
raise ValueError("Cannot save %s type" % type(item))
def run(self):
start_time = time.time()
self.pose_prediction_setup()
for video_id in range(len(self.filenames[0])):
utils.update_mainwindow_message(
MainWindow=self.gui,
GUIobject=self.GUIobject,
prompt="Processing video: {}".format(self.filenames[0][video_id]),
hide_progress=True,
)
print("\nProcessing video: {}".format(self.filenames[0][video_id]))
pred_data, metadata = self.predict_landmarks(video_id)
# Save the data using h5py
savepath = self.save_data_to_hdf5(pred_data.cpu().numpy(), video_id)
utils.update_mainwindow_message(
MainWindow=self.gui,
GUIobject=self.GUIobject,
prompt="Saved pose prediction outputs to: {}".format(savepath),
hide_progress=True,
)
print("Saved keypoints:", savepath)
# Save metadata to a pickle file
metadata_file = os.path.splitext(savepath)[0] + "_metadata.pkl"
with open(metadata_file, "wb") as f:
pickle.dump(metadata, f, pickle.HIGHEST_PROTOCOL)
print("Saved metadata:", metadata_file)
if self.gui is not None:
self.gui.poseFilepath.append(savepath)
end_time = time.time()
utils.update_mainwindow_message(
MainWindow=self.gui,
GUIobject=self.GUIobject,
prompt="Pose estimation time elapsed: {} seconds".format(
end_time - start_time
),
hide_progress=True,
)
def run_subset(self, subset_ind=None):
print("Using {} for pose estimation".format(self.model_name))
if subset_ind is None:
# Select a random subset of frames
subset_size = int(self.nframes / 10)
subset_ind = np.random.choice(self.nframes, subset_size, replace=False)
# subset_ind = np.sort(subset_ind)
utils.update_mainwindow_message(
MainWindow=self.gui,
GUIobject=self.GUIobject,
prompt="Processing video: {}".format(self.filenames[0][0]),
hide_progress=True,
)
pred_data, _ = self.predict_landmarks(0, frame_ind=subset_ind)
utils.update_mainwindow_message(
MainWindow=self.gui,
GUIobject=self.GUIobject,
prompt="Finished processing subset of video",
hide_progress=True,
)
return pred_data, subset_ind, self.bbox
def save_pose_prediction(self, dataFrame, video_id):
# Save prediction to .h5 file
if self.gui is not None:
basename = self.gui.save_path
_, filename = os.path.split(self.filenames[0][video_id])
videoname, _ = os.path.splitext(filename)
else:
basename, filename = os.path.split(self.filenames[0][video_id])
videoname, _ = os.path.splitext(filename)
poseFilepath = os.path.join(basename, videoname + "_FacemapPose.h5")
dataFrame.to_hdf(poseFilepath, "df_with_missing", mode="w")
return poseFilepath
def plot_pose_estimates(self):
# Plot labels
self.gui.is_pose_loaded = True
self.gui.load_keypoints()
self.gui.keypoints_checkbox.setChecked(True)
self.gui.start()