从Python中caffe.prototxt模型定义读取网络参数

2024-06-28 20:08:38 发布

您现在位置:Python中文网/ 问答频道 /正文

我想从Python中的.prototxt中定义的caffe网络读取网络参数,因为layer_dict中的layer对象只告诉我它是一个“卷积”层,而不是kernel_sizestrides等在.prototxt文件中定义良好的内容。在

所以假设我有一个model.prototxt像这样:

name: "Model"
layer {
  name: "data"
  type: "Input"
  top: "data"
  input_param {
    shape: {
      dim: 64
      dim: 1
      dim: 28
      dim: 28
    }
  }
}
layer {
  name: "conv2d_1"
  type: "Convolution"
  bottom: "data"
  top: "conv2d_1"
  convolution_param {
    num_output: 32
    kernel_size: 3
    stride: 1
    weight_filler {
      type: "gaussian" # initialize the filters from a Gaussian
      std: 0.01        # distribution with stdev 0.01 (default mean: 0)
    }
    bias_filler {
      type: "constant"
      value: 0
    }
  }
}

layer {
  name: "dense_1"
  type: "InnerProduct"
  bottom: "conv2d_1"
  top: "out"
  inner_product_param {
    num_output: 1024
    weight_filler {
      type: "gaussian"
      std: 0.01
    }
    bias_filler {
      type: "constant"
      value: 0
    }
  }
}

我发现可以这样解析模型:

^{pr2}$

但我不知道如何从protobuf消息中获取结果对象的字段。在


Tags: 对象name网络layerdatasize定义param
2条回答

您可以遍历这些层并询问它们对应的参数,例如:

for i in range(0, len(net.layer)):
    if net.layer[i].type == 'Convolution':
        net.layer[i].convolution_param.bias_term = True # bias term, for example

可以在中找到适当的*\u param类型原形咖啡馆,例如:

^{pr2}$

Caffe prototxt文件是基于googleprotobuf构建的。为了有问题地访问它们,您需要使用该包。下面是一个示例脚本(source):

from caffe.proto import caffe_pb2
import google.protobuf.text_format as txtf

net = caffe_pb2.NetParameter()

fn = '/tmp/net.prototxt'
with open(fn) as f:
    s = f.read()
    txtf.Merge(s, net)

net.name = 'my new net'
layerNames = [l.name for l in net.layer]
idx = layerNames.index('fc6')
l = net.layer[idx]
l.param[0].lr_mult = 1.3

outFn = '/tmp/newNet.prototxt'
print 'writing', outFn
with open(outFn, 'w') as f:
    f.write(str(net))

相关问题 更多 >