import sys
sys.path.append('/global/project/projectdirs/dasrepo/etalumis/pyprob_distributed/')
sys.path.append('/global/project/projectdirs/dasrepo/etalumis/pytorch_intel_3.7_larry_install/')
import pyprob
import pandas as pd

# specify path to ".network" file
#network_file = "/global/cscratch1/sd/wbhimji/etalumis_data_dec7_2018/networks2/sherpa_tau_decay_20181221_163632_traces_19933825.network"
network_file = "/global/cscratch1/sd/wbhimji/etalumis_data_dec7_2018/networks_voltangpu/sherpa_tau_decay_20190307_210234_traces_40000128.network"
output_file_valid = "/global/cscratch1/sd/wbhimji/etalumis_data_dec7_2018/valid_loss_mar8gpu.csv"
output_file_train = "/global/cscratch1/sd/wbhimji/etalumis_data_dec7_2018/train_loss_mar8gpu.csv"

# load network
inference_network = pyprob.nn.InferenceNetworkFeedForward._load(network_file)

# get the training and validation loss as numpy datastructures
train_loss = inference_network._history_train_loss
n_traces_train = inference_network._history_train_loss_trace
valid_loss = inference_network._history_valid_loss
n_traces = inference_network._history_valid_loss_trace

d = {"traces": n_traces, "loss": valid_loss}
df = pd.DataFrame(data=d)

df.to_csv(output_file_valid, sep=",", index=False)

d = {"traces": n_traces_train, "loss": train_loss}
df = pd.DataFrame(data=d)

df.to_csv(output_file_train, sep=",", index=False)
