Pytorch变量类型转换及One_hot编码表示

本文详细介绍了PyTorch中张量的基本操作,包括张量的生成、类型查看及转换、One-hot编码的实现方法,并对expand()与repeat()两个函数的功能进行了对比解析。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

生成张量
y = torch.empty(3, dtype=torch.long).random_(5)

y = torch.Tensor(2,3).random_(10)

y = torch.randn(3,4).random_(10)
查看类型
y.type
y.dtype
类型转化
tensor.long()/int()/float()
long(),int(),float() 实现类型的转化

One_hot编码表示
def one_hot(y):
    '''
    y: (N)的一维tensor,值为每个样本的类别
    out:
        y_onehot: 转换为one_hot 编码格式
    '''
    y = y.view(-1, 1)
    # y_onehot = torch.FloatTensor(3, 5)
    # y_onehot.zero_()

    y_onehot = torch.zeros(3,5)  # 等价于上面
    y_onehot.scatter_(1, y, 1)
    return y_onehot

y = torch.empty(3, dtype=torch.long).random_(5) #标签

res = one_hot(y)  # 转化为One_hot类型

# One_hot类型标签转化为整数型列表的两种方法

h = torch.argmax(res,dim=1)
_,h1 = res.max(dim=1)


expand()函数

这个函数的作用就是对指定的维度进行数值大小的改变。只能改变维大小为1的维,否则就会报错。不改变的维可以传入-1或者原来的数值。

a=torch.randn(1,1,3,768)

print(a.shape) #torch.Size([1, 1, 3, 768])
b=a.expand(2,-1,-1,-1)

print(b.shape) #torch.Size([2, 1, 3, 768])

c=a.expand(2,1,3,768)
print(c.shape) #torch.Size([2, 1, 3, 768])

repeat()函数

沿着指定的维度,对原来的tensor进行数据复制。这个函数和expand()还是有点区别的。expand()只能对维度为1的维进行扩大,而repeat()对所有的维度可以随意操作。

a=torch.randn(2,1,768)
print(a)
print(a.shape) #torch.Size([2, 1, 768])
b=a.repeat(1,2,1)
print(b)
print(b.shape) #torch.Size([2, 2, 768])
c=a.repeat(3,3,3)
print(c)
print(c.shape) #torch.Size([6, 3, 2304])
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Jeremy_lf

你的鼓励是我的动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值