Tensorflow变换字符串张量的每个元素

2024-09-28 22:22:35 发布

您现在位置:Python中文网/ 问答频道 /正文

我有一个弦的张量。下面是一些示例字符串

com.abc.display,com.abc.backend,com.xyz.forte,blah
com.pqr,npr.goog

我想做一些预处理,将CSV拆分成它的一部分,然后在点处拆分每个部分,然后创建多个字符串,其中一个字符串是另一个字符串的前缀。此外,必须删除所有blah

例如,给定第一个字符串com.abc.display,com.abc.backend,com.xyz.forte,它将转换为以下字符串的数组/列表

['com', 'com.abc', 'com.abc.display', 'com.abc.backend', 'com.xyz', 'com.xyz.forte']

结果列表没有重复项(这就是为什么com.abc.backend的前缀字符串没有显示为已包含的字符串-comcom.abc

我编写了下面的python函数,在给定一个CSV字符串示例的情况下可以实现上述功能

def expand_meta(meta):
    expanded_subparts = []
    meta_parts = set([x for x in meta.split(',') if x != 'blah'])
    for part in meta_parts:
        subparts = part.split('.')
        for i in range(len(subparts)+1):
            expanded = '.'.join(subparts[:i])
            if expanded:
                expanded_subparts.append(expanded)
    return list(set(expanded_subparts))

在第一个示例中调用此方法

expand_meta('com.abc.display,com.abc.backend,com.xyz.forte,blah')

返回

['com.abc.display',
 'com.abc',
 'com.xyz',
 'com.xyz.forte',
 'com.abc.backend',
 'com']

我知道tensorflow有这个方法。我希望用它来变换张量的每个元素。但是,我得到了以下错误

File "mypreprocess.py", line 152, in expand_meta
    meta_parts = set([x for x in meta.split(',') if x != 'blah'])
AttributeError: 'Tensor' object has no attribute 'split'

因此,我似乎无法使用带有map_fn的常规python函数,因为它期望元素为tensor。我该怎么做我想在这里做的事

(我的Tensorflow版本是1.11.0)


Tags: 字符串incombackend示例fordisplaymeta
1条回答
网友
1楼 · 发布于 2024-09-28 22:22:35

我想这正是你想要的:

import tensorflow as tf

# Function to process a single string
def make_splits(s):
    s = tf.convert_to_tensor(s)
    # Split by comma
    split1 = tf.strings.split([s], ',').values
    # Remove blahs
    split1 = tf.boolean_mask(split1, tf.not_equal(split1, 'blah'))
    # Split by period
    split2 = tf.string_split(split1, '.')
    # Get dense split tensor
    split2_dense = tf.sparse.to_dense(split2, default_value='')
    # Accummulated concatenations
    concats = tf.scan(lambda a, b: tf.string_join([a, b], '.'),
                      tf.transpose(split2_dense))
    # Get relevant concatenations
    out = tf.gather_nd(tf.transpose(concats), split2.indices)
    # Remove duplicates
    return tf.unique(out)[0]

# Test
with tf.Graph().as_default(), tf.Session() as sess:
    # Individual examples
    print(make_splits('com.abc.display,com.abc.backend,com.xyz.forte,blah').eval())
    # [b'com' b'com.abc' b'com.abc.display' b'com.abc.backend' b'com.xyz'
    #  b'com.xyz.forte']
    print(make_splits('com.pqr,npr.goog').eval())
    # [b'com' b'com.pqr' b'npr' b'npr.goog']

    # Apply to multiple strings with a loop
    data = tf.constant([
        'com.abc.display,com.abc.backend,com.xyz.forte,blah',
        'com.pqr,npr.goog'])
    ta = tf.TensorArray(size=data.shape[0], dtype=tf.string,
                        infer_shape=False, element_shape=[None])
    _, ta = tf.while_loop(
        lambda i, ta: i < tf.shape(data)[0],
        lambda i, ta: (i + 1, ta.write(i, make_splits(data[i]))),
        [0, ta])
    out = ta.concat()
    print(out.eval())
    # [b'com' b'com.abc' b'com.abc.display' b'com.abc.backend' b'com.xyz'
    #  b'com.xyz.forte' b'com' b'com.pqr' b'npr' b'npr.goog']

我不确定您是否希望总结果像那样串联起来,或者您希望将tf.unique应用于全局结果,但在任何情况下,想法都是一样的

相关问题 更多 >