import torch
source = [torch.tensor([1, 2, 3]), torch.tensor([1, 2, 3, 4]), torch.tensor([1, 2, 3, 5, 6])]
source_len = [torch.tensor([3]), torch.tensor([4]), torch.tensor([5])]
source_txt = ["A B C", "A B C D", "A B C D E"]
target = [torch.tensor([1, 2, 3, 1]), torch.tensor([1, 2, 3]), torch.tensor([1, 2, 3, 5, 6, 2])]
target_len = [torch.tensor([4]), torch.tensor([3]), torch.tensor([6])]
target_txt = ["E F G H", "E F G", "E F G H I J"]
这里源中的每个张量对应于目标中的每个张量。
我想根据源的长度反向实现一种“批次内排序”:
# result should be:
source = [torch.tensor([1, 2, 3, 5, 6]), torch.tensor([1, 2, 3, 4]), torch.tensor([1, 2, 3])]
source_len = [torch.tensor([5]), torch.tensor([4]), torch.tensor([3])]
source_txt = ["A B C D E", "A B C D", "A B C"]
target = [torch.tensor([1, 2, 3, 5, 6, 2]), torch.tensor([1, 2, 3])torch.tensor([1, 2, 3, 1])]
target_len = [torch.tensor([6]), torch.tensor([3]), torch.tensor([4])]
target_txt = ["E F G H I J", "E F G", "E F G H"]
回答1
您可以简单地使用从 torch.sort()
返回的第二个参数,如下所示:
source_len_sorted ,idx = torch.sort(torch.as_tensor(source_len),descending=True)
source_len_sorted = list(source_len_sorted)
source_sorted = [source[i] for i in idx]
source_txt_sorted = [source_txt[i] for i in idx]
target_sorted = [target[i] for i in idx]
target_len_sorted = [target_len[i] for i in idx]
target_txt_sorted = [target_txt[i] for i in idx]
如果 values 存储在 Torch 张量而不是 lists 中,则可以通过 idx
对它们进行索引:
source_sorted = source[idx]
等等等等