Transformer - 特征预处理
flyfish
原始数据
train_data.values
[[ 5.827 2.009 1.599 0.462 4.203 1.34 30.531][ 5.76 2.076 1.492 0.426 4.264 1.401 30.46 ][ 5.76 1.942 1.492 0.391 4.234 1.31 30.038][ 5.76 1.942 1.492 0.426 4.234 1.31 27.013][ 5.693 2.076 1.492 0.426 4.142 1.371 27.787][ 5.492 1.942 1.457 0.391 4.112 1.279 27.717][ 5.358 1.875 1.35 0.355 3.929 1.34 27.646][ 5.157 1.808 1.35 0.32 3.807 1.279 27.084][ 5.157 1.741 1.279 0.355 3.777 1.218 27.787][ 5.157 1.808 1.35 0.426 3.777 1.188 27.506][ 5.157 1.808 1.315 0.391 3.777 1.249 27.857][ 5.157 1.942 1.35 0.426 3.807 1.279 27.013][ 5.09 1.942 1.279 0.391 3.807 1.279 25.044][ 5.224 2.009 1.457 0.533 3.807 1.249 24.551][ 5.291 1.808 1.457 0.426 3.777 1.218 23.566][ 5.358 1.942 1.492 0.462 3.807 1.31 21.526][ 5.358 1.942 1.492 0.462 3.868 1.279 21.948][ 5.492 2.009 1.492 0.462 3.929 1.34 21.456][ 5.492 1.942 1.492 0.426 3.929 1.34 22.792][ 5.492 2.076 1.492 0.497 3.99 1.31 21.034][ 5.626 2.143 1.528 0.533 4.051 1.371 21.174][ 5.961 2.344 1.67 0.604 4.234 1.492 20.823][ 6.162 2.411 1.777 0.604 4.325 1.523 21.174][ 6.631 2.478 1.99 0.746 4.66 1.675 21.174][ 7.167 2.947 2.132 0.782 5.026 1.858 22.792][ 7.502 3.215 2.239 0.888 5.33 1.98 23.848][ 7.703 3.349 2.487 1.031 5.269 1.919 24.34 ]......
通过sklearn的fit和transform将数据规范化
train_data = df_data[border1s[0]:border2s[0]]
self.scaler.fit(train_data.values)
data = self.scaler.transform(df_data.values)
规划化后的数据就是将要训练的数据
transform_data [[ 0.6156 -1.3896 -0.991 ... 1.1402 -0.9535 2.907 ][ 0.5294 -1.2429 -1.2435 ... 1.2172 -0.6099 2.8853][ 0.5294 -1.5362 -1.2435 ... 1.1794 -1.1224 2.7561]...[ 5.6959 5.6479 11.0826 ... -0.82 0.933 1.5291][ 7.1602 6.0879 13.2628 ... -0.897 1.1076 1.5077][ 6.8156 5.3546 14.3529 ... -0.82 1.2765 1.5508]]
可以通过inverse_transform将数据还原
def inverse_transform(self, data):return self.scaler.inverse_transform(data)
inverse_transform_data: [[ 5.827 2.009 1.599 ... 4.203 1.34 30.531][ 5.76 2.076 1.492 ... 4.264 1.401 30.46 ][ 5.76 1.942 1.492 ... 4.234 1.31 30.038]...[ 9.779 5.224 6.716 ... 2.65 1.675 26.028][10.918 5.425 7.64 ... 2.589 1.706 25.958][10.65 5.09 8.102 ... 2.65 1.736 26.099]]......
配置
seq_len:24
label_len:12
pred_len:24
set_type:0
features:M
target:OT
scale:True
timeenc:1
freq:h
root_path:./dataset/ETT-small/
data_path:ETTm1.csv
scaler:StandardScaler()
data_x是训练数据
data_x:[[ 6.1557e-01 -1.3896e+00 -9.9100e-01 -1.3248e+00 1.1402e+00 -9.5346e-012.9070e+00][ 5.2944e-01 -1.2429e+00 -1.2435e+00 -1.4268e+00 1.2172e+00 -6.0995e-012.8853e+00][ 5.2944e-01 -1.5362e+00 -1.2435e+00 -1.5260e+00 1.1794e+00 -1.1224e+002.7561e+00][ 5.2944e-01 -1.5362e+00 -1.2435e+00 -1.4268e+00 1.1794e+00 -1.1224e+001.8305e+00][ 4.4331e-01 -1.2429e+00 -1.2435e+00 -1.4268e+00 1.0633e+00 -7.7889e-012.0673e+00][ 1.8492e-01 -1.5362e+00 -1.3260e+00 -1.5260e+00 1.0254e+00 -1.2970e+002.0459e+00][ 1.2660e-02 -1.6829e+00 -1.5785e+00 -1.6280e+00 7.9439e-01 -9.5346e-012.0242e+00]
时间数据的编码
具体看这里
原值
df_stamp['date'].values: ['2016-07-01T00:00:00.000000000' '2016-07-01T00:15:00.000000000''2016-07-01T00:30:00.000000000' '2016-07-01T00:45:00.000000000''2016-07-01T01:00:00.000000000' '2016-07-01T01:15:00.000000000''2016-07-01T01:30:00.000000000' '2016-07-01T01:45:00.000000000''2016-07-01T02:00:00.000000000' '2016-07-01T02:15:00.000000000''2016-07-01T02:30:00.000000000' '2016-07-01T02:45:00.000000000''2016-07-01T03:00:00.000000000' '2016-07-01T03:15:00.000000000''2016-07-01T03:30:00.000000000' '2016-07-01T03:45:00.000000000''2016-07-01T04:00:00.000000000' '2016-07-01T04:15:00.000000000''2016-07-01T04:30:00.000000000' '2016-07-01T04:45:00.000000000''2016-07-01T05:00:00.000000000' '2016-07-01T05:15:00.000000000'......
编码之后
data_stamp: [[-0.5 0.1667 -0.5 -0.0014][-0.5 0.1667 -0.5 -0.0014][-0.5 0.1667 -0.5 -0.0014][-0.5 0.1667 -0.5 -0.0014][-0.4565 0.1667 -0.5 -0.0014][-0.4565 0.1667 -0.5 -0.0014][-0.4565 0.1667 -0.5 -0.0014][-0.4565 0.1667 -0.5 -0.0014][-0.413 0.1667 -0.5 -0.0014][-0.413 0.1667 -0.5 -0.0014][-0.413 0.1667 -0.5 -0.0014][-0.413 0.1667 -0.5 -0.0014][-0.3696 0.1667 -0.5 -0.0014][-0.3696 0.1667 -0.5 -0.0014][-0.3696 0.1667 -0.5 -0.0014][-0.3696 0.1667 -0.5 -0.0014][-0.3261 0.1667 -0.5 -0.0014][-0.3261 0.1667 -0.5 -0.0014][-0.3261 0.1667 -0.5 -0.0014][-0.3261 0.1667 -0.5 -0.0014][-0.2826 0.1667 -0.5 -0.0014][-0.2826 0.1667 -0.5 -0.0014]
s_begin: 0
s_end: 24
r_begin: 12
r_end: 48s_begin: 1
s_end: 25
r_begin: 13
r_end: 49......
seq_x: [[ 6.1557e-01 -1.3896e+00 -9.9100e-01 -1.3248e+00 1.1402e+00 -9.5346e-012.9070e+00][ 5.2944e-01 -1.2429e+00 -1.2435e+00 -1.4268e+00 1.2172e+00 -6.0995e-012.8853e+00][ 5.2944e-01 -1.5362e+00 -1.2435e+00 -1.5260e+00 1.1794e+00 -1.1224e+002.7561e+00][ 5.2944e-01 -1.5362e+00 -1.2435e+00 -1.4268e+00 1.1794e+00 -1.1224e+001.8305e+00][ 4.4331e-01 -1.2429e+00 -1.2435e+00 -1.4268e+00 1.0633e+00 -7.7889e-012.0673e+00][ 1.8492e-01 -1.5362e+00 -1.3260e+00 -1.5260e+00 1.0254e+00 -1.2970e+002.0459e+00][ 1.2660e-02 -1.6829e+00 -1.5785e+00 -1.6280e+00 7.9439e-01 -9.5346e-012.0242e+00][-2.4573e-01 -1.8295e+00 -1.5785e+00 -1.7271e+00 6.4040e-01 -1.2970e+001.8522e+00][-2.4573e-01 -1.9762e+00 -1.7460e+00 -1.6280e+00 6.0253e-01 -1.6405e+002.0673e+00][-2.4573e-01 -1.8295e+00 -1.5785e+00 -1.4268e+00 6.0253e-01 -1.8094e+001.9814e+00][-2.4573e-01 -1.8295e+00 -1.6611e+00 -1.5260e+00 6.0253e-01 -1.4659e+002.0888e+00][-2.4573e-01 -1.5362e+00 -1.5785e+00 -1.4268e+00 6.4040e-01 -1.2970e+001.8305e+00][-3.3186e-01 -1.5362e+00 -1.7460e+00 -1.5260e+00 6.4040e-01 -1.2970e+001.2280e+00][-1.5960e-01 -1.3896e+00 -1.3260e+00 -1.1237e+00 6.4040e-01 -1.4659e+001.0771e+00][-7.3470e-02 -1.8295e+00 -1.3260e+00 -1.4268e+00 6.0253e-01 -1.6405e+007.7573e-01][ 1.2660e-02 -1.5362e+00 -1.2435e+00 -1.3248e+00 6.4040e-01 -1.1224e+001.5150e-01][ 1.2660e-02 -1.5362e+00 -1.2435e+00 -1.3248e+00 7.1740e-01 -1.2970e+002.8063e-01][ 1.8492e-01 -1.3896e+00 -1.2435e+00 -1.3248e+00 7.9439e-01 -9.5346e-011.3008e-01][ 1.8492e-01 -1.5362e+00 -1.2435e+00 -1.4268e+00 7.9439e-01 -9.5346e-015.3889e-01][ 1.8492e-01 -1.2429e+00 -1.2435e+00 -1.2257e+00 8.7139e-01 -1.1224e+009.5210e-04][ 3.5718e-01 -1.0962e+00 -1.1585e+00 -1.1237e+00 9.4839e-01 -7.7889e-014.3791e-02][ 7.8783e-01 -6.5627e-01 -8.2347e-01 -9.2254e-01 1.1794e+00 -9.7496e-02-6.3613e-02][ 1.0462e+00 -5.0961e-01 -5.7100e-01 -9.2254e-01 1.2942e+00 7.7075e-024.3791e-02][ 1.6491e+00 -3.6295e-01 -6.8426e-02 -5.2023e-01 1.7171e+00 9.3304e-014.3791e-02]]
seq_y: [[-3.3186e-01 -1.5362e+00 -1.7460e+00 -1.5260e+00 6.4040e-01 -1.2970e+001.2280e+00]
[-1.5960e-01 -1.3896e+00 -1.3260e+00 -1.1237e+00 6.4040e-01 -1.4659e+001.0771e+00]
[-7.3470e-02 -1.8295e+00 -1.3260e+00 -1.4268e+00 6.0253e-01 -1.6405e+007.7573e-01]
[ 1.2660e-02 -1.5362e+00 -1.2435e+00 -1.3248e+00 6.4040e-01 -1.1224e+001.5150e-01]
[ 1.2660e-02 -1.5362e+00 -1.2435e+00 -1.3248e+00 7.1740e-01 -1.2970e+002.8063e-01]
[ 1.8492e-01 -1.3896e+00 -1.2435e+00 -1.3248e+00 7.9439e-01 -9.5346e-011.3008e-01]
[ 1.8492e-01 -1.5362e+00 -1.2435e+00 -1.4268e+00 7.9439e-01 -9.5346e-015.3889e-01]
[ 1.8492e-01 -1.2429e+00 -1.2435e+00 -1.2257e+00 8.7139e-01 -1.1224e+009.5210e-04]
[ 3.5718e-01 -1.0962e+00 -1.1585e+00 -1.1237e+00 9.4839e-01 -7.7889e-014.3791e-02]
[ 7.8783e-01 -6.5627e-01 -8.2347e-01 -9.2254e-01 1.1794e+00 -9.7496e-02-6.3613e-02]
[ 1.0462e+00 -5.0961e-01 -5.7100e-01 -9.2254e-01 1.2942e+00 7.7075e-024.3791e-02]
[ 1.6491e+00 -3.6295e-01 -6.8426e-02 -5.2023e-01 1.7171e+00 9.3304e-014.3791e-02]
[ 2.3382e+00 6.6367e-01 2.6662e-01 -4.1824e-01 2.1791e+00 1.9636e+005.3889e-01]
[ 2.7688e+00 1.2503e+00 5.1909e-01 -1.1793e-01 2.5628e+00 2.6506e+008.6202e-01]
[ 3.0272e+00 1.5436e+00 1.1043e+00 2.8720e-01 2.4858e+00 2.3071e+001.0126e+00]
[ 2.6827e+00 1.1037e+00 6.8662e-01 8.3219e-02 2.2169e+00 2.1325e+006.4660e-01]
[ 2.6827e+00 1.3970e+00 6.8662e-01 2.8720e-01 2.2561e+00 4.0246e+006.4660e-01]
[ 2.9411e+00 1.3970e+00 6.0168e-01 3.8636e-01 2.5628e+00 3.1630e+008.4030e-01]
[ 2.9411e+00 1.6903e+00 6.8662e-01 4.8835e-01 2.4858e+00 2.3071e+008.6202e-01]
[ 2.8549e+00 1.3970e+00 6.8662e-01 3.8636e-01 2.4479e+00 2.3071e+009.0486e-01]
[ 2.7105e-01 8.1033e-01 1.0217e+00 6.8951e-01 -4.3503e-01 -4.3537e-011.9465e-01]
[-1.5960e-01 5.1701e-01 8.5414e-01 6.8951e-01 -6.2815e-01 -6.0995e-013.2378e-01]
[-2.4573e-01 3.7035e-01 6.8662e-01 6.8951e-01 -5.8902e-01 -4.3537e-014.0976e-01]
[-3.3186e-01 -5.0961e-01 -4.8842e-01 -1.1793e-01 -7.0514e-01 -2.6644e-01-1.4198e+00]
[-1.0196e+00 -2.1629e-01 -2.3595e-01 -3.1908e-01 -7.8214e-01 -7.7889e-01-1.0970e+00]
[-1.0196e+00 -6.5627e-01 -5.7100e-01 -4.1824e-01 -7.4301e-01 -6.0995e-01-1.0110e+00]
[-8.4735e-01 -2.1629e-01 -3.2089e-01 -2.1709e-01 -6.2815e-01 -2.6644e-01-7.7413e-01]
[-8.4735e-01 -5.0961e-01 -2.3595e-01 -1.1793e-01 -6.2815e-01 -7.7889e-01-5.3729e-01]
[-5.0283e-01 -2.1629e-01 -6.8426e-02 -2.1709e-01 -4.3503e-01 -9.7496e-02-3.2187e-01]
[ 9.8790e-02 7.7033e-02 4.3415e-01 -1.1793e-01 -2.0530e-01 2.4601e-01-7.7413e-01]
[ 1.8492e-01 7.7033e-02 3.5157e-01 -2.1709e-01 -5.1305e-02 2.4601e-01-7.5241e-01]
[ 7.0170e-01 2.2369e-01 8.5414e-01 8.3219e-02 1.4055e-01 2.4601e-01-4.9415e-01]
[ 5.2944e-01 -2.1629e-01 4.3415e-01 -2.1709e-01 1.7968e-01 -9.7496e-02-2.7903e-01]
[ 3.5718e-01 2.2369e-01 3.5157e-01 -2.1709e-01 6.3558e-02 5.8952e-01-3.8644e-01]
[-1.3641e+00 -1.0962e+00 -2.0811e+00 -1.4268e+00 -1.6617e-01 7.6410e-01-3.8644e-01]
[-1.3641e+00 -5.0961e-01 -1.0736e+00 -1.4268e+00 -2.4317e-01 4.2059e-01-2.1447e-01]]
seq_x_mark: [[-0.5 0.1667 -0.5 -0.0014][-0.5 0.1667 -0.5 -0.0014][-0.5 0.1667 -0.5 -0.0014][-0.5 0.1667 -0.5 -0.0014][-0.4565 0.1667 -0.5 -0.0014][-0.4565 0.1667 -0.5 -0.0014][-0.4565 0.1667 -0.5 -0.0014][-0.4565 0.1667 -0.5 -0.0014][-0.413 0.1667 -0.5 -0.0014][-0.413 0.1667 -0.5 -0.0014][-0.413 0.1667 -0.5 -0.0014][-0.413 0.1667 -0.5 -0.0014][-0.3696 0.1667 -0.5 -0.0014][-0.3696 0.1667 -0.5 -0.0014][-0.3696 0.1667 -0.5 -0.0014][-0.3696 0.1667 -0.5 -0.0014][-0.3261 0.1667 -0.5 -0.0014][-0.3261 0.1667 -0.5 -0.0014][-0.3261 0.1667 -0.5 -0.0014][-0.3261 0.1667 -0.5 -0.0014][-0.2826 0.1667 -0.5 -0.0014][-0.2826 0.1667 -0.5 -0.0014][-0.2826 0.1667 -0.5 -0.0014][-0.2826 0.1667 -0.5 -0.0014]]
seq_y_mark: [[-0.3696 0.1667 -0.5 -0.0014][-0.3696 0.1667 -0.5 -0.0014][-0.3696 0.1667 -0.5 -0.0014][-0.3696 0.1667 -0.5 -0.0014][-0.3261 0.1667 -0.5 -0.0014][-0.3261 0.1667 -0.5 -0.0014][-0.3261 0.1667 -0.5 -0.0014][-0.3261 0.1667 -0.5 -0.0014][-0.2826 0.1667 -0.5 -0.0014][-0.2826 0.1667 -0.5 -0.0014][-0.2826 0.1667 -0.5 -0.0014][-0.2826 0.1667 -0.5 -0.0014][-0.2391 0.1667 -0.5 -0.0014][-0.2391 0.1667 -0.5 -0.0014][-0.2391 0.1667 -0.5 -0.0014][-0.2391 0.1667 -0.5 -0.0014][-0.1957 0.1667 -0.5 -0.0014][-0.1957 0.1667 -0.5 -0.0014][-0.1957 0.1667 -0.5 -0.0014][-0.1957 0.1667 -0.5 -0.0014][-0.1522 0.1667 -0.5 -0.0014][-0.1522 0.1667 -0.5 -0.0014][-0.1522 0.1667 -0.5 -0.0014][-0.1522 0.1667 -0.5 -0.0014][-0.1087 0.1667 -0.5 -0.0014][-0.1087 0.1667 -0.5 -0.0014][-0.1087 0.1667 -0.5 -0.0014][-0.1087 0.1667 -0.5 -0.0014][-0.0652 0.1667 -0.5 -0.0014][-0.0652 0.1667 -0.5 -0.0014][-0.0652 0.1667 -0.5 -0.0014][-0.0652 0.1667 -0.5 -0.0014][-0.0217 0.1667 -0.5 -0.0014][-0.0217 0.1667 -0.5 -0.0014][-0.0217 0.1667 -0.5 -0.0014][-0.0217 0.1667 -0.5 -0.0014]]
代码
class Dataset_Custom(Dataset):def __init__(self, root_path, flag='train', size=None,features='S', data_path='ETTh1.csv',target='OT', scale=True, timeenc=0, freq='h'):# size [seq_len, label_len, pred_len]# infoif size == None:self.seq_len = 24 * 4 * 4self.label_len = 24 * 4self.pred_len = 24 * 4else:self.seq_len = size[0]self.label_len = size[1]self.pred_len = size[2]# initassert flag in ['train', 'test', 'val']type_map = {'train': 0, 'val': 1, 'test': 2}self.set_type = type_map[flag]self.features = featuresself.target = targetself.scale = scaleself.timeenc = timeencself.freq = freqself.root_path = root_pathself.data_path = data_pathself.__read_data__()def __read_data__(self):self.scaler = StandardScaler()df_raw = pd.read_csv(os.path.join(self.root_path,self.data_path))'''df_raw.columns: ['date', ...(other features), target feature]'''cols = list(df_raw.columns)cols.remove(self.target)cols.remove('date')df_raw = df_raw[['date'] + cols + [self.target]]# print(cols)# num_train = int(len(df_raw) * 0.7)# print("num_train:",num_train)# num_test = int(len(df_raw) * 0.2)num_train = int(len(df_raw) * 0.5)print("num_train:",num_train)num_test = int(len(df_raw) * 0.2)num_vali = len(df_raw) - num_train - num_testborder1s = [0, num_train - self.seq_len, len(df_raw) - num_test - self.seq_len]border2s = [num_train, num_train + num_vali, len(df_raw)]border1 = border1s[self.set_type]border2 = border2s[self.set_type]if self.features == 'M' or self.features == 'MS':cols_data = df_raw.columns[1:]df_data = df_raw[cols_data]elif self.features == 'S':df_data = df_raw[[self.target]]if self.scale:train_data = df_data[border1s[0]:border2s[0]]self.scaler.fit(train_data.values)data = self.scaler.transform(df_data.values)#--------------------------------------------------------------------print("train_data.values",train_data.values)print("transform_data",data)inverse_transform_data=self.inverse_transform(data)print("inverse_transform_data:",inverse_transform_data)#--------------------------------------------------------------------else:data = df_data.valuesdf_stamp = df_raw[['date']][border1:border2]df_stamp['date'] = pd.to_datetime(df_stamp.date)if self.timeenc == 0:df_stamp['month'] = df_stamp.date.apply(lambda row: row.month, 1)df_stamp['day'] = df_stamp.date.apply(lambda row: row.day, 1)df_stamp['weekday'] = df_stamp.date.apply(lambda row: row.weekday(), 1)df_stamp['hour'] = df_stamp.date.apply(lambda row: row.hour, 1)data_stamp = df_stamp.drop(['date'], 1).valueselif self.timeenc == 1:print("df_stamp['date'].values:",df_stamp['date'].values)data_stamp = time_features(pd.to_datetime(df_stamp['date'].values), freq=self.freq)data_stamp = data_stamp.transpose(1, 0)self.data_x = data[border1:border2]self.data_y = data[border1:border2]self.data_stamp = data_stampprint("data_stamp:",data_stamp)print('\n'.join(['%s:%s' % item for item in self.__dict__.items()]) )def __getitem__(self, index):s_begin = indexs_end = s_begin + self.seq_lenr_begin = s_end - self.label_lenr_end = r_begin + self.label_len + self.pred_lenprint("s_begin:",s_begin)print("s_end:",s_end)print("r_begin:",r_begin)print("r_end:",r_end)seq_x = self.data_x[s_begin:s_end]seq_y = self.data_y[r_begin:r_end]seq_x_mark = self.data_stamp[s_begin:s_end]seq_y_mark = self.data_stamp[r_begin:r_end]print("seq_x.shape:",seq_x.shape)print("seq_y.shape:",seq_y.shape)print("seq_x_mark.shape:",seq_x_mark.shape)print("seq_y_mark.shape:",seq_y_mark.shape)print("seq_x:",seq_x)print("seq_y:",seq_y)print("seq_x_mark:",seq_x_mark)print("seq_y_mark:",seq_y_mark)return seq_x, seq_y, seq_x_mark, seq_y_markdef __len__(self):return len(self.data_x) - self.seq_len - self.pred_len + 1def inverse_transform(self, data):return self.scaler.inverse_transform(data)
对比下形状
for i, (batch_x, , , ): torch.Size([1, 24, 7])for i, (, batch_y, , ): torch.Size([1, 36, 7])for i, (, , batch_x_mark, ): torch.Size([1, 24, 4])for i, (, , , batch_y_mark): torch.Size([1, 36, 4])
seq_x.shape: (24, 7)
seq_y.shape: (36, 7)
seq_x_mark.shape: (24, 4)
seq_y_mark.shape: (36, 4)
训练数据用
先用batch_x, batch_y, batch_x_mark, batch_y_mark
作为参数
outputs, batch_y = self._predict(batch_x, batch_y, batch_x_mark, batch_y_mark)
def _predict(self, batch_x, batch_y, batch_x_mark, batch_y_mark):# decoder inputdec_inp = torch.zeros_like(batch_y[:, -self.args.pred_len:, :]).float()dec_inp = torch.cat([batch_y[:, :self.args.label_len, :], dec_inp], dim=1).float().to(self.device)# encoder - decoderdef _run_model():outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)if self.args.output_attention:outputs = outputs[0]return outputs
batch_y变dec_inp
假如是Vanilla Transformer模型
输入的对应关系如下
x_enc = batch_xx_mark_enc = batch_x_markx_dec = dec_inpx_mark_dec = batch_y_mark