sparknlp_jsl._tf_graph_builders.graph_builders.GenericClassifierTFGraphBuilder#

class sparknlp_jsl._tf_graph_builders.graph_builders.GenericClassifierTFGraphBuilder(build_params)[source]#

Bases: TFGraphBuilder

Class to create the the TF graphs for GenericClassifierApproach

Examples

>>> from sparknlp_jsl.training import tf_graph
>>> from sparknlp_jsl.base import *
>>> from sparknlp.annotator import *
>>> from sparknlp_jsl.annotator import *
>>> from sparknlp_jsl.annotator import *
>>> dataframe = pd.read_csv('petfinder-mini.csv')
>>> DL_params = {"input_dim": 302,"output_dim": 2,"hidden_layers": [300, 200, 100], "hidden_act": "tanh",'hidden_act_l2':1,'batch_norm':1}
>>> tf_graph.build("generic_classifier",build_params=DL_params, model_location="/content/gc_graph", model_filename="auto")
>>> gen_clf = GenericClassifierApproach() \
...    .setLabelColumn("target") \
...    .setInputCols(["features"]) \
...    .setOutputCol("prediction") \
...    .setModelFile('/content/gc_graph/gcl.302.2.pb') \
...    .setEpochsNumber(50) \
...    .setBatchSize(100) \
...    .setFeatureScaling("zscore") \
...    .setFixImbalance(True) \
...    .setLearningRate(0.001) \
...    .setOutputLogsPath("logs") \
...    .setValidationSplit(0.2)
>>> clf_Pipeline = Pipeline(stages=[features_asm,gen_clf])

Methods

__init__(build_params)

build(model_location, model_filename)

check_build_params()

get_build_param(build_param)

get_build_params()

get_build_params_with_defaults()

get_model_build_param_explanations()

get_model_build_params()

get_model_filename()

supports_auto_file_name()