diff --git a/README.rst b/README.rst index 71b10524..4c5fdbc8 100755 --- a/README.rst +++ b/README.rst @@ -26,7 +26,7 @@ You can install this package using pypi: ``pip install flask-mongoengine`` Tests ===== To run the test suite, ensure you are running a local copy of Flask-MongoEngine -and run: ``python setup.py nosetests``. +and simply run: ``pytest``. To run the test suite on every supported versions of Python, PyPy and MongoEngine you can use ``tox``. Ensure tox and each supported Python, PyPy versions are installed in your environment: @@ -38,11 +38,7 @@ Ensure tox and each supported Python, PyPy versions are installed in your enviro # Run the test suites $ tox -To run a single or selected test suits, use the nosetest convention. E.g. - -.. code-block:: shell - - $ python setup.py nosetests --tests tests/example_test.py:ExampleTestClass.example_test_method +To run a single or selected test suits, use pytest `-k` option. Contributing ============ diff --git a/requirements-dev.txt b/requirements-dev.txt index ea811ca2..f8829e3e 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -5,4 +5,3 @@ black pre-commit pytest pytest-cov -nose diff --git a/setup.cfg b/setup.cfg index fd3465ee..c4731f72 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,23 +1,10 @@ -[nosetests] -rednose = 1 -verbosity = 2 -detailed-errors = 1 -cover-erase = 1 -cover-branches = 1 -cover-package = flask_mongoengine -tests = tests - [tool:pytest] addopts = --cov=flask_mongoengine --cov-config=setup.cfg testpaths = tests env_override_existing_values = 1 -filterwarnings = - ignore::UserWarning - ignore::DeprecationWarning - ignore::PendingDeprecationWarning [flake8] ignore=E501,F403,F405,I201,W503,E203 -max-line-length = 90 +max-line-length=90 exclude=build,dist,docs,examples,venv,.tox,.eggs max-complexity=17 diff --git a/setup.py b/setup.py index d99967d9..097d8829 100644 --- a/setup.py +++ b/setup.py @@ -32,7 +32,7 @@ def get_version(version_tuple): version_line = list(filter(lambda l: l.startswith("VERSION"), open(init)))[0] version = get_version(eval(version_line.split("=")[-1])) -test_requirements = ["coverage", "nose", "pytest", "pytest-cov"] +test_requirements = ["coverage", "pytest", "pytest-cov"] setup( name="flask-mongoengine", diff --git a/tests/__init__.py b/tests/__init__.py index a5eca11d..e69de29b 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,25 +0,0 @@ -import unittest - -import flask -import mongoengine - - -class FlaskMongoEngineTestCase(unittest.TestCase): - """Parent class of all test cases""" - - def setUp(self): - self.app = flask.Flask(__name__) - self.app.config["MONGODB_DB"] = "test_db" - self.app.config["TESTING"] = True - self.ctx = self.app.app_context() - self.ctx.push() - # Mongoengine keep a global state of the connections that must be - # reset before each test. - # Given it doesn't expose any method to get the list of registered - # connections, we have to do the cleaning by hand... - mongoengine.connection._connection_settings.clear() - mongoengine.connection._connections.clear() - mongoengine.connection._dbs.clear() - - def tearDown(self): - self.ctx.pop() diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..59dcb280 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,52 @@ +from datetime import datetime + +import mongoengine +import pytest +from flask import Flask + +from flask_mongoengine import MongoEngine + + +@pytest.fixture() +def app(): + app = Flask(__name__) + app.config["TESTING"] = True + app.config["WTF_CSRF_ENABLED"] = False + + with app.app_context(): + yield app + + mongoengine.connection.disconnect_all() + + +@pytest.fixture() +def db(app): + app.config["MONGODB_HOST"] = "mongodb://localhost:27017/flask_mongoengine_test_db" + test_db = MongoEngine(app) + db_name = test_db.connection.get_database("flask_mongoengine_test_db").name + + if not db_name.endswith("_test_db"): + raise RuntimeError( + f"DATABASE_URL must point to testing db, not to master db ({db_name})" + ) + + # Clear database before tests, for cases when some test failed before. + test_db.connection.drop_database(db_name) + + yield test_db + + # Clear database after tests, for graceful exit. + test_db.connection.drop_database(db_name) + + +@pytest.fixture() +def todo(db): + class Todo(db.Document): + title = mongoengine.StringField(max_length=60) + text = mongoengine.StringField() + done = mongoengine.BooleanField(default=False) + pub_date = mongoengine.DateTimeField(default=datetime.utcnow) + comments = mongoengine.ListField(mongoengine.StringField()) + comment_count = mongoengine.IntField() + + return Todo diff --git a/tests/test_base.py b/tests/test_base.py new file mode 100644 index 00000000..ca0ca0ea --- /dev/null +++ b/tests/test_base.py @@ -0,0 +1,19 @@ +"""Tests for base MongoEngine class.""" +from flask_mongoengine import MongoEngine +import pytest + + +def test_mongoengine_class__should_raise_type_error__if_config_not_dict(): + """MongoEngine will handle None values, but will pass anything else as app.""" + input_value = "Not dict type" + with pytest.raises(TypeError) as error: + MongoEngine(input_value) + assert str(error.value) == "Invalid Flask application instance" + + +@pytest.mark.parametrize("input_value", [None, "Not dict type"]) +def test_init_app__should_raise_type_error__if_config_not_dict(input_value): + db = MongoEngine() + with pytest.raises(TypeError) as error: + db.init_app(input_value) + assert str(error.value) == "Invalid Flask application instance" diff --git a/tests/test_basic_app.py b/tests/test_basic_app.py index 49f65010..db823340 100644 --- a/tests/test_basic_app.py +++ b/tests/test_basic_app.py @@ -1,75 +1,45 @@ -import datetime - import flask +import pytest from bson import ObjectId -from flask_mongoengine import MongoEngine -from tests import FlaskMongoEngineTestCase - - -class BasicAppTestCase(FlaskMongoEngineTestCase): - def setUp(self): - super(BasicAppTestCase, self).setUp() - db = MongoEngine() - - class Todo(db.Document): - title = db.StringField(max_length=60) - text = db.StringField() - done = db.BooleanField(default=False) - pub_date = db.DateTimeField(default=datetime.datetime.now) - - db.init_app(self.app) - - Todo.drop_collection() - self.Todo = Todo - - @self.app.route("/") - def index(): - return "\n".join(x.title for x in self.Todo.objects) - @self.app.route("/add", methods=["POST"]) - def add(): - form = flask.request.form - todo = self.Todo(title=form["title"], text=form["text"]) - todo.save() - return "added" +@pytest.fixture(autouse=True) +def setup_endpoints(app, todo): + Todo = todo - @self.app.route("/show//") - def show(id): - todo = self.Todo.objects.get_or_404(id=id) - return "\n".join([todo.title, todo.text]) + @app.route("/") + def index(): + return "\n".join(x.title for x in Todo.objects) - self.db = db + @app.route("/add", methods=["POST"]) + def add(): + form = flask.request.form + todo = Todo(title=form["title"], text=form["text"]) + todo.save() + return "added" - def test_connection_default(self): - self.app.config["MONGODB_SETTINGS"] = {} - self.app.config["TESTING"] = True + @app.route("/show//") + def show(id): + todo = Todo.objects.get_or_404(id=id) + return "\n".join([todo.title, todo.text]) - db = MongoEngine() - # Disconnect to drop connection from setup. - db.disconnect() - db.init_app(self.app) - def test_with_id(self): - c = self.app.test_client() - resp = c.get("/show/%s/" % ObjectId()) - self.assertEqual(resp.status_code, 404) +def test_with_id(app, todo): + Todo = todo + client = app.test_client() + response = client.get("/show/%s/" % ObjectId()) + assert response.status_code == 404 - c.post("/add", data={"title": "First Item", "text": "The text"}) + client.post("/add", data={"title": "First Item", "text": "The text"}) - resp = c.get("/show/%s/" % self.Todo.objects.first_or_404().id) - self.assertEqual(resp.status_code, 200) - self.assertEqual(resp.data.decode("utf-8"), "First Item\nThe text") + response = client.get("/show/%s/" % Todo.objects.first_or_404().id) + assert response.status_code == 200 + assert response.data.decode("utf-8") == "First Item\nThe text" - def test_basic_insert(self): - c = self.app.test_client() - c.post("/add", data={"title": "First Item", "text": "The text"}) - c.post("/add", data={"title": "2nd Item", "text": "The text"}) - rv = c.get("/") - self.assertEqual(rv.data.decode("utf-8"), "First Item\n2nd Item") - def test_request_context(self): - with self.app.test_request_context(): - todo = self.Todo(title="Test", text="test") - todo.save() - self.assertEqual(self.Todo.objects.count(), 1) +def test_basic_insert(app): + client = app.test_client() + client.post("/add", data={"title": "First Item", "text": "The text"}) + client.post("/add", data={"title": "2nd Item", "text": "The text"}) + response = client.get("/") + assert response.data.decode("utf-8") == "First Item\n2nd Item" diff --git a/tests/test_connection.py b/tests/test_connection.py index 670de5a1..5ebe5f54 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -1,39 +1,266 @@ import mongoengine import pymongo +import pytest +from mongoengine.connection import ConnectionFailure from mongoengine.context_managers import switch_db -from nose import SkipTest -from nose.tools import assert_raises +from pymongo.database import Database from pymongo.errors import InvalidURI +from pymongo.mongo_client import MongoClient from pymongo.read_preferences import ReadPreference -from flask_mongoengine import MongoEngine -from tests import FlaskMongoEngineTestCase +from flask_mongoengine import MongoEngine, current_mongoengine_instance -class ConnectionTestCase(FlaskMongoEngineTestCase): - def _do_persist(self, db, alias=None): - """Initialize a test Flask application and persist some data in - MongoDB, ultimately asserting that the connection works. - """ - if alias: +def test_connection__should_use_defaults__if_no_settings_provided(app): + """Make sure a simple connection to a standalone MongoDB works.""" + db = MongoEngine() - class Todo(db.Document): - meta = {"db_alias": alias} - title = db.StringField(max_length=60) - text = db.StringField() - done = db.BooleanField(default=False) + # Verify no extension for Mongoengine yet created for app + assert app.extensions == {} + assert current_mongoengine_instance() is None - else: + # Create db connection. Should return None. + assert db.init_app(app) is None - class Todo(db.Document): - title = db.StringField(max_length=60) - text = db.StringField() - done = db.BooleanField(default=False) + # Verify db added to Flask extensions. + assert current_mongoengine_instance() == db - db.init_app(self.app) - Todo.drop_collection() + # Verify db settings passed to pymongo driver. + # Default mongoengine db is 'default', default Flask-Mongoengine db is 'test'. + connection = mongoengine.get_connection() + mongo_engine_db = mongoengine.get_db() + assert isinstance(mongo_engine_db, Database) + assert isinstance(connection, MongoClient) + assert mongo_engine_db.name == "test" + assert connection.HOST == "localhost" + assert connection.PORT == 27017 - # Test persist + +@pytest.mark.parametrize( + ("config_extension"), + [ + { + "MONGODB_SETTINGS": { + "ALIAS": "simple_conn", + "HOST": "localhost", + "PORT": 27017, + "DB": "flask_mongoengine_test_db", + } + }, + { + "MONGODB_HOST": "localhost", + "MONGODB_PORT": 27017, + "MONGODB_DB": "flask_mongoengine_test_db", + "MONGODB_ALIAS": "simple_conn", + }, + ], + ids=("Dict format", "Config variable format"), +) +def test_connection__should_pass_alias__if_provided(app, config_extension): + """Make sure a simple connection pass ALIAS setting variable.""" + db = MongoEngine() + app.config.update(config_extension) + + # Verify no extension for Mongoengine yet created for app + assert app.extensions == {} + assert current_mongoengine_instance() is None + + # Create db connection. Should return None. + assert db.init_app(app) is None + + # Verify db added to Flask extensions. + assert current_mongoengine_instance() == db + + # Verify db settings passed to pymongo driver. + # ALIAS is used to find correct connection. + # As we do not use default alias, default call to mongoengine.get_connection + # should raise. + with pytest.raises(ConnectionFailure): + mongoengine.get_connection() + + connection = mongoengine.get_connection("simple_conn") + mongo_engine_db = mongoengine.get_db("simple_conn") + assert isinstance(mongo_engine_db, Database) + assert isinstance(connection, MongoClient) + assert mongo_engine_db.name == "flask_mongoengine_test_db" + assert connection.HOST == "localhost" + assert connection.PORT == 27017 + + +@pytest.mark.parametrize( + ("config_extension"), + [ + { + "MONGODB_SETTINGS": { + "HOST": "mongodb://localhost:27017/flask_mongoengine_test_db" + } + }, + { + "MONGODB_HOST": "mongodb://localhost:27017/flask_mongoengine_test_db", + "MONGODB_PORT": 27017, + "MONGODB_DB": "should_ignore_it", + }, + ], + ids=("Dict format", "Config variable format"), +) +def test_connection__should_parse_host_uri__if_host_formatted_as_uri( + app, config_extension +): + """Make sure a simple connection pass ALIAS setting variable.""" + db = MongoEngine() + app.config.update(config_extension) + + # Verify no extension for Mongoengine yet created for app + assert app.extensions == {} + assert current_mongoengine_instance() is None + + # Create db connection. Should return None. + assert db.init_app(app) is None + + # Verify db added to Flask extensions. + assert current_mongoengine_instance() == db + + connection = mongoengine.get_connection() + mongo_engine_db = mongoengine.get_db() + assert isinstance(mongo_engine_db, Database) + assert isinstance(connection, MongoClient) + assert mongo_engine_db.name == "flask_mongoengine_test_db" + assert connection.HOST == "localhost" + assert connection.PORT == 27017 + + +@pytest.mark.parametrize( + ("config_extension"), + [ + { + "MONGODB_SETTINGS": { + "HOST": "mongomock://localhost:27017/flask_mongoengine_test_db" + } + }, + { + "MONGODB_SETTINGS": { + "ALIAS": "simple_conn", + "HOST": "localhost", + "PORT": 27017, + "DB": "flask_mongoengine_test_db", + "IS_MOCK": True, + } + }, + {"MONGODB_HOST": "mongomock://localhost:27017/flask_mongoengine_test_db"}, + ], + ids=("Dict format as URI", "Dict format as Param", "Config variable format as URI"), +) +def test_connection__should_parse_mongo_mock_uri__as_uri_and_as_settings( + app, config_extension +): + """Make sure a simple connection pass ALIAS setting variable.""" + db = MongoEngine() + app.config.update(config_extension) + + # Verify no extension for Mongoengine yet created for app + assert app.extensions == {} + assert current_mongoengine_instance() is None + + # Create db connection. Should return None. + + with pytest.raises(RuntimeError) as error: + assert db.init_app(app) is None + + assert str(error.value) == "You need mongomock installed to mock MongoEngine." + + +@pytest.mark.parametrize( + ("config_extension"), + [ + { + "MONGODB_SETTINGS": { + "HOST": "postgre://localhost:27017/flask_mongoengine_test_db" + } + }, + {"MONGODB_HOST": "mysql://localhost:27017/flask_mongoengine_test_db"}, + ], + ids=("Dict format as URI", "Config variable format as URI"), +) +def test_connection__should_raise__if_uri_not_properly_formatted(app, config_extension): + """Make sure a simple connection pass ALIAS setting variable.""" + db = MongoEngine() + app.config.update(config_extension) + + # Verify no extension for Mongoengine yet created for app + assert app.extensions == {} + assert current_mongoengine_instance() is None + + # Create db connection. Should return None. + + with pytest.raises(InvalidURI) as error: + assert db.init_app(app) is None + + assert ( + str(error.value) + == "Invalid URI scheme: URI must begin with 'mongodb://' or 'mongodb+srv://'" + ) + + +def test_connection__should_accept_host_as_list(app): + """Make sure MONGODB_HOST can be a list hosts.""" + db = MongoEngine() + app.config["MONGODB_SETTINGS"] = { + "ALIAS": "host_list", + "HOST": ["localhost:27017"], + "DB": "flask_mongoengine_list_test_db", + } + db.init_app(app) + + connection = mongoengine.get_connection("host_list") + mongo_engine_db = mongoengine.get_db("host_list") + assert isinstance(mongo_engine_db, Database) + assert isinstance(connection, MongoClient) + assert mongo_engine_db.name == "flask_mongoengine_list_test_db" + assert connection.HOST == "localhost" + assert connection.PORT == 27017 + + +def test_multiple_connections(app): + """Make sure establishing multiple connections to a standalone + MongoDB and switching between them works. + """ + db = MongoEngine() + app.config["MONGODB_SETTINGS"] = [ + { + "ALIAS": "default", + "DB": "flask_mongoengine_test_db_1", + "HOST": "localhost", + "PORT": 27017, + }, + { + "ALIAS": "alternative", + "DB": "flask_mongoengine_test_db_2", + "HOST": "localhost", + "PORT": 27017, + }, + ] + + class Todo(db.Document): + title = db.StringField(max_length=60) + + db.init_app(app) + # Drop default collection from init + Todo.drop_collection() + Todo.meta = {"db_alias": "alternative"} + # Drop 'alternative' collection initiated early. + Todo.drop_collection() + + # Make sure init correct and both databases are clean + with switch_db(Todo, "default") as Todo: + doc = Todo.objects().first() + assert doc is None + + with switch_db(Todo, "alternative") as Todo: + doc = Todo.objects().first() + assert doc is None + + # Test saving a doc via the default connection + with switch_db(Todo, "default") as Todo: todo = Todo() todo.text = "Sample" todo.title = "Testing" @@ -41,159 +268,61 @@ class Todo(db.Document): s_todo = todo.save() f_to = Todo.objects().first() - self.assertEqual(s_todo.title, f_to.title) + assert s_todo.title == f_to.title - def test_simple_connection(self): - """Make sure a simple connection to a standalone MongoDB works.""" - db = MongoEngine() - self.app.config["MONGODB_SETTINGS"] = { - "ALIAS": "simple_conn", - "HOST": "localhost", - "PORT": 27017, - "DB": "flask_mongoengine_test_db", - } - self._do_persist(db, alias="simple_conn") - - def test_host_as_uri_string(self): - """Make sure we can connect to a standalone MongoDB if we specify - the host as a MongoDB URI. - """ - db = MongoEngine() - self.app.config[ - "MONGODB_HOST" - ] = "mongodb://localhost:27017/flask_mongoengine_test_db" - self._do_persist(db) - - def test_mongomock_host_as_uri_string(self): - """Make sure we switch to mongomock if we specify the host as a mongomock URI. - """ - if mongoengine.VERSION < (0, 9, 0): - raise SkipTest("Mongomock not supported for mongoengine < 0.9.0") - db = MongoEngine() - self.app.config[ - "MONGODB_HOST" - ] = "mongomock://localhost:27017/flask_mongoengine_test_db" - with assert_raises(RuntimeError) as exc: - self._do_persist(db) - assert str(exc.exception) == "You need mongomock installed to mock MongoEngine." - - def test_mongomock_as_param(self): - """Make sure we switch to mongomock when providing IS_MOCK option. - """ - if mongoengine.VERSION < (0, 9, 0): - raise SkipTest("Mongomock not supported for mongoengine < 0.9.0") - db = MongoEngine() - self.app.config["MONGODB_SETTINGS"] = { - "ALIAS": "simple_conn", - "HOST": "localhost", - "PORT": 27017, - "DB": "flask_mongoengine_test_db", - "IS_MOCK": True, - } - with assert_raises(RuntimeError) as exc: - self._do_persist(db, alias="simple_conn") - assert str(exc.exception) == "You need mongomock installed to mock MongoEngine." - - def test_host_as_list(self): - """Make sure MONGODB_HOST can be a list hosts.""" - db = MongoEngine() - self.app.config["MONGODB_SETTINGS"] = { - "ALIAS": "host_list", - "HOST": ["localhost:27017"], - "DB": "flask_mongoengine_test_db", - } - self._do_persist(db, alias="host_list") - - def test_multiple_connections(self): - """Make sure establishing multiple connections to a standalone - MongoDB and switching between them works. - """ - db = MongoEngine() - self.app.config["MONGODB_SETTINGS"] = [ - { - "ALIAS": "default", - "DB": "flask_mongoengine_test_db_1", - "HOST": "localhost", - "PORT": 27017, - }, - { - "ALIAS": "alternative", - "DB": "flask_mongoengine_test_db_2", - "HOST": "localhost", - "PORT": 27017, - }, - ] - - class Todo(db.Document): - title = db.StringField(max_length=60) - text = db.StringField() - done = db.BooleanField(default=False) - meta = {"db_alias": "alternative"} - - db.init_app(self.app) - Todo.drop_collection() - - # Test saving a doc via the default connection - with switch_db(Todo, "default") as Todo: - todo = Todo() - todo.text = "Sample" - todo.title = "Testing" - todo.done = True - s_todo = todo.save() - - f_to = Todo.objects().first() - self.assertEqual(s_todo.title, f_to.title) - - # Make sure the doc doesn't exist in the alternative db - with switch_db(Todo, "alternative") as Todo: - doc = Todo.objects().first() - self.assertEqual(doc, None) - - # Make sure switching back to the default connection shows the doc - with switch_db(Todo, "default") as Todo: - doc = Todo.objects().first() - self.assertNotEqual(doc, None) - - def test_connection_with_invalid_uri(self): - """Make sure connecting via an invalid URI raises an InvalidURI - exception. - """ - self.app.config["MONGODB_HOST"] = "mongo://localhost" - self.assertRaises(InvalidURI, MongoEngine, self.app) - - def test_ingnored_mongodb_prefix_config(self): - """Config starting by MONGODB_ but not used by flask-mongoengine - should be ignored. - """ - db = MongoEngine() - self.app.config[ - "MONGODB_HOST" - ] = "mongodb://localhost:27017/flask_mongoengine_test_db_prod" - # Invalid host, should trigger exception if used - self.app.config["MONGODB_TEST_HOST"] = "dummy://localhost:27017/test" - self._do_persist(db) - - def test_connection_kwargs(self): - """Make sure additional connection kwargs work.""" - - # Figure out whether to use "MAX_POOL_SIZE" or "MAXPOOLSIZE" based - # on PyMongo version (former was changed to the latter as described - # in https://jira.mongodb.org/browse/PYTHON-854) - # TODO remove once PyMongo < 3.0 support is dropped - if pymongo.version_tuple[0] >= 3: - MAX_POOL_SIZE_KEY = "MAXPOOLSIZE" - else: - MAX_POOL_SIZE_KEY = "MAX_POOL_SIZE" - - self.app.config["MONGODB_SETTINGS"] = { - "ALIAS": "tz_aware_true", - "DB": "flask_mongoengine_testing_tz_aware", - "TZ_AWARE": True, - "READ_PREFERENCE": ReadPreference.SECONDARY, - MAX_POOL_SIZE_KEY: 10, - } - db = MongoEngine() - db.init_app(self.app) - self.assertTrue(db.connection.codec_options.tz_aware) - self.assertEqual(db.connection.max_pool_size, 10) - self.assertEqual(db.connection.read_preference, ReadPreference.SECONDARY) + # Make sure the doc still doesn't exist in the alternative db + with switch_db(Todo, "alternative") as Todo: + doc = Todo.objects().first() + assert doc is None + + # Make sure switching back to the default connection shows the doc + with switch_db(Todo, "default") as Todo: + doc = Todo.objects().first() + assert doc is not None + + +def test_ingnored_mongodb_prefix_config(app): + """Config starting by MONGODB_ but not used by flask-mongoengine + should be ignored. + """ + db = MongoEngine() + app.config[ + "MONGODB_HOST" + ] = "mongodb://localhost:27017/flask_mongoengine_test_db_prod" + # Invalid host, should trigger exception if used + app.config["MONGODB_TEST_HOST"] = "dummy://localhost:27017/test" + db.init_app(app) + + connection = mongoengine.get_connection() + mongo_engine_db = mongoengine.get_db() + assert isinstance(mongo_engine_db, Database) + assert isinstance(connection, MongoClient) + assert mongo_engine_db.name == "flask_mongoengine_test_db_prod" + assert connection.HOST == "localhost" + assert connection.PORT == 27017 + + +def test_connection_kwargs(app): + """Make sure additional connection kwargs work.""" + + # Figure out whether to use "MAX_POOL_SIZE" or "MAXPOOLSIZE" based + # on PyMongo version (former was changed to the latter as described + # in https://jira.mongodb.org/browse/PYTHON-854) + # TODO remove once PyMongo < 3.0 support is dropped + if pymongo.version_tuple[0] >= 3: + MAX_POOL_SIZE_KEY = "MAXPOOLSIZE" + else: + MAX_POOL_SIZE_KEY = "MAX_POOL_SIZE" + + app.config["MONGODB_SETTINGS"] = { + "ALIAS": "tz_aware_true", + "DB": "flask_mongoengine_testing_tz_aware", + "TZ_AWARE": True, + "READ_PREFERENCE": ReadPreference.SECONDARY, + MAX_POOL_SIZE_KEY: 10, + } + db = MongoEngine(app) + + assert db.connection.codec_options.tz_aware + assert db.connection.max_pool_size == 10 + assert db.connection.read_preference == ReadPreference.SECONDARY diff --git a/tests/test_forms.py b/tests/test_forms.py index f7620a1f..4fceaabb 100644 --- a/tests/test_forms.py +++ b/tests/test_forms.py @@ -1,564 +1,530 @@ import datetime import re -import unittest import bson import flask +import pytest import wtforms -from mongoengine import queryset_manager +from mongoengine import NotUniqueError, queryset_manager from werkzeug.datastructures import MultiDict -from flask_mongoengine import MongoEngine from flask_mongoengine.wtf import model_form -from tests import FlaskMongoEngineTestCase -class WTFormsAppTestCase(FlaskMongoEngineTestCase): - def setUp(self): - super(WTFormsAppTestCase, self).setUp() - self.db_name = "test_db" - self.app.config["MONGODB_DB"] = self.db_name - self.app.config["TESTING"] = True - # For Flask-WTF < 0.9 - self.app.config["CSRF_ENABLED"] = False - # For Flask-WTF >= 0.9 - self.app.config["WTF_CSRF_ENABLED"] = False - self.db = MongoEngine() - self.db.init_app(self.app) +def test_binaryfield(app, db): - def tearDown(self): - try: - self.db.connection.drop_database(self.db_name) - except Exception: - self.db.connection.client.drop_database(self.db_name) + with app.test_request_context("/"): - def test_binaryfield(self): + class Binary(db.Document): + binary = db.BinaryField() - with self.app.test_request_context("/"): - db = self.db + BinaryForm = model_form(Binary) + form = BinaryForm(MultiDict({"binary": "1"})) + assert form.validate() + form.save() - class Binary(db.Document): - binary = db.BinaryField() - BinaryForm = model_form(Binary) - form = BinaryForm(MultiDict({"binary": "1"})) - self.assertTrue(form.validate()) - form.save() +def test_choices_coerce(app, db): - def test_choices_coerce(self): + with app.test_request_context("/"): - with self.app.test_request_context("/"): - db = self.db + CHOICES = ((1, "blue"), (2, "red")) - CHOICES = ((1, "blue"), (2, "red")) + class MyChoices(db.Document): + pill = db.IntField(choices=CHOICES) - class MyChoices(db.Document): - pill = db.IntField(choices=CHOICES) + MyChoicesForm = model_form(MyChoices) + form = MyChoicesForm(MultiDict({"pill": "1"})) + assert form.validate() + form.save() + assert MyChoices.objects.first().pill == 1 - MyChoicesForm = model_form(MyChoices) - form = MyChoicesForm(MultiDict({"pill": "1"})) - self.assertTrue(form.validate()) - form.save() - self.assertEqual(MyChoices.objects.first().pill, 1) - def test_list_choices_coerce(self): +def test_list_choices_coerce(app, db): - with self.app.test_request_context("/"): - db = self.db + with app.test_request_context("/"): - CHOICES = ((1, "blue"), (2, "red")) + CHOICES = ((1, "blue"), (2, "red")) - class MyChoices(db.Document): - pill = db.ListField(db.IntField(choices=CHOICES)) + class MyChoices(db.Document): + pill = db.ListField(db.IntField(choices=CHOICES)) - MyChoicesForm = model_form(MyChoices) - form = MyChoicesForm(MultiDict({"pill": "1"})) - self.assertTrue(form.validate()) - form.save() - self.assertEqual(MyChoices.objects.first().pill[0], 1) + MyChoicesForm = model_form(MyChoices) + form = MyChoicesForm(MultiDict({"pill": "1"})) + assert form.validate() + form.save() + assert MyChoices.objects.first().pill[0] == 1 - def test_emailfield(self): - with self.app.test_request_context("/"): - db = self.db +def test_emailfield(app, db): - class Email(db.Document): - email = db.EmailField(required=False) + with app.test_request_context("/"): - EmailForm = model_form(Email) - form = EmailForm(instance=Email()) - self.assertFalse("None" in "%s" % form.email) - self.assertTrue(form.validate()) + class Email(db.Document): + email = db.EmailField(required=False) - form = EmailForm(MultiDict({"email": ""})) - self.assertFalse("None" in "%s" % form.email) - self.assertTrue(form.validate()) + EmailForm = model_form(Email) + form = EmailForm(instance=Email()) + assert "None" not in "%s" % form.email + assert form.validate() - # Ensure required works + form = EmailForm(MultiDict({"email": ""})) + assert "None" not in "%s" % form.email + assert form.validate() - class Email(db.Document): - email = db.EmailField(required=True) + # Ensure required works - EmailForm = model_form(Email) - form = EmailForm(MultiDict({"email": ""})) - self.assertFalse("None" in "%s" % form.email) - self.assertFalse(form.validate()) + class Email(db.Document): + email = db.EmailField(required=True) - def test_model_form(self): - with self.app.test_request_context("/"): - db = self.db + EmailForm = model_form(Email) + form = EmailForm(MultiDict({"email": ""})) + assert "None" not in "%s" % form.email + assert not form.validate() - class BlogPost(db.Document): - meta = {"allow_inheritance": True} - title = db.StringField(required=True, max_length=200) - posted = db.DateTimeField(default=datetime.datetime.now) - tags = db.ListField(db.StringField()) - class TextPost(BlogPost): - email = db.EmailField(required=False) - lead_paragraph = db.StringField(max_length=200) - content = db.StringField(required=True) +def test_model_form(app, db): + with app.test_request_context("/"): - class LinkPost(BlogPost): - url = db.StringField(required=True, max_length=200) - interest = db.DecimalField(required=True) + class BlogPost(db.Document): + meta = {"allow_inheritance": True} + title = db.StringField(required=True, max_length=200) + posted = db.DateTimeField(default=datetime.datetime.now) + tags = db.ListField(db.StringField()) - # Create a text-based post - TextPostForm = model_form( - TextPost, field_args={"lead_paragraph": {"textarea": True}} - ) + class TextPost(BlogPost): + email = db.EmailField(required=False) + lead_paragraph = db.StringField(max_length=200) + content = db.StringField(required=True) + + class LinkPost(BlogPost): + url = db.StringField(required=True, max_length=200) + interest = db.DecimalField(required=True) + + # Create a text-based post + TextPostForm = model_form( + TextPost, field_args={"lead_paragraph": {"textarea": True}} + ) - form = TextPostForm( - MultiDict( - {"title": "Using MongoEngine", "tags": ["mongodb", "mongoengine"]} - ) + form = TextPostForm( + MultiDict( + {"title": "Using MongoEngine", "tags": ["mongodb", "mongoengine"]} ) + ) - self.assertFalse(form.validate()) + assert not form.validate() - form = TextPostForm( - MultiDict( - { - "title": "Using MongoEngine", - "content": "See the tutorial", - "tags": ["mongodb", "mongoengine"], - } - ) + form = TextPostForm( + MultiDict( + { + "title": "Using MongoEngine", + "content": "See the tutorial", + "tags": ["mongodb", "mongoengine"], + } ) + ) - self.assertTrue(form.validate()) - form.save() + assert form.validate() + form.save() - self.assertEqual(form.title.type, "StringField") - self.assertEqual(form.content.type, "TextAreaField") - self.assertEqual(form.lead_paragraph.type, "TextAreaField") + assert form.title.type == "StringField" + assert form.content.type == "TextAreaField" + assert form.lead_paragraph.type == "TextAreaField" - self.assertEqual(BlogPost.objects.first().title, "Using MongoEngine") - self.assertEqual(BlogPost.objects.count(), 1) + assert BlogPost.objects.first().title == "Using MongoEngine" + assert BlogPost.objects.count() == 1 - form = TextPostForm( - MultiDict( - { - "title": "Using Flask-MongoEngine", - "content": "See the tutorial", - "tags": ["flask", "mongodb", "mongoengine"], - } - ) + form = TextPostForm( + MultiDict( + { + "title": "Using Flask-MongoEngine", + "content": "See the tutorial", + "tags": ["flask", "mongodb", "mongoengine"], + } ) - - self.assertTrue(form.validate()) - form.save() - self.assertEqual(BlogPost.objects.count(), 2) - - post = BlogPost.objects(title="Using Flask-MongoEngine").get() - - form = TextPostForm( - MultiDict( - { - "title": "Using Flask-MongoEngine", - "content": "See the tutorial", - "tags-0": "flask", - "tags-1": "mongodb", - "tags-2": "mongoengine", - "tags-3": "flask-mongoengine", - } - ), - instance=post, + ) + + assert form.validate() + form.save() + assert BlogPost.objects.count() == 2 + + post = BlogPost.objects(title="Using Flask-MongoEngine").get() + + form = TextPostForm( + MultiDict( + { + "title": "Using Flask-MongoEngine", + "content": "See the tutorial", + "tags-0": "flask", + "tags-1": "mongodb", + "tags-2": "mongoengine", + "tags-3": "flask-mongoengine", + } + ), + instance=post, + ) + assert form.validate() + form.save() + post = post.reload() + + assert post.tags == ["flask", "mongodb", "mongoengine", "flask-mongoengine"] + + # Create a link post + LinkPostForm = model_form(LinkPost) + + form = LinkPostForm( + MultiDict( + { + "title": "Using Flask-MongoEngine", + "url": "http://flask-mongoengine.org", + "interest": "0", + } ) - self.assertTrue(form.validate()) - form.save() - post = post.reload() + ) + form.validate() + assert form.validate() - self.assertEqual( - post.tags, ["flask", "mongodb", "mongoengine", "flask-mongoengine"] - ) - # Create a link post - LinkPostForm = model_form(LinkPost) - - form = LinkPostForm( - MultiDict( - { - "title": "Using Flask-MongoEngine", - "url": "http://flask-mongoengine.org", - "interest": "0", - } - ) - ) - form.validate() - self.assertTrue(form.validate()) +def test_model_form_only(app, db): + with app.test_request_context("/"): - def test_model_form_only(self): - with self.app.test_request_context("/"): - db = self.db + class BlogPost(db.Document): + title = db.StringField(required=True, max_length=200) + posted = db.DateTimeField(default=datetime.datetime.now) + tags = db.ListField(db.StringField()) - class BlogPost(db.Document): - title = db.StringField(required=True, max_length=200) - posted = db.DateTimeField(default=datetime.datetime.now) - tags = db.ListField(db.StringField()) + BlogPost.drop_collection() - BlogPost.drop_collection() + BlogPostForm = model_form(BlogPost, only=["tags"]) + form = BlogPostForm() + assert hasattr(form, "tags") + assert not hasattr(form, "posted") - BlogPostForm = model_form(BlogPost, only=["tags"]) - form = BlogPostForm() - self.assertTrue(hasattr(form, "tags")) - self.assertFalse(hasattr(form, "posted")) + BlogPostForm = model_form(BlogPost, exclude=["posted"]) + form = BlogPostForm() + assert hasattr(form, "tags") + assert not hasattr(form, "posted") - BlogPostForm = model_form(BlogPost, exclude=["posted"]) - form = BlogPostForm() - self.assertTrue(hasattr(form, "tags")) - self.assertFalse(hasattr(form, "posted")) - def test_model_form_with_custom_query_set(self): - with self.app.test_request_context("/"): - db = self.db +def test_model_form_with_custom_query_set(app, db): + with app.test_request_context("/"): - class Dog(db.Document): - breed = db.StringField() + class Dog(db.Document): + breed = db.StringField() - @queryset_manager - def large_objects(cls, queryset): - return queryset(breed__in=["german sheppard", "wolfhound"]) + @queryset_manager + def large_objects(cls, queryset): + return queryset(breed__in=["german sheppard", "wolfhound"]) - class DogOwner(db.Document): - dog = db.ReferenceField(Dog) + class DogOwner(db.Document): + dog = db.ReferenceField(Dog) - big_dogs = [Dog(breed="german sheppard"), Dog(breed="wolfhound")] - dogs = [Dog(breed="poodle")] + big_dogs - for dog in dogs: - dog.save() + big_dogs = [Dog(breed="german sheppard"), Dog(breed="wolfhound")] + dogs = [Dog(breed="poodle")] + big_dogs + for dog in dogs: + dog.save() - BigDogForm = model_form( - DogOwner, field_args={"dog": {"queryset": Dog.large_objects}} - ) + BigDogForm = model_form( + DogOwner, field_args={"dog": {"queryset": Dog.large_objects}} + ) - form = BigDogForm(dog=big_dogs[0]) - self.assertTrue(form.validate()) - self.assertEqual(big_dogs, [d[1] for d in form.dog.iter_choices()]) + form = BigDogForm(dog=big_dogs[0]) + assert form.validate() + assert big_dogs == [d[1] for d in form.dog.iter_choices()] - def test_modelselectfield(self): - with self.app.test_request_context("/"): - db = self.db - class Dog(db.Document): - name = db.StringField() +def test_modelselectfield(app, db): + with app.test_request_context("/"): - class DogOwner(db.Document): - dog = db.ReferenceField(Dog) + class Dog(db.Document): + name = db.StringField() - DogOwnerForm = model_form( - DogOwner, field_args={"dog": {"allow_blank": True}} - ) + class DogOwner(db.Document): + dog = db.ReferenceField(Dog) - dog = Dog(name="fido") - dog.save() + DogOwnerForm = model_form(DogOwner, field_args={"dog": {"allow_blank": True}}) - form = DogOwnerForm(dog=dog) - self.assertTrue(form.validate()) + dog = Dog(name="fido") + dog.save() - self.assertEqual(wtforms.widgets.Select, type(form.dog.widget)) - self.assertFalse(form.dog.widget.multiple) + form = DogOwnerForm(dog=dog) + assert form.validate() - # Validate the options - should contain a dog (selected) and a - # blank option there should be an extra blank option. - choices = list(form.dog) - self.assertEqual(len(choices), 2) - self.assertFalse(choices[0].checked) - self.assertEqual(choices[0].data, "__None") - self.assertTrue(choices[1].checked) - self.assertEqual(choices[1].data, dog.pk) + assert isinstance(form.dog.widget, wtforms.widgets.Select) + assert not form.dog.widget.multiple - # Validate selecting one item - form = DogOwnerForm(MultiDict({"dog": dog.id})) - self.assertEqual(form.dog.data, dog) + # Validate the options - should contain a dog (selected) and a + # blank option there should be an extra blank option. + choices = list(form.dog) + assert len(choices) == 2 + assert not choices[0].checked + assert choices[0].data == "__None" + assert choices[1].checked + assert choices[1].data == dog.pk - # Validate selecting no item - form = DogOwnerForm(MultiDict({"dog": "__None"}), dog=dog) - self.assertEqual(form.dog.data, None) + # Validate selecting one item + form = DogOwnerForm(MultiDict({"dog": dog.id})) + assert form.dog.data == dog - def test_modelselectfield_multiple(self): - with self.app.test_request_context("/"): - db = self.db + # Validate selecting no item + form = DogOwnerForm(MultiDict({"dog": "__None"}), dog=dog) + assert form.dog.data is None - class Dog(db.Document): - name = db.StringField() - class DogOwner(db.Document): - dogs = db.ListField(db.ReferenceField(Dog)) +def test_modelselectfield_multiple(app, db): + with app.test_request_context("/"): - DogOwnerForm = model_form( - DogOwner, field_args={"dogs": {"allow_blank": True}} - ) + class Dog(db.Document): + name = db.StringField() - dogs = [Dog(name="fido"), Dog(name="rex")] - for dog in dogs: - dog.save() + class DogOwner(db.Document): + dogs = db.ListField(db.ReferenceField(Dog)) - form = DogOwnerForm(dogs=dogs) - self.assertTrue(form.validate()) + DogOwnerForm = model_form(DogOwner, field_args={"dogs": {"allow_blank": True}}) - self.assertEqual(wtforms.widgets.Select, type(form.dogs.widget)) - self.assertTrue(form.dogs.widget.multiple) + dogs = [Dog(name="fido"), Dog(name="rex")] + for dog in dogs: + dog.save() - # Validate the options - both dogs should be selected and - # there should be an extra blank option. - choices = list(form.dogs) - self.assertEqual(len(choices), 3) - self.assertFalse(choices[0].checked) - self.assertEqual(choices[0].data, "__None") - self.assertTrue(choices[1].checked) - self.assertEqual(choices[1].data, dogs[0].pk) - self.assertTrue(choices[2].checked) - self.assertEqual(choices[2].data, dogs[1].pk) + form = DogOwnerForm(dogs=dogs) + assert form.validate() - # Validate selecting two items - form = DogOwnerForm(MultiDict({"dogs": [dog.id for dog in dogs]})) - self.assertEqual(form.dogs.data, dogs) + assert isinstance(form.dogs.widget, wtforms.widgets.Select) + assert form.dogs.widget.multiple - # Validate selecting none actually empties the list - form = DogOwnerForm(MultiDict({"dogs": "__None"}), dogs=dogs) - self.assertEqual(form.dogs.data, None) + # Validate the options - both dogs should be selected and + # there should be an extra blank option. + choices = list(form.dogs) + assert len(choices) == 3 + assert not choices[0].checked + assert choices[0].data == "__None" + assert choices[1].checked + assert choices[1].data == dogs[0].pk + assert choices[2].checked + assert choices[2].data == dogs[1].pk - def test_modelselectfield_multiple_initalvalue_None(self): - with self.app.test_request_context("/"): - db = self.db + # Validate selecting two items + form = DogOwnerForm(MultiDict({"dogs": [dog.id for dog in dogs]})) + assert form.dogs.data == dogs - class Dog(db.Document): - name = db.StringField() + # Validate selecting none actually empties the list + form = DogOwnerForm(MultiDict({"dogs": "__None"}), dogs=dogs) + assert form.dogs.data is None - class DogOwner(db.Document): - dogs = db.ListField(db.ReferenceField(Dog)) - DogOwnerForm = model_form(DogOwner) +def test_modelselectfield_multiple_initalvalue_None(app, db): + with app.test_request_context("/"): - dogs = [Dog(name="fido"), Dog(name="rex")] - for dog in dogs: - dog.save() + class Dog(db.Document): + name = db.StringField() - form = DogOwnerForm(dogs=None) - self.assertTrue(form.validate()) + class DogOwner(db.Document): + dogs = db.ListField(db.ReferenceField(Dog)) - self.assertEqual(wtforms.widgets.Select, type(form.dogs.widget)) - self.assertTrue(form.dogs.widget.multiple) + DogOwnerForm = model_form(DogOwner) - # Validate if both dogs are selected - choices = list(form.dogs) - self.assertEqual(len(choices), 2) - self.assertFalse(choices[0].checked) - self.assertFalse(choices[1].checked) + dogs = [Dog(name="fido"), Dog(name="rex")] + for dog in dogs: + dog.save() - def test_modelradiofield(self): - with self.app.test_request_context("/"): - db = self.db + form = DogOwnerForm(dogs=None) + assert form.validate() - choices = [("male", "Male"), ("female", "Female"), ("other", "Other")] + assert isinstance(form.dogs.widget, wtforms.widgets.Select) + assert form.dogs.widget.multiple - class Poll(db.Document): - answer = db.StringField(choices=choices) + # Validate if both dogs are selected + choices = list(form.dogs) + assert len(choices) == 2 + assert not choices[0].checked + assert not choices[1].checked - PollForm = model_form(Poll, field_args={"answer": {"radio": True}}) - form = PollForm(answer=None) - self.assertTrue(form.validate()) +def test_modelradiofield(app, db): + with app.test_request_context("/"): - self.assertEqual(form.answer.type, "RadioField") - self.assertEqual(form.answer.choices, choices) + choices = [("male", "Male"), ("female", "Female"), ("other", "Other")] - def test_passwordfield(self): - with self.app.test_request_context("/"): - db = self.db + class Poll(db.Document): + answer = db.StringField(choices=choices) - class User(db.Document): - password = db.StringField() + PollForm = model_form(Poll, field_args={"answer": {"radio": True}}) - UserForm = model_form(User, field_args={"password": {"password": True}}) - form = UserForm(password="12345") - self.assertEqual(wtforms.widgets.PasswordInput, type(form.password.widget)) + form = PollForm(answer=None) + assert form.validate() - def test_unique_with(self): + assert form.answer.type == "RadioField" + assert form.answer.choices == choices - with self.app.test_request_context("/"): - db = self.db - class Item(db.Document): - owner_id = db.ObjectIdField(required=True) - owner_item_id = db.StringField(required=True, unique_with="owner_id") +def test_passwordfield(app, db): + with app.test_request_context("/"): - Item.drop_collection() + class User(db.Document): + password = db.StringField() - object_id = bson.ObjectId() + UserForm = model_form(User, field_args={"password": {"password": True}}) + form = UserForm(password="12345") + assert isinstance(form.password.widget, wtforms.widgets.PasswordInput) + + +def test_unique_with(app, db): + + with app.test_request_context("/"): + + class Item(db.Document): + owner_id = db.ObjectIdField(required=True) + owner_item_id = db.StringField(required=True, unique_with="owner_id") + + Item.drop_collection() + + object_id = bson.ObjectId() + Item(owner_id=object_id, owner_item_id="1").save() + + with pytest.raises(NotUniqueError): Item(owner_id=object_id, owner_item_id="1").save() - try: - Item(owner_id=object_id, owner_item_id="1").save() - self.fail("Should have raised duplicate key error") - except Exception: - pass + assert 1 == Item.objects.count() - self.assertEqual(1, Item.objects.count()) - def test_sub_field_args(self): - with self.app.test_request_context("/"): - db = self.db +def test_sub_field_args(app, db): + with app.test_request_context("/"): - class TestModel(db.Document): - lst = db.ListField(db.StringField()) + class TestModel(db.Document): + lst = db.ListField(db.StringField()) - field_args = { - "lst": { - "label": "Custom Label", - "field_args": { - "widget": wtforms.widgets.HiddenInput(), - "label": "Hidden Input", - }, - } + field_args = { + "lst": { + "label": "Custom Label", + "field_args": { + "widget": wtforms.widgets.HiddenInput(), + "label": "Hidden Input", + }, } - CustomForm = model_form(TestModel, field_args=field_args) + } + CustomForm = model_form(TestModel, field_args=field_args) - custom_form = CustomForm(obj=TestModel(lst=["Foo"])) - list_label = flask.render_template_string( - "{{ custom_form.lst.label }}", custom_form=custom_form - ) - self.assertTrue("Custom Label" in list_label) - self.assertTrue("Hidden Input" not in list_label) + custom_form = CustomForm(obj=TestModel(lst=["Foo"])) + list_label = flask.render_template_string( + "{{ custom_form.lst.label }}", custom_form=custom_form + ) + assert "Custom Label" in list_label + assert "Hidden Input" not in list_label - sub_label = flask.render_template_string( - "{{ custom_form.lst }}", custom_form=custom_form - ) - self.assertTrue("Hidden Input" in sub_label) + sub_label = flask.render_template_string( + "{{ custom_form.lst }}", custom_form=custom_form + ) + assert "Hidden Input" in sub_label - def test_modelselectfield_multiple_selected_elements_must_be_retained(self): - with self.app.test_request_context("/"): - db = self.db - class Dog(db.Document): - name = db.StringField() +def test_modelselectfield_multiple_selected_elements_must_be_retained(app, db): + with app.test_request_context("/"): - def __unicode__(self): - return self.name + class Dog(db.Document): + name = db.StringField() - class DogOwner(db.Document): - dogs = db.ListField(db.ReferenceField(Dog)) + def __unicode__(self): + return self.name - DogOwnerForm = model_form(DogOwner) + class DogOwner(db.Document): + dogs = db.ListField(db.ReferenceField(Dog)) - fido = Dog(name="fido").save() - Dog(name="rex").save() + DogOwnerForm = model_form(DogOwner) - dogOwner = DogOwner(dogs=[fido]) - form = DogOwnerForm(obj=dogOwner) - html = form.dogs() + fido = Dog(name="fido").save() + Dog(name="rex").save() - m = re.search("", html) - self.assertTrue(m is not None, "Should have one selected option") - self.assertEqual("fido", m.group(1)) + dogOwner = DogOwner(dogs=[fido]) + form = DogOwnerForm(obj=dogOwner) + html = form.dogs() - def test_model_form_help_text(self): - with self.app.test_request_context("/"): - db = self.db + m = re.search("", html) - class BlogPost(db.Document): - title = db.StringField( - required=True, - help_text="Some imaginative title to set the world on fire", - ) + # Should have one selected option + assert m is not None + assert "fido" == m.group(1) - post = BlogPost(title="hello world").save() - BlogPostForm = model_form(BlogPost) - form = BlogPostForm(instance=post) +def test_model_form_help_text(app, db): + with app.test_request_context("/"): - self.assertEqual( - form.title.description, - "Some imaginative title to set the world on fire", + class BlogPost(db.Document): + title = db.StringField( + required=True, + help_text="Some imaginative title to set the world on fire", ) - def test_shared_field_args(self): - with self.app.test_request_context("/"): - db = self.db + post = BlogPost(title="hello world").save() - class BlogPost(db.Document): - title = db.StringField(required=True) - content = db.StringField(required=False) + BlogPostForm = model_form(BlogPost) + form = BlogPostForm(instance=post) - shared_field_args = { - "title": {"validators": [wtforms.validators.Regexp("test")]} - } + assert ( + form.title.description == "Some imaginative title to set the world on fire" + ) - TitleOnlyForm = model_form( - BlogPost, field_args=shared_field_args, exclude=["content"] - ) - BlogPostForm = model_form(BlogPost, field_args=shared_field_args) - # ensure shared field_args don't create duplicate validators - title_only_form = TitleOnlyForm() - self.assertEqual(len(title_only_form.title.validators), 2) +def test_shared_field_args(app, db): + with app.test_request_context("/"): + + class BlogPost(db.Document): + title = db.StringField(required=True) + content = db.StringField(required=False) + + shared_field_args = { + "title": {"validators": [wtforms.validators.Regexp("test")]} + } + + TitleOnlyForm = model_form( + BlogPost, field_args=shared_field_args, exclude=["content"] + ) + BlogPostForm = model_form(BlogPost, field_args=shared_field_args) + + # ensure shared field_args don't create duplicate validators + title_only_form = TitleOnlyForm() + assert len(title_only_form.title.validators) == 2 - blog_post_form = BlogPostForm() - self.assertEqual(len(blog_post_form.title.validators), 2) + blog_post_form = BlogPostForm() + assert len(blog_post_form.title.validators) == 2 - def test_embedded_model_form(self): - with self.app.test_request_context("/"): - db = self.db - class Content(db.EmbeddedDocument): - text = db.StringField() - lang = db.StringField(max_length=3) +def test_embedded_model_form(app, db): + with app.test_request_context("/"): - class Post(db.Document): - title = db.StringField(max_length=120, required=True) - tags = db.ListField(db.StringField(max_length=30)) - content = db.EmbeddedDocumentField("Content") + class Content(db.EmbeddedDocument): + text = db.StringField() + lang = db.StringField(max_length=3) - PostForm = model_form(Post) - form = PostForm() - self.assertTrue("content-text" in "%s" % form.content.text) + class Post(db.Document): + title = db.StringField(max_length=120, required=True) + tags = db.ListField(db.StringField(max_length=30)) + content = db.EmbeddedDocumentField("Content") - def test_form_label_modifier(self): - with self.app.test_request_context("/"): - db = self.db + PostForm = model_form(Post) + form = PostForm() + assert "content-text" in "%s" % form.content.text - class FoodItem(db.Document): - title = db.StringField() - class FoodStore(db.Document): - title = db.StringField(max_length=120, required=True) - food_items = db.ListField(db.ReferenceField(FoodItem)) +def test_form_label_modifier(app, db): + with app.test_request_context("/"): - def food_items_label_modifier(obj): - return obj.title + class FoodItem(db.Document): + title = db.StringField() - fruit_names = ["banana", "apple", "pear"] + class FoodStore(db.Document): + title = db.StringField(max_length=120, required=True) + food_items = db.ListField(db.ReferenceField(FoodItem)) - food_items = [FoodItem(title=name).save() for name in fruit_names] + def food_items_label_modifier(obj): + return obj.title - FoodStore(title="John's fruits", food_items=food_items).save() + fruit_names = ["banana", "apple", "pear"] - FoodStoreForm = model_form(FoodStore) - form = FoodStoreForm() + food_items = [FoodItem(title=name).save() for name in fruit_names] - assert [obj.label.text for obj in form.food_items] == fruit_names + FoodStore(title="John's fruits", food_items=food_items).save() + FoodStoreForm = model_form(FoodStore) + form = FoodStoreForm() -if __name__ == "__main__": - unittest.main() + assert [obj.label.text for obj in form.food_items] == fruit_names diff --git a/tests/test_json.py b/tests/test_json.py index f8f4900e..4438ca03 100644 --- a/tests/test_json.py +++ b/tests/test_json.py @@ -1,7 +1,28 @@ import flask +import pytest from flask_mongoengine import MongoEngine -from tests import FlaskMongoEngineTestCase + + +@pytest.fixture() +def extended_db(app): + app.json_encoder = DummyEncoder + app.config["MONGODB_HOST"] = "mongodb://localhost:27017/flask_mongoengine_test_db" + test_db = MongoEngine(app) + db_name = test_db.connection.get_database("flask_mongoengine_test_db").name + + if not db_name.endswith("_test_db"): + raise RuntimeError( + f"DATABASE_URL must point to testing db, not to master db ({db_name})" + ) + + # Clear database before tests, for cases when some test failed before. + test_db.connection.drop_database(db_name) + + yield test_db + + # Clear database after tests, for graceful exit. + test_db.connection.drop_database(db_name) class DummyEncoder(flask.json.JSONEncoder): @@ -12,29 +33,11 @@ class DummyEncoder(flask.json.JSONEncoder): """ -class JSONAppTestCase(FlaskMongoEngineTestCase): - def dictContains(self, superset, subset): - for k, v in subset.items(): - if not superset[k] == v: - return False - return True - - def assertDictContains(self, superset, subset): - return self.assertTrue(self.dictContains(superset, subset)) - - def setUp(self): - super(JSONAppTestCase, self).setUp() - self.app.config["MONGODB_DB"] = "test_db" - self.app.config["TESTING"] = True - self.app.json_encoder = DummyEncoder - db = MongoEngine() - db.init_app(self.app) - self.db = db - - def test_inheritance(self): - self.assertTrue(issubclass(self.app.json_encoder, DummyEncoder)) - json_encoder_name = self.app.json_encoder.__name__ - - # Since the class is dynamically derrived, must compare class names - # rather than class objects. - self.assertEqual(json_encoder_name, "MongoEngineJSONEncoder") +@pytest.mark.usefixtures("extended_db") +def test_inheritance(app): + assert issubclass(app.json_encoder, DummyEncoder) + json_encoder_name = app.json_encoder.__name__ + + # Since the class is dynamically derrived, must compare class names + # rather than class objects. + assert json_encoder_name == "MongoEngineJSONEncoder" diff --git a/tests/test_json_app.py b/tests/test_json_app.py index ff3eda72..60962a3e 100644 --- a/tests/test_json_app.py +++ b/tests/test_json_app.py @@ -1,84 +1,49 @@ -import datetime - import flask +import pytest from bson import ObjectId -from flask_mongoengine import MongoEngine -from tests import FlaskMongoEngineTestCase - - -class JSONAppTestCase(FlaskMongoEngineTestCase): - def dictContains(self, superset, subset): - for k, v in subset.items(): - if not superset[k] == v: - return False - return True - - def assertDictContains(self, superset, subset): - return self.assertTrue(self.dictContains(superset, subset)) - - def setUp(self): - super(JSONAppTestCase, self).setUp() - self.app.config["MONGODB_DB"] = "test_db" - self.app.config["TESTING"] = True - self.app.config["TEMP_DB"] = True - db = MongoEngine() - - class Todo(db.Document): - title = db.StringField(max_length=60) - text = db.StringField() - done = db.BooleanField(default=False) - pub_date = db.DateTimeField(default=datetime.datetime.now) - db.init_app(self.app) +@pytest.fixture(autouse=True) +def setup_endpoints(app, todo): + Todo = todo - Todo.drop_collection() - self.Todo = Todo + @app.route("/") + def index(): + return flask.jsonify(result=Todo.objects()) - @self.app.route("/") - def index(): - return flask.jsonify(result=self.Todo.objects()) + @app.route("/add", methods=["POST"]) + def add(): + form = flask.request.form + todo = Todo(title=form["title"], text=form["text"]) + todo.save() + return flask.jsonify(result=todo) - @self.app.route("/add", methods=["POST"]) - def add(): - form = flask.request.form - todo = self.Todo(title=form["title"], text=form["text"]) - todo.save() - return flask.jsonify(result=todo) + @app.route("/show//") + def show(id): + return flask.jsonify(result=Todo.objects.get_or_404(id=id)) - @self.app.route("/show//") - def show(id): - return flask.jsonify(result=self.Todo.objects.get_or_404(id=id)) - self.db = db +def test_with_id(app, todo): + Todo = todo + client = app.test_client() + response = client.get("/show/%s/" % ObjectId()) + assert response.status_code == 404 - def test_with_id(self): - c = self.app.test_client() - resp = c.get("/show/%s/" % ObjectId()) - self.assertEqual(resp.status_code, 404) + response = client.post("/add", data={"title": "First Item", "text": "The text"}) + assert response.status_code == 200 - rv = c.post("/add", data={"title": "First Item", "text": "The text"}) - self.assertEqual(rv.status_code, 200) + response = client.get("/show/%s/" % Todo.objects.first().id) + assert response.status_code == 200 - resp = c.get("/show/%s/" % self.Todo.objects.first().id) - self.assertEqual(resp.status_code, 200) - res = flask.json.loads(resp.data).get("result") - self.assertDictContains(res, {"title": "First Item", "text": "The text"}) + result = flask.json.loads(response.data).get("result") + assert ("title", "First Item") in result.items() - def test_basic_insert(self): - c = self.app.test_client() - d1 = {"title": "First Item", "text": "The text"} - d2 = {"title": "2nd Item", "text": "The text"} - c.post("/add", data=d1) - c.post("/add", data=d2) - rv = c.get("/") - result = flask.json.loads(rv.data).get("result") - self.assertEqual(len(result), 2) +def test_basic_insert(app): + client = app.test_client() + client.post("/add", data={"title": "First Item", "text": "The text"}) + client.post("/add", data={"title": "2nd Item", "text": "The text"}) + rv = client.get("/") + result = flask.json.loads(rv.data).get("result") - # ensure each of the objects is one of the two we already - # inserted - for obj in result: - self.assertTrue( - any([self.dictContains(obj, d1), self.dictContains(obj, d2)]) - ) + assert len(result) == 2 diff --git a/tests/test_pagination.py b/tests/test_pagination.py index a390d3d4..605f53a7 100644 --- a/tests/test_pagination.py +++ b/tests/test_pagination.py @@ -1,112 +1,84 @@ -import unittest - +import pytest from werkzeug.exceptions import NotFound -from flask_mongoengine import ListFieldPagination, MongoEngine, Pagination -from tests import FlaskMongoEngineTestCase - - -class PaginationTestCase(FlaskMongoEngineTestCase): - def setUp(self): - super(PaginationTestCase, self).setUp() - self.db_name = "test_db" - self.app.config["MONGODB_DB"] = self.db_name - self.app.config["TESTING"] = True - self.app.config["CSRF_ENABLED"] = False - self.db = MongoEngine() - self.db.init_app(self.app) - - def tearDown(self): - try: - self.db.connection.drop_database(self.db_name) - except Exception: - self.db.connection.client.drop_database(self.db_name) +from flask_mongoengine import ListFieldPagination, Pagination - def test_queryset_paginator(self): - with self.app.test_request_context("/"): - db = self.db - class Post(db.Document): - title = db.StringField(required=True, max_length=200) +def test_queryset_paginator(app, todo): + Todo = todo + for i in range(42): + Todo(title="post: %s" % i).save() - for i in range(42): - Post(title="post: %s" % i).save() + with pytest.raises(NotFound): + Pagination(iterable=Todo.objects, page=0, per_page=10) - self.assertRaises(NotFound, Pagination, Post.objects, 0, 10) - self.assertRaises(NotFound, Pagination, Post.objects, 6, 10) + with pytest.raises(NotFound): + Pagination(iterable=Todo.objects, page=6, per_page=10) - paginator = Pagination(Post.objects, 1, 10) - self._test_paginator(paginator) + paginator = Pagination(Todo.objects, 1, 10) + _test_paginator(paginator) - def test_paginate_plain_list(self): - self.assertRaises(NotFound, Pagination, range(1, 42), 0, 10) - self.assertRaises(NotFound, Pagination, range(1, 42), 6, 10) +def test_paginate_plain_list(): + with pytest.raises(NotFound): + Pagination(iterable=range(1, 42), page=0, per_page=10) - paginator = Pagination(range(1, 42), 1, 10) - self._test_paginator(paginator) + with pytest.raises(NotFound): + Pagination(iterable=range(1, 42), page=6, per_page=10) - def test_list_field_pagination(self): + paginator = Pagination(range(1, 42), 1, 10) + _test_paginator(paginator) - with self.app.test_request_context("/"): - db = self.db - class Post(db.Document): - title = db.StringField(required=True, max_length=200) - comments = db.ListField(db.StringField()) - comment_count = db.IntField() +def test_list_field_pagination(app, todo): + Todo = todo - comments = ["comment: %s" % i for i in range(42)] - post = Post( - title="post has comments", - comments=comments, - comment_count=len(comments), - ).save() + comments = ["comment: %s" % i for i in range(42)] + todo = Todo( + title="todo has comments", comments=comments, comment_count=len(comments), + ).save() - # Check without providing a total - paginator = ListFieldPagination(Post.objects, post.id, "comments", 1, 10) - self._test_paginator(paginator) + # Check without providing a total + paginator = ListFieldPagination(Todo.objects, todo.id, "comments", 1, 10) + _test_paginator(paginator) - # Check with providing a total (saves a query) - paginator = ListFieldPagination( - Post.objects, post.id, "comments", 1, 10, post.comment_count - ) - self._test_paginator(paginator) + # Check with providing a total (saves a query) + paginator = ListFieldPagination( + Todo.objects, todo.id, "comments", 1, 10, todo.comment_count + ) + _test_paginator(paginator) - paginator = post.paginate_field("comments", 1, 10) - self._test_paginator(paginator) + paginator = todo.paginate_field("comments", 1, 10) + _test_paginator(paginator) - def _test_paginator(self, paginator): - self.assertEqual(5, paginator.pages) - self.assertEqual([1, 2, 3, 4, 5], list(paginator.iter_pages())) - for i in [1, 2, 3, 4, 5]: +def _test_paginator(paginator): + assert 5 == paginator.pages + assert [1, 2, 3, 4, 5] == list(paginator.iter_pages()) - if i == 1: - self.assertRaises(NotFound, paginator.prev) - self.assertFalse(paginator.has_prev) - else: - self.assertTrue(paginator.has_prev) + for i in [1, 2, 3, 4, 5]: - if i == 5: - self.assertRaises(NotFound, paginator.next) - self.assertFalse(paginator.has_next) - else: - self.assertTrue(paginator.has_next) + if i == 1: + assert not paginator.has_prev + with pytest.raises(NotFound): + paginator.prev() + else: + assert paginator.has_prev - if i == 3: - self.assertEqual( - [None, 2, 3, 4, None], list(paginator.iter_pages(0, 1, 1, 0)) - ) + if i == 5: + assert not paginator.has_next + with pytest.raises(NotFound): + paginator.next() + else: + assert paginator.has_next - self.assertEqual(i, paginator.page) - self.assertEqual(i - 1, paginator.prev_num) - self.assertEqual(i + 1, paginator.next_num) - - # Paginate to the next page - if i < 5: - paginator = paginator.next() + if i == 3: + assert [None, 2, 3, 4, None] == list(paginator.iter_pages(0, 1, 1, 0)) + assert i == paginator.page + assert i - 1 == paginator.prev_num + assert i + 1 == paginator.next_num -if __name__ == "__main__": - unittest.main() + # Paginate to the next page + if i < 5: + paginator = paginator.next() diff --git a/tests/test_session.py b/tests/test_session.py index 51c0e474..2faa0abd 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -1,56 +1,40 @@ -import unittest - +import pytest from flask import session -from flask_mongoengine import MongoEngine, MongoEngineSessionInterface -from tests import FlaskMongoEngineTestCase - +from flask_mongoengine import MongoEngineSessionInterface -class SessionTestCase(FlaskMongoEngineTestCase): - def setUp(self): - super(SessionTestCase, self).setUp() - self.db_name = "test_db" - self.app.config["MONGODB_DB"] = self.db_name - self.app.config["TESTING"] = True - db = MongoEngine(self.app) - self.app.session_interface = MongoEngineSessionInterface(db) - @self.app.route("/") - def index(): - session["a"] = "hello session" - return session["a"] +@pytest.fixture(autouse=True) +def setup_endpoints(app, db): - @self.app.route("/check-session") - def check_session(): - return "session: %s" % session["a"] + app.session_interface = MongoEngineSessionInterface(db) - @self.app.route("/check-session-database") - def check_session_database(): - sessions = self.app.session_interface.cls.objects.count() - return "sessions: %s" % sessions + @app.route("/") + def index(): + session["a"] = "hello session" + return session["a"] - self.db = db + @app.route("/check-session") + def check_session(): + return "session: %s" % session["a"] - def tearDown(self): - try: - self.db.connection.drop_database(self.db_name) - except Exception: - self.db.connection.client.drop_database(self.db_name) + @app.route("/check-session-database") + def check_session_database(): + sessions = app.session_interface.cls.objects.count() + return "sessions: %s" % sessions - def test_setting_session(self): - c = self.app.test_client() - resp = c.get("/") - self.assertEqual(resp.status_code, 200) - self.assertEqual(resp.data.decode("utf-8"), "hello session") - resp = c.get("/check-session") - self.assertEqual(resp.status_code, 200) - self.assertEqual(resp.data.decode("utf-8"), "session: hello session") +def test_setting_session(app): + client = app.test_client() - resp = c.get("/check-session-database") - self.assertEqual(resp.status_code, 200) - self.assertEqual(resp.data.decode("utf-8"), "sessions: 1") + response = client.get("/") + assert response.status_code == 200 + assert response.data.decode("utf-8") == "hello session" + response = client.get("/check-session") + assert response.status_code == 200 + assert response.data.decode("utf-8") == "session: hello session" -if __name__ == "__main__": - unittest.main() + response = client.get("/check-session-database") + assert response.status_code == 200 + assert response.data.decode("utf-8") == "sessions: 1" diff --git a/tox.ini b/tox.ini index 6a3a3327..764c28c0 100644 --- a/tox.ini +++ b/tox.ini @@ -8,7 +8,6 @@ deps = PyMongo>3.9.0 pytest pytest-cov - nose [testenv:lint] deps =