JSL_MedSQL_T2SQL (t2sql - q16 - v1)

Description

This is a lightweight Text-to-SQL model fine-tuned by John Snow Labs, built specifically for working with medical and healthcare data.

Copy S3 URI

How to use

from sparknlp.base import DocumentAssembler
from sparknlp_jsl.annotator import MedicalLLM
from pyspark.ml import Pipeline

document_assembler = DocumentAssembler()\
    .setInputCol("text")\
    .setOutputCol("document")

medical_llm = MedicalLLM.pretrained("jsl_meds_text2sql_1b_q16_v1", "en", "clinical/models")\
    .setInputCols("document")\
    .setOutputCol("completions")\
    .setBatchSize(1)\
    .setNPredict(100)\
    .setUseChatTemplate(True)\
    .setTemperature(0)

pipeline = Pipeline(stages=[
    document_assembler,
    medical_llm
])

medm_prompt = """### Instruction:
### Instruction:
Table: CancerPatients
- patient_id (INT)
- name (VARCHAR)
- age (INT)
- gender (VARCHAR)
- cancer_type (VARCHAR)
- diagnosis_date (DATE)

List the names of patients diagnosed with breast cancer.

### Response:
"""

data = spark.createDataFrame([[medm_prompt]]).toDF("text")

model = pipeline.fit(data)
result = model.transform(data)

result.select("completions").show(truncate=False)

from johnsnowlabs import nlp, medical

document_assembler = nlp.DocumentAssembler()\
    .setInputCol("text")\
    .setOutputCol("document")

medical_llm = medical.MedicalLLM.pretrained(jsl_meds_text2sql_1b_q16_v1, "en", "clinical/models")\
    .setInputCols("document")\
    .setOutputCol("completions")\
    .setBatchSize(1)\
    .setNPredict(100)\
    .setUseChatTemplate(True)\
    .setTemperature(0)

pipeline = nlp.Pipeline(stages=[
    document_assembler,
    medical_llm
])

medm_prompt = """### Instruction:
### Instruction:
Table: CancerPatients
- patient_id (INT)
- name (VARCHAR)
- age (INT)
- gender (VARCHAR)
- cancer_type (VARCHAR)
- diagnosis_date (DATE)

List the names of patients diagnosed with breast cancer.

### Response:
"""

data = spark.createDataFrame([[medm_prompt]]).toDF("text")

model = pipeline.fit(data)
result = model.transform(data)

result.select("completions").show(truncate=False)


val documentAssembler = new DocumentAssembler()
  .setInputCol("text")
  .setOutputCol("document")

val medicalLLM = MedicalLLM.pretrained("jsl_meds_text2sql_1b_q16_v1", "en", "clinical/models")
  .setInputCols("document")
  .setOutputCol("completions")
  .setBatchSize(1)
  .setNPredict(100)
  .setUseChatTemplate(true)
  .setTemperature(0)

val pipeline = new Pipeline().setStages(Array(
  documentAssembler,
  medicalLLM
))

val medmPrompt = """### Instruction:
### Instruction:
Table: CancerPatients
- patient_id (INT)
- name (VARCHAR)
- age (INT)
- gender (VARCHAR)
- cancer_type (VARCHAR)
- diagnosis_date (DATE)

List the names of patients diagnosed with breast cancer.

### Response:
"""

val data = Seq(medmPrompt).toDF("text")

val model = pipeline.fit(data)
val result = model.transform(data)

result.select("completions").show(false)

Results

SELECT name FROM CancerPatients WHERE cancer_type = 'breast cancer'

Model Information

Model Name: jsl_meds_text2sql_1b_q16_v1
Compatibility: Healthcare NLP 6.1.0+
License: Licensed
Edition: Official
Language: en
Size: 1.6 GB