DataLoaders和DataSets使用总结
通过 torch.utils.data.Dataset
类定义数据集,通过torch.utils.data.DataLoader
与Dataset
配合定义数据加载方式
torch.utils.data.Dataset
主要参数:
- dataset 指定构造的数据集,一般是数据+label的元组
- shuffle 指定是否打乱
- sampler 指定采样器,采取某种方式遍历数据(顺序,随机等)
- collate_fn (callable, optional):将数据形成一个batch的tensor(在某些时候需要自定义)
以加载minist数据集为例,基本的使用流程:
1 2 3 4 5 6 7 8
| mnist_train = torchvision.datasets.FashionMNIST(root=root, train=True, download=True, transform=transform) mnist_test = torchvision.datasets.FashionMNIST(root=root, train=False, download=True, transform=transform)
train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=0) test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False, num_workers=0)
for batch_features,batch_labels in train_iter:
|
基本遍历原理
- 通过sampler抽样index(默认为顺序抽样),根据抽出的index从dataset中获取元素(getitem),调用自身的collate_fn方法,将原始数据转化为batch形式的tensor,返回。
1 2 3 4 5 6 7 8 9 10
| def __next__(self): if self.num_workers == 0: indices = next(self.sample_iter) batch = self.collate_fn([self.dataset[i] for i in indices]) if self.pin_memory: batch = _utils.pin_memory.pin_memory_batch(batch) return batch
def __getitem__(self, idx):
|
1. 自定义数据集
1.1 继承dataset类
需要重写init(),len(),getitem()
方法,以word2vec中负采样数据集生成为例
1 2 3 4 5 6 7 8 9 10 11 12 13
| class MyDataset(torch.utils.data.Dataset): def __init__(self, centers, contexts, negatives): assert len(centers) == len(contexts) == len(negatives) self.centers = centers self.contexts = contexts self.negatives = negatives def __getitem__(self, index): return (self.centers[index], self.contexts[index], self.negatives[index]) def __len__(self): return len(self.centers)
|
1.2 定义collate_fn方法
默认的collate_fn方法,必须保证每个训练数据的长度相同,通过自定义collate_fn方法解决这一问题
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
| def my_collate_fn(data): max_len = 500 center_words = [] context_negatives = [] masks = [] labels = [] for center, contexts, negatives in data: cur_len = len(contexts) + len(negatives) center_words.append([center]) if cur_len > max_len: context_negatives.append((contexts + negatives)[0:max_len]) labels.append([1] * len(contexts) + [0] * (max_len - len(contexts))) else: context_negatives.append(contexts + negatives + [0] * (max_len - cur_len)) labels.append([1] * len(contexts) + [0] * (max_len - len(contexts))) return torch.tensor(center_words), torch.tensor(context_negatives)
train_iter = Data.DataLoader(MyDataset(all_centers, all_contexts, all_negatives), collate_fn = batchify, batch_size=64, shuffle=True)
|
2.自定义sampler
待补充