李沐Softmax回归从零开始实现代码中的关于y和y_hat

原视频:李沐Softmax回归从零开始实现

其中,这段代码令人迷惑。

y = torch.tensor([0, 2])
y_hat = torch.tensor([[0.1, 0.3, 0.6], [0.3, 0.2, 0.5]])

y_hat[[0, 1], y]

视频文字上的注释是:

创建一个数据 y_hat,其中包含2个样本在3个类别上的预测概率,使用 y 作为 y_hat 中概率的索引。

为什么介绍这段代码?因为为了介绍交叉熵。

在之前的课程中提到,对真实 y 进行独热编码。

比如,共有 3 类,则真实输出 y = [ 0 , 0 , 1 ] \bold y = [0, 0, 1] y=[0,0,1],即表示:真实的类别是第3类。

最后发现,交叉熵损失等于 − l o g ( y y ^ ) -log(\hat{y_y}) log(yy^),就是 i = y 真实类别的预测概率 y ^ \hat{y} y^

但是,这里的 y 不表示这个含义。这里的 y 表示 2 个样本的真实类别分别是 0 和 2(类别有 [0, 1, 2])

而之前的独热编码 y 表示为 1 个样本的真实类别:[0, 0, 1]。第 2 个是1,则表示第 2 个为真实类别。所以独热编码的y要写成上述代码的y,可以写成:y = [2]

当把 y 写成独热编码,是为了方便解释:交叉熵损失的预测概率只需要真实类别的预测概率,并对其求-log。

那么,既然如此,代码中的 y 就表示 index,就告诉你哪一个是真实类别的预测概率,那么要计算交叉熵损失就直接根据 index 在 y_hat 里面取就行。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值