dgs.utils.state.collate_states

dgs.utils.state.collate_states(batch: list[State] | State) State[source]

Collate function for multiple States, to flatten / squeeze the shapes and keep the tv_tensors classes.

The default collate function messes up a few of the dimensions and removes custom tv_tensor classes. Therefore, add custom collate functions for the tv_tensors classes. Additionally, custom torch tensor collate, which stacks tensors only if first dimension != 1, cat otherwise.

Parameters:

batch – A list of State, each State contains the data belonging to a single bounding-box.

Returns:

One single State object, containing a batch of data belonging to the bounding-boxes.