Override torch.Tensor.to() for the whole object.
State.to()
dgs.utils.state.State.split
dgs.utils.state.State.update