sparknlp_jsl.annotator.assertion.bert_assertion_classifier
#
Contains Class for BertAssertionClassifier
Module Contents#
Classes#
BertAssertionClassifier extracts the assertion status from text by analyzing both the extracted entities |
- class BertAssertionClassifier(classname='com.johnsnowlabs.nlp.annotators.assertion.BertAssertionClassifier', java_model=None)#
Bases:
sparknlp_jsl.common.AnnotatorModelInternal
,sparknlp_jsl.annotator.white_black_list_params.WhiteBlackListParams
,sparknlp_jsl.common.HasEngine
BertAssertionClassifier extracts the assertion status from text by analyzing both the extracted entities and their surrounding context.
This classifier leverages pre-trained BERT models fine-tuned on biomedical text (e.g., BioBERT) and applies a sequence classification/regression head (a linear layer on the pooled output) to support multi-class document classification.
Key features:
Accepts DOCUMENT and CHUNK type inputs and produces ASSERTION type annotations.
Emphasizes entity context by marking target entities with special tokens (e.g., [entity]), allowing the model to better focus on them.
Utilizes a transformer-based architecture (BERT for Sequence Classification) to achieve accurate assertion status prediction.
Input Example:
This annotator preprocesses the input text to emphasize the target entities as follows: [CLS] Patient with [entity] severe fever [entity].
Models from the HuggingFace π€ Transformers library are also compatible with Spark NLP π. To see which models are compatible and how to import them see Import Transformers into Spark NLP π.
Input Annotation types
Output Annotation type
DOCUMENT, CHUNK
ASSERTION
- Parameters:
configProtoBytes β ConfigProto from tensorflow, serialized into byte array.
classificationCaseSensitive β Whether to use case sensitive classification. Default is True.
Examples
>>> import sparknlp >>> from sparknlp.base import * >>> from sparknlp.annotator import * >>> from sparknlp_jsl.annotator import * >>> from pyspark.ml import Pipeline >>> document_assembler = DocumentAssembler() \ ... .setInputCol("text") \ ... .setOutputCol("document") >>> sentence_detector = SentenceDetector() \ ... .setInputCols(["document"]) \ ... .setOutputCol("sentence") >>> tokenizer = Tokenizer() \ ... .setInputCols(["sentence"]) \ ... .setOutputCol("token") >>> word_embeddings = WordEmbeddingsModel.pretrained("embeddings_clinical", "en", "clinical/models") \ ... .setInputCols(["sentence", "token"]) \ ... .setOutputCol("embeddings") >>> clinical_ner = MedicalNerModel.pretrained("ner_clinical", "en", "clinical/models") \ ... .setInputCols(["sentence", "token", "embeddings"]) \ ... .setOutputCol("ner") >>> ner_converter = NerConverterInternal() \ ... .setInputCols(["sentence", "token", "ner"]) \ ... .setOutputCol("ner_chunk") >>> clinical_assertion = BertAssertionClassifier.pretrained() \ ... .setInputCols(["sentence", "ner_chunk"]) \ ... .setOutputCol("assertion") >>> pipeline = Pipeline().setStages([ ... document_assembler, ... sentence_detector, ... tokenizer, ... word_embeddings, ... clinical_ner, ... ner_converter, ... clinical_assertion ... ]) >>> text = ( ... "Patient with severe fever and sore throat. He shows no stomach pain and he maintained on an epidural." + \ ... "and PCA for pain control. He also became short of breath with climbing a flight of stairs. After CT, " + \ ... "lung tumor located at the right lower lobe. Father with Alzheimer." ... ) >>> data = spark.createDataFrame([[text]]).toDF("text") >>> result_df = pipeline.fit(data).transform(data) >>> result_df.selectExpr("explode(assertion) as result").show(100, False) +---------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ |result | +---------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ |{assertion, 13, 24, present, {assertion_source -> assertion, chunk -> 0, ner_chunk -> severe fever, confidence -> 0.9996883, ner_label -> PROBLEM, sentence -> 0}, []} | |{assertion, 30, 40, present, {assertion_source -> assertion, chunk -> 1, ner_chunk -> sore throat, confidence -> 0.999676, ner_label -> PROBLEM, sentence -> 0}, []} | |{assertion, 55, 66, absent, {assertion_source -> assertion, chunk -> 2, ner_chunk -> stomach pain, confidence -> 0.9989444, ner_label -> PROBLEM, sentence -> 1}, []} | |{assertion, 89, 99, present, {assertion_source -> assertion, chunk -> 3, ner_chunk -> an epidural, confidence -> 0.99903834, ner_label -> TREATMENT, sentence -> 1}, []} | |{assertion, 114, 116, present, {assertion_source -> assertion, chunk -> 4, ner_chunk -> PCA, confidence -> 0.99900436, ner_label -> TREATMENT, sentence -> 1}, []} | |{assertion, 122, 133, present, {assertion_source -> assertion, chunk -> 5, ner_chunk -> pain control, confidence -> 0.9993321, ner_label -> PROBLEM, sentence -> 1}, []} | |{assertion, 151, 165, present, {assertion_source -> assertion, chunk -> 6, ner_chunk -> short of breath, confidence -> 0.9997882, ner_label -> PROBLEM, sentence -> 2}, []}| |{assertion, 207, 208, present, {assertion_source -> assertion, chunk -> 7, ner_chunk -> CT, confidence -> 0.9996158, ner_label -> TEST, sentence -> 3}, []} | |{assertion, 220, 229, present, {assertion_source -> assertion, chunk -> 8, ner_chunk -> lung tumor, confidence -> 0.9997308, ner_label -> PROBLEM, sentence -> 3}, []} | |{assertion, 276, 284, present, {assertion_source -> assertion, chunk -> 9, ner_chunk -> Alzheimer, confidence -> 0.98367596, ner_label -> PROBLEM, sentence -> 4}, []} | +---------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
- blackList#
- caseSensitive#
- classificationCaseSensitive#
- configProtoBytes#
- engine#
- getter_attrs = []#
- inputAnnotatorTypes#
- inputCols#
- lazyAnnotator#
- name = 'BertAssertionClassifier'#
- optionalInputAnnotatorTypes = []#
- outputAnnotatorType = 'assertion'#
- outputCol#
- skipLPInputColsValidation = True#
- uid = ''#
- whiteList#
- clear(param: pyspark.ml.param.Param) None #
Clears a param from the param map if it has been explicitly set.
- copy(extra: pyspark.ml._typing.ParamMap | None = None) JP #
Creates a copy of this instance with the same uid and some extra params. This implementation first calls Params.copy and then make a copy of the companion Java pipeline component with extra params. So both the Python wrapper and the Java pipeline component get copied.
- Parameters:
extra (dict, optional) β Extra parameters to copy to the new instance
- Returns:
Copy of this instance
- Return type:
JavaParams
- explainParam(param: str | Param) str #
Explains a single param and returns its name, doc, and optional default value and user-supplied value in a string.
- explainParams() str #
Returns the documentation of all params with their optionally default values and user-supplied values.
- extractParamMap(extra: pyspark.ml._typing.ParamMap | None = None) pyspark.ml._typing.ParamMap #
Extracts the embedded default param values and user-supplied values, and then merges them with extra values from input into a flat param map, where the latter value is used if there exist conflicts, i.e., with ordering: default param values < user-supplied values < extra.
- Parameters:
extra (dict, optional) β extra param values
- Returns:
merged param map
- Return type:
dict
- getClasses()#
Returns labels used to train this model
- getEngine()#
- Returns:
Deep Learning engine used for this modelβ
- Return type:
str
- getInputCols()#
Gets current column names of input annotations.
- getLazyAnnotator()#
Gets whether Annotator should be evaluated lazily in a RecursivePipeline.
- getOrDefault(param: str) Any #
- getOrDefault(param: Param[T]) T
Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither is set.
- getOutputCol()#
Gets output column name of annotations.
- getParam(paramName: str) Param #
Gets a param by its name.
- getParamValue(paramName)#
Gets the value of a parameter.
- Parameters:
paramName (str) β Name of the parameter
- hasDefault(param: str | Param[Any]) bool #
Checks whether a param has a default value.
- hasParam(paramName: str) bool #
Tests whether this instance contains a param with a given (string) name.
- inputColsValidation(value)#
- isDefined(param: str | Param[Any]) bool #
Checks whether a param is explicitly set by user or has a default value.
- isSet(param: str | Param[Any]) bool #
Checks whether a param is explicitly set by user.
- classmethod load(path: str) RL #
Reads an ML instance from the input path, a shortcut of read().load(path).
- static loadSavedModel(folder, spark_session)#
Loads a locally saved model :param folder: Folder of the saved model :type folder: str :param spark_session: The current SparkSession :type spark_session: pyspark.sql.SparkSession
- Returns:
The restored model
- Return type:
- static pretrained(name='assertion_bert_classifier_jsl_slim', lang='en', remote_loc='clinical/models')#
Download a pre-trained BertAssertionClassifier.
- Parameters:
name (str) β Name of the pre-trained model.
lang (str) β Language of the pre-trained model. Default is βenβ.
remote_loc (str) β Remote location of the pre-trained model. Default is βclinical/modelsβ.
- Returns:
A pre-trained BertAssertionClassifier.
- Return type:
- classmethod read()#
Returns an MLReader instance for this class.
- save(path: str) None #
Save this ML instance to the given path, a shortcut of βwrite().save(path)β.
- set(param: Param, value: Any) None #
Sets a parameter in the embedded param map.
- setBlackList(value)#
Sets If defined, list of entities to ignore. The rest will be processed. Do not include IOB prefix on labels
- Parameters:
value (List[str]) β If defined, list of entities to ignore. The rest will be processed. Do not include IOB prefix on labels
- setCaseSensitive(value)#
Determines whether the definitions of the white listed and black listed entities are case sensitive or not.
- Parameters:
value (bool) β Whether white listed and black listed entities are case sensitive or not. Default: True.
- setClassificationCaseSensitive(value)#
Sets whether to use case sensitive classification. Default is True.
- Parameters:
value (bool) β Whether to use case sensitive classification
- setConfigProtoBytes(b)#
Sets configProto from tensorflow, serialized into byte array.
- Parameters:
b (List[str]) β ConfigProto from tensorflow, serialized into byte array
- setDenyList(value)#
Sets If defined, list of entities to ignore. The rest will be processed. Do not include IOB prefix on labels
- Parameters:
value (List[str]) β If defined, list of entities to ignore. The rest will be processed. Do not include IOB prefix on labels
- setForceInputTypeValidation(etfm)#
- setInputCols(*value)#
Sets column names of input annotations.
- Parameters:
*value (List[str]) β Input columns for the annotator
- setLazyAnnotator(value)#
Sets whether Annotator should be evaluated lazily in a RecursivePipeline.
- Parameters:
value (bool) β Whether Annotator should be evaluated lazily in a RecursivePipeline
- setOutputCol(value)#
Sets output column name of annotations.
- Parameters:
value (str) β Name of output column
- setParamValue(paramName)#
Sets the value of a parameter.
- Parameters:
paramName (str) β Name of the parameter
- setParams()#
- setWhiteList(value)#
Sets If defined, list of entities to process. The rest will be ignored. Do not include IOB prefix on labels
- Parameters:
value (List[str]) β If defined, list of entities to process. The rest will be ignored. Do not include IOB prefix on labels
- transform(dataset: pyspark.sql.dataframe.DataFrame, params: pyspark.ml._typing.ParamMap | None = None) pyspark.sql.dataframe.DataFrame #
Transforms the input dataset with optional parameters.
New in version 1.3.0.
- Parameters:
dataset (
pyspark.sql.DataFrame
) β input datasetparams (dict, optional) β an optional param map that overrides embedded params.
- Returns:
transformed dataset
- Return type:
- write() JavaMLWriter #
Returns an MLWriter instance for this ML instance.