# Your dict
dict_ = {'a': 1.12, 'b': 5.86, 'c': 68.}
# concrete query
query_list = ['a', 'c']
# unpack key and value lists
key, value = list(zip(*dict_.items()))
# map query list to list -> [0, 2]
query_list = [i for i, s in enumerate(key) if s in query_list]
# query as tensor
query = tf.placeholder(tf.int32, shape=[None])
# convert value list to tensor
vl_tf = tf.constant(value)
# get value
my_vl = tf.gather(vl_tf, query)
# session run
sess = tf.InteractiveSession()
sess.run(my_vl, feed_dict={query:query_list})
如果你想用新的TF 2.0代码运行它,默认情况下会启用紧急执行。下面是快速代码片段。
输出:
您可能会发现
tensorflow.contrib.lookup
有帮助: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lookup/lookup_ops.pyhttps://www.tensorflow.org/api_docs/python/tf/contrib/lookup/HashTable
特别是,您可以:
gather可以帮助您,但它只获取list的值。可以将字典转换为键和值列表,然后应用tf.gather。示例:
相关问题 更多 >
编程相关推荐