Python无法模拟继承类的调用

2024-10-02 00:29:51 发布

您现在位置:Python中文网/ 问答频道 /正文

我有一节主要课

def main(args):
    if type == train_pipeline_type:
        strategy = TrainPipelineStrategy()
    else:
        strategy = TestPipelineStrategy()
    for table in fetch_table_information_by_region(region):
        split_required = DataUtils.load_from_dict(table, "split_required")
        if split_required:
            strategy.split(spark=spark, table_name=table_name,
                           data_loc=filtered_data_location, partition_column=partition_column,
                           split_output_dir= split_output_dir)
            logger.info("Data Split for table : {} completed".format(table_name))

我的TrainPipelineStrategy和TestPipelineStrategy是这样的-

class PipelineTypeStrategy(object):

    def partition_data(self, x):
        # Something

    def prepare_split_data(self, y):
        # Something

    def write_split_data(self, z):
        # Something

    def split(self, p):
        # Something


class TrainPipelineStrategy(PipelineTypeStrategy):
    """"""


class TestPipelineStrategy(PipelineTypeStrategy):

    def write_split_data(self, y):
        # Something else

我的测试用例- 我需要测试在main方法中模拟split功能调用split的次数

以下是我尝试过的-

@patch('module.PipelineTypeStrategy.TrainPipelineStrategy')
    def test_split_data_main_split_data_call_count(self, fake_train):
        fake_train_functions = mock.Mock()
        fake_train_functions.split.return_value = None
        fake_train.return_value = fake_train_functions
        test_args = ["", "--x=6"]
        SplitData.main(args=test_args)
        assert fake_train_functions.split.call_count == 10

当我尝试运行我的测试时,它创建了mock,但最终调用了实际的split函数。我做错什么了


Tags: selfdatamaindeftableargstrainfunctions
1条回答
网友
1楼 · 发布于 2024-10-02 00:29:51

这段代码的主要问题是,如果TrainPipelineStrategyPipelineTypeStrategy的嵌套类,那么设置patch的方式就是TrainPipelineStrategyPipelineTypeStrategy的子类

由于TrainPipelineStrategy继承自PipelineTypeStrategy,因此它可以直接访问split,因此您可以修补split,而无需任何引用PipelineTypeStrategy(除非您特别希望修补PipelineTypeStrategy中定义的split的版本)

但是,如果您只想模拟PipelineTypeStrategy类的split方法,那么应该使用patch.object装饰器来模拟split,而不是模拟整个类,因为它更干净一些。举个例子:

class TestClass(unittest.TestCase):
    @patch.object(TrainPipelineStrategy, 'split', return_value=None)
    def test_split_data_main_split_data_call_count(self, mock_split):
        test_args = ["", " x=6"]
        SplitData.main(args=test_args)
        self.assertEqual(mock_split.call_count, 10)

相关问题 更多 >

    热门问题