我正在尝试拆分CIFAR10的训练数据,所以最后5000的训练集用于验证。我的代码
size = len(CIFAR10_training)
dataset_indices = list(range(size))
val_index = int(np.floor(0.9 * size))
train_idx, val_idx = dataset_indices[:val_index], dataset_indices[val_index:]
train_sampler = SubsetRandomSampler(train_idx)
val_sampler = SubsetRandomSampler(val_idx)
train_dataloader = torch.utils.data.DataLoader(CIFAR10_training,
batch_size=config['batch_size'],
shuffle=False, sampler = train_sampler)
valid_dataloader = torch.utils.data.DataLoader(CIFAR10_training,
batch_size=config['batch_size'],
shuffle=False, sampler = val_sampler)
print(len(train_dataloader.dataset),len(valid_dataloader.dataset),但是最后的打印语句打印50000和10000。如果不是45000和5000,当我打印train_idx和val_idx时,它会打印正确的值(0:4499,45000:49999),我的代码有什么问题吗?
发布于 2020-10-15 22:51:06
我无法复制您的结果,当我执行您的代码时,print语句输出的数值是相同的两倍:train_CIFAR10中的元素数。因此,我猜您在复制代码时犯了错误,而valid_dataloader实际上被赋予了CIFAR10_test (或类似的东西)作为参数。在下面的文章中,我将假设是这样的,并且您的打印输出(50000, 50000),这就是Pytorch的CIFAR10数据集的训练部分的大小。
然后它是完全预期的,而不是它不应该输出(45000,5000)。您正在询问train_dataloader.dataset和valid_dataloader.dataset的长度,即基础数据集的长度。对于两个加载器,此数据集为CIFAR10_training。因此,您将获得此数据集的两倍大小(即50000)。
您也不能请求len(train_dataloader),因为这样会产生数据集中的批数(大约是45000/batch_size)。
如果您需要知道分块的大小,则必须计算取样器的长度:
print(len(train_dataloader.sampler), len(valid_dataloader.sampler))除此之外,您的代码很好,您正在正确地分割数据。
https://stackoverflow.com/questions/64379339
复制相似问题