import tensorflow as tf
import tensorflow.keras.backend as K

import abc
from typing import Optional, Callable, List
from functools import reduce

from common import common


class WordsSubtokenMetricBase(tf.metrics.Metric): # KIR
    FilterType = Callable[[tf.Tensor, tf.Tensor], tf.Tensor]

    def __init__(self,
                 index_to_word_table: Optional[tf.lookup.StaticHashTable] = None,
                 topk_predicted_words=None,
                 predicted_words_filters: Optional[List[FilterType]] = None,
                 subtokens_delimiter: str = '|', name=None, dtype=None):
        super(WordsSubtokenMetricBase, self).__init__(name=name, dtype=dtype)
        self.tp = self.add_weight('true_positives', shape=(), initializer=tf.zeros_initializer)
        self.fp = self.add_weight('false_positives', shape=(), initializer=tf.zeros_initializer)
        self.fn = self.add_weight('false_negatives', shape=(), initializer=tf.zeros_initializer)
        self.index_to_word_table = index_to_word_table
        self.topk_predicted_words = topk_predicted_words
        self.predicted_words_filters = predicted_words_filters
        self.subtokens_delimiter = subtokens_delimiter

    def _get_true_target_word_string(self, true_target_word):
        if self.index_to_word_table is None:
            return true_target_word
        true_target_word_index = tf.cast(true_target_word, dtype=self.index_to_word_table.key_dtype)
        return self.index_to_word_table.lookup(true_target_word_index)

    def update_state(self, true_target_word, predictions, sample_weight=None):
        """Accumulates true positive, false positive and false negative statistics."""
        if sample_weight is not None:
            raise NotImplemented("WordsSubtokenMetricBase with non-None `sample_weight` is not implemented.")

        # For each example in the batch we have:
        #     (i)  one ground true target word;
        #     (ii) one predicted word (argmax y_hat)

        topk_predicted_words = predictions if self.topk_predicted_words is None else self.topk_predicted_words
        assert topk_predicted_words is not None
        predicted_word = self._get_prediction_from_topk(topk_predicted_words)

        true_target_word_string = self._get_true_target_word_string(true_target_word)
        true_target_word_string = tf.reshape(true_target_word_string, [-1])

        # We split each word into subtokens
        true_target_subwords = tf.compat.v1.string_split(true_target_word_string, sep=self.subtokens_delimiter)
        prediction_subwords = tf.compat.v1.string_split(predicted_word, sep=self.subtokens_delimiter)
        true_target_subwords = tf.sparse.to_dense(true_target_subwords, default_value='<PAD>')
        prediction_subwords = tf.sparse.to_dense(prediction_subwords, default_value='<PAD>')
        true_target_subwords_mask = tf.not_equal(true_target_subwords, '<PAD>')
        prediction_subwords_mask = tf.not_equal(prediction_subwords, '<PAD>')
        # Now shapes of true_target_subwords & true_target_subwords are (batch, subtokens)

        # We use broadcast to calculate 2 lists difference with duplicates preserving.
        true_target_subwords = tf.expand_dims(true_target_subwords, -1)
        prediction_subwords = tf.expand_dims(prediction_subwords, -1)
        # Now shapes of true_target_subwords & true_target_subwords are (batch, subtokens, 1)
        true_target_subwords__in__prediction_subwords = \
            tf.reduce_any(tf.equal(true_target_subwords, tf.transpose(prediction_subwords, perm=[0, 2, 1])), axis=2)
        prediction_subwords__in__true_target_subwords = \
            tf.reduce_any(tf.equal(prediction_subwords, tf.transpose(true_target_subwords, perm=[0, 2, 1])), axis=2)

        # Count ground true label subwords that exist in the predicted word.
        batch_true_positive = tf.reduce_sum(tf.cast(
            tf.logical_and(prediction_subwords__in__true_target_subwords, prediction_subwords_mask), tf.float32))
        # Count ground true label subwords that don't exist in the predicted word.
        batch_false_positive = tf.reduce_sum(tf.cast(
            tf.logical_and(~prediction_subwords__in__true_target_subwords, prediction_subwords_mask), tf.float32))
        # Count predicted word subwords that don't exist in the ground true label.
        batch_false_negative = tf.reduce_sum(tf.cast(
            tf.logical_and(~true_target_subwords__in__prediction_subwords, true_target_subwords_mask), tf.float32))

        self.tp.assign_add(batch_true_positive)
        self.fp.assign_add(batch_false_positive)
        self.fn.assign_add(batch_false_negative)

    def _get_prediction_from_topk(self, topk_predicted_words):
        # apply given filter
        masks = []
        if self.predicted_words_filters is not None:
            masks = [fltr(topk_predicted_words) for fltr in self.predicted_words_filters]
        if masks:
            # assert all(mask.shape.assert_is_compatible_with(top_k_pred_indices) for mask in masks)
            legal_predicted_target_words_mask = reduce(tf.logical_and, masks)
        else:
            legal_predicted_target_words_mask = tf.cast(tf.ones_like(topk_predicted_words), dtype=tf.bool)

        # the first legal predicted word is our prediction
        first_legal_predicted_target_word_mask = common.tf_get_first_true(legal_predicted_target_words_mask)
        first_legal_predicted_target_word_idx = tf.where(first_legal_predicted_target_word_mask)
        first_legal_predicted_word_string = tf.gather_nd(topk_predicted_words,
                                                         first_legal_predicted_target_word_idx)

        prediction = tf.reshape(first_legal_predicted_word_string, [-1])
        return prediction

    @abc.abstractmethod
    def result(self):
        ...

    def reset_states(self):
        for v in self.variables:
            K.set_value(v, 0)


class WordsSubtokenPrecisionMetric(WordsSubtokenMetricBase):
    def result(self):
        precision = tf.math.divide_no_nan(self.tp, self.tp + self.fp)
        return precision


class WordsSubtokenRecallMetric(WordsSubtokenMetricBase):
    def result(self):
        recall = tf.math.divide_no_nan(self.tp, self.tp + self.fn)
        return recall


class WordsSubtokenF1Metric(WordsSubtokenMetricBase):
    def result(self):
        recall = tf.math.divide_no_nan(self.tp, self.tp + self.fn)
        precision = tf.math.divide_no_nan(self.tp, self.tp + self.fp)
        f1 = tf.math.divide_no_nan(2 * precision * recall, precision + recall + K.epsilon())
        return f1