Skip to content

Commit

Permalink
Output CSVLogger callback to a tabular file (#1335)
Browse files Browse the repository at this point in the history
* add training history

* update dl train tool

* update

* fix end tag

* fix linting
  • Loading branch information
anuprulez authored Oct 2, 2023
1 parent 3c1e6c7 commit 80417bf
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 8 deletions.
24 changes: 18 additions & 6 deletions tools/sklearn/keras_train_and_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ def main(
infile1,
infile2,
outfile_result,
outfile_history=None,
outfile_object=None,
outfile_y_true=None,
outfile_y_preds=None,
Expand Down Expand Up @@ -215,6 +216,9 @@ def main(
outfile_result : str
File path to save the results, either cv_results or test result.
outfile_history : str, optional
File path to save the training history.
outfile_object : str, optional
File path to save searchCV object.
Expand Down Expand Up @@ -253,9 +257,7 @@ def main(
swapping = params["experiment_schemes"]["hyperparams_swapping"]
swap_params = _eval_swap_params(swapping)
estimator.set_params(**swap_params)

estimator_params = estimator.get_params()

# store read dataframe object
loaded_df = {}

Expand Down Expand Up @@ -448,12 +450,20 @@ def main(
# train and eval
if hasattr(estimator, "config") and hasattr(estimator, "model_type"):
if exp_scheme == "train_val_test":
estimator.fit(X_train, y_train, validation_data=(X_val, y_val))
history = estimator.fit(X_train, y_train, validation_data=(X_val, y_val))
else:
estimator.fit(X_train, y_train, validation_data=(X_test, y_test))
history = estimator.fit(X_train, y_train, validation_data=(X_test, y_test))
else:
estimator.fit(X_train, y_train)

history = estimator.fit(X_train, y_train)
if "callbacks" in estimator_params:
for cb in estimator_params["callbacks"]:
if cb["callback_selection"]["callback_type"] == "CSVLogger":
hist_df = pd.DataFrame(history.history)
hist_df["epoch"] = np.arange(1, estimator_params["epochs"] + 1)
epo_col = hist_df.pop('epoch')
hist_df.insert(0, 'epoch', epo_col)
hist_df.to_csv(path_or_buf=outfile_history, sep="\t", header=True, index=False)
break
if isinstance(estimator, KerasGBatchClassifier):
scores = {}
steps = estimator.prediction_steps
Expand Down Expand Up @@ -526,6 +536,7 @@ def main(
aparser.add_argument("-X", "--infile1", dest="infile1")
aparser.add_argument("-y", "--infile2", dest="infile2")
aparser.add_argument("-O", "--outfile_result", dest="outfile_result")
aparser.add_argument("-hi", "--outfile_history", dest="outfile_history")
aparser.add_argument("-o", "--outfile_object", dest="outfile_object")
aparser.add_argument("-l", "--outfile_y_true", dest="outfile_y_true")
aparser.add_argument("-p", "--outfile_y_preds", dest="outfile_y_preds")
Expand All @@ -542,6 +553,7 @@ def main(
args.infile1,
args.infile2,
args.outfile_result,
outfile_history=args.outfile_history,
outfile_object=args.outfile_object,
outfile_y_true=args.outfile_y_true,
outfile_y_preds=args.outfile_y_preds,
Expand Down
10 changes: 8 additions & 2 deletions tools/sklearn/keras_train_and_eval.xml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
<expand macro="macro_stdio" />
<version_command>echo "@VERSION@"</version_command>
<command>
<![CDATA[
<![CDATA[
export HDF5_USE_FILE_LOCKING='FALSE';
#if $input_options.selected_input == 'refseq_and_interval'
bgzip -c '$input_options.target_file' > '${target_file.element_identifier}.gz' &&
Expand All @@ -29,6 +29,9 @@
#end if
--infile2 '$input_options.infile2'
--outfile_result '$outfile_result'
#if $save and 'save_csvlogger' in str($save)
--outfile_history '$outfile_history'
#end if
#if $save and 'save_estimator' in str($save)
--outfile_object '$outfile_object'
#end if
Expand All @@ -39,7 +42,6 @@
#if $experiment_schemes.test_split.split_algos.shuffle == 'group'
--groups '$experiment_schemes.test_split.split_algos.groups_selector.infile_g'
#end if
]]>
</command>
<configfiles>
Expand Down Expand Up @@ -81,10 +83,14 @@
<param name="save" type="select" multiple='true' display="checkboxes" label="Save the fitted model" optional="true" help="Evaluation scores will be output by default.">
<option value="save_estimator" selected="true">Fitted estimator</option>
<option value="save_prediction">True labels and prediction results from evaluation for downstream analysis</option>
<option value="save_csvlogger">Display CSVLogger if selected as a callback in the Keras model builder tool</option>
</param>
</inputs>
<outputs>
<data format="tabular" name="outfile_result" />
<data format="tabular" name="outfile_history" label="Deep learning training history log on ${on_string}">
<filter>str(save) and 'save_csvlogger' in str(save)</filter>
</data>
<data format="h5mlm" name="outfile_object" label="Fitted estimator or estimator skeleton on ${on_string}">
<filter>str(save) and 'save_estimator' in str(save)</filter>
</data>
Expand Down

0 comments on commit 80417bf

Please sign in to comment.