Tensorflow数据集是否会在不同时期之间进行洗牌,并在洗牌后进行数据集转换?

2024-09-30 01:32:23 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在开发一个TensorFlow管道,在这里我把一堆信号加载到一个数据集中,我洗牌这些信号,然后在信号上加窗,然后批处理和重复。此数据集用于训练特斯拉斯使用模型.拟合函数调用。非常重要的是,信号的窗口不被洗牌,这就是为什么这是数据集转换的顺序。在

我想知道信号的顺序是否会在两个时代之间被洗牌?我发现dataset.shuffle().batch().repeat()会在不同时期之间对数据集进行无序处理,但这对我的应用程序不起作用,因为我需要在洗牌之后进行窗口化和其他转换。在

我使用的是TensorFlow版本1.13.1。在

#... some pre-processing on the signals 
signalList = [...] # a list of tuples (data, label)
dataset = tf.data.Dataset.from_generator(lambda: signalList)
dataset = dataset.shuffle(buffer_size=self.buffer_size)  ## will this shuffle be repeated??
dataset = dataset.map(...) # windowing and other transforms
dataset = dataset.batch()
dataset = dataset.repeat()

model.fit(dataset, ...)

编辑:我感兴趣的行为是,我希望信号的顺序会随着每个时代的变化而重新洗牌。如果我有3个信号

^{pr2}$

然后输出如下所示:

tf.Tensor([signal0,signal2,signal1],...) # equivalent to tf.Tensor([window0_0,window0_1,window2_0,window1_0,window1_1,window1_2])
tf.Tensor([signal1,signal0,signal2],...) # equivalent to tf.Tensor([window1_0,window1_1,window1_2,window0_0,window0_1,window2_0]) 

在哪里变换数据集.map(windowing).shuffle().batch().repeat()会产生类似这样的结果(我对此不感兴趣)

tf.Tensor([window0_1,window1_1,window2_0,window1_0,window0_0,window1_2])
tf.Tensor([window0_0,window1_2,window0_1,window2_0,window1_1,window1_0]) 

Tags: 数据信号顺序tftensorflowbatchdatasettensor
2条回答

经过一番调查,我意识到,是的,shuffle在每个epoch之后都被调用,即使在shuffle之后和批处理之前还有其他转换。我不确定这对管道意味着什么(比如,我不确定窗口是否在每个纪元中都被调用,并且减慢了处理速度),但是我创建了一个jupyter笔记本,在那里我创建了一个小版本的管道

signalList = [...] # a list of tuples (data, label)
dataset = tf.data.Dataset.from_generator(lambda: signalList)
dataset = dataset.shuffle(buffer_size=self.buffer_size)  
dataset = dataset.map(...) # windowing and other transforms
dataset = dataset.batch()
dataset = dataset.repeat()

创建了迭代器

^{pr2}$

绘制了几个时代的信号

next_ = iterator.get_next()
for i in range(10):  # 10 epochs
    full_signal = []
    for j in range(29):  # 29 events for this epoch
        next_ = iterator.get_next()
        full_signal = np.concatenate((full_signal, next_[0][0]), axis=None)

    fig = plt.figure(figsize=(18, 5))
    plt.plot(full_signal)

看到信号看起来总是以不同的顺序排列,这意味着它们在每一个时代之后都会被重新洗牌。在

如果有人有更详细的答案,在哪里他们可以解释这是如何与DatasetAPI编译的,或者如果他们可以澄清这些转换的顺序是否减慢了管道,我将非常感谢!在

您可以将可选参数传递给.shuffle(),以防止每个历元的重新洗牌。在

所以,如果我有这样的数据集:

def gen():
  yield 1
  yield 2
  yield 3

ds = tf.data.Dataset.from_generator(gen, output_shapes=(), output_types=tf.int32)

然后做:

^{pr2}$

输出:

tf.Tensor([3 2 1], shape=(3,), dtype=int32)
tf.Tensor([1 3 2], shape=(3,), dtype=int32)
tf.Tensor([2 1 3], shape=(3,), dtype=int32)
tf.Tensor([3 1 2], shape=(3,), dtype=int32)
tf.Tensor([2 3 1], shape=(3,), dtype=int32)

每一个纪元重新排列我的3个元素。这是我理解你想要避免的行为。在

相反,如果我愿意:

shuffled_and_batched = ds.shuffle(3, reshuffle_each_iteration=False).batch(3).repeat()

然后得到输出:

tf.Tensor([1 3 2], shape=(3,), dtype=int32)
tf.Tensor([1 3 2], shape=(3,), dtype=int32)
tf.Tensor([1 3 2], shape=(3,), dtype=int32)
tf.Tensor([1 3 2], shape=(3,), dtype=int32)
tf.Tensor([1 3 2], shape=(3,), dtype=int32)

顺序洗牌一次,然后重复使用每个历元。在

相关问题 更多 >

    热门问题