Menu
June 25, 2019 | By

Opening the Black Box with Adversarial RNN (Part 1)

Introduction

 In recent years, Deep Learning (DL) have been highly successful in solving different NLP tasks using state of the art architectures like Recurrent Neural Network (RNN) and Transformers. However, architectures aren’t everything, we also need a representative dataset to learn from. If we have unrepresentative train/test sets, our algorithm’s great performances might be misleading as it will likely fail in the real world. This is a BIG problem in DL and especially in NLP – it’s extremely difficult to gather labeled datasets that can represent the diversity and complexity that characterize natural languages.

 This fact leads us to the following realization: it’s not enough to just tell the algorithm what to do, we also need to tell it how to do it. In other words, we don’t just want any (seemingly) good solution, we want a solution that is complex enough to solve any dataset we might give it – and if it can’t, the algorithm’s confidence level should be low enough for us to recognize that.

 For example, during our research in LogMeIn we found that some models tend to search for the most important words in a phrase and map them to the right label. While this yields good results on the test set, more complex phrases will cause these models to fail. To see why, consider the following phrases:

I want to show a book’

I want to book a show

 While they both share the same important words, their meanings are very different. Let’s see a more applicable example – consider the following phrases with which we need to find the referred objects in the image:

 

‘hat worn by the man’

 ‘man wearing a hat’

 

While both phrases share important words, the rad phrase refers to the object bounded by the red box while the green phrase refers to the object bounded by the green box. Thus, even though this approach might yield good results for our dataset, we need a model that can understand the actual meaning of the phrase. This will help with the long tale of edge cases. So how can we prevent the model from doing so? And what are the most important words? This is a huge problem, especially considering that these architectures are black boxes, hence it’s very hard to understand how they do whatever they do. 

 In this blog post we’ll share one of LogMeIn ongoing research in which we introduce an adversarial RNN (ARNN) game that tackle this problem. This game generalizes our language models and can act as an analytic tool that sheds light on how our models do their magic. In the next part of this series, we’ll explain our experimental settings and show concrete results and analysis. 

 

The Game

 Consider the common solution for NLP tasks which contains (among others) an RNN who needs to embed a given phrase to minimize some loss. Now imagine we have another player – the Editor. The Editor is a nasty model and it really wants the RNN to fail, so it tries to maximize the loss by  editing the phrase’s words. To keep things fair, when a word gets edited, the editor replaces the word with the unique token – ‘edit’ and  for each edited word it gets a penalty. Thus, it needs to find a way to maximize the loss while minimizing the number of edited words. As a result, the editor must learn which words are the most important ones for the RNN, that is, which words does the RNN depend on in order to do a good job.

 As for the RNN, it can’t just search for the most important words and rely on them as they might be edited at any given time and that makes it hard to overfit, even if the phrase was edited. Therefore, it now needs to figure out what were the edited words originally by using other parts of the phrase. As a result, the generalization capabilities of the language model improves while other words become more important and that causes the editor to edit different words… and that cycle goes on and on until when?

 As the game evolves, the RNN learns to use more and more words, thus the perceived importance of words by the Editor becomes uniform and if the penalty for editing a word is big enough, the editor will not edit any more words. Nevertheless, the Editor is always watching, forcing the RNN to learn a much better use of both syntax and semantic or else the Editor will start editing again. Another option is to set a low penalty that encourage a high number of edited words. In this setting we’ll turn the Editor off for some number of randomly chosen phrases, which grows as the training progresses.

 We can also see this game as a form of data augmentation. The Editor helps making the RNN more resilient to noise by producing new noisy phrases at each training iteration. These phrases are chosen to be exactly those with which the RNN doesn’t know how to handle, thereby forcing the RNN to generalize better.

 

The Models Structure 

 As described above, let’s say that the RNN embeds some phrase  p=(w1, w2, …, wt, …, wn) with our RNN. For each time step t, the RNN takes in the word wt and produces the hidden state ht.

The editor is implemented by another RNN that takes ht as its input at time step t. Finally, it outputs two Q-values (Q as in Quality) for two actions – edit or no_edit – the action with the highest value is chosen. So how do we edit a word? Let’s see it step by step:

  1. The RNN gets the word wt and produces the hidden state ht 
  2.  The editor gets ht and approximate two Q-values: Q(ht , ‘edit’) and Q(ht , ‘no_edit’). If we choose not to edit, then nothing happens and the RNN continues as usual
  3.  If we choose to edit, we ignore ht and instead run the RNN again with the ‘edit’ token instead of wt. This happens on the fly, that is, we don’t run the entire edited phrase each time we edit a word. Instead we only run the RNN’s t’th time step again, with ht-1 and the token ‘edit’ as its inputs.

 

 One last thing before we continue: Our experiments show that we can improve the editor performances by using attention over the un-edited hidden states, produced without editing any of the words. With this approach the Editor can approximate the relative importance of each word.

 Therefore, before activating the editor, we first run the RNN alone with no edited words (and with no optimization). Then, we run the editor and the RNN together, as described above, however, the editor also uses attention over the RNN un-edited hidden states. Thus, to evaluate the Q-values, the editor now uses the state h’t, which includes both ht and the un-edited hidden states. After each game iteration we optimize the RNN and the editor parameters separately, so the models are unaware of each other.

 

Peeking into the Black Box

 One of the most important contributions of ARNN game is its insights. From our description above it’s clear that the Editor doesn’t see the phrases and their labels. Actually, the Editor doesn’t know anything about the problem that the RNN tries to solve! At each time step, the Editor can only see the hidden states which the RNN produced so far and use them to predict the RNN’s future behavior and its loss. 

The key point here is that the Editor acts as a Deep Learning framework that learns Deep Learning frameworks and we have a clear view to its insights! We just need to examine the Q-values (as we’ll soon see)

 Now you may think that this contribution is limited as the Editor also effects the RNN’s behavior, which it aims to learn, but that’s not necessarily true. There’s another way to play this game. We can just train the RNN as usual, but after each epoch we activate the editor and train only the Editor for a few rounds. So now the Editor has no effect on the RNN’s training, but the editor still learns how to detect important words. This can be a very powerful analytic tool. For example, this can help us detect the phrases which the RNN overfits and help us understand why. 

 

Optimization

 While the field of Reinforcement Learning (RL) offers many ways with which we can optimize the Editor’s performances, our approach is a variant of the Monte Carlo Q-learning. Let’s start our discussion by taking a closer look at the editor. 

 Given a phrase p=(w1, w2, …, wt, …, wn), at each time step t the editor gets a state h’t, evaluates the Q-value of each action, uses these values to choose the action at and as a result receives a reward r( h’t, at). The reward is zero if the Editor chooses not to edit and it’s negative if it edits the word wt – this is the penalty of which we’ve discussed earlier. To this we add a final reward rp, which is the RNN’s loss that the editor tries to maximize. Putting this all together, if we use a deterministic policy that chooses the action that has the highest Q-value, then we can use the following equation to evaluate the true value of being in state h’t and taking the action at.

In words, at state h’t we want to maximize the sum of future rewards. That means that we need to minimize the number of edited words (as they give us negative rewards) while maximizing the RNN’s loss (which is the final reward). It is common to add a discount factor γ to this equation, where 0≤γ≤1, which gives us:

 A small γ  will result in giving more weight to the rewards we get early on.

 One problem with this idea is that the RNN’s loss changes during training thus rp is constantly changing. This makes it very hard for the editor to predict the Q-value of each action based on experience. To resolve this issue, remember that we first run the RNN without editing any of the words (for attention). This gives us the RNN loss for the un-edited phrase:  LUE(p). Then, we activate the Editor and get the RNN’s loss for the edited phrase  LE(p). With this we calculate the last reward by:

Which can be maximized by increasing LE(p).  This gives us -1 ≤ rp ≤ 1. If editing the words improves the RNN’s performances, so that LE(p) < LUE(p), we’ll get a negative  rp. Note that if the Editor didn’t edit any of the words, we have LUE(p) = LE(p) and thus Q*( h’t, at) = 0. To optimize the editor performances, we look for the parameters that maximize Q*(h’t, at) expectation. 

 Finally, we need to talk about the negative reward. Let n be the number of words and c some negative constant, the reward for editing a word in phrase p is:

As the phrase becomes shorter, the penalty for editing a word becomes more severe. The length of the phrase can be inferred by the un-edited hidden states on which the Editor attends and we can also use it, together with LUE(p), as two additional inputs to the Editor.

 

 

Exploration vs Exploitation

 Before we end the first part of this blog post, there’s one important point we first need to understand. Generally, using a deterministic policy during training is problematic. If the Editor always picks the most valuable action in each state, according to its own policy, then how will it know how valuable are the state/action pairs that weren’t chosen or were never seen? To address this issue, it’s common to use a stochastic policy during training to explore more states and actions. Indeed, we’ve experimented with different variations of policy gradient and  ε-greedy Q-learning algorithms to resolve this and found them to produce poor results. 

 To understand why, remember that the Editor uses the state h’t to evaluate the Q-value. However, due to the RNN’s training and its interaction with the editor, the phrases embedings are constantly changing. It follows that h’t changes all the time and so do the Editor’s choices, as a result, the search space increases. If we also add a stochastic policy, its choices will very even more thereby creating a huge variance in the gradients produced by the editor, which consequently resulted in poor performances. Eventually, the best results were obtained by using a deterministic on-policy Q-learning that chooses the action with the highest Q-value. 

 That’s it for now. In the next part we’ll show some cool stuff that we can do with this setting. We’ll see that the Editor can significantly reduce overfit and give us extremely important insights about the dataset and how the RNN learns what it learns. 

 


Discuss / Read Comments

Leave a Reply

Leave a Reply

Explore our other Uncategorized or Recent Posts.