diff --git a/Makefile b/Makefile index d4ce561..fb5ea54 100644 --- a/Makefile +++ b/Makefile @@ -5,7 +5,7 @@ test_deps: lint: ruff check src - mypy --install-types --non-interactive --check-untyped-defs src + mypy --install-types --non-interactive --check-untyped-defs src test.py test: coverage run --branch --include 'src/*' -m unittest discover -s test -v diff --git a/src/pyotp/contrib/steam.py b/src/pyotp/contrib/steam.py index 1cd2bdd..87c608b 100644 --- a/src/pyotp/contrib/steam.py +++ b/src/pyotp/contrib/steam.py @@ -13,7 +13,12 @@ class Steam(TOTP): """ def __init__( - self, s: str, name: Optional[str] = None, issuer: Optional[str] = None, interval: int = 30, digits: int = 5 + self, + s: str, + name: Optional[str] = None, + issuer: Optional[str] = None, + interval: int = 30, + digits: int = 5, ) -> None: """ :param s: secret in base32 format diff --git a/src/pyotp/totp.py b/src/pyotp/totp.py index 9908d55..4b0aad4 100644 --- a/src/pyotp/totp.py +++ b/src/pyotp/totp.py @@ -2,7 +2,7 @@ import datetime import hashlib import time -from typing import Any, Optional, Union +from typing import Any, Literal, Optional, Union, overload from . import utils from .otp import OTP @@ -38,7 +38,7 @@ def __init__( self.interval = interval super().__init__(s=s, digits=digits, digest=digest, name=name, issuer=issuer) - def at(self, for_time: Union[int, datetime.datetime], counter_offset: int = 0) -> str: + def at(self, for_time: Union[float, datetime.datetime], counter_offset: int = 0) -> str: """ Accepts either a Unix timestamp integer or a datetime object. @@ -65,25 +65,53 @@ def now(self) -> str: """ return self.generate_otp(self.timecode(datetime.datetime.now())) - def verify(self, otp: str, for_time: Optional[datetime.datetime] = None, valid_window: int = 0) -> bool: + @overload + def verify( + self, + otp: str, + for_time: Optional[Union[datetime.datetime, float]] = None, + valid_window: int = 0, + return_timecode: Literal[False] = False, + ) -> bool: ... + + @overload + def verify( + self, + otp: str, + for_time: Optional[Union[datetime.datetime, float]] = None, + valid_window: int = 0, + return_timecode: Literal[True] = True, + ) -> Literal[False] | int: ... + + def verify( + self, + otp: str, + for_time: Optional[Union[datetime.datetime, float]] = None, + valid_window: int = 0, + return_timecode: bool = False, + ) -> bool | int: """ Verifies the OTP passed in against the current time OTP. :param otp: the OTP to check against :param for_time: Time to check OTP at (defaults to now) :param valid_window: extends the validity to this many counter ticks before and after the current one - :returns: True if verification succeeded, False otherwise + :param return_timecode: if True, on success return the timecode of the OTP (to be used to prevent replay attacks) + :returns: True or the matching timecode if verification succeeded (depending on return_timecode), False otherwise """ if for_time is None: for_time = datetime.datetime.now() + elif not isinstance(for_time, datetime.datetime): + for_time = datetime.datetime.fromtimestamp(int(for_time)) - if valid_window: - for i in range(-valid_window, valid_window + 1): - if utils.strings_equal(str(otp), str(self.at(for_time, i))): + base_timecode = self.timecode(for_time) + for i in range(-valid_window, valid_window + 1): + if utils.strings_equal(str(otp), str(self.generate_otp(base_timecode + i))): + if return_timecode: + return base_timecode + i + else: return True - return False - - return utils.strings_equal(str(otp), str(self.at(for_time))) + return False def provisioning_uri(self, name: Optional[str] = None, issuer_name: Optional[str] = None, **kwargs) -> str: """ diff --git a/test.py b/test.py index 2d33c60..833ce35 100755 --- a/test.py +++ b/test.py @@ -1,11 +1,13 @@ #!/usr/bin/env python import base64 +import contextlib import datetime import hashlib import os import sys -import unittest +import typing +import unittest.mock from urllib.parse import parse_qsl, urlparse from warnings import warn @@ -48,7 +50,9 @@ def test_provisioning_uri(self): self.assertEqual(url.netloc, "hotp") self.assertEqual(url.path, "/mark%40percival") self.assertEqual(dict(parse_qsl(url.query)), {"secret": "wrn3pqx5uqxqvnqr", "counter": "0"}) - self.assertEqual(hotp.provisioning_uri(), pyotp.parse_uri(hotp.provisioning_uri()).provisioning_uri()) + parsed_otp = pyotp.parse_uri(hotp.provisioning_uri()) + assert isinstance(parsed_otp, pyotp.HOTP) + self.assertEqual(hotp.provisioning_uri(), parsed_otp.provisioning_uri()) hotp = pyotp.HOTP("wrn3pqx5uqxqvnqr", name="mark@percival", initial_count=12) url = urlparse(hotp.provisioning_uri()) @@ -56,7 +60,9 @@ def test_provisioning_uri(self): self.assertEqual(url.netloc, "hotp") self.assertEqual(url.path, "/mark%40percival") self.assertEqual(dict(parse_qsl(url.query)), {"secret": "wrn3pqx5uqxqvnqr", "counter": "12"}) - self.assertEqual(hotp.provisioning_uri(), pyotp.parse_uri(hotp.provisioning_uri()).provisioning_uri()) + parsed_otp = pyotp.parse_uri(hotp.provisioning_uri()) + assert isinstance(parsed_otp, pyotp.HOTP) + self.assertEqual(hotp.provisioning_uri(), parsed_otp.provisioning_uri()) hotp = pyotp.HOTP("wrn3pqx5uqxqvnqr", name="mark@percival", issuer="FooCorp!") url = urlparse(hotp.provisioning_uri()) @@ -66,7 +72,9 @@ def test_provisioning_uri(self): self.assertEqual( dict(parse_qsl(url.query)), {"secret": "wrn3pqx5uqxqvnqr", "counter": "0", "issuer": "FooCorp!"} ) - self.assertEqual(hotp.provisioning_uri(), pyotp.parse_uri(hotp.provisioning_uri()).provisioning_uri()) + parsed_otp = pyotp.parse_uri(hotp.provisioning_uri()) + assert isinstance(parsed_otp, pyotp.HOTP) + self.assertEqual(hotp.provisioning_uri(), parsed_otp.provisioning_uri()) key = "c7uxuqhgflpw7oruedmglbrk7u6242vb" hotp = pyotp.HOTP(key, digits=8, digest=hashlib.sha256, name="baco@peperina", issuer="FooCorp") @@ -84,7 +92,9 @@ def test_provisioning_uri(self): "algorithm": "SHA256", }, ) - self.assertEqual(hotp.provisioning_uri(), pyotp.parse_uri(hotp.provisioning_uri()).provisioning_uri()) + parsed_otp = pyotp.parse_uri(hotp.provisioning_uri()) + assert isinstance(parsed_otp, pyotp.HOTP) + self.assertEqual(hotp.provisioning_uri(), parsed_otp.provisioning_uri()) hotp = pyotp.HOTP(key, digits=8, name="baco@peperina", issuer="Foo Corp", initial_count=10) url = urlparse(hotp.provisioning_uri()) @@ -95,7 +105,9 @@ def test_provisioning_uri(self): dict(parse_qsl(url.query)), {"secret": "c7uxuqhgflpw7oruedmglbrk7u6242vb", "counter": "10", "issuer": "Foo Corp", "digits": "8"}, ) - self.assertEqual(hotp.provisioning_uri(), pyotp.parse_uri(hotp.provisioning_uri()).provisioning_uri()) + parsed_otp = pyotp.parse_uri(hotp.provisioning_uri()) + assert isinstance(parsed_otp, pyotp.HOTP) + self.assertEqual(hotp.provisioning_uri(), parsed_otp.provisioning_uri()) code = pyotp.totp.TOTP("S46SQCPPTCNPROMHWYBDCTBZXV") self.assertEqual(code.provisioning_uri(), "otpauth://totp/Secret?secret=S46SQCPPTCNPROMHWYBDCTBZXV") @@ -141,7 +153,7 @@ class TOTPExampleValuesFromTheRFC(unittest.TestCase): def test_match_rfc(self): for digest, secret in self.RFC_VALUES: - totp = pyotp.TOTP(base64.b32encode(secret), 8, digest) + totp = pyotp.TOTP(base64.b32encode(secret).decode(), 8, digest) for utime, code in self.RFC_VALUES[(digest, secret)]: if utime > sys.maxsize: warn( @@ -162,17 +174,30 @@ def test_match_rfc_digit_length(self): def test_match_google_authenticator_output(self): totp = pyotp.TOTP("wrn3pqx5uqxqvnqr") - with Timecop(1297553958): + with timecop(1297553958): self.assertEqual(totp.now(), "102705") def test_validate_totp(self): totp = pyotp.TOTP("wrn3pqx5uqxqvnqr") - with Timecop(1297553958): + with timecop(1297553958): self.assertTrue(totp.verify("102705")) self.assertTrue(totp.verify("102705")) - with Timecop(1297553958 + 30): + with timecop(1297553958 + 30): self.assertFalse(totp.verify("102705")) + def test_return_timecode_on_verify(self): + totp = pyotp.TOTP("wrn3pqx5uqxqvnqr") + with timecop(1297553958): + timecode1 = totp.verify("102705", valid_window=1, return_timecode=True) + self.assertTrue(isinstance(timecode1, int)) + with timecop(1297553958 + 30): + timecode2 = totp.verify("102705", valid_window=1, return_timecode=True) + self.assertEqual(timecode1, timecode2) + + with timecop(1297553958 + 60): + timecode3 = totp.verify("102705", valid_window=1, return_timecode=True) + self.assertFalse(timecode3) + def test_input_before_epoch(self): totp = pyotp.TOTP("GEZDGNBVGY3TQOJQGEZDGNBVGY3TQOJQ") # -1 and -29.5 round down to 0 (epoch) @@ -183,9 +208,9 @@ def test_input_before_epoch(self): def test_validate_totp_with_digit_length(self): totp = pyotp.TOTP("GEZDGNBVGY3TQOJQGEZDGNBVGY3TQOJQ") - with Timecop(1111111111): + with timecop(1111111111): self.assertTrue(totp.verify("050471")) - with Timecop(1297553958 + 30): + with timecop(1297553958 + 30): self.assertFalse(totp.verify("050471")) def test_provisioning_uri(self): @@ -195,7 +220,9 @@ def test_provisioning_uri(self): self.assertEqual(url.netloc, "totp") self.assertEqual(url.path, "/mark%40percival") self.assertEqual(dict(parse_qsl(url.query)), {"secret": "wrn3pqx5uqxqvnqr"}) - self.assertEqual(totp.provisioning_uri(), pyotp.parse_uri(totp.provisioning_uri()).provisioning_uri()) + parsed_otp = pyotp.parse_uri(totp.provisioning_uri()) + assert isinstance(parsed_otp, pyotp.TOTP) + self.assertEqual(totp.provisioning_uri(), parsed_otp.provisioning_uri()) totp = pyotp.TOTP("wrn3pqx5uqxqvnqr", name="mark@percival", issuer="FooCorp!") url = urlparse(totp.provisioning_uri()) @@ -203,7 +230,9 @@ def test_provisioning_uri(self): self.assertEqual(url.netloc, "totp") self.assertEqual(url.path, "/FooCorp%21:mark%40percival") self.assertEqual(dict(parse_qsl(url.query)), {"secret": "wrn3pqx5uqxqvnqr", "issuer": "FooCorp!"}) - self.assertEqual(totp.provisioning_uri(), pyotp.parse_uri(totp.provisioning_uri()).provisioning_uri()) + parsed_otp = pyotp.parse_uri(totp.provisioning_uri()) + assert isinstance(parsed_otp, pyotp.TOTP) + self.assertEqual(totp.provisioning_uri(), parsed_otp.provisioning_uri()) key = "c7uxuqhgflpw7oruedmglbrk7u6242vb" totp = pyotp.TOTP(key, digits=8, interval=60, digest=hashlib.sha256, name="baco@peperina", issuer="FooCorp") @@ -221,7 +250,9 @@ def test_provisioning_uri(self): "algorithm": "SHA256", }, ) - self.assertEqual(totp.provisioning_uri(), pyotp.parse_uri(totp.provisioning_uri()).provisioning_uri()) + parsed_otp = pyotp.parse_uri(totp.provisioning_uri()) + assert isinstance(parsed_otp, pyotp.TOTP) + self.assertEqual(totp.provisioning_uri(), parsed_otp.provisioning_uri()) totp = pyotp.TOTP(key, digits=8, interval=60, name="baco@peperina", issuer="FooCorp") url = urlparse(totp.provisioning_uri()) @@ -232,7 +263,9 @@ def test_provisioning_uri(self): dict(parse_qsl(url.query)), {"secret": "c7uxuqhgflpw7oruedmglbrk7u6242vb", "issuer": "FooCorp", "digits": "8", "period": "60"}, ) - self.assertEqual(totp.provisioning_uri(), pyotp.parse_uri(totp.provisioning_uri()).provisioning_uri()) + parsed_otp = pyotp.parse_uri(totp.provisioning_uri()) + assert isinstance(parsed_otp, pyotp.TOTP) + self.assertEqual(totp.provisioning_uri(), parsed_otp.provisioning_uri()) totp = pyotp.TOTP(key, digits=8, name="baco@peperina", issuer="FooCorp") url = urlparse(totp.provisioning_uri()) @@ -243,7 +276,9 @@ def test_provisioning_uri(self): dict(parse_qsl(url.query)), {"secret": "c7uxuqhgflpw7oruedmglbrk7u6242vb", "issuer": "FooCorp", "digits": "8"}, ) - self.assertEqual(totp.provisioning_uri(), pyotp.parse_uri(totp.provisioning_uri()).provisioning_uri()) + parsed_otp = pyotp.parse_uri(totp.provisioning_uri()) + assert isinstance(parsed_otp, pyotp.TOTP) + self.assertEqual(totp.provisioning_uri(), parsed_otp.provisioning_uri()) def test_random_key_generation(self): self.assertEqual(len(pyotp.random_base32()), 32) @@ -274,30 +309,30 @@ def test_match_examples(self): def test_verify(self): steam = pyotp.contrib.Steam("BASE32SECRET3232") - with Timecop(1662883100): + with timecop(1662883100): self.assertTrue(steam.verify("N3G63")) - with Timecop(1662883100 + 30): + with timecop(1662883100 + 30): self.assertFalse(steam.verify("N3G63")) - with Timecop(946681223): + with timecop(946681223): self.assertTrue(steam.verify("7VP3X")) - with Timecop(946681223 + 30): + with timecop(946681223 + 30): self.assertFalse(steam.verify("7VP3X")) steam = pyotp.contrib.Steam("FMXNK4QEGKVPULRTADY6JIDK5VHUBGZW") - with Timecop(1662884261): + with timecop(1662884261): self.assertTrue(steam.verify("V6WKJ")) - with Timecop(1662884261 + 30): + with timecop(1662884261 + 30): self.assertFalse(steam.verify("V6WKJ")) - with Timecop(946681223): + with timecop(946681223): self.assertTrue(steam.verify("4MK54")) - with Timecop(946681223 + 30): + with timecop(946681223 + 30): self.assertFalse(steam.verify("4MK54")) -class CompareDigestTest(unittest.TestCase): - method = staticmethod(pyotp.utils.compare_digest) +class StringComparisonTest(unittest.TestCase): + method = staticmethod(pyotp.utils.strings_equal) def test_comparisons(self): self.assertTrue(self.method("", "")) @@ -308,10 +343,6 @@ def test_comparisons(self): self.assertFalse(self.method("a", "")) self.assertFalse(self.method("a" * 999 + "b", "a" * 1000)) - -class StringComparisonTest(CompareDigestTest): - method = staticmethod(pyotp.utils.strings_equal) - def test_fullwidth_input(self): self.assertTrue(self.method("xs12345", "xs12345")) @@ -387,6 +418,7 @@ def test_parse_steam(self): @unittest.skipIf(sys.version_info < (3, 6), "Skipping test that requires deterministic dict key enumeration") def test_algorithms(self): otp = pyotp.parse_uri("otpauth://totp?algorithm=SHA1&secret=GEZDGNBV&algorithm=SHA1") + assert isinstance(otp, pyotp.TOTP) self.assertEqual(hashlib.sha1, otp.digest) self.assertEqual(otp.at(0), "734055") self.assertEqual(otp.at(30), "662488") @@ -395,6 +427,7 @@ def test_algorithms(self): self.assertEqual(otp.provisioning_uri(name="n", issuer_name="i"), "otpauth://totp/i:n?secret=GEZDGNBV&issuer=i") otp = pyotp.parse_uri("otpauth://totp?algorithm=SHA1&secret=GEZDGNBV&algorithm=SHA1&period=60") + assert isinstance(otp, pyotp.TOTP) self.assertEqual(hashlib.sha1, otp.digest) self.assertEqual(otp.at(30), "734055") self.assertEqual(otp.at(60), "662488") @@ -403,6 +436,7 @@ def test_algorithms(self): ) otp = pyotp.parse_uri("otpauth://hotp?algorithm=SHA1&secret=GEZDGNBV&algorithm=SHA1") + assert isinstance(otp, pyotp.HOTP) self.assertEqual(hashlib.sha1, otp.digest) self.assertEqual(otp.at(0), "734055") self.assertEqual(otp.at(1), "662488") @@ -412,6 +446,7 @@ def test_algorithms(self): ) otp = pyotp.parse_uri("otpauth://hotp?algorithm=SHA1&secret=GEZDGNBV&algorithm=SHA1&counter=1") + assert isinstance(otp, pyotp.HOTP) self.assertEqual(hashlib.sha1, otp.digest) self.assertEqual(otp.at(0), "662488") self.assertEqual(otp.at(1), "289363") @@ -420,11 +455,13 @@ def test_algorithms(self): ) otp = pyotp.parse_uri("otpauth://totp?algorithm=SHA1&secret=GEZDGNBV&algorithm=SHA256") + assert isinstance(otp, pyotp.TOTP) self.assertEqual(hashlib.sha256, otp.digest) self.assertEqual(otp.at(0), "918961") self.assertEqual(otp.at(9000), "934470") otp = pyotp.parse_uri("otpauth://totp?algorithm=SHA1&secret=GEZDGNBV&algorithm=SHA512") + assert isinstance(otp, pyotp.TOTP) self.assertEqual(hashlib.sha512, otp.digest) self.assertEqual(otp.at(0), "816660") self.assertEqual(otp.at(9000), "524153") @@ -440,7 +477,7 @@ def test_algorithms(self): self.assertEqual(hashlib.sha512, otp.digest) otp = pyotp.parse_uri("otpauth://totp/Steam:?secret=FMXNK4QEGKVPULRTADY6JIDK5VHUBGZW&encoder=steam") - self.assertEqual(type(otp), pyotp.contrib.Steam) + assert isinstance(otp, pyotp.contrib.Steam) self.assertEqual(otp.at(0), "C5V56") self.assertEqual(otp.at(30), "QJY8Y") self.assertEqual(otp.at(60), "R3WQY") @@ -450,7 +487,7 @@ def test_algorithms(self): otp = pyotp.parse_uri( "otpauth://totp/Steam:?secret=FMXNK4QEGKVPULRTADY6JIDK5VHUBGZW&period=15&digits=7&encoder=steam" ) - self.assertEqual(type(otp), pyotp.contrib.Steam) + assert isinstance(otp, pyotp.contrib.Steam) self.assertEqual(otp.at(0), "C5V56") self.assertEqual(otp.at(30), "QJY8Y") self.assertEqual(otp.at(60), "R3WQY") @@ -459,29 +496,15 @@ def test_algorithms(self): pyotp.parse_uri("otpauth://totp?secret=abc&image=foobar") -class Timecop(object): - """ - Half-assed clone of timecop.rb, just enough to pass our tests. - """ - - def __init__(self, freeze_timestamp): - self.freeze_timestamp = freeze_timestamp - - def __enter__(self): - self.real_datetime = datetime.datetime - datetime.datetime = self.frozen_datetime() - - def __exit__(self, type, value, traceback): - datetime.datetime = self.real_datetime - - def frozen_datetime(self): - class FrozenDateTime(datetime.datetime): - @classmethod - def now(cls, **kwargs): - return cls.fromtimestamp(timecop.freeze_timestamp) +@contextlib.contextmanager +def timecop(freeze_timestamp: int) -> typing.Generator[None, None, None]: + class FrozenDateTime(datetime.datetime): + @classmethod + def now(cls, tz: datetime.tzinfo | None = None) -> "FrozenDateTime": + return cls.fromtimestamp(freeze_timestamp, tz=tz) - timecop = self - return FrozenDateTime + with unittest.mock.patch("datetime.datetime", FrozenDateTime): + yield if __name__ == "__main__":