Text-to-SQL Generation (Custom_DB_Schema_Single_Table_Augmented)

Description

This model is the SOTA for generating SQL queries from natural questions and custom database schemas with a single table. It is based on a large-size LLM, which is finetuned by John Snow Labs on an augmented dataset having schemas with single tables.

Predicted Entities

Live Demo Open in Colab Copy S3 URI

How to use

question = "What is the average age of male patients with 'Diabetes'?"
query_schema = {
    "medical_treatment": ["patient_id","patient_name","age","gender","diagnosis","treatment","doctor_name","hospital_name","admission_date","discharge_date"]
}


document_assembler = DocumentAssembler()\
    .setInputCol("text")\
    .setOutputCol("document")

text2sql_with_schema_single_table_augmented = Text2SQL.pretrained("text2sql_with_schema_single_table_augmented", "en", "clinical/models")\
    .setMaxNewTokens(200)\
    .setSchema(query_schema)\
    .setInputCols(["document"])\
    .setOutputCol("sql")

pipeline = Pipeline(stages=[document_assembler, text2sql_with_schema_single_table_augmented ])

data = spark.createDataFrame([[question]]).toDF("text")

pipeline.fit(data)\
        .transform(data)\
        .select("sql.result")\
        .show(truncate=False)
val question = """What is the average age of male patients with 'Diabetes'? """
val query_schema : Map[String, List[String]] = Map(
    "medical_treatment" -> List("patient_id","patient_name","age","gender","diagnosis","treatment","doctor_name","hospital_name","admission_date","discharge_date")
  )

val document_assembler = new DocumentAssembler()
    .setInputCol("text")
    .setOutputCol("document")

val text2sql_with_schema_single_table_augmented = new Text2SQL.pretrained("text2sql_with_schema_single_table_augmented", "en", "clinical/models")
    .setMaxNewTokens(200)
    .setSchema(query_schema)
    .setInputCols(["document"])
    .setOutputCol("sql")

val pipeline = new Pipeline().setStages(Array(document_assembler, text2sql_with_schema_single_table_augmented ))

val data = Seq(Array(text)).toDS.toDF("text")

val result = pipeline.fit(data).transform(data)

Results

[SELECT AVG(age) FROM medical_treatment WHERE gender = 'male' AND diagnosis = 'diabetes']

Model Information

Model Name: text2sql_with_schema_single_table_augmented
Compatibility: Healthcare NLP 5.1.0+
License: Licensed
Edition: Official
Language: en
Size: 3.0 GB