diff --git a/references/wifi_densepose_pytorch.py b/references/wifi_densepose_pytorch.py index 4d3475c6..844bdc89 100644 --- a/references/wifi_densepose_pytorch.py +++ b/references/wifi_densepose_pytorch.py @@ -441,7 +441,7 @@ class WiFiDensePoseTrainer: }, path) def load_model(self, path): - checkpoint = torch.load(path) + checkpoint = torch.load(path, map_location=self.device, weights_only=True) self.model.load_state_dict(checkpoint['model_state_dict']) self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])