import torch
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
import torch.distributed as dist
from torch.nn.modules import Module
from collections import defaultdict
from torch.autograd import Variable
import torch.utils.hooks


class DistributedDataParallelCPU(Module):
    r"""Implements distributed data parallelism for CPU at the module level.

    This module supports the ``mpi`` and ``gloo`` backends.

    This container parallelizes the application of the given module by splitting
    the input across the specified devices by chunking in the batch
    dimension. The module is replicated on each machine, and each such replica
    handles a portion of the input. During the backwards pass, gradients from
    each node are averaged.

    This module could be used in conjunction with the DistributedSampler,
    (see :class `torch.utils.data.distributed.DistributedSampler`)
    which will load a subset of the original datset for each node with the same
    batch size. So strong scaling should be configured like this:

    n = 1, batch size = 12

    n = 2, batch size = 64

    n = 4, batch size = 32

    n = 8, batch size = 16

    Creation of this class requires the distributed package to be already
    initialized in the process group mode
    (see :func:`torch.distributed.init_process_group`).

    .. warning::
        Constructor, forward method, and differentiation of the output (or a
        function of the output of this module) is a distributed synchronization
        point. Take that into account in case different node might be
        executing different code.

    .. warning::
        This module assumes all parameters are registered in the model by the
        time it is created. No parameters should be added nor removed later.

    .. warning::
        This module assumes all gradients are dense.

    .. warning::
        This module doesn't work with :func:`torch.autograd.grad` (i.e. it will
        only work if gradients are to be accumulated in ``.grad`` attributes of
        parameters).

    .. warning::
        Forward and backward hooks defined on :attr:`module` and its submodules
        won't be invoked anymore, unless the hooks are initialized in the
        :meth:`forward` method.

    .. note::
        Parameters are broadcast between nodes in the __init__() function. The
        module performs an all-reduce step on gradients and assumes that they
        will be modified by the optimizer in all nodes in the same way.

    Args:
        module: module to be parallelized

    Example::

        >>> torch.distributed.init_process_group(world_size=4, init_method='...')
        >>> net = torch.nn.DistributedDataParallelCPU(model)
    """

    def __init__(self, module):
        super(DistributedDataParallelCPU, self).__init__()
        self.module = module
        self.sync_parameters()

        def allreduce_params():
            if self.needs_reduction:
                self.needs_reduction = False
                buckets = defaultdict(list)
                for param in self.module.parameters():
                    if param.requires_grad and param.grad is not None:
                        tp = type(param.data)
                        buckets[tp].append(param)

                for bucket in buckets.values():
                    grads = [param.grad.data for param in bucket]
                    coalesced = _flatten_dense_tensors(grads)
                    dist.all_reduce(coalesced)
                    coalesced /= dist.get_world_size()
                    for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
                        buf.copy_(synced)

        for param in list(self.module.parameters()):
            @torch.utils.hooks.unserializable_hook
            def allreduce_hook(*unused):
                Variable._execution_engine.queue_callback(allreduce_params)

            if param.requires_grad:
                param.register_hook(allreduce_hook)

    def sync_parameters(self):
        for param in self.module.parameters():
            dist.broadcast(param.data, 0)

    def forward(self, *inputs, **kwargs):
        self.needs_reduction = True
        return self.module(*inputs, **kwargs)
