dgs.utils.state.collate_tensors

dgs.utils.state.collate_tensors(batch: list[torch.Tensor], *_args, **_kwargs) torch.Tensor[source]

Collate a batch of tensors into a single one.

Will use torch.cat() if the first dimension has a shape of one, otherwise torch.stack()