diff --git a/CHANGELOG.md b/CHANGELOG.md index 431e0470e..a675a6006 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -44,6 +44,8 @@ time reading the [rich documentation](https://rich.readthedocs.io/). - [argparse_example.py](https://github.com/python-cmd2/cmd2/blob/main/examples/argparse_example.py) - [command_sets.py](https://github.com/python-cmd2/cmd2/blob/main/examples/command_sets.py) - [getting_started.py](https://github.com/python-cmd2/cmd2/blob/main/examples/getting_started.py) + - Optimized performance of terminal fixup during command finalization by replacing `stty sane` + with `termios.tcsetattr` - Bug Fixes - Fixed a redirection bug where `cmd2` could unintentionally overwrite an application's diff --git a/cmd2/cmd2.py b/cmd2/cmd2.py index c42ec40c5..a1c257917 100644 --- a/cmd2/cmd2.py +++ b/cmd2/cmd2.py @@ -508,6 +508,19 @@ def __init__( # Commands that will run at the beginning of the command loop self._startup_commands: list[str] = [] + # Store initial termios settings to restore after each command. + # This is a faster way of accomplishing what "stty sane" does. + self._initial_termios_settings = None + if not sys.platform.startswith('win') and self.stdin.isatty(): + try: + import io + import termios + + self._initial_termios_settings = termios.tcgetattr(self.stdin.fileno()) + except (ImportError, io.UnsupportedOperation, termios.error): + # This can happen if termios isn't available or stdin is a pseudo-TTY + self._initial_termios_settings = None + # If a startup script is provided and exists, then execute it in the startup commands if startup_script: startup_script = os.path.abspath(os.path.expanduser(startup_script)) @@ -2822,14 +2835,15 @@ def onecmd_plus_hooks( def _run_cmdfinalization_hooks(self, stop: bool, statement: Statement | None) -> bool: """Run the command finalization hooks.""" - with self.sigint_protection: - if not sys.platform.startswith('win') and self.stdin.isatty(): - # Before the next command runs, fix any terminal problems like those - # caused by certain binary characters having been printed to it. - import subprocess - - proc = subprocess.Popen(['stty', 'sane']) # noqa: S607 - proc.communicate() + if self._initial_termios_settings is not None and self.stdin.isatty(): + import io + import termios + + # Before the next command runs, fix any terminal problems like those + # caused by certain binary characters having been printed to it. + with self.sigint_protection, contextlib.suppress(io.UnsupportedOperation, termios.error): + # This can fail if stdin is a pseudo-TTY, in which case we just ignore it + termios.tcsetattr(self.stdin.fileno(), termios.TCSANOW, self._initial_termios_settings) data = plugin.CommandFinalizationData(stop, statement) for func in self._cmdfinalization_hooks: diff --git a/tests/test_cmd2.py b/tests/test_cmd2.py index 82497b723..2586da8c7 100644 --- a/tests/test_cmd2.py +++ b/tests/test_cmd2.py @@ -1068,16 +1068,70 @@ def test_cmdloop_without_rawinput() -> None: assert out == expected -@pytest.mark.skipif(sys.platform.startswith('win'), reason="stty sane only run on Linux/Mac") -def test_stty_sane(base_app, monkeypatch) -> None: - """Make sure stty sane is run on Linux/Mac after each command if stdin is a terminal""" - with mock.patch('sys.stdin.isatty', mock.MagicMock(name='isatty', return_value=True)): - # Mock out the subprocess.Popen call so we don't actually run stty sane - m = mock.MagicMock(name='Popen') - monkeypatch.setattr("subprocess.Popen", m) +def test_cmdfinalizations_runs(base_app, monkeypatch) -> None: + """Make sure _run_cmdfinalization_hooks is run after each command.""" + with ( + mock.patch('sys.stdin.isatty', mock.MagicMock(name='isatty', return_value=True)), + mock.patch('sys.stdin.fileno', mock.MagicMock(name='fileno', return_value=0)), + ): + monkeypatch.setattr(base_app.stdin, "fileno", lambda: 0) + monkeypatch.setattr(base_app.stdin, "isatty", lambda: True) + + cmd_fin = mock.MagicMock(name='cmdfinalization') + monkeypatch.setattr("cmd2.Cmd._run_cmdfinalization_hooks", cmd_fin) base_app.onecmd_plus_hooks('help') - m.assert_called_once_with(['stty', 'sane']) + cmd_fin.assert_called_once() + + +@pytest.mark.skipif(sys.platform.startswith('win'), reason="termios is not available on Windows") +@pytest.mark.parametrize( + ('is_tty', 'settings_set', 'raised_exception', 'should_call'), + [ + (True, True, None, True), + (True, True, 'termios_error', True), + (True, True, 'unsupported_operation', True), + (False, True, None, False), + (True, False, None, False), + ], +) +def test_restore_termios_settings(base_app, monkeypatch, is_tty, settings_set, raised_exception, should_call): + """Test that terminal settings are restored after a command and that errors are suppressed.""" + import io + import termios # Mock termios since it's imported within the method + + termios_mock = mock.MagicMock() + # The error attribute needs to be the actual exception for isinstance checks + termios_mock.error = termios.error + monkeypatch.setitem(sys.modules, 'termios', termios_mock) + + # Set the exception to be raised by tcsetattr + if raised_exception == 'termios_error': + termios_mock.tcsetattr.side_effect = termios.error("test termios error") + elif raised_exception == 'unsupported_operation': + termios_mock.tcsetattr.side_effect = io.UnsupportedOperation("test io error") + + # Set initial termios settings so the logic will run + if settings_set: + termios_settings = ["dummy settings"] + base_app._initial_termios_settings = termios_settings + else: + base_app._initial_termios_settings = None + termios_settings = None # for the assert + + # Mock stdin to make it look like a TTY + monkeypatch.setattr(base_app.stdin, "isatty", lambda: is_tty) + monkeypatch.setattr(base_app.stdin, "fileno", lambda: 0) + + # Run a command to trigger _run_cmdfinalization_hooks + # This should not raise an exception + base_app.onecmd_plus_hooks('help') + + # Verify that tcsetattr was called with the correct arguments + if should_call: + termios_mock.tcsetattr.assert_called_once_with(0, termios_mock.TCSANOW, termios_settings) + else: + termios_mock.tcsetattr.assert_not_called() def test_sigint_handler(base_app) -> None: