diff --git a/.gitignore b/.gitignore index 5473d972..7cdd6cde 100644 --- a/.gitignore +++ b/.gitignore @@ -138,4 +138,9 @@ dmypy.json # Pyre type checker .pyre/ + +# VScode +.vscode/ +app/.vscode/ + app/routers/stam diff --git a/app/config.py.example b/app/config.py.example index 57cfa211..924b4ed6 100644 --- a/app/config.py.example +++ b/app/config.py.example @@ -1,7 +1,6 @@ import os from fastapi_mail import ConnectionConfig - # flake8: noqa # general @@ -9,6 +8,8 @@ DOMAIN = 'Our-Domain' # DATABASE DEVELOPMENT_DATABASE_STRING = "sqlite:///./dev.db" +# Set the following True if working on PSQL environment or set False otherwise +PSQL_ENVIRONMENT = False # MEDIA MEDIA_DIRECTORY = 'media' diff --git a/app/database/database.py b/app/database/database.py index 631a3593..c0544c0c 100644 --- a/app/database/database.py +++ b/app/database/database.py @@ -6,12 +6,20 @@ from app import config + SQLALCHEMY_DATABASE_URL = os.getenv( "DATABASE_CONNECTION_STRING", config.DEVELOPMENT_DATABASE_STRING) -engine = create_engine( - SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False} -) + +def create_env_engine(psql_environment, sqlalchemy_database_url): + if not psql_environment: + return create_engine( + sqlalchemy_database_url, connect_args={"check_same_thread": False}) + + return create_engine(sqlalchemy_database_url) + + +engine = create_env_engine(config.PSQL_ENVIRONMENT, SQLALCHEMY_DATABASE_URL) SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) Base = declarative_base() diff --git a/app/database/models.py b/app/database/models.py index bc3025ba..91fc16b2 100644 --- a/app/database/models.py +++ b/app/database/models.py @@ -1,9 +1,11 @@ from datetime import datetime -from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Integer, String -from sqlalchemy.orm import relationship - +from app.config import PSQL_ENVIRONMENT from app.database.database import Base +from sqlalchemy import (DDL, Boolean, Column, DateTime, ForeignKey, Index, + Integer, String, event) +from sqlalchemy.dialects.postgresql import TSVECTOR +from sqlalchemy.orm import relationship class UserEvent(Base): @@ -54,10 +56,39 @@ class Event(Base): participants = relationship("UserEvent", back_populates="events") + # PostgreSQL + if PSQL_ENVIRONMENT: + events_tsv = Column(TSVECTOR) + __table_args__ = (Index( + 'events_tsv_idx', + 'events_tsv', + postgresql_using='gin'), + ) + def __repr__(self): return f'' +class PSQLEnvironmentError(Exception): + pass + + +# PostgreSQL +if PSQL_ENVIRONMENT: + trigger_snippet = DDL(""" + CREATE TRIGGER ix_events_tsv_update BEFORE INSERT OR UPDATE + ON events + FOR EACH ROW EXECUTE PROCEDURE + tsvector_update_trigger(events_tsv,'pg_catalog.english','title','content') + """) + + event.listen( + Event.__table__, + 'after_create', + trigger_snippet.execute_if(dialect='postgresql') + ) + + class Invitation(Base): __tablename__ = "invitations" diff --git a/app/internal/search.py b/app/internal/search.py new file mode 100644 index 00000000..47a83d4f --- /dev/null +++ b/app/internal/search.py @@ -0,0 +1,42 @@ +from typing import List + +from app.database.database import SessionLocal +from app.database.models import Event +from sqlalchemy.exc import SQLAlchemyError + + +def get_stripped_keywords(keywords: str) -> str: + '''Gets a string of keywords to search for from the user form + and returns a stripped ready-to-db-search keywords string''' + + keywords = " ".join(keywords.split()) + keywords = keywords.replace(" ", ":* & ") + ":*" + return keywords + + +def get_results_by_keywords( + session: SessionLocal, + keywords: str, + owner_id: int + ) -> List[Event]: + """Returns possible results for a search in the 'events' database table + + Args: + keywords (str): search string + owner_id (int): current user id + + Returns: + list: a list of events from the database matching the inserted keywords + + Uses PostgreSQL's built in 'Full-text search' feature + (doesn't work with SQLite)""" + + keywords = get_stripped_keywords(keywords) + + try: + return session.query(Event).filter( + Event.owner_id == owner_id, + Event.events_tsv.match(keywords)).all() + + except (SQLAlchemyError, AttributeError): + return [] diff --git a/app/main.py b/app/main.py index 5247976f..796ade3e 100644 --- a/app/main.py +++ b/app/main.py @@ -1,15 +1,27 @@ from fastapi import FastAPI, Request from fastapi.staticfiles import StaticFiles +from app.config import PSQL_ENVIRONMENT from app.database import models from app.database.database import engine from app.dependencies import ( MEDIA_PATH, STATIC_PATH, templates) -from app.routers import agenda, dayview, event, profile, email, invitation +from app.routers import (agenda, dayview, email, event, invitation, profile, + search) -models.Base.metadata.create_all(bind=engine) +def create_tables(engine, psql_environment): + if 'sqlite' in str(engine.url) and psql_environment: + raise models.PSQLEnvironmentError( + "You're trying to use PSQL features on SQLite env.\n" + "Please set app.config.PSQL_ENVIRONMENT to False " + "and run the app again." + ) + else: + models.Base.metadata.create_all(bind=engine) + +create_tables(engine, PSQL_ENVIRONMENT) app = FastAPI() app.mount("/static", StaticFiles(directory=STATIC_PATH), name="static") app.mount("/media", StaticFiles(directory=MEDIA_PATH), name="media") @@ -20,6 +32,7 @@ app.include_router(dayview.router) app.include_router(email.router) app.include_router(invitation.router) +app.include_router(search.router) @app.get("/") diff --git a/app/routers/search.py b/app/routers/search.py new file mode 100644 index 00000000..60fee140 --- /dev/null +++ b/app/routers/search.py @@ -0,0 +1,48 @@ +from app.database.database import get_db +from app.dependencies import templates +from app.internal.search import get_results_by_keywords +from fastapi import APIRouter, Depends, Form, Request +from sqlalchemy.orm import Session + + +router = APIRouter() + + +@router.get("/search") +def search(request: Request): + # Made up user details until there's a user login system + current_username = "Chuck Norris" + + return templates.TemplateResponse("search.html", { + "request": request, + "username": current_username + }) + + +@router.post("/search") +async def show_results( + request: Request, + keywords: str = Form(None), + db: Session = Depends(get_db)): + # Made up user details until there's a user login system + current_username = "Chuck Norris" + current_user = 1 + + message = "" + + if not keywords: + message = "Invalid request." + results = None + else: + results = get_results_by_keywords(db, keywords, owner_id=current_user) + if not results: + message = f"No matching results for '{keywords}'." + + return templates.TemplateResponse("search.html", { + "request": request, + "username": current_username, + "message": message, + "results": results, + "keywords": keywords + } + ) diff --git a/app/templates/base.html b/app/templates/base.html index 97360bd6..767f59cc 100644 --- a/app/templates/base.html +++ b/app/templates/base.html @@ -42,6 +42,9 @@ + diff --git a/app/templates/profile.html b/app/templates/profile.html index 9a0ddbda..aecb5164 100644 --- a/app/templates/profile.html +++ b/app/templates/profile.html @@ -246,4 +246,4 @@
{{ user.full_name }}
-{% endblock %} +{% endblock %} \ No newline at end of file diff --git a/app/templates/search.html b/app/templates/search.html new file mode 100644 index 00000000..af5eb5b4 --- /dev/null +++ b/app/templates/search.html @@ -0,0 +1,69 @@ +{% extends "base.html" %} + + +{% block content %} + +
+

Hello, {{ username }}

+
+
+
+ + +
+
+ +
+ +
+ +{% if message %} +
+ {{ message }} +
+{% endif %} + + +{% if results %} +
+
+ Showing results for '{{ keywords }}': +
+ +
+
+ {% for result in results %} + +
+
+ + {{ loop.index }}. {{ result.title }} + +
+
+

+ {{ result.content }} +

+
+ {{ result.date }} +
+
+ + {% endfor %} +
+
+ {% endif %} + + + + + {% endblock %} \ No newline at end of file diff --git a/app/utils/__init__.py b/app/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/requirements.txt b/requirements.txt index ba9af619..97a8e2bc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -29,6 +29,7 @@ packaging==20.8 Pillow==8.1.0 pluggy==0.13.1 priority==1.3.0 +psycopg2==2.8.6 py==1.10.0 pydantic==1.7.3 pyparsing==2.4.7 diff --git a/tests/conftest.py b/tests/conftest.py index a838375d..a5e0defd 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,8 +1,9 @@ import pytest +from app.config import PSQL_ENVIRONMENT +from app.database.database import Base from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker -from app.database.database import Base pytest_plugins = [ 'tests.user_fixture', @@ -13,11 +14,25 @@ 'smtpdfix', ] -SQLALCHEMY_TEST_DATABASE_URL = "sqlite:///./test.db" +# When testing in a PostgreSQL environment please make sure that: +# - Base string is a PSQL string +# - app.config.PSQL_ENVIRONMENT is set to True + +if PSQL_ENVIRONMENT: + SQLALCHEMY_TEST_DATABASE_URL = ( + "postgresql://postgres:1234" + "@localhost/postgres" + ) + test_engine = create_engine( + SQLALCHEMY_TEST_DATABASE_URL + ) + +else: + SQLALCHEMY_TEST_DATABASE_URL = "sqlite:///./test.db" + test_engine = create_engine( + SQLALCHEMY_TEST_DATABASE_URL, connect_args={"check_same_thread": False} + ) -test_engine = create_engine( - SQLALCHEMY_TEST_DATABASE_URL, connect_args={"check_same_thread": False} -) TestingSessionLocal = sessionmaker( autocommit=False, autoflush=False, bind=test_engine) @@ -33,3 +48,19 @@ def session(): yield session session.close() Base.metadata.drop_all(bind=test_engine) + + +@pytest.fixture +def sqlite_engine(): + SQLALCHEMY_TEST_DATABASE_URL = "sqlite:///./test.db" + sqlite_test_engine = create_engine( + SQLALCHEMY_TEST_DATABASE_URL, connect_args={"check_same_thread": False} + ) + + TestingSession = sessionmaker( + autocommit=False, autoflush=False, bind=sqlite_test_engine) + + yield sqlite_test_engine + session = TestingSession() + session.close() + Base.metadata.drop_all(bind=sqlite_test_engine) diff --git a/tests/test_psql_environment.py b/tests/test_psql_environment.py new file mode 100644 index 00000000..bc2aeab4 --- /dev/null +++ b/tests/test_psql_environment.py @@ -0,0 +1,21 @@ +import pytest +from app.database.database import create_env_engine +from app.database.models import PSQLEnvironmentError +from app.main import create_tables + + +def test_main_create_tables_error(sqlite_engine): + raised_error = False + with pytest.raises(PSQLEnvironmentError): + create_tables(sqlite_engine, True) + raised_error = True + assert raised_error + + +def test_database_create_engine(): + sqlalchemy_database_url = "postgresql://postgres:1234@localhost/postgres" + engine = create_env_engine(True, sqlalchemy_database_url) + assert 'postgres' in str(engine.url) + sqlalchemy_database_url = "sqlite:///./test1.db" + engine = create_env_engine(False, sqlalchemy_database_url) + assert 'sqlite' in str(engine.url) diff --git a/tests/test_search.py b/tests/test_search.py new file mode 100644 index 00000000..6a62f844 --- /dev/null +++ b/tests/test_search.py @@ -0,0 +1,172 @@ +from datetime import datetime + +import pytest +from app.config import PSQL_ENVIRONMENT +from app.database.models import Event, User +from app.internal.search import get_results_by_keywords, get_stripped_keywords +from fastapi import status + + +class TestSearch: + '''Search feature test. Works with PostgreSQL''' + + SEARCH = '/search' + GOOD_KEYWORDS = [ + ({'keywords': 'lov'}, b'test'), + ({'keywords': 'very emotional'}, b'second event'), + ({'keywords': 'event'}, b'My second event'), + ({'keywords': 'event'}, b'My first event'), + ({'keywords': 'jam'}, b'is fun'), + ({'keywords': ' jam '}, b'is fun') + ] + BAD_KEYWORDS = [ + ({'keywords': ''}, b'Invalid'), + ({'keywords': 'ev!@&'}, b'No matching'), + ({'keywords': '[]'}, b'No matching'), + ({'keywords': 'firsttttt'}, b'No matching') + ] + NOT_PSQL_ENV_KEYWORDS = [ + ({'keywords': 'lov'}, b'No matching'), + ({'keywords': 'very emotional'}, b'No matching'), + ({'keywords': 'event'}, b'No matching'), + ({'keywords': 'jam'}, b'No matching'), + ({'keywords': ' jam '}, b'No matching'), + ({'keywords': ''}, b'Invalid') + ] + KEYWORDS_FOR_FUNC = [ + 'lov', + 'very emotional', + 'event', + 'jam', + ' jam ' + ] + + @staticmethod + def create_user(session): + user = User(username='testuser', email='test@abc.com', password='1234') + session.add(user) + session.commit() + return user + + @staticmethod + def add_event(session, title, content, owner_id): + event = Event( + title=title, + content=content, + start=datetime.today(), + end=datetime.today(), + owner_id=owner_id + ) + + session.add(event) + session.commit() + + @staticmethod + def create_data(session): + TestSearch.create_user(session) + events = [ + { + 'title': "My first event", + 'content': 'I am so excited', + 'owner_id': 1 + }, + { + 'title': "My second event", + 'content': 'I am very emotional', + 'owner_id': 1 + }, + { + 'title': "Pick up my nephews", + 'content': 'Very important', + 'owner_id': 1 + }, + { + 'title': "Solve this ticket", + 'content': 'I can do this', + 'owner_id': 1 + }, + { + 'title': "Jam with my friends", + 'content': "Jamming is fun", + 'owner_id': 1 + }, + { + 'title': 'test', + 'content': 'love string', + 'owner_id': 1 + } + ] + + for event in events: + TestSearch.add_event(session, + title=event['title'], + content=event['content'], + owner_id=event['owner_id']) + + @staticmethod + def test_search_page_exists(client): + resp = client.get(TestSearch.SEARCH) + assert resp.status_code == status.HTTP_200_OK + assert b'Search event by keyword' in resp.content + + +@pytest.mark.skipif(not PSQL_ENVIRONMENT, reason="Not PSQL environment") +@pytest.mark.parametrize('data, string', TestSearch.GOOD_KEYWORDS) +def test_search_good_keywords(data, string, client, session): + ts = TestSearch() + ts.create_data(session) + resp = client.post(ts.SEARCH, data=data) + assert string in resp.content + + +@pytest.mark.skipif(not PSQL_ENVIRONMENT, reason="Not PSQL environment") +@pytest.mark.parametrize('data, string', TestSearch.BAD_KEYWORDS) +def test_search_bad_keywords(data, string, client, session): + ts = TestSearch() + ts.create_data(session) + resp = client.post(ts.SEARCH, data=data) + assert string in resp.content + + +@pytest.mark.skipif(PSQL_ENVIRONMENT, reason="PSQL environment") +@pytest.mark.parametrize('data, string', TestSearch.NOT_PSQL_ENV_KEYWORDS) +def test_search_not_psql_env_keywords(data, string, client, session): + ts = TestSearch() + ts.create_data(session) + resp = client.post(ts.SEARCH, data=data) + assert string in resp.content + + +@pytest.mark.skipif(PSQL_ENVIRONMENT, reason="Not PSQL environment") +@pytest.mark.parametrize('input_string', TestSearch.KEYWORDS_FOR_FUNC) +def test_get_results_by_keywords_func(input_string, client, session): + ts = TestSearch() + ts.create_data(session) + assert not get_results_by_keywords(session, input_string, 1) + + +STRIPPED_KEYWORDS = [ + (" love string ", "love:* & string:*"), + ("test ", "test:*"), + ("i am awesome", "i:* & am:* & awesome:*"), + ("a lot of spaces", "a:* & lot:* & of:* & spaces:*") +] + + +@pytest.mark.parametrize('input_string, output_string', STRIPPED_KEYWORDS) +def test_search_stripped_keywords(input_string, output_string): + assert get_stripped_keywords(input_string) == output_string + + +def test_events_tsv_column_exists(): + column_created = True + + try: + Event.events_tsv + except AttributeError: + column_created = False + + if PSQL_ENVIRONMENT: + assert column_created + else: + assert not column_created