This commit is contained in:
qhy
2026-02-10 22:35:45 +08:00
parent dcbcb2c377
commit b558856e1e
2 changed files with 16 additions and 93 deletions

View File

@@ -691,6 +691,15 @@ class WMAModel(nn.Module):
# Reusable CUDA stream for parallel state_unet / action_unet
self._state_stream = torch.cuda.Stream()
def __getstate__(self):
state = self.__dict__.copy()
state.pop('_state_stream', None)
return state
def __setstate__(self, state):
self.__dict__.update(state)
self._state_stream = torch.cuda.Stream()
def forward(self,
x: Tensor,
x_action: Tensor,