#返回的数据集,既有标号又有数据
from torch.utils.data import Dataset, DataLoader
class LoadDataset(Dataset):
def __init__(self, data):
self.x = data
def __len__(self):
return self.x.shape[0]
def __getitem__(self, idx):
return torch.from_numpy(np.array(self.x[idx])).float(),\
torch.from_numpy(np.array(idx))
x_train = np.loadtxt("./data_s/CITE/cite.txt")
dataset = LoadDataset(x_train)
#在模型训练过程中
train_loader = DataLoader(dataset, batch_size=256, shuffle=True) #把打好序号的数据送入DataLoader中,打乱顺序并分批,每批256组数据。
for epoch in range(args.pretrain_epoch):
for batch_idx, (x, _) in enumerate(train_loader): #遍历每批数据
x = x.to(device) #将一批数据送入device中
z, x_bar = model(x)
loss = F.mse_loss(x_bar, x)
optimizer.zero_grad()
loss.backward()
optimizer.step()
with tor