dgs.utils.state.collate_states¶
- dgs.utils.state.collate_states(batch: list[State] | State) State [source]¶
Collate function for multiple States, to flatten / squeeze the shapes and keep the tv_tensors classes.
The default collate function messes up a few of the dimensions and removes custom tv_tensor classes. Therefore, add custom collate functions for the tv_tensors classes. Additionally, custom torch tensor collate, which stacks tensors only if first dimension != 1, cat otherwise.
- Parameters:
batch – A list of
State
, each State contains the data belonging to a single bounding-box.- Returns:
One single State object, containing a batch of data belonging to the bounding-boxes.