了解如何使用tf.nn.conv2d函数

2024-09-28 22:05:20 发布

您现在位置:Python中文网/ 问答频道 /正文

我试图理解tf.nn.conv2d()函数是如何工作的。因此,我创建了一个简单的tensorflow程序:

import tensorflow as tf
import numpy as np

Nif = 2
Niy = 4
Nix = 4

Nof = 1
Koy =3
Kox = 3

ifmaps  = np.random.randint(3, size=(Nif, Niy, Nix))
print("ifmaps= ", ifmaps)
weights = np.random.randint(3, size=(Nof, Nif, Koy, Kox))
print("weights = ", weights)
weights = np.reshape(weights, (Koy, Kox, Nif,Nof)) 

ifmaps = tf.constant(ifmaps, dtype=tf.float64, shape=[Nif, Niy, Nix])
weights = tf.constant(weights, dtype=tf.float64, shape=[Koy, Kox, Nif,Nof])

ifmaps_tf = tf.reshape(ifmaps, shape=[-1, Niy, Nix, Nif]) #NHWC
weights_tf = tf.reshape(weights, shape = [Koy, Kox, Nif, Nof])

res = tf.nn.conv2d(ifmaps_tf, weights_tf, strides=[1, 1, 1, 1], padding='VALID') #S=1, no padding
#reshape it to NCHW format
ofmap = tf.reshape(res, shape=[ Nof, 2, 2])

with tf.Session() as sess:
   print("ofmap = ", sess.run(ofmap))

我得到的结果是:

ifmaps=  [[[0 2 1 0]
  [2 0 1 0]
  [0 1 2 1]
  [1 0 2 1]]

 [[0 0 2 1]
  [0 0 1 0]
  [2 1 2 1]
  [0 2 0 2]]]

  weights =  [[[[1 1 1]
   [0 1 0]
   [1 0 2]]

  [[2 2 2]
   [2 2 2]
   [0 1 0]]]]

   ofmap =  [[[17. 21.]
  [20. 16.]]]

ofmaps值不正确!有人能帮我得到我所缺少的吗? 提前谢谢


Tags: tfasnpnixnifprintshapeweights