Pytorch:如何从扁平化网络中取消扁平化/恢复网络?

2024-10-01 17:40:57 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在使用以下功能展平网络:

#############################################################################
# Flattening the NET
#############################################################################
def flattenNetwork(net):
    flatNet = []
    shapes = []
    for param in net.parameters():
        #if its WEIGHTS
        curr_shape = param.cpu().data.numpy().shape
        shapes.append(curr_shape)
        if len(curr_shape) == 2:
            param = param.cpu().data.numpy().reshape(curr_shape[0]*curr_shape[1])
            flatNet.append(param)
        elif len(curr_shape) == 4:
            param = param.cpu().data.numpy().reshape(curr_shape[0]*curr_shape[1]*curr_shape[2]*curr_shape[3])
            flatNet.append(param)
        else:
            param = param.cpu().data.numpy().reshape(curr_shape[0])
            flatNet.append(param)
    finalNet = []
    for obj in flatNet:
        for x in obj:
            finalNet.append(x)
    finalNet = np.array(finalNet)
    return finalNet,shapes

上述函数将所有权重作为网络的numpy列向量finalNetshapes(列表)返回。我想看看权重修改对预测准确性的影响。所以,我改变了重量。如何将修改后的权重向量复制回原始网络?请帮忙。多谢各位


Tags: in网络numpyfordatanetparamcpu
1条回答
网友
1楼 · 发布于 2024-10-01 17:40:57

模型定义(其forward函数)和参数配置(称为模型状态,可以使用^{}作为字典轻松访问)之间存在差异

您可以获取模型的状态,就像您在实现flattenNetwork中所做的那样。但是,对于几乎所有的模型,恢复此操作(如果您只有权重和层形状)是不可能的

现在,假设您仍然可以访问net。我的建议是直接使用net.state_dict(),修改它,然后用^{}加载权重字典。这样,您将避免自己处理序列化模型参数的问题

相关问题 更多 >

    热门问题