word_seq = Input(shape = (SEQ_LEN,), dtype = "int32", name = "word_seq")

query = Input(shape = (EMBED_DIM, ), dtype = "float32", name = "q_input")
#the query for lang. modeling is a constant vector filled with 0.1, as described at the bottom of page 7 in the first linked paper

T_A = Added_Weights(input_dim = (SEQ_LEN, EMBED_DIM))
#Added_Weights is a custom layer I wrote, which I'll post below
#These are the "positional encoding" components

T_C = Added_Weights(input_dim = (SEQ_LEN, EMBED_DIM))

Emb_A = Embedding(output_dim = EMBED_DIM, input_dim = VOCAB_SIZE, input_length = SEQ_LEN, name = "Emb_A")

Emb_C = Embedding(output_dim = EMBED_DIM, input_dim = VOCAB_SIZE, input_length = SEQ_LEN, name = "Emb_C")

int_state_weights = Dense(units = EMBED_DIM, activation = 'linear',
           kernel_initializer=RandomNormal(mean=0., stddev = 0.05, seed = None))

layer_output = query
#the loop uses the output from the previous layer as the query, but the first layer's query is just that constant vector

for i in range(0, NUM_LAYERS - 1):
    memories = Emb_A(word_seq) #these all re-use the weights instantiated earlier.

    memories = T_A(memories)

    memories = Dropout(DROPOUT_R)(memories)

    content = Emb_C(word_seq)

    content = T_C(content)

    mem_relevance = Dot(axes=[1, 2])([layer_output, memories])

    weighted_internal_state = int_state_weights(mem_relevance)

    mem_relevance = Softmax()(mem_relevance)

    content_relevance = Dot(axes=1)([mem_relevance,
                                content])  # weight each piece of content by it's probability of being relevant

    layer_output = Add()([content_relevance, weighted_internal_state])

    layer_output = Dropout(DROPOUT_R)(layer_output)

final_output = Dense(units = VOCAB_SIZE, activation ='relu',
                 kernel_initializer=RandomNormal(mean=0., stddev = 0.05, seed = None))(layer_output)

model = Model(inputs = [word_seq, query], outputs = prediction)
model.compile(optimizer = SGD(lr = 0.01, clipnorm = 50.), loss = 'categorical_crossentropy', metrics = ['accuracy'])
model.fit(x = [td_seqs, td_query], y = [td_labels],
      batch_size = BATCH_SIZE, callbacks = [lr_adjust, lr_termination, for_csv], epochs=200, verbose = 1)




