sparknlp_jsl.annotator.classification.few_shot_classifier
#
Module Contents#
Classes#
FewShotClassifier is an implementation of SetFIT (Tunstall et al, 2022). The input to the model is sentence |
|
FewShotClassifier is an implementation of SetFIT (Tunstall et al, 2022). The input to the model is sentence |
- class FewShotClassifierApproach(classname='com.johnsnowlabs.nlp.annotators.classification.FewShotClassifierApproach')#
Bases:
sparknlp_jsl.annotator.classification.generic_log_reg_classifier.GenericLogRegClassifierApproach
FewShotClassifier is an implementation of SetFIT (Tunstall et al, 2022). The input to the model is sentence embeddings and the output is category annotations with labels and corresponding confidence scores varying between 0 and 1.
Input Annotation types
Output Annotation type
SENTENCE_EMBEDDINGS
CATEGORY
- batchSize#
- datasetInfo#
- doExceptionHandling#
- dropout#
- engine#
- epochsN#
- featureScaling#
- fixImbalance#
- getter_attrs = []#
- inputAnnotatorTypes#
- inputCols#
- labelColumn#
- lazyAnnotator#
- learningRate#
- modelFile#
- multiClass#
- name = 'FewShotClassifierApproach'#
- optionalInputAnnotatorTypes = []#
- outputAnnotatorType = 'category'#
- outputCol#
- outputLogsPath#
- skipLPInputColsValidation = True#
- uid = ''#
- validationSplit#
- clear(param: pyspark.ml.param.Param) None #
Clears a param from the param map if it has been explicitly set.
- copy(extra: pyspark.ml._typing.ParamMap | None = None) JP #
Creates a copy of this instance with the same uid and some extra params. This implementation first calls Params.copy and then make a copy of the companion Java pipeline component with extra params. So both the Python wrapper and the Java pipeline component get copied.
- Parameters:
extra (dict, optional) – Extra parameters to copy to the new instance
- Returns:
Copy of this instance
- Return type:
JavaParams
- explainParam(param: str | Param) str #
Explains a single param and returns its name, doc, and optional default value and user-supplied value in a string.
- explainParams() str #
Returns the documentation of all params with their optionally default values and user-supplied values.
- extractParamMap(extra: pyspark.ml._typing.ParamMap | None = None) pyspark.ml._typing.ParamMap #
Extracts the embedded default param values and user-supplied values, and then merges them with extra values from input into a flat param map, where the latter value is used if there exist conflicts, i.e., with ordering: default param values < user-supplied values < extra.
- Parameters:
extra (dict, optional) – extra param values
- Returns:
merged param map
- Return type:
dict
- fit(dataset: pyspark.sql.dataframe.DataFrame, params: pyspark.ml._typing.ParamMap | None = ...) M #
- fit(dataset: pyspark.sql.dataframe.DataFrame, params: List[pyspark.ml._typing.ParamMap] | Tuple[pyspark.ml._typing.ParamMap]) List[M]
Fits a model to the input dataset with optional parameters.
New in version 1.3.0.
- Parameters:
dataset (
pyspark.sql.DataFrame
) – input dataset.params (dict or list or tuple, optional) – an optional param map that overrides embedded params. If a list/tuple of param maps is given, this calls fit on each param map and returns a list of models.
- Returns:
fitted model(s)
- Return type:
Transformer
or a list ofTransformer
- fitMultiple(dataset: pyspark.sql.dataframe.DataFrame, paramMaps: Sequence[pyspark.ml._typing.ParamMap]) Iterator[Tuple[int, M]] #
Fits a model to the input dataset for each param map in paramMaps.
New in version 2.3.0.
- Parameters:
dataset (
pyspark.sql.DataFrame
) – input dataset.paramMaps (
collections.abc.Sequence
) – A Sequence of param maps.
- Returns:
A thread safe iterable which contains one model for each param map. Each call to next(modelIterator) will return (index, model) where model was fit using paramMaps[index]. index values may not be sequential.
- Return type:
_FitMultipleIterator
- getEngine()#
- Returns:
Deep Learning engine used for this model”
- Return type:
str
- getInputCols()#
Gets current column names of input annotations.
- getLazyAnnotator()#
Gets whether Annotator should be evaluated lazily in a RecursivePipeline.
- getOrDefault(param: str) Any #
- getOrDefault(param: Param[T]) T
Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither is set.
- getOutputCol()#
Gets output column name of annotations.
- getParam(paramName: str) Param #
Gets a param by its name.
- getParamValue(paramName)#
Gets the value of a parameter.
- Parameters:
paramName (str) – Name of the parameter
- hasDefault(param: str | Param[Any]) bool #
Checks whether a param has a default value.
- hasParam(paramName: str) bool #
Tests whether this instance contains a param with a given (string) name.
- inputColsValidation(value)#
- isDefined(param: str | Param[Any]) bool #
Checks whether a param is explicitly set by user or has a default value.
- isSet(param: str | Param[Any]) bool #
Checks whether a param is explicitly set by user.
- classmethod load(path: str) RL #
Reads an ML instance from the input path, a shortcut of read().load(path).
- classmethod read()#
Returns an MLReader instance for this class.
- save(path: str) None #
Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
- set(param: Param, value: Any) None #
Sets a parameter in the embedded param map.
- setBatchSize(size: int)#
Size for each batch in the optimization process
- Parameters:
size (int) – Size for each batch in the optimization process
- setDatasetInfo(info: str)#
Sets descriptive information about the dataset being used.
- Parameters:
info (str) – Descriptive information about the dataset being used.
- setDoExceptionHandling(value: bool)#
If True, exceptions are handled. If exception causing data is passed to the model, a error annotation is emitted which has the exception message. Processing continues with the next one. This comes with a performance penalty.
- Parameters:
value (bool) – If True, exceptions are handled.
- setDropout(dropout: float)#
Sets drouptup
- Parameters:
dropout (float) – Dropout at the output of each layer
- setEpochsNumber(epochs: int)#
Sets number of epochs for the optimization process
- Parameters:
epochs (int) – Number of epochs for the optimization process
- setFeatureScaling(feature_scaling: str)#
Sets Feature scaling method. Possible values are ‘zscore’, ‘minmax’ or empty (no scaling
- Parameters:
feature_scaling (str) – Feature scaling method. Possible values are ‘zscore’, ‘minmax’ or empty (no scaling
- setFixImbalance(fix_imbalance: bool)#
Sets A flag indicating whenther to balance the trainig set.
- Parameters:
fix_imbalance (bool) – A flag indicating whenther to balance the trainig set.
- setForceInputTypeValidation(etfm)#
- setInputCols(*value)#
Sets column names of input annotations.
- Parameters:
*value (List[str]) – Input columns for the annotator
- setLabelCol(label_column: str)#
Sets Size for each batch in the optimization process
- Parameters:
label_column (str) – Column with the value result we are trying to predict.
- setLazyAnnotator(value)#
Sets whether Annotator should be evaluated lazily in a RecursivePipeline.
- Parameters:
value (bool) – Whether Annotator should be evaluated lazily in a RecursivePipeline
- setLearningRate(learning_rate: float)#
Sets learning rate for the optimization process
- Parameters:
learning_rate (float) – Learning rate for the optimization process
- setModelFile(mode_file: str)#
Sets file name to load the mode from”
- Parameters:
label (str) – File name to load the mode from”
- setMultiClass(value: bool)#
Sets the model in multi class prediction mode (Default: false)
- Parameters:
value (bool) – Whether to return only the label with the highest confidence score or all labels
- setOutputCol(value)#
Sets output column name of annotations.
- Parameters:
value (str) – Name of output column
- setOutputLogsPath(output_logs_path: str)#
Sets path to folder where logs will be saved. If no path is specified, no logs are generated
- Parameters:
output_logs_path (str) – Path to folder where logs will be saved. If no path is specified, no logs are generated
- setParamValue(paramName)#
Sets the value of a parameter.
- Parameters:
paramName (str) – Name of the parameter
- setValidationSplit(validation_split: float)#
Sets validaiton split - how much data to use for validation
- Parameters:
validation_split (float) – Validaiton split - how much data to use for validation
- write() JavaMLWriter #
Returns an MLWriter instance for this ML instance.
- class FewShotClassifierModel(classname='com.johnsnowlabs.nlp.annotators.classification.FewShotClassifierModel', java_model=None)#
Bases:
sparknlp_jsl.annotator.classification.generic_log_reg_classifier.GenericLogRegClassifierModel
FewShotClassifier is an implementation of SetFIT (Tunstall et al, 2022). The input to the model is sentence embeddings and the output is category annotations with labels and corresponding confidence scores varying between 0 and 1.
Input Annotation types
Output Annotation type
SENTENCE_EMBEDDINGS
CATEGORY
- classes#
- datasetInfo#
- doExceptionHandling#
- featureScaling#
- getter_attrs = []#
- inputAnnotatorTypes#
- inputCols#
- lazyAnnotator#
- multiClass#
- name = 'FewShotClassifierModel'#
- optionalInputAnnotatorTypes = []#
- outputAnnotatorType = 'category'#
- outputCol#
- skipLPInputColsValidation = True#
- uid = ''#
- clear(param: pyspark.ml.param.Param) None #
Clears a param from the param map if it has been explicitly set.
- copy(extra: pyspark.ml._typing.ParamMap | None = None) JP #
Creates a copy of this instance with the same uid and some extra params. This implementation first calls Params.copy and then make a copy of the companion Java pipeline component with extra params. So both the Python wrapper and the Java pipeline component get copied.
- Parameters:
extra (dict, optional) – Extra parameters to copy to the new instance
- Returns:
Copy of this instance
- Return type:
JavaParams
- explainParam(param: str | Param) str #
Explains a single param and returns its name, doc, and optional default value and user-supplied value in a string.
- explainParams() str #
Returns the documentation of all params with their optionally default values and user-supplied values.
- extractParamMap(extra: pyspark.ml._typing.ParamMap | None = None) pyspark.ml._typing.ParamMap #
Extracts the embedded default param values and user-supplied values, and then merges them with extra values from input into a flat param map, where the latter value is used if there exist conflicts, i.e., with ordering: default param values < user-supplied values < extra.
- Parameters:
extra (dict, optional) – extra param values
- Returns:
merged param map
- Return type:
dict
- getInputCols()#
Gets current column names of input annotations.
- getLazyAnnotator()#
Gets whether Annotator should be evaluated lazily in a RecursivePipeline.
- getOrDefault(param: str) Any #
- getOrDefault(param: Param[T]) T
Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither is set.
- getOutputCol()#
Gets output column name of annotations.
- getParam(paramName: str) Param #
Gets a param by its name.
- getParamValue(paramName)#
Gets the value of a parameter.
- Parameters:
paramName (str) – Name of the parameter
- hasDefault(param: str | Param[Any]) bool #
Checks whether a param has a default value.
- hasParam(paramName: str) bool #
Tests whether this instance contains a param with a given (string) name.
- inputColsValidation(value)#
- isDefined(param: str | Param[Any]) bool #
Checks whether a param is explicitly set by user or has a default value.
- isSet(param: str | Param[Any]) bool #
Checks whether a param is explicitly set by user.
- classmethod load(path: str) RL #
Reads an ML instance from the input path, a shortcut of read().load(path).
- static pretrained(name, lang='en', remote_loc='clinical/models')#
Downloads and loads a pretrained model.
- Parameters:
name (str, optional) – Name of the pretrained model
lang (str, optional) – Language of the pretrained model, by default “en”
remote_loc (str, optional) – Optional remote address of the resource, by default “clinical/models”. Will use Spark NLPs repositories otherwise.
- Returns:
The restored model
- Return type:
- classmethod read()#
Returns an MLReader instance for this class.
- save(path: str) None #
Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
- set(param: Param, value: Any) None #
Sets a parameter in the embedded param map.
- setDatasetInfo(info: str)#
Sets descriptive information about the dataset being used.
- Parameters:
info (str) – Descriptive information about the dataset being used.
- setDoExceptionHandling(value: bool)#
If True, exceptions are handled. If exception causing data is passed to the model, a error annotation is emitted which has the exception message. Processing continues with the next one. This comes with a performance penalty.
- Parameters:
value (bool) – If True, exceptions are handled.
- setFeatureScaling(feature_scaling: str)#
Sets Feature scaling method. Possible values are ‘zscore’, ‘minmax’ or empty (no scaling)
- Parameters:
feature_scaling (str) – Feature scaling method. Possible values are ‘zscore’, ‘minmax’ or empty (no scaling)
- setForceInputTypeValidation(etfm)#
- setInputCols(*value)#
Sets column names of input annotations.
- Parameters:
*value (List[str]) – Input columns for the annotator
- setLazyAnnotator(value)#
Sets whether Annotator should be evaluated lazily in a RecursivePipeline.
- Parameters:
value (bool) – Whether Annotator should be evaluated lazily in a RecursivePipeline
- setMultiClass(value: bool)#
Sets the model in multi class prediction mode (Default: false)
- Parameters:
value (bool) – Whether to return only the label with the highest confidence score or all labels
- setOutputCol(value)#
Sets output column name of annotations.
- Parameters:
value (str) – Name of output column
- setParamValue(paramName)#
Sets the value of a parameter.
- Parameters:
paramName (str) – Name of the parameter
- setParams()#
- transform(dataset: pyspark.sql.dataframe.DataFrame, params: pyspark.ml._typing.ParamMap | None = None) pyspark.sql.dataframe.DataFrame #
Transforms the input dataset with optional parameters.
New in version 1.3.0.
- Parameters:
dataset (
pyspark.sql.DataFrame
) – input datasetparams (dict, optional) – an optional param map that overrides embedded params.
- Returns:
transformed dataset
- Return type:
- write() JavaMLWriter #
Returns an MLWriter instance for this ML instance.