sparknlp.annotator.ContextSpellCheckerModel

class sparknlp.annotator.ContextSpellCheckerModel(classname='com.johnsnowlabs.nlp.annotators.spell.context.ContextSpellCheckerModel', java_model=None)[source]

Bases: sparknlp.common.AnnotatorModel

Implements a deep-learning based Noisy Channel Model Spell Algorithm. Correction candidates are extracted combining context information and word information.

Spell Checking is a sequence to sequence mapping problem. Given an input sequence, potentially containing a certain number of errors, ContextSpellChecker will rank correction sequences according to three things:

  1. Different correction candidates for each word — word level.

  2. The surrounding text of each word, i.e. it’s context — sentence level.

  3. The relative cost of different correction candidates according to the edit operations at the character level it requires — subword level.

This is the instantiated model of the ContextSpellCheckerApproach. For training your own model, please see the documentation of that class.

Pretrained models can be loaded with pretrained() of the companion object:

>>> spellChecker = ContextSpellCheckerModel.pretrained() \
...     .setInputCols(["token"]) \
...     .setOutputCol("checked")

The default model is "spellcheck_dl", if no name is provided. For available pretrained models please see the Models Hub.

For extended examples of usage, see the Spark NLP Workshop.

Input Annotation types

Output Annotation type

TOKEN

TOKEN

Parameters
wordMaxDistance

Maximum distance for the generated candidates for every word.

maxCandidates

Maximum number of candidates for every word.

caseStrategy

What case combinations to try when generating candidates.

errorThreshold

Threshold perplexity for a word to be considered as an error.

tradeoff

Tradeoff between the cost of a word error and a transition in the language model.

weightedDistPath

The path to the file containing the weights for the levenshtein distance.

maxWindowLen

Maximum size for the window used to remember history prior to every correction.

gamma

Controls the influence of individual word frequency in the decision.

correctSymbols

Whether to correct special symbols or skip spell checking for them

compareLowcase

If true will compare tokens in low case with vocabulary

configProtoBytes

ConfigProto from tensorflow, serialized into byte array.

References

For an in-depth explanation of the module see the article Applying Context Aware Spell Checking in Spark NLP.

Examples

>>> import sparknlp
>>> from sparknlp.base import *
>>> from sparknlp.annotator import *
>>> from pyspark.ml import Pipeline
>>> documentAssembler = DocumentAssembler() \
...     .setInputCol("text") \
...     .setOutputCol("doc")
>>> tokenizer = Tokenizer() \
...     .setInputCols(["doc"]) \
...     .setOutputCol("token")
>>> spellChecker = ContextSpellCheckerModel \
...     .pretrained() \
...     .setTradeoff(12.0) \
...     .setInputCols("token") \
...     .setOutputCol("checked")
>>> pipeline = Pipeline().setStages([
...     documentAssembler,
...     tokenizer,
...     spellChecker
... ])
>>> data = spark.createDataFrame([["It was a cold , dreary day and the country was white with smow ."]]).toDF("text")
>>> result = pipeline.fit(data).transform(data)
>>> result.select("checked.result").show(truncate=False)
+--------------------------------------------------------------------------------+
|result                                                                          |
+--------------------------------------------------------------------------------+
|[It, was, a, cold, ,, dreary, day, and, the, country, was, white, with, snow, .]|
+--------------------------------------------------------------------------------+

Methods

__init__([classname, java_model])

Initialize this instance with a Java model object.

clear(param)

Clears a param from the param map if it has been explicitly set.

copy([extra])

Creates a copy of this instance with the same uid and some extra params.

explainParam(param)

Explains a single param and returns its name, doc, and optional default value and user-supplied value in a string.

explainParams()

Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap([extra])

Extracts the embedded default param values and user-supplied values, and then merges them with extra values from input into a flat param map, where the latter value is used if there exist conflicts, i.e., with ordering: default param values < user-supplied values < extra.

getInputCols()

Gets current column names of input annotations.

getLazyAnnotator()

Gets whether Annotator should be evaluated lazily in a RecursivePipeline.

getOrDefault(param)

Gets the value of a param in the user-supplied param map or its default value.

getOutputCol()

Gets output column name of annotations.

getParam(paramName)

Gets a param by its name.

getParamValue(paramName)

Gets the value of a parameter.

getWordClasses()

Gets the classes of words to be corrected.

hasDefault(param)

Checks whether a param has a default value.

hasParam(paramName)

Tests whether this instance contains a param with a given (string) name.

isDefined(param)

Checks whether a param is explicitly set by user or has a default value.

isSet(param)

Checks whether a param is explicitly set by user.

load(path)

Reads an ML instance from the input path, a shortcut of read().load(path).

pretrained([name, lang, remote_loc])

Downloads and loads a pretrained model.

read()

Returns an MLReader instance for this class.

save(path)

Save this ML instance to the given path, a shortcut of 'write().save(path)'.

set(param, value)

Sets a parameter in the embedded param map.

setCaseStrategy(strategy)

Sets what case combinations to try when generating candidates.

setCompareLowcase(value)

Sets whether to compare tokens in lower case with vocabulary.

setConfigProtoBytes(b)

Sets configProto from tensorflow, serialized into byte array.

setCorrectSymbols(value)

Sets whether to correct special symbols or skip spell checking for them.

setErrorThreshold(threshold)

Sets threshold perplexity for a word to be considered as an error.

setGamma(g)

Sets the influence of individual word frequency in the decision.

setInputCols(*value)

Sets column names of input annotations.

setLazyAnnotator(value)

Sets whether Annotator should be evaluated lazily in a RecursivePipeline.

setMaxCandidates(candidates)

Sets maximum number of candidates for every word.

setMaxWindowLen(length)

Sets the maximum size for the window used to remember history prior to every correction.

setOutputCol(value)

Sets output column name of annotations.

setParamValue(paramName)

Sets the value of a parameter.

setParams()

setTradeoff(alpha)

Sets tradeoff between the cost of a word error and a transition in the language model.

setWeights(weights)

Sets weights of each word for Levenshtein distance.

setWordMaxDistance(dist)

Sets maximum distance for the generated candidates for every word.

transform(dataset[, params])

Transforms the input dataset with optional parameters.

updateRegexClass(label, regex)

Update existing class to correct, based on regex

updateVocabClass(label, vocab[, append])

Update existing class to correct, based on a vocabulary.

write()

Returns an MLWriter instance for this ML instance.

Attributes

caseStrategy

compareLowcase

configProtoBytes

correctSymbols

errorThreshold

gamma

getter_attrs

inputCols

lazyAnnotator

maxCandidates

maxWindowLen

name

outputCol

params

Returns all params ordered by name.

tradeoff

weightedDistPath

wordMaxDistance

clear(param)

Clears a param from the param map if it has been explicitly set.

copy(extra=None)

Creates a copy of this instance with the same uid and some extra params. This implementation first calls Params.copy and then make a copy of the companion Java pipeline component with extra params. So both the Python wrapper and the Java pipeline component get copied.

Parameters

extra – Extra parameters to copy to the new instance

Returns

Copy of this instance

explainParam(param)

Explains a single param and returns its name, doc, and optional default value and user-supplied value in a string.

explainParams()

Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra=None)

Extracts the embedded default param values and user-supplied values, and then merges them with extra values from input into a flat param map, where the latter value is used if there exist conflicts, i.e., with ordering: default param values < user-supplied values < extra.

Parameters

extra – extra param values

Returns

merged param map

getInputCols()

Gets current column names of input annotations.

getLazyAnnotator()

Gets whether Annotator should be evaluated lazily in a RecursivePipeline.

getOrDefault(param)

Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither is set.

getOutputCol()

Gets output column name of annotations.

getParam(paramName)

Gets a param by its name.

getParamValue(paramName)

Gets the value of a parameter.

Parameters
paramNamestr

Name of the parameter

getWordClasses()[source]

Gets the classes of words to be corrected.

Returns
List[str]

Classes of words to be corrected

hasDefault(param)

Checks whether a param has a default value.

hasParam(paramName)

Tests whether this instance contains a param with a given (string) name.

isDefined(param)

Checks whether a param is explicitly set by user or has a default value.

isSet(param)

Checks whether a param is explicitly set by user.

classmethod load(path)

Reads an ML instance from the input path, a shortcut of read().load(path).

property params

Returns all params ordered by name. The default implementation uses dir() to get all attributes of type Param.

static pretrained(name='spellcheck_dl', lang='en', remote_loc=None)[source]

Downloads and loads a pretrained model.

Parameters
namestr, optional

Name of the pretrained model, by default “spellcheck_dl”

langstr, optional

Language of the pretrained model, by default “en”

remote_locstr, optional

Optional remote address of the resource, by default None. Will use Spark NLPs repositories otherwise.

Returns
ContextSpellCheckerModel

The restored model

classmethod read()

Returns an MLReader instance for this class.

save(path)

Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param, value)

Sets a parameter in the embedded param map.

setCaseStrategy(strategy)[source]

Sets what case combinations to try when generating candidates.

Parameters
strategyint

Case combinations to try when generating candidates.

setCompareLowcase(value)[source]

Sets whether to compare tokens in lower case with vocabulary.

Parameters
valuebool

Whether to compare tokens in lower case with vocabulary.

setConfigProtoBytes(b)[source]

Sets configProto from tensorflow, serialized into byte array.

Parameters
bList[str]

ConfigProto from tensorflow, serialized into byte array

setCorrectSymbols(value)[source]

Sets whether to correct special symbols or skip spell checking for them.

Parameters
valuebool

Whether to correct special symbols or skip spell checking for them

setErrorThreshold(threshold)[source]

Sets threshold perplexity for a word to be considered as an error.

Parameters
thresholdfloat

Threshold perplexity for a word to be considered as an error

setGamma(g)[source]

Sets the influence of individual word frequency in the decision.

Parameters
gfloat

Controls the influence of individual word frequency in the decision.

setInputCols(*value)

Sets column names of input annotations.

Parameters
*valuestr

Input columns for the annotator

setLazyAnnotator(value)

Sets whether Annotator should be evaluated lazily in a RecursivePipeline.

Parameters
valuebool

Whether Annotator should be evaluated lazily in a RecursivePipeline

setMaxCandidates(candidates)[source]

Sets maximum number of candidates for every word.

Parameters
candidatesint

Maximum number of candidates for every word.

setMaxWindowLen(length)[source]

Sets the maximum size for the window used to remember history prior to every correction.

Parameters
lengthint

Maximum size for the window used to remember history prior to every correction

setOutputCol(value)

Sets output column name of annotations.

Parameters
valuestr

Name of output column

setParamValue(paramName)

Sets the value of a parameter.

Parameters
paramNamestr

Name of the parameter

setTradeoff(alpha)[source]

Sets tradeoff between the cost of a word error and a transition in the language model.

Parameters
alphafloat

Tradeoff between the cost of a word error and a transition in the language model

setWeights(weights)[source]

Sets weights of each word for Levenshtein distance.

Parameters
weightsDict[str, float]

Weights for Levenshtein distance as a maping.

setWordMaxDistance(dist)[source]

Sets maximum distance for the generated candidates for every word.

Parameters
distint

Maximum distance for the generated candidates for every word.

transform(dataset, params=None)

Transforms the input dataset with optional parameters.

Parameters
  • dataset – input dataset, which is an instance of pyspark.sql.DataFrame

  • params – an optional param map that overrides embedded params.

Returns

transformed dataset

New in version 1.3.0.

uid

A unique id for the object.

updateRegexClass(label, regex)[source]

Update existing class to correct, based on regex

Parameters
labelstr

Label of the class

regexstr

Regex to parse the class

updateVocabClass(label, vocab, append=True)[source]

Update existing class to correct, based on a vocabulary.

Parameters
labelstr

Label of the class

vocabList[str]

Vocabulary as a list

appendbool, optional

Whether to append to the existing vocabulary, by default True

write()

Returns an MLWriter instance for this ML instance.