TensorFlow 2.0的可解释性回调
tf-explain的Python项目详细描述
tf解释
tf explain将可解释性方法实现为对ease神经网络理解的tensorflow 2.0回调。
安装
tf explain作为alpha版本在pypi上可用。要安装它:
virtualenv venv -p python3.6 pip install tf-explain
tf-explain与tensorflow 2兼容。它没有声明为依赖项 以便在CPU和GPU版本之间进行选择。除了上一次安装之外,请运行:
# For CPU version pip install tensorflow==2.0.0-beta1 # For GPU version pip install tensorflow-gpu==2.0.0-beta1
可用方法
激活可视化
Visualize how a given input comes out of a specific activation layer
fromtf_explain.callbacks.activations_visualizationimportActivationsVisualizationCallbackmodel=[...]callbacks=[ActivationsVisualizationCallback(validation_data=(x_val,y_val),layers_name=["activation_1"],output_dir=output_dir,),]model.fit(x_train,y_train,batch_size=2,epochs=2,callbacks=callbacks)
遮挡敏感度
Visualize how parts of the image affects neural network's confidence by occluding parts iteratively
fromtf_explain.callbacks.occlusion_sensitivityimportOcclusionSensitivityCallbackmodel=[...]callbacks=[OcclusionSensitivityCallback(validation_data=(x_val,y_val),class_index=0,patch_size=4,output_dir=output_dir,),]model.fit(x_train,y_train,batch_size=2,epochs=2,callbacks=callbacks)
tabby类的遮挡敏感度(条纹将tabby cat与其他imagenet cat类区分开来)
梯度凸轮
Visualize how parts of the image affects neural network's output by looking into the activation maps
从Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization
fromtf_explain.callbacks.grad_camimportGradCAMCallbackmodel=[...]callbacks=[GradCAMCallback(validation_data=(x_val,y_val),layer_name="activation_1",class_index=0,output_dir=output_dir,)]model.fit(x_train,y_train,batch_size=2,epochs=2,callbacks=callbacks)
平滑梯度
Visualize stabilized gradients on the inputs towards the decision
来自SmoothGrad: removing noise by adding noise
fromtf_explain.callbacks.smoothgradimportSmoothGradCallbackmodel=[...]callbacks=[SmoothGradCallback(validation_data=(x_val,y_val),class_index=0,num_samples=20,noise=1.,output_dir=output_dir,)]model.fit(x_train,y_train,batch_size=2,epochs=2,callbacks=callbacks)
可视化结果
使用回调时,将在logs
目录中创建输出文件。
您可以使用以下命令在tensorboard中看到它们:tensorboard --logdir logs
路线图
- []子类化API支持
- []其他方法
- []自动生成的API文档和文档测试