Skip to content

Commit bd7cff2

Browse files
committed
fix: notify and cancel windows
Signed-off-by: Cody Edwards <edwards@amazon.com>
1 parent 242474c commit bd7cff2

12 files changed

Lines changed: 191 additions & 144 deletions

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,6 @@ addopts = [
143143
"--cov-report=html:build/coverage",
144144
"--cov-report=xml:build/coverage/coverage.xml",
145145
"--cov-report=term-missing",
146-
"--numprocesses=auto",
147146
"--timeout=30"
148147
]
149148
markers = [

scripts/windows_service_test.py

Lines changed: 34 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import socket
44
import logging
5+
import io
56
from threading import Event
67
from typing import Optional
78

@@ -15,6 +16,7 @@
1516
import shlex
1617
import win32con
1718
import win32api
19+
import pytest
1820
from getpass import getpass
1921

2022

@@ -54,32 +56,43 @@ def SvcDoRun(self):
5456
)
5557
code_location = os.environ["CODE_LOCATION"]
5658
pytest_args = os.environ.get("PYTEST_ARGS", None)
59+
log_file_name = os.path.join(code_location, "test.log")
5760

58-
args = ["pytest", os.path.join(code_location, "test")]
61+
args = [os.path.join(code_location, "test"), "-p", "no:xdist"]
5962

6063
if pytest_args:
6164
args.extend(shlex.split(pytest_args, posix=False))
6265

63-
logging.basicConfig(
64-
filename=os.path.join(code_location, "test.log"),
65-
encoding="utf-8",
66-
level=logging.INFO,
67-
filemode="w",
68-
)
69-
process = subprocess.Popen(
70-
args,
71-
stdout=subprocess.PIPE,
72-
stderr=subprocess.STDOUT,
73-
text=True,
74-
cwd=code_location,
75-
)
76-
77-
while True:
78-
output = process.stdout.readline()
79-
if not output and process.poll() is not None:
80-
break
81-
82-
logger.info(output.strip())
66+
# logging.basicConfig(
67+
# filename=log_file_name,
68+
# encoding="utf-8",
69+
# level=logging.INFO,
70+
# filemode="w",
71+
# )
72+
# process = subprocess.Popen(
73+
# args,
74+
# stdout=subprocess.PIPE,
75+
# stderr=subprocess.STDOUT,
76+
# text=True,
77+
# cwd=code_location,
78+
# creationflags=subprocess.CREATE_NEW_PROCESS_GROUP | subprocess.CREATE_NO_WINDOW | subprocess.CREATE_NEW_CONSOLE,
79+
# )
80+
81+
with open(log_file_name, mode="w") as f:
82+
sys.stdout = f
83+
sys.stderr = f
84+
85+
ret = pytest.main(args)
86+
87+
# while True:
88+
# output = process.stdout.readline()
89+
# if not output and process.poll() is not None:
90+
# break
91+
92+
# logger.info(output.strip())
93+
94+
# logger.info(sys.stdout.getvalue())
95+
# logger.error(sys.stderr.getvalue())
8396

8497
servicemanager.LogMsg(
8598
servicemanager.EVENTLOG_INFORMATION_TYPE,

src/openjd/sessions/_scripts/_windows/_signal_win_subprocess.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,6 @@ def signal_process(pgid: int):
5555
if not kernel32.GenerateConsoleCtrlEvent(CTRL_BREAK_EVENT, pgid):
5656
raise ctypes.WinError()
5757

58-
if not kernel32.FreeConsole():
59-
raise ctypes.WinError()
60-
if not kernel32.AttachConsole(ATTACH_PARENT_PROCESS):
61-
raise ctypes.WinError()
62-
6358

6459
if __name__ == "__main__":
6560
signal_process(int(sys.argv[1]))

src/openjd/sessions/_subprocess.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,7 @@
1111
from queue import Queue, Empty
1212
from subprocess import DEVNULL, PIPE, STDOUT, Popen, list2cmdline, run
1313
from threading import Event, Thread
14-
from typing import Any
15-
from typing import Callable, Literal, Optional, Sequence, cast
14+
from typing import Callable, Literal, Optional, Sequence, cast, Any
1615

1716
from ._linux._capabilities import try_use_cap_kill
1817
from ._linux._sudo import find_sudo_child_process_group_id
@@ -623,17 +622,22 @@ def _windows_notify_subprocess(self) -> None:
623622
str(WINDOWS_SIGNAL_SUBPROC_SCRIPT_PATH),
624623
str(self._process.pid),
625624
]
626-
result = run(
627-
cmd,
628-
stdout=PIPE,
629-
stderr=STDOUT,
630-
stdin=DEVNULL,
631-
creationflags=CREATE_NEW_PROCESS_GROUP | CREATE_NO_WINDOW,
625+
process = LoggingSubprocess(
626+
logger=self._logger,
627+
args=cmd,
628+
encoding=self._encoding,
629+
user=self._user,
630+
os_env_vars=self._os_env_vars,
631+
working_dir=self._working_dir,
632+
creation_flags=CREATE_NO_WINDOW,
632633
)
633-
if result.returncode != 0:
634+
635+
# Blocking call
636+
process.run()
637+
638+
if process.exit_code != 0:
634639
self._logger.warning(
635-
f"Failed to send signal 'CTRL_BREAK_EVENT' to subprocess {self._process.pid}: %s",
636-
result.stdout.decode("utf-8"),
640+
f"Failed to send signal 'CTRL_BREAK_EVENT' to subprocess {self._process.pid}",
637641
extra=LogExtraInfo(
638642
openjd_log_content=LogContent.PROCESS_CONTROL | LogContent.EXCEPTION_INFO
639643
),

src/openjd/sessions/_win32/_helpers.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,17 @@
2525
# Constants
2626
LOGON32_LOGON_INTERACTIVE,
2727
LOGON32_PROVIDER_DEFAULT,
28+
PI_NOUI,
29+
PROFILEINFO,
2830
# Functions
2931
CloseHandle,
3032
CreateEnvironmentBlock,
3133
DestroyEnvironmentBlock,
3234
GetCurrentProcessId,
3335
LogonUserW,
3436
ProcessIdToSessionId,
37+
LoadUserProfileW,
38+
UnloadUserProfile,
3539
)
3640

3741

@@ -166,3 +170,33 @@ def environment_block_from_dict(env: dict[str, str]) -> c_wchar_p:
166170
env_block_str = null_delimited + "\0"
167171

168172
return c_wchar_p(env_block_str)
173+
174+
175+
def load_user_profile(token: HANDLE, username: str) -> PROFILEINFO:
176+
"""
177+
Load the user profile for the given logon token and user name
178+
179+
NOTE: The caller *MUST* call unload_user_profile when finished with the user profile
180+
181+
See: https://learn.microsoft.com/en-us/windows/win32/api/userenv/nf-userenv-loaduserprofilew
182+
"""
183+
profile_info = PROFILEINFO()
184+
profile_info.dwSize = sizeof(PROFILEINFO)
185+
profile_info.lpUserName = username
186+
profile_info.dwFlags = PI_NOUI
187+
profile_info.lpProfilePath = None
188+
189+
if not LoadUserProfileW(token, byref(profile_info)):
190+
raise WinError()
191+
192+
return profile_info
193+
194+
195+
def unload_user_profile(token: HANDLE, profile_info: PROFILEINFO) -> None:
196+
"""
197+
Unload the user profile for the given token and profile.
198+
199+
See: https://learn.microsoft.com/en-us/windows/win32/api/userenv/nf-userenv-unloaduserprofile
200+
"""
201+
if not UnloadUserProfile(token, profile_info.hProfile):
202+
raise WinError()

test/openjd/sessions/conftest.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
from openjd.sessions._win32._helpers import ( # type: ignore
2222
get_current_process_session_id,
2323
logon_user_context,
24+
load_user_profile,
25+
unload_user_profile,
2426
)
2527

2628
TEST_RUNNING_IN_WINDOWS_SESSION_0 = 0 == get_current_process_session_id()
@@ -251,10 +253,10 @@ def windows_user() -> Generator[WindowsSessionUser, None, None]:
251253

252254
if TEST_RUNNING_IN_WINDOWS_SESSION_0:
253255
try:
254-
# Note: We don't load the user profile; it's currently not needed by our tests,
255-
# and we're getting a mysterious crash when unloading it.
256256
with logon_user_context(user, password) as logon_token:
257+
profile_info = load_user_profile(logon_token, user)
257258
yield WindowsSessionUser(user, logon_token=logon_token)
259+
unload_user_profile(logon_token, profile_info)
258260
except OSError as e:
259261
raise Exception(
260262
f"Could not logon as {user}. Check the password that was provided in {WIN_PASS_ENV_VAR}."

test/openjd/sessions/test_runner_base.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def test_basic_run(self, tmp_path: Path) -> None:
128128
logger=MagicMock(), session_working_directory=tmp_path, callback=callback
129129
) as runner:
130130
# WHEN
131-
runner._run([sys.executable, "-c", "import time; time.sleep(0.25)"])
131+
runner._run([sys.executable.lower().replace("pythonservice.exe", "python.exe"), "-c", "import time; time.sleep(0.25)"])
132132

133133
# THEN
134134
assert runner.state == ScriptRunnerState.RUNNING
@@ -173,7 +173,7 @@ def test_working_dir_is_cwd(
173173
logger=logger, session_working_directory=tmp_path, startup_directory=tmp_path
174174
) as runner:
175175
# WHEN
176-
runner._run([sys.executable, "-c", "import os; print(os.getcwd())"])
176+
runner._run([sys.executable.lower().replace("pythonservice.exe", "python.exe"), "-c", "import os; print(os.getcwd())"])
177177
# Wait until the process exits.
178178
while runner.state == ScriptRunnerState.RUNNING:
179179
time.sleep(0.1)
@@ -189,7 +189,7 @@ def test_failing_run(self, tmp_path: Path) -> None:
189189
# GIVEN
190190
with TerminatingRunner(logger=MagicMock(), session_working_directory=tmp_path) as runner:
191191
# WHEN
192-
runner._run([sys.executable, "-c", "import sys; sys.exit(1)"])
192+
runner._run([sys.executable.lower().replace("pythonservice.exe", "python.exe"), "-c", "import sys; sys.exit(1)"])
193193

194194
# THEN
195195
while runner.state == ScriptRunnerState.RUNNING:
@@ -256,7 +256,7 @@ def test_run_with_env_vars(
256256
# WHEN
257257
runner._run(
258258
[
259-
sys.executable,
259+
sys.executable.lower().replace("pythonservice.exe", "python.exe"),
260260
"-c",
261261
r"import os;print(*(f'{k} = {v}' for k,v in os.environ.items()), sep='\n')",
262262
]
@@ -297,8 +297,8 @@ def test_run_as_posix_user(
297297
# WHEN
298298
runner._run(
299299
[
300-
# Note: Intentionally not `sys.executable`. Reasons:
301-
# 1) This is a cross-account command, and sys.executable may be in a user-specific venv
300+
# Note: Intentionally not `sys.executable.lower().replace("pythonservice.exe", "python.exe")`. Reasons:
301+
# 1) This is a cross-account command, and sys.executable.lower().replace("pythonservice.exe", "python.exe") may be in a user-specific venv
302302
# 2) This test is, generally, intended to be run in a docker container where the system
303303
# python is the correct version that we want to run under.
304304
"python",
@@ -348,8 +348,8 @@ def test_run_as_posix_user_with_env_vars(
348348
# WHEN
349349
runner._run(
350350
[
351-
# Note: Intentionally not `sys.executable`. Reasons:
352-
# 1) This is a cross-account command, and sys.executable may be in a user-specific venv
351+
# Note: Intentionally not `sys.executable.lower().replace("pythonservice.exe", "python.exe")`. Reasons:
352+
# 1) This is a cross-account command, and sys.executable.lower().replace("pythonservice.exe", "python.exe") may be in a user-specific venv
353353
# 2) This test is, generally, intended to be run in a docker container where the system
354354
# python is the correct version that we want to run under.
355355
"python",
@@ -467,8 +467,8 @@ def test_run_as_windows_user_with_env_vars(
467467
# WHEN
468468
runner._run(
469469
[
470-
# Note: Intentionally not `sys.executable`. Reasons:
471-
# 1) This is a cross-account command, and sys.executable may be in a user-specific venv
470+
# Note: Intentionally not `sys.executable.lower().replace("pythonservice.exe", "python.exe")`. Reasons:
471+
# 1) This is a cross-account command, and sys.executable.lower().replace("pythonservice.exe", "python.exe") may be in a user-specific venv
472472
# 2) This test is, generally, intended to be run in a docker container where the system
473473
# python is the correct version that we want to run under.
474474
"python",
@@ -518,8 +518,8 @@ def test_does_not_inherit_env_vars_posix(
518518
# WHEN
519519
runner._run(
520520
[
521-
# Note: Intentionally not `sys.executable`. Reasons:
522-
# 1) This is a cross-account command, and sys.executable may be in a user-specific venv
521+
# Note: Intentionally not `sys.executable.lower().replace("pythonservice.exe", "python.exe")`. Reasons:
522+
# 1) This is a cross-account command, and sys.executable.lower().replace("pythonservice.exe", "python.exe") may be in a user-specific venv
523523
# 2) This test is, generally, intended to be run in a docker container where the system
524524
# python is the correct version that we want to run under.
525525
"python",
@@ -570,8 +570,8 @@ def test_does_not_inherit_env_vars_windows(
570570
) as runner:
571571
# WHEN
572572
py_script = f"import os; v=os.environ.get('{var_name}'); print('NOT_PRESENT' if v is None else v)"
573-
# Use the default 'python' rather than 'sys.executable' since we typically do not have access to
574-
# sys.executable when running with impersonation since it's in a hatch environment for the local user.
573+
# Use the default 'python' rather than 'sys.executable.lower().replace("pythonservice.exe", "python.exe")' since we typically do not have access to
574+
# sys.executable.lower().replace("pythonservice.exe", "python.exe") when running with impersonation since it's in a hatch environment for the local user.
575575
runner._run(["python", "-c", py_script])
576576

577577
# THEN
@@ -598,11 +598,11 @@ def test_cannot_run_twice(self, tmp_path: Path) -> None:
598598
logger=MagicMock(), session_working_directory=tmp_path, callback=callback
599599
) as runner:
600600
# WHEN
601-
runner._run([sys.executable, "-c", "print('hello')"])
601+
runner._run([sys.executable.lower().replace("pythonservice.exe", "python.exe"), "-c", "print('hello')"])
602602

603603
# THEN
604604
with pytest.raises(RuntimeError):
605-
runner._run([sys.executable, "-c", "print('hello')"])
605+
runner._run([sys.executable.lower().replace("pythonservice.exe", "python.exe"), "-c", "print('hello')"])
606606

607607
@pytest.mark.usefixtures("message_queue", "queue_handler")
608608
def test_run_action(
@@ -623,7 +623,7 @@ def test_run_action(
623623
python_app_loc = (Path(__file__).parent / "support_files" / "app_20s_run.py").resolve()
624624
symtab = SymbolTable(
625625
source={
626-
"Task.PythonInterpreter": sys.executable,
626+
"Task.PythonInterpreter": sys.executable.lower().replace("pythonservice.exe", "python.exe"),
627627
"Task.ScriptFile": str(python_app_loc),
628628
}
629629
)
@@ -683,7 +683,7 @@ def test_run_action_default_timeout(
683683
python_app_loc = (Path(__file__).parent / "support_files" / "app_20s_run.py").resolve()
684684
symtab = SymbolTable(
685685
source={
686-
"Task.PythonInterpreter": sys.executable,
686+
"Task.PythonInterpreter": sys.executable.lower().replace("pythonservice.exe", "python.exe"),
687687
"Task.ScriptFile": str(python_app_loc),
688688
}
689689
)
@@ -755,7 +755,7 @@ def test_cancel_terminate(
755755
logger=logger, session_working_directory=tmp_path, callback=callback
756756
) as runner:
757757
python_app_loc = (Path(__file__).parent / "support_files" / "app_20s_run.py").resolve()
758-
runner._run([sys.executable, str(python_app_loc)])
758+
runner._run([sys.executable.lower().replace("pythonservice.exe", "python.exe"), str(python_app_loc)])
759759

760760
# WHEN
761761
runner.cancel()
@@ -789,7 +789,7 @@ def test_run_with_time_limit(
789789
python_app_loc = (Path(__file__).parent / "support_files" / "app_20s_run.py").resolve()
790790

791791
# WHEN
792-
runner._run([sys.executable, str(python_app_loc)], time_limit=timedelta(seconds=1))
792+
runner._run([sys.executable.lower().replace("pythonservice.exe", "python.exe"), str(python_app_loc)], time_limit=timedelta(seconds=1))
793793

794794
# THEN
795795
# Wait until the process exits. We'll be in CANCELING state between when the timeout is reached
@@ -818,7 +818,7 @@ def test_cancel_notify(
818818
python_app_loc = (
819819
Path(__file__).parent / "support_files" / "app_20s_run_ignore_signal.py"
820820
).resolve()
821-
runner._run([sys.executable, str(python_app_loc)])
821+
runner._run([sys.executable.lower().replace("pythonservice.exe", "python.exe"), str(python_app_loc)])
822822

823823
# WHEN
824824
secs = 2 if not is_windows() else 5
@@ -971,7 +971,7 @@ def test_cancel_double_cancel_notify(
971971
python_app_loc = (
972972
Path(__file__).parent / "support_files" / "app_20s_run_ignore_signal.py"
973973
).resolve()
974-
runner._run([sys.executable, str(python_app_loc)])
974+
runner._run([sys.executable.lower().replace("pythonservice.exe", "python.exe"), str(python_app_loc)])
975975

976976
# WHEN
977977
secs = 2 if not is_windows() else 5

0 commit comments

Comments
 (0)