## @package sparse_lookup
# Module caffe2.python.layers.sparse_lookup
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

from caffe2.python.helpers.arg_scope import get_current_scope
from caffe2.python import schema
from caffe2.python.layers.layers import (
    get_categorical_limit,
    get_key,
    IdList,
    IdScoreList,
    LayerPsParam,
    ModelLayer,
)
import collections
import functools
import math
import numpy as np
import operator


def get_sparse_lookup_predictor_version(version):
    assert version in {'fp32', 'fp16', 'uint8rowwise', 'fused_uint8rowwise'},\
        "Unexpected version of sparse_lookup layer {0}".format(version)
    return version


def get_sparse_lookup_trainer_version(version):
    assert version in {'fp32', 'fp16'},\
        "Unexpected version of sparse_lookup layer {0}".format(version)
    return version


def _is_id_list(input_record):
    return schema.equal_schemas(input_record, IdList)


def _is_id_score_list(input_record):
    return schema.equal_schemas(input_record,
                                IdScoreList,
                                check_field_types=False)


class SparseLookup(ModelLayer):
    _id_list_supported_reducers = [
        'LogMeanExp', 'LogSumExp', 'Max', 'Mean', 'Sum',
        'WeightedSum', 'WeightedMean', 'Sqrt', 'None']

    _id_score_list_supported_reducers = [
        'PositionWeighted', 'RecencyWeighted', 'Mean', 'Sum', 'WeightedSum',
        'WeightedMean', 'None'
    ]

    def __init__(self, model, input_record, inner_shape, reducer,
                 weight_init=None, weight_optim=None,
                 name='sparse_lookup', regularizer=None, **kwargs):

        super(SparseLookup, self).__init__(model, name, input_record, **kwargs)

        # TODO Add some asserts about input type
        if isinstance(inner_shape, int):
            inner_shape = [inner_shape]
        assert isinstance(inner_shape, list) or isinstance(inner_shape, tuple),\
            "Unexpected type for inner_shape, expected list or tuple, got {0}".\
            format(type(inner_shape))

        if reducer == "PositionWeighted":
            assert _is_id_score_list(self.input_record), (
                "PositionWeighted only support IdScoreList, but got {} " +
                "please use PositionWeighted layer to convert IdList " +
                "to IdScoreList").format(repr(self.input_record))
            self.external_weights = input_record.values()

        elif reducer == "RecencyWeighted":
            assert _is_id_score_list(self.input_record), (
                "RecencyWeighted only supports IdScoreList.")
            self.external_weights = input_record.values()
        self.reducer = reducer

        input_dim = get_categorical_limit(input_record)
        assert input_dim > 0, (
            "{} should have categorical limit > 0, but got {}".format(
                get_key(input_record)(), input_dim))

        self.input_dim = input_dim
        self.shape = [input_dim] + inner_shape

        default_init_op = self._get_default_init_op()

        self.weight_init = weight_init or default_init_op

        if _is_id_list(self.input_record):
            sparse_key = self.input_record.items()
        elif _is_id_score_list(self.input_record):
            sparse_key = self.input_record.keys()
        else:
            raise NotImplementedError()

        if self.input_record.lengths.metadata:
            avg_length = self.input_record.lengths.metadata.expected_value
        else:
            avg_length = None

        self.w = self.create_param(
            param_name='w',
            shape=self.shape,
            initializer=self.weight_init,
            optimizer=weight_optim,
            ps_param=LayerPsParam(
                sparse_key=sparse_key,
                average_length=avg_length),
            regularizer=regularizer
        )

        self.scale_bias_init = ('ConstantFill', {'value': 0.0})

        self.scale_bias = self.create_param(
            param_name='scale_bias',
            shape=[],
            initializer=self.scale_bias_init,
            optimizer=model.NoOptim,
        )

        self.output_schema = schema.Scalar(
            (np.float32, inner_shape),
            self.get_next_blob_reference('output'),
        )

    def get_memory_usage(self):
        return functools.reduce(operator.mul, self.shape) * 4

    def get_fp16_compatible_parameters(self):
        return [self.w]

    def support_8bit(self):
        # Rowwise quantization makes sense only if shape it's 2D matrix with
        # second dimension >= 8
        if len(self.shape) != 2 or self.shape[1] < 8:
            return False
        return True

    def get_8bits_compatible_parameters(self, fused=True):
        if not self.support_8bit():
            return []
        if fused:
            RowwiseQuantized8BitsWeight = collections.namedtuple(
                'RowwiseQuantized8BitsWeight', 'w'
            )
            return [RowwiseQuantized8BitsWeight(self.w)]
        else:
            RowwiseQuantized8BitsWeight = collections.namedtuple(
                'RowwiseQuantized8BitsWeight', 'w, scale_bias'
            )
            return [RowwiseQuantized8BitsWeight(self.w, self.scale_bias)]

    def _get_default_init_op(self):
        scale = math.sqrt(1.0 / self.input_dim)

        cur_scope = get_current_scope()
        trainer_version = get_sparse_lookup_trainer_version(
            **cur_scope.get(get_sparse_lookup_trainer_version.__name__,
                            {'version': 'fp32'}))

        if trainer_version == 'fp32':
            default_weight_init = ('UniformFill', {'min': -scale, 'max': scale})
        elif trainer_version == 'fp16':
            default_weight_init = ("Float16UniformFill", {'min': -scale, 'max': scale})
        else:
            raise NotImplementedError(
                "Train version {} is not currently supported".format(trainer_version)
            )

        self.trainer_version = trainer_version

        return default_weight_init

    def _gather_wrapper(self, net, version, in_indices, out):
        # Gather can work on all kinds of input data types, and output
        # data with the same type. Convert the output of Gather to float,
        # because the follow-up Ops expect fp32.
        if version == 'fp32':
            return net.Gather([self.w, in_indices], out)
        elif version == 'fp16':
            gathered_w = net.Gather([self.w, in_indices], 'gathered_w')

            return net.HalfToFloat(gathered_w, out)
        elif version == 'uint8rowwise':
            gathered_w = net.Gather([self.w, in_indices], 'gathered_w')
            gathered_scale_bias = net.Gather(
                [self.scale_bias, in_indices],
                'gathered_scale_bias'
            )

            return net.Rowwise8BitQuantizedToFloat(
                [gathered_w, gathered_scale_bias], out)
        elif version == 'fused_uint8rowwise':
            gathered_w = net.Gather([self.w, in_indices], 'gathered_w')
            return net.Fused8BitRowwiseQuantizedToFloat(gathered_w, out)
        else:
            raise "Unsupported version of operators in SparseLookup " +\
                "layer: {0}".format(version)

    def _sparse_lengths_weighted_reducer(
            self, in_indices, weights, reducer,
            net, version, grad_on_weights=0):
        op_input = [
            self.w,
            weights,
            in_indices,
            self.input_record.lengths()
        ]
        layer_name = 'SparseLengths' + reducer

        if version in ['fp32', 'fp16']:
            # SparseLengths* Ops will accept either fp16 or fp32 embedding
            # matrix and output fp32 pooled embedding
            # A special case here is that we need FP16 engine for
            # SparseLengthsWeightedSum when FP16 embeedings are used for
            # correct backward updates
            if reducer == "WeightedSum" and version == "fp16":
                net.SparseLengthsWeightedSum(
                    op_input,
                    self.output_schema.field_blobs(),
                    grad_on_weights=grad_on_weights,
                    engine='FP16',
                )
            else:
                net.__getattr__(layer_name)(
                    op_input,
                    self.output_schema.field_blobs(),
                    grad_on_weights=grad_on_weights,
                )
        elif version == 'uint8rowwise':
            op_input.insert(len(op_input), self.scale_bias)
            net.__getattr__(layer_name + '8BitsRowwise')(
                op_input, self.output_schema.field_blobs())
        elif version == 'fused_uint8rowwise':
            net.__getattr__(layer_name + 'Fused8BitRowwise')(
                op_input, self.output_schema.field_blobs())
        else:
            raise "Unsupported version of operator in SparseLookUp " +\
                "layer: {0}".format(version)

    # deal with sparse features of id_list type
    def _add_ops_id_list(self, net, version):
        assert self.reducer in self._id_list_supported_reducers, (
            "Unsupported reducer: {} for ID_LIST".format(self.reducer)
        )
        if self.reducer in ['Sum', 'Mean', 'WeightedSum', 'WeightedMean']:
            op_input = [self.w,
                        self.input_record.items(),
                        self.input_record.lengths()]

            # For id list features, the behaviors of 'Sum' and
            # 'WeightedSum' are identical, since we can regard the weight on each
            # id as 1. Similarly, for 'Mean' and 'WeightedMean'.
            if self.reducer == 'WeightedSum':
                self.reducer = 'Sum'
            elif self.reducer == 'WeightedMean':
                self.reducer = 'Mean'

            layer_name = 'SparseLengths' + self.reducer
            if version in ['fp32', 'fp16']:
                # SparseLengths* Ops will accept either fp16 or fp32 embedding
                # matrix and output fp32 pooled embedding
                net.__getattr__(layer_name)(
                    op_input,
                    self.output_schema.field_blobs(),
                )
            elif version == 'uint8rowwise':
                op_input.insert(len(op_input), self.scale_bias)
                net.__getattr__(layer_name + '8BitsRowwise')(
                    op_input, self.output_schema.field_blobs())
            elif version == 'fused_uint8rowwise':
                net.__getattr__(layer_name + 'Fused8BitRowwise')(
                    op_input, self.output_schema.field_blobs())
            else:
                raise "Unsupported version of operator in SparseLookUp " +\
                    "layer: {0}".format(version)

        elif self.reducer == 'Sqrt':
            sqrt_weight = net.LengthsToWeights(
                [self.input_record.lengths()],
                [net.NextScopedBlob('lengths_sqrt')],
                power=0.5,
            )
            self._sparse_lengths_weighted_reducer(
                self.input_record.items(),
                sqrt_weight,
                'WeightedSum', net, version)

        elif self.reducer == 'None':
            # Gather operator will gather the embedding for each id of
            # each IdList.
            self._gather_wrapper(net, version, self.input_record.items(),
                                 self.output_schema.field_blobs())

        else:
            table_rows = self._gather_wrapper(
                net, version, self.input_record.items(), 'table_rows')

            segment_ids = net.LengthsToSegmentIds(
                self.input_record.lengths(),
                net.NextScopedBlob(self.input_record.lengths() + '_sid'))
            net.__getattr__('SortedSegmentRange' + self.reducer)(
                [table_rows, segment_ids],
                self.output_schema.field_blobs(),
            )

    # deal with sparse features of id_score_list type
    def _add_ops_id_score_list(self, net, version):
        assert self.reducer in self._id_score_list_supported_reducers, (
            "Unsupported reducer: {} for ID_SCORE_LIST".format(self.reducer)
        )
        if self.reducer in ['WeightedSum', 'WeightedMean']:
            self._sparse_lengths_weighted_reducer(
                self.input_record.keys(),
                self.input_record.values(),
                self.reducer, net, version)

        elif self.reducer in ['Sum', 'Mean']:
            op_input = [self.w,
                        self.input_record.keys(),
                        self.input_record.lengths()]

            layer_name = 'SparseLengths' + self.reducer

            if version in ['fp32', 'fp16']:
                net.__getattr__(layer_name)(
                    op_input,
                    self.output_schema.field_blobs(),
                )
            elif version == 'uint8rowwise':
                net.__getattr__(layer_name + '8BitsRowwise')(
                    op_input, self.output_schema.field_blobs())
            elif version == 'fused_uint8rowwise':
                net.__getattr__(layer_name + 'Fused8BitRowwise')(
                    op_input, self.output_schema.field_blobs())
            else:
                raise "Unsupported version of operator in SparseLookUp " +\
                    "layer: {0}".format(version)

        elif self.reducer in ['PositionWeighted', 'RecencyWeighted']:
            self._sparse_lengths_weighted_reducer(
                self.input_record.keys(),
                self.external_weights,
                'WeightedSum', net, version, grad_on_weights=1)

        elif self.reducer == 'None':
            # Gather operator will gather the embedding for each id of
            # each IdList.
            self._gather_wrapper(net, version, self.input_record.keys(),
                                 self.output_schema.field_blobs())
        else:
            raise "Only Sum, Mean, None are supported for IdScoreList input." +\
                "Trying to create with {}".format(self.reducer)

    def _add_ops(self, net, version='fp32'):
        if _is_id_list(self.input_record):
            self._add_ops_id_list(net, version=version)
        elif _is_id_score_list(self.input_record):
            self._add_ops_id_score_list(net, version=version)
        else:
            raise "Unsupported input type {0}".format(self.input_record)

    def add_train_ops(self, net):
        self._add_ops(net, self.trainer_version)

    def add_ops(self, net):
        cur_scope = get_current_scope()
        version = get_sparse_lookup_predictor_version(
            **cur_scope.get(get_sparse_lookup_predictor_version.__name__,
                            {'version': 'fp32'}))

        # TODO(amalevich): Layer should not be responsible for decision about
        # quantization.
        if not self.support_8bit() and version in {'uint8rowwise',
                                                   'fused_uint8rowwise'}:
            version = 'fp32'

        self._add_ops(net, version)
