Universal sentence encoder for English trained with CMLM (sent_bert_use_cmlm_en_base)


Universal sentence encoder for English trained with a conditional masked language model. The universal sentence encoder family of models maps the text into high dimensional vectors that capture sentence-level semantics. Our English-base (en-base) model is trained using a conditional masked language model described in [1]. The model is intended to be used for text classification, text clustering, semantic textual similarity, etc. It can also be used used as modularized input for multimodal tasks with text as a feature. The base model employs a 12 layer BERT transformer architecture.

The model extends the BERT transformer architecture that is why we use it with BertSentenceEmbeddings.

[1] Ziyi Yang, Yinfei Yang, Daniel Cer, Jax Law, Eric Darve. Universal Sentence Representations Learning with Conditional Masked Language Model. November 2020


How to use

embeddings = BertSentenceEmbeddings.pretrained("sent_bert_use_cmlm_en_base", "en") \
      .setInputCols("sentence") \
val embeddings = BertSentenceEmbeddings.pretrained("sent_bert_use_cmlm_en_base", "en")
import nlu

text = ["I hate cancer", "Antibiotics aren't painkiller"]
embeddings_df = nlu.load('en.embed_sentence.sent_bert_use_cmlm_en_base').predict(text, output_level='sentence')

Model Information

Model Name: sent_bert_use_cmlm_en_base
Compatibility: Spark NLP 3.1.3+
License: Open Source
Edition: Official
Input Labels: [sentence]
Output Labels: [bert]
Language: en
Case sensitive: false

Data Source



Training News dataset by using ClassifierDL with 120K training examples:

            precision    recall  f1-score   support

    Business       0.84      0.90      0.87      1784
    Sci/Tech       0.92      0.85      0.89      2053
      Sports       0.98      0.96      0.97      1952
       World       0.89      0.93      0.91      1811

    accuracy                           0.91      7600
   macro avg       0.91      0.91      0.91      7600
weighted avg       0.91      0.91      0.91      7600