首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >在deeplearning4j中初始化自定义权重

在deeplearning4j中初始化自定义权重
EN

Stack Overflow用户
提问于 2017-03-15 10:16:32
回答 2查看 1.3K关注 0票数 2

我试图用https://www.youtube.com/watch?v=Fp9kzoAxsA4库来实现类似于GANN (遗传算法神经网络)这样的工具。

遗传学习变量:

  • 基因:生物神经网络权重
  • 健身:移动的总距离。

每个生物的神经网络层:

  • 输入层:5个传感器,如果传感器方向有墙,则为1,如果没有,则为0

  • 输出层:映射到生物角度的线性输出。

这是我的生物对象的createBrain方法:

代码语言:javascript
复制
private void createBrain() {
    Layer inputLayer = new DenseLayer.Builder()
            // 5 eye sensors
            .nIn(5)
            .nOut(5)
            // How do I initialize custom weights using creature genes (this.genes)?
            // .weightInit(WeightInit.ZERO)
            .activation(Activation.RELU)
            .build();

    Layer outputLayer = new OutputLayer.Builder()
            .nIn(5)
            .nOut(1)
            .activation(Activation.IDENTITY)
            .lossFunction(LossFunctions.LossFunction.MSE)
            .build();

    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
            .seed(6)
            .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
            .iterations(1)
            .learningRate(0.006)
            .updater(Updater.NESTEROVS).momentum(0.9)
            .list()
            .layer(0,inputLayer)
            .layer(1, outputLayer)
            .pretrain(false).backprop(true)
            .build();

    this.brain = new MultiLayerNetwork(conf);
    this.brain.init();
}

如果它能帮上忙的话,我把它推到了回购https://github.com/kareem3d/GeneticNeuralNetwork

这是生物类https://github.com/kareem3d/GeneticNeuralNetwork/blob/master/src/main/java/com/mycompany/gaan/Creature.java

我是一个机器学习的学生,所以如果你看到任何明显的错误,请告诉我,谢谢:)

EN

回答 2

Stack Overflow用户

发布于 2017-05-10 08:37:05

我不知道您是否可以在层配置中设置权重(我在API文档中看不到),但是您可以在初始化模型之后获得并设置网络参数。

要将它们单独设置为层,您可以遵循以下示例;

代码语言:javascript
复制
    Iterator paramap_iterator = convolutionalEncoder.paramTable().entrySet().iterator();

    while(paramap_iterator.hasNext()) {
        Map.Entry<String, INDArray> me = (Map.Entry<String, INDArray>) paramap_iterator.next();
        System.out.println(me.getKey());//print key
        System.out.println(Arrays.toString(me.getValue().shape()));//print shape of INDArray
        convolutionalEncoder.setParam(me.getKey(), Nd4j.rand(me.getValue().shape()));//set some random values
    }

如果您想一次设置网络的所有参数,例如可以使用setParams()params()

代码语言:javascript
复制
INDArray all_params = convolutionalEncoder.params();
convolutionalEncoder.setParams(Nd4j.rand(all_params.shape()));//set random values with the same shape

您可以查看API以获得更多信息;https://deeplearning4j.org/doc/org/deeplearning4j/nn/api/Model.html#params--

票数 3
EN

Stack Overflow用户

发布于 2019-02-02 10:35:23

对我来说很管用:

代码语言:javascript
复制
    int inputNum = 4;
    int outputNum = 3;

    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
            .seed(123)
            .layer(new EmbeddingLayer.Builder()
                    .nIn(inputNum) // Number of input datapoints.
                    .nOut(8) // Number of output datapoints.
                    .activation(Activation.RELU) // Activation function.
                    .weightInit(WeightInit.XAVIER) // Weight initialization.
                    .build())
            .list()
            .layer(new DenseLayer.Builder()
                    .nIn(inputNum) // Number of input datapoints.
                    .nOut(8) // Number of output datapoints.
                    .activation(Activation.RELU) // Activation function.
                    .weightInit(WeightInit.XAVIER) // Weight initialization.
                    .build())
            .layer(new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
                    .nIn(8)
                    .nOut(outputNum)
                    .activation(Activation.SOFTMAX)
                    .weightInit(WeightInit.XAVIER)
                    .build())
            .pretrain(false).backprop(false)
            .build();

    MultiLayerNetwork multiLayerNetwork = new MultiLayerNetwork(conf);
    multiLayerNetwork.init();

    Map<String, INDArray> paramTable = multiLayerNetwork.paramTable();
    Set<String> keys = paramTable.keySet();
    Iterator<String> it = keys.iterator();

    while (it.hasNext()) {
        String key = it.next();
        INDArray values = paramTable.get(key);
        System.out.print(key+" ");//print keys
        System.out.println(Arrays.toString(values.shape()));//print shape of INDArray
        System.out.println(values);
        multiLayerNetwork.setParam(key, Nd4j.rand(values.shape()));//set some random values
    }
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/42806761

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档