With an attention mechanism, the data feeding the decoder is now the entire encoded sequence, solving the information bottleneck of the previous architecture.
Attention can be implemented in many flavors.
To provide flexibility, my implementation used the abstract notation of the Query-Key-Value that is described in the Attention is All you Need paper.
A benefit of this approach is the resulting modularity of the translation system.
The Encoder provides both the Value and Key matrix, the latter being a transformation of the former, while the Decoder provides the Query, which is the vector of features of the decoded token.
The Attention module will return a weighted average of the Value matrix (the attention vector), which will be used to enhance the vector of features during decoding.
Multiple variations on the Query-Key-Value paradigm are possible.
Bilinear and MLP have been implemented in addition to the Dot attention illustrated below:attn <- attn_dot(value=value, query_key_size=num_hidden, scale=T)init <- attn$init()attend <- attn$attendattention <- attend(query=query, key=init$key, value=init$value, attn_init=init)The above is the actual MXNet graph for the dot-attention where the batch size is 128 (last dimension).
For each token to be decoded, a query is the reprojection of the 512-length representation of that token.
The value is the full encoding of the source sequence.
It is itself reprojected to form a key on which a dot product on the query is applied to obtain the weighting scheme to be applied to the value matrix.
The resulting 512-length vector is called the context vector, which is then appended to the original token encoding to calculate the score associated to each word of the target vocabulary.
The final component of the model is the softmax loss function.
It normalizes the above scores into a probability distribution and uses the cross-entropy loss function to derive the head gradient to propagate.
TrainingThanks to a modular encoder-attention-decoder design, the complete model can be build in a straightforward fashion.
A remaining subtelty is that during training, the decoder takes advantage of a teacher.
That is, at each step, the true previous token is fed rather than the predicted one.
Such information is not available when performing inference.
A second decoder is therefore built which used the most likely word at inference (argmax over the predictions) rather than the true label.
The hyper-parameters for training were kept fairly vanilla: an Adam optimizer with a decreasing learning rate:initializer <- mx.
init.
Xavier(rnd_type = "uniform", factor_type = "in", magnitude = 2.
5)lr_scheduler <- mx.
lr_scheduler.
FactorScheduler(step = 5000, factor_val = 0.
9, stop_factor_lr = 5e-5)optimizer <- mx.
opt.
create("adam", learning.
rate = 5e-4, beta1 = 0.
9, beta2 = 0.
999, epsilon = 1e-8, wd = 1e-8, clip_gradient = 1, rescale.
grad = 1, lr_scheduler = lr_scheduler)The model was then trained for 8 epochs, taking about a full day on a V100 GPU.
model <- mx.
model.
buckets(symbol = decode_teacher, train.
data = iter_train, eval.
data = iter_eval, num.
round = 12, ctx = ctx, verbose = TRUE, metric = mx.
metric.
Perplexity, optimizer = optimizer, initializer = initializer, batch.
end.
callback = batch.
end.
callback, epoch.
end.
callback = epoch.
end.
callback)mx.
model.
save(model=model, prefix="models/en_fr_cnn_rnn_teacher", iteration = 8)mx.
symbol.
save(symbol=decode_argmax, filename="models/en_fr_cnn_rnn_argmax.
json")Perplexity is used as the evaluation metric to track the progress of the training:InferenceTo have a comparable assessment of the translation quality, the model can be benchmarked against the official WMT test set.
To do so, the sacreBLEU library comes in handy:sacrebleu –test-set wmt15 –language-pair en-fr –echo src > wmt15-en-fr.
srcWhen performing inference on a new dataset, it’s crucial to apply the same preprocessing as for the training data.
Luckily, very few transformations were applied in our scenario, making this step easily replicable on the wmt15-en-fr.
src data.
Obviously, the same dictionary must be applied as well, so the one developed for training will be used during the preprocessing step rather than building a new one on the fly.
The inference model is obtained by combining the argmax structure with the weights learned during training with the teacher.
model <- mx.
model.
load(prefix = "models/model_wmt15_en_fr_cnn_rnn_teacher_v2", iteration = 12)sym_infer <- mx.
symbol.
load(file.
name = "models/model_wmt15_en_fr_cnn_rnn_argmax_v2.
json")model_infer <- list(symbol = sym_infer, arg.
params = model$arg.
params, aux.
params = model$aux.
params)model_infer <- structure(model_infer, class="MXFeedForwardModel")The inference can then be applied on the test data, stored as a text file, ready to be evaluated by sacreBLEU:cat wmt15_en_fr_cnn_rnn.
txt | sacrebleu -t wmt15 -l en-frThe resulting performance summary should look similar to:BLEU+case.
mixed+lang.
en-fr+numrefs.
1+smooth.
exp+test.
wmt15+tok.
13a+version.
1.
2.
1 2 = 28.
2 61.
0/36.
2/23.
8/16.
1 (BP = 0.
930 ratio = 0.
933 hyp_len = 26090 ref_len = 27975)Indicating we achieve a BLEU score of 28.
2.
Test sentences can also be submitted for translation, validating the soundness of the model:> infer_helper(infer_seq = "I'd love to learn French!", model = model_infer, source_dic = source_dic, target_dic = target_dic, seq_len = seq_len)[1] "J'aimerais apprendre le français!"ImprovingTo reach state-of-the-art performance, some additional tricks can be considered.
Tokenization: best performing systems typically use more sophisticated tokenization schemes, notably BPE which creates sub-word splits.
Positional embedding: in addition to the token ids that are used as input data, features that represent the position of the token within the sentence can be added.
It can either be single position indicator (absolute or relative) or a more complex collection of sin/cos waves as used in the transformer model.
Model ensembling: average the predictions of a few models.
Beam search: rather than using the single best translated token, top N candidates and their associated next best step are generated and the token associated with the max likelihood path is kept.
This partially circumvents limitations of the greedy argmax decoding.
Many of those features are integrated in the comprehensive Sockeye library built on top of MXNet.
.. More details