diff --git a/imoog/__init__.py b/imoog/__init__.py index 7eaebcc..665e2cd 100644 --- a/imoog/__init__.py +++ b/imoog/__init__.py @@ -1 +1 @@ -# dummy __init__ \ No newline at end of file +# dummy __init__ diff --git a/imoog/app.py b/imoog/app.py index 6df7d1c..32828fd 100644 --- a/imoog/app.py +++ b/imoog/app.py @@ -1,10 +1,10 @@ from __future__ import annotations import importlib -from typing import Tuple +from typing import Tuple, Type from starlette.applications import Starlette -from starlette.routing import Route, Mount +from starlette.routing import Route from starlette.middleware.cors import CORSMiddleware from starlette.middleware.trustedhost import TrustedHostMiddleware from starlette.middleware.httpsredirect import HTTPSRedirectMiddleware @@ -15,7 +15,7 @@ delete_file ) from imoog import settings - +from imoog.database.drivers import Driver routes = [ Route( @@ -36,7 +36,7 @@ ] app = Starlette(routes=routes) -app.add_middleware(CORSMiddleware, allow_origins = settings.CORS_ALLOWED_ORIGINS) +app.add_middleware(CORSMiddleware, allow_origins=settings.CORS_ALLOWED_ORIGINS) app.add_middleware(TrustedHostMiddleware, allowed_hosts=settings.ALLOWED_HOSTS) if settings.ENFORCE_SECURE_SCHEME is True: @@ -45,33 +45,45 @@ app.image_cache = None app.db_driver = None + def _check_driver() -> Tuple[type, str]: _driver_path = settings.DATABASE_DRIVERS["driver"] package = importlib.import_module(_driver_path) - driver = package._DRIVER - _type = package._DRIVER_TYPE - return (driver, _type) + + # noinspection PyUnresolvedReferences + driver: Type[Driver] = package.DB_DRIVER + # noinspection PyUnresolvedReferences + driver_type: str = package.DB_DRIVER_TYPE + + return driver, driver_type + def _check_cache_driver() -> Tuple[type, str]: - _driver_path = settings.CACHE_DRIVERS["driver"] + _driver_path: str = settings.CACHE_DRIVERS.get("driver") package = importlib.import_module(_driver_path) - driver = package._DRIVER - _type = package._DRIVER_TYPE - return (driver, _type) - + + # noinspection PyUnresolvedReferences + driver = package.C_DRIVER + # noinspection PyUnresolvedReferences + _type = package.C_DRIVER_TYPE + + return driver, _type + @app.on_event("startup") async def on_startup(): # connect to databases and ready caches. driver_class, _ = _check_driver() config = settings.DATABASE_DRIVERS["config"] + driver = driver_class() await driver.connect(**config) + cache_driver_class, _ = _check_cache_driver() cache_config = settings.CACHE_DRIVERS["config"] cache_driver = cache_driver_class() - cache_config["max_cache_size"] = settings.MAX_CACHE_SIZE # we pass this into + cache_config["max_cache_size"] = settings.MAX_CACHE_SIZE # we pass this into # the connect function of the cache driver regardless whether its the # memory cache driver or not. await cache_driver.connect(**cache_config) @@ -82,9 +94,11 @@ async def on_startup(): @app.on_event("shutdown") async def on_shutdown(): + # noinspection PyBroadException try: await app.db_driver.cleanup() - await app.cache_driver.cleanup() + await app.image_cache.cleanup() + except Exception: # disregard any errors that occur # within the driver cleanup diff --git a/imoog/cache/basecache.py b/imoog/cache/basecache.py index a37c1d5..37206ac 100644 --- a/imoog/cache/basecache.py +++ b/imoog/cache/basecache.py @@ -1,7 +1,8 @@ class Cache: # a base cache handler + def __init__(self): - self._connection = None # this is similar + self._connection = None # this is similar # to the database handlers, except this can also just # be a regular dictionary diff --git a/imoog/cache/memorycache.py b/imoog/cache/memorycache.py index eceb451..01b2045 100644 --- a/imoog/cache/memorycache.py +++ b/imoog/cache/memorycache.py @@ -1,15 +1,17 @@ -from typing import Dict +from typing import Dict, Union from .basecache import Cache class InMemoryCache(Cache): + _max_size: int + async def connect(self, **kwargs): # to keep it consistent # we will indeed have an async function # to just register a dict - self._connection: Dict[str, bytes] = {} - max_size = kwargs["max_cache_size"] + self._connection: Dict[str, Union[bytes, str]] = {} + max_size = kwargs.get("max_cache_size") self._max_size = max_size return self._connection @@ -17,11 +19,12 @@ async def connect(self, **kwargs): async def get(self, key: str): mime = self._connection.get(key + "__mime__") image = self._connection.get(key) - R = (image, mime) - if None in R: + r = (image, mime) + + if None in r: return None - - return R + + return r async def delete(self, key: str) -> bool: try: @@ -29,25 +32,27 @@ async def delete(self, key: str) -> bool: self._connection.pop(key + "__mime__") # if the first pop fails, the second one will # fail too as its dependent on the first one to exist + except KeyError: - return False # failed + return False # failed - return True # success + return True # success async def set( - self, - key: str, - image: bytes, - mime: str + self, + key: str, + image: bytes, + mime: str ): self._connection[key] = image self._connection[key + "__mime__"] = mime - + async def cleanup(self): # believe it or not # there is a cleanup operation for a dictionary # we'll free the memory by clearing the dictionary self._connection.clear() -_DRIVER = InMemoryCache -_DRIVER_TYPE = "MEMORY" \ No newline at end of file + +C_DRIVER = InMemoryCache +C_DRIVER_TYPE = "MEMORY" diff --git a/imoog/cache/rediscache.py b/imoog/cache/rediscache.py index e0b7604..5d2c377 100644 --- a/imoog/cache/rediscache.py +++ b/imoog/cache/rediscache.py @@ -25,12 +25,12 @@ async def get(self, image: str): self._connection.get(image), self._connection.get(image + "__mime__") ] - R = await asyncio.gather(*futures) + r = await asyncio.gather(*futures) - if None in R: + if None in r: return None - return tuple(R) + return tuple(r) async def delete(self, image: str) -> bool: futures = [ @@ -38,10 +38,13 @@ async def delete(self, image: str) -> bool: self._connection.delete(image + "__mime__") ] + # noinspection PyBroadException try: await asyncio.gather(*futures) + except Exception: return False + else: return True @@ -60,5 +63,6 @@ async def set( async def cleanup(self): return await self._connection.close() -_DRIVER = RedisCache -_DRIVER_TYPE = "REDIS" + +C_DRIVER = RedisCache +C_DRIVER_TYPE = "REDIS" diff --git a/imoog/database/drivers.py b/imoog/database/drivers.py index 1bf2623..01061bc 100644 --- a/imoog/database/drivers.py +++ b/imoog/database/drivers.py @@ -1,10 +1,8 @@ import zlib -from typing import Any - class Driver: def __init__(self): - self._connection = None # this connection instance + self.pool = None # this connection instance # is filled in the connect method. self.identifier = None # this is custom per database driver. # this attribute will be None until the connect method is called. diff --git a/imoog/database/mongo.py b/imoog/database/mongo.py index 873938c..155a670 100644 --- a/imoog/database/mongo.py +++ b/imoog/database/mongo.py @@ -12,10 +12,14 @@ AsyncIOMotorClient ) -from .drivers import Driver # import the base driver impl +from .drivers import Driver # import the base driver impl class MongoDriver(Driver): + TABLE_NAME: str + _connection: AsyncIOMotorCollection + _parent_client: AsyncIOMotorClient + async def connect(self, **kwargs): self.identifier = "mongo" @@ -37,14 +41,14 @@ async def insert( image: bytes, name: str, mime: str - ): + ) -> int: insert = { "_id": name, "image": image, "mime": mime } await self._connection.insert_one(insert) - return 0 + return 0 # again, why? async def fetch( self, @@ -52,10 +56,12 @@ async def fetch( ) -> Tuple[bytes, str]: query = {"_id": name} result = await self._connection.find_one(query) + image = result["image"] mime = result["mime"] decompressed = self.decompress(image) - return (decompressed, mime) + + return decompressed, mime async def delete( self, @@ -63,19 +69,23 @@ async def delete( ) -> bool: # Returns whether the delete succeeded or failed query = {"_id": name} + # noinspection PyBroadException try: await self._connection.delete_one(query) + except Exception: return False + else: return True async def fetch_all(self) -> Tuple[List[Mapping[str, Any]], str]: - documents = await self._connection.find({}).to_list(length=99999999999999999999) # big number - return (documents, "_id") + documents = await self._connection.find({}).to_list(length=99999999999999999999) # big number + return documents, "_id" async def cleanup(self): return self._parent_client.close() -_DRIVER = MongoDriver -_DRIVER_TYPE = "MONGO" \ No newline at end of file + +DB_DRIVER = MongoDriver +DB_DRIVER_TYPE = "MONGO" diff --git a/imoog/database/postgres.py b/imoog/database/postgres.py index 0861783..14fa1e6 100644 --- a/imoog/database/postgres.py +++ b/imoog/database/postgres.py @@ -4,47 +4,73 @@ Tuple, Mapping, List, - Any + Any, + Dict, + TypeVar, + Iterable, + Optional, + Union ) import asyncpg -from .drivers import Driver # import the base driver impl +from .drivers import Driver # import the base driver impl + + +T = TypeVar("T") class PostgresDriver(Driver): + TABLE_NAME: str + db_pool: asyncpg.Pool + + @staticmethod + def save_unpack(values: Optional[Iterable[T]], value_count: int = 0) -> Union[Tuple[T], Tuple[None]]: + if values is None: + return (None, )*value_count + + if not isinstance(values, list): + values = tuple(values) + + if len(values) < value_count: + return values + (None, )*(len(values)-value_count) + + + async def connect(self, **kwargs): self.identifier = "postgres" - connection_uri = kwargs["connection_uri"] - max_size = kwargs["max_size"] - min_size = kwargs["min_size"] - table_name = kwargs["table_name"] + connection_uri: str = kwargs.get("connection_uri") + max_size: int = kwargs.get("max_size") + min_size: int = kwargs.get("min_size") + table_name: str = kwargs.get("table_name") + other_kwargs: Dict[str, Any] = kwargs.get("kwargs", {}) - self._table_name: str = table_name + self.TABLE_NAME: str = table_name pool = await asyncpg.create_pool( connection_uri, min_size=min_size, - max_size=max_size + max_size=max_size, + **other_kwargs ) - self._connection = pool + self.db_pool = pool # Creating the table in psql on connect # if it doesn't exist. - async with self._connection.acquire() as conn: + async with self.db_pool.acquire() as cursor: query = ( f"CREATE TABLE IF NOT EXISTS {table_name}(" - "name TEXT PRIMARY KEY," - "image BYTEA," - "mime TEXT" - ")" + "name TEXT PRIMARY KEY NOT NULL UNIQUE," + "image BYTEA NOT NULL," + "mime TEXT NOT NULL" + ");" ) - await conn.execute(query) + await cursor.execute(query) - return self._connection + return self.pool async def insert( self, @@ -52,62 +78,60 @@ async def insert( name: str, mime: str ): - table_name = self._table_name - async with self._connection.acquire() as conn: + async with self.db_pool.acquire() as cursor: query = ( - f"INSERT INTO {table_name} (name, image, mime) VALUES ($1, $2, $3)" - ) # this isn't vulnerable to SQL injection, as we have HARD-CODED values + f"INSERT INTO {self.TABLE_NAME} (name, image, mime) VALUES ($1, $2, $3);" + ) # this isn't vulnerable to SQL injection, as we have HARD-CODED values # controlled by YOU. So if you mess up, this isn't on us. - await conn.execute(query, name, image, mime) + await cursor.execute(query, name, image, mime) - return 0 + return 0 # why? async def fetch( self, name: str ) -> Tuple[bytes, str]: - table_name = self._table_name - - async with self._connection.acquire() as conn: + async with self.db_pool.acquire() as cursor: query = ( - f"SELECT image, mime FROM {table_name} " + f"SELECT image, mime FROM {self.TABLE_NAME} " "WHERE name = $1" ) - row = await conn.fetchrow(query, name) - - image = row["image"] - mime = row["mime"] + row: Tuple[bytes, str] = await cursor.fetchrow(query, name) + + image, mime = self.save_unpack(row, 2) decompressed = self.decompress(image) - return (decompressed, mime) + + return decompressed, mime async def delete( self, name: str ) -> bool: - table_name = self._table_name - + # noinspection PyBroadException try: - async with self._connection.acquire() as conn: + async with self.db_pool.acquire() as conn: query = ( - f"DELETE FROM {table_name} " + f"DELETE FROM {self.TABLE_NAME} " "WHERE name = $1" ) await conn.execute(query, name) - except Exception: + + except Exception as e: return False + else: return True async def fetch_all(self) -> List[Mapping[str, Any]]: - table_name = self._table_name + async with self.db_pool.acquire() as conn: + rows: List[Tuple] = await conn.fetch(f"SELECT * FROM {self.TABLE_NAME}") - async with self._connection.acquire() as conn: - rows = await conn.fetch(f"SELECT * FROM {table_name}") - - return (rows, "name") + # noinspection PyTypeChecker + return rows, "name" # why return this? async def cleanup(self): - return await self._connection.close() + return await self.db_pool.close() + -_DRIVER = PostgresDriver -_DRIVER_TYPE = "POSTGRES" +DB_DRIVER = PostgresDriver +DB_DRIVER_TYPE = "POSTGRES" diff --git a/imoog/opengraph.py b/imoog/opengraph.py index c0b8f35..9e4db59 100644 --- a/imoog/opengraph.py +++ b/imoog/opengraph.py @@ -3,11 +3,11 @@ BASE_TAG = r'' -def generate_opengraph_tag(property: str, content: str) -> str: - property = "og:" + property +def generate_opengraph_tag(prop: str, content: str) -> str: + prop = "og:" + prop tag = BASE_TAG.format( - property=property, + property=prop, content=content ) @@ -16,8 +16,8 @@ def generate_opengraph_tag(property: str, content: str) -> str: def generate_tags_from_dict(properties: Dict[str, str]) -> List[str]: tags = [] - for property, content in properties.items(): - tag = generate_opengraph_tag(property, content) + for prop, content in properties.items(): + tag = generate_opengraph_tag(prop, content) tags.append(tag) - return tags \ No newline at end of file + return tags diff --git a/imoog/views.py b/imoog/views.py index 4484bad..f52e663 100644 --- a/imoog/views.py +++ b/imoog/views.py @@ -43,23 +43,28 @@ async def upload_file(request: Request) -> JSONResponse: "message": "Invalid auth." } return JSONResponse(content, status_code=401) - + + # noinspection PyTypeChecker form: MultiDict = await request.form() _file: UploadFile = form.get("file") + if _file is None: print("\033[91mReceived an upload request that has not given us a 'file' to upload\033[0m") content = { "message": "No 'file' given." } return JSONResponse(content, status_code=400) - + await _file.seek(0) image = await _file.read() mime = _file.content_type compressed = zlib.compress(image, level=int(COMPRESSION_LEVEL)) - name = ''.join(random.SystemRandom().choice(string.ascii_uppercase + string.digits) for _ in range(FILE_NAME_LENGTH)) + name = ''.join( + random.SystemRandom().choice(string.ascii_uppercase + string.digits) for _ in range(FILE_NAME_LENGTH)) + await request.app.db_driver.insert(image=compressed, name=name, mime=mime) - await request.app.image_cache.set(key=name, image=image, mime=mime) # we insert the UNCOMPRESSED image into the cache, to avoid + await request.app.image_cache.set(key=name, image=image, + mime=mime) # we insert the UNCOMPRESSED image into the cache, to avoid # having to decompress later. The whole point of the cache is to retrieve the value without any # extra processing required. @@ -73,11 +78,12 @@ async def upload_file(request: Request) -> JSONResponse: "file_id": name, "file_ext": file_ext } - return JSONResponse(content, status_code=200) + return JSONResponse(content) + async def deliver_file(request: Request) -> Response: file_id: str = request.path_params["name"] - file_id = file_id.split(".")[0] # if a file extension has been provided, we split on the '.', + file_id = file_id.split(".")[0] # if a file extension has been provided, we split on the '.', # and return the file name. # possible mime is just for opengraph attributes @@ -93,7 +99,7 @@ async def deliver_file(request: Request) -> Response: media_url = _urljoin(str(request.base_url), DELIVER_ENDPOINT + file_id) media_url += "?opengraph_pass=yes" - + media_tag = generate_opengraph_tag(media_property, media_url) common_tags = generate_tags_from_dict(OPENGRAPH_PROPERTIES) common_tags.append(media_tag) @@ -104,26 +110,26 @@ async def deliver_file(request: Request) -> Response: ) return HTMLResponse( og_html, - status_code=200, media_type="text/html" ) cache_result = await request.app.image_cache.get(file_id) + if cache_result is None: - image, mime = await request.app.db_driver.fetch(file_id) # this will decompress it for us too. + image, mime = await request.app.db_driver.fetch(file_id) # this will decompress it for us too. else: image, mime = cache_result if image is None: # return an empty response with a 404, or a custom status code # which would've been provided in settings.py - return Response(None, status_code=NOT_FOUND_STATUS_CODE) + return Response(status_code=NOT_FOUND_STATUS_CODE) return Response( image, - status_code=200, media_type=mime - ) # return a 200 response with the correct mime type. Example: image/png + ) # return a 200 response with the correct mime type. Example: image/png + async def delete_file(request: Request): if REQUIRE_AUTH_FOR_DELETE is True: @@ -132,10 +138,10 @@ async def delete_file(request: Request): content = { "message": "Invalid auth." } - return JSONResponse(content, status_code=401) - + return JSONResponse(content, status_code=401) # this was really critical before i fixed it + file_id: str = request.path_params["name"] - file_id = file_id.split(".")[0] # if a file extension has been provided, we split on the '.', + file_id = file_id.split(".")[0] # if a file extension has been provided, we split on the '.', # and return the file name. # delete file from database first @@ -150,7 +156,5 @@ async def delete_file(request: Request): status_code = 500 return Response( - None, - status_code=status_code, - media_type=None + status_code=status_code )