Source code for dgs.utils.track

"""
Classes and helpers for Track, Tracks, and other tracking related objects.
"""

from collections import deque, UserDict
from copy import deepcopy
from enum import Enum

import torch as t

from dgs.utils.config import DEF_VAL
from dgs.utils.state import collate_states, State

TrackID = int
"""The ID of any given track is a positive integer."""


[docs] class TrackStatus(Enum): """Enumerate for handling the status of a :class:`.Track`. A track can be deleted and re-activated. If a track was 'Inactive' and becomes 'Active' again its status is simply 'Active'. """ New = 0 Active = 1 Inactive = 2 Reactivated = 3 Removed = 4
[docs] class TrackStatistics: """Data object to save and analyze the statistics of some tracks.""" # active new: list[TrackID] reactivated: list[TrackID] found: list[TrackID] still_active: list[TrackID] # inactive lost: list[TrackID] still_inactive: list[TrackID] # removed removed: list[TrackID]
[docs] def __init__(self) -> None: self.clear()
# ######### # # Functions # # ######### #
[docs] def print(self, logger, frame_idx: int) -> None: # pragma: no cover """Print the current Track statistics. Debug only.""" logger.debug(f"===========Frame{frame_idx}==========") logger.debug( f"Active: {self.active} of which {self.new} are new, " f"{self.found} are re-found, and {self.reactivated} are reactivated." ) logger.debug(f"Inactive: {self.inactive} of which {self.lost} are lost.") logger.debug(f"Removed: {self.removed}")
[docs] def clear(self) -> None: """Clear the current Track statistics. Mostly used for tests.""" # active self.new = [] self.reactivated = [] self.found = [] self.still_active = [] # inactive self.lost = [] self.still_inactive = [] # removed self.removed = []
# ########## # # Properties # # ########## # @property def active(self) -> set[TrackID]: return set(self.new + self.reactivated + self.found + self.still_active) @property def inactive(self) -> set[TrackID]: return set(self.lost + self.still_inactive) # ######### # # Number of # # ######### # @property def nof_active(self) -> int: return len(self.active) @property def nof_found(self) -> int: return len(self.found) @property def nof_inactive(self) -> int: return len(self.inactive) @property def nof_lost(self) -> int: return len(self.lost) @property def nof_new(self) -> int: return len(self.new) @property def nof_reactivated(self) -> int: return len(self.reactivated) @property def nof_removed(self) -> int: return len(self.removed) @property def nof_still_active(self) -> int: return len(self.still_active) @property def nof_still_inactive(self) -> int: return len(self.still_inactive)
[docs] class Track: """A Track is a single (de-)queue containing multiple :class:`.State` s, keeping the last N states. Args: N: The max length of this track. states: A list of :class:`.State` objects, describing the initial values of this track. Default None. tid: The Track ID of this object. Default -1. """ _N: int """Maximum number of states in this Track.""" _states: deque """The deque of the current states with a max length of _N.""" _id: TrackID """The Track-ID of this Track.""" _status: TrackStatus """The status of this Track.""" _start_frame: int """The number describing the first frame this Track was visible.""" _nof_active: int = 0 """The number of frames this track has been active."""
[docs] def __init__(self, N: int, curr_frame: int, states: list[State] = None, tid: int = -1) -> None: # max nof states if N <= 0: raise ValueError(f"N must be greater than 0 but got '{N}'.") self._N = N # track-id self.id = tid # status and frame management self._status = TrackStatus.New self._start_frame = curr_frame self._nof_active = len(states) if states is not None else 0 # already existing states if states is not None and len(states) and any(state.B != 1 for state in states): raise ValueError(f"The batch size of all the States '{[state.B for state in states]}' must be 1.") self._states = deque(iterable=states if states else [], maxlen=N)
def __repr__(self) -> str: return f"Track-{self.id}-{len(self)}-{self._start_frame}" def __getitem__(self, index: TrackID) -> State: return self._states[index] def __len__(self) -> int: return len(self._states) def __eq__(self, other: "Track") -> bool: """Return whether another Track is equal to self.""" if not isinstance(other, Track): return False variable_equality: bool = ( self.N == other.N and self.id == other.id and self.status == other.status and self.nof_active == other.nof_active and self._start_frame == other._start_frame ) if len(self) == 0 and len(other) == 0: return variable_equality return ( variable_equality and len(self._states) == len(other._states) and all(s == other[i] for i, s in enumerate(self._states)) ) # ########## # # Properties # # ########## # @property def status(self) -> TrackStatus: return self._status @property def nof_active(self) -> int: return self._nof_active @nof_active.setter def nof_active(self, value: int) -> None: self._nof_active = value @property def id(self) -> TrackID: return self._id @id.setter def id(self, value: TrackID): if isinstance(value, t.Tensor) and (value.ndim == 0 or (value.ndim == 1 and len(value) == 1)): self._id = int(value.item()) elif isinstance(value, int): self._id = value else: raise NotImplementedError(f"unknown type for ID, expected int but got '{type(value)}' - '{value}'") @property def N(self) -> int: """Get the max length of this Track.""" return self._N @property def device(self) -> t.device: """Get the device of every tensor in this Track.""" if len(self) == 0: raise ValueError("Can not get the device of an empty Track.") device = self._states[-1].device assert all(state.device == device for state in self._states), "Not all tensors are on the same device" return device # ############## # # State handling # # ############## #
[docs] def append(self, state: State) -> None: """Append a new state to the Track.""" if state.B != 1: raise ValueError(f"A Track should only get a State with the a batch size of 1, but got {state.B}.") if len(self._states) > 0: self._states[-1].clean() self._states.append(state) self.set_active() self._nof_active += 1
[docs] def get_all(self) -> State: """Get all the states from the Track and stack them into a single :class:`State`.""" if len(self) == 0: raise ValueError("Can not stack the items of an empty Track.") return collate_states(list(self._states))
# ############### # # Status handling # # ############### #
[docs] def set_active(self) -> None: self._status = TrackStatus.Active
[docs] def set_inactive(self) -> None: """Set the status of this Track to inactive and clean up older states.""" self._status = TrackStatus.Inactive self._nof_active = 0 # clean everything except last state for i in range(len(self) - 1): self._states[i].clean("all")
[docs] def set_removed(self) -> None: self._status = TrackStatus.Removed self._nof_active = 0 self._id = -1 # unset tID
[docs] def set_reactivated(self, tid: TrackID) -> None: self._status = TrackStatus.Reactivated self._nof_active = 0 self._id = tid
[docs] def set_status(self, status: TrackStatus, tid: TrackID = 0) -> None: """Set the status of this Track.""" if status == TrackStatus.Active: self.set_active() elif status == TrackStatus.Inactive: self.set_inactive() elif status == TrackStatus.Removed: self.set_removed() elif status == TrackStatus.Reactivated: self.set_reactivated(tid) elif status == TrackStatus.New: self._status = TrackStatus.New else: raise ValueError(f"Unknown TrackStatus {status}") # pragma: no cover
[docs] def age(self, curr_frame: int) -> int: """Get the age of this track (in frames). The age does not account for frames where the track has been deleted. """ return curr_frame - self._start_frame
# ####### # # Utility # # ####### #
[docs] def to(self, *args, **kwargs) -> "Track": """Call ``.to()`` like you do with any other ``torch.Tensor``.""" for i, state in enumerate(self._states): self._states[i] = state.to(*args, **kwargs) return self
[docs] def clear(self) -> None: """Clear all the states from the Track.""" self._states.clear() self._nof_active = 0
[docs] def copy(self) -> "Track": """Return a (deep) copy of self.""" track = Track( N=self.N, curr_frame=self._start_frame, states=[s.copy() for s in self._states], tid=self.id, ) track.nof_active = self._nof_active track.set_status(status=self._status, tid=self.id) return track
[docs] class Tracks(UserDict): """Multiple Track objects stored as a dictionary, where the Track is the value and the key is this tracks' unique ID. """ # pylint: disable=too-many-public-methods _N: int """The maximum number of frames in each track.""" data: dict[TrackID, Track] """All the Tracks that are currently tracked, including inactive Tracks as mapping 'Track-ID -> Track'.""" inactive: dict[TrackID, int] """All the inactive Tracks as 'Track-ID -> number of inactive frames / steps'.""" inactivity_threshold: int """The number of steps a Track can be inactive before deleting it.""" removed: dict[TrackID, Track] """All the Tracks that have been removed, to be able to reactivate them.""" _curr_frame: int """The number of the current frame."""
[docs] def __init__(self, N: int, thresh: int = None, start_frame: int = 0) -> None: super().__init__() # set N - the maximum length of every track if N <= 0: raise ValueError(f"N must be greater than 0 but got '{N}'") self._N = N # set the inactivity threshold if thresh is None: self.inactivity_threshold = DEF_VAL["tracks"]["inactivity_threshold"] elif not isinstance(thresh, int): raise TypeError(f"Threshold is expected to be int or None, but got {thresh}") elif thresh < 0: raise ValueError(f"Threshold must be positive, got {thresh}.") else: self.inactivity_threshold = thresh self.reset() # make sure to set the initial current frame after resetting self._curr_frame = start_frame
def __len__(self) -> int: """Get the length of data. If you want the number of active or inactive Tracks, use :meth:`.nof_active` and :meth:`.nof_inactive` respectively. If you want the age, use :meth:`.age`. """ return len(self.data) def __eq__(self, other: "Tracks") -> bool: """Check the equality of two Tracks. This method does not validate whether the removed Tracks are equal. """ if not isinstance(other, Tracks): return False return ( self.inactive == other.inactive and self.inactivity_threshold == other.inactivity_threshold and self._curr_frame == other._curr_frame and set(self.data.keys()) == set(other.data.keys()) and all(t == other.data[k] for k, t in self.data.items()) ) def __getitem__(self, key: TrackID) -> Track: """Given the Track-ID return the Track.""" return self.data[key] def __repr__(self) -> str: return f"Tracks-{self.age}-{self.ids_active}" # ########## # # Properties # # ########## # @property def N(self) -> int: return self._N @property def age(self) -> int: return self._curr_frame @property def ages(self) -> dict[int, int]: """Get the age of all the tracks (in frames).""" return {i: t.age(self._curr_frame) for i, t in self.data.items()} @property def ids(self) -> set[TrackID]: """Get all the track-IDs in this object.""" return set(int(k) for k in self.data.keys()) @property def ids_active(self) -> set[TrackID]: """Get all the track-IDs currently active.""" return self.ids - self.ids_inactive @property def ids_inactive(self) -> set[TrackID]: """Get all the track-IDs currently inactive.""" return set(int(k) for k in self.inactive.keys()) @property def ids_removed(self) -> set[TrackID]: """Get all the track-IDs that have been deleted.""" return set(int(k) for k in self.removed.keys()) @property def nof_active(self) -> int: """Get the number of active Tracks.""" return len(self.data) - len(self.inactive) @property def nof_inactive(self) -> int: """Get the number of inactive Tracks.""" return len(self.inactive) @property def nof_removed(self) -> int: """Get the number of Tracks that have been removed.""" return len(self.removed) # ######################## # # State and Track Handling # # ######################## #
[docs] def remove_tid(self, tid: TrackID) -> None: """Given a Track-ID, remove the track associated with it from this object.""" if tid not in self.data: raise KeyError(f"Track-ID {tid} can not be deleted, because it is not present in Tracks.") self.data[tid].set_removed() self.removed[tid] = self.data.pop(tid) self.inactive.pop(tid, None)
[docs] def remove_tids(self, tids: list[TrackID]) -> None: for tid in tids: self.remove_tid(tid)
[docs] def is_active(self, tid: TrackID) -> bool: """Return whether the given Track-ID is active.""" return tid in self.data and tid not in self.inactive
[docs] def is_inactive(self, tid: TrackID) -> bool: """Return whether the given Track-ID is inactive.""" return tid in self.data and tid in self.inactive
[docs] def is_removed(self, tid: TrackID) -> bool: """Return whether the given Track-ID has been removed.""" return tid not in self.data and tid in self.removed
[docs] def add(self, tracks: dict[TrackID, State], new: list[State]) -> list[TrackID]: """Given tracks with existing Track-IDs update those and create new Tracks for States without Track-IDs. Additionally, mark Track-IDs that are not in either of the inputs as unseen and therefore as inactive for one more step. Returns: The Track-IDs of the new_tracks in the same order as provided. """ inactive_ids = self.ids - set(int(k) for k in tracks.keys()) # get the next free ID and create track(s) new_tids = self.add_empty_tracks(len(new)) # add the new state to the new tracks for tid, new_state in zip(new_tids, new): if new_state.B != 0: self._update_track(tid=tid, add_state=new_state) else: inactive_ids.add(tid) # add state to Track and remove track from inactive if present for tid, new_state in tracks.items(): if new_state.B != 0: self._update_track(tid=tid, add_state=new_state) else: inactive_ids.add(tid) self._handle_inactive(tids=inactive_ids) # step to the next frame self._next_frame() return new_tids
def _next_frame(self) -> None: self._curr_frame += 1
[docs] def get_states(self) -> tuple[list[State], list[TrackID]]: """Get the last state of **every** track in this object as a :class:`State`.""" states: list[State] = [] tids: list[TrackID] = [] for tid, track in self.data.items(): states.append(track[-1]) tids.append(tid) return states, tids
[docs] def get_active_states(self) -> list[State]: """Get the last state of every **active** track in this object as a :class:`State`.""" states: list[State] = [] for tid, track in self.data.items(): # make sure that track ID is set in returned states if "track_id" not in track[-1]: track[-1]["track_id"] = t.tensor([tid] * track[-1].B, dtype=t.long, device=track[-1].device).flatten() # don't add empty states if tid in self.inactive: continue states.append(track[-1]) return states
[docs] def add_empty_tracks(self, n: int = 1) -> list[TrackID]: """Given a Track, compute the next track-ID, and save this track in data using this ID. Args: n: The number of new Tracks to add. Returns: tids: The track-IDs of the added tracks. """ tids = [] for _ in range(n): tid = self._get_next_id() self.data[tid] = Track(N=self._N, curr_frame=self._curr_frame, tid=tid) tids.append(tid) return tids
[docs] def reactivate_track(self, tid: TrackID) -> None: """Given the track-ID of a previously removed track, reactivate it.""" if tid not in self.removed: raise KeyError(f"Track-ID {tid} not present in removed Tracks.") self.data[tid] = self.removed.pop(tid) self.data[tid].set_reactivated(tid)
# todo should the states of the track be removed / cleared ? def _update_track(self, tid: TrackID, add_state: State) -> None: """Use the track-ID to update a track given an additional :class:`State` for the :class:`Track`. Will additionally remove the tid from the inactive Tracks. Returns: Whether this track has been reactivated with this update. """ if tid not in self.data.keys(): if tid not in self.removed.keys(): raise KeyError(f"Track-ID {tid} neither present in the current or previously removed Tracks.") # reactivate previously removed track self.reactivate_track(tid) elif tid in self.inactive: # update inactive self.inactive.pop(tid) # append state to track self.data[tid].append(state=add_state) # add track id to state self.data[tid][-1]["pred_tid"] = t.tensor( [tid] * self.data[tid][-1].B, dtype=t.long, device=add_state.device ).flatten() def _handle_inactive(self, tids: set[TrackID]) -> None: """Given the Track-IDs of the Tracks that haven't been seen this step, update the inactivity tracker. Create the counter for inactive Track-IDs and update existing counters. Additionally, remove tracks that have been inactive for too long. """ for tid in tids: if tid in self.inactive.keys(): self.inactive[tid] += 1 if self.inactive[tid] >= self.inactivity_threshold: self.remove_tid(tid) else: self.inactive[tid] = 1 self.data[tid].set_inactive() def _get_next_id(self) -> TrackID: """Get the next free track-ID.""" if len(self.data) == 0: return 0 return max(self.data.keys()) + 1 # ####### # # Utility # # ####### #
[docs] def reset(self) -> None: """Reset this object.""" self.data = {} self.inactive = {} self.removed = {} self._curr_frame = 0
[docs] def reset_deleted(self) -> None: """Fully remove the deleted Tracks.""" self.removed = {}
[docs] def copy(self) -> "Tracks": """Return a (deep) copy of this object.""" new_t = Tracks(N=self.N, thresh=self.inactivity_threshold) new_t.data = {i: t.copy() for i, t in self.data.items()} new_t.inactive = deepcopy(self.inactive) new_t.removed = deepcopy(self.removed) return new_t
[docs] def to(self, *args, **kwargs) -> "Tracks": """Create function similar to :func:`torch.Tensor.to` .""" self.data = {i: t.to(*args, **kwargs) for i, t in self.data.items()} self.removed = {i: t.to(*args, **kwargs) for i, t in self.removed.items()} return self