sparknlp_jsl._tf_graph_builders.graph_builders.GenericClassifierTFGraphBuilder#
- class sparknlp_jsl._tf_graph_builders.graph_builders.GenericClassifierTFGraphBuilder(build_params)[source]#
Bases:
TFGraphBuilder
Class to create the the TF graphs for GenericClassifierApproach
Examples
>>> from sparknlp_jsl.training import tf_graph >>> from sparknlp_jsl.base import * >>> from sparknlp.annotator import * >>> from sparknlp_jsl.annotator import * >>> from sparknlp_jsl.annotator import * >>> dataframe = pd.read_csv('petfinder-mini.csv') >>> DL_params = {"input_dim": 302,"output_dim": 2,"hidden_layers": [300, 200, 100], "hidden_act": "tanh",'hidden_act_l2':1,'batch_norm':1} >>> tf_graph.build("generic_classifier",build_params=DL_params, model_location="/content/gc_graph", model_filename="auto") >>> gen_clf = GenericClassifierApproach() \ ... .setLabelColumn("target") \ ... .setInputCols(["features"]) \ ... .setOutputCol("prediction") \ ... .setModelFile('/content/gc_graph/gcl.302.2.pb') \ ... .setEpochsNumber(50) \ ... .setBatchSize(100) \ ... .setFeatureScaling("zscore") \ ... .setFixImbalance(True) \ ... .setLearningRate(0.001) \ ... .setOutputLogsPath("logs") \ ... .setValidationSplit(0.2)
>>> clf_Pipeline = Pipeline(stages=[features_asm,gen_clf])
Methods
__init__
(build_params)build
(model_location, model_filename)check_build_params
()get_build_param
(build_param)get_build_params
()get_build_params_with_defaults
()get_model_build_param_explanations
()get_model_build_params
()get_model_filename
()supports_auto_file_name
()