@@ -120,6 +120,24 @@ inline bool operator == (const std::pmr::string& a, const std::string& b)
120120 return std::string_view (a) == std::string_view (b);
121121}
122122
123+ template <typename S, typename proxy_session>
124+ concept supports_stream_rate_limit = requires (S s, proxy_session p)
125+ {
126+ p.stream_rate_limit (s, 1 );
127+ };
128+
129+ template <typename S, typename proxy_session>
130+ concept supports_stream_expires_after = requires (S s, proxy_session p)
131+ {
132+ p.stream_expires_after (s, std::chrono::seconds (0 ));
133+ };
134+
135+ template <typename S>
136+ concept supports_shutdown = requires (S s)
137+ {
138+ s.shutdown (boost::asio::socket_base::shutdown_receive);
139+ };
140+
123141namespace proxy {
124142
125143 namespace net = boost::asio;
@@ -318,7 +336,7 @@ R"x*x*x(<html>
318336 0x16 , // ssl
319337 };
320338
321- inline const std::map<std::string , std::string > global_mimes =
339+ inline const std::map<std::string_view , std::string_view > global_mimes =
322340 {
323341 { " .html" , " text/html; charset=utf-8" },
324342 { " .htm" , " text/html; charset=utf-8" },
@@ -880,14 +898,11 @@ R"x*x*x(<html>
880898 if (!ret)
881899 co_return ;
882900
883- size_t l2r_transferred = 0 ;
884- size_t r2l_transferred = 0 ;
885-
886- co_await (
887- transfer (m_local_socket, m_remote_socket, l2r_transferred)
901+ auto [l2r_transferred, r2l_transferred] = co_await (
902+ transfer (m_local_socket, m_remote_socket)
888903 &&
889- transfer (m_remote_socket, m_local_socket, r2l_transferred )
890- );
904+ transfer (m_remote_socket, m_local_socket)
905+ );
891906
892907 XLOG_DBG << " connection id: "
893908 << m_connection_id
@@ -1819,13 +1834,10 @@ R"x*x*x(<html>
18191834 // 发起数据传输协程.
18201835 if (command == SOCKS_CMD_CONNECT)
18211836 {
1822- size_t l2r_transferred = 0 ;
1823- size_t r2l_transferred = 0 ;
1824-
1825- co_await (
1826- transfer (m_local_socket, m_remote_socket, l2r_transferred)
1837+ auto [l2r_transferred, r2l_transferred ] = co_await (
1838+ transfer (m_local_socket, m_remote_socket)
18271839 &&
1828- transfer (m_remote_socket, m_local_socket, r2l_transferred )
1840+ transfer (m_remote_socket, m_local_socket)
18291841 );
18301842
18311843 XLOG_DBG << " connection id: "
@@ -2265,14 +2277,11 @@ R"x*x*x(<html>
22652277 if (error_code != SOCKS4_REQUEST_GRANTED)
22662278 co_return ;
22672279
2268- size_t l2r_transferred = 0 ;
2269- size_t r2l_transferred = 0 ;
2270-
2271- co_await (
2272- transfer (m_local_socket, m_remote_socket, l2r_transferred)
2280+ auto [l2r_transferred , r2l_transferred ]= co_await (
2281+ transfer (m_local_socket, m_remote_socket)
22732282 &&
2274- transfer (m_remote_socket, m_local_socket, r2l_transferred )
2275- );
2283+ transfer (m_remote_socket, m_local_socket)
2284+ );
22762285
22772286 XLOG_DBG << " connection id: "
22782287 << m_connection_id
@@ -2293,14 +2302,24 @@ R"x*x*x(<html>
22932302 net::awaitable<bool > socks_auth ();
22942303
22952304 template <typename S1, typename S2>
2296- net::awaitable<void > transfer (S1& from, S2& to, size_t & bytes_transferred )
2305+ net::awaitable<std::streamsize > transfer (S1& from, S2& to, std::streamsize bytes_to_be_sent = - 1 )
22972306 {
2298- bytes_transferred = 0 ;
22992307
2300- stream_rate_limit (from, m_option.tcp_rate_limit_ );
2301- stream_rate_limit (to, m_option.tcp_rate_limit_ );
2308+ std::size_t bytes_transferred = 0 ;
23022309
2303- stream_expires_after (from, std::chrono::seconds (m_option.tcp_timeout_ ));
2310+ if constexpr (supports_stream_rate_limit<S1, proxy_session>)
2311+ {
2312+ stream_rate_limit (from, m_option.tcp_rate_limit_ );
2313+ }
2314+ if constexpr (supports_stream_rate_limit<S2, proxy_session>)
2315+ {
2316+ stream_rate_limit (to, m_option.tcp_rate_limit_ );
2317+ }
2318+
2319+ if constexpr (supports_stream_expires_after<S1, proxy_session>)
2320+ {
2321+ stream_expires_after (from, std::chrono::seconds (m_option.tcp_timeout_ ));
2322+ }
23042323
23052324 constexpr auto buf_size = 512 * 1024 ;
23062325
@@ -2312,49 +2331,73 @@ R"x*x*x(<html>
23122331 auto secondary_buf = buf1.get ();
23132332
23142333 // 首先邓读取第一个数据作为预备, 以用于后面的交替读写逻辑.
2334+ auto read_size = (bytes_to_be_sent == -1 ) ? buf_size : std::min<std::streamsize>(bytes_to_be_sent, buf_size);
23152335 boost::system::error_code ec;
2316- auto bytes = co_await from.async_read_some (net::buffer (primary_buf, buf_size), net_awaitable[ec]);
2336+ auto bytes = co_await from.async_read_some (net::buffer (primary_buf, read_size), net_awaitable[ec]);
2337+ if (bytes_to_be_sent != -1 ) bytes_to_be_sent -= bytes;
23172338 if (ec || m_abort)
23182339 {
23192340 if (bytes > 0 )
2320- co_await net::async_write (to,
2341+ bytes_transferred += co_await net::async_write (to,
23212342 net::buffer (primary_buf, bytes), net_awaitable[ec]);
23222343
23232344 to.shutdown (net::socket_base::shutdown_send, ec);
2324- co_return ;
2345+ co_return bytes_transferred ;
23252346 }
23262347
23272348 for (; !m_abort;)
23282349 {
2329- stream_expires_after (to, std::chrono::seconds (m_option.tcp_timeout_ ));
2330- stream_expires_after (from, std::chrono::seconds (m_option.tcp_timeout_ ));
2350+ if constexpr (supports_stream_expires_after<S2, proxy_session>)
2351+ {
2352+ stream_expires_after (to, std::chrono::seconds (m_option.tcp_timeout_ ));
2353+ }
23312354
2332- // 并发读写.
2333- auto [write_bytes, read_bytes] =
2334- co_await (
2335- net::async_write (to,
2336- net::buffer (primary_buf, bytes), net_awaitable[ec])
2337- &&
2338- from.async_read_some (
2339- net::buffer (secondary_buf, buf_size), net_awaitable[ec])
2340- );
2355+ if constexpr (supports_stream_expires_after<S1, proxy_session>)
2356+ {
2357+ stream_expires_after (from, std::chrono::seconds (m_option.tcp_timeout_ ));
2358+ }
23412359
2342- // 交换主从缓冲区.
2343- std::swap (primary_buf, secondary_buf);
2360+ read_size = (bytes_to_be_sent == -1 ) ? buf_size : std::min<std::streamsize>(bytes_to_be_sent, buf_size);
23442361
2345- bytes = read_bytes;
2346- bytes_transferred += bytes;
2362+ if (read_size > 0 )
2363+ {
2364+ // 并发读写.
2365+ auto [write_bytes, read_bytes] =
2366+ co_await (
2367+ net::async_write (to,
2368+ net::buffer (primary_buf, bytes), net_awaitable[ec])
2369+ &&
2370+ from.async_read_some (
2371+ net::buffer (secondary_buf, read_size), net_awaitable[ec])
2372+ );
2373+
2374+ // 交换主从缓冲区.
2375+ std::swap (primary_buf, secondary_buf);
2376+
2377+ bytes = read_bytes;
2378+ if (bytes_to_be_sent != -1 ) bytes_to_be_sent -= bytes;
2379+ bytes_transferred += write_bytes;
2380+ }
2381+ else
2382+ {
2383+ bytes_transferred += co_await net::async_write (to,
2384+ net::buffer (primary_buf, bytes), net_awaitable[ec]);
2385+ co_return bytes_transferred;
2386+ }
23472387
23482388 // 如果 async_write 失败, 则也无需要再读取数据, 如果
23492389 // async_read_some 失败, 则也无数据可用于写, 所以无论哪一种情况
23502390 // 都可以直接退出.
23512391 if (ec)
23522392 {
2353- to.shutdown (net::socket_base::shutdown_send, ec);
2354- from.shutdown (net::socket_base::shutdown_receive, ec);
2355- co_return ;
2393+ if constexpr (supports_shutdown<S2>)
2394+ to.shutdown (net::socket_base::shutdown_send, ec);
2395+ if constexpr (supports_shutdown<S1>)
2396+ from.shutdown (net::socket_base::shutdown_receive, ec);
2397+ co_return bytes_transferred;
23562398 }
23572399 }
2400+ co_return bytes_transferred;
23582401 }
23592402
23602403 template <typename Stream, typename Endpoint>
0 commit comments