Theano和tensorflow conv2D产生不同的输出

2024-06-25 23:07:45 发布

您现在位置:Python中文网/ 问答频道 /正文

下面的代码是用theano编写的,并产生一些输出

import theano
import tensorflow as tf
from theano import tensor as T
from theano.tensor.nnet import conv2d

import numpy as np

np.random.seed(1)
wt_th = np.random.uniform(low=-0.5, high=0.5, size=(16,3,5,5)).astype(np.float32)

np.random.seed(1)
inp = np.random.rand(128,3,128,128).astype(np.float32)

np.random.seed(1)
bias = np.random.uniform(low=0, high=10, size=(16)).astype(np.float32)

# instantiate 4D tensor for input
input = T.tensor4(name='input')

# initialize shared variable for weights.
w_shp = (16, 3, 5, 5)
W = theano.shared(wt_th , name ='W')

b_shp = (16,)
b = theano.shared(bias, name ='b')

# build symbolic expression that computes the convolution of input with filters in w
conv_out = conv2d(input, W)

output = T.nnet.relu(conv_out + b.dimshuffle('x', 0, 'x', 'x'))

# create theano function to compute filtered images
f = theano.function([input], output)
res = f(inp)
print res[0][0][0]

[2.8236432 2.094213 1.4916432 3.525494 2.3700824 1.8851945 2.2574215] 3.3974087 1.9719648 1.3346338 0.7322583 1.4527869 2.9211016 2.4242344 2.6613848 3.1885512 2.8935843 3.7721367 1.0875871 1.8844371 3.6890957 2.1210446 3.4621592 2.2298138 2.1788187 3.1571674 2.080009 1.4983883 3.3549118 1.8853223 2.0242834 2.8072758 4.2562714 3.6012995 2.2535224 3.87668 3.3886926 3.697033 3.3373523 2.2016246 3.5874677 3.0154514 2.434566 3.6492867 2.2965183 2.6377907 2.2562 2.8330164 2.1103406 2.9778543 2.3738375 3.1129453 1.277472 1.1789091 2.4199317 2.619667 2.5976152 1.0020001 2.562955 1.6254797 1.9258347 1.5564928 3.5225492 3.1682463 2.179951 3.2768161 2.2703805 2.0199404 2.4948874 2.9022932 3.0263028 2.264034 1.9042997 1.6110027 3.6300693 1.899374 2.9140353 2.8552768 2.7125297 2.7972744 2.0619967 3.8458047 3.140479 1.6845248 3.844461 3.8562043 2.5270283 2.4488764 2.7029114 1.8886952 3.034019 3.1078124 1.9806297 4.573 2.769538 2.6645966 3.501518 2.2144883 1.8297508 3.3294327 2.7242799 2.187298 2.5060043 1.9938259 3.914175 3.7276266 2.6536622 2.896241 2.821738 1.592206 1.8782039 2.648998 2.284129 3.4120197 2.6911411 3.2339904 2.5738459 2.8637185 1.8006318 3.1124763 2.1838622 2.6475391 1.7801914 2.5641136]

W的格式为(num\u output\u channel,num\u input\u channel,height,widht)

输入格式为(批量大小、通道数、高度、宽度)

现在我已经编写了一个函数来将权重从ano格式转换为tensorflow格式,并将输入形状更改为tf格式 但是下面的代码产生的输出与上面的不同

def convert_filter(wts):
    wts = np.moveaxis(wts, 0, 3)
    wts = np.moveaxis(wts, 0, 2)
    return wts

def convert_input(inp):
    inp = np.moveaxis(inp,1,3)
    return inp

input_shape = [128,128, 3]

X = tf.placeholder(shape = [None] + input_shape, dtype=tf.float32, name='X')

wt_tf = convert_filter(wt_th).astype(np.float32)

conv_kernel_1 = tf.nn.conv2d(X, wt_tf, [1, 1, 1, 1], padding='VALID',use_cudnn_on_gpu=False)

bias_layer_1 = tf.nn.bias_add(conv_kernel_1, bias)

act_out = tf.nn.relu(bias_layer_1)

inp_tf = convert_input(inp)

with tf.Session() as sess:
    out = sess.run(act_out, feed_dict={X:inp_tf})
    out = np.moveaxis(out, 3, 1)
print out[0][0][0]

[2.67197418 2.15853548 2.06719136 3.01160574 3.22447252 3.077492] 2.52125549 3.08207083 2.29633474 1.86849833 2.03281307 2.28387547 2.67936897 2.48002243 2.31078005 3.56169009 3.12560081 2.61774731 1.82814527 3.23375154 3.25905514 2.39252329 3.13444471 2.00132608 2.41169739 1.86714172 3.01640558 2.51328039 2.07797813 1.77424145 1.8954494 2.98585939 2.98480368 2.57455826 2.36318088 3.88532543 2.38877392 2.86067486 2.78855133 2.63732243 2.63163185 2.79659152 1.98354578 2.77975321 2.12787509 2.71589994 3.44908381 2.02305984 3.04079533 2.60647154 2.14657426 2.74537277 3.07799053 2.49051762 4.77739191 3.12529612 2.30980444 2.31344223 2.02293968 3.04298592 3.4453795 3.58379078 3.32912683 3.26278138 1.48381591 2.32841253 1.97166562 3.04377413 3.12559581 2.27840328 2.93908429 0.96808767 3.17380023 1.60673594 2.59704685 3.98458505 1.25713849 1.90271974 1.82997131 2.93574715 2.14195251 3.26882362 2.09072447 2.07539392 3.77434778 1.82215333 3.30864692 1.52123737 2.29328823 1.36722493 3.34969425 2.55285358 3.15811181 4.44630671 2.7549541 2.83824682 2.50485158 2.45610046 1.5423398 3.12460995 2.38987827 0.983325 2.64392757 3.11031628 1.41283321 2.58364391 2.17403984 3.19049454 2.83069992 1.04926252 2.93791962 2.37773943 3.51300693 3.02249169 2.59249544 1.81437802 3.34520817 3.04475498 4.02190208 3.84745455 2.45946741 2.06334138 2.11823249 2.95765638]

为什么输出不同?你知道吗


Tags: importinputtfas格式nprandomtheano