我正在用Python编写一个“批处理”过程(不使用任何框架)。
项目配置位于config.ini
文件中
[db]
db_uri = mysql+pymysql://root:password@localhost:3306/manage
我有另一个文件config.test
要在测试期间交换
[db]
db_uri = sqlite://
我有一个简单的test_sample.py
# tests/test_sample.py
import pytest
import shutil
import os
import batch
import batch_manage.utils.getconfig as getconfig_class
class TestClass():
def setup_method(self, method):
""" Rename the config """
shutil.copyfile("config.ini", "config.bak")
os.remove('config.ini')
shutil.copyfile("config.test", "config.ini")
def teardown_method(self, method):
""" Replace the config """
shutil.copyfile("config.bak", "config.ini")
os.remove('config.bak')
def test_can_get_all_data_from_table(self):
conf = getconfig_class.get_config('db')
db_uri = conf.get('db_uri')
assert db_uri == "sqlite://"
# This pass! ok!
people = batch.get_all_people()
assert len(people) == 0
# This fails, because counts the records in production database
db_uri
assert是ok(在测试时是sqlite而不是mysql),但len不是0,而是42(mysql数据库中的记录数)
我怀疑SqlAlchemy ORM的会话有问题。我做了几次尝试,没有可能覆盖/删除它
代码的其余部分非常简单:
# batch_manage/models/base.py
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
import batch_manage.utils.getconfig as getconfig_class
conf = getconfig_class.get_config('db')
db_uri = conf.get('db_uri')
engine = create_engine(db_uri)
Session = sessionmaker(bind=engine)
Base = declarative_base()
# batch_manage/models/persone.py
from sqlalchemy import Column, String, Integer, Date
from batch_manage.models.base import Base
class Persone(Base):
__tablename__ = "persone"
idpersona = Column(Integer, primary_key=True)
nome = Column(String)
created_at = Column(Date)
def __init__(self, nome, created_at):
self.nome = nome
self.created_at = created_at
以及batch.py
本身
# batch.py
import click
from batch_manage.models.base import Session
from batch_manage.models.persone import Persone
def get_all_people():
""" Get all people from database """
session = Session()
people = session.query(Persone).all()
return people
@click.command()
def batch():
click.echo("------------------------------")
click.echo("Running Batch")
click.echo("------------------------------")
people = get_all_people()
for item in people:
print(f"Persona con ID {item.idpersona} creata il {item.created_at}")
if __name__ == '__main__':
batch()
目前,我通过以下方式进行了测试:
def test_can_get_all_data_from_table(self):
conf = getconfig_class.get_config('db')
db_uri = conf.get('db_uri')
assert db_uri == "sqlite://"
from sqlalchemy.orm import sessionmaker
from sqlalchemy import create_engine
engine = create_engine(db_uri)
Session = sessionmaker(bind=engine)
session = Session()
people = batch.get_all_people(session)
assert len(people) == 0
和get_all_people
方法
def get_all_people(session = None):
""" Get all people from database """
if session is None:
session = Session()
people = session.query(Persone).all()
return people
但这种解决方案并不优雅,而且会降低代码覆盖率,因为if路径没有遵循
因此,如果我正确地遵循了您的代码,那么看起来您是在设置测试之前导入ORM内容。以下是您当前的操作顺序:
batch.py
已导入李>models/base.py
文件的顶级模块代码中,配置要使用的数据库李>因此,对于解决方案:
在测试本身中导入所有模块
如果您只是想更改操作顺序,在进入测试之前不要导入代码。无论如何,这通常是一种很好的测试实践:
这可能会解决您眼前的问题,但可能有一个更优雅的解决方案
不要立即配置数据库
我不知道您是否正在使用Flask,但不管怎样,the Flask testing documentation都有一些关于如何设置测试数据库的好说明。导入模块后,您需要配置数据库URL
例如:
请注意,我还没有定义引擎。我可以在运行时这样做
在主代码中,在向用户提供内容之前,您需要调用
setup_engine
。在您的测试环境中,您可以调用自己的setup_engine
,它绑定到测试环境相关问题 更多 >
编程相关推荐