Detect Diseases in Medical Text

Description

This model is trained with the BertForTokenClassification method from the transformers library and imported into Spark NLP. The model detects disease entities from a medical text

Predicted Entities

Disease

Open in Colab Copy S3 URI

How to use

document_assembler = DocumentAssembler()\
    .setInputCol("text")\
    .setOutputCol("document")\

sentence_detector = SentenceDetectorDLModel.pretrained("sentence_detector_dl_healthcare", "en", "clinical/models")\
    .setInputCols(["document"])\
    .setOutputCol("sentence")

tokenizer = Tokenizer()\
    .setInputCols(["sentence"])\
    .setOutputCol("token")

ner_model = MedicalBertForTokenClassifier.pretrained("bert_token_classifier_ner_ncbi_disease", "en", "clinical/models")\
    .setInputCols(["sentence", "token"])\
    .setOutputCol("ner")\
    .setCaseSensitive(True)\
    .setMaxSentenceLength(512)

ner_converter = NerConverter()\
    .setInputCols(["sentence", "token", "ner"])\
    .setOutputCol("ner_chunk")

pipeline = Pipeline(stages=[
    document_assembler, 
    sentence_detector,
    tokenizer,
    ner_model,
    ner_converter   
    ])

data = spark.createDataFrame([["""Kniest dysplasia is a moderately severe type II collagenopathy, characterized by short trunk and limbs, kyphoscoliosis, midface hypoplasia, severe myopia, and hearing loss."""]]).toDF("text")

result = pipeline.fit(data).transform(data)
val document_assembler = new DocumentAssembler()
    .setInputCol("text")
    .setOutputCol("document")

val sentence_detector = SentenceDetectorDLModel.pretrained("sentence_detector_dl_healthcare", "en", "clinical/models")
    .setInputCols(Array("document"))
    .setOutputCol("sentence")

val tokenizer = new Tokenizer()
    .setInputCols(Array("sentence"))
    .setOutputCol("token")

val ner_model = MedicalBertForTokenClassifier.pretrained("bert_token_classifier_ner_ncbi_disease", "en", "clinical/models")
    .setInputCols(Array("sentence", "token"))
    .setOutputCol("ner")
    .setCaseSensitive(True)
    .setMaxSentenceLength(512)

val ner_converter = new NerConverter()
    .setInputCols(Array("sentence", "token", "ner"))
    .setOutputCol("ner_chunk")

val pipeline = new Pipeline().setStages(Array(document_assembler, 
                                                   sentence_detector,
                                                   tokenizer,
                                                   ner_model,
                                                   ner_converter))

val data = Seq("""Kniest dysplasia is a moderately severe type II collagenopathy, characterized by short trunk and limbs, kyphoscoliosis, midface hypoplasia, severe myopia, and hearing loss.""").toDS.toDF("text")

val result = pipeline.fit(data).transform(data)
import nlu
nlu.load("en.classify.token_bert.ncbi_disease").predict("""Kniest dysplasia is a moderately severe type II collagenopathy, characterized by short trunk and limbs, kyphoscoliosis, midface hypoplasia, severe myopia, and hearing loss.""")

Results

+----------------------+-------+
|ner_chunk             |label  |
+----------------------+-------+
|Kniest dysplasia      |Disease|
|type II collagenopathy|Disease|
|kyphoscoliosis        |Disease|
|midface hypoplasia    |Disease|
|myopia                |Disease|
|hearing loss          |Disease|
+----------------------+-------+

Model Information

Model Name: bert_token_classifier_ner_ncbi_disease
Compatibility: Healthcare NLP 4.0.0+
License: Licensed
Edition: Official
Input Labels: [sentence, token]
Output Labels: [ner]
Language: en
Size: 404.2 MB
Case sensitive: true
Max sentence length: 512

References

https://github.com/cambridgeltl/MTL-Bioinformatics-2016

Benchmarking

 label         precision  recall  f1-score  support 
 B-Disease     0.8392     0.9406  0.8870    960     
 I-Disease     0.8752     0.9356  0.9044    1087    
 micro-avg     0.8579     0.9380  0.8961    2047    
 macro-avg     0.8572     0.9381  0.8957    2047    
 weighted-avg  0.8583     0.9380  0.8963    2047