sparknlp_jsl.annotator.RelationExtractionModel#

class sparknlp_jsl.annotator.RelationExtractionModel(classname='com.johnsnowlabs.nlp.annotators.re.RelationExtractionModel', java_model=None)[source]#

Bases: AnnotatorModel

Trains a TensorFlow model for relation extraction. The Tensorflow graph in .pb format needs to be specified with setModelFile. The result is a RelationExtractionModel. To start training, see the parameters that need to be set in the Parameters section.

Input Annotation types

Output Annotation type

WORD_EMBEDDINGS, POS, CHUNK, DEPENDENCY

CATEGORY

Parameters:
predictionThreshold

Minimal activation of the target unit to encode a new relation instance

relationPairs

List of dash-separated pairs of named entities (“ENTITY1-ENTITY2”, e.g. “Biomarker-RelativeDay”), which will be processed

relationPairsCaseSensitive

Determines whether relation pairs are case sensitive

maxSyntacticDistance

Maximal syntactic distance, as threshold (Default: 0)

Examples

>>> import sparknlp
>>> from sparknlp.base import *
>>> from sparknlp.common import *
>>> from sparknlp.annotator import *
>>> from sparknlp.training import *
>>> import sparknlp_jsl
>>> from sparknlp_jsl.base import *
>>> from sparknlp_jsl.annotator import *
>>> from pyspark.ml import Pipeline
>>> documentAssembler = DocumentAssembler() \
...   .setInputCol("text") \
...   .setOutputCol("document")
...
>>> tokenizer = Tokenizer() \
...   .setInputCols(["document"]) \
...   .setOutputCol("tokens")
...
>>> embedder = WordEmbeddingsModel     ...   .pretrained("embeddings_clinical", "en", "clinical/models") \
...   .setInputCols(["document", "tokens"]) \
...   .setOutputCol("embeddings")
...
>>> posTagger = PerceptronModel \
...   .pretrained("pos_clinical", "en", "clinical/models") \
...   .setInputCols(["document", "tokens"]) \
...   .setOutputCol("posTags")
...
>>> nerTagger = MedicalNerModel \
...   .pretrained("ner_events_clinical", "en", "clinical/models") \
...   .setInputCols(["document", "tokens", "embeddings"]) \
...   .setOutputCol("ner_tags")
...
>>> nerConverter = NerConverter() \
...   .setInputCols(["document", "tokens", "ner_tags"]) \
...   .setOutputCol("nerChunks")
...
>>> depencyParser = DependencyParserModel \
...   .pretrained("dependency_conllu", "en") \
...   .setInputCols(["document", "posTags", "tokens"]) \
...   .setOutputCol("dependencies")
...
>>> relationPairs = [
...   "direction-external_body_part_or_region",
...   "external_body_part_or_region-direction",
...   "direction-internal_organ_or_component",
...   "internal_organ_or_component-direction"
... ]
...
>>> re_model = RelationExtractionModel.pretrained("re_bodypart_directions", "en", "clinical/models") \
...     .setInputCols(["embeddings", "pos_tags", "ner_chunks", "dependencies"]) \
...     .setOutputCol("relations") \
...     .setRelationPairs(relationPairs) \
...     .setMaxSyntacticDistance(4) \
...     .setPredictionThreshold(0.9)
...
>>> pipeline = Pipeline(stages=[
...     documentAssembler,
...     tokenizer,
...     embedder,
...     posTagger,
...     nerTagger,
...     nerConverter,
...     depencyParser,
...     re_model])
>>> model = pipeline.fit(trainData)
>>> data = spark.createDataFrame([["MRI demonstrated infarction in the upper brain stem , left cerebellum and  right basil ganglia"]]).toDF("text")
>>> result = pipeline.fit(data).transform(data)
...
>>> result.selectExpr("explode(relations) as relations")
...  .select(
...    "relations.metadata.chunk1",
...    "relations.metadata.entity1",
...    "relations.metadata.chunk2",
...    "relations.metadata.entity2",
...    "relations.result"
...  )
...  .where("result != 0")
...  .show(truncate=False)
...
... # Show results
... result.selectExpr("explode(relations) as relations") \
...   .select(
...      "relations.metadata.chunk1",
...      "relations.metadata.entity1",
...      "relations.metadata.chunk2",
...      "relations.metadata.entity2",
...      "relations.result"
...   ).where("result != 0")     ...   .show(truncate=False)
+------+---------+-------------+---------------------------+------+
|chunk1|entity1  |chunk2       |entity2                    |result|
+------+---------+-------------+---------------------------+------+
|upper |Direction|brain stem   |Internal_organ_or_component|1     |
|left  |Direction|cerebellum   |Internal_organ_or_component|1     |
|right |Direction|basil ganglia|Internal_organ_or_component|1     |
+------+---------+-------------+---------------------------+------+

Methods

__init__([classname, java_model])

Initialize this instance with a Java model object.

clear(param)

Clears a param from the param map if it has been explicitly set.

copy([extra])

Creates a copy of this instance with the same uid and some extra params.

explainParam(param)

Explains a single param and returns its name, doc, and optional default value and user-supplied value in a string.

explainParams()

Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap([extra])

Extracts the embedded default param values and user-supplied values, and then merges them with extra values from input into a flat param map, where the latter value is used if there exist conflicts, i.e., with ordering: default param values < user-supplied values < extra.

getClasses()

Returns labels used to train this model

getInputCols()

Gets current column names of input annotations.

getLazyAnnotator()

Gets whether Annotator should be evaluated lazily in a RecursivePipeline.

getOrDefault(param)

Gets the value of a param in the user-supplied param map or its default value.

getOutputCol()

Gets output column name of annotations.

getParam(paramName)

Gets a param by its name.

getParamValue(paramName)

Gets the value of a parameter.

hasDefault(param)

Checks whether a param has a default value.

hasParam(paramName)

Tests whether this instance contains a param with a given (string) name.

isDefined(param)

Checks whether a param is explicitly set by user or has a default value.

isSet(param)

Checks whether a param is explicitly set by user.

load(path)

Reads an ML instance from the input path, a shortcut of read().load(path).

pretrained(name[, lang, remote_loc])

read()

Returns an MLReader instance for this class.

save(path)

Save this ML instance to the given path, a shortcut of 'write().save(path)'.

set(param, value)

Sets a parameter in the embedded param map.

setCustomLabels(labels)

Sets custom relation labels

setInputCols(*value)

Sets column names of input annotations.

setLazyAnnotator(value)

Sets whether Annotator should be evaluated lazily in a RecursivePipeline.

setMaxSyntacticDistance(distance)

Sets maximal syntactic distance, as threshold (Default: 0)

setOutputCol(value)

Sets output column name of annotations.

setParamValue(paramName)

Sets the value of a parameter.

setParams()

setPredictionThreshold(threshold)

Sets Minimal activation of the target unit to encode a new relation instance

setRelationPairs(pairs)

Sets List of dash-separated pairs of named entities ("ENTITY1-ENTITY2", e.g.

setRelationPairsCaseSensitive(value)

Sets the case sensitivity of relation pairs Parameters ---------- value : boolean whether relation pairs are case sensitive

transform(dataset[, params])

Transforms the input dataset with optional parameters.

write()

Returns an MLWriter instance for this ML instance.

Attributes

classes

customLabels

getter_attrs

inputCols

lazyAnnotator

maxSyntacticDistance

name

outputCol

params

Returns all params ordered by name.

predictionThreshold

relationPairs

relationPairsCaseSensitive

clear(param)#

Clears a param from the param map if it has been explicitly set.

copy(extra=None)#

Creates a copy of this instance with the same uid and some extra params. This implementation first calls Params.copy and then make a copy of the companion Java pipeline component with extra params. So both the Python wrapper and the Java pipeline component get copied.

Parameters:

extra – Extra parameters to copy to the new instance

Returns:

Copy of this instance

explainParam(param)#

Explains a single param and returns its name, doc, and optional default value and user-supplied value in a string.

explainParams()#

Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)#

Extracts the embedded default param values and user-supplied values, and then merges them with extra values from input into a flat param map, where the latter value is used if there exist conflicts, i.e., with ordering: default param values < user-supplied values < extra.

Parameters:

extra – extra param values

Returns:

merged param map

getClasses()[source]#

Returns labels used to train this model

getInputCols()#

Gets current column names of input annotations.

getLazyAnnotator()#

Gets whether Annotator should be evaluated lazily in a RecursivePipeline.

getOrDefault(param)#

Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither is set.

getOutputCol()#

Gets output column name of annotations.

getParam(paramName)#

Gets a param by its name.

getParamValue(paramName)#

Gets the value of a parameter.

Parameters:
paramNamestr

Name of the parameter

hasDefault(param)#

Checks whether a param has a default value.

hasParam(paramName)#

Tests whether this instance contains a param with a given (string) name.

isDefined(param)#

Checks whether a param is explicitly set by user or has a default value.

isSet(param)#

Checks whether a param is explicitly set by user.

classmethod load(path)#

Reads an ML instance from the input path, a shortcut of read().load(path).

property params#

Returns all params ordered by name. The default implementation uses dir() to get all attributes of type Param.

classmethod read()#

Returns an MLReader instance for this class.

save(path)#

Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)#

Sets a parameter in the embedded param map.

setCustomLabels(labels)[source]#

Sets custom relation labels

Parameters:
labelsdict[str, str]

Dictionary which maps old to new labels

setInputCols(*value)#

Sets column names of input annotations.

Parameters:
*valuestr

Input columns for the annotator

setLazyAnnotator(value)#

Sets whether Annotator should be evaluated lazily in a RecursivePipeline.

Parameters:
valuebool

Whether Annotator should be evaluated lazily in a RecursivePipeline

setMaxSyntacticDistance(distance)[source]#

Sets maximal syntactic distance, as threshold (Default: 0)

Parameters:
bint

Maximal syntactic distance, as threshold (Default: 0)

setOutputCol(value)#

Sets output column name of annotations.

Parameters:
valuestr

Name of output column

setParamValue(paramName)#

Sets the value of a parameter.

Parameters:
paramNamestr

Name of the parameter

setPredictionThreshold(threshold)[source]#

Sets Minimal activation of the target unit to encode a new relation instance

Parameters:
thresholdfloat

Minimal activation of the target unit to encode a new relation instance

setRelationPairs(pairs)[source]#

Sets List of dash-separated pairs of named entities (“ENTITY1-ENTITY2”, e.g. “Biomarker-RelativeDay”), which will be processed

Parameters:
pairsstr

List of dash-separated pairs of named entities (“ENTITY1-ENTITY2”, e.g. “Biomarker-RelativeDay”), which will be processed

setRelationPairsCaseSensitive(value)[source]#

Sets the case sensitivity of relation pairs Parameters ———- value : boolean

whether relation pairs are case sensitive

transform(dataset, params=None)#

Transforms the input dataset with optional parameters.

Parameters:
  • dataset – input dataset, which is an instance of pyspark.sql.DataFrame

  • params – an optional param map that overrides embedded params.

Returns:

transformed dataset

New in version 1.3.0.

uid#

A unique id for the object.

write()#

Returns an MLWriter instance for this ML instance.