# -*- encoding: utf-8 -*-
# @Time : 2020/11/26
# @Author : Xiaolei Wang
# @email : wxl1999@foxmail.com
# UPDATE
# @Time : 2020/11/16
# @Author : Xiaolei Wang
# @email : wxl1999@foxmail.com
import torch
[docs]def sort_for_packed_sequence(lengths: torch.Tensor):
"""
:param lengths: 1D array of lengths
:return: sorted_lengths (lengths in descending order), sorted_idx (indices to sort), rev_idx (indices to retrieve original order)
"""
sorted_idx = torch.argsort(lengths, descending=True) # idx to sort by length
rev_idx = torch.argsort(sorted_idx) # idx to retrieve original order
sorted_lengths = lengths[sorted_idx]
return sorted_lengths, sorted_idx, rev_idx