很遗憾,我遇到以下运行时错误:
该错误出现在最后一批中的第1个历元(因此所有其他批次都会运行), 我不知道是什么原因导致我的代码中出现错误。下面是我的函数的代码片段
def gradient_penalty(critic, real, fake, device):
BATCH_SIZE, C, H, W = real.shape
epsilon = torch.rand(size = (BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
# generate tensor filles only with ones
x = torch.ones(size = (BATCH_SIZE, C, H, W), dtype = int)
# interpolate images
interpolated_images = real * epsilon + fake * (x - epsilon)
变量real
代表图像,其形状为(128, 3, 64, 64)
。
我需要承认,我没有找到具体的错误消息,我。E张量的形状哪里不重合
任何帮助都将不胜感激
使用} 时,可以放弃未完成的批处理:
drop_last
参数实例化^{然而,这似乎有点激进,因为数据集中的128元素将被浪费
相关问题 更多 >
编程相关推荐