首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何使用KerasTuner优化神经网络体系结构?

如何使用KerasTuner优化神经网络体系结构?
EN

Stack Overflow用户
提问于 2021-12-30 17:57:07
回答 1查看 193关注 0票数 6

我试图使用KerasTuner自动调整神经网络体系结构,即隐藏层数和每个隐藏层中的节点数。目前,神经网络的体系结构是用一个参数NN_LAYER_SIZES定义的。例如,

代码语言:javascript
复制
NN_LAYER_SIZES = [128, 128, 128, 128]

表示神经网络有4个隐层,每个隐层有128个节点。

KerasTuner有以下超参数类型(https://keras.io/api/keras_tuner/hyperparameters/):

  • Int
  • Float
  • Boolean
  • Choice

这些超参数类型似乎都不适合我的用例。所以我编写了下面的代码来扫描隐藏层的数量和节点的数量。然而,它并没有被认为是一个超参数。

代码语言:javascript
复制
number_of_hidden_layer = hp.Int("layer_number", min_value=2, max_value=5, step=1)
number_of_nodes = hp.Int("node_number", min_value=4, max_value=8, step=1)
NN_LAYER_SIZES = [2**number_of_nodes for _ in range(number of hidden_layer)]

对如何使它正确有任何建议吗?

EN

回答 1

Stack Overflow用户

发布于 2022-01-04 16:54:08

在构建模型时,可以通过迭代来将层数看作一个超参数。这样,您就可以对不同的层数和不同的节点数进行实验:

代码语言:javascript
复制
import tensorflow as tf
import keras_tuner as kt

def model_builder(hp):
  model = tf.keras.Sequential()
  model.add(tf.keras.layers.Flatten(input_shape=(28, 28)))

  units = hp.Int('units', min_value=32, max_value=512, step=32)
  layers = hp.Int('layers', min_value=2, max_value=5, step=1)

  for _ in range(layers):
    model.add(tf.keras.layers.Dense(units=units, activation='relu')) 

  model.add(tf.keras.layers.Dense(10))

  model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
                loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                metrics=['accuracy'])
  return model

(img_train, label_train), (_, _) = tf.keras.datasets.fashion_mnist.load_data()
img_train = img_train.astype('float32') / 255.0

tuner = kt.Hyperband(model_builder,
                     objective='val_accuracy',
                     max_epochs=10,
                     factor=3)

tuner.search(img_train, label_train, epochs=50, validation_split=0.2)
best_hps=tuner.get_best_hyperparameters(num_trials=1)[0]

model = tuner.hypermodel.build(best_hps)
history = model.fit(img_train, label_train, epochs=50, validation_split=0.2)
票数 3
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/70535121

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档