item()方法只适用于包含单个元素的张量,并将其值提取为标量。它是针对张量对象的特定方法,用于方便地获取张量中的值。
import torch
tensor = torch.tensor([5])
print(type(tensor)) # 输出<class 'torch.Tensor'>
print(tensor) # 输出tensor([5])
value = tensor.item()
print(value) # 输出5
print(type(value)) # 输出<class 'int'>
注意该方法只适用于包含单个元素的张量,尝试使用 item()方法从一个包含多个元素的张量中提取值会报错。
import torch
tensor = torch.tensor([5,5])
# value = tensor.item() # 会报错
value = tensor[0].item() # 正确方法
print(value)
print(type(value))