TensorFlow2.0实现了TabNet的任何配置。
tabnet的Python项目详细描述
Tensorflow 2.0的TabNet
用于论文TabNet: Attentive Interpretable Tabular Learning的Tensorflow 2.0端口,其原始代码库位于https://github.com/google-research/google-research/blob/master/tabnet。在
上面的图像是从本文中获得的,模型分两个阶段建立,一个处理输入特征,另一个构建模型的输出。在
与纸张的区别
与文件和正式实施有两大区别。在
- 在
这个实现在规范化方法中提供了一个选择,在来自论文的正则
Batch Normalization
和{}之间进行选择。在 - 结果表明,本文使用了非常大的批大小来稳定批规范化,并获得了良好的泛化效果。这方面的一个问题是计算成本。在
- 因此,组规范化(将组数设置为1,又称实例规范化)提供了一个与批处理大小无关的合理替代方案。在
- 对于
Instance Normalization
类型行为,可以将num_groups
设置为1,对于Layer Normalization
类型行为,可以设置为-1。在
- 在
这个实现并不严格需要特性列作为输入。在
- 虽然这个模型最初是为tablur数据开发的,但是对于它接受的唯一类型的输入没有硬性要求。在
- 通过传递
feature_columns=None
并显式地指定数据的输入维度(使用num_features
),我们可以从偶数图像数据(在将其展平为一个长向量之后)得到半可解释的结果。在
安装
- 对于最新版本分支
$ pip install --upgrade tabnet
- 主分支机构。在
由于Tensorflow可以与CPU或GPU一起使用,因此可以使用[cpu]
或{
$ pip install tabnet[cpu] $ pip install tabnet[gpu]
使用
可以导入脚本tabnet.py
来生成TabNet
构建块,或者生成TabNetClassification
和{TabNet
模型添加适当的头。如果要定制分类或回归头,建议使用TabNet
作为模型的基础来构建一个新的模型。在
fromtabnetimportTabNet,TabNetClassifiermodel=TabNetClassifier(feature_list,num_classes,...)
堆叠选项卡
常规选项卡可以堆叠到不同的层中,从而降低了可解释性,但提高了模型容量。在
fromtabnetimportStackedTabNetClassifiermodel=TabNetClassifier(feature_list,num_classes,num_layers,...)
由于模型使用自定义对象,因此有必要在仅计算的脚本中导入custom_objects.py
。在
蒙版可视化
TabNet的掩码可以通过使用TabNet类属性获得
feature_selection_masks
:返回中间决策步骤的一个或多个掩码的列表。掩码数=决策步骤数-1aggregate_feature_selection_mask
:返回单个张量,它是该批训练样本上掩码的平均激活。在
这些掩模可以得到TabNet.feature_selection_masks
。由于TabNetClassification
和{TabNet
组成,因此可以得到model.tabnet.*
掩码生成必须处于紧急执行模式
注意:由于自动签名,当使用fit()
或predict()
kerasapi时,模型的输出将
通常是基于图的张量,而不是急切张量。由于掩码是在Model.call()
方法中生成的,
有必要强制模型以急切执行模式而不是图形模式运行。在
因此,有两种方法可以强制模型进入急切模式:
- 获取张量数据样本,并使用此数据直接
call
模型,如下所示:
x,_=next(iter(tf_dataset))# Assuming it generates an (x, y) tuple._=model(x)# This forces eager execution.
- 或者另一个选择是构建一个单独的模型(但是这里您将把
dynamic=True
标志传递给模型构造函数), 加载此模型中的权重和参数,并调用model.predict(x)
。这也应该强制执行模式。在
new_model=TabNetClassification(...,dynamic=True)new_model.load_weights('path/to/weights)')x,_=next(iter(tf_dataset))# Assuming it generates an (x, y) tuple.model.predict(x)
在模型被强制进入急执行模式后,可以在Tensorboard中可视化掩码,如下所示-
writer=tf.summary.create_file_writer("logs/")withwriter.as_default():fori,maskinenumerate(model.tabnet.feature_selection_masks):print("Saving mask {} of shape {}".format(i+1,mask.shape))tf.summary.image('mask_at_iter_{}'.format(i+1),step=0,data=mask,max_outputs=1)writer.flush()agg_mask=model.tabnet.aggregate_feature_selection_maskprint("Saving aggregate mask of shape",agg_mask.shape)tf.summary.image("Aggregate Mask",step=0,data=agg_mask,max_outputs=1)writer.flush()writer.close()
要求
- 在Tensorflow 2.0+(启用V2 compat的1.14+对于1.x可能足够)
- Tensorflow数据集(仅用于评估
train_iris.py
)
- 项目
标签: