1. 模型保存
方法一:
save model:
model = Mymodel()#按照你的实际模型
checkpoint = {'model': model,
'state_dict': model.state_dict(),
'optimizer' : optimizer.state_dict()}
torch.save(checkpoint, ckpt_file)
load model:
def load_checkpoint(filepath):
checkpoint = torch.load(filepath)
model = checkpoint['model']
model.load_state_dict(checkpoint['state_dict'])
for parameter in model.parameters():
parameter.requires_grad = False
model.eval()
return model
model = load_checkpoint('checkpoint.pth')
print(model)
方法二:
torch.save(model, PATH)
# Model class must be defined somewhere
model = torch.load(PATH)
model.eval()
方法三:
from pathlib import Path
import shutil
import zipfile
import io
def load_model(filepath):
with zipfile.ZipFile(filepath) as z:
with z.open("model_params.json") as f:
loaded_params = json.load(f)
with z.open("network.pt") as f:# 或者pkl格式
buffer = io.BytesIO(f.read())
saved_state_dict = torch.load(buffer)
def save_model(path):
"""
Saving model with two distinct files.
: param path: e.g. "/model/models"
"""
# Create folder
Path(path).mkdir(parents=True, exist_ok=True)
shutil.make_archive(path, 'zip', path)
print(path)
shutil.rmtree(path)
save_model("./model/models")
load_model("./model/models.zip")
2 tensorboardX
服务器用tensorboardX, source activate tensorflow
, import tensorflow
,然后,tensorboard --logdir yourlogdir
参考:
1 save and load models Kaggle;
2 tensorboardX教程;
3 tensorboardX github;
4 pytorch Loading weights from DataParallel models;
5.官网 save and load models;
6. pytorch load model