Text-to-SQL Generation (Custom_DB_Schema_Single_Table_Augmented)

Description

This model is the State-of-the-Art (SOTA) for generating SQL queries from natural questions and custom database schemas with a single table. It is based on a large-size LLM, which is finetuned by John Snow Labs on an augmented dataset having schemas with single tables.

Predicted Entities

Copy S3 URI

How to use

question = "What is the average age of male patients with 'Diabetes'?"
query_schema = {
    "medical_treatment": ["patient_id","patient_name","age","gender","diagnosis","treatment","doctor_name","hospital_name","admission_date","discharge_date"]
}


document_assembler = DocumentAssembler()\
    .setInputCol("text")\
    .setOutputCol("document")

text2sql_with_schema_single_table_augmented = Text2SQL.pretrained("text2sql_with_schema_single_table_augmented", "en", "clinical/models")\
    .setMaxNewTokens(200)\
    .setSchema(query_schema)\
    .setInputCols(["document"])\
    .setOutputCol("sql")

pipeline = Pipeline(stages=[document_assembler, text2sql_with_schema_single_table_augmented ])

data = spark.createDataFrame([[question]]).toDF("text")

pipeline.fit(data)\
        .transform(data)\
        .select("sql.result")\
        .show(truncate=False)
val question = """What is the average age of male patients with 'Diabetes'? """
val query_schema : Map[String, List[String]] = Map(
    "medical_treatment" -> List("patient_id","patient_name","age","gender","diagnosis","treatment","doctor_name","hospital_name","admission_date","discharge_date")
  )

val document_assembler = new DocumentAssembler()
    .setInputCol("text")
    .setOutputCol("document")

val text2sql_with_schema_single_table_augmented = new Text2SQL.pretrained("text2sql_with_schema_single_table_augmented", "en", "clinical/models")
    .setMaxNewTokens(200)
    .setSchema(query_schema)
    .setInputCols(["document"])
    .setOutputCol("sql")

val pipeline = new Pipeline().setStages(Array(document_assembler, text2sql_with_schema_single_table_augmented ))

val data = Seq(Array(text)).toDS.toDF("text")

val result = pipeline.fit(data).transform(data)

Results

[SELECT AVG(age) FROM medical_treatment WHERE gender = 'male' AND diagnosis = 'diabetes']

Model Information

Model Name: text2sql_with_schema_single_table_augmented
Compatibility: Healthcare NLP 5.1.1+
License: Licensed
Edition: Official
Language: en
Size: 3.1 GB