我正在做帧生成。对于数据集/train/文件夹(例如1.png)中的每个图像,我生成了一个包含100个图像的序列,并将所有图像保存到一个数据集/frames/train/中,作为(1_1.png…1_100.png),下面是我的文件夹结构示例:
Dataset:
train:
1.png
2.png
3.png
.
.
N.png
frames:
train:
1_1.png
1_2.png
.
.
N_100.png
2_1.png
2_2.png
.
.
N_100.png
我已经创建了自定义数据加载器,将生成的帧堆叠为通道以形成序列,但我的问题是,在创建序列时,我不想让来自图像2的帧与来自1的帧重叠,如何确保不同的帧不重叠
这是我的自定义数据加载器:
class LevelSetDataset(Dataset):
"""
Dataset object for CNN models
Temporal is defined implicitly
as the number of channels
example:
- X dimension
[H, W, C=number_of_timestap(t)]
- Y dimension
[W, W, C =(t+1)]
"""
def __init__(self, input_image_path:str,
target_image_path:str,
threshold:float=0.5,
num_input_steps:int=3,
num_future_steps:int=1,
image_dimension:int=32,
data_transformations=None,
istraining_mode:bool=True
):
self.input_image_path = input_image_path
self.target_image_path = target_image_path
self.threshold = threshold
self.num_input_steps = num_input_steps
self.num_future_steps = num_future_steps
self.image_dimension = image_dimension
self.data_transformations= data_transformations
self.istraining_mode = istraining_mode
# get a list of input filenames as sort them (e.g. 1.png, 2.png,..,N.png)
input_image_fp = sorted(glob(os.path.join(self.input_image_path , "*")),
key=lambda x: int(os.path.basename(x).split('.')[0])
)
# repeat the input image untill it matches the number of segmentation
# step of the target image
self.input_image_fp = [i for i in input_image_fp for _ in range(100)]
# get a list of the target filenames and sort them by the first id and second
# id after the underscore (e.g. 1_1.png, 1_2,..,N_M.png)
self.target_image_fp= sorted(glob(os.path.join(self.target_image_path , "*")),
key=lambda x: (int(os.path.basename(x).split('_')[0]),
int(os.path.basename(x).split('_')[1].split('.')[0]))
)
# check if in training mode
# to apply transformations
if (self.data_transformations is None) and (self.istraining_mode):
self.data_transformations= torchvision.transforms.Compose([
torchvision.transforms.Resize(size=(self.image_dimension,self.image_dimension),
interpolation=Image.BILINEAR),
torchvision.transforms.RandomHorizontalFlip(p=0.5),
torchvision.transforms.RandomVerticalFlip(p=0.5),
torchvision.transforms.ToTensor()
])
if (self.data_transformations is None) and (not self.istraining_mode):
self.data_transformations== torchvision.transforms.Compose([
torchvision.transforms.Resize(size=(self.image_dimension,self.image_dimension),
interpolation=Image.BILINEAR),
torchvision.transforms.ToTensor()
])
self.transforms = self.data_transformations
def _create_binary_mask(self, x):
x[x>=self.threshold] = 1
x[x <self.threshold] = 0
return x
def _stat_norm(self, x):
norm =torchvision.transforms.Compose([torchvision.transforms.Resize(
size=(self.image_dimension,self.image_dimension),
interpolation=Image.BILINEAR),
torchvision.transforms.ToTensor()])
return norm(x)
def __len__(self):
return len(self.target_image_fp) - (self.num_input_steps+self.num_future_steps)
def __getitem__(self, index):
X = torch.zeros((self.image_dimension, self.image_dimension, self.num_input_steps+1))
for step_idx, step in enumerate(np.arange(index, self.num_input_steps, 1)):
target_image = Image.open(self.target_image_fp[step+self.num_input_steps+self.num_future_steps-1])
target_image = self.transforms(target_image)
target_image = self._create_binary_mask(target_image)
X[:, :, step_idx] = target_image # (t+1)
input_img = Image.open(self.input_image_fp[index]).convert('L')
# input_img = self.transforms(input_img)
input_img = self.transforms(input_img)
X[:, :, 0] = input_img
target_image = Image.open(self.target_image_fp[index+self.num_input_steps+self.num_future_steps-1])
target_image = self.transforms(target_image)
target_image = self._create_binary_mask(target_image)
image_name = self.target_image_fp[index+self.num_input_steps+self.num_future_steps-1].split('/')[-1]
Y = target_image
return X, Y, image_name
目前没有回答
相关问题 更多 >
编程相关推荐