使用字符串十的Tensorflow字典查找

2024-05-18 18:38:01 发布

您现在位置:Python中文网/ 问答频道 /正文

有没有什么方法可以基于Tensorflow中的字符串张量执行字典查找?

在普通的Python中,我会做一些

value = dictionary[key]

是的。现在我想在Tensorflow运行时做同样的事情,当我的key作为一个字符串张量时。有点像

value_tensor = tf.dict_lookup(string_tensor)

会很好的。


Tags: 方法key字符串stringdictionary字典valuetf
3条回答

如果你想用新的TF 2.0代码运行它,默认情况下会启用紧急执行。下面是快速代码片段。

import tensorflow as tf

# build a lookup table
table = tf.lookup.StaticHashTable(
    initializer=tf.lookup.KeyValueTensorInitializer(
        keys=tf.constant([0, 1, 2, 3]),
        values=tf.constant([10, 11, 12, 13]),
    ),
    default_value=tf.constant(-1),
    name="class_weight"
)

# now let us do a lookup
input_tensor = tf.constant([0, 0, 1, 1, 2, 2, 3, 3])
out = table.lookup(input_tensor)
print(out)

输出:

tf.Tensor([10 10 11 11 12 12 13 13], shape=(8,), dtype=int32)

您可能会发现tensorflow.contrib.lookup有帮助: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lookup/lookup_ops.py

https://www.tensorflow.org/api_docs/python/tf/contrib/lookup/HashTable

特别是,您可以:

table = tf.contrib.lookup.HashTable(
  tf.contrib.lookup.KeyValueTensorInitializer(keys, values), -1
)
out = table.lookup(input_tensor)
table.init.run()
print out.eval()

gather可以帮助您,但它只获取list的值。可以将字典转换为键和值列表,然后应用tf.gather。示例:

# Your dict
dict_ = {'a': 1.12, 'b': 5.86, 'c': 68.}
# concrete query
query_list = ['a', 'c']

# unpack key and value lists
key, value = list(zip(*dict_.items()))
# map query list to list -> [0, 2]
query_list = [i for i, s in enumerate(key) if s in query_list]

# query as tensor
query = tf.placeholder(tf.int32, shape=[None])
# convert value list to tensor
vl_tf = tf.constant(value)
# get value
my_vl = tf.gather(vl_tf, query)

# session run
sess = tf.InteractiveSession()
sess.run(my_vl, feed_dict={query:query_list})

相关问题 更多 >

    热门问题