数据集来源: https://github.com/zhouhaoyi/ETDataset
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import math
import time
from sklearn.preprocessing import MinMaxScaler
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from matplotlib_inline import backend_inline
backend_inline.set_matplotlib_formats('svg')
Transformer类
只使用encoder,然后把encoder的输出展平,后接线性层进行输出,理解为encoder只是把原始特征进行变换。
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000):
super(PositionalEncoding, self).__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer('pe', pe)
def forward(self, x):
return x + self.pe[:x.size(0), :] # [seq_length, batch_size, d_model]
class TransformerTimeSeriesModel(nn.Module):
def __init__(self, input_size, output_size, seq_length, label_length,
d_model = 256, nhead = 8, num_layers = 2, dropout=0.5):
'''
input_size, output_size, seq_length, label_length分别为输入维度、输出维度、历史时刻步数、多步预测步数
'''
super(TransformerTimeSeriesModel, self).__init__()
self.src_mask = None
self.embedding = nn.Linear(input_size, d_model)
self.pos_coding = PositionalEncoding(d_model)
self.encoder_layer = nn.TransformerEncoderLayer(d_model = d_model, nhead = nhead,
dim_feedforward=4 * d_model, dropout = dropout)
self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers = num_layers)
self.fc1 = nn.Linear(seq_length * d_model, label_length * d_model)
self.fc2 = nn.Linear(label_length * d_model, label_length * output_size)
self.init_weights()
def forward(self, src):
if self.src_mask is None:
device = src.device
mask = self._generate_square_subsequent_mask(len(src)).to(device)
self.src_mask = mask
src = self.embedding(src)
src = self.pos_coding(src)
en_output = se