Visualize the outputs of intermediate layers of a Keras model

在 2017-12-29 发布于 人工智能 下以来已有1,733人读过本文 | 0条评论 发表评论

训练完成各种CNN的模型之后,为了查看模型的效果以及模型到底能够从原始图像中抽取什么样的特征,有时候我们需要模型的中间输出结果,并将其图示出来进行查看。在TensorFlow中这很简单,直接将中间层的Tensor作为输出,以图像数据通过feed_dict进行输入,然后使用tf.Session将Tensor的结果run出来图示即可。但Keras对模型的输入输出以及训练过程整合的比较严密,耦合度较高,尤其是使用Keras自带的一些经典模型进行fine-tuning时,想要得到其中间层的结果并将其输出并不容易。

下面我们就来看一下,如何取出Keras模型中的某一层输出,并图示出来进行查看。

1.使用layer别名得到输出并图示

这种方式需要原始模型支持,或者重新对原始模型进行一下定义和复现。

首先我们来看用来得到模型输出数据的函数,如下:

def layer_to_visualize(layer):
  inputs = [K.learning_phase()] + model.inputs

  _convout1_f = K.function(inputs, [layer.output])

  def convout1_f(X):
    # The [0] is to disable the training phase flag
    return _convout1_f([0] + [X])

  convolutions = convout1_f(img_to_visualize)
  convolutions = np.squeeze(convolutions)

  print ('Shape of conv:', convolutions.shape)

  num = convolutions.shape[2]
  n = int(np.ceil(np.sqrt(num)))

  # Visualization of each filter of the layer
  fig = plt.figure()
  for i in range(num):
    ax = fig.add_subplot(n, n, i + 1)
    ax.imshow(convolutions[..., i], cmap='gray')
  plt.show()
  fig.close()

如上的函数,输入为某一层的名字,即可将该层的所有feature maps进行输出。

下面我们进行详细分析:

  1. 第2行使用模型的状态(训练还是验证),以及模型的输入(即接下来输入的图像数据所对应的Tensor),构建一个输入;
  2. 第4行使用backend的function构建一个函数,用于对输入进行处理;
  3. 第6~8行对第4行中构建的函数进行封装;需要注意的是,这里将函数输入中的模型状态直接设置为了0,对应模型处于验证测试状态;
  4. 第10、11行以图像img_to_visualize作为输入,直接得到卷积过后的feature maps,并将得到的numpy array进行squeeze,以便后续处理;
  5. 随后即是将得到的feature maps进行可视化显示了:比如4中得到的对应于feature maps的numpy array的shape为[32, 32, 128],则可以在12×12的grid上进行显示,每个块显示一个32×32的feature map。

定义好了这个函数,在使用时,需要定义的模型进行匹配,如果模型不匹配,则需要重新进行一下定义。模型的定义需要满足如下条件:

model = Sequential()
model.add(Conv2D(64, kernel_size=(3, 3), padding='same', input_shape=(64, 64, 1)))
conv1out = Activation('relu')
model.add(conv1out)
model.add(MaxPool2D())
model.add(Conv2D(128, kernel_size=(3, 3), padding='same'), activation='relu')

...model definition...

model.compile(optimizer='...', loss='...', metrics=['...'])

model.load_weights('*.h5')

即为需要visualize的层定义一个名字,如conv1out;然后即可使用上面定义的函数layer_to_visualize进行可视化:layer_to_visualize(conv1out)。在最后可视化之前,注意到函数中用到的model需要提前定义好,而图像数据img_to_visualize也需要提前加载进去准备好,该数据需要与model的输入Tensor维度匹配。

完整代码请点击查看

2.直接加载模型,使用layer编号得到输出并图示

1中方法固然也比较方便,但对模型有要求;而模型不满足要求时,还需要重新定义,略为复杂。

我们还可以使用另外一种方法,直接将已经保存的模型加载进来,然后一次将整个模型每一层的输出都全部得到,最后根据需求将某些层的输出进行图示即可。当然,此种方法虽然不需要模型图的重定义,但也需要预先对模型比较熟悉,不然无法找到某些层对应的编号,便无法获得其对应的输出。

这种方法首先需要定义可以一次得到所有层的输出的函数:

def get_layer_outputs():
  img_to_visualize = get_image()
  plt.imshow(img_to_visualize[..., 0])
  plt.show()
  img_to_visualize = np.expand_dims(img_to_visualize, 0)
  outputs = [layer.output for layer in model.layers]  # all layer outputs
  comp_graph = [K.function([model.input] + [K.learning_phase()], [output])
                for output in outputs]  # evaluation functions

  # Testing
  layer_outputs_list = [op([img_to_visualize, 1.]) for op in comp_graph]
  layer_outputs = []

  for layer_output in layer_outputs_list:
    print(layer_output[0][0].shape, end='\n-------------------\n')
    layer_outputs.append(layer_output[0][0])

  return layer_outputs

这个函数里,第2~5行中img_to_visualize与1中类似,是输入图像,这几行对图像进行处理的代码也与1中类似,此处不再赘述;

第6~8行构建得到输出的函数,也使用了K.function以及leaning_phase等;

第9行往后到结束,使用构建的函数获得模型每一层的输出。

随后定义函数将得到的某一层的输出数据进行输出显示,如下:

def plot_layer_outputs(layer_number):
  layer_outputs = get_layer_outputs()

  x_max = layer_outputs[layer_number].shape[0]
  y_max = layer_outputs[layer_number].shape[1]
  n = layer_outputs[layer_number].shape[2]

  L = []
  for i in range(n):
    L.append(np.zeros((x_max, y_max)))

  for i in range(n):
    for x in range(x_max):
      for y in range(y_max):
        L[i][x][y] = layer_outputs[layer_number][x][y][i]

  fig = plt.figure()
  for i, c in enumerate(L):
    ax = fig.add_subplot(np.ceil(n**0.5), np.ceil(n**0.5), i + 1)
    ax.imshow(c, cmap='gray')
  plt.show()

这个函数比较简单,我们对此进行一下简单解释。

首先第2行利用上面定义的函数,获得模型所有层的输出;

然后第4~6行得到所取出层的各种参数,比如大小,一共多少个feature maps等等;

第8~15行将对应的输出赋给一个列表,以备后续显示处理;

随后将得到的输出的各个feature maps绘制出来。

有了这两个函数,即可直接使用如plot_layer_outputs(3)来进行调用,绘制模型第3层的输出feature maps了。

完整代码请点击查看

备注

除了这两种方法外,还有一些其他方法可以使用,请参考这个问题及其回答

发表评论

您的昵称 *

您的邮箱 *

您的网站