@@ -131,6 +131,15 @@ pub enum CheckSupportMethod {
131131 Head ,
132132}
133133
134+ /// Which URL should be used for subsequent range requests?
135+ pub enum RangeRequestUrlSource {
136+ /// Use the initial request URL
137+ Request ,
138+
139+ /// Use the initial response URL
140+ Response ,
141+ }
142+
134143fn error_for_status ( response : reqwest:: Response ) -> reqwest_middleware:: Result < Response > {
135144 response
136145 . error_for_status ( )
@@ -143,6 +152,7 @@ impl AsyncHttpRangeReader {
143152 client : impl Into < reqwest_middleware:: ClientWithMiddleware > ,
144153 url : reqwest:: Url ,
145154 check_method : CheckSupportMethod ,
155+ range_request_url_source : RangeRequestUrlSource ,
146156 extra_headers : HeaderMap ,
147157 ) -> Result < ( Self , HeaderMap ) , AsyncHttpRangeReaderError > {
148158 let client = client. into ( ) ;
@@ -156,15 +166,23 @@ impl AsyncHttpRangeReader {
156166 )
157167 . await ?;
158168 let response_headers = response. headers ( ) . clone ( ) ;
159- let self_ = Self :: from_tail_response ( client, response, extra_headers) . await ?;
169+ let url = match range_request_url_source {
170+ RangeRequestUrlSource :: Request => url,
171+ RangeRequestUrlSource :: Response => response. url ( ) . clone ( ) ,
172+ } ;
173+ let self_ = Self :: from_tail_response ( client, response, url, extra_headers) . await ?;
160174 Ok ( ( self_, response_headers) )
161175 }
162176 CheckSupportMethod :: Head => {
163177 let response =
164178 Self :: initial_head_request ( client. clone ( ) , url. clone ( ) , HeaderMap :: default ( ) )
165179 . await ?;
166180 let response_headers = response. headers ( ) . clone ( ) ;
167- let self_ = Self :: from_head_response ( client, response, extra_headers) . await ?;
181+ let url = match range_request_url_source {
182+ RangeRequestUrlSource :: Request => url,
183+ RangeRequestUrlSource :: Response => response. url ( ) . clone ( ) ,
184+ } ;
185+ let self_ = Self :: from_head_response ( client, response, url, extra_headers) . await ?;
168186 Ok ( ( self_, response_headers) )
169187 }
170188 }
@@ -200,6 +218,7 @@ impl AsyncHttpRangeReader {
200218 pub async fn from_tail_response (
201219 client : impl Into < reqwest_middleware:: ClientWithMiddleware > ,
202220 tail_request_response : Response ,
221+ url : Url ,
203222 extra_headers : HeaderMap ,
204223 ) -> Result < Self , AsyncHttpRangeReaderError > {
205224 let client = client. into ( ) ;
@@ -245,7 +264,7 @@ impl AsyncHttpRangeReader {
245264 let ( state_tx, state_rx) = watch:: channel ( StreamerState :: default ( ) ) ;
246265 tokio:: spawn ( run_streamer (
247266 client,
248- tail_request_response . url ( ) . clone ( ) ,
267+ url,
249268 extra_headers,
250269 Some ( ( tail_request_response, start) ) ,
251270 memory_map,
@@ -300,6 +319,7 @@ impl AsyncHttpRangeReader {
300319 pub async fn from_head_response (
301320 client : impl Into < reqwest_middleware:: ClientWithMiddleware > ,
302321 head_response : Response ,
322+ url : Url ,
303323 extra_headers : HeaderMap ,
304324 ) -> Result < Self , AsyncHttpRangeReaderError > {
305325 let client = client. into ( ) ;
@@ -345,7 +365,7 @@ impl AsyncHttpRangeReader {
345365 let ( state_tx, state_rx) = watch:: channel ( StreamerState :: default ( ) ) ;
346366 tokio:: spawn ( run_streamer (
347367 client,
348- head_response . url ( ) . clone ( ) ,
368+ url,
349369 extra_headers,
350370 None ,
351371 memory_map,
@@ -688,6 +708,7 @@ mod test {
688708 Client :: new ( ) ,
689709 server. url ( ) . join ( "andes-1.8.3-pyhd8ed1ab_0.conda" ) . unwrap ( ) ,
690710 check_method,
711+ RangeRequestUrlSource :: Response ,
691712 HeaderMap :: default ( ) ,
692713 )
693714 . await
@@ -783,6 +804,57 @@ mod test {
783804 Client :: new ( ) ,
784805 server. url ( ) . join ( "andes-1.8.3-pyhd8ed1ab_0.conda" ) . unwrap ( ) ,
785806 check_method,
807+ RangeRequestUrlSource :: Response ,
808+ HeaderMap :: default ( ) ,
809+ )
810+ . await
811+ . expect ( "bla" ) ;
812+
813+ // Also open a simple file reader
814+ let mut file = tokio:: fs:: File :: open ( path. join ( "andes-1.8.3-pyhd8ed1ab_0.conda" ) )
815+ . await
816+ . unwrap ( ) ;
817+
818+ // Read until the end and make sure that the contents matches
819+ let mut range_read = vec ! [ 0 ; 64 * 1024 ] ;
820+ let mut file_read = vec ! [ 0 ; 64 * 1024 ] ;
821+ loop {
822+ // Read with the async reader
823+ let range_read_bytes = range. read ( & mut range_read) . await . unwrap ( ) ;
824+
825+ // Read directly from the file
826+ let file_read_bytes = file
827+ . read_exact ( & mut file_read[ 0 ..range_read_bytes] )
828+ . await
829+ . unwrap ( ) ;
830+
831+ assert_eq ! ( range_read_bytes, file_read_bytes) ;
832+ assert_eq ! (
833+ range_read[ 0 ..range_read_bytes] ,
834+ file_read[ 0 ..file_read_bytes]
835+ ) ;
836+
837+ if file_read_bytes == 0 && range_read_bytes == 0 {
838+ break ;
839+ }
840+ }
841+ }
842+
843+ #[ rstest]
844+ #[ case( RangeRequestUrlSource :: Request ) ]
845+ #[ case( RangeRequestUrlSource :: Response ) ]
846+ #[ tokio:: test]
847+ async fn async_range_reader_url_source ( #[ case] url_source : RangeRequestUrlSource ) {
848+ // Spawn a static file server
849+ let path = Path :: new ( & std:: env:: var ( "CARGO_MANIFEST_DIR" ) . unwrap ( ) ) . join ( "test-data" ) ;
850+ let server = StaticDirectoryServer :: new ( & path) ;
851+
852+ // Construct an AsyncRangeReader
853+ let ( mut range, _) = AsyncHttpRangeReader :: new (
854+ Client :: new ( ) ,
855+ server. url ( ) . join ( "andes-1.8.3-pyhd8ed1ab_0.conda" ) . unwrap ( ) ,
856+ CheckSupportMethod :: Head ,
857+ url_source,
786858 HeaderMap :: default ( ) ,
787859 )
788860 . await
@@ -825,6 +897,7 @@ mod test {
825897 Client :: new ( ) ,
826898 server. url ( ) . join ( "not-found" ) . unwrap ( ) ,
827899 CheckSupportMethod :: Head ,
900+ RangeRequestUrlSource :: Response ,
828901 HeaderMap :: default ( ) ,
829902 )
830903 . await
0 commit comments