DeepLearning4j的hello world之mnist

本文介绍如何使用Java进行机器学习,重点介绍了使用DeepLearning4j框架搭建环境、处理MNIST数据集并构建卷积神经网络的过程。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

1.前言

本文纯属划水,最近没啥可写的了,因此写这么一篇凑个数,对Java比较陌生的朋友可以绕行。众所周知,Java是一门古老且全能的语言,从网站开发到安卓应用,从爬虫到机器学习,几乎无所不能。本期我们就尝试用Java进行一点机器学习的探索。

2.框架选择

本期我们选择的框架是DeepLearning4j,4j即为 for java之意,目前dl4j是java机器学习中的主流框架之一。由于手上目前没有像样的linux机,索性此处不用docker进行环境配置了(说的好像我会用一样),直接上maven吧。

3.环境搭建

新建一个maven工程,创建完毕后,修改pom.xml,为了方便起见,我直接把我的整个xml贴在下面:
我的jdk是1.8,如果用其他版本的jdk,应该注意修改下面的java.version的标记

<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
         xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
    <modelVersion>4.0.0</modelVersion>

    <groupId>org</groupId>
    <artifactId>dl4j-tutorials</artifactId>
    <version>1.0-SNAPSHOT</version>


    <properties>
        <!-- Change the nd4j.backend property to nd4j-cuda-9.2 to use CUDA GPUs -->
        <!-- CPU使用  nd4j-native-platform -->
        <!-- GPU使用  nd4j-cuda-9.2-platform-->
        <nd4j.backend>nd4j-native</nd4j.backend>
        <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
        <shadedClassifier>bin</shadedClassifier>


        <java.version>1.8</java.version>
        <nd4j.version>1.0.0-beta</nd4j.version>
        <dl4j.version>1.0.0-beta</dl4j.version>
        <datavec.version>1.0.0-beta</datavec.version>
        <arbiter.version>1.0.0-beta</arbiter.version>
        <rl4j.version>1.0.0-beta</rl4j.version>

        <!-- For Spark examples: change the _1 to _2 to switch between Spark 1 and Spark 2 -->
        <dl4j.spark.version>1.0.0-beta2_spark_1</dl4j.spark.version>
        <datavec.spark.version>1.0.0-beta2_spark_1</datavec.spark.version>

        <jcommander.version>1.27</jcommander.version>

        <!-- Scala binary version: DL4J's Spark and UI functionality are released with both Scala 2.10 and 2.11 support -->
        <scala.binary.version>2.11</scala.binary.version>

        <guava.version>19.0</guava.version>
        <logback.version>1.1.7</logback.version>
        <jfreechart.version>1.0.13</jfreechart.version>
        <jcommon.version>1.0.23</jcommon.version>
        <maven-shade-plugin.version>2.4.3</maven-shade-plugin.version>
        <exec-maven-plugin.version>1.4.0</exec-maven-plugin.version>
        <maven.minimum.version>3.3.1</maven.minimum.version>
    </properties>

    <dependencyManagement>
        <dependencies>

        </dependencies>
    </dependencyManagement>

    <dependencies>
        <dependency>
            <groupId>org.deeplearning4j</groupId>
            <artifactId>deeplearning4j-core</artifactId>
            <version>${dl4j.version}</version>
        </dependency>

        <dependency>
            <groupId>org.deeplearning4j</groupId>
            <artifactId>deeplearning4j-nlp</artifactId>
            <version>${dl4j.version}</version>
        </dependency>

        <dependency>
            <groupId>org.deeplearning4j</groupId>
            <artifactId>deeplearning4j-zoo</artifactId>
            <version>${dl4j.version}</version>
        </dependency>
        <dependency>
            <groupId>org.deeplearning4j</groupId>
            <artifactId>deeplearning4j-nn</artifactId>
            <version>${dl4j.version}</version>
        </dependency>
        <!-- deeplearning4j-ui is used for HistogramIterationListener + visualization: see http://deeplearning4j.org/visualization -->
        <dependency>
            <groupId>org.deeplearning4j</groupId>
            <artifactId>deeplearning4j-ui_${scala.binary.version}</artifactId>
            <version>${dl4j.version}</version>
        </dependency>

        <dependency>
            <groupId>com.google.guava</groupId>
            <artifactId>guava</artifactId>
            <version>${guava.version}</version>
        </dependency>

        <!-- datavec-data-codec: used only in video example for loading video data -->
        <dependency>
            <artifactId>datavec-data-codec</artifactId>
            <groupId>org.datavec</groupId>
            <version>${datavec.version}</version>
        </dependency>

        <dependency>
            <groupId>jfree</groupId>
            <artifactId>jfreechart</artifactId>
            <version>${jfreechart.version}</version>
        </dependency>
        <dependency>
            <groupId>org.jfree</groupId>
            <artifactId>jcommon</artifactId>
            <version>${jcommon.version}</version>
        </dependency>

        <dependency>
            <groupId>org.apache.httpcomponents</groupId>
            <artifactId>httpclient</artifactId>
            <version>4.3.6</version>
        </dependency>

        <dependency>
            <groupId>ch.qos.logback</groupId>
            <artifactId>logback-classic</artifactId>
            <version>${logback.version}</version>
        </dependency>
        <!-- https://mvnrepository.com/artifact/org.deeplearning4j/deeplearning4j-cuda-9.1 -->
        <!-- https://mvnrepository.com/artifact/org.nd4j/nd4j-cuda-8.0 -->
        <dependency>
            <groupId>org.nd4j</groupId>
            <artifactId>nd4j-native-platform</artifactId>
            <version>${nd4j.version}</version>
        </dependency>

    </dependencies>

    <build>
        <plugins>
            <plugin>
                <groupId>org.codehaus.mojo</groupId>
                <artifactId>exec-maven-plugin</artifactId>
                <version>${exec-maven-plugin.version}</version>
                <executions>
                    <execution>
                        <goals>
                            <goal>exec</goal>
                        </goals>
                    </execution>
                </executions>
                <configuration>
                    <executable>java</executable>
                </configuration>
            </plugin>
            <plugin>
                <groupId>org.apache.maven.plugins</groupId>
                <artifactId>maven-shade-plugin</artifactId>
                <version>${maven-shade-plugin.version}</version>
                <configuration>
                    <shadedArtifactAttached>true</shadedArtifactAttached>
                    <shadedClassifierName>${shadedClassifier}</shadedClassifierName>
                    <createDependencyReducedPom>true</createDependencyReducedPom>
                    <filters>
                        <filter>
                            <artifact>*:*</artifact>
                            <excludes>
                                <exclude>org/datanucleus/**</exclude>
                                <exclude>META-INF/*.SF</exclude>
                                <exclude>META-INF/*.DSA</exclude>
                                <exclude>META-INF/*.RSA</exclude>
                            </excludes>
                        </filter>
                    </filters>
                </configuration>
                <executions>
                    <execution>
                        <phase>package</phase>
                        <goals>
                            <goal>shade</goal>
                        </goals>
                        <configuration>
                            <transformers>
                                <transformer implementation="org.apache.maven.plugins.shade.resource.AppendingTransformer">
                                    <resource>reference.conf</resource>
                                </transformer>
                                <transformer implementation="org.apache.maven.plugins.shade.resource.ServicesResourceTransformer"/>
                                <transformer implementation="org.apache.maven.plugins.shade.resource.ManifestResourceTransformer">
                                </transformer>
                            </transformers>
                        </configuration>
                    </execution>
                </executions>
            </plugin>

            <plugin>
                <groupId>org.apache.maven.plugins</groupId>
                <artifactId>maven-compiler-plugin</artifactId>
                <version>3.5.1</version>
                <configuration>
                    <source>${java.version}</source>
                    <target>${java.version}</target>
                </configuration>
            </plugin>
        </plugins>
    </build>
</project>

然后就是漫长的等待maven的自动下载……

4.数据集

本来想跑一点好玩的数据集,但一方面偷懒起见,另一方面我的计算机也没什么空间了(鬼知道我的硬盘上放了什么),也没什么算力,索性还是用最古老的MNIST,权当做是跑hello world了。这个数据集我就放在了我的云盘上
链接:https://pan.baidu.com/s/1YeeK9kyzXL6BaqfjVMMK0A
提取码:0dyi
大家可以自行下载。下载解压后,其文件结构如下:

├─testing
│  ├─0
│  ├─1
│  ├─2
│  ├─3
│  ├─4
│  ├─5
│  ├─6
│  ├─7
│  ├─8
│  └─9
└─training
    ├─0
    ├─1
    ├─2
    ├─3
    ├─4
    ├─5
    ├─6
    ├─7
    ├─8
    └─9

5.读取数据

int nChannels = 1; // 单通道图片
int outputNum = 10; // 输出十类
int batchSize = 64;
File trainData = new File( "E:/mnist_png/training");
FileSplit trainSplit = new FileSplit(trainData, NativeImageLoader.ALLOWED_FORMATS, randNumGen);//获取目录下的文件
ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator(); //以父级目录名作为分类的标签名
ImageRecordReader trainRR = new ImageRecordReader(28, 28, nChannels, labelMaker);//构造图片读取类
trainRR.initialize(trainSplit);
DataSetIterator trainIter = new RecordReaderDataSetIterator(trainRR, batchSize, 1, outputNum); //构造出dl4j可读的训练集

// 把像素值区间 0-255 压缩到0-1 区间
DataNormalization scaler = new ImagePreProcessingScaler(0, 1);
scaler.fit(trainIter);
trainIter.setPreProcessor(scaler);

6.构建模型

网络模型通常在dl4j中是通过链式参数构建的,其实也十分容易理解。

MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
	.weightInit(WeightInit.XAVIER)
	.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
	.updater(new Nesterovs(0.01,0.9))
	.list()
	.layer(0, new ConvolutionLayer.Builder(5, 5)
	      //nIn and nOut specify depth. nIn here is the nChannels and nOut is the number of filters to be applied
	      .nIn(nChannels)
	      /**
	       * nOut 等价于 我们在训练的时候需要使用多少个 filter
	       */
	      .nOut(20)
	      /**
	       * padding 对于个人用户来说,在使用中我们不需要自己去特别计算padding
	       * SAME -> ConvolutionMode.Same
	       * VALID -> ConvolutionMode.Truncate
	       */
	      .activation(Activation.TANH)
	      .build())
	.layer(1, new SubsamplingLayer.Builder(PoolingType.MAX)
	      .kernelSize(2,2)
	      .build())
	.layer(2, new ConvolutionLayer.Builder(5, 5)
	      //Note that nIn need not be specified in later layers
	      .nOut(50)
	      .activation(Activation.TANH)
	      .build())
	.layer(3, new SubsamplingLayer.Builder(PoolingType.MAX)
	      .kernelSize(2,2)
	      .build())
	.layer(4, new DenseLayer.Builder().activation(Activation.SIGMOID)
	      .nOut(500).build())
	.layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
	      .nOut(outputNum)
	      .activation(Activation.SOFTMAX)
	      .build())
	.setInputType(InputType.convolutionalFlat(28,28,1)) //See note below
	.backprop(true).pretrain(false).build();

但是要小心,dl4j的网络模型是通过MultiLayerConfiguration来进行指定的,也就是说这里还并没有生成真正的网络,真正的网络还需要把这个config贴到某个网络上去初始化。

MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init();

下面就可以开始训练了,不过在训练之前,可以为model添加一个监听器,以方便模型在适当的轮数时输出当前的结果。

model.setListeners(new ScoreIterationListener(10)); //Print score every 10 iterations
model.fit(trainIter);

7.测试结果

同样的我们用5中的方法构建测试集数据,让train好的model进行测试

File testData = new File("E:/mnist_png/testing");
FileSplit testSplit = new FileSplit(testData, NativeImageLoader.ALLOWED_FORMATS, randNumGen);
ImageRecordReader testRR = new ImageRecordReader(28, 28, nChannels, labelMaker);
testRR.initialize(testSplit);
DataSetIterator testIter = new RecordReaderDataSetIterator(testRR, batchSize, 1, outputNum);
testIter.setPreProcessor(scaler);

Evaluation eval = model.evaluate(testIter);
log.info(eval.stats());

可以看到结果为

========================Evaluation Metrics========================
 # of classes:    10
 Accuracy:        0.9731
 Precision:       0.9728
 Recall:          0.9731
 F1 Score:        0.9729
Precision, recall & F1: macro-averaged (equally weighted avg. of 10 classes)


=========================Confusion Matrix=========================
    0    1    2    3    4    5    6    7    8    9
---------------------------------------------------
  966    0    1    0    1    5    4    1    2    0 | 0 = 0
    0 1125    3    1    0    1    2    0    3    0 | 1 = 1
    3    0  996    5    4    4    2    9    9    0 | 2 = 2
    1    0    5  971    0   15    0    9    7    2 | 3 = 3
    1    0    1    0  961    0    5    1    2   11 | 4 = 4
    1    1    0    5    0  877    5    1    1    1 | 5 = 5
    7    4    0    0    4    7  936    0    0    0 | 6 = 6
    1    6   14    1    1    0    0  997    1    7 | 7 = 7
    3    0    1    4    6    6    5    6  938    5 | 8 = 8
    3    6    0    4   18    8    1    5    0  964 | 9 = 9

8.完整代码

import org.datavec.api.io.labels.ParentPathLabelGenerator;
import org.datavec.api.split.FileSplit;
import org.datavec.image.loader.NativeImageLoader;
import org.datavec.image.recordreader.ImageRecordReader;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.util.DataTypeUtil;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.learning.config.Nesterovs;
import org.nd4j.linalg.lossfunctions.LossFunctions;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.File;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.Random;

/**
 * Created by agibsonccc on 9/16/15.
 */
public class Mnist {
    private static final Logger log = LoggerFactory.getLogger(Mnist.class);

    public static void main(String[] args) throws Exception {
        int nChannels = 1; // Number of input channels
        int outputNum = 10; // The number of possible outcomes
        int batchSize = 64; // Test batch size
        int nEpochs = 2; // Number of training epochs
        int seed = 123; //

        Nd4j.ENFORCE_NUMERICAL_STABILITY = false;
        /*
            Create an iterator using the batch size for one iteration
         */
        Random randNumGen = new Random(seed);

        log.info("Load data....");
        File trainData = new File( "E:/mnist_png/training");
        FileSplit trainSplit = new FileSplit(trainData, NativeImageLoader.ALLOWED_FORMATS, randNumGen);
        ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator(); //以父级目录名作为分类的标签名
        ImageRecordReader trainRR = new ImageRecordReader(28, 28, nChannels, labelMaker);//构造图片读取类
        trainRR.initialize(trainSplit);
        DataSetIterator trainIter = new RecordReaderDataSetIterator(trainRR, batchSize, 1, outputNum);

        // 把像素值区间 0-255 压缩到0-1 区间
        DataNormalization scaler = new ImagePreProcessingScaler(0, 1);
        scaler.fit(trainIter);
        trainIter.setPreProcessor(scaler);

        /*
            Construct the neural network
         */
        log.info("Build model....");

        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                .weightInit(WeightInit.XAVIER)
                .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                .updater(new Nesterovs(0.01,0.9))
                .list()
                .layer(0, new ConvolutionLayer.Builder(5, 5)
                        //nIn and nOut specify depth. nIn here is the nChannels and nOut is the number of filters to be applied
                        .nIn(nChannels)
                        .stride(1, 1)
                        /**
                         * nOut 等价于 我们在训练的时候需要使用多少个 filter
                         */
                        .nOut(20)
                        /**
                         * padding 对于个人用户来说,在使用中我们不需要自己去特别计算padding
                         * SAME -> ConvolutionMode.Same
                         * VALID -> ConvolutionMode.Truncate
                         */
                        .activation(Activation.IDENTITY)
                        .build())
                .layer(1, new SubsamplingLayer.Builder(PoolingType.MAX)
                        .kernelSize(2,2)
                        .stride(2,2)
                        .build())
                .layer(2, new ConvolutionLayer.Builder(5, 5)
                        //Note that nIn need not be specified in later layers
                        .stride(1, 1)
                        .nOut(50)
                        .activation(Activation.IDENTITY)
                        .build())
                .layer(3, new SubsamplingLayer.Builder(PoolingType.MAX)
                        .kernelSize(2,2)
                        .stride(2,2)
                        .build())
                .layer(4, new DenseLayer.Builder().activation(Activation.SIGMOID)
                        .nOut(500).build())
                .layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
                        .nOut(outputNum)
                        .activation(Activation.SOFTMAX)
                        .build())
                .setInputType(InputType.convolutionalFlat(28,28,1)) //See note below
                .backprop(true).pretrain(false).build();

        MultiLayerNetwork model = new MultiLayerNetwork(conf);
        model.init();

        File testData = new File("E:/mnist_png/testing");
        FileSplit testSplit = new FileSplit(testData, NativeImageLoader.ALLOWED_FORMATS, randNumGen);
        ImageRecordReader testRR = new ImageRecordReader(28, 28, nChannels, labelMaker);
        testRR.initialize(testSplit);
        DataSetIterator testIter = new RecordReaderDataSetIterator(testRR, batchSize, 1, outputNum);
        testIter.setPreProcessor(scaler);


        log.info("Train model....");
        model.setListeners(new ScoreIterationListener(10)); //Print score every 10 iterations
        for( int i=0; i<nEpochs; i++ ) {
            model.fit(trainIter);
            log.info("*** Completed epoch {} ***", i);

            log.info("Evaluate model....");
            Evaluation eval = model.evaluate(testIter);
            log.info(eval.stats());
            testIter.reset();
        }
        log.info("****************Example finished********************");

    }
}

9.尾声

目前来看,python仍然是机器学习的主流趋势,其中的TensorFlow,Thano, keras, pytorch 和caffee等框架深入人心,且文档丰富,目前DeepLearning4j的使用仍然有点点小众,除了官方文档,其他资料相对较少。尽管dl4j是基于java的,因此可以放到移动端进行计算,但是移动端深度学习框架近年来也是层出不穷,尤其是小米公司的MACE,更得到了许多移动开发者的喜爱。本次仅仅尝试一下使用dl4j,后续如果有需要,再行学习研究。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值