python - 根据长度对张量的 list 中的 list 进行排序

import torch

source = [torch.tensor([1, 2, 3]), torch.tensor([1, 2, 3, 4]), torch.tensor([1, 2, 3, 5, 6])]

source_len = [torch.tensor([3]), torch.tensor([4]), torch.tensor([5])]

source_txt = ["A B C", "A B C D", "A B C D E"]


target =  [torch.tensor([1, 2, 3, 1]), torch.tensor([1, 2, 3]), torch.tensor([1, 2, 3, 5, 6, 2])]

target_len = [torch.tensor([4]), torch.tensor([3]), torch.tensor([6])]

target_txt = ["E F G H", "E F G", "E F G H I J"]

这里源中的每个张量对应于目标中的每个张量。

我想根据源的长度反向实现一种“批次内排序”:

# result should be:

source = [torch.tensor([1, 2, 3, 5, 6]), torch.tensor([1, 2, 3, 4]), torch.tensor([1, 2, 3])]

source_len = [torch.tensor([5]), torch.tensor([4]), torch.tensor([3])]

source_txt = ["A B C D E", "A B C D", "A B C"]


target =  [torch.tensor([1, 2, 3, 5, 6, 2]), torch.tensor([1, 2, 3])torch.tensor([1, 2, 3, 1])]

target_len = [torch.tensor([6]), torch.tensor([3]), torch.tensor([4])]

target_txt = ["E F G H I J", "E F G", "E F G H"]

回答1

您可以简单地使用从 torch.sort() 返回的第二个参数,如下所示:

source_len_sorted ,idx = torch.sort(torch.as_tensor(source_len),descending=True)
source_len_sorted = list(source_len_sorted)
source_sorted = [source[i] for i in idx]
source_txt_sorted = [source_txt[i] for i in idx]

target_sorted =  [target[i] for i in idx]
target_len_sorted = [target_len[i] for i in idx]
target_txt_sorted = [target_txt[i] for i in idx]

如果 values 存储在 Torch 张量而不是 lists 中,则可以通过 idx 对它们进行索引:

source_sorted = source[idx]

等等等等

相似文章

kotlin - 使用 Kotlin 按多个条件排序

我们想按三个不同的标准对对象进行排序,具有较高优先级的标准会覆盖下一个):状态字段时间场排序字段status是枚举并且有一个自定义的sorting,我们通过一个比较器来实现。(未知,打开,关闭,)so...