首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >CIFAR10数据采集器分离器

CIFAR10数据采集器分离器
EN

Stack Overflow用户
提问于 2020-10-15 20:40:06
回答 1查看 264关注 0票数 0

我正在尝试拆分CIFAR10的训练数据,所以最后5000的训练集用于验证。我的代码

代码语言:javascript
复制
size = len(CIFAR10_training)
dataset_indices = list(range(size))
val_index = int(np.floor(0.9 * size))
train_idx, val_idx = dataset_indices[:val_index], dataset_indices[val_index:]
train_sampler = SubsetRandomSampler(train_idx)
val_sampler = SubsetRandomSampler(val_idx)

train_dataloader = torch.utils.data.DataLoader(CIFAR10_training,
                                          batch_size=config['batch_size'],
                                          shuffle=False,  sampler = train_sampler)
valid_dataloader = torch.utils.data.DataLoader(CIFAR10_training,
                                           batch_size=config['batch_size'],
                                           shuffle=False,  sampler = val_sampler)
print(len(train_dataloader.dataset),len(valid_dataloader.dataset),

但是最后的打印语句打印50000和10000。如果不是45000和5000,当我打印train_idx和val_idx时,它会打印正确的值(0:4499,45000:49999),我的代码有什么问题吗?

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2020-10-15 22:51:06

我无法复制您的结果,当我执行您的代码时,print语句输出的数值是相同的两倍:train_CIFAR10中的元素数。因此,我猜您在复制代码时犯了错误,而valid_dataloader实际上被赋予了CIFAR10_test (或类似的东西)作为参数。在下面的文章中,我将假设是这样的,并且您的打印输出(50000, 50000),这就是Pytorch的CIFAR10数据集的训练部分的大小。

然后它是完全预期的,而不是它不应该输出(45000,5000)。您正在询问train_dataloader.datasetvalid_dataloader.dataset的长度,即基础数据集的长度。对于两个加载器,此数据集为CIFAR10_training。因此,您将获得此数据集的两倍大小(即50000)。

您也不能请求len(train_dataloader),因为这样会产生数据集中的批数(大约是45000/batch_size)。

如果您需要知道分块的大小,则必须计算取样器的长度:

代码语言:javascript
复制
print(len(train_dataloader.sampler), len(valid_dataloader.sampler))

除此之外,您的代码很好,您正在正确地分割数据。

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/64379339

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档