  1. 当使用共享LSTM层训练模型并使用stateful=True时,并行使用是否也在训练期间更新相同的状态?你知道吗
  2. 如果我的观察是有效的,是否有一种方法可以使用权重共享LSTMs,以便为每个并行使用独立地存储状态?你知道吗




import keras
import keras.backend as K
import numpy as np

nOut = 3
xShape = (3, 50, 4)
inShape = (xShape[0], None, xShape[2])   
batchInShape = (1, ) + inShape
x = np.random.randn(*xShape)

# construct network
xIn = keras.layers.Input(shape=inShape, batch_shape=batchInShape)

# shared LSTM layer
sharedLSTM = keras.layers.LSTM(units=nOut, stateful=True, return_sequences=True, return_state=False)

# split the input on the first axis
x1 = keras.layers.Lambda(lambda x: x[:,0,:,:])(xIn)
x2 = keras.layers.Lambda(lambda x: x[:,1,:,:])(xIn)
x3 = keras.layers.Lambda(lambda x: x[:,2,:,:])(xIn)

# pass each input through the LSTM
z1 = sharedLSTM(x1)
z2 = sharedLSTM(x2)
z3 = sharedLSTM(x3)

# add a singleton dimension
y1 = keras.layers.Lambda(lambda x: K.expand_dims(x, axis=1))(z1)
y2 = keras.layers.Lambda(lambda x: K.expand_dims(x, axis=1))(z2)
y3 = keras.layers.Lambda(lambda x: K.expand_dims(x, axis=1))(z3)

# combine the outputs
y = keras.layers.Concatenate(axis=1)([y1, y2, y3])

model = keras.models.Model(inputs=xIn, outputs=y)
model.compile(loss='mse', optimizer='adam')

# no need to train, since we're interested only what is happening mechanically

# reset to a known state and predict for full input
aFull = model.predict(x[np.newaxis,:,:,:])

# reset to a known state and predict for the same input, but in two pieces
a1 = model.predict(x[np.newaxis,:,:xShape[1]//2,:])
a2 = model.predict(x[np.newaxis,:,xShape[1]//2:,:])
# combine the pieces
aSplit = np.concatenate((a1, a2), axis=2)

print('full diff: {}, first half diff: {}, second half diff: {}'.format(str(np.sum(np.abs(aFull - aSplit))), str(np.sum(np.abs(aFull[:,:,:xShape[1]//2,:] - aSplit[:,:,:xShape[1]//2,:]))), str(np.sum(np.abs(aFull[:,:,xShape[1]//2:,:] - aSplit[:,:,xShape[1]//2:,:])))))

更新:Keras使用Tensorflow 1.14和1.15作为后端,观察到上述行为。使用tf2.0运行相同的代码(使用调整后的导入)会更改结果,因此a1不再与aFull的前半部分相同。这仍然可以通过在层实例化中设置stateful=False来实现。你知道吗


更新2:似乎同样的功能也被其他早期版本遗漏了:closed, unanswered question at Keras' github。你知道吗


import torch
import numpy as np

class sharedLSTM(torch.nn.Module):

    def __init__(self, batchSz, nBands, nDims, outDim):
        super(sharedLSTM, self).__init__()
        self.internalLSTM = torch.nn.LSTM(input_size=nDims, hidden_size=outDim, num_layers=1, bias=True, batch_first=True)
        allStates = list()
        for bandIdx in range(nBands):
            h_0 = torch.zeros(1, batchSz, outDim)
            c_0 = torch.zeros(1, batchSz, outDim)
            allStates.append((h_0, c_0))

        self.allStates = allStates            
        self.nBands = nBands

    def forward(self, x):
        allOut = list()
        for dimIdx in range(self.nBands):
            thisSlice = x[:,dimIdx,:,:] # (batchSz, nSteps, nFeats)
            thisState = self.allStates[dimIdx]

            thisY, thisState = self.internalLSTM(thisSlice, thisState) 
            self.allStates[dimIdx] = thisState
            allOut.append(thisY[:,None,:,:]) # => (batchSz, 1, nSteps, nFeats)

        y =, dim=1) # => (batchSz, nDims, nSteps, nFeats)

        return y

    def resetStates(self):
        for bandIdx in range(nBands):
            self.allStates[bandIdx][0][:] = 0.0
            self.allStates[bandIdx][1][:] = 0.0

batchSz = 5
nBands = 3
nFeats = 4
nOutDims = 2
net = sharedLSTM(batchSz, nBands, nFeats, nOutDims)
net = net.float()

N = 20
x = torch.from_numpy(np.random.rand(batchSz, nBands, N, nFeats)).float()
x1 = x[:, :, :N//2, :]
x2 = x[:, :, N//2:, :]

aa = net.forward(x)
a1 = net.forward(x1)
a2 = net.forward(x2)

print('(with reset) first half abs diff: {}'.format(str(torch.sum(torch.abs(a1 - aa[:,:,:N//2,:])).detach().numpy())))
print('(with reset) second half abs diff: {}'.format(str(torch.sum(torch.abs(a2 - aa[:,:,N//2:,:])).detach().numpy())))



import keras
import numpy as np

class sharedLSTM(keras.Model):
    def __init__(self, batchSz, nBands, nDims, outDim):
        super(sharedLSTM, self).__init__()
        self.internalLSTM = keras.layers.LSTM(units=outDim, stateful=True, return_sequences=True, return_state=True), None, nDims))
        allStates = list()
        allSlicers = list()
        for bandIdx in range(nBands):
            allSlicers.append(keras.layers.Lambda(lambda x, b: x[:, :, b, :], arguments = {'b' : bandIdx}))

        self.allStates = allStates            
        self.allSlicers = allSlicers
        self.Concat = keras.layers.Lambda(lambda x: keras.backend.concatenate(x, axis=2))

        self.nBands = nBands

    def call(self, x):
        allOut = list()
        for bandIdx in range(self.nBands):
            thisSlice = self.allSlicers[bandIdx]( x )
            thisState = self.allStates[bandIdx]

            thisY, *thisState = self.internalLSTM(thisSlice, initial_state=thisState) 
            self.allStates[bandIdx] = thisState.copy()

        y = self.Concat( allOut )
        return y

batchSz = 1
nBands = 3
nFeats = 4
nOutDims = 2
N = 20

model = sharedLSTM(batchSz, nBands, nFeats, nOutDims)
model.compile(optimizer='SGD', loss='mae')

x = np.random.rand(batchSz, N, nBands, nFeats)
x1 = x[:, :N//2, :, :]
x2 = x[:, N//2:, :, :]

aa = model.predict(x)

a1 = model.predict(x1)
a2 = model.predict(x2)

print('(with reset) first half abs diff: {}'.format(str(np.sum(np.abs(a1 - aa[:,:N//2,:,:])))))
print('(with reset) second half abs diff: {}'.format(str(np.sum(np.abs(a2 - aa[:,N//2:,:,:])))))


