Source code for dgs.models.dataset.keypoint_rcnn

"""
Use :func:`.keypointrcnn_resnet50_fpn` to predict the key points and bounding boxes of each image.

References:
    https://pytorch.org/vision/0.17/models/generated/torchvision.models.detection.keypointrcnn_resnet50_fpn.html
"""

import os
from abc import ABC
from typing import Union

import torch as t
from imagesize import imagesize
from torch import nn
from torchvision import tv_tensors as tvte
from torchvision.io import VideoReader
from torchvision.models.detection import keypointrcnn_resnet50_fpn, KeypointRCNN_ResNet50_FPN_Weights
from torchvision.ops import box_iou
from torchvision.transforms import v2
from torchvision.transforms.v2.functional import to_dtype
from tqdm import tqdm

from dgs.models.dataset.dataset import BaseDataset, ImageDataset, VideoDataset
from dgs.utils.config import DEF_VAL
from dgs.utils.constants import IMAGE_FORMATS, VIDEO_FORMATS
from dgs.utils.files import is_dir, is_file, read_json
from dgs.utils.image import create_mask_from_polygons, CustomToAspect, load_image
from dgs.utils.state import EMPTY_STATE, State
from dgs.utils.types import Config, FilePath, FilePaths, Image, Images, ImgShape, NodePath, Validations
from dgs.utils.utils import extract_crops_from_images

rcnn_validations: Validations = {
    "data_path": [("any", [str, ("all", [list, ("forall", str)])])],
    # optional
    "score_threshold": ["optional", float, ("within", (0.0, 1.0))],
    "iou_threshold": ["optional", float, ("within", (0.0, 1.0))],
    "force_reshape": ["optional", bool],
    "image_mode": ["optional", str, ("in", CustomToAspect.modes)],
    "image_size": ["optional", tuple, ("len", 2), ("forall", [int, ("gt", 0)])],
    "crop_mode": ["optional", str, ("in", CustomToAspect.modes)],
    "crop_size": ["optional", tuple, ("len", 2), ("forall", [int, ("gt", 0)])],
    "bbox_min_size": ["optional", float, ("gte", 1.0)],
    "mask_path": ["optional", str],
    "weights": ["optional", ("instance", KeypointRCNN_ResNet50_FPN_Weights)],
}


[docs] class KeypointRCNNBackbone(BaseDataset, nn.Module, ABC): """Metaclass for the torchvision Key Point RCNN backbone model. This class sets up the RCNN model and validates and sets the basic modules parameters. Params ------ data_path (FilePath): A single path or a list of paths. The path is either a directory, a single image file, or a list of image filepaths. Optional Params --------------- score_threshold (float, optional): Detections with a score lower than the threshold will be ignored. Default ``DEF_VAL.dataset.kprcnn.score_threshold``. iou_threshold (float, optional): Bounding-boxes with IoU above this threshold will be ignored. Default ``DEF_VAL.dataset.kprcnn.iou_threshold``. force_reshape (bool, optional): Whether to force reshape all the input images. Change the size and mode via ``image_mode`` and ``image_size`` parameters, iff ``force_reshape`` is `True`. Default ``DEF_VAL.images.force_reshape``. image_size (:obj:`ImgSize`, optional): The size, the loaded image should have, iff ``force_reshape`` is `True`. Default ``DEF_VAL.images.image_size``. image_mode (str, optional): The mode to use when loading the image, iff ``force_reshape`` is `True`. Default ``DEF_VAL.images.image_mode``. crop_size (:obj:`ImgSize`, optional): The size, the image crop should have. Default ``DEF_VAL.images.crop_size``. crop_mode (str, optional): The mode to use when cropping the image. Default ``DEF_VAL.images.crop_mode``. bbox_min_size (float, optional): The minimum side length a bounding box should have in pixels. Smaller detections will be discarded. Works in addition to the ``threshold`` parameter. If you do not want to discard smaller bounding boxes, make sure to set ``bbox_min_size`` to ``1.0``. The size of the bounding boxes is in relation to the original image. Default ``DEF_VAL.images.bbox_min_size``. weights (KeypointRCNN_ResNet50_FPN_Weights, optional): The weights to load for the model. Default ``KeypointRCNN_ResNet50_FPN_Weights.COCO_V1``. """ model: nn.Module
[docs] def __init__(self, config: Config, path: NodePath) -> None: BaseDataset.__init__(self, config=config, path=path) nn.Module.__init__(self) self.validate_params(rcnn_validations) self.score_threshold: float = self.params.get( "score_threshold", DEF_VAL["dataset"]["kprcnn"]["score_threshold"] ) self.logger.debug("Loading Keypoint-RCNN Model") weights = self.params.get("weights", KeypointRCNN_ResNet50_FPN_Weights.COCO_V1) model = keypointrcnn_resnet50_fpn(weights=weights, progress=True) self.register_module("model", self.configure_torch_module(module=model, train=False)) self.img_id: t.Tensor = t.tensor(1, dtype=t.long, device=self.device) bbox_min_size: float = float(self.params.get("bbox_min_size", DEF_VAL["images"]["bbox_min_size"])) self.bbox_cleaner = v2.Compose( [ v2.ClampBoundingBoxes(), v2.SanitizeBoundingBoxes( min_size=bbox_min_size, labels_getter=lambda y: (y["keypoints"], y["scores"], y["keypoints_scores"], y["labels"]), ), ] ) self.iou_threshold: float = self.params.get("iou_threshold", DEF_VAL["dataset"]["kprcnn"]["iou_threshold"]) # image loading params self.force_reshape: bool = self.params.get("force_reshape", DEF_VAL["images"]["force_reshape"]) self.image_size: ImgShape = self.params.get("image_size", DEF_VAL["images"]["image_size"]) self.image_mode: str = self.params.get("image_mode", DEF_VAL["images"]["image_mode"])
@t.no_grad() def images_to_states(self, images: Images) -> list[State]: """Given a list of images, use the key-point-RCNN model to predict key points and bounding boxes, then create a :class:`State` containing the available information. Notes: Does not add the original image to the new State, to reduce memory / GPU usage. With the filepath given in the state, the image can be reloaded if required. """ # make sure all images are float images = [tvte.Image(to_dtype(img, dtype=t.float32, scale=True)) for img in images] # predicts a list of {boxes: XYXY[N], labels: Int64[N], scores: [N], keypoints: Float[N,J,(x|y|vis)]} # every image in images can have multiple predictions outputs: list[dict[str, t.Tensor]] = self.model.forward(images) states: list[State] = [] canvas_size = (max(i.shape[-2] for i in images), max(i.shape[-1] for i in images)) for output, image in zip(outputs, images): # get the output for every image independently # bbox given in XYXY format output["boxes"] = tvte.BoundingBoxes(output["boxes"], format="XYXY", canvas_size=canvas_size) # first sanitize and clamp the bboxes, while cleaning up the respective other data as well sanitized = self.bbox_cleaner(output) scores = sanitized["scores"] # score of each instance # Get the sanitized bboxes and compute the IoU. bbox = tvte.BoundingBoxes(sanitized["boxes"], format="XYXY", canvas_size=canvas_size) iou = box_iou(bbox, bbox).tril(diagonal=-1) # lower tri excluding diag # Filter the bboxes using an IoU threshold. # Additionally, use only the indices where the score ('certainty') is bigger than the given score_threshold. # Because the output of KeypointRCNN is sorted by score, # using the lower triangular matrix will remove the lower score. indices = t.logical_and( t.logical_not(t.any(iou > self.iou_threshold, dim=1)), # iou smaller than scores > self.score_threshold, # score > thresh ) # get final bbox and B after double sanitizing bbox = tvte.BoundingBoxes(sanitized["boxes"][indices], format="XYXY", canvas_size=canvas_size) B: int = int(t.count_nonzero(indices).item()) data = { "validate": False, "image_id": t.ones(max(B, 1), device=self.device, dtype=t.long) * self.img_id, "frame_id": t.ones(max(B, 1), device=self.device, dtype=t.long) * self.img_id, } self.img_id += t.tensor(1, dtype=t.long, device=self.device) # increment counter # skip if there aren't any detections if B == 0: es = EMPTY_STATE.copy() es.update(data) states.append(es) continue # keypoints in [x,y,v] format -> kp, vis kps, vis = ( sanitized["keypoints"][indices] .to(device=self.device, dtype=self.precision) .reshape((-1, 17, 3)) .split([2, 1], dim=-1) ) assert kps.shape[-2:] == (17, 2), kps.shape[-2:] crops, loc_kps = extract_crops_from_images( imgs=[tvte.Image(image.unsqueeze(0)) for _ in range(B)], bboxes=bbox, kps=kps, crop_size=self.params.get("crop_size", DEF_VAL["images"]["crop_size"]), crop_mode=self.params.get("crop_mode", DEF_VAL["images"]["crop_mode"]), ) assert loc_kps is not None data = dict( data, **{ "skeleton_name": tuple("coco" for _ in range(B)), "scores": sanitized["keypoints_scores"][indices, :], # B x 17 "score": scores[indices], "bbox": bbox, "image_crop": crops, "keypoints": kps, "keypoints_local": loc_kps, "joint_weight": vis, "person_id": t.ones(B, device=self.device, dtype=t.long) * -1, # set as -1 }, ) states.append(State(**data)) return states
[docs] def terminate(self) -> None: # pragma: no cover if hasattr(self, "model"): del self.model
# pylint: disable=too-many-ancestors
[docs] class KeypointRCNNImageBackbone(KeypointRCNNBackbone, ImageDataset): """Predicts 17 key-points (like COCO). Optional Params --------------- mask_path (str, optional): The path to a PT21 json file containing the ``ignore_regions``. Note that currently only PT21 ignore regions are supported. References: https://pytorch.org/vision/0.17/models/generated/torchvision.models.detection.keypointrcnn_resnet50_fpn.html """ __doc__ += KeypointRCNNBackbone.__doc__ data: list[FilePath] masks: list[Union[tvte.Mask, None]]
[docs] def __init__(self, config: Config, path: NodePath) -> None: KeypointRCNNBackbone.__init__(self, config=config, path=path) ImageDataset.__init__(self, config=config, path=path) # load data - data_path is either a directory, a single image file, or a list of image filepaths self.data = [] data_path: any = self.params["data_path"] if isinstance(data_path, list): self.data = sorted(data_path) elif isinstance(data_path, str): data_path: FilePath = self.get_path_in_dataset(data_path) if is_file(data_path): # single image if data_path.lower().endswith(IMAGE_FORMATS): self.data = [data_path] # video file elif data_path.lower().endswith(VIDEO_FORMATS): raise TypeError(f"Got Video file, but is an Image Dataset. File: {data_path}") else: raise NotImplementedError(f"Unknown file type. Got '{data_path}'") elif is_dir(data_path): # directory of images self.data = [ os.path.normpath(os.path.join(data_path, child_path)) for child_path in tqdm(sorted(os.listdir(data_path)), desc="Loading images", leave=False) if child_path.lower().endswith(IMAGE_FORMATS) ] else: raise NotImplementedError(f"string is neither file nor dir. Got '{data_path}'.") else: raise NotImplementedError( f"Unknown path object, expected filepath, dirpath, or list of filepaths. Got {type(data_path)}" ) # fixme what about other masking types? if "mask_path" in self.params and self.force_reshape: self.masks = [ self.transform_resize_image()( { "image": tvte.Image( create_mask_from_polygons( img_size=imagesize.get(self.get_path_in_dataset(img["file_name"]))[::-1], polygons_x=img["ignore_regions_x"], polygons_y=img["ignore_regions_y"], device=self.device, ).unsqueeze(0) ), "box": tvte.BoundingBoxes( t.ones((1, 4), dtype=t.float32), canvas_size=self.image_size, format="XYWH", dtype=t.float32 ), "keypoints": t.ones((1, 15, 2)), "output_size": self.image_size, "mode": self.image_mode, } )["image"] .squeeze(0) .to(dtype=t.bool) for img in read_json(self.params["mask_path"])["images"] ] elif "mask_path" in self.params: self.masks = [ create_mask_from_polygons( img_size=imagesize.get(self.get_path_in_dataset(img["file_name"]))[::-1], polygons_x=img["ignore_regions_x"], polygons_y=img["ignore_regions_y"], device=self.device, ) for img in read_json(self.params["mask_path"])["images"] ] else: self.masks = [None for _ in range(len(self.data))]
[docs] def arbitrary_to_ds(self, a: FilePath, idx: int) -> list[State]: """Given a filepath, predict the bounding boxes and key-points of the respective image. Return a State containing all the available information. Because the state is known, the image is not saved in the State, to reduce the space-overhead on the GPU. Args: a: A single path to an image file. idx: The index of the file path within ``self.data``. Returns: A list containing one single :class:`.State` that describes zero or more detections of the given image. """ img = load_image( filepath=a, force_reshape=self.force_reshape, output_size=self.image_size, mode=self.image_mode, device=self.device, dtype=t.float32, ).squeeze(0) if self.masks[idx] is not None: # Get the mask with the same size as the image. # True, where the image should be ignored. mask = self.masks[idx] assert not self.force_reshape or mask.shape == t.Size(self.image_size), (mask.shape, self.image_size) else: mask = t.zeros(img.shape[-2:], device=self.device, dtype=t.bool) # create the image by using the unmasked area of the image and the masked area of a black image masked_img = tvte.Image(img * t.bitwise_not(mask) + t.zeros_like(img) * mask, device=self.device) # the torch model expects a list of 3D images states = self.images_to_states(images=[masked_img]) for state in states: state.filepath = tuple(a for _ in range(max(state.B, 1))) return states
def __getitems__(self, indices: list[int]) -> list[State]: """Get a batch of predictions from the dataset. It is expected that all images have the same shape. Returns: A list containing one :class:`.State` per image / index. Every State describes zero or more detections of the respective image. """ fps: FilePaths = tuple(self.data[idx] for idx in indices) masks = [self.masks[idx] for idx in indices] images = load_image( fps, force_reshape=self.force_reshape, output_size=self.image_size, mode=self.image_mode, device=self.device, dtype=t.float32, ) assert not self.force_reshape or all( mask.shape == t.Size(self.image_size) for mask in masks if mask is not None ) # the torch model expects a list of 3D images masked_images = [ (img * t.bitwise_not(mask) + t.zeros_like(img) * mask).squeeze(0) if mask is not None else img.squeeze(0) for img, mask in zip(images.split(1, dim=0), masks) ] states = self.images_to_states(images=masked_images) for fp, state in zip(fps, states): state.filepath = tuple(fp for _ in range(max(state.B, 1))) return states
# pylint: disable=too-many-ancestors
[docs] class KeypointRCNNVideoBackbone(KeypointRCNNBackbone, VideoDataset): """A Dataset that gets the path to a single Video file and predicts the bounding boxes and key points of the Video. Predicts 17 key-points (like COCO). References: https://pytorch.org/vision/0.17/models/generated/torchvision.models.detection.keypointrcnn_resnet50_fpn.html """ __doc__ += KeypointRCNNBackbone.__doc__ data: VideoReader
[docs] def __init__(self, config: Config, path: NodePath) -> None: KeypointRCNNBackbone.__init__(self, config=config, path=path) VideoDataset.__init__(self, config=config, path=path)
# the data has already been loaded in the VideoDataset # the model and threshold has been loaded in KeypointRCNNBackbone
[docs] def arbitrary_to_ds(self, a: Image, idx: int) -> list[State]: """Given a frame of a video, return the resulting state after running the RCNN model.""" if not isinstance(a, t.Tensor): raise NotImplementedError # the torch RCNN model expects a list of 3D images if self.force_reshape: # reshape the images iff requested to the desired size images = [ self.transform_resize_image()( { "image": tvte.Image(a), "box": tvte.BoundingBoxes( t.ones((1, 4), dtype=t.float32), canvas_size=self.image_size, format="XYWH", dtype=t.float32, ), "keypoints": t.ones((1, 15, 2)), "output_size": self.image_size, "mode": self.image_mode, } )["image"], ] else: images = [a] states = self.images_to_states(images=images) for img, state in zip(images, states): state.image = [tvte.Image(img.unsqueeze(0)) for _ in range(state.B)] return states