import torch
import torch.nn as nn
import torch.optim as optim
#import torch.distributed as dist
import horovod.torch as hvd #Lei
from torch.utils.data import DataLoader
import sys
import time
import os
import shutil
import uuid
import tempfile
import tarfile
import copy
from threading import Thread
from termcolor import colored

from . import Batch, OfflineDataset, SortedTraceSampler, SortedTraceBatchSamplerDistributed, SortedTraceBatchSampler, EmbeddingFeedForward, EmbeddingCNN2D5C, EmbeddingCNN3D5C
from .. import __version__, util, Optimizer, ObserveEmbedding


class InferenceNetwork(nn.Module):
    # observe_embeddings example: {'obs1': {'embedding':ObserveEmbedding.FEEDFORWARD, 'reshape': [10, 10], 'dim': 32, 'depth': 2}}
    def __init__(self, model, observe_embeddings={}, network_type=''):
        super().__init__()
        self._model = model
        self._layers_observe_embedding = nn.ModuleDict()
        self._layers_observe_embedding_final = None
        self._layers_pre_generated = False
        self._layers_initialized = False
        self._observe_embeddings = observe_embeddings
        self._observe_embedding_dim = None
        self._infer_observe = None
        self._infer_observe_embedding = {}
        self._optimizer = None

        self._total_train_seconds = 0
        self._total_train_traces = 0
        self._total_train_iterations = 0
        self._loss_initial = None
        self._loss_min = float('inf')
        self._loss_max = None
        self._loss_previous = float('inf')
        self._history_train_loss = []
        self._history_train_loss_trace = []
        self._history_valid_loss = []
        self._history_valid_loss_trace = []
        self._history_num_params = []
        self._history_num_params_trace = []
        self._distributed_train_loss = util.to_tensor(0.)
        self._distributed_valid_loss = util.to_tensor(0.)
        self._distributed_train_loss_min = float('inf')
        self._distributed_valid_loss_min = float('inf')
        self._distributed_history_train_loss = []
        self._distributed_history_train_loss_trace = []
        self._distributed_history_valid_loss = []
        self._distributed_history_valid_loss_trace = []
        self._distributed_filtered_train_loss=util.to_tensor(0.)
        self._distributed_filtered_valid_loss = util.to_tensor(0.)
        self._distributed_history_filtered_train_loss = []
        self._distributed_history_filtered_valid_loss = []
        self._distributed_filtered_train_loss_min = float('inf')
        self._distributed_filtered_valid_loss_min = float('inf')

        self._modified = util.get_time_str()
        self._updates = 0
        self._on_cuda = False
        self._device = torch.device('cpu')
        self._network_type = network_type
        self._optimizer_type = None
        self._learning_rate = None
        self._momentum = None
        self._batch_size = None
        self._distributed_backend = None
        self._distributed_world_size = None


    def _init_layers_observe_embedding(self, observe_embeddings, example_trace):
        if len(observe_embeddings) == 0:
            raise ValueError('At least one observe embedding is needed to initialize inference network.')
        observe_embedding_total_dim = 0
        for name, value in observe_embeddings.items():
            variable = example_trace.named_variables[name]
            # distribution = variable.distribution
            # if distribution is None:
            #     raise ValueError('Observable {}: cannot use this observation as an input to the inference network, because there is no associated likelihood.'.format(name))
            # else:
            if 'reshape' in value:
                input_shape = torch.Size(value['reshape'])
            else:
                input_shape = variable.value.size()
            if 'dim' in value:
                output_shape = torch.Size([value['dim']])
            else:
                print('Observable {}: embedding dim not specified, using the default 256.'.format(name))
                output_shape = torch.Size([256])
            if 'embedding' in value:
                embedding = value['embedding']
            else:
                print('Observable {}: observe embedding not specified, using the default FEEDFORWARD.'.format(name))
                embedding = ObserveEmbedding.FEEDFORWARD
            if embedding == ObserveEmbedding.FEEDFORWARD:
                if 'depth' in value:
                    depth = value['depth']
                else:
                    print('Observable {}: embedding depth not specified, using the default 2.'.format(name))
                    depth = 2
                layer = EmbeddingFeedForward(input_shape=input_shape, output_shape=output_shape, num_layers=depth)
            elif embedding == ObserveEmbedding.CNN2D5C:
                layer = EmbeddingCNN2D5C(input_shape=input_shape, output_shape=output_shape)
            elif embedding == ObserveEmbedding.CNN3D5C:
                layer = EmbeddingCNN3D5C(input_shape=input_shape, output_shape=output_shape)
            else:
                raise ValueError('Unknown embedding: {}'.format(embedding))
            layer.to(device=util._device)
            self._layers_observe_embedding[name] = layer
            observe_embedding_total_dim += util.prod(output_shape)
        self._observe_embedding_dim = observe_embedding_total_dim
        print('Observe embedding dimension: {}'.format(self._observe_embedding_dim))
        self._layers_observe_embedding_final = EmbeddingFeedForward(input_shape=self._observe_embedding_dim, output_shape=self._observe_embedding_dim, num_layers=2)
        self._layers_observe_embedding_final.to(device=util._device)

    def _embed_observe(self, traces=None):
        embedding = []
        for name, layer in self._layers_observe_embedding.items():
            values = torch.stack([util.to_tensor(trace.named_variables[name].value) for trace in traces]).view(len(traces), -1)
            embedding.append(layer(values))
        embedding = torch.cat(embedding, dim=1)
        embedding = self._layers_observe_embedding_final(embedding)
        return embedding

    def _infer_init(self, observe=None):
        self._infer_observe = observe
        embedding = []
        for name, layer in self._layers_observe_embedding.items():
            value = util.to_tensor(observe[name]).view(1, -1)
            embedding.append(layer(value))
        embedding = torch.cat(embedding, dim=1)
        self._infer_observe_embedding = self._layers_observe_embedding_final(embedding)

    def _init_layers(self):
        raise NotImplementedError()

    def _polymorph(self, batch):
        raise NotImplementedError()

    def _infer_step(self, variable, previous_variable=None, proposal_min_train_iterations=None):
        raise NotImplementedError()

    def _loss(self, batch):
        raise NotImplementedError()

    def _save(self, file_name):
        self._modified = util.get_time_str()
        self._updates += 1

        data = {}
        data['pyprob_version'] = __version__
        data['torch_version'] = torch.__version__
        # The following is due to a temporary hack related with https://github.com/pytorch/pytorch/issues/9981 and can be deprecated by using dill as pickler with torch > 0.4.1
        data['inference_network'] = copy.copy(self)
        data['inference_network']._model = None
        data['inference_network']._optimizer = None

        def thread_save():
            tmp_dir = tempfile.mkdtemp(suffix=str(uuid.uuid4()))
            tmp_file_name = os.path.join(tmp_dir, 'pyprob_inference_network')
            torch.save(data, tmp_file_name)
            tar = tarfile.open(file_name, 'w:gz', compresslevel=2)
            tar.add(tmp_file_name, arcname='pyprob_inference_network')
            tar.close()
            shutil.rmtree(tmp_dir)
        t = Thread(target=thread_save)
        t.start()
        t.join()

    @staticmethod
    def _load(file_name):
        try:
            tar = tarfile.open(file_name, 'r:gz')
            tmp_dir = tempfile.mkdtemp(suffix=str(uuid.uuid4()))
            tmp_file = os.path.join(tmp_dir, 'pyprob_inference_network')
            tar.extract('pyprob_inference_network', tmp_dir)
            tar.close()
            if util._cuda_enabled:
                data = torch.load(tmp_file)
            else:
                data = torch.load(tmp_file, map_location=lambda storage, loc: storage)
                shutil.rmtree(tmp_dir)
        except:
            raise RuntimeError('Cannot load inference network.')

        if data['pyprob_version'] != __version__:
            print(colored('Warning: different pyprob versions (loaded network: {}, current system: {})'.format(data['pyprob_version'], __version__), 'red', attrs=['bold']))
        if data['torch_version'] != torch.__version__:
            print(colored('Warning: different PyTorch versions (loaded network: {}, current system: {})'.format(data['torch_version'], torch.__version__), 'red', attrs=['bold']))

        ret = data['inference_network']
        if util._cuda_enabled:
            if ret._on_cuda:
                if ret._device != util._device:
                    print(colored('Warning: loading CUDA (device {}) network to CUDA (device {})'.format(ret._device, util._device), 'red', attrs=['bold']))
            else:
                print(colored('Warning: loading CPU network to CUDA (device {})'.format(util._device), 'red', attrs=['bold']))
        else:
            if ret._on_cuda:
                print(colored('Warning: loading CUDA (device {}) network to CPU'.format(ret._device), 'red', attrs=['bold']))
        ret.to(device=util._device)
        return ret

    def to(self, device=None, *args, **kwargs):
        self._device = device
        self._on_cuda = 'cuda' in str(device)
        super().to(device=device, *args, *kwargs)

    def _pre_generate_layers(self, dataset, batch_size=64, save_file_name_prefix=None):
        if not self._layers_initialized:
            self._init_layers_observe_embedding(self._observe_embeddings, example_trace=dataset.__getitem__(0))
            self._init_layers()
            self._layers_initialized = True

        self._layers_pre_generated = True
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0, collate_fn=lambda x: Batch(x))
        util.progress_bar_init('Layer pre-generation...', len(dataset), 'Traces')
        i = 0
        for i_batch, batch in enumerate(dataloader):
            i += len(batch)
            layers_changed = self._polymorph(batch)
            util.progress_bar_update(i)
            if layers_changed and (save_file_name_prefix is not None):
                file_name = '{}_00000000_pre_generated.network'.format(save_file_name_prefix)
                print('\rSaving to disk...  ', end='\r')
                self._save(file_name)
        util.progress_bar_end('Layer pre-generation complete')

    def _distributed_sync_parameters(self):
        """ broadcast rank 0 parameter to all ranks """
        # print('Distributed training synchronizing parameters across nodes...')
        #for param in self.parameters():
        #    dist.broadcast(param.data, 0)
        hvd.broadcast_parameters(self.state_dict(), root_rank=0)

    #def _distributed_zero_grad(self):
    #    # Create zero tensors for gradients not initialized at this distributed training rank
    #    # print('Distributed zeroing gradients...')
    #    for param in self.parameters():
    #        # if (param.grad is None):
    #        param.grad = util.to_tensor(torch.zeros_like(param.data))

    #def _distributed_sync_grad(self, world_size):
    #    """ all_reduce grads from all ranks """
    #    # print('Distributed training synchronizing gradients across nodes...')
    #    for param in self.parameters():
    #        try:
    #            dist.all_reduce(param.grad.data)
    #            param.grad.data /= float(world_size)  # average gradients
    #        except AttributeError:
    #            # pass
    #            print('None for grad, with param.size=', param.size())

    def _distributed_update_train_loss(self, loss, world_size, loss_moving_average_window_size):
        self._distributed_train_loss = util.to_tensor(float(loss))
        #dist.all_reduce(self._distributed_train_loss)
        #self._distributed_train_loss /= float(world_size)
        self._distributed_train_loss = hvd.allreduce(self._distributed_train_loss, name='avg_train_loss').item() # average= True by default for horovod
        self._distributed_history_train_loss.append(float(self._distributed_train_loss))
        self._distributed_history_train_loss_trace.append(self._total_train_traces)
        recent_losses=self._distributed_history_train_loss[(1-loss_moving_average_window_size):]
        self._distributed_filtered_train_loss=sum(recent_losses) /len(recent_losses)
        self._distributed_history_filtered_train_loss.append(self._distributed_filtered_train_loss)
        if float(self._distributed_filtered_train_loss) < self._distributed_filtered_train_loss_min:
            self._distributed_filtered_train_loss_min = float(self._distributed_filtered_train_loss)
        if float(self._distributed_train_loss) < self._distributed_train_loss_min:
            self._distributed_train_loss_min = float(self._distributed_train_loss)
        #print(colored('Distributed mean train. loss across ranks : {:+.2e}, min. train. loss: {:+.2e}'.format(self._distributed_train_loss, self._distributed_train_loss_min), 'yellow', attrs=['bold']))

    def _distributed_update_valid_loss(self, loss, world_size, loss_moving_average_window_size,Enable_MA_filter):
        self._distributed_valid_loss = util.to_tensor(float(loss))
        #dist.all_reduce(self._distributed_valid_loss)
        #self._distributed_valid_loss /= float(world_size)
        self._distributed_valid_loss = hvd.allreduce(self._distributed_valid_loss, name='avg_valid_loss').item() # average= True by default for horovod
        if float(self._distributed_valid_loss) < self._distributed_valid_loss_min:
            self._distributed_valid_loss_min = float(self._distributed_valid_loss)
        self._distributed_history_valid_loss.append(float(self._distributed_valid_loss))
        self._distributed_history_valid_loss_trace.append(self._total_train_traces)
        recent_losses=self._distributed_history_valid_loss[(1-loss_moving_average_window_size):]
        self._distributed_filtered_valid_loss = sum(recent_losses) /len(recent_losses)
        self._distributed_history_filtered_valid_loss.append(self._distributed_filtered_valid_loss)        
        if float(self._distributed_filtered_valid_loss) < self._distributed_filtered_valid_loss_min:
            self._distributed_filtered_valid_loss_min = float(self._distributed_filtered_valid_loss)
        if Enable_MA_filter:
            if (dist.get_rank ==0):
                print(colored('Filtered Distributed mean valid. loss across ranks : {:+.2e}, min. valid. loss: {:+.2e}'.format(self._distributed_filtered_valid_loss, self._distributed_filtered_valid_loss_min), 'yellow', attrs=['bold']))
        else:
            if (dist.get_rank ==0):
                print(colored('Distributed mean valid. loss across ranks : {:+.2e}, min. valid. loss: {:+.2e}'.format(self._distributed_valid_loss, self._distributed_valid_loss_min), 'yellow', attrs=['bold']))

    def optimize(self, num_traces, dataset, dataset_valid, batch_size=64, valid_every=None, optimizer_type=Optimizer.ADAM, learning_rate=0.0001, momentum=0.9, weight_decay=1e-5, save_file_name_prefix=None, save_every_sec=600, distributed_backend=None, distributed_params_sync_every=10000, distributed_loss_update_every=None, dataloader_offline_num_workers=0, stop_with_bad_loss=False, *args, **kwargs):
        if not self._layers_initialized:
            self._init_layers_observe_embedding(self._observe_embeddings, example_trace=dataset.__getitem__(0))
            self._init_layers()
            self._layers_initialized = True

        #hvd.init() #moved to model.py
        distributed_world_size = hvd.size()
        distributed_rank = hvd.rank()

        if distributed_backend is None:
            distributed_world_size = 1
            distributed_rank = 0
        else:
            #dist.init_process_group(backend=distributed_backend)
            #distributed_world_size = dist.get_world_size()
            #distributed_rank = dist.get_rank()
            #hvd.init()
            #distributed_world_size = hvd.size()
            #distributed_rank = hvd.rank()
            util.init_distributed_print(distributed_rank, distributed_world_size,True)
            if (distributed_rank ==0):
                print(colored('Distributed synchronous training', 'yellow', attrs=['bold']))
                print(colored('Distributed backend       : {}'.format(distributed_backend), 'yellow', attrs=['bold']))
                print(colored('Distributed world size    : {}'.format(distributed_world_size), 'yellow', attrs=['bold']))
                print(colored('Distributed minibatch size: {} (global), {} (per node)'.format(batch_size * distributed_world_size, batch_size), 'yellow', attrs=['bold']))
                print(colored('Distributed learning rate : {} (global), {} (base)'.format(learning_rate * distributed_world_size, learning_rate), 'yellow', attrs=['bold']))
                print(colored('Distributed optimizer     : {}'.format(str(optimizer_type)), 'yellow', attrs=['bold']))
            self._distributed_backend = distributed_backend
            self._distributed_world_size = distributed_world_size
        self._distributed_history_train_loss = []
        self._distributed_history_train_loss_trace = []
        self._distributed_history_valid_loss = []
        self._distributed_history_valid_loss_trace = []
        self._distributed_train_loss = util.to_tensor(0.)
        self._distributed_valid_loss = util.to_tensor(0.)
        self._distributed_train_loss_min = float('inf')
        self._distributed_valid_loss_min = float('inf')
        self._distributed_filtered_train_loss=util.to_tensor(0.)
        self._distributed_filtered_valid_loss = util.to_tensor(0.)
        self._distributed_history_filtered_train_loss = []
        self._distributed_history_filtered_valid_loss = []
        self._distributed_filtered_train_loss_min = float('inf')
        self._distributed_filtered_valid_loss_min = float('inf')

        self._optimizer_type = optimizer_type
        self._batch_size = batch_size
        self._learning_rate = learning_rate * distributed_world_size
        self._momentum = momentum
        self.train()
        prev_total_train_seconds = self._total_train_seconds
        time_start = time.time()
        time_loss_min = time.time()
        time_last_batch = time.time()
        if valid_every is None:
            valid_every = max(100, num_traces / 1000)
        if distributed_loss_update_every is None:
            distributed_loss_update_every = valid_every
        last_validation_trace = -valid_every + 1

        #for distributed training only
        loss_moving_average_window_size=10
        Enable_MA_filter = True
        per_rank_print = False        

        epoch = 0
        iteration = 0
        trace = 0
        stop = False
        print('Train. time | Epoch| Trace     | Init. loss| Min. loss | Curr. loss| T.since min | Traces/sec')
        max_print_line_len = 0
        loss_min_str = ''
        time_since_loss_min_str = ''
        last_auto_save_time = time.time() - save_every_sec
        if isinstance(dataset, OfflineDataset) and (distributed_world_size == 1):
            dataloader_epoch_one = DataLoader(dataset, batch_sampler=SortedTraceBatchSampler(dataset, batch_size=batch_size, shuffle=False), num_workers=dataloader_offline_num_workers, collate_fn=lambda x: Batch(x))
            dataloader_epoch_all = DataLoader(dataset, batch_sampler=SortedTraceBatchSampler(dataset, batch_size=batch_size, shuffle=True), num_workers=dataloader_offline_num_workers, collate_fn=lambda x: Batch(x))
        else:
#            dataloader_epoch_one = DataLoader(dataset, batch_size=batch_size, num_workers=0, collate_fn=lambda x: Batch(x))
            dataloader_epoch_one = DataLoader(dataset, batch_sampler=SortedTraceBatchSamplerDistributed(dataset, batch_size=batch_size,num_replicas=distributed_world_size,rank=distributed_rank, shuffle=False), num_workers=dataloader_offline_num_workers, collate_fn =lambda x: Batch(x))
            dataloader_epoch_all = dataloader_epoch_one
        if dataset_valid is not None:
            dataloader_valid = DataLoader(dataset_valid, batch_size=batch_size, num_workers=0, collate_fn=lambda x: Batch(x))

#        num_params=0
#        for p in self.parameters():
#            num_params += p.nelement()
#        num_named_param = 0
#        for k, p in self.named_parameters():
#            num_named_param +=p.nelement()
#        totsize=sum(p.numel() for p in self.parameters() if p.requires_grad)
#        print("num_param={}".format(num_params))
#        print("num_named_param={}".format(num_named_param))
#        print("total size of params need grad={}".format(totsize))


       # for horovod offline training test only
        self._distributed_sync_parameters()
        if self._layers_pre_generated:  # and (distributed_world_size > 1):
            layers_changed = False
        else:
            layers_changed = self._polymorph(batch)

        if (self._optimizer is None) or layers_changed:
            if optimizer_type == Optimizer.ADAM:
                self._optimizer = optim.Adam(self.parameters(), lr=learning_rate * distributed_world_size, weight_decay=weight_decay)
            else:  # optimizer_type == Optimizer.SGD
                self._optimizer = optim.SGD(self.parameters(), lr=learning_rate * distributed_world_size, momentum=momentum, nesterov=True, weight_decay=weight_decay)
        self._optimizer=hvd.DistributedOptimizer(self._optimizer, named_parameters=self.named_parameters(), compression = hvd.Compression.none)

        while not stop:
            epoch += 1
            dataloader = dataloader_epoch_one if epoch == 1 else dataloader_epoch_all
            for i_batch, batch in enumerate(dataloader):
                # Important, a self._distributed_sync_parameters() needs to happen at the very beginning of a training
                #if (distributed_world_size > 1) and (iteration % distributed_params_sync_every == 0):
                #    self._distributed_sync_parameters()

                #if self._layers_pre_generated:  # and (distributed_world_size > 1):
                #    layers_changed = False
                #else:
                #    layers_changed = self._polymorph(batch)

                #if (self._optimizer is None) or layers_changed:
                #    if optimizer_type == Optimizer.ADAM:
                #        self._optimizer = optim.Adam(self.parameters(), lr=learning_rate * distributed_world_size, weight_decay=weight_decay)
                #    else:  # optimizer_type == Optimizer.SGD
                #        self._optimizer = optim.SGD(self.parameters(), lr=learning_rate * distributed_world_size, momentum=momentum, nesterov=True, weight_decay=weight_decay)
                #self._optimizer=hvd.DistributedOptimizer(self._optimizer,named_parameters=self.named_parameters())
#                self._optimizer=hvd.DistributedOptimizer(self._optimizer, named_parameters=self.named_parameters(), compression = hvd.Compression.none, backward_passes_per_step=1)
#                self._optimizer=hvd.DistributedOptimizer(self._optimizer, named_parameters=self.named_parameters(), compression = hvd.Compression.none)
                #self._optimizer=hvd.DistributedOptimizer(self._optimizer, named_parameters=self.named_parameters(), compression = hvd.Compression.none) #Lei modified horovod interface
                # self._optimizer.zero_grad()
                #if distributed_world_size > 1:
                #    self._distributed_zero_grad()
                #else:
                #    self._optimizer.zero_grad()
                self._optimizer.zero_grad() #for hvd only
                success, loss = self._loss(batch)
                if not success:
                    print(colored('Cannot compute loss, skipping batch. Loss: {}'.format(loss), 'red', attrs=['bold']))
                    if stop_with_bad_loss:
                        return
                else:
                    loss.backward()
                    #if distributed_world_size > 1: #block for hvd
                    #    self._distributed_sync_grad(distributed_world_size)
                    self._optimizer.step()
                    loss = float(loss)

                    if self._loss_initial is None:
                        self._loss_initial = loss
                        self._loss_max = loss
                    loss_initial_str = '{:+.2e}'.format(self._loss_initial)
                    # loss_max_str = '{:+.3e}'.format(self._loss_max)
                    if loss < self._loss_min:
                        self._loss_min = loss
                        loss_str = colored('{:+.2e}'.format(loss), 'green', attrs=['bold'])
                        loss_min_str = colored('{:+.2e}'.format(self._loss_min), 'green', attrs=['bold'])
                        time_loss_min = time.time()
                        time_since_loss_min_str = colored(util.days_hours_mins_secs_str(0), 'green', attrs=['bold'])
                    elif loss > self._loss_max:
                        self._loss_max = loss
                        loss_str = colored('{:+.2e}'.format(loss), 'red', attrs=['bold'])
                        # loss_max_str = colored('{:+.3e}'.format(self._loss_max), 'red', attrs=['bold'])
                    else:
                        if loss < self._loss_previous:
                            loss_str = colored('{:+.2e}'.format(loss), 'green')
                        elif loss > self._loss_previous:
                            loss_str = colored('{:+.2e}'.format(loss), 'red')
                        else:
                            loss_str = '{:+.2e}'.format(loss)
                        loss_min_str = '{:+.2e}'.format(self._loss_min)
                        # loss_max_str = '{:+.3e}'.format(self._loss_max)
                        time_since_loss_min_str = util.days_hours_mins_secs_str(time.time() - time_loss_min)

                    self._loss_previous = loss
                    self._total_train_iterations += 1
                    trace += batch.size
                    self._total_train_traces += batch.size * distributed_world_size
                    total_train_traces_str = '{:9}'.format('{:,}'.format(self._total_train_traces))
                    epoch_str = '{:4}'.format('{:,}'.format(epoch))
                    self._total_train_seconds = prev_total_train_seconds + (time.time() - time_start)
                    total_training_seconds_str = util.days_hours_mins_secs_str(self._total_train_seconds)
                    traces_per_second_str = '{:,.1f}'.format(int(batch.size * distributed_world_size / (time.time() - time_last_batch)))
                    time_last_batch = time.time()
                    if num_traces is not None:
                        if trace >= num_traces:
                            stop = True

                    self._history_train_loss.append(loss)
                    self._history_train_loss_trace.append(self._total_train_traces)
                    if dataset_valid is not None:
                        if trace - last_validation_trace > valid_every:
                            print('\rComputing validation loss...  ', end='\r')
                            valid_loss = 0
                            with torch.no_grad():
                                for i_batch, batch in enumerate(dataloader_valid):
                                    _, v = self._loss(batch)
                                    valid_loss += v
                            valid_loss = float(valid_loss / len(dataset_valid))
                            self._history_valid_loss.append(valid_loss)
                            self._history_valid_loss_trace.append(self._total_train_traces)
                            last_validation_trace = trace - 1

                            if distributed_world_size > 1: #Lei blocked for debugging
                                self._distributed_update_train_loss(loss, distributed_world_size,loss_moving_average_window_size)
                                if not per_rank_print:
                                    if Enable_MA_filter:
                                        loss_str=colored('{:+.2e}'.format(self._distributed_filtered_train_loss), 'yellow')
                                        loss_min_str=colored('{:+.2e}'.format(self._distributed_filtered_train_loss_min), 'yellow')
                                    else:
                                        loss_str=colored('{:+.2e}'.format(self._distributed_train_loss), 'yellow')
                                        loss_min_str=colored('{:+.2e}'.format(self._distributed_train_loss_min), 'yellow')
                                self._distributed_update_valid_loss(valid_loss, distributed_world_size,loss_moving_average_window_size,Enable_MA_filter)

                    if (distributed_world_size > 1): # and (iteration % distributed_loss_update_every == 0):
                        self._distributed_update_train_loss(loss, distributed_world_size,loss_moving_average_window_size)
                        if not per_rank_print:
                            if Enable_MA_filter:
                                loss_str=colored('{:+.2e}'.format(self._distributed_filtered_train_loss), 'yellow')
                                loss_min_str=colored('{:+.2e}'.format(self._distributed_filtered_train_loss_min), 'yellow')
                            else:
                                loss_str=colored('{:+.2e}'.format(self._distributed_train_loss), 'yellow')
                                loss_min_str=colored('{:+.2e}'.format(self._distributed_train_loss_min), 'yellow')



                    if (distributed_rank == 0) and (save_file_name_prefix is not None):
                        if time.time() - last_auto_save_time > save_every_sec:
                            last_auto_save_time = time.time()
                            file_name = '{}_{}_traces_{}.network'.format(save_file_name_prefix, util.get_time_stamp(), self._total_train_traces)
                            print('\rSaving to disk...  ', end='\r')
#                            self._save(file_name)

                    print_line = '{} | {} | {} | {} | {} | {} | {} | {}'.format(total_training_seconds_str, epoch_str, total_train_traces_str, loss_initial_str, loss_min_str, loss_str, time_since_loss_min_str, traces_per_second_str)
                    max_print_line_len = max(len(print_line), max_print_line_len)
                    if per_rank_print:
                        print(print_line.ljust(max_print_line_len), end='\r')
                    else:
                        if (distributed_rank == 0):
                            print(print_line.ljust(max_print_line_len), end='\r')
                    sys.stdout.flush()
                    if stop:
                        break
                iteration += 1

        print()
        if (distributed_rank == 0) and (save_file_name_prefix is not None):
            file_name = '{}_{}_traces_{}.network'.format(save_file_name_prefix, util.get_time_stamp(), self._total_train_traces)
            print('\rSaving to disk...  ', end='\r')
            self._save(file_name)
