sparknlp_jsl.annotator.assertion.bert_assertion_classifier#

Contains Class for BertAssertionClassifier

Module Contents#

Classes#

BertAssertionClassifier

BertAssertionClassifier extracts the assertion status from text by analyzing both the extracted entities

class BertAssertionClassifier(classname='com.johnsnowlabs.nlp.annotators.assertion.BertAssertionClassifier', java_model=None)#

Bases: sparknlp_jsl.common.AnnotatorModelInternal, sparknlp_jsl.annotator.white_black_list_params.WhiteBlackListParams, sparknlp_jsl.common.HasEngine

BertAssertionClassifier extracts the assertion status from text by analyzing both the extracted entities and their surrounding context.

This classifier leverages pre-trained BERT models fine-tuned on biomedical text (e.g., BioBERT) and applies a sequence classification/regression head (a linear layer on the pooled output) to support multi-class document classification.

Key features:

  • Accepts DOCUMENT and CHUNK type inputs and produces ASSERTION type annotations.

  • Emphasizes entity context by marking target entities with special tokens (e.g., [entity]), allowing the model to better focus on them.

  • Utilizes a transformer-based architecture (BERT for Sequence Classification) to achieve accurate assertion status prediction.

Input Example:

This annotator preprocesses the input text to emphasize the target entities as follows:
[CLS] Patient with [entity] severe fever [entity].

Models from the HuggingFace πŸ€— Transformers library are also compatible with Spark NLP πŸš€. To see which models are compatible and how to import them see Import Transformers into Spark NLP πŸš€.

Input Annotation types

Output Annotation type

DOCUMENT, CHUNK

ASSERTION

Parameters:
  • configProtoBytes – ConfigProto from tensorflow, serialized into byte array.

  • classificationCaseSensitive – Whether to use case sensitive classification. Default is True.

Examples

>>> import sparknlp
>>> from sparknlp.base import *
>>> from sparknlp.annotator import *
>>> from sparknlp_jsl.annotator import *
>>> from pyspark.ml import Pipeline
>>> document_assembler = DocumentAssembler() \
...     .setInputCol("text") \
...     .setOutputCol("document")
>>> sentence_detector = SentenceDetector() \
...     .setInputCols(["document"]) \
...     .setOutputCol("sentence")
>>> tokenizer = Tokenizer() \
...     .setInputCols(["sentence"]) \
...     .setOutputCol("token")
>>> word_embeddings = WordEmbeddingsModel.pretrained("embeddings_clinical", "en", "clinical/models") \
...     .setInputCols(["sentence", "token"]) \
...     .setOutputCol("embeddings")
>>> clinical_ner = MedicalNerModel.pretrained("ner_clinical", "en", "clinical/models") \
...     .setInputCols(["sentence", "token", "embeddings"]) \
...     .setOutputCol("ner")
>>> ner_converter = NerConverterInternal() \
...     .setInputCols(["sentence", "token", "ner"]) \
...     .setOutputCol("ner_chunk")
>>> clinical_assertion = BertAssertionClassifier.pretrained() \
...     .setInputCols(["sentence", "ner_chunk"]) \
...     .setOutputCol("assertion")
>>> pipeline = Pipeline().setStages([
...     document_assembler,
...     sentence_detector,
...     tokenizer,
...     word_embeddings,
...     clinical_ner,
...     ner_converter,
...     clinical_assertion
... ])
>>> text = (
... "Patient with severe fever and sore throat. He shows no stomach pain and he maintained on an epidural." + \
... "and PCA for pain control. He also became short of breath with climbing a flight of stairs. After CT, " + \
... "lung tumor located at the right lower lobe. Father with Alzheimer."
... )
>>> data = spark.createDataFrame([[text]]).toDF("text")
>>> result_df = pipeline.fit(data).transform(data)
>>> result_df.selectExpr("explode(assertion) as result").show(100, False)
+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|result                                                                                                                                                                     |
+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|{assertion, 13, 24, present, {assertion_source -> assertion, chunk -> 0, ner_chunk -> severe fever, confidence -> 0.9996883, ner_label -> PROBLEM, sentence -> 0}, []}     |
|{assertion, 30, 40, present, {assertion_source -> assertion, chunk -> 1, ner_chunk -> sore throat, confidence -> 0.999676, ner_label -> PROBLEM, sentence -> 0}, []}       |
|{assertion, 55, 66, absent, {assertion_source -> assertion, chunk -> 2, ner_chunk -> stomach pain, confidence -> 0.9989444, ner_label -> PROBLEM, sentence -> 1}, []}      |
|{assertion, 89, 99, present, {assertion_source -> assertion, chunk -> 3, ner_chunk -> an epidural, confidence -> 0.99903834, ner_label -> TREATMENT, sentence -> 1}, []}   |
|{assertion, 114, 116, present, {assertion_source -> assertion, chunk -> 4, ner_chunk -> PCA, confidence -> 0.99900436, ner_label -> TREATMENT, sentence -> 1}, []}         |
|{assertion, 122, 133, present, {assertion_source -> assertion, chunk -> 5, ner_chunk -> pain control, confidence -> 0.9993321, ner_label -> PROBLEM, sentence -> 1}, []}   |
|{assertion, 151, 165, present, {assertion_source -> assertion, chunk -> 6, ner_chunk -> short of breath, confidence -> 0.9997882, ner_label -> PROBLEM, sentence -> 2}, []}|
|{assertion, 207, 208, present, {assertion_source -> assertion, chunk -> 7, ner_chunk -> CT, confidence -> 0.9996158, ner_label -> TEST, sentence -> 3}, []}                |
|{assertion, 220, 229, present, {assertion_source -> assertion, chunk -> 8, ner_chunk -> lung tumor, confidence -> 0.9997308, ner_label -> PROBLEM, sentence -> 3}, []}     |
|{assertion, 276, 284, present, {assertion_source -> assertion, chunk -> 9, ner_chunk -> Alzheimer, confidence -> 0.98367596, ner_label -> PROBLEM, sentence -> 4}, []}     |
+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
blackList#
caseSensitive#
classificationCaseSensitive#
configProtoBytes#
engine#
getter_attrs = []#
inputAnnotatorTypes#
inputCols#
lazyAnnotator#
name = 'BertAssertionClassifier'#
optionalInputAnnotatorTypes = []#
outputAnnotatorType = 'assertion'#
outputCol#
skipLPInputColsValidation = True#
uid = ''#
whiteList#
clear(param: pyspark.ml.param.Param) None#

Clears a param from the param map if it has been explicitly set.

copy(extra: pyspark.ml._typing.ParamMap | None = None) JP#

Creates a copy of this instance with the same uid and some extra params. This implementation first calls Params.copy and then make a copy of the companion Java pipeline component with extra params. So both the Python wrapper and the Java pipeline component get copied.

Parameters:

extra (dict, optional) – Extra parameters to copy to the new instance

Returns:

Copy of this instance

Return type:

JavaParams

explainParam(param: str | Param) str#

Explains a single param and returns its name, doc, and optional default value and user-supplied value in a string.

explainParams() str#

Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra: pyspark.ml._typing.ParamMap | None = None) pyspark.ml._typing.ParamMap#

Extracts the embedded default param values and user-supplied values, and then merges them with extra values from input into a flat param map, where the latter value is used if there exist conflicts, i.e., with ordering: default param values < user-supplied values < extra.

Parameters:

extra (dict, optional) – extra param values

Returns:

merged param map

Return type:

dict

getClasses()#

Returns labels used to train this model

getEngine()#
Returns:

Deep Learning engine used for this model”

Return type:

str

getInputCols()#

Gets current column names of input annotations.

getLazyAnnotator()#

Gets whether Annotator should be evaluated lazily in a RecursivePipeline.

getOrDefault(param: str) Any#
getOrDefault(param: Param[T]) T

Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither is set.

getOutputCol()#

Gets output column name of annotations.

getParam(paramName: str) Param#

Gets a param by its name.

getParamValue(paramName)#

Gets the value of a parameter.

Parameters:

paramName (str) – Name of the parameter

hasDefault(param: str | Param[Any]) bool#

Checks whether a param has a default value.

hasParam(paramName: str) bool#

Tests whether this instance contains a param with a given (string) name.

inputColsValidation(value)#
isDefined(param: str | Param[Any]) bool#

Checks whether a param is explicitly set by user or has a default value.

isSet(param: str | Param[Any]) bool#

Checks whether a param is explicitly set by user.

classmethod load(path: str) RL#

Reads an ML instance from the input path, a shortcut of read().load(path).

static loadSavedModel(folder, spark_session)#

Loads a locally saved model :param folder: Folder of the saved model :type folder: str :param spark_session: The current SparkSession :type spark_session: pyspark.sql.SparkSession

Returns:

The restored model

Return type:

BertAssertionClassifier

static pretrained(name='assertion_bert_classifier_jsl_slim', lang='en', remote_loc='clinical/models')#

Download a pre-trained BertAssertionClassifier.

Parameters:
  • name (str) – Name of the pre-trained model.

  • lang (str) – Language of the pre-trained model. Default is β€œen”.

  • remote_loc (str) – Remote location of the pre-trained model. Default is β€œclinical/models”.

Returns:

A pre-trained BertAssertionClassifier.

Return type:

BertAssertionClassifier

classmethod read()#

Returns an MLReader instance for this class.

save(path: str) None#

Save this ML instance to the given path, a shortcut of β€˜write().save(path)’.

set(param: Param, value: Any) None#

Sets a parameter in the embedded param map.

setBlackList(value)#

Sets If defined, list of entities to ignore. The rest will be processed. Do not include IOB prefix on labels

Parameters:

value (List[str]) – If defined, list of entities to ignore. The rest will be processed. Do not include IOB prefix on labels

setCaseSensitive(value)#

Determines whether the definitions of the white listed and black listed entities are case sensitive or not.

Parameters:

value (bool) – Whether white listed and black listed entities are case sensitive or not. Default: True.

setClassificationCaseSensitive(value)#

Sets whether to use case sensitive classification. Default is True.

Parameters:

value (bool) – Whether to use case sensitive classification

setConfigProtoBytes(b)#

Sets configProto from tensorflow, serialized into byte array.

Parameters:

b (List[str]) – ConfigProto from tensorflow, serialized into byte array

setDenyList(value)#

Sets If defined, list of entities to ignore. The rest will be processed. Do not include IOB prefix on labels

Parameters:

value (List[str]) – If defined, list of entities to ignore. The rest will be processed. Do not include IOB prefix on labels

setForceInputTypeValidation(etfm)#
setInputCols(*value)#

Sets column names of input annotations.

Parameters:

*value (List[str]) – Input columns for the annotator

setLazyAnnotator(value)#

Sets whether Annotator should be evaluated lazily in a RecursivePipeline.

Parameters:

value (bool) – Whether Annotator should be evaluated lazily in a RecursivePipeline

setOutputCol(value)#

Sets output column name of annotations.

Parameters:

value (str) – Name of output column

setParamValue(paramName)#

Sets the value of a parameter.

Parameters:

paramName (str) – Name of the parameter

setParams()#
setWhiteList(value)#

Sets If defined, list of entities to process. The rest will be ignored. Do not include IOB prefix on labels

Parameters:

value (List[str]) – If defined, list of entities to process. The rest will be ignored. Do not include IOB prefix on labels

transform(dataset: pyspark.sql.dataframe.DataFrame, params: pyspark.ml._typing.ParamMap | None = None) pyspark.sql.dataframe.DataFrame#

Transforms the input dataset with optional parameters.

New in version 1.3.0.

Parameters:
  • dataset (pyspark.sql.DataFrame) – input dataset

  • params (dict, optional) – an optional param map that overrides embedded params.

Returns:

transformed dataset

Return type:

pyspark.sql.DataFrame

write() JavaMLWriter#

Returns an MLWriter instance for this ML instance.