sparknlp_jsl.llm.llm_loader#

Module Contents#

Classes#

LLMLoader

Base class for :py:class:`Model`s that wrap Java/Scala

class LLMLoader(spark)#

Bases: pyspark.ml.wrapper.JavaModel

Base class for :py:class:`Model`s that wrap Java/Scala implementations. Subclasses should inherit this class before param mix-ins, because this sets the UID from the Java model.

uid = ''#
clear(param: pyspark.ml.param.Param) None#

Clears a param from the param map if it has been explicitly set.

copy(extra: pyspark.ml._typing.ParamMap | None = None) JP#

Creates a copy of this instance with the same uid and some extra params. This implementation first calls Params.copy and then make a copy of the companion Java pipeline component with extra params. So both the Python wrapper and the Java pipeline component get copied.

Parameters:

extra (dict, optional) – Extra parameters to copy to the new instance

Returns:

Copy of this instance

Return type:

JavaParams

encodeModel(model_path, output_model_path, metadata)#
explainParam(param: str | Param) str#

Explains a single param and returns its name, doc, and optional default value and user-supplied value in a string.

explainParams() str#

Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra: pyspark.ml._typing.ParamMap | None = None) pyspark.ml._typing.ParamMap#

Extracts the embedded default param values and user-supplied values, and then merges them with extra values from input into a flat param map, where the latter value is used if there exist conflicts, i.e., with ordering: default param values < user-supplied values < extra.

Parameters:

extra (dict, optional) – extra param values

Returns:

merged param map

Return type:

dict

generate(prompt)#
getMetadata(param)#
getMetadataEntry(param)#
getOrDefault(param: str) Any#
getOrDefault(param: Param[T]) T

Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither is set.

getParam(paramName: str) Param#

Gets a param by its name.

hasDefault(param: str | Param[Any]) bool#

Checks whether a param has a default value.

hasParam(paramName: str) bool#

Tests whether this instance contains a param with a given (string) name.

isDefined(param: str | Param[Any]) bool#

Checks whether a param is explicitly set by user or has a default value.

isSet(param: str | Param[Any]) bool#

Checks whether a param is explicitly set by user.

load(model_path, n_gpu_layers=0)#
loadGGUF(model_path, system_prompt='', n_gpu_layers=0)#
pretrained(name, lang='en', remote_loc='clinical/models')#
set(param: Param, value: Any) None#

Sets a parameter in the embedded param map.

setCachePrompt(cachePrompt)#

Whether to remember the prompt to avoid reprocessing it

setDynamicTemperatureExponent(dynatempExponent)#

Set the dynamic temperature exponent (default: 1.0)

setDynamicTemperatureRange(dynatempRange)#

Set the dynamic temperature range (default: 0.0, 0.0 = disabled)

setFrequencyPenalty(frequencyPenalty)#

Set the repetition alpha frequency penalty (default: 0.0, 0.0 = disabled)

setGrammar(grammar)#

Set BNF-like grammar to constrain generations (see samples in grammars/ dir)

setIgnoreEos(ignoreEos)#

Set whether to ignore end of stream token and continue generating (implies –logit-bias 2-inf)

setInputPrefix(inputPrefix)#

Set a prefix for infilling (default: empty)

setInputSuffix(inputSuffix)#

Set a suffix for infilling (default: empty)

setMinKeep(minKeep)#

Set the amount of tokens the samplers should return at least (0 = disabled)

setMinP(minP)#

Set min-p sampling (default: 0.1, 0.0 = disabled)

setMiroStatEta(mirostatEta)#

Set the MiroStat learning rate, parameter eta (default: 0.1)

setMiroStatTau(mirostatTau)#

Set the MiroStat target entropy, parameter tau (default: 5.0)

setNKeep(nKeep)#

Set the number of tokens to keep from the initial prompt (default: 0, -1 = all)

setNPredict(nPredict)#

Set the number of tokens to predict (default: -1, -1 = infinity, -2 = until context filled)

setNProbs(nProbs)#

Set the amount top tokens probabilities to output if greater than 0.

setPenalizeNl(penalizeNl)#

Whether to penalize newline tokens

setPenaltyPrompt(penaltyPrompt)#

Override which part of the prompt is penalized for repetition. E.g. if original prompt is “Alice: Hello!” and penaltyPrompt is “Hello!”, only the latter will be penalized if repeated. See <a href=”ggerganov/llama.cpp#3727”>pull request 3727</a> for more details.

setPresencePenalty(presencePenalty)#

Set the repetition alpha presence penalty (default: 0.0, 0.0 = disabled)

setRepeatLastN(repeatLastN)#

Set the last n tokens to consider for penalties (default: 64, 0 = disabled, -1 = ctx_size)

setRepeatPenalty(repeatPenalty)#

Set the penalty of repeated sequences of tokens (default: 1.0, 1.0 = disabled)

setSeed(seed)#

Set the RNG seed (default: -1, use random seed for &lt; 0)

setStopStrings(stopStrings)#

Set strings upon seeing which token generation is stopped

setTemperature(temperature)#

Set the temperature (default: 0.8)

setTfsZ(tfsZ)#

Set tail free sampling, parameter z (default: 1.0, 1.0 = disabled)

setTopK(topK)#

Set top-k sampling (default: 40, 0 = disabled)

setTopP(topP)#

Set top-p sampling (default: 0.9, 1.0 = disabled)

setTypicalP(typicalP)#

Set locally typical sampling, parameter p (default: 1.0, 1.0 = disabled)

setUseChatTemplate(useChatTemplate)#

Set whether or not generate should apply a chat template (default: false)

transform(dataset: pyspark.sql.dataframe.DataFrame, params: pyspark.ml._typing.ParamMap | None = None) pyspark.sql.dataframe.DataFrame#

Transforms the input dataset with optional parameters.

New in version 1.3.0.

Parameters:
  • dataset (pyspark.sql.DataFrame) – input dataset

  • params (dict, optional) – an optional param map that overrides embedded params.

Returns:

transformed dataset

Return type:

pyspark.sql.DataFrame