在python中,我有一个迭代器返回一个固定范围[0, N]
内的无限索引字符串,称为Sampler
。实际上我有一个列表,它们所做的只是返回范围[0, N_0], [N_0, N_1], ..., [N_{n-1}, N_n].
内的索引
我现在想做的是首先根据迭代器范围的长度选择其中一个迭代器,因此我有一个weights
列表[N_0, N_1 - N_0, ...]
,我选择其中一个:
iterator_idx = random.choices(range(len(weights)), weights=weights/weights.sum())[0]
接下来,我要做的是创建一个迭代器,它随机选择一个迭代器并选择一批M
样本
class BatchSampler:
def __init__(self, M):
self.M = M
self.weights = [weight_list]
self.samplers = [list_of_iterators]
]
self._batch_samplers = [
self.batch_sampler(sampler) for sampler in self.samplers
]
def batch_sampler(self, sampler):
batch = []
for batch_idx in sampler:
batch.append(batch_idx)
if len(batch) == self.M:
yield batch
if len(batch) > 0:
yield batch
def __iter__(self):
# First select one of the datasets.
iterator_idx = random.choices(
range(len(self.weights)), weights=self.weights / self.weights.sum()
)[0]
return self._batch_samplers[iterator_idx]
问题是iter()
似乎只被调用一次,因此只选择了第一次iterator_idx
。显然这是错误的。。。解决这个问题的办法是什么?
当pytorch中有多个数据集,但只希望从其中一个数据集中采样批次时,可能会出现这种情况
在我看来,您似乎想要定义自己的容器类型。
我将尝试提供一些标准方法的示例
(希望不会遗漏太多细节)
您应该能够重用这些简单示例中的一个,
进入你自己的班级
使用just _ugetItem(支持索引和循环):
object.__getitem__
使用迭代器协议:
object.__iter__
Iterator Types
使用发电机:
Generator Types
generator
generator iterator
6.2.9. Yield expressions
相关问题 更多 >
编程相关推荐