from torch.utils.data.dataloader import DataLoader
#from dataset import TrainDataset
from dataset import DatasetFromHdf5
from vdsr import Net
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import MultiStepLR
import torch
from tqdm import tqdm
import os, argparse
import utils
import matplotlib.pyplot as plt
import h5py
import numpy as np
# parameters setting
parser = argparse.ArgumentParser(description="PyTorch VDSR")
parser.add_argument("--datasets_path", type=str, default=f'datasets/291_aug_train_VDSRx2.h5', help=".h5 file path")
parser.add_argument("--weight_save_path", type=str, default=f'weight/Train_VDSR_model_x2.pth', help="weight file save path")
parser.add_argument("--batch_size", type=int, default=64, help="Train batch size, Default: 64")
parser.add_argument("--num_workers", type=int, default=16, help="Train num workers, Default: 16")
parser.add_argument("--momentum", type=float, default=0.9, help="Momentum, Default: 0.9")
parser.add_argument("--weight_decay", type=float, default=0.0001, help="Weight decay, Default: 1e-4")
parser.add_argument("--init_lr", type=float, default=0.1, help="Learning Rate, Default=0.1")
parser.add_argument("--theta", type=float, default=0.01, help="Clipping Gradients theta, Default: 0.01")
parser.add_argument("--epochs", type=int, default=80, help="Number of epochs to train for, Default: 80")
parser.add_argument("--step", type=int, default=20, help="Every 20 epochs, lr down 10 times, Default: 20")
parser.add_argument("--gamma", type=float, default=0.1, help="lr down times, Default: 0.1")
# 设置所有可以使用的显卡,共计四块,根据自己设备修改
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3'
device_ids = [0, 1, 2, 3]
opt = parser.parse_args()
#print(opt)
# 加载数据集
#train_dataset = TrainDataset(opt.datasets_path)
train_dataset = DatasetFromHdf5(opt.datasets_path)
# print(train_dataset.data[0].shape)
#print(train_dataset.target.size)
train_dataloader = DataLoader(dataset=train_dataset, batch_size=opt.batch_size, num_workers=opt.num_workers, shuffle=True)
# 模型训练相关
model = Net()
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=opt.init_lr, momentum=opt.momentum, weight_decay=opt.weight_decay)
scheduler = MultiStepLR(optimizer, milestones=[opt.step, opt.step*2, opt.step*3], gamma=opt.gamma)
# 设置GPU
model = nn.DataParallel(model, device_ids=device_ids)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# 训练
# TODO 论文中实验部分没有对验证环节说明,但是如果训练中累加psnr,最终现存会溢出。
# best_psnr = 0
# each_psnr = []
for epoch in range(opt.epochs):
epoch_loss = 0
#sum_psnr = 0
print("Epoch = {}, lr = {}".format(epoch, optimizer.param_groups[0]["lr"]))
with tqdm(train_dataloader, desc='standing...') as tepoch:
for iteration, (input, target) in enumerate(tepoch):
input, target = input.to(device), target.to(device)
optimizer.zero_grad()
out = model(input)
loss = criterion(out, target)
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), opt.theta/optimizer.param_groups[0]["lr"])
optimizer.step()
epoch_loss += loss.item()
tepoch.set_description(f'Epoch [{epoch + 1}/{opt.epochs}]')
#vdsr_psnr = utils.calc_psnr(input, out)
#sum_psnr += vdsr_psnr
if iteration % 100 == 0:
print("===> Epoch[{}]({}/{}): Loss: {:.10f}".format(epoch, iteration, len(train_dataloader), loss.item()))
# print("===> Epoch[{}]({}/{}): Loss: {:.10f}, PSNR: {:.2f} ".format(epoch, iteration, len(train_dataloader),
# loss.item(), vdsr_psnr))
#avg_psnr = sum_psnr / len(train_dataloader)
# print(f"Average PSNR: {avg_psnr} dB")
# if epoch > 0: # TODO 忽略第一个epoch,因为第一个epoch是从最大往下降的,值偏高
# each_psnr.append(avg_psnr)
# if avg_psnr >= best_psnr and epoch > 0:
# best_psnr = avg_psnr # 用psnr衡量模型,保存最好的
# torch.save(model, r"weight/best_model_VDSR_x2.pth")
print(f"Epoch {epoch}. Training loss: {epoch_loss / len(train_dataloader)}")
scheduler.step()
# 训练完保存模型
if len(device_ids) > 1:
torch.save(model.module.state_dict(), opt.weight_save_path)
else:
torch.save(model.state_dict(), opt.weight_save_path)
torch.save(model, 'weight/VDSR_model_x2_80epoch.pth')
# # 画epoch和psnr的关系
# epochs = range(0, opt.epochs)
#
# # plot画图,设置颜色和图例,legend设置图例样式,右下,字体大小
# plt.plot(epochs, each_psnr, color='red', label='VDSR(trained on T291)')
# plt.legend(loc='lower right', prop={'family':'Times New Roman', 'size': 12})
#
# # 横纵轴文字内容以及字体样式
# plt.xlabel("epoch",fontproperties='Times New Roman', size = 12)
# plt.ylabel("Average test PSNR (dB)",fontproperties='Times New Roman', size=12)
#
# # 横纵轴刻度以及字体样式
# plt.xlim(0, opt.epochs)
# plt.ylim(20, 40)
# plt.yticks(fontproperties='Times New Roman', size=12)
# plt.xticks(fontproperties='Times New Roman', size=12)
#
# # 网格样式
# plt.grid(ls = '--')
# plt.savefig('fig4-a.png', bbox_inches='tight')
# plt.show()
#
# # epoch_loss = 0
# # # 查看训练集第一个批次,看数据集是否封装正确
# # for batch_idx, (data, target) in enumerate(train_dataloader):
# # # data 和 target 是每个批次的数据和标签
# # # 这里只输出第一个批次的数据
# # print("Batch Index:", batch_idx)
# # print("Data shape:", data.shape) # 输出数据形状
# # print("Target shape:", target.shape) # 输出标签形状
# # print("First Batch Data:", data[0]) # 输出第一个批次的数据
# # print("First Batch Target:", target[0]) # 输出第一个批次的标签
# #
# # optimizer.zero_grad()
# # out = model(data)
# # loss = criterion(out, target)
# # loss.backward()
# # optimizer.step()
# # epoch_loss += loss.item()
# # print(loss)
# # print(epoch_loss)
# # print(epoch_loss / len(train_dataloader))
# # break
#
# # if batch_idx == 3:
# # break # 只输出第一个批次
#
# # print(input.shape)
# # print(label.shape)
没有合适的资源?快使用搜索试试~ 我知道了~
温馨提示
1.数据增强:旋转、翻转、缩放(可选)。 2.制作.h5格式的数据集:python实现,与github上的matlab代码相同,统一了编程语言。 3.模型实现:pytorch搭建vdsr网络。 4.训练:参数与论文中完全一致,与github上的参考代码不同。 5.测试:PSNR评估和图像可视化,与Bicubic双三次插值对比。 资源对应文章:https://blog.csdn.net/qq_36584673/article/details/136625210 本文用Pytorch实现了VDSR算法的全部流程。将参考4中的制作.h5数据集的matlab代码改为python,统一了编程语言。这样,在python项目中只需按顺序一步一步运行文件,即可得到最终的结果。 使用说明请见文章,非常详细,手把手教你运行。
资源推荐
资源详情
资源评论





















收起资源包目录











































































共 64 条
- 1


十小大

- 粉丝: 1w+
上传资源 快速赚钱
我的内容管理 展开
我的资源 快来上传第一个资源
我的收益
登录查看自己的收益我的积分 登录查看自己的积分
我的C币 登录后查看C币余额
我的收藏
我的下载
下载帮助


最新资源
- 永磁同步电机SVPWM弱磁控制仿真Simulink模型研究:前馈弱磁法及其应用 v2.5
- 电力电子领域永磁同步电机SVPWM算法故障诊断与容错控制的Simulink仿真研究 - SVPWM 实用版
- Java语言Post请求的request只可以读取一次的问题解决
- Java多线程:Runnable与Thread的比较
- 电源领域PFM与PWM混合调制LLC全桥谐振变换器闭环仿真模型解析
- 基于Python实现BP神经网络识别手写字体源码
- 基于MATLAB的单相双极性SPWM逆变电路设计与仿真实现
- Comsol纳米摩擦发电机仿真:基于静电场的电极材料电势与电场分布计算
- 电子相册制作平台源码项目说明
- 使用robot_localization实现传感器融合的深入分步教程
- COMSOL模拟中晶界介电特性的电击穿与电树枝发展
- 毕业设计智能电网级联故障建模研究 Matlab完整源码带说明文档
- Comsol流固耦合仿真模型:多物理场计算揭示速度、压力、位移与应力分布
- 土柱单向冻结与冻融循环中水热力三场耦合的COMSOL仿真及隔水层影响研究
- ArcGIS Editor for OSM 10.0-0010.8
- Comsol反应器仿真模型:多物理场耦合下的温度、速度与浓度分布研究 - Comsol
资源上传下载、课程学习等过程中有任何疑问或建议,欢迎提出宝贵意见哦~我们会及时处理!
点击此处反馈



安全验证
文档复制为VIP权益,开通VIP直接复制

- 1
- 2
- 3
前往页