"""
Modules for computing the similarity between two poses.
"""
import torch as t
from torchvision.ops import box_area, box_iou
from torchvision.transforms.v2 import ConvertBoundingBoxFormat
from torchvision.tv_tensors import BoundingBoxes, BoundingBoxFormat
from dgs.models.similarity.similarity import SimilarityModule
from dgs.utils.config import DEF_VAL
from dgs.utils.constants import OKS_SIGMAS
from dgs.utils.state import State
from dgs.utils.types import Config, NodePath, Validations
oks_validations: Validations = {
    "format": [str, ("in", list(OKS_SIGMAS.keys()))],
    # optional
    "keypoint_dim": ["optional", int, ("within", (1, 3))],
}
iou_validations: Validations = {}
[docs]
class ObjectKeypointSimilarity(SimilarityModule):
    """Compute the object key-point similarity (OKS) between two batches of poses / States.
    Params
    ------
    format (str):
        The key point format, e.g., 'coco', 'coco-whole', ... has to be in OKS_SIGMAS.keys().
    Optional Params
    ---------------
    keypoint_dim (int, optional):
        The dimensionality of the key points. So whether 2D or 3D is expected.
        Default ``DEF_VAL.similarity.oks.kp_dim``.
    """
[docs]
    def __init__(self, config: Config, path: NodePath):
        super().__init__(config, path)
        self.validate_params(oks_validations)
        # get sigma
        sigma: t.Tensor = OKS_SIGMAS[self.params["format"]].to(device=self.device, dtype=self.precision)
        # With k = 2 * sigma -> shape [J]
        # We know that k is constant and k^2 is only ever required. Therefore, save it as parameter / buffer.
        self.register_buffer("k2", t.square(t.mul(2, sigma)))
        # Create a small value for epsilon to make sure that we do not divide by zero later on.
        self.register_buffer("eps", t.tensor(t.finfo(self.precision).eps, device=self.device, dtype=self.precision))
        # Set up a transform function to convert the bounding boxes if they have the wrong format
        self.transf_bbox_to_xyxy = ConvertBoundingBoxFormat("XYXY")
        self.kp_dim: int = self.params.get("keypoint_dim", DEF_VAL["similarity"]["oks"]["kp_dim"]) 
[docs]
    def get_data(self, ds: State) -> t.Tensor:
        """Given a :class:`.State`, compute the detected / predicted key points with shape ``[B1 x J x 2|3]``
        and the areas of the respective ground-truth bounding-boxes with shape ``[B1]``.
        """
        return ds.keypoints.float().view(ds.B, -1, self.kp_dim) 
[docs]
    def get_area(self, ds: State) -> t.Tensor:
        """Given a :class:`.State`, compute the area of the bounding box."""
        bboxes = ds.bbox
        if bboxes.format == BoundingBoxFormat.XYXY:
            area = box_area(bboxes).float()  # (x2-x1) * (y2-y1)
        elif bboxes.format == BoundingBoxFormat.XYWH:
            area = bboxes[:, -2] * bboxes[:, -1]  # w * h
        else:
            bboxes = self.transf_bbox_to_xyxy(bboxes)
            area = box_area(bboxes).float()
        return area 
[docs]
    def get_target(self, ds: State) -> tuple[t.Tensor, t.Tensor]:
        """Given a :class:`.State` obtain the ground truth key points and the key-point-visibility.
        Both are tensors, the key points are a FloatTensor of shape ``[B2 x J x 2|3]``
        and the visibility is a BoolTensor of shape ``[B2 x J]``.
        """
        kps = ds.keypoints.float().view(ds.B, -1, self.kp_dim)
        vis = ds.cast_joint_weight(dtype=t.bool).squeeze(-1).view(ds.B, -1)
        return kps, vis 
[docs]
    def forward(self, data: State, target: State) -> t.Tensor:
        r"""Compute the object key-point similarity between a ground truth label and detected key points.
        There has to be one key point of the label for any detection. (Batch sizes have to match)
        Notes:
            Compute the key-point similarity :math:`\mathtt{ks}_i` for every joint between every detection and the
            respective ground truth annotation.
            .. math::
                \mathtt{ks}_i = \exp(-\dfrac{d_i^2}{2s^2k_i^2})
            The key-point similarity :math:`\mathtt{OKS}` is then computed as the weighted sum
            using the key-point visibilities as weights.
            .. math::
                \mathtt{OKS} = \dfrac{\sum_i \mathtt{ks}_i \cdot \delta (v_i > 0)}{\sum_i \delta (v_i > 0)}
            * :math:`d_i` the euclidean distance between the ground truth and detected key point
            * :math:`k_i` the constant for the key point, computed as :math:`k=2\cdot\sigma`
            * :math:`v_i` the visibility of the key point, with
                * 0 = unlabeled
                * 1 = labeled but not visible
                * 2 = labeled but visible
            * :math:`s` the scale of the ground truth object, with :math:`s^2` becoming the object's segmented area
        Args:
            data: A :class:`.State` object containing at least the key points and the bounding box. Shape ``N``.
            target: A :class:`.State` containing at least the target key points. Shape ``T``.
        Returns:
            A (Float)Tensor of shape ``[N x T]`` with values in ``[0..1]``.
            If requested, the softmax is computed along the -1 dimension,
            resulting in probability distributions for each value of the input data.
        """
        # get predicted key-points as [N x J x 2] and bbox area as [N]
        pred_kps = self.get_data(ds=data)
        bbox_area = self.get_area(ds=data)
        # get ground-truth key-points as [T x J x 2] and visibility as [T x J]
        gt_kps, gt_vis = self.get_target(ds=target)
        assert pred_kps.size(-1) == gt_kps.size(-1), "Key-points should have the same number of dimensions"
        # Compute d = Euclidean dist, but don't compute the sqrt, because only d^2 is required.
        # A little tensor magic, because if N != T and N != 1 and T != 1, regular subtraction will fail!
        # Therefore, modify the tensors to have shape [N x J x 2 x 1], [(1 x) J x 2 x T].
        # The output has shape [N x J x 2 x T], then square and sum over the number of dimensions (-2).
        d2 = t.sum(
            t.sub(pred_kps.unsqueeze(-1), gt_kps.permute(1, 2, 0)).square(),
            dim=-2,
        )  # -> [N x J x T]
        # Ground truth scale as bounding box area in relation to the image area it lies within.
        # Keep area s^2, because s is never used.
        s2 = bbox_area.flatten()  # [N]
        # Keypoint similarity for every key-point pair of ground truth and detected.
        # Use outer product to combine s^2 [N] with k^2 [J] and add epsilon to make sure to have non-zero values.
        # Again, modify the tensor shapes to match for division.
        # Shapes: d2 [N x J x T], new_outer [N x J x 1]
        ks = t.exp(-t.div(d2, (2 * t.outer(s2, self.k2) + self.eps).unsqueeze(-1)))  # -> [N x J x T]
        # The count of non-zero visibilities in the ground-truth
        count = t.count_nonzero(gt_vis, dim=-1)  # [T]
        # with ks [N x J x T], sum over all J and divide by the nof visibilities
        return self.softmax(t.div(t.where(gt_vis.T, ks, 0).sum(dim=-2), count).nan_to_num_(nan=0.0, posinf=0.0)) 
 
[docs]
class IntersectionOverUnion(SimilarityModule):
    """Use the bounding-box based intersection-over-union as a similarity metric.
    Params
    ------
    """
[docs]
    def __init__(self, config: Config, path: NodePath):
        super().__init__(config, path)
        self.bbox_transform = ConvertBoundingBoxFormat("XYXY") 
[docs]
    def get_data(self, ds: State) -> BoundingBoxes:
        """Given a :class:`.State` obtain the ground-truth bounding-boxes as
        :class:`torchvision.tv_tensors.BoundingBoxes` object of size ``[N x 4]``.
        Notes:
            The box_iou function expects that the bounding boxes are in the 'XYXY' format.
        """
        bboxes = ds.bbox
        if bboxes.format != BoundingBoxFormat.XYXY:
            bboxes = self.bbox_transform(bboxes)
        return bboxes 
[docs]
    def get_target(self, ds: State) -> BoundingBoxes:
        """Given a :class:`.State` obtain the ground-truth bounding-boxes as
        :class:`torchvision.tv_tensors.BoundingBoxes` object of size ``[T x 4]``.
        Notes:
            The function :func:`box_iou` expects that the bounding boxes are in the 'XYXY' format.
        """
        bboxes = ds.bbox
        if bboxes.format != BoundingBoxFormat.XYXY:
            bboxes = self.bbox_transform(bboxes)
        return bboxes 
[docs]
    def forward(self, data: State, target: State) -> t.Tensor:
        """Given two states containing bounding-boxes, compute the intersection over union between each pair.
        Args:
            data: A :class:`.State` object containing the detected bounding-boxes. Size ``N``
            target: A :class:`.State` object containing the target bounding-boxes. Size ``T``
        Returns:
            A (Float)Tensor of shape ``[N x T]`` with values in ``[0..1]``.
            If requested, the softmax is computed along the -1 dimension,
            resulting in probability distributions for each value of the input data.
        """
        return self.softmax(box_iou(self.get_data(ds=data), self.get_target(ds=target)))