当前位置: 移动技术网 > IT编程>脚本编程>Python > pytorch 01 关于分割任务中 onehot 编码转换的问题

pytorch 01 关于分割任务中 onehot 编码转换的问题

2020年07月30日  | 移动技术网IT编程  | 我要评论
在分割任务中,我们拿到的label通常是由数字类别组成的,但是在应用某些损失函数时,我们需要把label转换成 one—hot编码的形式。例如:原始label维度 224*224*1(由数字0-2组成) ,为一个三类别的分割任务,在onehot编码后维度为 224*224*3,(可以看成3张224*224*1的切片)。代码:一:当维度为 N 1 *one-hot后 N C *def make_one_hot(input, num_classes): """Convert .

在分割任务中,我们拿到的label通常是由数字类别组成的,但是在应用某些损失函数时,我们需要把label转换成 one—hot编码的形式。

例如:原始label维度 224*224*1(由数字0-2组成) ,为一个三类别的分割任务,在onehot编码后维度为 224*224*3,(可以看成3张224*224*1的切片)。

 

代码:

一:当维度为 N  1 *
one-hot后 N C *

def make_one_hot(input, num_classes):
    """Convert class index tensor to one hot encoding tensor.
    Args:
         input: A tensor of shape [N, 1, *]
         num_classes: An int of number of class
    Returns:
        A tensor of shape [N, num_classes, *]
    """
    shape = np.array(input.shape)
    shape[1] = num_classes
    shape = tuple(shape)
    result = torch.zeros(shape)
    result = result.scatter_(1, torch.LongTensor(input), 1)

    return result
二:当维度为 1 * 
one_hot后 N *
def make_one_hot(input, num_classes):
    """Convert class index tensor to one hot encoding tensor.
    Args:
         input: A tensor of shape [N, 1, *]
         num_classes: An int of number of class
    Returns:
        A tensor of shape [N, num_classes, *]
    """
    shape = np.array(input.shape)
    shape[0] = num_classes
    shape = tuple(shape)
    result = torch.zeros(shape)
    result = result.scatter_(0, torch.LongTensor(input), 1)

    return result

* 代表图像大小 例如 224 x 224

 

本文地址:https://blog.csdn.net/wwwww_bw/article/details/107643179

如您对本文有疑问或者有任何想说的,请点击进行留言回复,万千网友为您解惑!

相关文章:

验证码:
移动技术网