Does BERT encode the length of a sentence in the norm of its [CLS] embedding?
I just want to explore a simple and meaningles question: Does BERT encode the length of a sentence in the norm of its [CLS] embedding?
I am interested in comparing the number of tokens per sentence and the norm of the embedding og the [CLS] token. For that, I will use the huggingface models together with some simple auxiliary functions in tensorflow, and the bookcorpus dataset.
1
!pip install --upgrade datasets transformers
Let’s instantiate the BERT model, its tokenizer and some auxiliary functions to process a large batch of sentences and extract only the [CLS] token.
# Consume the tokens in batches. Split in Python, # we don't need multiprocessing because GPU is the bottleneck. for i inrange(0, len(sentences), batch_size): results.append(_cls({k: v[i:i+batch_size] for k, v in tokens.items()}))
return tf.concat(results, axis=0)
Let’s load 10000 sentences randomly sampled from bookcorpus, we will encode the sentence with BERT and retrieve the norm of the [CLS] token.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
from datasets import load_dataset
# Load the book corpus dataset. bookcorpus = load_dataset('bookcorpus')
# Get 100000 sentences. sentences = bookcorpus['train'].shuffle()[:100000]['text']
# Get the token length and the CLS token norm (substract two tokens). token_length = tf.reduce_sum( tokenizer(sentences, padding=True, return_tensors="tf")['attention_mask'], axis=-1 ).numpy() - 2
plt.ion() plt.clf() plt.scatter(X, Y) plt.plot(X, Y_pred, color='red') plt.grid() plt.xlabel("Sentence length (in tokens)") plt.ylabel("Norm of the [CLS] embedding") plt.figtext(.8, .8, f"$R^2$ = {r2_score(Y, Y_pred):0.4f}", size='xx-large')
{:width=“100%”}
In general, the norm of the [CLS] embedding is not a reliable proxy for the length of the encoded sentence. You could argue that, if the norm of the [CLS] embedding is over 16 or below 13 the encoded sentence is most likely small (below 20 tokens). Nevertheless, in the great bulk in between vectors of norm 13 and 16, the length of the sentence varies greatly and there is not reliable way to extract the length of the encoded sentence from the norm of the [CLS] embedding.