sparknlp_jsl.annotator.classification.few_shot_classifier#

Module Contents#

Classes#

FewShotClassifierApproach

FewShotClassifier is an implementation of SetFIT (Tunstall et al, 2022). The input to the model is sentence

FewShotClassifierModel

FewShotClassifier is an implementation of SetFIT (Tunstall et al, 2022). The input to the model is sentence

class FewShotClassifierApproach(classname='com.johnsnowlabs.nlp.annotators.classification.FewShotClassifierApproach')#

Bases: sparknlp_jsl.annotator.classification.generic_log_reg_classifier.GenericLogRegClassifierApproach

FewShotClassifier is an implementation of SetFIT (Tunstall et al, 2022). The input to the model is sentence embeddings and the output is category annotations with labels and corresponding confidence scores varying between 0 and 1.

Input Annotation types

Output Annotation type

SENTENCE_EMBEDDINGS

CATEGORY

batchSize#
datasetInfo#
doExceptionHandling#
dropout#
engine#
epochsN#
featureScaling#
fixImbalance#
getter_attrs = []#
inputAnnotatorTypes#
inputCols#
labelColumn#
lazyAnnotator#
learningRate#
modelFile#
multiClass#
name = 'FewShotClassifierApproach'#
optionalInputAnnotatorTypes = []#
outputAnnotatorType = 'category'#
outputCol#
outputLogsPath#
skipLPInputColsValidation = True#
uid = ''#
validationSplit#
clear(param: pyspark.ml.param.Param) None#

Clears a param from the param map if it has been explicitly set.

copy(extra: pyspark.ml._typing.ParamMap | None = None) JP#

Creates a copy of this instance with the same uid and some extra params. This implementation first calls Params.copy and then make a copy of the companion Java pipeline component with extra params. So both the Python wrapper and the Java pipeline component get copied.

Parameters:

extra (dict, optional) – Extra parameters to copy to the new instance

Returns:

Copy of this instance

Return type:

JavaParams

explainParam(param: str | Param) str#

Explains a single param and returns its name, doc, and optional default value and user-supplied value in a string.

explainParams() str#

Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra: pyspark.ml._typing.ParamMap | None = None) pyspark.ml._typing.ParamMap#

Extracts the embedded default param values and user-supplied values, and then merges them with extra values from input into a flat param map, where the latter value is used if there exist conflicts, i.e., with ordering: default param values < user-supplied values < extra.

Parameters:

extra (dict, optional) – extra param values

Returns:

merged param map

Return type:

dict

fit(dataset: pyspark.sql.dataframe.DataFrame, params: pyspark.ml._typing.ParamMap | None = ...) M#
fit(dataset: pyspark.sql.dataframe.DataFrame, params: List[pyspark.ml._typing.ParamMap] | Tuple[pyspark.ml._typing.ParamMap]) List[M]

Fits a model to the input dataset with optional parameters.

New in version 1.3.0.

Parameters:
  • dataset (pyspark.sql.DataFrame) – input dataset.

  • params (dict or list or tuple, optional) – an optional param map that overrides embedded params. If a list/tuple of param maps is given, this calls fit on each param map and returns a list of models.

Returns:

fitted model(s)

Return type:

Transformer or a list of Transformer

fitMultiple(dataset: pyspark.sql.dataframe.DataFrame, paramMaps: Sequence[pyspark.ml._typing.ParamMap]) Iterator[Tuple[int, M]]#

Fits a model to the input dataset for each param map in paramMaps.

New in version 2.3.0.

Parameters:
  • dataset (pyspark.sql.DataFrame) – input dataset.

  • paramMaps (collections.abc.Sequence) – A Sequence of param maps.

Returns:

A thread safe iterable which contains one model for each param map. Each call to next(modelIterator) will return (index, model) where model was fit using paramMaps[index]. index values may not be sequential.

Return type:

_FitMultipleIterator

getEngine()#
Returns:

Deep Learning engine used for this model”

Return type:

str

getInputCols()#

Gets current column names of input annotations.

getLazyAnnotator()#

Gets whether Annotator should be evaluated lazily in a RecursivePipeline.

getOrDefault(param: str) Any#
getOrDefault(param: Param[T]) T

Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither is set.

getOutputCol()#

Gets output column name of annotations.

getParam(paramName: str) Param#

Gets a param by its name.

getParamValue(paramName)#

Gets the value of a parameter.

Parameters:

paramName (str) – Name of the parameter

hasDefault(param: str | Param[Any]) bool#

Checks whether a param has a default value.

hasParam(paramName: str) bool#

Tests whether this instance contains a param with a given (string) name.

inputColsValidation(value)#
isDefined(param: str | Param[Any]) bool#

Checks whether a param is explicitly set by user or has a default value.

isSet(param: str | Param[Any]) bool#

Checks whether a param is explicitly set by user.

classmethod load(path: str) RL#

Reads an ML instance from the input path, a shortcut of read().load(path).

classmethod read()#

Returns an MLReader instance for this class.

save(path: str) None#

Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param: Param, value: Any) None#

Sets a parameter in the embedded param map.

setBatchSize(size: int)#

Size for each batch in the optimization process

Parameters:

size (int) – Size for each batch in the optimization process

setDatasetInfo(info: str)#

Sets descriptive information about the dataset being used.

Parameters:

info (str) – Descriptive information about the dataset being used.

setDoExceptionHandling(value: bool)#

If True, exceptions are handled. If exception causing data is passed to the model, a error annotation is emitted which has the exception message. Processing continues with the next one. This comes with a performance penalty.

Parameters:

value (bool) – If True, exceptions are handled.

setDropout(dropout: float)#

Sets drouptup

Parameters:

dropout (float) – Dropout at the output of each layer

setEpochsNumber(epochs: int)#

Sets number of epochs for the optimization process

Parameters:

epochs (int) – Number of epochs for the optimization process

setFeatureScaling(feature_scaling: str)#

Sets Feature scaling method. Possible values are ‘zscore’, ‘minmax’ or empty (no scaling

Parameters:

feature_scaling (str) – Feature scaling method. Possible values are ‘zscore’, ‘minmax’ or empty (no scaling

setFixImbalance(fix_imbalance: bool)#

Sets A flag indicating whenther to balance the trainig set.

Parameters:

fix_imbalance (bool) – A flag indicating whenther to balance the trainig set.

setForceInputTypeValidation(etfm)#
setInputCols(*value)#

Sets column names of input annotations.

Parameters:

*value (List[str]) – Input columns for the annotator

setLabelCol(label_column: str)#

Sets Size for each batch in the optimization process

Parameters:

label_column (str) – Column with the value result we are trying to predict.

setLazyAnnotator(value)#

Sets whether Annotator should be evaluated lazily in a RecursivePipeline.

Parameters:

value (bool) – Whether Annotator should be evaluated lazily in a RecursivePipeline

setLearningRate(learning_rate: float)#

Sets learning rate for the optimization process

Parameters:

learning_rate (float) – Learning rate for the optimization process

setModelFile(mode_file: str)#

Sets file name to load the mode from”

Parameters:

label (str) – File name to load the mode from”

setMultiClass(value: bool)#

Sets the model in multi class prediction mode (Default: false)

Parameters:

value (bool) – Whether to return only the label with the highest confidence score or all labels

setOutputCol(value)#

Sets output column name of annotations.

Parameters:

value (str) – Name of output column

setOutputLogsPath(output_logs_path: str)#

Sets path to folder where logs will be saved. If no path is specified, no logs are generated

Parameters:

output_logs_path (str) – Path to folder where logs will be saved. If no path is specified, no logs are generated

setParamValue(paramName)#

Sets the value of a parameter.

Parameters:

paramName (str) – Name of the parameter

setValidationSplit(validation_split: float)#

Sets validaiton split - how much data to use for validation

Parameters:

validation_split (float) – Validaiton split - how much data to use for validation

write() JavaMLWriter#

Returns an MLWriter instance for this ML instance.

class FewShotClassifierModel(classname='com.johnsnowlabs.nlp.annotators.classification.FewShotClassifierModel', java_model=None)#

Bases: sparknlp_jsl.annotator.classification.generic_log_reg_classifier.GenericLogRegClassifierModel

FewShotClassifier is an implementation of SetFIT (Tunstall et al, 2022). The input to the model is sentence embeddings and the output is category annotations with labels and corresponding confidence scores varying between 0 and 1.

Input Annotation types

Output Annotation type

SENTENCE_EMBEDDINGS

CATEGORY

classes#
datasetInfo#
doExceptionHandling#
featureScaling#
getter_attrs = []#
inputAnnotatorTypes#
inputCols#
lazyAnnotator#
multiClass#
name = 'FewShotClassifierModel'#
optionalInputAnnotatorTypes = []#
outputAnnotatorType = 'category'#
outputCol#
skipLPInputColsValidation = True#
uid = ''#
clear(param: pyspark.ml.param.Param) None#

Clears a param from the param map if it has been explicitly set.

copy(extra: pyspark.ml._typing.ParamMap | None = None) JP#

Creates a copy of this instance with the same uid and some extra params. This implementation first calls Params.copy and then make a copy of the companion Java pipeline component with extra params. So both the Python wrapper and the Java pipeline component get copied.

Parameters:

extra (dict, optional) – Extra parameters to copy to the new instance

Returns:

Copy of this instance

Return type:

JavaParams

explainParam(param: str | Param) str#

Explains a single param and returns its name, doc, and optional default value and user-supplied value in a string.

explainParams() str#

Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra: pyspark.ml._typing.ParamMap | None = None) pyspark.ml._typing.ParamMap#

Extracts the embedded default param values and user-supplied values, and then merges them with extra values from input into a flat param map, where the latter value is used if there exist conflicts, i.e., with ordering: default param values < user-supplied values < extra.

Parameters:

extra (dict, optional) – extra param values

Returns:

merged param map

Return type:

dict

getInputCols()#

Gets current column names of input annotations.

getLazyAnnotator()#

Gets whether Annotator should be evaluated lazily in a RecursivePipeline.

getOrDefault(param: str) Any#
getOrDefault(param: Param[T]) T

Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither is set.

getOutputCol()#

Gets output column name of annotations.

getParam(paramName: str) Param#

Gets a param by its name.

getParamValue(paramName)#

Gets the value of a parameter.

Parameters:

paramName (str) – Name of the parameter

hasDefault(param: str | Param[Any]) bool#

Checks whether a param has a default value.

hasParam(paramName: str) bool#

Tests whether this instance contains a param with a given (string) name.

inputColsValidation(value)#
isDefined(param: str | Param[Any]) bool#

Checks whether a param is explicitly set by user or has a default value.

isSet(param: str | Param[Any]) bool#

Checks whether a param is explicitly set by user.

classmethod load(path: str) RL#

Reads an ML instance from the input path, a shortcut of read().load(path).

static pretrained(name, lang='en', remote_loc='clinical/models')#

Downloads and loads a pretrained model.

Parameters:
  • name (str, optional) – Name of the pretrained model

  • lang (str, optional) – Language of the pretrained model, by default “en”

  • remote_loc (str, optional) – Optional remote address of the resource, by default “clinical/models”. Will use Spark NLPs repositories otherwise.

Returns:

The restored model

Return type:

FewShotClassifierModel

classmethod read()#

Returns an MLReader instance for this class.

save(path: str) None#

Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param: Param, value: Any) None#

Sets a parameter in the embedded param map.

setDatasetInfo(info: str)#

Sets descriptive information about the dataset being used.

Parameters:

info (str) – Descriptive information about the dataset being used.

setDoExceptionHandling(value: bool)#

If True, exceptions are handled. If exception causing data is passed to the model, a error annotation is emitted which has the exception message. Processing continues with the next one. This comes with a performance penalty.

Parameters:

value (bool) – If True, exceptions are handled.

setFeatureScaling(feature_scaling: str)#

Sets Feature scaling method. Possible values are ‘zscore’, ‘minmax’ or empty (no scaling)

Parameters:

feature_scaling (str) – Feature scaling method. Possible values are ‘zscore’, ‘minmax’ or empty (no scaling)

setForceInputTypeValidation(etfm)#
setInputCols(*value)#

Sets column names of input annotations.

Parameters:

*value (List[str]) – Input columns for the annotator

setLazyAnnotator(value)#

Sets whether Annotator should be evaluated lazily in a RecursivePipeline.

Parameters:

value (bool) – Whether Annotator should be evaluated lazily in a RecursivePipeline

setMultiClass(value: bool)#

Sets the model in multi class prediction mode (Default: false)

Parameters:

value (bool) – Whether to return only the label with the highest confidence score or all labels

setOutputCol(value)#

Sets output column name of annotations.

Parameters:

value (str) – Name of output column

setParamValue(paramName)#

Sets the value of a parameter.

Parameters:

paramName (str) – Name of the parameter

setParams()#
transform(dataset: pyspark.sql.dataframe.DataFrame, params: pyspark.ml._typing.ParamMap | None = None) pyspark.sql.dataframe.DataFrame#

Transforms the input dataset with optional parameters.

New in version 1.3.0.

Parameters:
  • dataset (pyspark.sql.DataFrame) – input dataset

  • params (dict, optional) – an optional param map that overrides embedded params.

Returns:

transformed dataset

Return type:

pyspark.sql.DataFrame

write() JavaMLWriter#

Returns an MLWriter instance for this ML instance.