我有一节主要课
def main(args):
if type == train_pipeline_type:
strategy = TrainPipelineStrategy()
else:
strategy = TestPipelineStrategy()
for table in fetch_table_information_by_region(region):
split_required = DataUtils.load_from_dict(table, "split_required")
if split_required:
strategy.split(spark=spark, table_name=table_name,
data_loc=filtered_data_location, partition_column=partition_column,
split_output_dir= split_output_dir)
logger.info("Data Split for table : {} completed".format(table_name))
我的TrainPipelineStrategy和TestPipelineStrategy是这样的-
class PipelineTypeStrategy(object):
def partition_data(self, x):
# Something
def prepare_split_data(self, y):
# Something
def write_split_data(self, z):
# Something
def split(self, p):
# Something
class TrainPipelineStrategy(PipelineTypeStrategy):
""""""
class TestPipelineStrategy(PipelineTypeStrategy):
def write_split_data(self, y):
# Something else
我的测试用例- 我需要测试在main方法中模拟split功能调用split的次数
以下是我尝试过的-
@patch('module.PipelineTypeStrategy.TrainPipelineStrategy')
def test_split_data_main_split_data_call_count(self, fake_train):
fake_train_functions = mock.Mock()
fake_train_functions.split.return_value = None
fake_train.return_value = fake_train_functions
test_args = ["", "--x=6"]
SplitData.main(args=test_args)
assert fake_train_functions.split.call_count == 10
当我尝试运行我的测试时,它创建了mock,但最终调用了实际的split函数。我做错什么了
这段代码的主要问题是,如果
TrainPipelineStrategy
是PipelineTypeStrategy
的嵌套类,那么设置patch
的方式就是TrainPipelineStrategy
是PipelineTypeStrategy
的子类由于
TrainPipelineStrategy
继承自PipelineTypeStrategy
,因此它可以直接访问split
,因此您可以修补split
,而无需任何引用PipelineTypeStrategy
(除非您特别希望修补PipelineTypeStrategy
中定义的split
的版本)但是,如果您只想模拟
PipelineTypeStrategy
类的split
方法,那么应该使用patch.object
装饰器来模拟split
,而不是模拟整个类,因为它更干净一些。举个例子:相关问题 更多 >
编程相关推荐