从包含y个总嵌入的Tensorflow变量平均x嵌入的最有效方法

2024-09-30 20:33:39 发布

您现在位置:Python中文网/ 问答频道 /正文

假设我有y个全部的嵌入,它们都是用这段代码检索的

embeds = tf.nn.embedding_lookup(embeddings, train_dataset)

所以数据应该是这样的

^{pr2}$

假设,我要取3个嵌入组的平均值。所以有点像

averaged_embeds = [ averageOf(embed45, embed2, embed939) , averageOf(embed3, embed32, embed2), . . . . etc]

所以当评估的时候会像这样

averaged_embeds = [ averagedEmbeds1, averagedEmbeds2, averagedEmbeds3, . . . etc]

最好的办法是什么?在

我的第一个想法是tf.段平均值但据我所知,它只能取每个嵌入中的平均值,而不是一堆嵌入的平均值(如果这是错误的,请告诉我)。在

还有tf.reduce_平均值它可以沿着一个指定的维度求平均值,但它将取所有嵌入项的平均值,而不是某个特定数量的束。在


Tags: 数据代码tfetctrainnnembeddinglookup
2条回答

实际上tf.段平均值可用于平均嵌入,下面是一个示例

%matplotlib inline
from __future__ import print_function
import collections
import math
import numpy as np
import os
import random
import tensorflow as tf
import zipfile
from matplotlib import pylab
from six.moves import range
from six.moves.urllib.request import urlretrieve

tf.enable_eager_execution() 

train_dataset = [3, 7, 5 ,12 ,19 ,6, 10, 8]

embeddings = tf.get_variable( 'generator', 
    initializer=tf.random_uniform([20, 6], -1.0, 1.0))

embed = tf.nn.embedding_lookup(embeddings, train_dataset)

print(embed)

segments= np.arange(2).repeat(4)
print(segments)

averaged_embeds = tf.segment_mean( embed, segments, name=None)
print(averaged_embeds)

#Use this to confirm that the embeddings were averaged correctly
print( np.mean([   -0.78844213 ,  -0.2852435 , 0.58107734, 0.12990952   ]))

不幸的是,如果嵌入张量有两个以上的维,我还没有找到一个平均值的方法。我尝试过使用segments具有多个维度,但似乎不起作用。到目前为止,我在tf.segment_mean操作前后重塑embed和{}张量。在

您可以使用^{},但这意味着,如果参数num_or_size_splits是标量,那么它应该是输入长度的倍数,或者沿着拆分维度的维数之和应该与输入的长度匹配(对于^{}也是一样)。更好的方法是在不适用这些限制的情况下使用tf.extract_image_patches

# generate batch of inputs
def get_batch(tensor, batch, k):
    return tf.extract_image_patches(tensor, 
                                ksizes=[1, batch, k, 1], 
                                strides=[1, batch, k, 1], 
                                rates=[1, 1, 1, 1], padding='VALID')


embed_dim = 5
batch = 3
x = np.arange(200).reshape(-1, embed_dim)

embeddings = tf.constant(x)
train_dataset = tf.constant([0,1,2,5,6,7])
embeds = tf.nn.embedding_lookup(embeddings, train_dataset)


split = tf.reshape(get_batch(embeds[None,..., None], batch, embed_dim),
                   [-1, batch, embed_dim])
avg = tf.reduce_mean(split, 1)

with tf.Session() as sess:
   print(sess.run(embeds))
   #[[ 0  1  2  3  4]
   # [ 5  6  7  8  9]
   # [10 11 12 13 14]
   # [25 26 27 28 29]
   # [30 31 32 33 34]
   # [35 36 37 38 39]]

   print(sess.run(split))
   #[[[ 0  1  2  3  4]
   # [ 5  6  7  8  9]
   # [10 11 12 13 14]]

   # [[25 26 27 28 29]
   #  [30 31 32 33 34]
   #  [35 36 37 38 39]]]

   print(sess.run(avg))
   #[[ 5  6  7  8  9]
   # [30 31 32 33 34]]

对于三维段,代码更改为:

^{pr2}$

相关问题 更多 >