pytorch - 如何根据标签拆分 pytorch dataset?

我需要按类别拆分 CIFAR10 dataset 以便我可以为每个类别创建一个具有相同数量的样本的较小样本。

我怎样才能最好地做到这一点?

回答1

import numpy
sorted_by_value = [0]*10
for i in range(10):
  sorted_by_value[i] =(train.data[numpy.where(numpy.array(train.targets) == i)])
  numpy.random.shuffle(sorted_by_value[i])

对于任何 dataset,您只需将 10 替换为类别数,就可以了。

相似文章

最新文章