Sentence Transformers for Sentence Similarity¶
In this article, we will take a look at the history leading up to the creation of Sentence Transformers, the shortcomings of past architectures across various Natural Language Processing (NLP) tasks (mainly sentence similarity) and how Sentence Transformers tackle these problems.
Introduction¶
data:image/s3,"s3://crabby-images/9553e/9553ea11d0810d1293d7906c945dde7886f3651b" alt="sentence-transformers-history"
Recurrent Networks¶
data:image/s3,"s3://crabby-images/61e7d/61e7d1ab3a4dfafc890f9cc37f2d127fc93f8257" alt="recurrent-networks"
Clearly, Recurrent Neural Networks (RNNs) are versatile but for language problems, they have their disadvantages:
Disadvantages:
- Slow to train and slow at inference
- This is because the input words are processed one at a time, sequentially. Therefore, longer sentences just take a longer time.
- Do not truly understand context
- RNNs only learn about a word based on the words that came before it. In reality, the context of a word depends on the sentence as a whole.
- Bidirectional Long Short-Term Memory (LSTMs) try to address this but even here, the left to right and right to left context are learned separately and are concatenated so some of the true context are lost.
Transformer Networks¶
data:image/s3,"s3://crabby-images/680aa/680aa419dd36c150a0b73a00fcfdbad4044b23f0" alt="transformer-networks"
data:image/s3,"s3://crabby-images/bd59f/bd59fb05e283a3eb271e15d0d7006a49ceecf4f5" alt="transformer-networks-2"
For English to French translation, we pass in the entire English sentence into the encoder simultaneously. Then, we get the corresponding word vectors simultaneously. These word vectors encode the meaning of the word and they are better than RNNs because they understand bidirectional context through attention units.
Now we pass these vectors into the decoder along with the previously generated French words to generate the next French word in the sentence. We keep passing the French words that were generated into the decoder until we hit the end of sentence.
Transformers work well for sequence to sequence problems but for the specific natural language problems like question answering and text summarization, even Transformers have drawbacks related to one fact — language is complicated.
Disadvantages:
- Need a lot of data
- Architecture may not be complex enough
- Transformers may not be complex enough to understand patterns to solve these language problems. After all, Transformers weren’t designed to be language models so the word representations generated can still be improved.
BERT Networks¶
BERT was introduced to extend the capabilities of the Transformer. BERT was built with the ideology that different Natural Language Processing (NLP) problems all rely on the same fundamental understanding of language.
data:image/s3,"s3://crabby-images/99424/994246b0e4039adf79f1c960f75b7b80069ab491" alt="bert-networks"
Phases¶
BERT undergoes two phases of training:
- Pre-Training: Understand Language
- Fine Tuning: Understand Language specific tasks
Advantages over Transformers¶
Needing a lot of data→ Fine tuning does not require obscene amounts of dataArchitecture may not be complex enough→ BERT is a stack of Transformer encoders and is therefore known as Bidirectional Encoder Representations from Transformers.- Bidirectional: It is bidirectional since it understands the context of words looking both ways via attention.
- Encoder & Transformers: Since BERT is essentially a stack of the encoder part of the Transformer.
- Representations: Since BERT is pre-trained to be a language model, it better understands word representations. This means the output word vectors from BERT better encapsulates the meaning of the words in sentences.
data:image/s3,"s3://crabby-images/c63ab/c63abd17fe8db66cca29e2a84a752c89a2c4152e" alt="nlp-tasks"
The big takeaway here is that BERT can now solve a host of complex language specific problems except for one type.
Imagine you’re a Data Scientist at Quora which is a question answer site and you want to design a system that find related questions to the one that is currently being asked. How would we solve this with BERT?
<figure markdown="span">
data:image/s3,"s3://crabby-images/58f67/58f6739838c4f451f463fbe4253a3ca11d8b94c2" alt="quora"{ width=500 }
<figcaption>You're a Data Scientist trying to design a system that find related questions to the one that is currently being asked</figcaption>
</figure>
Goal: Determine questions similar to the one being asked.
data:image/s3,"s3://crabby-images/72d3a/72d3a61eb022f96655b81ca04b83b60771bebfef" alt="quora"
Steps:
- First take the question that is being asked and another question that had been asked in the past, pass both of these questions into BERT
- BERT generates word vectors
- Pipe these word vectors into some feed forward layer such that the output would be a single neuron corresponding to the similarity score
- Repeat the steps for every question on the platform to compute the pairwise similarity
- Select the highest similarity scores and the corresponding questions will be the most similar and relevant to the question that is being asked
However, there is a big issue here. If there are 100 million questions on the platform, we’d have to run the forward pass of BERT 100 million times every single time a new question comes in. This is not viable!
So the next question so ask is: how do we make BERT work for the current goal?
Sentence Transformers¶
Pass 1: High Level Idea ¶
data:image/s3,"s3://crabby-images/192b7/192b71ac2a5b4d093162f21cfb4eaa3847a0db37" alt="quora"
Steps:
- We would want to pass the new question into BERT to get a single vector that represents the meaning of the question.
- Compare the vector of the new question to the vectors of all other questions using a similarity metric (i.e. cosine similarity).
- Return the nearest neighbours as the most related questions to the new question.
Therefore, for every new question asked, we only require a single forward pass of the BERT model not 100 million times as mentioned before. This is great because computing simple similarity metrics between vectors is much cheaper than passing in all questions on the platform through the complex model every time you need to make a decision.
Pass 2: Sentence Transformers¶
In the first pass, a new question is passed into BERT to get a single vector that represents the question. However, BERT only gives us word vectors. Therefore, in order to get a single vector, you’ll need to somehow aggregate these word vectors by passing it through some unit.
The most straightforward way of doing this is to take the average of these vectors. This is known as mean pooling. Another way is to take the maximum value across every dimension of the embedding. This is known as max pooling.
data:image/s3,"s3://crabby-images/16d60/16d604ea34346f1a8210362ac168ea80f9cacdc6" alt="quora"
The diagram above shows the simplest form of a Sentence Transformer but the output vector generated is extremely poor quality. Its quality is so poor that you might be better off simply taking the average of GloVe embeddings (and not even using BERT).
data:image/s3,"s3://crabby-images/e41fe/e41fe4b80841fecdeb03c29aaf59cf120e1c197f" alt="quora"
How to get sentence vectors with meaning?
In order for BERT to create sentence vectors that actually have meaning, we need to further train it (fine-tune) on sentence level tasks (refer to next section for more information).
Once we train (fine-tune) BERT on one or all of these tasks, the sentence vector generated becomes a good representation of the sentence — that is, it encodes the meaning of the sentence very well.
This is important since it means that closer the vectors are in terms of distance, the more similar is the meaning.
Info
In our Quora questioning setting, we would pass every question through the sentence transformer once and store them somewhere for future use. Then when a new question comes in, we pass only that question through the sentence transformer to get the sentence vector representation and then determine the questions with the highest cosine similarity and surface them as related questions. We can find the nearest neighbours through some nearest neighbours techniques:
- ANNOY (Approximate Nearest Neighbours)
- KNN Elastic Search
Pass 3: Sentence Transformers Training ¶
BERT is good at word representations but we want to make a Sentence Transformer that is good with sentence representations. To do this, we fine-tune BERT on any or all of the three sentence related tasks:
- Natural Language Inference (NLI)
- Sentence Text Similarity (STS)
- Triplet Dataset
Natural Language Inference (NLI) ¶
data:image/s3,"s3://crabby-images/df495/df4951746c3a4399763f1f3ef9dfb94f58a79c81" alt="quora"
NLI is a task that takes in two sentences and determines if sentence 1 entails or contradicts sentence 2 or simply neither. See some examples below:
Examples
Entailment
- Sentence 1: “Say hello to me!”
- Sentence 2: “Greet me!”
Neutral
- Sentence 1: “Say hello to me!”
- Sentence 2: “Two people greeting and playing together.”
Contradiction
- Sentence 1: “Say hello to me!”
- Sentence 2: “You’re ignoring me!”
This allows BERT to understand sentence meanings as a whole. For training NLI, a Siamese network is used. “Siamese” means twins so we have two of the exact same Sentence Transformer networks connected in this fashion.
data:image/s3,"s3://crabby-images/a53b2/a53b21f60c2bd0002a9095d6870565b8fef5b954" alt="quora"
If we want to compare two sentences, we pass them through the different BERT networks to get word representations. These word vectors are then combined to create a sentence vector and then concatenate the two sentence vectors and their difference. The output is a softmax classification which can be one of these three classes — entailment, contradiction or neutral.
data:image/s3,"s3://crabby-images/b29c4/b29c4a9f24c54c8ee70ec923a5bfb66de51c218a" alt="quora"
Note
Note that the mean pooling and concatenation look really arbitrary but they were chosen because they yielded the best results than any other strategy — like choosing max pooling or simply only considering the absolute difference between the vectors instead.
During inference time, we only need the Sentence Transformer piece where we get a question and then we get the corresponding sentence vector. This vector is the sentence representation that encodes the meaning of the sentence (very well, hopefully).
Sentence Text Similarity (STS)¶
Another task we can use to fine-tune BERT to understand sentences is using STS. Given two sentences, output the score of how similar they are.
data:image/s3,"s3://crabby-images/218a7/218a7a851503660c323e80c5855c595ef74b342b" alt="quora"
Just like NLI, this is also trained with a Siamese network. During training, we pass the two sentences to compare through different Sentence Transformers to get these sentence vectors and then compute the cosine similarity between these sentence vectors to get the a value between \(-1\) and \(1\). These are then compared to an actual labelled similarity rating on a scale of \(1\) to \(5\) which is normalized to be comparable to the output score. We minimize the squared difference between the two so that the model can be trained.
data:image/s3,"s3://crabby-images/17401/174016a938ed61ec05876e40d92692d1e580e171" alt="quora"
Triplet Dataset¶
A third type of task that we can train Sentence Transformers is using a dataset that has triple of sentences. The main sentence is called the “anchor”, the next sentence is a sentence that is “related” and the last sentence being one that is “unrelated” to the “anchor”.
data:image/s3,"s3://crabby-images/44e95/44e951195562ab156864df7b2138cb571010c253" alt="quora"
We can quickly make this type of dataset by picking a Wikipedia page, then choosing a sentence to be the “anchor” and the next sentence in the same paragraph can be chosen as the “related” sentence and then choose a sentence from another paragraph as the “unrelated” sentence. See screenshot below for an example.
data:image/s3,"s3://crabby-images/3be74/3be74fcb40793519da4f32c3ae167c12158892d0" alt="quora"
The network is a triplet (not siamese, or twins) of the exact same Sentence Transformer architectures. During training, we pass each sentence through a Sentence Transformer to get three sentence vectors; \(S_{a}\), \(S_{+}\) and \(S_{-}\).
We want to make sure the distance between the anchor and the related sentence is small and the distance between the anchor and unrelated sentence is large. This is so that the meanings are learned.
data:image/s3,"s3://crabby-images/caa2f/caa2f1494c377014136bf831f46da28b2c4aeb9c" alt="quora"
Conclusion¶
Regardless which of the tasks is chosen for training the Sentence Transformers, during inference time, we should be able to pass in some sentences and generate sentence representation vectors that encode the meaning of the sentences very well.
Pass 4: Sentence Transformers Inference¶
Going back to our Data Science job at Quora, how do we recommend similar questions? Before additional questions are asked, we want to pass in every single question/sentence through the fine-tuned Sentence Transformer to get the corresponding sentence vectors. These vectors are good sentence representations (if fine tuning did not go wrong). These vectors all live in a space also known as the embedding space as previously seen above.
Next, when a new question comes in, we pass it through our Sentence Transformer to get the sentence representation or sentence embedding. Next, we determine the cosine similarity between the new question and every other candidate question. Finally, we will return the closest questions list as the related questions.
For small datasets, we can determine the cosine similarity for a new question with every other question but it becomes increasingly harder to do when there are hundreds and millions of questions (very common especially on a platform like Quora).
data:image/s3,"s3://crabby-images/edf32/edf32722c2c87f593d4f141d914c11cf0e2e7361" alt="quora"
To solve this issue, there are a couple of algorithms we can use. Spotify uses an Approximate Nearest Neighbours algorithm called ANNOY to recommend music to you. In this case, songs are embedded into vectors.
Another way to quickly compute the nearest neighbours is through AWS which has an extremely efficient implementation of the k-Nearest Neighbours algorithm.
Summary¶
Recurrent Neural Networks¶
Advantages:
- Able to deal with Sequence-to-Sequence problems
Disadvantages:
- Slow to train and during inference
- Do not truly understand context
Transformers¶
Advantages:
- Replace Recurrent units with Attention units, addressing past concerns
- Solves Sequence-to-Sequence problems
Disadvantages:
- Not necessarily complex enough to understand language
BERT¶
Advantages:
- Stack of Transformer encoders
- Complex enough to solve a host of NLP problems
Disadvantages:
- Not good with sentence similarity tasks
Sentence BERT¶
Advantages:
- Fine tunes BERT on Sentence Similarity Tasks, addressing past concerns