sparknlp_jsl.annotator.ner.pretrained_zero_shot_multi_task#
Module Contents#
Classes#
Zero-shot multi-task information extraction. |
- class PretrainedZeroShotMultiTask(classname='com.johnsnowlabs.nlp.annotators.ner.PretrainedZeroShotMultiTask', java_model=None)#
Bases:
sparknlp_jsl.common.AnnotatorModelInternal,sparknlp_jsl.common.HasBatchedAnnotate,sparknlp_jsl.common.HasEngineZero-shot multi-task information extraction.
Performs up to four extraction tasks simultaneously from a single document in a single forward pass:
Named entity extraction — spans of text matching a given type
Relation extraction — (head, tail) span pairs for a given relation type
Classification — document-level or sentence-level label assignment
Structured extraction — structured records with typed fields extracted from text
All tasks are defined via a compact
::DSL and can be combined freely. Tasks are zero-shot: no fine-tuning is needed.DSL syntax
Entities — each entry is a string:
"name","name::dtype","name::description", or"name::dtype::description"where dtype is"list"(default) or"str"(single best span).Relations — each entry is
"relation_name"or"relation_name::description".Classifications — each entry is a
(taskSpec, [labelSpec, ...])tuple where taskSpec is"task_name"(single-label) or"task_name::multi"(multi-label) and each labelSpec is"label"or"label::description".Structures — each entry is a
(structureName, [fieldSpec, ...])tuple. Fields support"field_name","field_name::dtype","field_name::description","field_name::dtype::description", or"field_name::[choice1|choice2]".Output (all tasks share one output column):
Entities →
annotatorType = "chunk",result= span text,metadatahasentity,confidence,sentenceClassifications →
annotatorType = "category",result= label,metadatahasconfidence,task,sentenceRelations →
annotatorType = "category",result= relation name,metadatahaschunk1,chunk2,entity1,entity2,entity1_begin,entity1_end,entity2_begin,entity2_end,chunk1_confidence,chunk2_confidence,sentenceStructures →
annotatorType = "struct",result= structure name,metadatahas one key per field (JSON-encoded) plusinstance_idx,sentence
Input Annotation types
Output Annotation type
DOCUMENTmulti- Parameters:
entities (list[str]) – Entity specifications in DSL format.
entityThreshold (float) – Minimum confidence for entity spans (default: 0.5).
classificationThreshold (float) – Minimum confidence for classification labels (default: 0.5).
relations (list[str]) – Relation specifications in DSL format.
relationThreshold (float) – Minimum confidence for relation spans (default: 0.5).
structureThreshold (float) – Minimum confidence for structure field spans (default: 0.5).
Examples
>>> from sparknlp_jsl.annotator import PretrainedZeroShotMultiTask >>> document_assembler = DocumentAssembler() \ ... .setInputCol("text") \ ... .setOutputCol("document") >>> zero_shot = PretrainedZeroShotMultiTask.pretrained() \ ... .setInputCols(["document"]) \ ... .setOutputCol("extractions") \ ... .setEntities(["person", "company::str", "product::List of products"]) \ ... .setClassifications([ ... ("sentiment", ["positive", "negative", "neutral"]), ... ("topic::multi", ["finance", "technology"]), ... ]) \ ... .setRelations(["works_for", "founded"]) \ ... .setStructures([ ... ("invoice", ["vendor::str", "amount::str", "items::list"]), ... ]) \ ... .setEntityThreshold(0.5) \ ... .setRelationThreshold(0.6)
- batchSize#
- classificationThreshold :sparknlp_jsl.common.Param#
- classifications :sparknlp_jsl.common.Param#
- entities :sparknlp_jsl.common.Param#
- entityThreshold :sparknlp_jsl.common.Param#
- getter_attrs = []#
- inputAnnotatorTypes#
- inputCols#
- lazyAnnotator#
- name = 'PretrainedZeroShotMultiTask'#
- optionalInputAnnotatorTypes = []#
- outputAnnotatorType = 'multi'#
- outputCol#
- relationThreshold :sparknlp_jsl.common.Param#
- relations :sparknlp_jsl.common.Param#
- skipLPInputColsValidation = True#
- structureThreshold :sparknlp_jsl.common.Param#
- structures :sparknlp_jsl.common.Param#
- uid = ''#
- clear(param: pyspark.ml.param.Param) None#
Clears a param from the param map if it has been explicitly set.
- copy(extra: pyspark.ml._typing.ParamMap | None = None) JP#
Creates a copy of this instance with the same uid and some extra params. This implementation first calls Params.copy and then make a copy of the companion Java pipeline component with extra params. So both the Python wrapper and the Java pipeline component get copied.
- Parameters:
extra (dict, optional) – Extra parameters to copy to the new instance
- Returns:
Copy of this instance
- Return type:
JavaParams
- explainParam(param: str | Param) str#
Explains a single param and returns its name, doc, and optional default value and user-supplied value in a string.
- explainParams() str#
Returns the documentation of all params with their optionally default values and user-supplied values.
- extractParamMap(extra: pyspark.ml._typing.ParamMap | None = None) pyspark.ml._typing.ParamMap#
Extracts the embedded default param values and user-supplied values, and then merges them with extra values from input into a flat param map, where the latter value is used if there exist conflicts, i.e., with ordering: default param values < user-supplied values < extra.
- Parameters:
extra (dict, optional) – extra param values
- Returns:
merged param map
- Return type:
dict
- getBatchSize()#
Gets current batch size.
- Returns:
Current batch size
- Return type:
int
- getClassificationThreshold() float#
Return classification confidence threshold.
- getClassifications() list#
Return classification task specifications.
- getEntities() list#
Return entity specifications.
- getEntityThreshold() float#
Return entity confidence threshold.
- getInputCols()#
Gets current column names of input annotations.
- getLazyAnnotator()#
Gets whether Annotator should be evaluated lazily in a RecursivePipeline.
- getOrDefault(param: str) Any#
- getOrDefault(param: Param[T]) T
Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither is set.
- getOutputCol()#
Gets output column name of annotations.
- getParam(paramName: str) Param#
Gets a param by its name.
- getParamValue(paramName)#
Gets the value of a parameter.
- Parameters:
paramName (str) – Name of the parameter
- getRelationThreshold() float#
Return relation confidence threshold.
- getRelations() list#
Return relation specifications.
- getStructureThreshold() float#
Return structure confidence threshold.
- getStructures() list#
Return structure specifications.
- hasDefault(param: str | Param[Any]) bool#
Checks whether a param has a default value.
- hasParam(paramName: str) bool#
Tests whether this instance contains a param with a given (string) name.
- inputColsValidation(value)#
- isDefined(param: str | Param[Any]) bool#
Checks whether a param is explicitly set by user or has a default value.
- isSet(param: str | Param[Any]) bool#
Checks whether a param is explicitly set by user.
- classmethod load(path: str) RL#
Reads an ML instance from the input path, a shortcut of read().load(path).
- static loadSavedModel(folder, spark_session)#
Load a locally saved PretrainedZeroShotMultiTask model.
- Parameters:
folder (str) – Path to the saved model directory.
spark_session (pyspark.sql.SparkSession) – The current SparkSession.
- Return type:
- static pretrained(name='zeroshot_multitask_base', lang='en', remote_loc='clinical/models')#
Download a pre-trained PretrainedZeroShotMultiTask model.
- Parameters:
name (str) – Name of the pre-trained model.
lang (str) – Language of the pre-trained model, by default
"en".remote_loc (str) – Remote location, by default
"clinical/models".
- Return type:
- classmethod read()#
Returns an MLReader instance for this class.
- save(path: str) None#
Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
- set(param: Param, value: Any) None#
Sets a parameter in the embedded param map.
- setBatchSize(v)#
Sets batch size.
- Parameters:
v (int) – Batch size
- setClassificationThreshold(value: float)#
Set minimum confidence threshold for classification (default: 0.5).
- Parameters:
value (float) – Confidence threshold in [0, 1].
- setClassifications(classification_labels: list)#
Set classification task specifications in DSL format.
- Parameters:
classification_labels (list[tuple[str, list[str]]]) – Each element is a
(taskSpec, [labelSpec, ...])pair.taskSpecis"task_name"or"task_name::multi". EachlabelSpecis"label"or"label::description".
Examples
>>> annotator.setClassifications([ ... ("sentiment", ["positive", "negative", "neutral"]), ... ("topic::multi", ["finance::Financial content", "technology"]), ... ])
- setEntities(entities: list)#
Set entity specifications in DSL format.
- Parameters:
entities (list[str]) – Each string is
"name","name::dtype","name::description", or"name::dtype::description".
- setEntityThreshold(value: float)#
Set minimum confidence threshold for entity extraction (default: 0.5).
- Parameters:
value (float) – Confidence threshold in [0, 1].
- setForceInputTypeValidation(etfm)#
- setInputCols(*value)#
Sets column names of input annotations.
- Parameters:
*value (List[str]) – Input columns for the annotator
- setLazyAnnotator(value)#
Sets whether Annotator should be evaluated lazily in a RecursivePipeline.
- Parameters:
value (bool) – Whether Annotator should be evaluated lazily in a RecursivePipeline
- setOutputCol(value)#
Sets output column name of annotations.
- Parameters:
value (str) – Name of output column
- setParamValue(paramName)#
Sets the value of a parameter.
- Parameters:
paramName (str) – Name of the parameter
- setParams()#
- setRelationThreshold(value: float)#
Set minimum confidence threshold for relation extraction (default: 0.5).
- Parameters:
value (float) – Confidence threshold in [0, 1].
- setRelations(relations: list)#
Set relation specifications in DSL format.
- Parameters:
relations (list[str]) – Each string is
"relation_name"or"relation_name::description".
- setStructureThreshold(value: float)#
Set minimum confidence threshold for structure field extraction (default: 0.5).
- Parameters:
value (float) – Confidence threshold in [0, 1].
- setStructures(structures: list)#
Set structure specifications in DSL format.
- Parameters:
structures (list[tuple[str, list[str]]]) – Each element is a
(structureName, [fieldSpec, ...])pair. Field specs support"field","field::dtype","field::description","field::dtype::description", or"field::[choice1|choice2]".
Examples
>>> annotator.setStructures([ ... ("product_info", [ ... "name::str", ... "price::str::Price including currency symbol", ... "features::list", ... "availability::[in_stock|pre_order|sold_out]", ... ]), ... ])
- transform(dataset: pyspark.sql.dataframe.DataFrame, params: pyspark.ml._typing.ParamMap | None = None) pyspark.sql.dataframe.DataFrame#
Transforms the input dataset with optional parameters.
New in version 1.3.0.
- Parameters:
dataset (
pyspark.sql.DataFrame) – input datasetparams (dict, optional) – an optional param map that overrides embedded params.
- Returns:
transformed dataset
- Return type:
- write() JavaMLWriter#
Returns an MLWriter instance for this ML instance.