diff --git a/src/click/_termui_impl.py b/src/click/_termui_impl.py index 286271c8a..d3ef3cf84 100644 --- a/src/click/_termui_impl.py +++ b/src/click/_termui_impl.py @@ -32,6 +32,19 @@ V = t.TypeVar("V") + +class _BufferedTextPagerStream(t.Protocol): + buffer: t.BinaryIO + + +def _has_binary_buffer( + stream: t.BinaryIO | t.TextIO, +) -> t.TypeGuard[_BufferedTextPagerStream]: + # TextIO is wider than TextIOWrapper; text-only streams such as StringIO + # are valid TextIO values but do not expose a binary buffer to wrap. + return getattr(stream, "buffer", None) is not None + + if os.name == "nt": BEFORE_BAR = "\r" AFTER_BAR = "\n" @@ -417,18 +430,17 @@ def get_pager_file(color: bool | None = None) -> t.Generator[t.TextIO, None, Non default is autodetection. """ with _pager_contextmanager(color=color) as (stream, encoding, color): - if not isinstance(stream, MaybeStripAnsi): - if hasattr(stream, "buffer"): - # Real TextIO with buffer - unwrap and wrap in MaybeStripAnsi - stream = MaybeStripAnsi(stream.buffer, color=color, encoding=encoding) - elif not getattr(stream, "encoding", None): - # BinaryIO - wrap directly in MaybeStripAnsi - stream = MaybeStripAnsi(stream, color=color, encoding=encoding) - else: - # StringIO - add .color attribute only, no ANSI stripping - stream.color = color # type: ignore[attr-defined] + # Split streams by capabilities rather than the abstract TextIO / + # BinaryIO annotations: buffered text streams can be unwrapped to bytes, + # while text-only streams are yielded as-is. + if _has_binary_buffer(stream): + # Text stream backed by a binary buffer. + stream = MaybeStripAnsi(stream.buffer, color=color, encoding=encoding) + elif isinstance(stream, t.BinaryIO): + # Binary stream + stream = MaybeStripAnsi(stream, color=color, encoding=encoding) try: - yield t.cast(t.TextIO, stream) + yield stream finally: stream.flush() diff --git a/tests/test_termui.py b/tests/test_termui.py index 7aa260084..58870533d 100644 --- a/tests/test_termui.py +++ b/tests/test_termui.py @@ -1,11 +1,16 @@ +import contextlib +import io import platform import shlex +import shutil +import sys import tempfile import time from unittest.mock import patch import pytest +import click import click._termui_impl from click._compat import WIN from click._termui_impl import Editor @@ -620,6 +625,154 @@ def test_pager_shlex_split(pager_env, expected_parts): assert shlex.split(pager_env) == expected_parts +def _get_real_pager_command() -> str: + """Return a platform pager used to exercise the BinaryIO pager branch.""" + pager_name = "more" if WIN else "cat" + pager_path = shutil.which(pager_name) + assert pager_path is not None, f"{pager_name} not available" + return pager_path + + +def _run_get_pager_file_with_real_pager(monkeypatch, capfd, writer, color=False): + """Run through the platform pager backend selected by ``PAGER``.""" + monkeypatch.setattr(click._termui_impl, "isatty", lambda _: True) + monkeypatch.setitem( + click._termui_impl.os.environ, "PAGER", _get_real_pager_command() + ) + + with click.get_pager_file(color=color) as pager: + writer(pager) + + # The real pager writes to the process stdout; stderr should stay quiet. + out, err = capfd.readouterr() + assert err == "" + return out + + +def _write_pager_from_multiple_sites(pager): + pager.write("prefix\n") + click.echo("middle", file=pager) + pager.write("suffix\n") + + +@pytest.mark.parametrize( + ("writer", "color", "expected"), + [ + pytest.param( + _write_pager_from_multiple_sites, + False, + "prefix\nmiddle\nsuffix\n", + id="multiple write sites", + ), + pytest.param( + lambda pager: pager.write("hello\n"), False, "hello\n", id="plain text" + ), + pytest.param( + lambda pager: pager.write(click.style("hello", fg="red") + "\n"), + False, + "hello\n", + id="strip ansi", + ), + pytest.param( + lambda pager: pager.write(click.style("hello", fg="red") + "\n"), + True, + click.style("hello", fg="red") + "\n", + id="preserve ansi", + ), + pytest.param(lambda pager: pager.write(""), False, "", id="empty string"), + ], +) +def test_get_pager_file_with_real_pager_binary_stream( + monkeypatch, capfd, writer, color, expected +): + """A real pager should exercise the BinaryIO branch on Unix and Windows.""" + output = _run_get_pager_file_with_real_pager( + monkeypatch, capfd, writer, color=color + ) + + assert output == expected + + +@pytest.mark.parametrize( + ("color", "expected"), + [ + pytest.param(False, "hello\n", id="strip ansi"), + pytest.param(True, click.style("hello", fg="red") + "\n", id="preserve ansi"), + ], +) +def test_get_pager_file_nullpager_wraps_textio_stream( + monkeypatch, tmp_path, color, expected +): + """When paging falls back to a real TextIO stream, ``.buffer`` is wrapped.""" + pager_out = tmp_path / "pager_out.txt" + + with pager_out.open("w", encoding="utf-8") as text_stream: + monkeypatch.setattr( + click._termui_impl, "_default_text_stdout", lambda: text_stream + ) + monkeypatch.setattr( + click._termui_impl, "isatty", lambda stream: stream is not sys.stdin + ) + + with click.get_pager_file(color=color) as pager: + pager.write(click.style("hello", fg="red") + "\n") + + assert pager_out.read_text(encoding="utf-8") == expected + + +def test_get_pager_file_nullpager_keeps_stringio_stream(monkeypatch): + """The no-stdout fallback should keep a text-only stream and set ``.color``.""" + + created = [] + + def make_stringio(): + stream = io.StringIO() + created.append(stream) + return stream + + monkeypatch.setattr(sys, "stdout", None) + monkeypatch.setattr(click._termui_impl, "StringIO", make_stringio) + monkeypatch.setattr(click._termui_impl, "isatty", lambda _: False) + + styled_text = click.style("hello", fg="red") + + with click.get_pager_file(color=False) as pager: + assert pager is created[0] + pager.write(styled_text) + + assert created[0].getvalue() == styled_text + + +def test_get_pager_file_flushes_stream_on_exception(monkeypatch): + """Exceptions should still flush the yielded stream in ``finally``.""" + + class FlushableTextStream(io.StringIO): + def __init__(self): + super().__init__() + self.color = None + self.flush_calls = 0 + + def flush(self): + self.flush_calls += 1 + + stream = FlushableTextStream() + + @contextlib.contextmanager + def pager_contextmanager(color=None): + yield stream, "utf-8", color + + monkeypatch.setattr( + click._termui_impl, "_pager_contextmanager", pager_contextmanager + ) + + with pytest.raises(RuntimeError, match="boom"): + with click.get_pager_file() as pager: + assert pager is stream + raise RuntimeError("boom") + + assert stream.flush_calls == 1 + + def test_editor_unclosed_quote(): """An unclosed quote in the editor command raises ValueError.""" with pytest.raises(ValueError, match="No closing quotation"):