from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from multiprocessing import Process, Manager

import numpy as np
import unittest
import tempfile
import shutil
import logging

from hypothesis import given
import hypothesis.strategies as st

log = logging.getLogger("parallelize_bmuf_distributed_test")
log.setLevel(logging.INFO)


def bmuf_process(filestore_dir, process_id, shared_results,
                 cpu_device=False, nesterov=False):
    # We need to import caffe2 in every process to initialize CUDA independently.
    from caffe2.python import core, cnn, data_parallel_model, dyndep, workspace
    from caffe2.proto import caffe2_pb2
    dyndep.InitOpsLibrary("@/caffe2/caffe2/distributed:file_store_handler_ops")

    if not cpu_device:
        if not workspace.has_gpu_support and not workspace.has_hip_support:
            log.info('No GPU support test is Ignored.')
            return
        if workspace.NumGpuDevices() < 4:
            log.info('Not enough GPU support, test IGNORED')
            return

    model = cnn.CNNModelHelper(
        order="NHWC",
        name="test"
    )
    if not cpu_device:
        device_type = workspace.GpuDeviceType
        device_prefix = "gpu"
    else:
        device_type = caffe2_pb2.CPU
        device_prefix = "cpu"

    devices = [0, 1] if process_id == 0 else [2, 3]

    def _model_build_fun(model, loss_scale):
        fc = model.FC(
            "data", "fc", 16, 1, ("ConstantFill", {}), ("ConstantFill", {})
        )
        fc_fl = model.FlattenToVec(fc, "fc_fl")
        sigm = model.Sigmoid(fc_fl, "sigm")
        sq = model.SquaredL2Distance([sigm, "label"], "sq")
        loss = model.AveragedLoss(sq, "loss")
        loss = model.Scale(loss, scale=loss_scale)

        # For testing explicit sync
        model.param_init_net.UniformFill([], ["sync_num"], shape=[1])
        return [loss]

    def _input_builder_fun(model):
        return None

    def _param_update_fun(model):
        ITER = model.Iter("ITER")
        LR = model.net.LearningRate(
            [ITER],
            "LR",
            base_lr=(-0.1),
            policy="fixed",
        )
        ONE = model.param_init_net.ConstantFill(
            [], "ONE", shape=[1], value=1.0,
        )
        for param in model.GetParams():
            grad = model.param_to_grad[param]
            model.WeightedSum([param, ONE, grad, LR], param)

    def _generate_data(devices, process_id, device_type, device_prefix):
        np.random.seed(26 + process_id * 10)
        # Each run has same input, independent of number of gpus
        batch_size = 64
        for _ in range(0, 10):
            full_data = np.random.rand(batch_size, 16)
            full_labels = np.round(full_data[:, 0])
            batch_per_device = batch_size // len(devices)

            for (j, g) in enumerate(devices):
                st = j * batch_per_device
                en = st + batch_per_device
                data = full_data[st:en, :].astype(np.float32)
                labels = full_labels[st:en].astype(np.float32)
                with core.DeviceScope(core.DeviceOption(device_type, g)):
                    workspace.FeedBlob("{}_{}/data".format(device_prefix, g), data)
                    workspace.FeedBlob("{}_{}/label".format(device_prefix, g), labels)

    _generate_data(devices, process_id, device_type, device_prefix)

    workspace.RunOperatorOnce(
        core.CreateOperator(
            "FileStoreHandlerCreate", [], ["store_handler"],
            path=filestore_dir
        )
    )
    rendezvous = dict(
        kv_handler="store_handler",
        shard_id=process_id,
        num_shards=2,
        engine="GLOO",
        exit_nets=None
    )

    data_parallel_model.Parallelize_BMUF(
        model,
        _input_builder_fun,
        _model_build_fun,
        _param_update_fun,
        devices=devices,
        rendezvous=rendezvous,
        nesterov=nesterov,
        add_blobs_to_sync=["sync_num"],
        cpu_device=cpu_device
    )

    data_parallel_model.RunInitNet(model)

    def _device_pid(device, pid):
        if pid == 1:
            return device + 2
        return device

    np.testing.assert_equal(
        workspace.FetchBlob("{}_{}/fc_w_v".format(
            device_prefix, _device_pid(0, process_id))),
        np.zeros(16).astype(np.float32).reshape(1, 16)
    )

    # Run the algorithm for one iteration to have non-zero params.
    data_parallel_model.RunNet(model, 1)

    # Save iteration momentum and post local update params
    results = {}
    v_b_ = workspace.FetchBlob(
        "{}_{}/fc_b_v".format(device_prefix, _device_pid(0, process_id)))
    v_w_ = workspace.FetchBlob(
        "{}_{}/fc_w_v".format(device_prefix, _device_pid(0, process_id)))

    results['v_b_'] = v_b_
    results['v_w_'] = v_w_

    workspace.RunNetOnce(model.net)

    b_0_ = workspace.FetchBlob(
        "{}_{}/fc_b".format(device_prefix, _device_pid(0, process_id)))
    w_0_ = workspace.FetchBlob(
        "{}_{}/fc_w".format(device_prefix, _device_pid(0, process_id)))
    b_1_ = workspace.FetchBlob(
        "{}_{}/fc_b".format(device_prefix, _device_pid(1, process_id)))
    w_1_ = workspace.FetchBlob(
        "{}_{}/fc_w".format(device_prefix, _device_pid(1, process_id)))

    results['b_0_'] = b_0_
    results['w_0_'] = w_0_
    results['b_1_'] = b_1_
    results['w_1_'] = w_1_

    # Test sync
    if process_id == 0:
        workspace.FeedBlob(
            device_prefix + "_0/sync_num",
            np.array([2603]).astype(np.float32),
            device_option=core.DeviceOption(device_type, 0))

    # Compute block gradients.
    b_g_ = workspace.FetchBlob(
        "{}_{}/fc_b_g".format(device_prefix, _device_pid(0, process_id)))
    w_g_ = workspace.FetchBlob(
        "{}_{}/fc_w_g".format(device_prefix, _device_pid(0, process_id)))
    results['b_g_'] = b_g_
    results['w_g_'] = w_g_
    workspace.RunNetOnce(model._global_model_param_updates_net)

    #  g_b = (b_0_ + b_1_) / 2 - b_g_
    #  g_w = (w_0_ + w_1_) / 2 - w_g_
    v_b = workspace.FetchBlob(
        "{}_{}/fc_b_v".format(device_prefix, _device_pid(0, process_id)))
    v_w = workspace.FetchBlob(
        "{}_{}/fc_w_v".format(device_prefix, _device_pid(0, process_id)))
    w_g = workspace.FetchBlob(
        "{}_{}/fc_w_g".format(device_prefix, _device_pid(0, process_id)))
    b_g = workspace.FetchBlob(
        "{}_{}/fc_b_g".format(device_prefix, _device_pid(0, process_id)))
    w_0 = workspace.FetchBlob(
        "{}_{}/fc_w".format(device_prefix, _device_pid(0, process_id)))
    b_0 = workspace.FetchBlob(
        "{}_{}/fc_b".format(device_prefix, _device_pid(0, process_id)))
    w_1 = workspace.FetchBlob(
        "{}_{}/fc_w".format(device_prefix, _device_pid(1, process_id)))
    b_1 = workspace.FetchBlob(
        "{}_{}/fc_b".format(device_prefix, _device_pid(1, process_id)))
    results['v_b'] = v_b
    results['v_w'] = v_w
    results['w_g'] = w_g
    results['b_g'] = b_g
    results['w_0'] = w_0
    results['b_0'] = b_0
    results['w_1'] = w_1
    results['b_1'] = b_1

    # Test add_blobs_to_sync
    for j in devices:
        sync = workspace.FetchBlob(
            device_prefix + "_{}/sync_num".format(j))[0]
        results['sync_{}'.format(j)] = sync

    shared_results[process_id] = results


class DistributedTest(unittest.TestCase):

    @given(
        cpu_device=st.booleans(),
        nesterov=st.booleans()
    )
    def test_bmuf_distributed(self, cpu_device, nesterov):
        self._test_bmuf_distributed(cpu_device=cpu_device, nesterov=nesterov)

    def _test_bmuf_distributed(self, cpu_device=False, nesterov=False):
        processes = []
        filestore_dir = tempfile.mkdtemp()
        results = Manager().dict()
        for idx in range(0, 2):
            process = Process(
                target=bmuf_process,
                args=(filestore_dir, idx, results, cpu_device, nesterov)
            )
            processes.append(process)
            process.start()

        while len(processes) > 0:
            process = processes.pop()
            process.join()
        shutil.rmtree(filestore_dir)

        if len(results) == 0:
            return

        w_0 = results[0]['w_0']
        w_1 = results[0]['w_1']
        b_0 = results[0]['b_0']
        b_1 = results[0]['b_1']
        # Check parameters are in sync.
        np.testing.assert_equal(w_0, w_1)
        np.testing.assert_equal(w_0, results[1]['w_0'])
        np.testing.assert_equal(w_0, results[1]['w_1'])
        np.testing.assert_equal(b_0, b_1)
        np.testing.assert_equal(b_0, results[1]['b_0'])
        np.testing.assert_equal(b_0, results[1]['b_1'])

        w_g_ = results[0]['w_g_']
        b_g_ = results[0]['b_g_']

        g_b = (results[0]['b_0_'] + results[1]['b_0_'] + results[0]['b_1_'] +
               results[1]['b_1_']) / 4 - b_g_
        g_w = (results[0]['w_0_'] + results[1]['w_0_'] + results[0]['w_1_'] +
               results[1]['w_1_']) / 4 - w_g_
        v_b_ = results[0]['v_b_']
        v_b = results[0]['v_b']
        v_w_ = results[0]['v_w_']
        v_w = results[0]['v_w']

        for pid in results.keys():
            for k in results[pid].keys():
                if k.startswith("sync_num"):
                    self.assertEqual(2603, results[pid][k])

        # Check block gradients are correct.
        np.testing.assert_almost_equal(v_b, 0.75 * v_b_ + g_b)
        np.testing.assert_almost_equal(v_w, 0.75 * v_w_ + g_w)

        # Check params update step
        if nesterov:
            np.testing.assert_equal(w_0, w_g_ + v_w - 0.75 * (v_w - v_w_))
            np.testing.assert_equal(b_0, b_g_ + v_b - 0.75 * (v_b - v_b_))
        else:
            np.testing.assert_equal(w_0, w_g_ + v_w)
            np.testing.assert_equal(b_0, b_g_ + v_b)
