"""
Module for creating submission files for |PT21|_ .
References:
https://github.com/anDoer/PoseTrack21/blob/main/doc/dataset_structure.md
https://github.com/leonid-pishchulin/poseval
Notes:
The structure of the PT21 submission file is similar to the structure of the inputs::
{
"images": [
{
"file_name": "images/train/000001_bonn_train/000000.jpg",
"id": 10000010000,
"frame_id": 10000010000
},
],
"annotations": [
{
"bbox": [x1, y1, w, h],
"image_id": 10000010000,
"keypoints": [x1, y1, vis1, ..., x17, y17, vis17],
"scores": [s1, ..., s17],
"person_id": 1024,
"track_id": 0
},
]
}
Additionally, note that the visibilities are ignored during evaluation.
"""
import torch as t
from torchvision import tv_tensors as tvte
from torchvision.transforms.v2.functional import convert_bounding_box_format
from dgs.models.submission.submission import SubmissionFile
from dgs.utils.constants import PT21_CATEGORIES
from dgs.utils.files import write_json
from dgs.utils.state import State
from dgs.utils.types import Config, NodePath
[docs]
class PoseTrack21Submission(SubmissionFile):
"""Class for creating and appending to a |PT21|_ -style submission file."""
data: dict[str, list[any]]
[docs]
def __init__(self, config: Config, path: NodePath) -> None:
super().__init__(config=config, path=path)
# add the categories to the json data and create the empty lists for the images and annotations
self.clear()
[docs]
def append(self, s: State, *_args, **_kwargs) -> None:
"""Given data, append to the created |PT21| submission file."""
self.data["images"].append(self.get_image_data(s))
self.data["annotations"] += self.get_anno_data(s)
[docs]
def save(self) -> None:
"""Save the submission data in a file."""
try:
write_json(obj=self.data, filepath=self.fp)
except TypeError as e:
self.logger.exception(f"data: {self.data}")
raise TypeError from e
[docs]
@staticmethod
def get_image_data(s: State) -> dict[str, any]:
"""Given a :class:`.State`, extract data for the 'images' used in the submission file."""
# validate the image data
for key in ["filepath", "image_id", "frame_id"]:
if key not in s:
raise KeyError(f"Expected key '{key}' to be in State. Got {s}")
if isinstance(s[key], str):
# str -> tuple of str, this will always be correct, add at least one value for later usage
s[key] = (s[key] for _ in range(max(1, s.B)))
elif s.B > 1:
if (l := len(s[key])) != s.B:
raise ValueError(f"Expected '{key}' ({l}) to have the same length as the State ({s.B}).")
if any(s[key][i] != s[key][0] for i in range(1, s.B)):
raise ValueError(f"State has different {key}s, expected all {key}s to match. got: '{s[key]}'.")
elif (l := len(s[key])) != 1:
raise ValueError(f"Expected '{key}' ({l}) to have a length of exactly 1.")
# get the file_name in the PT21 directory
file_name = f".{s.filepath[0].split('PoseTrack21')[-1]}"
# get the image data
image_data = {
"file_name": file_name,
"id": int(s["image_id"][0].item() if isinstance(s["image_id"], t.Tensor) else s["image_id"][0]),
"image_id": int(s["image_id"][0].item() if isinstance(s["image_id"], t.Tensor) else s["image_id"][0]),
"frame_id": int(s["frame_id"][0].item() if isinstance(s["frame_id"], t.Tensor) else s["frame_id"][0]),
}
return image_data
[docs]
@staticmethod
def get_anno_data(s: State) -> list[dict[str, any]]:
"""Given a :class:`.State`, extract data for the 'annotations' list used in the submission file."""
if s.B == 0:
return []
# validate the annotation data
for key in ["person_id", "pred_tid", "bbox", "keypoints", "joint_weight"]:
if key not in s:
raise KeyError(f"Expected key '{key}' to be in State.")
if (l := len(s[key])) != s.B:
raise ValueError(f"Expected '{key}' ({l}) to have the same length as the State ({s.B}).")
# get the annotation data
anno_data = []
if s.bbox.format != tvte.BoundingBoxFormat.XYWH:
s.bbox = convert_bounding_box_format(s.bbox, new_format=tvte.BoundingBoxFormat.XYWH)
assert s.bbox.format == tvte.BoundingBoxFormat.XYWH, f"got format: {s.bbox.format}"
for i in range(s.B):
kps = t.cat([s.keypoints[i], s.joint_weight[i]], dim=-1)
scores: list[float]
if "scores" in s:
if isinstance(s["scores"], t.Tensor):
scores = s["scores"][i].to(dtype=t.float32).flatten().tolist()
else:
scores = [float(score) for score in s["scores"]]
else:
scores = [0.0 for _ in range(17)]
anno_data.append(
{
"bboxes": s.bbox[i].flatten().tolist(),
"keypoints": kps.flatten().tolist(),
"scores": scores,
"score": (
float(sum(scores) / len(scores))
if "score" not in s
else s["score"][i].item() if isinstance(s["score"][i], t.Tensor) else s["score"][i]
),
"image_id": int(
s["image_id"][i].item() if isinstance(s["image_id"], t.Tensor) else s["image_id"][i]
),
"person_id": int(s.person_id[i].item()),
"track_id": int(s["pred_tid"][i].item()),
}
)
return anno_data
[docs]
def clear(self) -> None:
"""Clear the submission data."""
self.data = {
"images": [],
"annotations": [],
"categories": PT21_CATEGORIES,
}