sparknlp_jsl.annotator.medical_llm.medical_llm#

Module Contents#

Classes#

MedicalLLM

MedicalLLM was designed to load and run large language models (LLMs) in GGUF format with scalable performance.

class MedicalLLM(classname='com.johnsnowlabs.nlp.annotators.seq2seq.MedicalLLM', java_model=None)#

Bases: sparknlp.annotator.AutoGGUFModel

MedicalLLM was designed to load and run large language models (LLMs) in GGUF format with scalable performance. Ideal for clinical and healthcare applications, MedicalLLM supports tasks like medical entity extraction, summarization, Q&A, Retrieval Augmented Generation (RAG), and conversational AI. With simple integration into Spark NLP pipelines, it allows for customizable batch sizes, prediction settings, and chat templates. GPU optimization is also available, enhancing its capabilities for high-performance environments. MedicalLLM empowers users to link medical entities and perform complex NLP tasks with efficiency and precision.

Input Annotation types

Output Annotation type

DOCUMENT

DOCUMENT

Notes

To use GPU inference with this annotator, make sure to use the Spark NLP GPU package and set the number of GPU layers with the setNGpuLayers method. When using larger models, we recommend adjusting GPU usage with setNCtx and setNGpuLayers according to your hardware to avoid out-of-memory errors.

Parameters:
  • nThreads – Set the number of threads to use during generation

  • nThreadsDraft – Set the number of threads to use during draft generation

  • nThreadsBatch – Set the number of threads to use during batch and prompt processing

  • nThreadsBatchDraft – Set the number of threads to use during batch and prompt processing

  • nCtx – Set the size of the prompt context

  • nBatch – Set the logical batch size for prompt processing (must be >=32 to use BLAS)

  • nUbatch – Set the physical batch size for prompt processing (must be >=32 to use BLAS)

  • nDraft – Set the number of tokens to draft for speculative decoding

  • nChunks – Set the maximal number of chunks to process

  • nSequences – Set the number of sequences to decode

  • pSplit – Set the speculative decoding split probability

  • nGpuLayers – Set the number of layers to store in VRAM (-1 - use default)

  • nGpuLayersDraft – Set the number of layers to store in VRAM for the draft model (-1 - use default)

  • gpuSplitMode – Set how to split the model across GPUs

  • mainGpu – Set the main GPU that is used for scratch and small tensors.

  • tensorSplit – Set how split tensors should be distributed across GPUs

  • grpAttnN – Set the group-attention factor

  • grpAttnW – Set the group-attention width

  • ropeFreqBase – Set the RoPE base frequency, used by NTK-aware scaling

  • ropeFreqScale – Set the RoPE frequency scaling factor, expands context by a factor of 1/N

  • yarnExtFactor – Set the YaRN extrapolation mix factor

  • yarnAttnFactor – Set the YaRN scale sqrt(t) or attention magnitude

  • yarnBetaFast – Set the YaRN low correction dim or beta

  • yarnBetaSlow – Set the YaRN high correction dim or alpha

  • yarnOrigCtx – Set the YaRN original context size of model

  • defragmentationThreshold – Set the KV cache defragmentation threshold

  • numaStrategy – Set optimization strategies that help on some NUMA systems (if available)

  • ropeScalingType – Set the RoPE frequency scaling method, defaults to linear unless specified by the model

  • poolingType – Set the pooling type for embeddings, use model default if unspecified

  • modelDraft – Set the draft model for speculative decoding

  • modelAlias – Set a model alias

  • lookupCacheStaticFilePath – Set path to static lookup cache to use for lookup decoding (not updated by generation)

  • lookupCacheDynamicFilePath – Set path to dynamic lookup cache to use for lookup decoding (updated by generation)

  • embedding – Whether to load model with embedding support

  • flashAttention – Whether to enable Flash Attention

  • inputPrefixBos – Whether to add prefix BOS to user inputs, preceding the –in-prefix string

  • useMmap – Whether to use memory-map model (faster load but may increase pageouts if not using mlock)

  • useMlock – Whether to force the system to keep model in RAM rather than swapping or compressing

  • noKvOffload – Whether to disable KV offload

  • systemPrompt – Set a system prompt to use

  • chatTemplate – The chat template to use

  • inputPrefix – Set the prompt to start generation with

  • inputSuffix – Set a suffix for infilling

  • cachePrompt – Whether to remember the prompt to avoid reprocessing it

  • nPredict – Set the number of tokens to predict

  • topK – Set top-k sampling

  • topP – Set top-p sampling

  • minP – Set min-p sampling

  • tfsZ – Set tail free sampling, parameter z

  • typicalP – Set locally typical sampling, parameter p

  • temperature – Set the temperature

  • dynatempRange – Set the dynamic temperature range

  • dynatempExponent – Set the dynamic temperature exponent

  • repeatLastN – Set the last n tokens to consider for penalties

  • repeatPenalty – Set the penalty of repeated sequences of tokens

  • frequencyPenalty – Set the repetition alpha frequency penalty

  • presencePenalty – Set the repetition alpha presence penalty

  • miroStat – Set MiroStat sampling strategies.

  • mirostatTau – Set the MiroStat target entropy, parameter tau

  • mirostatEta – Set the MiroStat learning rate, parameter eta

  • penalizeNl – Whether to penalize newline tokens

  • nKeep – Set the number of tokens to keep from the initial prompt

  • seed – Set the RNG seed

  • nProbs – Set the amount top tokens probabilities to output if greater than 0.

  • minKeep – Set the amount of tokens the samplers should return at least (0 = disabled)

  • grammar – Set BNF-like grammar to constrain generations

  • penaltyPrompt – Override which part of the prompt is penalized for repetition.

  • ignoreEos – Set whether to ignore end of stream token and continue generating (implies –logit-bias 2-inf)

  • disableTokenIds – Set the token ids to disable in the completion

  • stopStrings – Set strings upon seeing which token generation is stopped

  • samplers – Set which samplers to use for token generation in the given order

  • useChatTemplate – Set whether or not generate should apply a chat template

batchSize#
cachePrompt#
chatTemplate#
defragmentationThreshold#
disableTokenIds#
dynamicTemperatureExponent#
dynamicTemperatureRange#
embedding#
flashAttention#
frequencyPenalty#
getter_attrs = []#
gpuSplitMode#
grammar#
grpAttnN#
grpAttnW#
ignoreEos#
inputAnnotatorTypes#
inputCols#
inputPrefix#
inputPrefixBos#
inputSuffix#
lazyAnnotator#
lookupCacheDynamicFilePath#
lookupCacheStaticFilePath#
mainGpu#
minKeep#
minP#
miroStat#
miroStatEta#
miroStatTau#
modelAlias#
modelDraft#
nBatch#
nChunks#
nCtx#
nDraft#
nGpuLayers#
nGpuLayersDraft#
nKeep#
nPredict#
nProbs#
nSequences#
nThreads#
nThreadsBatch#
nThreadsBatchDraft#
nThreadsDraft#
nUbatch#
name = 'MedicalLLM'#
noKvOffload#
numaStrategy#
optionalInputAnnotatorTypes = []#
outputAnnotatorType = 'document'#
outputCol#
pSplit#
penalizeNl#
penaltyPrompt#
poolingType#
presencePenalty#
repeatLastN#
repeatPenalty#
ropeFreqBase#
ropeFreqScale#
ropeScalingType#
samplers#
seed#
stopStrings#
systemPrompt#
temperature#
tensorSplit#
tfsZ#
topK#
topP#
typicalP#
uid = ''#
useChatTemplate#
useMlock#
useMmap#
yarnAttnFactor#
yarnBetaFast#
yarnBetaSlow#
yarnExtFactor#
yarnOrigCtx#
clear(param: pyspark.ml.param.Param) None#

Clears a param from the param map if it has been explicitly set.

copy(extra: pyspark.ml._typing.ParamMap | None = None) JP#

Creates a copy of this instance with the same uid and some extra params. This implementation first calls Params.copy and then make a copy of the companion Java pipeline component with extra params. So both the Python wrapper and the Java pipeline component get copied.

Parameters:

extra (dict, optional) – Extra parameters to copy to the new instance

Returns:

Copy of this instance

Return type:

JavaParams

explainParam(param: str | Param) str#

Explains a single param and returns its name, doc, and optional default value and user-supplied value in a string.

explainParams() str#

Returns the documentation of all params with their optionally default values and user-supplied values.

extractParamMap(extra: pyspark.ml._typing.ParamMap | None = None) pyspark.ml._typing.ParamMap#

Extracts the embedded default param values and user-supplied values, and then merges them with extra values from input into a flat param map, where the latter value is used if there exist conflicts, i.e., with ordering: default param values < user-supplied values < extra.

Parameters:

extra (dict, optional) – extra param values

Returns:

merged param map

Return type:

dict

getBatchSize()#

Gets current batch size.

Returns:

Current batch size

Return type:

int

getInputCols()#

Gets current column names of input annotations.

getLazyAnnotator()#

Gets whether Annotator should be evaluated lazily in a RecursivePipeline.

getMetadata()#

Gets the metadata of the model

getOrDefault(param: str) Any#
getOrDefault(param: Param[T]) T

Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither is set.

getOutputCol()#

Gets output column name of annotations.

getParam(paramName: str) Param#

Gets a param by its name.

getParamValue(paramName)#

Gets the value of a parameter.

Parameters:

paramName (str) – Name of the parameter

hasDefault(param: str | Param[Any]) bool#

Checks whether a param has a default value.

hasParam(paramName: str) bool#

Tests whether this instance contains a param with a given (string) name.

inputColsValidation(value)#
isDefined(param: str | Param[Any]) bool#

Checks whether a param is explicitly set by user or has a default value.

isSet(param: str | Param[Any]) bool#

Checks whether a param is explicitly set by user.

classmethod load(path: str) RL#

Reads an ML instance from the input path, a shortcut of read().load(path).

static loadSavedModel(folder, spark_session)#

Loads a locally saved model.

Parameters:
Returns:

The restored model

Return type:

AutoGGUFModel

static pretrained(name='jsl_medm_q8_v1', lang='en', remote_loc='clinical/models')#

Downloads and loads a pretrained model.

Parameters:
  • name (str, optional) – Name of the pretrained model, by default “jsl_medm_q8_v1”

  • lang (str, optional) – Language of the pretrained model, by default “en”

  • remote_loc (str, optional) – Optional remote address of the resource, by default “clinical/models”. Will use Spark NLPs repositories otherwise.

Returns:

The restored model

Return type:

MedicalLLM

classmethod read()#

Returns an MLReader instance for this class.

save(path: str) None#

Save this ML instance to the given path, a shortcut of ‘write().save(path)’.

set(param: Param, value: Any) None#

Sets a parameter in the embedded param map.

setBatchSize(v)#

Sets batch size.

Parameters:

v (int) – Batch size

setCachePrompt(cachePrompt: bool)#

Whether to remember the prompt to avoid reprocessing it

setChatTemplate(chatTemplate: str)#

The chat template to use

setDefragmentationThreshold(defragmentationThreshold: float)#

Set the KV cache defragmentation threshold

setDisableTokenIds(disableTokenIds: List[int])#

Set the token ids to disable in the completion

setDynamicTemperatureExponent(dynamicTemperatureExponent: float)#

Set the dynamic temperature exponent

setDynamicTemperatureRange(dynamicTemperatureRange: float)#

Set the dynamic temperature range

setEmbedding(embedding: bool)#

Whether to load model with embedding support

setFlashAttention(flashAttention: bool)#

Whether to enable Flash Attention

setFrequencyPenalty(frequencyPenalty: float)#

Set the repetition alpha frequency penalty

setGpuSplitMode(gpuSplitMode: str)#

Set how to split the model across GPUs

setGrammar(grammar: bool)#

Set BNF-like grammar to constrain generations

setGrpAttnN(grpAttnN: int)#

Set the group-attention factor

setGrpAttnW(grpAttnW: int)#

Set the group-attention width

setIgnoreEos(ignoreEos: bool)#

Set whether to ignore end of stream token and continue generating (implies –logit-bias 2-inf)

setInputCols(*value)#

Sets column names of input annotations.

Parameters:

*value (List[str]) – Input columns for the annotator

setInputPrefix(inputPrefix: str)#

Set the prompt to start generation with

setInputPrefixBos(inputPrefixBos: bool)#

Whether to add prefix BOS to user inputs, preceding the –in-prefix bool

setInputSuffix(inputSuffix: str)#

Set a suffix for infilling

setLazyAnnotator(value)#

Sets whether Annotator should be evaluated lazily in a RecursivePipeline.

Parameters:

value (bool) – Whether Annotator should be evaluated lazily in a RecursivePipeline

setLookupCacheDynamicFilePath(lookupCacheDynamicFilePath: str)#

Set path to dynamic lookup cache to use for lookup decoding (updated by generation)

setLookupCacheStaticFilePath(lookupCacheStaticFilePath: str)#

Set path to static lookup cache to use for lookup decoding (not updated by generation)

setLoraAdapters(loraAdapters: Dict[str, float])#

Set token id bias

setMainGpu(mainGpu: int)#

Set the main GPU that is used for scratch and small tensors.

setMinKeep(minKeep: int)#

Set the amount of tokens the samplers should return at least (0 = disabled)

setMinP(minP: float)#

Set min-p sampling

setMiroStat(miroStat: str)#

Set MiroStat sampling strategies.

setMiroStatEta(miroStatEta: float)#

Set the MiroStat learning rate, parameter eta

setMiroStatTau(miroStatTau: float)#

Set the MiroStat target entropy, parameter tau

setModelAlias(modelAlias: str)#

Set a model alias

setModelDraft(modelDraft: str)#

Set the draft model for speculative decoding

setNBatch(nBatch: int)#

Set the logical batch size for prompt processing (must be >=32 to use BLAS)

setNChunks(nChunks: int)#

Set the maximal number of chunks to process

setNCtx(nCtx: int)#

Set the size of the prompt context

setNDraft(nDraft: int)#

Set the number of tokens to draft for speculative decoding

setNGpuLayers(nGpuLayers: int)#

Set the number of layers to store in VRAM (-1 - use default)

setNGpuLayersDraft(nGpuLayersDraft: int)#

Set the number of layers to store in VRAM for the draft model (-1 - use default)

setNKeep(nKeep: int)#

Set the number of tokens to keep from the initial prompt

setNPredict(nPredict: int)#

Set the number of tokens to predict

setNProbs(nProbs: int)#

Set the amount top tokens probabilities to output if greater than 0.

setNSequences(nSequences: int)#

Set the number of sequences to decode

setNThreads(nThreads: int)#

Set the number of threads to use during generation

setNThreadsBatch(nThreadsBatch: int)#

Set the number of threads to use during batch and prompt processing

setNThreadsBatchDraft(nThreadsBatchDraft: int)#

Set the number of threads to use during batch and prompt processing

setNThreadsDraft(nThreadsDraft: int)#

Set the number of threads to use during draft generation

setNUbatch(nUbatch: int)#

Set the physical batch size for prompt processing (must be >=32 to use BLAS)

setNoKvOffload(noKvOffload: bool)#

Whether to disable KV offload

setNumaStrategy(numaStrategy: str)#

Set optimization strategies that help on some NUMA systems (if available)

setOutputCol(value)#

Sets output column name of annotations.

Parameters:

value (str) – Name of output column

setPSplit(pSplit: float)#

Set the speculative decoding split probability

setParamValue(paramName)#

Sets the value of a parameter.

Parameters:

paramName (str) – Name of the parameter

setParams()#
setPenalizeNl(penalizeNl: bool)#

Whether to penalize newline tokens

setPenaltyPrompt(penaltyPrompt: str)#

Override which part of the prompt is penalized for repetition.

setPoolingType(poolingType: bool)#

Set the pooling type for embeddings, use model default if unspecified

setPresencePenalty(presencePenalty: float)#

Set the repetition alpha presence penalty

setRepeatLastN(repeatLastN: int)#

Set the last n tokens to consider for penalties

setRepeatPenalty(repeatPenalty: float)#

Set the penalty of repeated sequences of tokens

setRopeFreqBase(ropeFreqBase: float)#

Set the RoPE base frequency, used by NTK-aware scaling

setRopeFreqScale(ropeFreqScale: float)#

Set the RoPE frequency scaling factor, expands context by a factor of 1/N

setRopeScalingType(ropeScalingType: str)#

Set the RoPE frequency scaling method, defaults to linear unless specified by the model

setSamplers(samplers: List[str])#

Set which samplers to use for token generation in the given order

setSeed(seed: int)#

Set the RNG seed

setStopStrings(stopStrings: List[str])#

Set strings upon seeing which token generation is stopped

setSystemPrompt(systemPrompt: bool)#

Set a system prompt to use

setTemperature(temperature: float)#

Set the temperature

setTensorSplit(tensorSplit: List[float])#

Set how split tensors should be distributed across GPUs

setTfsZ(tfsZ: float)#

Set tail free sampling, parameter z

setTokenBias(tokenBias: Dict[str, float])#

Set token id bias

setTokenIdBias(tokenIdBias: Dict[int, float])#

Set token id bias

setTopK(topK: int)#

Set top-k sampling

setTopP(topP: float)#

Set top-p sampling

setTypicalP(typicalP: float)#

Set locally typical sampling, parameter p

setUseChatTemplate(useChatTemplate: bool)#

Set whether generate should apply a chat template

setUseMlock(useMlock: bool)#

Whether to force the system to keep model in RAM rather than swapping or compressing

setUseMmap(useMmap: bool)#

Whether to use memory-map model (faster load but may increase pageouts if not using mlock)

setYarnAttnFactor(yarnAttnFactor: float)#

Set the YaRN scale sqrt(t) or attention magnitude

setYarnBetaFast(yarnBetaFast: float)#

Set the YaRN low correction dim or beta

setYarnBetaSlow(yarnBetaSlow: float)#

Set the YaRN high correction dim or alpha

setYarnExtFactor(yarnExtFactor: float)#

Set the YaRN extrapolation mix factor

setYarnOrigCtx(yarnOrigCtx: int)#

Set the YaRN original context size of model

transform(dataset: pyspark.sql.dataframe.DataFrame, params: pyspark.ml._typing.ParamMap | None = None) pyspark.sql.dataframe.DataFrame#

Transforms the input dataset with optional parameters.

New in version 1.3.0.

Parameters:
  • dataset (pyspark.sql.DataFrame) – input dataset

  • params (dict, optional) – an optional param map that overrides embedded params.

Returns:

transformed dataset

Return type:

pyspark.sql.DataFrame

write() JavaMLWriter#

Returns an MLWriter instance for this ML instance.