如何将SparseTensorValue转换成numpy数组?

2024-09-29 06:28:12 发布

您现在位置:Python中文网/ 问答频道 /正文

我有一个Tensorflow网络,可以在调用Session().run()后得到图的值。但是,我在将SparseTensorValue转换为其他类型时遇到了一些问题。在

例如,下面的程序创建一个SparseTensorValue。在

>>> import tensorflow as tf
>>> t = tf.Session().run(tf.SparseTensor([[0,1], [0,0], [1,1], [1,0]],[1,2,3,4],[2,2]))
>>> print(t)
SparseTensorValue(indices=array([[0, 1],
       [0, 0],
       [1, 1],
       [1, 0]]), values=array([1, 2, 3, 4], dtype=int32), dense_shape=array([2, 2]))
>>> 

我想要的是将t转换为np.array或{},例如np.array([[2., 1.], [4., 3.]])。在

我目前的情况是

^{pr2}$

有没有更好的方法来进行转换?尤其是,我想消除for循环。在


Tags: runimport程序网络类型sessiontftensorflow
1条回答
网友
1楼 · 发布于 2024-09-29 06:28:12

多亏了hpaulj的提示,我找到了从Tensorflow网站转换的方法。在

tf.Session().run(tf.sparse.to_dense(tf.sparse.reorder(t)))

首先将值重新排序为字典顺序,然后使用to_dense使其密集,最后将张量输入Session().run()。在

相关问题 更多 >