BERT Sequence Classification Base - DBpedia14 (bert_base_sequence_classifier_dbpedia_14)

Description

BERT Model with sequence classification/regression head on top (a linear layer on top of the pooled output) e.g. for multi-class document classification tasks.

bert_base_sequence_classifier_dbpedia_14 is a fine-tuned BERT model that is ready to be used for Sequence Classification tasks such as sentiment analysis or multi-class text classification and it achieves state-of-the-art performance.

We used TFBertForSequenceClassification to train this model and used BertForSequenceClassification annotator in Spark NLP 🚀 for prediction at scale!

Predicted Entities

Album, Animal, Artist, Athlete, Building, Company, EducationalInstitution, Film, MeanOfTransportation, NaturalPlace, OfficeHolder, Plant, Village, WrittenWork

Download Copy S3 URI

How to use

document_assembler = DocumentAssembler() \
.setInputCol('text') \
.setOutputCol('document')

tokenizer = Tokenizer() \
.setInputCols(['document']) \
.setOutputCol('token')

sequenceClassifier = BertForSequenceClassification \
.pretrained('bert_base_sequence_classifier_dbpedia_14', 'en') \
.setInputCols(['token', 'document']) \
.setOutputCol('class') \
.setCaseSensitive(True) \
.setMaxSentenceLength(512)

pipeline = Pipeline(stages=[
document_assembler,
tokenizer,
sequenceClassifier
])

example = spark.createDataFrame([['Disney Comics was a comic book publishing company operated by The Walt Disney Company which ran from 1990 to 1993.']]).toDF("text")
result = pipeline.fit(example).transform(example)
val document_assembler = DocumentAssembler()
.setInputCol("text")
.setOutputCol("document")

val tokenizer = Tokenizer()
.setInputCols("document")
.setOutputCol("token")

val tokenClassifier = BertForSequenceClassification.pretrained("bert_base_sequence_classifier_dbpedia_14", "en")
.setInputCols("document", "token")
.setOutputCol("class")
.setCaseSensitive(true)
.setMaxSentenceLength(512)

val pipeline = new Pipeline().setStages(Array(document_assembler, tokenizer, sequenceClassifier))

val example = Seq("Disney Comics was a comic book publishing company operated by The Walt Disney Company which ran from 1990 to 1993.").toDS.toDF("text")

val result = pipeline.fit(example).transform(example)
import nlu
nlu.load("en.classify.bert_sequence.dbpedia_14").predict("""Disney Comics was a comic book publishing company operated by The Walt Disney Company which ran from 1990 to 1993.""")

Model Information

Model Name: bert_base_sequence_classifier_dbpedia_14
Compatibility: Spark NLP 3.3.2+
License: Open Source
Edition: Official
Input Labels: [token, document]
Output Labels: [class]
Language: en
Case sensitive: true
Max sentense length: 512

Data Source

https://huggingface.co/datasets/dbpedia_14

Benchmarking

precision    recall  f1-score   support

Album       1.00      1.00      1.00      5004
Animal       1.00      1.00      1.00      4998
Artist       0.99      0.99      0.99      5012
Athlete       1.00      1.00      1.00      5002
Building       0.99      0.99      0.99      5007
Company       0.98      0.98      0.98      4999
EducationalInstitution       0.99      0.99      0.99      4998
Film       0.99      1.00      1.00      4978
MeanOfTransportation       1.00      1.00      1.00      5002
NaturalPlace       1.00      1.00      1.00      5005
OfficeHolder       0.99      0.99      0.99      5001
Plant       1.00      1.00      1.00      4994
Village       1.00      1.00      1.00      5003
WrittenWork       0.99      0.99      0.99      4997

accuracy                           0.99     70000
macro avg       0.99      0.99      0.99     70000
weighted avg       0.99      0.99      0.99     70000