Skip to content

Commit 6c773fe

Browse files
committed
fix a failure in 3.14t test run , closes #9
1 parent 7ef85b0 commit 6c773fe

File tree

4 files changed

+61
-21
lines changed

4 files changed

+61
-21
lines changed

justfile

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,6 @@ set shell := ["bash", "-euo", "pipefail", "-c"]
22

33
tls-test-certs outdir="tests/fixtures/tls":
44
./scripts/generate-test-tls-certs.sh {{outdir}}
5+
6+
test:
7+
uv run python -m unittest discover -s tests

src/stream_transport.rs

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -734,6 +734,19 @@ impl WorkerThread {
734734
}
735735

736736
impl StreamTransportCore {
737+
fn close_extra_socket_with_py(&self, py: Python<'_>) {
738+
let socket = self
739+
.state
740+
.lock()
741+
.expect("poisoned transport state")
742+
.extra
743+
.get("socket")
744+
.map(|value| value.clone_ref(py));
745+
if let Some(socket) = socket {
746+
let _ = socket.bind(py).call_method0("close");
747+
}
748+
}
749+
737750
fn register_worker(&self, worker: WorkerThread) {
738751
self.workers
739752
.lock()
@@ -1089,6 +1102,8 @@ impl StreamTransportCore {
10891102
self.call_protocol_method1(py, &callback, &context, context_needs_run, arg)?;
10901103
}
10911104

1105+
self.close_extra_socket_with_py(py);
1106+
10921107
if let Some(server) = server.and_then(|weak| weak.upgrade()) {
10931108
server.connection_lost();
10941109
}
@@ -1133,7 +1148,8 @@ impl StreamTransportCore {
11331148
}
11341149
}
11351150

1136-
fn detach_underlying_stream(&self) {
1151+
fn detach_underlying_stream(&self, py: Python<'_>) {
1152+
self.close_extra_socket_with_py(py);
11371153
let mut state = self.state.lock().expect("poisoned transport state");
11381154
state.detached = true;
11391155
state.closing = true;
@@ -1387,7 +1403,7 @@ impl StreamTransportCore {
13871403
let fd = fd_ops::dup_raw_fd(socket.bind(py).call_method0("fileno")?.extract()?)
13881404
.map_err(|err| PyRuntimeError::new_err(err.to_string()))?;
13891405

1390-
self.detach_underlying_stream();
1406+
self.detach_underlying_stream(py);
13911407
let _ = self.writer_tx.send(WriterCommand::Stop);
13921408
if let Some(fd) = self.runtime_socket_fd() {
13931409
let _ = self
@@ -1529,6 +1545,15 @@ impl StreamTransportCore {
15291545
}
15301546

15311547
impl ServerCore {
1548+
fn close_python_sockets(&self) {
1549+
let _ = Python::try_attach(|py| -> PyResult<()> {
1550+
for socket in &self.sockets {
1551+
let _ = socket.bind(py).call_method0("close");
1552+
}
1553+
Ok(())
1554+
});
1555+
}
1556+
15321557
fn report_error(&self, err: PyErr, message: &str) {
15331558
let _ = Python::try_attach(|py| -> PyResult<()> {
15341559
let context = PyDict::new(py);
@@ -1582,6 +1607,8 @@ impl ServerCore {
15821607
state.listeners.clear();
15831608
}
15841609

1610+
self.close_python_sockets();
1611+
15851612
for task in self
15861613
.accept_tasks
15871614
.lock()
@@ -1997,16 +2024,7 @@ fn socket_from_owned_raw(fd: fd_ops::RawFd) -> PyResult<Socket> {
19972024
}
19982025

19992026
fn detached_socket_handle(py: Python<'_>, socket_obj: &Py<PyAny>) -> PyResult<fd_ops::RawFd> {
2000-
#[cfg(windows)]
2001-
{
2002-
return socket_obj.call_method0(py, "fileno")?.extract(py);
2003-
}
2004-
2005-
#[cfg(not(windows))]
2006-
{
2007-
let dup = socket_obj.call_method0(py, "dup")?;
2008-
dup.call_method0(py, "detach")?.extract(py)
2009-
}
2027+
socket_obj.call_method0(py, "detach")?.extract(py)
20102028
}
20112029

20122030
fn tcp_family(stream: &StdTcpStream) -> c_int {

tests/test_compat.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ async def fake_sock_connect(self, sock, address):
8080
self.assertEqual(rsloop.run(main()), [42001, 42003, 42002, 42004])
8181

8282
def test_create_connection_happy_eyeballs_staggers_attempts(self) -> None:
83-
async def main() -> float:
83+
async def main() -> tuple[float, int]:
8484
loop = asyncio.get_running_loop()
8585
done = loop.create_future()
8686

@@ -139,12 +139,16 @@ async def fake_sock_connect(self, sock, address):
139139
)
140140
await asyncio.wait_for(done, 1.0)
141141
transport.close()
142-
return time.monotonic() - started
142+
await asyncio.sleep(0)
143+
socket_fileno = transport.get_extra_info("socket").fileno()
144+
return time.monotonic() - started, socket_fileno
143145
finally:
144146
server.close()
145147
await server.wait_closed()
146148

147-
self.assertLess(rsloop.run(main()), 0.15)
149+
elapsed, socket_fileno = rsloop.run(main())
150+
self.assertLess(elapsed, 0.15)
151+
self.assertEqual(socket_fileno, -1)
148152

149153
def test_shutdown_default_executor_timeout_warns_and_falls_back_to_nowait(
150154
self,
@@ -161,6 +165,7 @@ def shutdown(self, wait):
161165
time.sleep(0.2)
162166

163167
loop.set_default_executor(DummyExecutor())
168+
164169
def capture_warning(message, category=None, stacklevel=1, source=None):
165170
messages.append(str(message))
166171
return None

tests/test_tls.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,11 @@ def test_create_default_context_marks_default_verify_paths(self) -> None:
5050
self.assertTrue(context.__dict__.get("_rsloop_use_default_verify_paths"))
5151

5252
def test_create_connection_and_server_tls_round_trip(self) -> None:
53-
async def main() -> str:
53+
async def main() -> tuple[str, tuple[int, ...]]:
5454
loop = asyncio.get_running_loop()
5555
done: asyncio.Future[str] = loop.create_future()
56+
result = ""
57+
server_fds: tuple[int, ...] = ()
5658

5759
class ServerProtocol(asyncio.Protocol):
5860
def connection_made(self, transport: asyncio.BaseTransport) -> None:
@@ -104,17 +106,24 @@ def connection_lost(self, exc: Exception | None) -> None:
104106
await asyncio.wait_for(client_protocol.result, 5.0), "TLS-OK"
105107
)
106108
self.assertEqual(await asyncio.wait_for(done, 5.0), "server-closed")
107-
return "ok"
109+
result = "ok"
108110
finally:
109111
server.close()
110112
await server.wait_closed()
113+
server_fds = tuple(sock.fileno() for sock in server.sockets)
114+
115+
return result, server_fds
111116

112-
self.assertEqual(rsloop.run(main()), "ok")
117+
result, server_fds = rsloop.run(main())
118+
self.assertEqual(result, "ok")
119+
self.assertEqual(server_fds, (-1,))
113120

114121
@unittest.skipIf(os.name == "nt", "Unix sockets are Unix-only")
115122
def test_create_unix_connection_and_server_tls_round_trip(self) -> None:
116-
async def main() -> str:
123+
async def main() -> tuple[str, tuple[int, ...]]:
117124
loop = asyncio.get_running_loop()
125+
result = ""
126+
server_fds: tuple[int, ...] = ()
118127

119128
class ServerProtocol(asyncio.Protocol):
120129
def connection_made(self, transport: asyncio.BaseTransport) -> None:
@@ -155,12 +164,17 @@ def connection_lost(self, exc: Exception | None) -> None:
155164
ssl=client_ctx,
156165
server_hostname="localhost",
157166
)
158-
return await asyncio.wait_for(client_protocol.done, 5.0)
167+
result = await asyncio.wait_for(client_protocol.done, 5.0)
159168
finally:
160169
server.close()
161170
await server.wait_closed()
171+
server_fds = tuple(sock.fileno() for sock in server.sockets)
172+
173+
return result, server_fds
162174

163-
self.assertEqual(rsloop.run(main()), "unix:tls")
175+
result, server_fds = rsloop.run(main())
176+
self.assertEqual(result, "unix:tls")
177+
self.assertEqual(server_fds, (-1,))
164178

165179
def test_connect_accepted_socket_tls_round_trip(self) -> None:
166180
async def main() -> str:

0 commit comments

Comments
 (0)