Skip to content

Commit

Permalink
update MLP model
Browse files Browse the repository at this point in the history
  • Loading branch information
phatdatnguyen committed Feb 9, 2024
1 parent 28be974 commit 2314f16
Show file tree
Hide file tree
Showing 11 changed files with 195 additions and 83 deletions.
2 changes: 1 addition & 1 deletion AboutBox.cs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ public static string AssemblyVersion
{
get
{
return Assembly.GetExecutingAssembly().GetName().Version.ToString();
return "1.0.2";
}
}

Expand Down
2 changes: 1 addition & 1 deletion AboutBox.resx
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
<resheader name="reader">System.Resources.ResXResourceReader, System.Windows.Forms, ...</resheader>
<resheader name="writer">System.Resources.ResXResourceWriter, System.Windows.Forms, ...</resheader>
<data name="Name1"><value>this is my long string</value><comment>this is a comment</comment></data>
<data name="Color1" type="System.Drawing.Color, System.Drawing"">Blue</data>
<data name="Color1" type="System.Drawing.Color, System.Drawing">Blue</data>
<data name="Bitmap1" mimetype="application/x-microsoft.net.object.binary.base64">
<value>[base64 mime encoded serialized .NET Framework object]</value>
</data>
Expand Down
10 changes: 9 additions & 1 deletion CustomControls/ModelControls/MLPModelControl.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using TorchSharp;
using TorchSharp.Modules;
using System.Data;
using System.Diagnostics;

namespace JadeChem.CustomControls.ModelControls
{
Expand Down Expand Up @@ -57,6 +58,7 @@ public partial class MLPModelControl : UserControl
private readonly List<int> validationEpochs = new();
private readonly List<float> validationLosses = new();

private Stopwatch trainingStopwatch = new Stopwatch();
// Flags
private bool trainWithValidation = true;
private bool logWithTensorboard = false;
Expand Down Expand Up @@ -745,6 +747,7 @@ private void TrainButton_Click(object sender, EventArgs e)
torch.cuda.manual_seed(randomSeed);

trainProgressBar.Visible = true;
trainingStopwatch.Start();
int epochs = (int)epochsNumericUpDown.Value;
int startEpochIndex = trainedEpochs;
int saveInterval = (int)saveIntervalNumericUpDown.Value;
Expand Down Expand Up @@ -777,9 +780,10 @@ private void TrainButton_Click(object sender, EventArgs e)
lrScheduler?.step();

// Update progress
double elapsed = (double)trainingStopwatch.ElapsedMilliseconds / 1000;
trainProgressBar.Value = (int)Math.Ceiling((float)(epochIndex - startEpochIndex + 1) / epochs * 100);
trainProgressBar.Update();
trainProgressLabel.Text = "Epoch " + (epochIndex + 1).ToString() + "/" + (startEpochIndex + epochs).ToString();
trainProgressLabel.Text = "Epoch " + (epochIndex + 1).ToString() + "/" + (startEpochIndex + epochs).ToString() + " (" + Math.Round(elapsed, 4).ToString() + " s)";
trainProgressLabel.Update();
trainLossLabel.Text = "Train loss: " + Math.Round(trainLoss, 4).ToString();
trainLossLabel.Update();
Expand Down Expand Up @@ -814,11 +818,15 @@ private void TrainButton_Click(object sender, EventArgs e)
}
catch (Exception ex)
{
trainingStopwatch.Stop();
trainingStopwatch.Reset();
MessageBox.Show(this, ex.Message, "Error", MessageBoxButtons.OK, MessageBoxIcon.Error);
Cursor = Cursors.Default;
return;
}

trainingStopwatch.Stop();
trainingStopwatch.Reset();
Cursor = Cursors.Default;

// Raise the event
Expand Down
29 changes: 27 additions & 2 deletions Dialogs/VisualizeLossesDialog.Designer.cs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

38 changes: 32 additions & 6 deletions Dialogs/VisualizeLossesDialog.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
using OxyPlot;
using Accord.IO;
using Accord.Math;
using Accord.Statistics.Kernels;
using OxyPlot;
using OxyPlot.Axes;
using OxyPlot.Legends;
using OxyPlot.Series;
using System.Data;

namespace JadeChem.Dialogs
{
Expand All @@ -10,7 +14,7 @@ public partial class VisualizeLossesDialog : Form
#region Fields
private readonly List<int> trainEpochs;
private readonly List<float> trainLosses;
private readonly List<int> validationEpoch;
private readonly List<int> validationEpochs;
private readonly List<float> validationLosses;
private readonly string lossFunctionName;
#endregion
Expand All @@ -22,7 +26,7 @@ public VisualizeLossesDialog(List<int> trainEpochs, List<float> trainLosses, Lis

this.trainEpochs = trainEpochs;
this.trainLosses = trainLosses;
this.validationEpoch = validationEpochs;
this.validationEpochs = validationEpochs;
this.validationLosses = validationLosses;
this.lossFunctionName = lossFunctionName;
}
Expand All @@ -45,12 +49,12 @@ private void VisualizeLossesDialog_Load(object sender, EventArgs e)

plotModel.Series.Add(trainLossesLineSeries);

if (validationEpoch.Count > 0)
if (validationEpochs.Count > 0)
{
LineSeries validationLossesLineSeries = new();
for (int epochIndex = 0; epochIndex < validationEpoch.Count; epochIndex++)
for (int epochIndex = 0; epochIndex < validationEpochs.Count; epochIndex++)
{
double x = validationEpoch[epochIndex];
double x = validationEpochs[epochIndex];
double y = validationLosses[epochIndex];

validationLossesLineSeries.Points.Add(new DataPoint(x, y));
Expand Down Expand Up @@ -87,6 +91,28 @@ private void VisualizeLossesDialog_Load(object sender, EventArgs e)
dataPlotView.Model = plotModel;
}

private void ExportButton_Click(object sender, EventArgs e)
{
if (exportFileDialog.ShowDialog() != DialogResult.OK)
return;

CsvWriter csvWriter = new CsvWriter(exportFileDialog.FileName, ',');

if (trainLosses != null && validationLosses.Count == 0)
{
double[][] lossTable = trainEpochs.ToArray().ToDouble().ToJagged();
lossTable = lossTable.Concatenate(trainLosses.ToArray().ToDouble().ToJagged());
csvWriter.Write(lossTable.ToTable("Epoch", "Train loss"));
}
else if (trainLosses != null && validationLosses != null)
{
double[][] lossTable = trainEpochs.ToArray().ToDouble().ToJagged();
lossTable = lossTable.Concatenate(trainLosses.ToArray().ToDouble().ToJagged());
lossTable = lossTable.Concatenate(validationLosses.ToArray().ToDouble().ToJagged());
csvWriter.Write(lossTable.ToTable("Epoch", "Train loss", "Validation loss"));
}
}

private void CloseButton_Click(object sender, EventArgs e)
{
Close();
Expand Down
5 changes: 4 additions & 1 deletion Dialogs/VisualizeLossesDialog.resx
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
<resheader name="reader">System.Resources.ResXResourceReader, System.Windows.Forms, ...</resheader>
<resheader name="writer">System.Resources.ResXResourceWriter, System.Windows.Forms, ...</resheader>
<data name="Name1"><value>this is my long string</value><comment>this is a comment</comment></data>
<data name="Color1" type="System.Drawing.Color, System.Drawing"">Blue</data>
<data name="Color1" type="System.Drawing.Color, System.Drawing">Blue</data>
<data name="Bitmap1" mimetype="application/x-microsoft.net.object.binary.base64">
<value>[base64 mime encoded serialized .NET Framework object]</value>
</data>
Expand Down Expand Up @@ -117,4 +117,7 @@
<resheader name="writer">
<value>System.Resources.ResXResourceWriter, System.Windows.Forms, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089</value>
</resheader>
<metadata name="exportFileDialog.TrayLocation" type="System.Drawing.Point, System.Drawing, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b03f5f7f11d50a3a">
<value>17, 17</value>
</metadata>
</root>
3 changes: 2 additions & 1 deletion JadeChem.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
<Description>A GUI software for extracting chemical features and performing supervised machine learning.</Description>
<NeutralLanguage>en-US</NeutralLanguage>
<ApplicationIcon>icon.ico</ApplicationIcon>
<Version>1.0.2</Version>
</PropertyGroup>

<ItemGroup>
Expand All @@ -31,7 +32,7 @@
<PackageReference Include="Newtonsoft.Json" Version="13.0.3" />
<PackageReference Include="OxyPlot.WindowsForms" Version="2.1.2" />
<PackageReference Include="RDKit2DotNetStandard" Version="1.0.32" />
<PackageReference Include="TorchSharp-cuda-windows" Version="0.101.4" />
<PackageReference Include="TorchSharp-cuda-windows" Version="0.101.5" />
</ItemGroup>

<ItemGroup>
Expand Down
9 changes: 5 additions & 4 deletions MainForm.Designer.cs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 9 additions & 0 deletions MainForm.cs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,15 @@ private void AboutToolStripMenuItem_Click(object sender, EventArgs e)
aboutBox.ShowDialog(this);
}

private void MainForm_Load(object sender, EventArgs e)
{
PredictionTaskForm predictionTaskForm = new()
{
MdiParent = this
};
predictionTaskForm.Show();
}

private void MainForm_MdiChildActivate(object sender, EventArgs e)
{
if (ActiveMdiChild == null)
Expand Down
2 changes: 1 addition & 1 deletion MainForm.resx
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
<resheader name="reader">System.Resources.ResXResourceReader, System.Windows.Forms, ...</resheader>
<resheader name="writer">System.Resources.ResXResourceWriter, System.Windows.Forms, ...</resheader>
<data name="Name1"><value>this is my long string</value><comment>this is a comment</comment></data>
<data name="Color1" type="System.Drawing.Color, System.Drawing"">Blue</data>
<data name="Color1" type="System.Drawing.Color, System.Drawing">Blue</data>
<data name="Bitmap1" mimetype="application/x-microsoft.net.object.binary.base64">
<value>[base64 mime encoded serialized .NET Framework object]</value>
</data>
Expand Down
Loading

0 comments on commit 2314f16

Please sign in to comment.