多标签分类与binary_cross_entropy_with_logits

本文介绍了PyTorch中torch.nn.functional.binary_cross_entropy_with_logits和torch.nn.BCEWithLogitsLoss在多标签分类中的应用。二值交叉熵用于计算每个类别的损失,其等价于tf.nn.sigmoid_cross_entropy_with_logits。目标标签可以有多个1,适用于多标签问题。同时,文章对比了二值交叉熵与softmax_cross_entropy_with_logits,后者适用于多分类问题。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

1. binary_cross_entropy_with_logits可用于多标签分类

torch.nn.functional.binary_cross_entropy_with_logits等价于torch. nn.BCEWithLogitsLoss

torch.nn.BCELoss+torch.nn.Sigmoid 等价于 torch. nn.BCEWithLogitsLoss

在pytorch中torch.nn.functional.binary_cross_entropy_with_logits和tensorflow中tf.nn.sigmoid_cross_entropy_with_logits,都是二值交叉熵,二者等价。

接受任意形状的输入,target要求与输入形状一致。注意:target的值必须在[0,N-1]之间,其中N为类别数,否则会出现莫名其妙的错误,比如loss为负数。

二值交叉熵的Loss如下:

在这里插入图片描述

其中 l_{i} 可以解释为:预测这个样本为第i个类别的损失

 在这里插入图片描述

 w_{n} 解释为类别的权重,重视某个类别,则加大该类别权重。

from torch import nn
from torch.autograd import Variable
bce_criterion = nn.BCEWithLogitsLoss(weight = None, reduce = False)
y = Variable(torch.tensor([[1,0,0],[0,1,0],[0,0,1],[1,1,0],[0,1,0]],dtype=torch.float64))
logits = Variable(torch.tensor([[12,3,2],[3,10,1],[1,2,5],[4,6.5,1.2],[3,6,1]],dtype=torch.float64))
bce_criterion(logits, y)

binary_cross_entropy_with_logits中的target(标签)的one_hot编码中每一维可以出现多个1,而softmax_cross_entropy_with_logits 中的target的one_hot编码中每一维只能出现一个1

2. softmax_cross_entropy_with_logits 

binary_cross_entropy_with_logits是二分类的交叉熵,实际是多分类softmax_cross_entropy的一种特殊情况

 

from torch import nn
from torch.autograd import Variable
bce_criterion = nn.BCEWithLogitsLoss(weight = None, reduce = False)
y = Variable(torch.tensor([[1,0,0],[0,1,0],[0,0,1],[0,1,0],[0,1,0]],dtype=torch.float64))
logits = Variable(torch.tensor([[12,3,2],[3,10,1],[1,2,5],[4,6.5,1.2],[3,6,1]],dtype=torch.float64))
bce_criterion(logits, y)

target中one_hot编码后每一行只能出现一个1

准确率评价参考:(69条消息) 多标签分类中的损失函数与评估指标_小Aer的博客-CSDN博客_多标签分类损失函数

参考:https://github.com/PaddlePaddle/PaddleClas/blob/release/2.3/docs/zh_CN/quick_start/quick_start_multilabel_classification.md

参考:binary_cross_entropy_with_logits-API文档-PaddlePaddle深度学习平台

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值