shapeguard是帮助处理tensorflow中的形状的工具。
shapeguard的Python项目详细描述
形状保护罩
shapeguard是帮助处理tensorflow中的形状的工具。
基本用法
importtensorflowastffromshapeguardimportShapeGuardsg=ShapeGuard()img=tf.ones([64,32,32,3])flat_img=tf.ones([64,1024])labels=tf.ones([64])# check shape consistencysg.guard(img,"B, H, W, C")sg.guard(labels,"B, 1")# raises error because of rank mismatchsg.guard(flat_img,"B, H*W*C")# raises error because 1024 != 32*32*3# guard also returns the tensor, so it can be inlinedmean_img=sg.guard(tf.reduce_mean(img,axis=0),"H, W, C")# more readable reshapesflat_img=sg.reshape(img,'B, H*W*C')# evaluate templatesassertsg['H, W*C+1']==[32,97]# attribute access to inferred dimensionsassertsg.B==64
形状模板语法
形状模板mini dsl支持多种指定形状的方法:
- 数字:
"64, 32, 32, 3"
- 命名维度:
"B, width, height2, channels"
- 通配符:
"B, *, *, *"
- 省略号:
"B, ..., 3"
- 加、减、乘、除:
"B*N, W/2, H*(C+1)"
- 动态维度:
"?, H, W, C"
(仅匹配[None, H, W, C]
)
免责声明
这不是官方支持的谷歌产品。