从学习变量得到的期望张量流模型大小

2024-05-19 22:25:51 发布

您现在位置:Python中文网/ 问答频道 /正文

当训练卷积神经网络用于图像分类任务时,我们通常希望我们的算法学习将给定图像转换为其正确标签的滤波器(和偏差)。我有一些模型,我试图在模型大小、操作数量、精确度等方面进行比较。但是,从tensorflow输出的模型的大小,具体地说是型号.ckpt.数据存储图中所有变量值的文件不是我所期望的。事实上,它似乎大了三倍。在

为了直接回答这个问题,我将把我的问题建立在this朱皮特笔记本上。以下是定义变量(权重和偏差)的部分:

# Store layers weight & bias
weights = {
# 5x5 conv, 1 input, 32 outputs
'wc1': tf.Variable(tf.random_normal([5, 5, 1, 32]),dtype=tf.float32),
# 5x5 conv, 32 inputs, 64 outputs
'wc2': tf.Variable(tf.random_normal([5, 5, 32, 64]),dtype=tf.float32),
# fully connected, 7*7*64 inputs, 1024 outputs
'wd1': tf.Variable(tf.random_normal([7*7*64, 1024]),dtype=tf.float32),
# 1024 inputs, 10 outputs (class prediction)
'out': tf.Variable(tf.random_normal([1024, num_classes]),dtype=tf.float32)
}

biases = {
'bc1': tf.Variable(tf.random_normal([32]),dtype=tf.float32),
'bc2': tf.Variable(tf.random_normal([64]),dtype=tf.float32),
'bd1': tf.Variable(tf.random_normal([1024]),dtype=tf.float32),
'out': tf.Variable(tf.random_normal([num_classes]),dtype=tf.float32)
}

为了在培训过程结束时保存模型,我添加了几行:

^{pr2}$

把所有这些变量加起来,我们会得到一个型号.ckpt.数据大小为12.45Mb的文件(我通过计算模型学习的浮点元素的数量,然后将该值转换为兆字节,就得到了这个值)。但是!保存的.data文件为39.3Mb。为什么会这样?在

我在一个更复杂的网络(ResNet的变体)中采用了相同的方法,我的预期是模型.数据大小也比实际的.data文件小约3倍。在

所有这些变量的数据类型都是float32。在


Tags: 文件数据模型图像数量tfrandomoutputs
1条回答
网友
1楼 · 发布于 2024-05-19 22:25:51

Adding up all those variables we would expect to get a model.ckpt.data file of size 12.45Mb

传统上,大多数模型参数都在第一个完全连通的层中,在本例中wd1。仅计算其大小即可得出:

7*7*128 * 1024 * 4 = 25690112

。。。或25.6Mb。注意4系数,因为变量dtype=tf.float32,即每个参数4字节。其他层也会影响模型大小,但不会太大。在

如您所见,您的估计值12.45Mb有点偏离(您是否使用每个参数16位?)。检查点还存储一些常规信息,因此开销大约为25%,这仍然很大,但不是300%。在

[更新]

所讨论的模型实际上具有形状为[7*7*64, 1024]的FC1层,如前所述。所以上面计算的大小应该是12.5Mb。这让我更仔细地查看保存的检查点。在

在检查之后,我注意到了我最初忽略的其他大变量:

^{pr2}$

Variable_2正好是wd1,但是Adam优化器还有2个副本。这些变量由the Adam optimizer创建,它们被称为,并为所有可训练变量保存m和{}累加器。现在总尺寸是合理的。在

您可以运行以下代码来计算图形变量-37.47Mb的总大小:

var_sizes = [np.product(list(map(int, v.shape))) * v.dtype.size
             for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)]
print(sum(var_sizes) / (1024 ** 2), 'MB')

开销其实很小。额外的大小是由于优化器。在

相关问题 更多 >