1.前言
本文纯属划水,最近没啥可写的了,因此写这么一篇凑个数,对Java比较陌生的朋友可以绕行。众所周知,Java是一门古老且全能的语言,从网站开发到安卓应用,从爬虫到机器学习,几乎无所不能。本期我们就尝试用Java进行一点机器学习的探索。
2.框架选择
本期我们选择的框架是DeepLearning4j,4j即为 for java之意,目前dl4j是java机器学习中的主流框架之一。由于手上目前没有像样的linux机,索性此处不用docker进行环境配置了(说的好像我会用一样),直接上maven吧。
3.环境搭建
新建一个maven工程,创建完毕后,修改pom.xml,为了方便起见,我直接把我的整个xml贴在下面:
我的jdk是1.8,如果用其他版本的jdk,应该注意修改下面的java.version的标记
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>org</groupId>
<artifactId>dl4j-tutorials</artifactId>
<version>1.0-SNAPSHOT</version>
<properties>
<!-- Change the nd4j.backend property to nd4j-cuda-9.2 to use CUDA GPUs -->
<!-- CPU使用 nd4j-native-platform -->
<!-- GPU使用 nd4j-cuda-9.2-platform-->
<nd4j.backend>nd4j-native</nd4j.backend>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<shadedClassifier>bin</shadedClassifier>
<java.version>1.8</java.version>
<nd4j.version>1.0.0-beta</nd4j.version>
<dl4j.version>1.0.0-beta</dl4j.version>
<datavec.version>1.0.0-beta</datavec.version>
<arbiter.version>1.0.0-beta</arbiter.version>
<rl4j.version>1.0.0-beta</rl4j.version>
<!-- For Spark examples: change the _1 to _2 to switch between Spark 1 and Spark 2 -->
<dl4j.spark.version>1.0.0-beta2_spark_1</dl4j.spark.version>
<datavec.spark.version>1.0.0-beta2_spark_1</datavec.spark.version>
<jcommander.version>1.27</jcommander.version>
<!-- Scala binary version: DL4J's Spark and UI functionality are released with both Scala 2.10 and 2.11 support -->
<scala.binary.version>2.11</scala.binary.version>
<guava.version>19.0</guava.version>
<logback.version>1.1.7</logback.version>
<jfreechart.version>1.0.13</jfreechart.version>
<jcommon.version>1.0.23</jcommon.version>
<maven-shade-plugin.version>2.4.3</maven-shade-plugin.version>
<exec-maven-plugin.version>1.4.0</exec-maven-plugin.version>
<maven.minimum.version>3.3.1</maven.minimum.version>
</properties>
<dependencyManagement>
<dependencies>
</dependencies>
</dependencyManagement>
<dependencies>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-core</artifactId>
<version>${dl4j.version}</version>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-nlp</artifactId>
<version>${dl4j.version}</version>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-zoo</artifactId>
<version>${dl4j.version}</version>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-nn</artifactId>
<version>${dl4j.version}</version>
</dependency>
<!-- deeplearning4j-ui is used for HistogramIterationListener + visualization: see http://deeplearning4j.org/visualization -->
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-ui_${scala.binary.version}</artifactId>
<version>${dl4j.version}</version>
</dependency>
<dependency>
<groupId>com.google.guava</groupId>
<artifactId>guava</artifactId>
<version>${guava.version}</version>
</dependency>
<!-- datavec-data-codec: used only in video example for loading video data -->
<dependency>
<artifactId>datavec-data-codec</artifactId>
<groupId>org.datavec</groupId>
<version>${datavec.version}</version>
</dependency>
<dependency>
<groupId>jfree</groupId>
<artifactId>jfreechart</artifactId>
<version>${jfreechart.version}</version>
</dependency>
<dependency>
<groupId>org.jfree</groupId>
<artifactId>jcommon</artifactId>
<version>${jcommon.version}</version>
</dependency>
<dependency>
<groupId>org.apache.httpcomponents</groupId>
<artifactId>httpclient</artifactId>
<version>4.3.6</version>
</dependency>
<dependency>
<groupId>ch.qos.logback</groupId>
<artifactId>logback-classic</artifactId>
<version>${logback.version}</version>
</dependency>
<!-- https://mvnrepository.com/artifact/org.deeplearning4j/deeplearning4j-cuda-9.1 -->
<!-- https://mvnrepository.com/artifact/org.nd4j/nd4j-cuda-8.0 -->
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-native-platform</artifactId>
<version>${nd4j.version}</version>
</dependency>
</dependencies>
<build>
<plugins>
<plugin>
<groupId>org.codehaus.mojo</groupId>
<artifactId>exec-maven-plugin</artifactId>
<version>${exec-maven-plugin.version}</version>
<executions>
<execution>
<goals>
<goal>exec</goal>
</goals>
</execution>
</executions>
<configuration>
<executable>java</executable>
</configuration>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-shade-plugin</artifactId>
<version>${maven-shade-plugin.version}</version>
<configuration>
<shadedArtifactAttached>true</shadedArtifactAttached>
<shadedClassifierName>${shadedClassifier}</shadedClassifierName>
<createDependencyReducedPom>true</createDependencyReducedPom>
<filters>
<filter>
<artifact>*:*</artifact>
<excludes>
<exclude>org/datanucleus/**</exclude>
<exclude>META-INF/*.SF</exclude>
<exclude>META-INF/*.DSA</exclude>
<exclude>META-INF/*.RSA</exclude>
</excludes>
</filter>
</filters>
</configuration>
<executions>
<execution>
<phase>package</phase>
<goals>
<goal>shade</goal>
</goals>
<configuration>
<transformers>
<transformer implementation="org.apache.maven.plugins.shade.resource.AppendingTransformer">
<resource>reference.conf</resource>
</transformer>
<transformer implementation="org.apache.maven.plugins.shade.resource.ServicesResourceTransformer"/>
<transformer implementation="org.apache.maven.plugins.shade.resource.ManifestResourceTransformer">
</transformer>
</transformers>
</configuration>
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
<version>3.5.1</version>
<configuration>
<source>${java.version}</source>
<target>${java.version}</target>
</configuration>
</plugin>
</plugins>
</build>
</project>
然后就是漫长的等待maven的自动下载……
4.数据集
本来想跑一点好玩的数据集,但一方面偷懒起见,另一方面我的计算机也没什么空间了(鬼知道我的硬盘上放了什么),也没什么算力,索性还是用最古老的MNIST,权当做是跑hello world了。这个数据集我就放在了我的云盘上
链接:https://pan.baidu.com/s/1YeeK9kyzXL6BaqfjVMMK0A
提取码:0dyi
大家可以自行下载。下载解压后,其文件结构如下:
├─testing
│ ├─0
│ ├─1
│ ├─2
│ ├─3
│ ├─4
│ ├─5
│ ├─6
│ ├─7
│ ├─8
│ └─9
└─training
├─0
├─1
├─2
├─3
├─4
├─5
├─6
├─7
├─8
└─9
5.读取数据
int nChannels = 1; // 单通道图片
int outputNum = 10; // 输出十类
int batchSize = 64;
File trainData = new File( "E:/mnist_png/training");
FileSplit trainSplit = new FileSplit(trainData, NativeImageLoader.ALLOWED_FORMATS, randNumGen);//获取目录下的文件
ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator(); //以父级目录名作为分类的标签名
ImageRecordReader trainRR = new ImageRecordReader(28, 28, nChannels, labelMaker);//构造图片读取类
trainRR.initialize(trainSplit);
DataSetIterator trainIter = new RecordReaderDataSetIterator(trainRR, batchSize, 1, outputNum); //构造出dl4j可读的训练集
// 把像素值区间 0-255 压缩到0-1 区间
DataNormalization scaler = new ImagePreProcessingScaler(0, 1);
scaler.fit(trainIter);
trainIter.setPreProcessor(scaler);
6.构建模型
网络模型通常在dl4j中是通过链式参数构建的,其实也十分容易理解。
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.weightInit(WeightInit.XAVIER)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.updater(new Nesterovs(0.01,0.9))
.list()
.layer(0, new ConvolutionLayer.Builder(5, 5)
//nIn and nOut specify depth. nIn here is the nChannels and nOut is the number of filters to be applied
.nIn(nChannels)
/**
* nOut 等价于 我们在训练的时候需要使用多少个 filter
*/
.nOut(20)
/**
* padding 对于个人用户来说,在使用中我们不需要自己去特别计算padding
* SAME -> ConvolutionMode.Same
* VALID -> ConvolutionMode.Truncate
*/
.activation(Activation.TANH)
.build())
.layer(1, new SubsamplingLayer.Builder(PoolingType.MAX)
.kernelSize(2,2)
.build())
.layer(2, new ConvolutionLayer.Builder(5, 5)
//Note that nIn need not be specified in later layers
.nOut(50)
.activation(Activation.TANH)
.build())
.layer(3, new SubsamplingLayer.Builder(PoolingType.MAX)
.kernelSize(2,2)
.build())
.layer(4, new DenseLayer.Builder().activation(Activation.SIGMOID)
.nOut(500).build())
.layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.nOut(outputNum)
.activation(Activation.SOFTMAX)
.build())
.setInputType(InputType.convolutionalFlat(28,28,1)) //See note below
.backprop(true).pretrain(false).build();
但是要小心,dl4j的网络模型是通过MultiLayerConfiguration来进行指定的,也就是说这里还并没有生成真正的网络,真正的网络还需要把这个config贴到某个网络上去初始化。
MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init();
下面就可以开始训练了,不过在训练之前,可以为model添加一个监听器,以方便模型在适当的轮数时输出当前的结果。
model.setListeners(new ScoreIterationListener(10)); //Print score every 10 iterations
model.fit(trainIter);
7.测试结果
同样的我们用5中的方法构建测试集数据,让train好的model进行测试
File testData = new File("E:/mnist_png/testing");
FileSplit testSplit = new FileSplit(testData, NativeImageLoader.ALLOWED_FORMATS, randNumGen);
ImageRecordReader testRR = new ImageRecordReader(28, 28, nChannels, labelMaker);
testRR.initialize(testSplit);
DataSetIterator testIter = new RecordReaderDataSetIterator(testRR, batchSize, 1, outputNum);
testIter.setPreProcessor(scaler);
Evaluation eval = model.evaluate(testIter);
log.info(eval.stats());
可以看到结果为
========================Evaluation Metrics========================
# of classes: 10
Accuracy: 0.9731
Precision: 0.9728
Recall: 0.9731
F1 Score: 0.9729
Precision, recall & F1: macro-averaged (equally weighted avg. of 10 classes)
=========================Confusion Matrix=========================
0 1 2 3 4 5 6 7 8 9
---------------------------------------------------
966 0 1 0 1 5 4 1 2 0 | 0 = 0
0 1125 3 1 0 1 2 0 3 0 | 1 = 1
3 0 996 5 4 4 2 9 9 0 | 2 = 2
1 0 5 971 0 15 0 9 7 2 | 3 = 3
1 0 1 0 961 0 5 1 2 11 | 4 = 4
1 1 0 5 0 877 5 1 1 1 | 5 = 5
7 4 0 0 4 7 936 0 0 0 | 6 = 6
1 6 14 1 1 0 0 997 1 7 | 7 = 7
3 0 1 4 6 6 5 6 938 5 | 8 = 8
3 6 0 4 18 8 1 5 0 964 | 9 = 9
8.完整代码
import org.datavec.api.io.labels.ParentPathLabelGenerator;
import org.datavec.api.split.FileSplit;
import org.datavec.image.loader.NativeImageLoader;
import org.datavec.image.recordreader.ImageRecordReader;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.util.DataTypeUtil;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.learning.config.Nesterovs;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.File;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.Random;
/**
* Created by agibsonccc on 9/16/15.
*/
public class Mnist {
private static final Logger log = LoggerFactory.getLogger(Mnist.class);
public static void main(String[] args) throws Exception {
int nChannels = 1; // Number of input channels
int outputNum = 10; // The number of possible outcomes
int batchSize = 64; // Test batch size
int nEpochs = 2; // Number of training epochs
int seed = 123; //
Nd4j.ENFORCE_NUMERICAL_STABILITY = false;
/*
Create an iterator using the batch size for one iteration
*/
Random randNumGen = new Random(seed);
log.info("Load data....");
File trainData = new File( "E:/mnist_png/training");
FileSplit trainSplit = new FileSplit(trainData, NativeImageLoader.ALLOWED_FORMATS, randNumGen);
ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator(); //以父级目录名作为分类的标签名
ImageRecordReader trainRR = new ImageRecordReader(28, 28, nChannels, labelMaker);//构造图片读取类
trainRR.initialize(trainSplit);
DataSetIterator trainIter = new RecordReaderDataSetIterator(trainRR, batchSize, 1, outputNum);
// 把像素值区间 0-255 压缩到0-1 区间
DataNormalization scaler = new ImagePreProcessingScaler(0, 1);
scaler.fit(trainIter);
trainIter.setPreProcessor(scaler);
/*
Construct the neural network
*/
log.info("Build model....");
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.weightInit(WeightInit.XAVIER)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.updater(new Nesterovs(0.01,0.9))
.list()
.layer(0, new ConvolutionLayer.Builder(5, 5)
//nIn and nOut specify depth. nIn here is the nChannels and nOut is the number of filters to be applied
.nIn(nChannels)
.stride(1, 1)
/**
* nOut 等价于 我们在训练的时候需要使用多少个 filter
*/
.nOut(20)
/**
* padding 对于个人用户来说,在使用中我们不需要自己去特别计算padding
* SAME -> ConvolutionMode.Same
* VALID -> ConvolutionMode.Truncate
*/
.activation(Activation.IDENTITY)
.build())
.layer(1, new SubsamplingLayer.Builder(PoolingType.MAX)
.kernelSize(2,2)
.stride(2,2)
.build())
.layer(2, new ConvolutionLayer.Builder(5, 5)
//Note that nIn need not be specified in later layers
.stride(1, 1)
.nOut(50)
.activation(Activation.IDENTITY)
.build())
.layer(3, new SubsamplingLayer.Builder(PoolingType.MAX)
.kernelSize(2,2)
.stride(2,2)
.build())
.layer(4, new DenseLayer.Builder().activation(Activation.SIGMOID)
.nOut(500).build())
.layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.nOut(outputNum)
.activation(Activation.SOFTMAX)
.build())
.setInputType(InputType.convolutionalFlat(28,28,1)) //See note below
.backprop(true).pretrain(false).build();
MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init();
File testData = new File("E:/mnist_png/testing");
FileSplit testSplit = new FileSplit(testData, NativeImageLoader.ALLOWED_FORMATS, randNumGen);
ImageRecordReader testRR = new ImageRecordReader(28, 28, nChannels, labelMaker);
testRR.initialize(testSplit);
DataSetIterator testIter = new RecordReaderDataSetIterator(testRR, batchSize, 1, outputNum);
testIter.setPreProcessor(scaler);
log.info("Train model....");
model.setListeners(new ScoreIterationListener(10)); //Print score every 10 iterations
for( int i=0; i<nEpochs; i++ ) {
model.fit(trainIter);
log.info("*** Completed epoch {} ***", i);
log.info("Evaluate model....");
Evaluation eval = model.evaluate(testIter);
log.info(eval.stats());
testIter.reset();
}
log.info("****************Example finished********************");
}
}
9.尾声
目前来看,python仍然是机器学习的主流趋势,其中的TensorFlow,Thano, keras, pytorch 和caffee等框架深入人心,且文档丰富,目前DeepLearning4j的使用仍然有点点小众,除了官方文档,其他资料相对较少。尽管dl4j是基于java的,因此可以放到移动端进行计算,但是移动端深度学习框架近年来也是层出不穷,尤其是小米公司的MACE,更得到了许多移动开发者的喜爱。本次仅仅尝试一下使用dl4j,后续如果有需要,再行学习研究。