Does BERT encode the length of a sentence in the norm of its [CLS] embedding?

I just want to explore a simple and meaningles question: Does BERT encode the length of a sentence in the norm of its [CLS] embedding?

I am interested in comparing the number of tokens per sentence and the norm of the embedding og the [CLS] token. For that, I will use the huggingface models together with some simple auxiliary functions in tensorflow, and the bookcorpus dataset.

!pip install --upgrade datasets transformers

Let’s instantiate the BERT model, its tokenizer and some auxiliary functions to process a large batch of sentences and extract only the [CLS] token.

import tensorflow as tf
from transformers import AutoTokenizer, TFAutoModel

# Load BERT and the tokenizer.
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
model = TFAutoModel.from_pretrained("bert-base-uncased")

@tf.function(experimental_relax_shapes=True)
def _cls(x):
    # Return only the CLS token.

    # Get the max seq_length for this batch.
    max_seq_length = tf.reduce_max(
        tf.reduce_sum(x['attention_mask'], axis=1)
    )

    x_ = {}
    for k, v in x.items():
        x_[k] = v[:,:max_seq_length]

    return model(**x_).last_hidden_state[:, 0]

# Compute the CLS token in batches.
def cls(sentences, batch_size=8):

    tokens = tokenizer(sentences, padding=True, return_tensors="tf")

    results = []

    # Consume the tokens in batches. Split in Python,
    # we don't need multiprocessing because GPU is the bottleneck.
    for i in range(0, len(sentences), batch_size):
        results.append(_cls({k: v[i:i+batch_size] for k, v in tokens.items()}))

    return tf.concat(results, axis=0)

Let's load 10000 sentences randomly sampled from bookcorpus, we will encode the sentence with BERT and retrieve the norm of the [CLS] token.

from datasets import load_dataset

# Load the book corpus dataset.
bookcorpus = load_dataset('bookcorpus')

# Get 100000 sentences.
sentences = bookcorpus['train'].shuffle()[:100000]['text']

# Get the token length and the CLS token norm (substract two tokens).
token_length = tf.reduce_sum(
    tokenizer(sentences, padding=True, return_tensors="tf")['attention_mask'],
    axis=-1
).numpy() - 2

cls_norm = tf.norm(cls(sentences), axis=-1).numpy()

Let's use a scatter plot to check if there is any connection between both variables. I will add a linear regression in top of it as well.

import matplotlib.pylab as plt
from sklearn.linear_model import LinearRegression
from sklearn.metrics import r2_score

X = token_length[token_length < 80].reshape(-1, 1)
Y = cls_norm[token_length < 80].reshape(-1, 1)

linear_regressor = LinearRegression()
linear_regressor.fit(X, Y)

Y_pred = linear_regressor.predict(X)

plt.ion()
plt.clf()
plt.scatter(X, Y)
plt.plot(X, Y_pred, color='red')
plt.grid()
plt.xlabel("Sentence length (in tokens)")
plt.ylabel("Norm of the [CLS] embedding")
plt.figtext(.8, .8, f"$R^2$ = {r2_score(Y, Y_pred):0.4f}", size='xx-large')

png

In general, the norm of the [CLS] embedding is not a reliable proxy for the length of the encoded sentence. You could argue that, if the norm of the [CLS] embedding is over 16 or below 13 the encoded sentence is most likely small (below 20 tokens). Nevertheless, in the great bulk in between vectors of norm 13 and 16, the length of the sentence varies greatly and there is not reliable way to extract the length of the encoded sentence from the norm of the [CLS] embedding.

Updated:

Comments