In [2]:
import numpy as np
import pandas as pd
import os
import json
import math
from utils import load_data
from utils import load_global_inputs
from utils import basic_hyperparams
from GeoMAN import GeoMAN
from utils import shuffle_data
from utils import get_batch_feed_dict
from utils import get_valid_batch_feed_dict
import torch.optim as optim
from torch import nn

In [3]:
# load hyperparameters
hps = basic_hyperparams()
print(hps)

{'learning_rate': 0.001, 'lambda_l2_reg': 0.001, 'gc_rate': 2.5, 'dropout_rate': 0.3, 'n_stacked_layers': 2, 's_attn_flag': 2, 'ext_flag': True, 'n_sensors': 35, 'n_input_encoder': 19, 'n_steps_encoder': 12, 'n_hidden_encoder': 64, 'n_input_decoder': 1, 'n_external_input': 83, 'n_steps_decoder': 6, 'n_hidden_decoder': 64, 'n_output_decoder': 1}


In [4]:
# read data from different sets
input_path = './sample_data/'
training_data = load_data(
    input_path, 'train', hps['n_steps_encoder'], hps['n_steps_decoder'])
valid_data = load_data(
    input_path, 'eval', hps['n_steps_encoder'], hps['n_steps_decoder'])
global_inpts, global_attn_sts = load_global_inputs(
    input_path, hps['n_steps_encoder'], hps['n_steps_decoder'])
# print dataset info
num_train = len(training_data[0])
num_valid = len(valid_data[0])
print('train samples: {0}'.format(num_train))
print('eval samples: {0}'.format(num_valid))
#[mode_local_inp, global_inp_index, global_attn_index, mode_ext_inp, mode_labels]
print(training_data[0].shape)
print(training_data[1].shape)
print(training_data[2].shape)
print(training_data[3].shape)
print(training_data[4].shape)
#global_inputs, global_attn_states
print(global_inpts.shape)
print(global_attn_sts.shape)

train samples: 100
eval samples: 10
(100, 12, 19)
(100,)
(100,)
(100, 6, 83)
(100, 6)
(500, 35)
(500, 35, 19, 12)


In [5]:
np.random.seed(2017)
model = GeoMAN(hps)

total_epoch = 50
batch_size = 16
lr = 0.0001

optimizer = optim.RMSprop(model.parameters(), lr=lr, momentum=0.9)
def criterion(preds, labels):
    loss_fn = nn.MSELoss()
    loss = 0.0
    
    for ps, ls in zip(preds, labels):
        loss += loss_fn(ps.float(),ls.float())
        
    return loss

for i in range(total_epoch):
    print('----------epoch {}-----------'.format(i))
    training_data = shuffle_data(training_data)
    lossSum = 0
    i += 1
    for j in range(0, num_train, batch_size):
        x = get_batch_feed_dict(j, batch_size, training_data, global_inpts, global_attn_sts)
        preds, labels = model(x)
        loss = criterion(preds, labels)
        lossSum += loss.data.numpy()
        loss.backward(retain_graph=True)
        optimizer.step()
        
    print(lossSum)
#print(preds, labels)

# test
n_split_test = 2
test_loss = 0
test_indexes = np.int64(
    np.linspace(0, num_valid, n_split_test))
rmses=[]
maes=[]
from sklearn.preprocessing import MinMaxScaler
scaler = MinMaxScaler(feature_range=(0, 1))
for k in range(n_split_test - 1):
    x = get_valid_batch_feed_dict(k, test_indexes, valid_data, global_inpts, global_attn_sts)
    # re-scale predicted labels
    batch_preds, _ = model(x)
    batch_preds = np.array([bp.data.numpy() for bp in  batch_preds ])
    batch_preds = np.swapaxes(batch_preds, 0, 1)
    batch_preds = np.reshape(batch_preds, [batch_preds.shape[0], -1])
    # re-scale real labels
    batch_labels = valid_data[4]
    batch_labels = batch_labels[test_indexes[k]:test_indexes[k + 1]]
    rmses.append(np.sqrt(np.sum(np.square(batch_labels-batch_preds))/
                         (batch_labels.shape[0]*batch_labels.shape[1])))
    maes.append((np.abs(batch_labels-batch_preds)).mean())

test_rmses = np.asarray(rmses)
test_maes = np.asarray(maes)

print('===============METRIC===============')
print('rmse = {:.6f}'.format(test_rmses.mean()))
print('mae = {:.6f}'.format(test_maes.mean()))

----------epoch 0-----------
34.392651081085205
----------epoch 1-----------
32.09111285209656
----------epoch 2-----------
30.321520805358887
----------epoch 3-----------
28.657651901245117
----------epoch 4-----------
25.779441714286804
----------epoch 5-----------
27.266942381858826
----------epoch 6-----------
25.80989384651184
----------epoch 7-----------
22.023674607276917
----------epoch 8-----------
26.35184931755066
----------epoch 9-----------
22.535444736480713
----------epoch 10-----------
19.636800527572632
----------epoch 11-----------
22.199295163154602
----------epoch 12-----------
21.164698719978333
----------epoch 13-----------
19.719813585281372
----------epoch 14-----------
23.407500863075256
----------epoch 15-----------
20.608052134513855
----------epoch 16-----------
20.279370188713074
----------epoch 17-----------
19.449496746063232
----------epoch 18-----------
19.843530893325806
----------epoch 19-----------
20.214853882789612
----------epoch 20-----------
17.