pytorch - 如何根据标签拆分 pytorch dataset?

我需要按类别拆分 CIFAR10 dataset 以便我可以为每个类别创建一个具有相同数量的样本的较小样本。

我怎样才能最好地做到这一点?

回答1

import numpy
sorted_by_value = [0]*10
for i in range(10):
  sorted_by_value[i] =(train.data[numpy.where(numpy.array(train.targets) == i)])
  numpy.random.shuffle(sorted_by_value[i])

对于任何 dataset,您只需将 10 替换为类别数,就可以了。

相似文章

c - 未分配 realloc 中的双**指针

我必须实现一个聚类算法,在加载数据集后,我去检查每个点可以插入到哪个聚类中。如果无法将点插入到任何集群中,我必须将它们从数据集中移出并将它们插入到保留集中。由于我事先不知道保留集的大小,因此我分配了一...

随机推荐

最新文章