0%

DataLoaders和DataSets使用总结

DataLoaders和DataSets使用总结

通过 torch.utils.data.Dataset类定义数据集,通过torch.utils.data.DataLoaderDataset配合定义数据加载方式

torch.utils.data.Dataset主要参数:

  1. dataset 指定构造的数据集,一般是数据+label的元组
  2. shuffle 指定是否打乱
  3. sampler 指定采样器,采取某种方式遍历数据(顺序,随机等)
  4. collate_fn (callable, optional):将数据形成一个batch的tensor(在某些时候需要自定义) 以加载minist数据集为例,基本的使用流程:
1
2
3
4
5
6
7
8
# 首先加载数据集(pytorch现有的数据集) 返回的是DataSet对象
mnist_train = torchvision.datasets.FashionMNIST(root=root, train=True, download=True, transform=transform)
mnist_test = torchvision.datasets.FashionMNIST(root=root, train=False, download=True, transform=transform)
# 然后构造DataLoader对象 返回的是可迭代的DataLoader对象
train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=0)
test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False, num_workers=0)
# 通过for循环遍历进行训练
for batch_features,batch_labels in train_iter:

基本遍历原理

  • 通过sampler抽样index(默认为顺序抽样),根据抽出的index从dataset中获取元素(getitem),调用自身的collate_fn方法,将原始数据转化为batch形式的tensor,返回。
1
2
3
4
5
6
7
8
9
10
# dataloader遍历的next方法
def __next__(self):
if self.num_workers == 0:
indices = next(self.sample_iter) # Sampler
batch = self.collate_fn([self.dataset[i] for i in indices]) # Dataset
if self.pin_memory:
batch = _utils.pin_memory.pin_memory_batch(batch)
return batch
# dataset 以[] 形式访问的方法
def __getitem__(self, idx):

1. 自定义数据集

1.1 继承dataset类

需要重写init(),len(),getitem()方法,以word2vec中负采样数据集生成为例

1
2
3
4
5
6
7
8
9
10
11
12
13
# 自定义数据读取类,可结合dataloader进行数据批量读取
class MyDataset(torch.utils.data.Dataset):
def __init__(self, centers, contexts, negatives):
assert len(centers) == len(contexts) == len(negatives)
self.centers = centers
self.contexts = contexts
self.negatives = negatives
# 返回当前行数据
def __getitem__(self, index):
return (self.centers[index], self.contexts[index], self.negatives[index])
# 返回数据长度
def __len__(self):
return len(self.centers)

1.2 定义collate_fn方法

默认的collate_fn方法,必须保证每个训练数据的长度相同,通过自定义collate_fn方法解决这一问题

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
def my_collate_fn(data):
# 保证所有数据一致的统一长度
max_len = 500
center_words = []
context_negatives = []
masks = []
labels = []
for center, contexts, negatives in data:
cur_len = len(contexts) + len(negatives)
center_words.append([center])
# 对数据进行截断或者延长
if cur_len > max_len:
context_negatives.append((contexts + negatives)[0:max_len])
labels.append([1] * len(contexts) + [0] * (max_len - len(contexts)))
else:
context_negatives.append(contexts + negatives + [0] * (max_len - cur_len))
labels.append([1] * len(contexts) + [0] * (max_len - len(contexts)))
return torch.tensor(center_words), torch.tensor(context_negatives)
# 使用即可
train_iter = Data.DataLoader(MyDataset(all_centers, all_contexts, all_negatives), collate_fn = batchify, batch_size=64, shuffle=True)

2.自定义sampler

待补充