生成张量
y = torch.empty(3, dtype=torch.long).random_(5)
y = torch.Tensor(2,3).random_(10)
y = torch.randn(3,4).random_(10)
查看类型
y.type
y.dtype
类型转化
tensor.long()/int()/float()
long(),int(),float() 实现类型的转化
One_hot编码表示
def one_hot(y):
'''
y: (N)的一维tensor,值为每个样本的类别
out:
y_onehot: 转换为one_hot 编码格式
'''
y = y.view(-1, 1)
# y_onehot = torch.FloatTensor(3, 5)
# y_onehot.zero_()
y_onehot = torch.zeros(3,5) # 等价于上面
y_onehot.scatter_(1, y, 1)
return y_onehot
y = torch.empty(3, dtype=torch.long).random_(5) #标签
res = one_hot(y) # 转化为One_hot类型
# One_hot类型标签转化为整数型列表的两种方法
h = torch.argmax(res,dim=1)
_,h1 = res.max(dim=1)
expand()函数
这个函数的作用就是对指定的维度进行数值大小的改变。只能改变维大小为1的维,否则就会报错。不改变的维可以传入-1或者原来的数值。
a=torch.randn(1,1,3,768)
print(a.shape) #torch.Size([1, 1, 3, 768])
b=a.expand(2,-1,-1,-1)
print(b.shape) #torch.Size([2, 1, 3, 768])
c=a.expand(2,1,3,768)
print(c.shape) #torch.Size([2, 1, 3, 768])
repeat()函数
沿着指定的维度,对原来的tensor进行数据复制。这个函数和expand()还是有点区别的。expand()只能对维度为1的维进行扩大,而repeat()对所有的维度可以随意操作。
a=torch.randn(2,1,768)
print(a)
print(a.shape) #torch.Size([2, 1, 768])
b=a.repeat(1,2,1)
print(b)
print(b.shape) #torch.Size([2, 2, 768])
c=a.repeat(3,3,3)
print(c)
print(c.shape) #torch.Size([6, 3, 2304])