class TensorsDataset(torch.utils.data.Dataset): ''' A simple loading dataset - loads the tensor that are passed in input. This is the same as torch.utils.data.TensorDataset except that you can add transformations to your data and target tensor. Target tensor can also be None, in which case it is not returned. ''' def __init__(self, data_tensor, target_tensor=None, transforms=None, target_transforms=None): if target_tensor is not None: assert data_tensor.size(0) == target_tensor.size(0) self.data_tensor = data_tensor self.target_tensor = target_tensor if transforms is None: transforms = [] if target_transforms is None: target_transforms = [] if not isinstance(transforms, list): transforms = [transforms] if not isinstance(target_transforms, list): target_transforms = [target_transforms] self.transforms = transforms self.target_transforms = target_transforms def __getitem__(self, index): data_tensor = self.data_tensor[index] for transform in self.transforms: data_tensor = transform(data_tensor) if self.target_tensor is None: return data_tensor target_tensor = self.target_tensor[index] for transform in self.target_transforms: target_tensor = transform(target_tensor) return data_tensor, target_tensor def __len__(self): return self.data_tensor.size(0)
重新定义Pytorch中的TensorDataset,可实现transforms
原创
©著作权归作者所有:来自51CTO博客作者marsggbo的原创作品,请联系作者获取转载授权,否则将追究法律责任
提问和评论都可以,用心的回复会被更多人看到
评论
发布评论
相关文章
-
重新定义DPU——中科驭数2024产品发布会,6月19日诚邀莅临!
共同见证DPU发展的重要时刻!
数据中心 基础设施 DPU -
重新定义Python学习!
作为生产力工具,Python是当今极为流行的编程语言。Python编程逐渐成为一项通用能力,从小学生到各个行业的从业人员都在学Python。Python确实能够
python 学习 开发语言 Python 数据