实现像素RNN的rowlsm pytorch的前向传递

2024-10-03 09:15:40 发布

您现在位置:Python中文网/ 问答频道 /正文

这个问题说明了一切。我试图在PyTorch中实现像素RNN的行LSTM的前向传递的门计算。我特别好奇第1行的第一个门是如何计算的,因为没有以前的单元状态或隐藏状态可以使用。在

我知道有一个问题与此非常相似,但接受的答案没有代码,这是我正在寻找的。在

此外,我已经实现了我自己的代码,但是训练速度非常慢,这就是为什么我在寻找一个更快的解决方案。在

class RLSTM(nn.Module):
    def __init__(self,ch):
        super(RLSTM,self).__init__()
        self.ch=ch
        self.input_to_state = torch.nn.Conv2d(self.ch,4*self.ch,kernel_size=(1,3),padding=(0,1)).cuda()
        self.state_to_state = torch.nn.Conv2d(self.ch,4*self.ch,kernel_size=(1,3),padding=(0,1)).cuda() # error is here: hidPrev is an array - not a valid number of input channel
        self.cell_list = []

        # check if these convolutions are changing their weights     

    def forward(self, image):
       # print("starting forward")
       # if(self.ch==64):
      #      print self.input_to_state.weight[0][0][0][0]
        size = image.size()
        b = size[0]
        indvs = list(image.split(1,0))
        tensor_array = []

        for i in range(b):
            tensor_array.append(self.RowLSTM(indvs[i]))
        seq=tuple(tensor_array)
        trans = torch.cat(seq,0)
        global total
        total+=1
      #  print("finished forward")
        return trans.cuda() 
    def RowLSTM(self, image): 
     # input-to-state (K_is * x_i) : 1x3 convolution. generate h x n x n tensor. hxnxn tensor contains all i -> s info
       # print("Starting LSTM")
        self.cell_list=[]
        igates = []
       # print(image.size())
        n = image.size()[2]
        ch=image.size()[1]
        for i in range(n):
            if i==0:      
                # COULD BE THIS AREA, SHOULD EVERYTHING BE 0?   
                isgates = self.splitIS(self.input_to_state(image)) # convolve, then split into gates (4 per row) 
                #print("{} and {}".format(len(isgates),len(isgates[0])))
                #cell=RowLSTMCell(0,torch.zeros(ch,n,1).cuda(),torch.zeros(ch,n,1).cuda(),torch.zeros(ch,n,1).cuda(),torch.zeros(ch,n,1).cuda(),torch.zeros(ch,n,1).cuda(),torch.zeros(ch,n,1).cuda())
                cell=RowLSTMCell(0,isgates[0][0],isgates[0][1],isgates[0][2],isgates[0][3],torch.zeros(ch,n,1).cuda(),torch.zeros(ch,n,1).cuda())
                cell.c=isgates[0][0]*isgates[0][3]
                cell.h=torch.tanh(cell.c)*isgates[0][1]
                # now have dummy variables for first row
                self.cell_list.append(cell)       
            else:   
                cell_prev = self.cell_list[i-1]
                hid_prev = cell_prev.getHiddenState()
                ssgates = self.splitSS(self.state_to_state(hid_prev.unsqueeze(0)))
                gates = self.addGates(isgates, ssgates,i)
                ig, og, fg, gg = gates[0], gates[1], gates[2], gates[3]
                cell = RowLSTMCell(cell_prev, ig, og, fg, gg, 0 ,0) #MORE zeros
                cell.compute()

                self.cell_list.append(cell)

        # now have a list of all cell data, concatenate hidden state into 1 x h x n x n
        hidden_layers = []
        for i in range(n):
            hid = self.cell_list[i].h
            hidden_layers.append(torch.unsqueeze(hid,0))

        seq = tuple(hidden_layers)
        tensor = torch.cat(seq,3)
      #  print("finished lstm")
        #print(tensor.size())
        return tensor 

    def splitIS(self, tensor): #always going to be splitting into 4 pieces, so no need to add extra parameters
        inputStateGates={}
        size=tensor.size() # 1 x 4h x n x n
        out_ft=size[1] # get 4h for the nxnx4h tensor
        num=size[2] # get n for the nxn image
        hh=out_ft/4 # we want to split the tensor into 4, for the gates
        tensor = torch.squeeze(tensor).cuda() # 4h x n x n

        # First, split by row: Creates n tensors of 4h x n x 1
        rows = list(tensor.split(1,2))

        for i in range(num):
            # Each row is a tensor of 4h x n x 1, split it into 4 of h x n x 1
            row=rows[i]
          #  print("Each row using cuda: "+str(row.is_cuda))
            inputStateGates[i]=list(row.split(hh,0))

        return inputStateGates 


    def splitSS(self, tensor): # 1 x 4h x n x 1, create 4 of 1 x h x n x 1 
        size=tensor.size() 
        out_ft=size[1] # get 4h for the 1x4hxn tensor
        num=size[2] # get n for the 1xhxn row
        hh=out_ft/4 # we want to split the tensor into 4, for the gates
        tensor = tensor.squeeze(0).cuda() # 4h x n x 1
        splitted=list(tensor.split(hh,0))
        return splitted 


    def addGates(self, i2s,s2s,key):
        """ these dictionaries are of form {key : [[i], [o], [f], [g]]}
            we want to add pairwise elemeents """

        # i2s is of form key: [[i], [o], [f], [g]] where each gate is hxn
        # s2s is of form [[h,n],[h,n],[h,n], [h,n]]
        gateSum = []
        for i in range(4): # always of length 4, representing the gates
            gateSum.append(torch.sigmoid(i2s[key][i] + s2s[i]))
        return gateSum

这是每个细胞的分类:

^{pr2}$

Tags: oftoimageselfforsizecelltorch