tf.data.Dataset筛选?

2024-10-06 07:17:05 发布

您现在位置:Python中文网/ 问答频道 /正文

我知道可以用^{}过滤数据集:

d = tf.data.Dataset.from_tensor_slices([1, 2, 3])

d = d.filter(lambda x: x < 3)  # ==> [1, 2]

# `tf.math.equal(x, y)` is required for equality comparison
def filter_fn(x):
  return tf.math.equal(x, 1)

d = d.filter(filter_fn)  # ==> [1]

如果我想做“批量筛选”呢?我的意思是,给定一批字符串['str1', 'str2', 'str3', 'str4'],如何生成一个数据集,该数据集能够返回一个经过过滤的数据集,该数据集会吐出一批与这些字符串相对应的值:[val1_respects_str1, val2_respects_str2, val3_respects_str3, val4_respects_str4]?你知道吗


Tags: 数据字符串fromdatatfmathequalfilter
1条回答
网友
1楼 · 发布于 2024-10-06 07:17:05

你想要的不是一个过滤器而是一个地图。map函数将计算映射值,例如:

d = tf.data.Dataset.from_tensor_slices(list(range(100)))

def map_fn(x):
  return x*2  # if you need arbitrary python logic, use tf.py_function to wrap it

d = d.shuffle(100).batch(10).map(map_fn)

相关问题 更多 >