torch.Tensor.topk提供了一种在一维张量中提取顶k值的有效方法。是否可以将顶部k值限制为non-repetitive
例如,
input = torch.tensor([0.2,0.2,0.1])
k = 2
dim = 0
output[0] = torch.tensor([0.2,0.1])
output[1] = torch.longtensor([0,2])发布于 2021-07-30 09:53:49
您可以在输入张量上应用torch.unique。
>>> input.unique().topk(k=2).values
tensor([0.2000, 0.1000])请注意,此时你将失去指数。
编辑:实际上,torch.unique有一个选项来对结果进行排序(默认情况下该选项是打开的)。
>>> input
tensor([0.0000, 0.3000, 0.2000, 0.2000, 0.1000])
>>> input.unique(return_inverse=True)[1].unique(sorted=False)
tensor([1, 2, 3, 0])https://stackoverflow.com/questions/68588784
复制相似问题