-
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
tuyucheng
committed
Dec 27, 2024
1 parent
54e6576
commit 69b09e8
Showing
227 changed files
with
9,014 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
<?xml version="1.0" encoding="UTF-8"?> | ||
<project xmlns="http://maven.apache.org/POM/4.0.0" | ||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" | ||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd"> | ||
<modelVersion>4.0.0</modelVersion> | ||
<artifactId>deeplearning4j</artifactId> | ||
<packaging>jar</packaging> | ||
<name>deeplearning4j</name> | ||
|
||
<parent> | ||
<groupId>cn.tuyucheng.taketoday</groupId> | ||
<artifactId>taketoday-tutorial4j</artifactId> | ||
<version>1.0.0</version> | ||
</parent> | ||
|
||
<dependencies> | ||
<dependency> | ||
<groupId>org.nd4j</groupId> | ||
<artifactId>nd4j-api</artifactId> | ||
<version>${dl4j.version}</version> | ||
</dependency> | ||
<dependency> | ||
<groupId>org.nd4j</groupId> | ||
<artifactId>nd4j-native-platform</artifactId> | ||
<version>${dl4j.version}</version> | ||
</dependency> | ||
<dependency> | ||
<groupId>org.deeplearning4j</groupId> | ||
<artifactId>deeplearning4j-core</artifactId> | ||
<version>${dl4j.version}</version> | ||
</dependency> | ||
<dependency> | ||
<groupId>org.deeplearning4j</groupId> | ||
<artifactId>deeplearning4j-nn</artifactId> | ||
<version>${dl4j.version}</version> | ||
</dependency> | ||
<dependency> | ||
<groupId>org.slf4j</groupId> | ||
<artifactId>slf4j-api</artifactId> | ||
<version>${org.slf4j.version}</version> | ||
</dependency> | ||
<dependency> | ||
<groupId>org.slf4j</groupId> | ||
<artifactId>slf4j-log4j12</artifactId> | ||
<version>${org.slf4j.version}</version> | ||
</dependency> | ||
<!-- https://mvnrepository.com/artifact/org.datavec/datavec-api --> | ||
<dependency> | ||
<groupId>org.datavec</groupId> | ||
<artifactId>datavec-api</artifactId> | ||
<version>${dl4j.version}</version> | ||
</dependency> | ||
<dependency> | ||
<groupId>org.apache.httpcomponents</groupId> | ||
<artifactId>httpclient</artifactId> | ||
<version>${httpclient.version}</version> | ||
</dependency> | ||
<dependency> | ||
<groupId>org.projectlombok</groupId> | ||
<artifactId>lombok</artifactId> | ||
<version>${lombok.version}</version> | ||
<scope>provided</scope> | ||
</dependency> | ||
</dependencies> | ||
|
||
<properties> | ||
<dl4j.version>0.9.1</dl4j.version> <!-- Latest non beta version --> | ||
<httpclient.version>4.3.5</httpclient.version> | ||
</properties> | ||
</project> |
78 changes: 78 additions & 0 deletions
78
deeplearning4j/src/main/java/cn/tuyucheng/taketoday/deeplearning4j/IrisClassifier.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
package cn.tuyucheng.taketoday.deeplearning4j; | ||
|
||
import org.datavec.api.records.reader.RecordReader; | ||
import org.datavec.api.records.reader.impl.csv.CSVRecordReader; | ||
import org.datavec.api.split.FileSplit; | ||
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; | ||
import org.deeplearning4j.eval.Evaluation; | ||
import org.deeplearning4j.nn.conf.BackpropType; | ||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration; | ||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; | ||
import org.deeplearning4j.nn.conf.layers.DenseLayer; | ||
import org.deeplearning4j.nn.conf.layers.OutputLayer; | ||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; | ||
import org.deeplearning4j.nn.weights.WeightInit; | ||
import org.nd4j.linalg.activations.Activation; | ||
import org.nd4j.linalg.api.ndarray.INDArray; | ||
import org.nd4j.linalg.dataset.DataSet; | ||
import org.nd4j.linalg.dataset.SplitTestAndTrain; | ||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; | ||
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization; | ||
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize; | ||
import org.nd4j.linalg.io.ClassPathResource; | ||
import org.nd4j.linalg.lossfunctions.LossFunctions; | ||
|
||
import java.io.IOException; | ||
|
||
public class IrisClassifier { | ||
|
||
private static final int CLASSES_COUNT = 3; | ||
private static final int FEATURES_COUNT = 4; | ||
|
||
public static void main(String[] args) throws IOException, InterruptedException { | ||
DataSet allData; | ||
try (RecordReader recordReader = new CSVRecordReader(0, ',')) { | ||
recordReader.initialize(new FileSplit(new ClassPathResource("iris.txt").getFile())); | ||
|
||
DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader, 150, FEATURES_COUNT, CLASSES_COUNT); | ||
allData = iterator.next(); | ||
} | ||
|
||
allData.shuffle(42); | ||
|
||
DataNormalization normalizer = new NormalizerStandardize(); | ||
normalizer.fit(allData); | ||
normalizer.transform(allData); | ||
|
||
SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.65); | ||
DataSet trainingData = testAndTrain.getTrain(); | ||
DataSet testData = testAndTrain.getTest(); | ||
|
||
MultiLayerConfiguration configuration = new NeuralNetConfiguration.Builder() | ||
.iterations(1000) | ||
.activation(Activation.TANH) | ||
.weightInit(WeightInit.XAVIER) | ||
.regularization(true) | ||
.learningRate(0.1).l2(0.0001) | ||
.list() | ||
.layer(0, new DenseLayer.Builder().nIn(FEATURES_COUNT).nOut(3) | ||
.build()) | ||
.layer(1, new DenseLayer.Builder().nIn(3).nOut(3) | ||
.build()) | ||
.layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) | ||
.activation(Activation.SOFTMAX) | ||
.nIn(3).nOut(CLASSES_COUNT).build()) | ||
.backpropType(BackpropType.Standard).pretrain(false) | ||
.build(); | ||
|
||
MultiLayerNetwork model = new MultiLayerNetwork(configuration); | ||
model.init(); | ||
model.fit(trainingData); | ||
|
||
INDArray output = model.output(testData.getFeatures()); | ||
|
||
Evaluation eval = new Evaluation(CLASSES_COUNT); | ||
eval.eval(testData.getLabels(), output); | ||
System.out.println(eval.stats()); | ||
} | ||
} |
47 changes: 47 additions & 0 deletions
47
...arning4j/src/main/java/cn/tuyucheng/taketoday/deeplearning4j/cnn/CifarDataSetService.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
package cn.tuyucheng.taketoday.deeplearning4j.cnn; | ||
|
||
import lombok.Getter; | ||
import org.deeplearning4j.datasets.iterator.impl.CifarDataSetIterator; | ||
import org.deeplearning4j.nn.conf.inputs.InputType; | ||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; | ||
|
||
import java.util.List; | ||
|
||
@Getter | ||
class CifarDataSetService implements IDataSetService { | ||
|
||
private final InputType inputType = InputType.convolutional(32, 32, 3); | ||
private final int trainImagesNum = 512; | ||
private final int testImagesNum = 128; | ||
private final int trainBatch = 16; | ||
private final int testBatch = 8; | ||
|
||
private final CifarDataSetIterator trainIterator; | ||
|
||
private final CifarDataSetIterator testIterator; | ||
|
||
CifarDataSetService() { | ||
trainIterator = new CifarDataSetIterator(trainBatch, trainImagesNum, true); | ||
testIterator = new CifarDataSetIterator(testBatch, testImagesNum, false); | ||
} | ||
|
||
@Override | ||
public DataSetIterator trainIterator() { | ||
return trainIterator; | ||
} | ||
|
||
@Override | ||
public DataSetIterator testIterator() { | ||
return testIterator; | ||
} | ||
|
||
@Override | ||
public InputType inputType() { | ||
return inputType; | ||
} | ||
|
||
@Override | ||
public List<String> labels() { | ||
return trainIterator.getLabels(); | ||
} | ||
} |
17 changes: 17 additions & 0 deletions
17
deeplearning4j/src/main/java/cn/tuyucheng/taketoday/deeplearning4j/cnn/CnnExample.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
package cn.tuyucheng.taketoday.deeplearning4j.cnn; | ||
|
||
import lombok.extern.slf4j.Slf4j; | ||
import org.deeplearning4j.eval.Evaluation; | ||
|
||
@Slf4j | ||
class CnnExample { | ||
|
||
public static void main(String... args) { | ||
CnnModel network = new CnnModel(new CifarDataSetService(), new CnnModelProperties()); | ||
|
||
network.train(); | ||
Evaluation evaluation = network.evaluate(); | ||
|
||
LOGGER.info(evaluation.stats()); | ||
} | ||
} |
118 changes: 118 additions & 0 deletions
118
deeplearning4j/src/main/java/cn/tuyucheng/taketoday/deeplearning4j/cnn/CnnModel.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
package cn.tuyucheng.taketoday.deeplearning4j.cnn; | ||
|
||
import lombok.extern.slf4j.Slf4j; | ||
import org.deeplearning4j.eval.Evaluation; | ||
import org.deeplearning4j.nn.api.OptimizationAlgorithm; | ||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration; | ||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; | ||
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; | ||
import org.deeplearning4j.nn.conf.layers.OutputLayer; | ||
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer; | ||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; | ||
import org.deeplearning4j.nn.weights.WeightInit; | ||
import org.nd4j.linalg.activations.Activation; | ||
import org.nd4j.linalg.lossfunctions.LossFunctions; | ||
|
||
import java.util.stream.IntStream; | ||
|
||
@Slf4j | ||
class CnnModel { | ||
|
||
private final IDataSetService dataSetService; | ||
|
||
private final MultiLayerNetwork network; | ||
|
||
private final CnnModelProperties properties; | ||
|
||
CnnModel(IDataSetService dataSetService, CnnModelProperties properties) { | ||
this.dataSetService = dataSetService; | ||
this.properties = properties; | ||
|
||
MultiLayerConfiguration configuration = new NeuralNetConfiguration.Builder() | ||
.seed(1611) | ||
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) | ||
.learningRate(properties.getLearningRate()) | ||
.regularization(true) | ||
.updater(properties.getOptimizer()) | ||
.list() | ||
.layer(0, conv5x5()) | ||
.layer(1, pooling2x2Stride2()) | ||
.layer(2, conv3x3Stride1Padding2()) | ||
.layer(3, pooling2x2Stride1()) | ||
.layer(4, conv3x3Stride1Padding1()) | ||
.layer(5, pooling2x2Stride1()) | ||
.layer(6, dense()) | ||
.pretrain(false) | ||
.backprop(true) | ||
.setInputType(dataSetService.inputType()) | ||
.build(); | ||
|
||
network = new MultiLayerNetwork(configuration); | ||
} | ||
|
||
void train() { | ||
network.init(); | ||
int epochsNum = properties.getEpochsNum(); | ||
IntStream.range(1, epochsNum + 1).forEach(epoch -> { | ||
LOGGER.info("Epoch {} / {}", epoch, epochsNum); | ||
network.fit(dataSetService.trainIterator()); | ||
}); | ||
} | ||
|
||
Evaluation evaluate() { | ||
return network.evaluate(dataSetService.testIterator()); | ||
} | ||
|
||
private ConvolutionLayer conv5x5() { | ||
return new ConvolutionLayer.Builder(5, 5) | ||
.nIn(3) | ||
.nOut(16) | ||
.stride(1, 1) | ||
.padding(1, 1) | ||
.weightInit(WeightInit.XAVIER_UNIFORM) | ||
.activation(Activation.RELU) | ||
.build(); | ||
} | ||
|
||
private SubsamplingLayer pooling2x2Stride2() { | ||
return new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX) | ||
.kernelSize(2, 2) | ||
.stride(2, 2) | ||
.build(); | ||
} | ||
|
||
private ConvolutionLayer conv3x3Stride1Padding2() { | ||
return new ConvolutionLayer.Builder(3, 3) | ||
.nOut(32) | ||
.stride(1, 1) | ||
.padding(2, 2) | ||
.weightInit(WeightInit.XAVIER_UNIFORM) | ||
.activation(Activation.RELU) | ||
.build(); | ||
} | ||
|
||
private SubsamplingLayer pooling2x2Stride1() { | ||
return new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX) | ||
.kernelSize(2, 2) | ||
.stride(1, 1) | ||
.build(); | ||
} | ||
|
||
private ConvolutionLayer conv3x3Stride1Padding1() { | ||
return new ConvolutionLayer.Builder(3, 3) | ||
.nOut(64) | ||
.stride(1, 1) | ||
.padding(1, 1) | ||
.weightInit(WeightInit.XAVIER_UNIFORM) | ||
.activation(Activation.RELU) | ||
.build(); | ||
} | ||
|
||
private OutputLayer dense() { | ||
return new OutputLayer.Builder(LossFunctions.LossFunction.MEAN_SQUARED_LOGARITHMIC_ERROR) | ||
.activation(Activation.SOFTMAX) | ||
.weightInit(WeightInit.XAVIER_UNIFORM) | ||
.nOut(dataSetService.labels().size() - 1) | ||
.build(); | ||
} | ||
} |
13 changes: 13 additions & 0 deletions
13
...earning4j/src/main/java/cn/tuyucheng/taketoday/deeplearning4j/cnn/CnnModelProperties.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
package cn.tuyucheng.taketoday.deeplearning4j.cnn; | ||
|
||
import lombok.Value; | ||
import org.deeplearning4j.nn.conf.Updater; | ||
|
||
@Value | ||
class CnnModelProperties { | ||
private final int epochsNum = 512; | ||
|
||
private final double learningRate = 0.001; | ||
|
||
private final Updater optimizer = Updater.ADAM; | ||
} |
16 changes: 16 additions & 0 deletions
16
deeplearning4j/src/main/java/cn/tuyucheng/taketoday/deeplearning4j/cnn/IDataSetService.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
package cn.tuyucheng.taketoday.deeplearning4j.cnn; | ||
|
||
import org.deeplearning4j.nn.conf.inputs.InputType; | ||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; | ||
|
||
import java.util.List; | ||
|
||
interface IDataSetService { | ||
DataSetIterator trainIterator(); | ||
|
||
DataSetIterator testIterator(); | ||
|
||
InputType inputType(); | ||
|
||
List<String> labels(); | ||
} |
Oops, something went wrong.