sparknlp_jsl.annotator.medical_llm.medical_vision_llm
#
Contains classes for the MedicalVisionLLM.
Module Contents#
Classes#
Multimodal annotator to generate text completions with large |
- class MedicalVisionLLM(classname='com.johnsnowlabs.nlp.annotators.seq2seq.MedicalVisionLLM', java_model=None)#
Bases:
sparknlp.annotator.AutoGGUFVisionModel
Multimodal annotator to generate text completions with large language models. It supports ingesting images for captioning.
At the moment only CLIP based models are supported.
If the parameters are not set, the annotator will default to use the parameters provided by the model.
This annotator expects a column of annotator type AnnotationImage for the image and Annotation for the caption. Note that the image bytes in the image annotation need to be raw image bytes without preprocessing. We provide the helper function ImageAssembler.loadImagesAsBytes to load the image bytes from a directory.
Pretrained models can be loaded with
pretrained
of the companion object:medical_vision_llm = MedicalVisionLLM.pretrained() \ .setInputCols(["image", "document"]) \ .setOutputCol("completions")
Input Annotation types
Output Annotation type
IMAGE, DOCUMENT
DOCUMENT
- Parameters:
nThreads – Set the number of threads to use during generation
nThreadsDraft – Set the number of threads to use during draft generation
nThreadsBatch – Set the number of threads to use during batch and prompt processing
nThreadsBatchDraft – Set the number of threads to use during batch and prompt processing
nCtx – Set the size of the prompt context
nBatch – Set the logical batch size for prompt processing (must be >=32 to use BLAS)
nUbatch – Set the physical batch size for prompt processing (must be >=32 to use BLAS)
nDraft – Set the number of tokens to draft for speculative decoding
nChunks – Set the maximal number of chunks to process
nSequences – Set the number of sequences to decode
pSplit – Set the speculative decoding split probability
nGpuLayers – Set the number of layers to store in VRAM (-1 - use default)
nGpuLayersDraft – Set the number of layers to store in VRAM for the draft model (-1 - use default)
gpuSplitMode – Set how to split the model across GPUs
mainGpu – Set the main GPU that is used for scratch and small tensors.
tensorSplit – Set how split tensors should be distributed across GPUs
grpAttnN – Set the group-attention factor
grpAttnW – Set the group-attention width
ropeFreqBase – Set the RoPE base frequency, used by NTK-aware scaling
ropeFreqScale – Set the RoPE frequency scaling factor, expands context by a factor of 1/N
yarnExtFactor – Set the YaRN extrapolation mix factor
yarnAttnFactor – Set the YaRN scale sqrt(t) or attention magnitude
yarnBetaFast – Set the YaRN low correction dim or beta
yarnBetaSlow – Set the YaRN high correction dim or alpha
yarnOrigCtx – Set the YaRN original context size of model
defragmentationThreshold – Set the KV cache defragmentation threshold
numaStrategy – Set optimization strategies that help on some NUMA systems (if available)
ropeScalingType – Set the RoPE frequency scaling method, defaults to linear unless specified by the model
poolingType – Set the pooling type for embeddings, use model default if unspecified
modelDraft – Set the draft model for speculative decoding
modelAlias – Set a model alias
lookupCacheStaticFilePath – Set path to static lookup cache to use for lookup decoding (not updated by generation)
lookupCacheDynamicFilePath – Set path to dynamic lookup cache to use for lookup decoding (updated by generation)
embedding – Whether to load model with embedding support
flashAttention – Whether to enable Flash Attention
inputPrefixBos – Whether to add prefix BOS to user inputs, preceding the –in-prefix string
useMmap – Whether to use memory-map model (faster load but may increase pageouts if not using mlock)
useMlock – Whether to force the system to keep model in RAM rather than swapping or compressing
noKvOffload – Whether to disable KV offload
systemPrompt – Set a system prompt to use
chatTemplate – The chat template to use
inputPrefix – Set the prompt to start generation with
inputSuffix – Set a suffix for infilling
cachePrompt – Whether to remember the prompt to avoid reprocessing it
nPredict – Set the number of tokens to predict
topK – Set top-k sampling
topP – Set top-p sampling
minP – Set min-p sampling
tfsZ – Set tail free sampling, parameter z
typicalP – Set locally typical sampling, parameter p
temperature – Set the temperature
dynatempRange – Set the dynamic temperature range
dynatempExponent – Set the dynamic temperature exponent
repeatLastN – Set the last n tokens to consider for penalties
repeatPenalty – Set the penalty of repeated sequences of tokens
frequencyPenalty – Set the repetition alpha frequency penalty
presencePenalty – Set the repetition alpha presence penalty
miroStat – Set MiroStat sampling strategies.
mirostatTau – Set the MiroStat target entropy, parameter tau
mirostatEta – Set the MiroStat learning rate, parameter eta
penalizeNl – Whether to penalize newline tokens
nKeep – Set the number of tokens to keep from the initial prompt
seed – Set the RNG seed
nProbs – Set the amount top tokens probabilities to output if greater than 0.
minKeep – Set the amount of tokens the samplers should return at least (0 = disabled)
grammar – Set BNF-like grammar to constrain generations
penaltyPrompt – Override which part of the prompt is penalized for repetition.
ignoreEos – Set whether to ignore end of stream token and continue generating (implies –logit-bias 2-inf)
disableTokenIds – Set the token ids to disable in the completion
stopStrings – Set strings upon seeing which token generation is stopped
samplers – Set which samplers to use for token generation in the given order
useChatTemplate – Set whether or not generate should apply a chat template
Notes
To use GPU inference with this annotator, make sure to use the Spark NLP GPU package and set the number of GPU layers with the setNGpuLayers method.
When using larger models, we recommend adjusting GPU usage with setNCtx and setNGpuLayers according to your hardware to avoid out-of-memory errors.
Examples >>> import sparknlp >>> import sparknlp_jsl >>> from sparknlp.base import * >>> from sparknlp.annotator import * >>> from sparknlp_jsl.annotator import * >>> from pyspark.ml import Pipeline >>> from pyspark.sql.functions import lit >>> documentAssembler = DocumentAssembler() … .setInputCol(“caption”) … .setOutputCol(“caption_document”) >>> imageAssembler = ImageAssembler() … .setInputCol(“image”) … .setOutputCol(“image_assembler”) >>> imagesPath = “IMAGES_PATH” >>> data = ImageAssembler … .loadImagesAsBytes(spark, imagesPath) … .withColumn(“caption”, lit(“Caption this image.”)) >>> model = MedicalVisionLLM.pretrained() … .setInputCols([“caption_document”, “image_assembler”]) … .setOutputCol(“completions”) … .setBatchSize(4) … .setNGpuLayers(99) … .setNCtx(4096) … .setMinKeep(0) … .setMinP(0.05) … .setNPredict(40) … .setNProbs(0) … .setPenalizeNl(False) … .setRepeatLastN(256) … .setRepeatPenalty(1.18) … .setStopStrings([“</s>”, “Llama:”, “User:”]) … .setTemperature(0.05) … .setTfsZ(1) … .setTypicalP(1) … .setTopK(40) … .setTopP(0.95) >>> pipeline = Pipeline().setStages([documentAssembler, imageAssembler, model]) >>> pipeline.fit(data).transform(data) … .selectExpr(“reverse(split(image.origin, ‘/’))[0] as image_name”, “completions.result”) … .show(truncate = False) +—————–+———————————————————————————————————————————————————————————————-+ |image_name |result | +-----------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ |palace.JPEG |[ The image depicts a large, ornate room with high ceilings and beautifully decorated walls. There are several chairs placed throughout the space, some of which have cushions] | |egyptian_cat.jpeg|[ The image features two cats lying on a pink surface, possibly a bed or sofa. One cat is positioned towards the left side of the scene and appears to be sleeping while holding] | |hippopotamus.JPEG|[ A large brown hippo is swimming in a body of water, possibly an aquarium. The hippo appears to be enjoying its time in the water and seems relaxed as it floats] | |hen.JPEG |[ The image features a large chicken standing next to several baby chickens. In total, there are five birds in the scene: one adult and four young ones. They appear to be gathered together] | |ostrich.JPEG |[ The image features a large, long-necked bird standing in the grass. It appears to be an ostrich or similar species with its head held high and looking around. In addition to] | |junco.JPEG |[ A small bird with a black head and white chest is standing on the snow. It appears to be looking at something, possibly food or another animal in its vicinity. The scene takes place out] | |bluetick.jpg |[ A dog with a red collar is sitting on the floor, looking at something. The dog appears to be staring into the distance or focusing its attention on an object in front of it.] | |chihuahua.jpg |[ A small brown dog wearing a sweater is sitting on the floor. The dog appears to be looking at something, possibly its owner or another animal in the room. It seems comfortable and relaxed]| |tractor.JPEG |[ A man is sitting in the driver’s seat of a green tractor, which has yellow wheels and tires. The tractor appears to be parked on top of an empty field with] | |ox.JPEG |[ A large bull with horns is standing in a grassy field.] | +—————–+———————————————————————————————————————————————————————————————-+——-
- batchSize#
- cachePrompt#
- chatTemplate#
- defragmentationThreshold#
- disableLog#
- disableTokenIds#
- dynamicTemperatureExponent#
- dynamicTemperatureRange#
- embedding#
- flashAttention#
- frequencyPenalty#
- getter_attrs = []#
- gpuSplitMode#
- grammar#
- ignoreEos#
- inputAnnotatorTypes#
- inputCols#
- inputPrefix#
- inputSuffix#
- lazyAnnotator#
- logVerbosity#
- mainGpu#
- minKeep#
- minP#
- miroStat#
- miroStatEta#
- miroStatTau#
- modelAlias#
- modelDraft#
- nBatch#
- nCtx#
- nDraft#
- nGpuLayers#
- nGpuLayersDraft#
- nKeep#
- nPredict#
- nProbs#
- nThreads#
- nThreadsBatch#
- nUbatch#
- name = 'MedicalVisionLLM'#
- noKvOffload#
- numaStrategy#
- optionalInputAnnotatorTypes = []#
- outputAnnotatorType = 'document'#
- outputCol#
- penalizeNl#
- penaltyPrompt#
- poolingType#
- presencePenalty#
- repeatLastN#
- repeatPenalty#
- ropeFreqBase#
- ropeFreqScale#
- ropeScalingType#
- samplers#
- seed#
- stopStrings#
- systemPrompt#
- temperature#
- tfsZ#
- topK#
- topP#
- typicalP#
- uid = ''#
- useChatTemplate#
- useMlock#
- useMmap#
- yarnAttnFactor#
- yarnBetaFast#
- yarnBetaSlow#
- yarnExtFactor#
- yarnOrigCtx#
- clear(param: pyspark.ml.param.Param) None #
Clears a param from the param map if it has been explicitly set.
- copy(extra: pyspark.ml._typing.ParamMap | None = None) JP #
Creates a copy of this instance with the same uid and some extra params. This implementation first calls Params.copy and then make a copy of the companion Java pipeline component with extra params. So both the Python wrapper and the Java pipeline component get copied.
- Parameters:
extra (dict, optional) – Extra parameters to copy to the new instance
- Returns:
Copy of this instance
- Return type:
JavaParams
- explainParam(param: str | Param) str #
Explains a single param and returns its name, doc, and optional default value and user-supplied value in a string.
- explainParams() str #
Returns the documentation of all params with their optionally default values and user-supplied values.
- extractParamMap(extra: pyspark.ml._typing.ParamMap | None = None) pyspark.ml._typing.ParamMap #
Extracts the embedded default param values and user-supplied values, and then merges them with extra values from input into a flat param map, where the latter value is used if there exist conflicts, i.e., with ordering: default param values < user-supplied values < extra.
- Parameters:
extra (dict, optional) – extra param values
- Returns:
merged param map
- Return type:
dict
- getBatchSize()#
Gets current batch size.
- Returns:
Current batch size
- Return type:
int
- getInputCols()#
Gets current column names of input annotations.
- getLazyAnnotator()#
Gets whether Annotator should be evaluated lazily in a RecursivePipeline.
- getMetadata()#
Gets the metadata of the model
- getOrDefault(param: str) Any #
- getOrDefault(param: Param[T]) T
Gets the value of a param in the user-supplied param map or its default value. Raises an error if neither is set.
- getOutputCol()#
Gets output column name of annotations.
- getParam(paramName: str) Param #
Gets a param by its name.
- getParamValue(paramName)#
Gets the value of a parameter.
- Parameters:
paramName (str) – Name of the parameter
- getSystemPrompt()#
Get the system prompt.
- hasDefault(param: str | Param[Any]) bool #
Checks whether a param has a default value.
- hasParam(paramName: str) bool #
Tests whether this instance contains a param with a given (string) name.
- inputColsValidation(value)#
- isDefined(param: str | Param[Any]) bool #
Checks whether a param is explicitly set by user or has a default value.
- isSet(param: str | Param[Any]) bool #
Checks whether a param is explicitly set by user.
- classmethod load(path: str) RL #
Reads an ML instance from the input path, a shortcut of read().load(path).
- static loadSavedModel(modelPath, mmprojPath, spark_session)#
Loads a locally saved modelPath.
- Parameters:
modelPath (str) – Path to the modelPath file
mmprojPath (str) – Path to the mmprojPath file
spark_session (pyspark.sql.SparkSession) – The current SparkSession
- Returns:
The restored modelPath
- Return type:
- static pretrained(name='jsl_meds_vlm_3b_q4_v1', lang='en', remote_loc='clinical/models')#
Downloads and loads a pretrained model.
- Parameters:
name (str, optional) – Name of the pretrained model, by default “jsl_meds_vlm_3b_q4_v1”
lang (str, optional) – Language of the pretrained model, by default “en”
remote_loc (str, optional) – Optional remote address of the resource, by default “clinical/models”. Will use Spark NLPs repositories otherwise.
- Returns:
The restored model
- Return type:
AutoGGUFVisionModel
- classmethod read()#
Returns an MLReader instance for this class.
- save(path: str) None #
Save this ML instance to the given path, a shortcut of ‘write().save(path)’.
- set(param: Param, value: Any) None #
Sets a parameter in the embedded param map.
- setBatchSize(v)#
Sets batch size.
- Parameters:
v (int) – Batch size
- setCachePrompt(cachePrompt: bool)#
Whether to remember the prompt to avoid reprocessing it
- setChatTemplate(chatTemplate: str)#
The chat template to use
- setDefragmentationThreshold(defragmentationThreshold: float)#
Set the KV cache defragmentation threshold
- setDisableLog(disableLog: bool)#
Whether to disable logging
- setDisableTokenIds(disableTokenIds: List[int])#
Set the token ids to disable in the completion
- setDynamicTemperatureExponent(dynamicTemperatureExponent: float)#
Set the dynamic temperature exponent
- setDynamicTemperatureRange(dynamicTemperatureRange: float)#
Set the dynamic temperature range
- setFlashAttention(flashAttention: bool)#
Whether to enable Flash Attention
- setFrequencyPenalty(frequencyPenalty: float)#
Set the repetition alpha frequency penalty
- setGpuSplitMode(gpuSplitMode: str)#
Set how to split the model across GPUs
- setGrammar(grammar: str)#
Set BNF-like grammar to constrain generations
- setIgnoreEos(ignoreEos: bool)#
Set whether to ignore end of stream token and continue generating (implies –logit-bias 2-inf)
- setInputCols(*value)#
Sets column names of input annotations.
- Parameters:
*value (List[str]) – Input columns for the annotator
- setInputPrefix(inputPrefix: str)#
Set the prompt to start generation with
- setInputSuffix(inputSuffix: str)#
Set a suffix for infilling
- setLazyAnnotator(value)#
Sets whether Annotator should be evaluated lazily in a RecursivePipeline.
- Parameters:
value (bool) – Whether Annotator should be evaluated lazily in a RecursivePipeline
- setLogVerbosity(logVerbosity: int)#
Set the log verbosity level
- setMainGpu(mainGpu: int)#
Set the main GPU that is used for scratch and small tensors.
- setMinKeep(minKeep: int)#
Set the amount of tokens the samplers should return at least (0 = disabled)
- setMinP(minP: float)#
Set min-p sampling
- setMiroStat(miroStat: str)#
Set MiroStat sampling strategies.
- setMiroStatEta(miroStatEta: float)#
Set the MiroStat learning rate, parameter eta
- setMiroStatTau(miroStatTau: float)#
Set the MiroStat target entropy, parameter tau
- setModelAlias(modelAlias: str)#
Set a model alias
- setModelDraft(modelDraft: str)#
Set the draft model for speculative decoding
- setNBatch(nBatch: int)#
Set the logical batch size for prompt processing (must be >=32 to use BLAS)
- setNCtx(nCtx: int)#
Set the size of the prompt context
- setNDraft(nDraft: int)#
Set the number of tokens to draft for speculative decoding
- setNGpuLayers(nGpuLayers: int)#
Set the number of layers to store in VRAM (-1 - use default)
- setNGpuLayersDraft(nGpuLayersDraft: int)#
Set the number of layers to store in VRAM for the draft model (-1 - use default)
- setNKeep(nKeep: int)#
Set the number of tokens to keep from the initial prompt
- setNParallel(nParallel: int)#
Sets the number of parallel processes for decoding. This is an alias for setBatchSize.
- setNPredict(nPredict: int)#
Set the number of tokens to predict
- setNProbs(nProbs: int)#
Set the amount top tokens probabilities to output if greater than 0.
- setNThreads(nThreads: int)#
Set the number of threads to use during generation
- setNThreadsBatch(nThreadsBatch: int)#
Set the number of threads to use during batch and prompt processing
- setNUbatch(nUbatch: int)#
Set the physical batch size for prompt processing (must be >=32 to use BLAS)
- setNoKvOffload(noKvOffload: bool)#
Whether to disable KV offload
- setNumaStrategy(numaStrategy: str)#
Set optimization strategies that help on some NUMA systems (if available)
Possible values:
DISABLED: No NUMA optimizations
DISTRIBUTE: spread execution evenly over all
ISOLATE: only spawn threads on CPUs on the node that execution started on
NUMA_CTL: use the CPU map provided by numactl
MIRROR: Mirrors the model across NUMA nodes
- setOutputCol(value)#
Sets output column name of annotations.
- Parameters:
value (str) – Name of output column
- setParamValue(paramName)#
Sets the value of a parameter.
- Parameters:
paramName (str) – Name of the parameter
- setParams()#
- setPenalizeNl(penalizeNl: bool)#
Whether to penalize newline tokens
- setPenaltyPrompt(penaltyPrompt: str)#
Override which part of the prompt is penalized for repetition.
- setPoolingType(poolingType: str)#
Set the pooling type for embeddings, use model default if unspecified
Possible values:
MEAN: Mean Pooling
CLS: CLS Pooling
LAST: Last token pooling
RANK: For reranked models
- setPresencePenalty(presencePenalty: float)#
Set the repetition alpha presence penalty
- setRepeatLastN(repeatLastN: int)#
Set the last n tokens to consider for penalties
- setRepeatPenalty(repeatPenalty: float)#
Set the penalty of repeated sequences of tokens
- setRopeFreqBase(ropeFreqBase: float)#
Set the RoPE base frequency, used by NTK-aware scaling
- setRopeFreqScale(ropeFreqScale: float)#
Set the RoPE frequency scaling factor, expands context by a factor of 1/N
- setRopeScalingType(ropeScalingType: str)#
Set the RoPE frequency scaling method, defaults to linear unless specified by the model.
Possible values:
NONE: Don’t use any scaling
LINEAR: Linear scaling
YARN: YaRN RoPE scaling
- setSamplers(samplers: List[str])#
Set which samplers to use for token generation in the given order
- setSeed(seed: int)#
Set the RNG seed
- setStopStrings(stopStrings: List[str])#
Set strings upon seeing which token generation is stopped
- setSystemPrompt(systemPrompt)#
Set a system prompt to use.
- setTemperature(temperature: float)#
Set the temperature
- setTfsZ(tfsZ: float)#
Set tail free sampling, parameter z
- setTokenBias(tokenBias: Dict[str, float])#
Set token id bias
- setTokenIdBias(tokenIdBias: Dict[int, float])#
Set token id bias
- setTopK(topK: int)#
Set top-k sampling
- setTopP(topP: float)#
Set top-p sampling
- setTypicalP(typicalP: float)#
Set locally typical sampling, parameter p
- setUseChatTemplate(useChatTemplate: bool)#
Set whether generate should apply a chat template
- setUseMlock(useMlock: bool)#
Whether to force the system to keep model in RAM rather than swapping or compressing
- setUseMmap(useMmap: bool)#
Whether to use memory-map model (faster load but may increase pageouts if not using mlock)
- setYarnAttnFactor(yarnAttnFactor: float)#
Set the YaRN scale sqrt(t) or attention magnitude
- setYarnBetaFast(yarnBetaFast: float)#
Set the YaRN low correction dim or beta
- setYarnBetaSlow(yarnBetaSlow: float)#
Set the YaRN high correction dim or alpha
- setYarnExtFactor(yarnExtFactor: float)#
Set the YaRN extrapolation mix factor
- setYarnOrigCtx(yarnOrigCtx: int)#
Set the YaRN original context size of model
- transform(dataset: pyspark.sql.dataframe.DataFrame, params: pyspark.ml._typing.ParamMap | None = None) pyspark.sql.dataframe.DataFrame #
Transforms the input dataset with optional parameters.
New in version 1.3.0.
- Parameters:
dataset (
pyspark.sql.DataFrame
) – input datasetparams (dict, optional) – an optional param map that overrides embedded params.
- Returns:
transformed dataset
- Return type:
- write() JavaMLWriter #
Returns an MLWriter instance for this ML instance.