import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers import Layer
import tensorflow.keras.backend as K
from collections import namedtuple


TopKWordPredictionsLayerResult = namedtuple('TopKWordPredictionsLayerResult', ['words', 'scores'])


class TopKWordPredictionsLayer(Layer):
    def __init__(self,
                 top_k: int,
                 index_to_word_table: tf.lookup.StaticHashTable,
                 **kwargs):
        kwargs['dtype'] = tf.string
        kwargs['trainable'] = False
        super(TopKWordPredictionsLayer, self).__init__(**kwargs)
        self.top_k = top_k
        self.index_to_word_table = index_to_word_table

    def build(self, input_shape):
        if len(input_shape) < 2:
            raise ValueError("Input shape for TopKWordPredictionsLayer should be of >= 2 dimensions.")
        if input_shape[-1] < self.top_k:
            raise ValueError("Last dimension of input shape for TopKWordPredictionsLayer should be of >= `top_k`.")
        super(TopKWordPredictionsLayer, self).build(input_shape)
        self.trainable = False

    def call(self, y_pred, **kwargs) -> TopKWordPredictionsLayerResult:
        top_k_pred_scores, top_k_pred_indices = tf.nn.top_k(y_pred, k=self.top_k, sorted=True)
        top_k_pred_indices = tf.cast(top_k_pred_indices, dtype=self.index_to_word_table.key_dtype)
        top_k_pred_words = self.index_to_word_table.lookup(top_k_pred_indices)

        return TopKWordPredictionsLayerResult(words=top_k_pred_words, scores=top_k_pred_scores)

    def compute_output_shape(self, input_shape):
        output_shape = tuple(input_shape[:-1]) + (self.top_k, )
        return output_shape, output_shape