diff --git a/src/Connection.cpp b/src/Connection.cpp index 010262e..39aa361 100644 --- a/src/Connection.cpp +++ b/src/Connection.cpp @@ -97,7 +97,7 @@ size_t Connection::Write(const uint8_t *data, size_t length, bool doPush, bool c // Try to send all the data const bool push = doPush || closeAfterSending; - u8_t flag = NETCONN_COPY | (push ? NETCONN_MORE : 0); + u8_t flag = NETCONN_COPY | (push ? 0 : NETCONN_MORE); size_t total = 0; size_t written = 0; @@ -107,18 +107,23 @@ size_t Connection::Write(const uint8_t *data, size_t length, bool doPush, bool c written = 0; rc = netconn_write_partly(conn, data + total, length - total, flag, &written); + // Note: ERR_MEM is not handled here because lwIP's netconn layer retries + // internally and never propagates ERR_MEM to the application layer. if (rc != ERR_OK && rc != ERR_WOULDBLOCK) { break; } + if (rc == ERR_WOULDBLOCK && written == 0) { + break; // send buffer full and no progress after timeout, avoid spinning + } } if (rc != ERR_OK) { - if (rc == ERR_RST || rc == ERR_CLSD) + if (rc == ERR_RST || rc == ERR_CLSD || rc == ERR_ABRT || rc == ERR_CONN) { SetState(ConnState::otherEndClosed); } - else + else if (rc != ERR_WOULDBLOCK) { // We failed to write the data. See above for possible mitigations. For now we just terminate the connection. debugPrintfAlways("Write fail len=%u err=%d\n", total, (int)rc); @@ -133,7 +138,7 @@ size_t Connection::Write(const uint8_t *data, size_t length, bool doPush, bool c Close(); } - return length; + return total; } size_t Connection::CanWrite() const @@ -162,9 +167,9 @@ void Connection::Poll() rc = netconn_recv_tcp_pbuf_flags(conn, &data, NETCONN_NOAUTORCVD); } - if (rc != ERR_WOULDBLOCK) + if (rc != ERR_WOULDBLOCK && rc != ERR_TIMEOUT) { - if (rc == ERR_RST || rc == ERR_CLSD || rc == ERR_CONN) + if (rc == ERR_RST || rc == ERR_CLSD || rc == ERR_CONN || rc == ERR_ABRT) { // Pend setting the state to other end closed if there is data to be read. // Otherwise, set it immediately. This is to avoid a case when a socket in RRF @@ -194,16 +199,18 @@ void Connection::Poll() } else if (state == ConnState::closePending) { - // We're about to close this connection and we're still waiting for the remaining data to be acknowledged - if (conn->pcb.tcp && !conn->pcb.tcp->unacked) + // The other end may have closed the connection with RST, which causes lwIP + // to free the PCB. Detect this and close immediately instead of waiting for + // the acknowledgement timer to expire. + if (!conn->pcb.tcp || (!conn->pcb.tcp->unsent && !conn->pcb.tcp->unacked)) { - // All data has been received, close this connection next time SetState(ConnState::closeReady); } else if (millis() - closeTimer >= MaxAckTime) { - // The acknowledgement timer has expired, abort this connection - Terminate(false); + // The acknowledgement timer has expired. The close was already initiated + // by RRF, so go straight to free rather than aborted to avoid a round-trip. + Terminate(true); } } else { } @@ -237,7 +244,10 @@ void Connection::Close() } FreePbuf(); SetState(ConnState::free); - listener->Notify(); + if (listener) + { + listener->Notify(); + } break; case ConnState::closePending: // we already asked to close @@ -303,7 +313,10 @@ void Connection::Terminate(bool external) } FreePbuf(); SetState((external) ? ConnState::free : ConnState::aborted); - listener->Notify(); + if (external && listener) + { + listener->Notify(); + } } void Connection::Accept(Listener *listener, struct netconn* conn, uint8_t protocol) @@ -422,7 +435,8 @@ void Connection::Report() { connectedSockets |= (1 << i); } - else if (Connection::Get(i).GetState() == ConnState::otherEndClosed) + else if (Connection::Get(i).GetState() == ConnState::otherEndClosed + || Connection::Get(i).GetState() == ConnState::aborted) { otherEndClosedSockets |= (1 << i); } diff --git a/src/Listener.cpp b/src/Listener.cpp index a63728d..01fb0d8 100644 --- a/src/Listener.cpp +++ b/src/Listener.cpp @@ -127,13 +127,14 @@ bool Listener::Start(uint16_t port, uint32_t ip, int protocol, int maxConns) void Listener::Stop() { - netconn_close(conn); - netconn_delete(conn); + struct netconn *savedConn = conn; + netconn_close(savedConn); + netconn_delete(savedConn); for (int i = 0; i < MaxConnections; i++) { Listener *listener = listeners[i]; - if (listener && listener->conn == conn) + if (listener && listener->conn == savedConn) { delete listener; listeners[i] = nullptr; @@ -234,6 +235,7 @@ void Listener::Notify() if (listener->protocol == protocolFtpData) { debugPrintf("accept conn, stop listen on port %u\n", listener->port); + c->listener = nullptr; // clear before Stop() deletes the listener listener->Stop(); // don't listen for further connections } }