diff --git a/options.go b/options.go index 6bee0b5..d01236c 100644 --- a/options.go +++ b/options.go @@ -210,7 +210,12 @@ func Body(body io.Reader) Option { } } -// Data sets raw string into the request body. +// Data sets data of request body. It also deduces Content-Type based on +// input data types: +// +// 1. auto deduce by [http.DetectContentType]: io.Reader, []byte +// 2. "application/json": struct, slice(except []byte), and map +// 3. "text/plain": others func Data(data any) Option { return func(opts *Options) { opts.Data = data diff --git a/request.go b/request.go index 6ed6820..80867c6 100644 --- a/request.go +++ b/request.go @@ -10,6 +10,7 @@ import ( "io" "mime/multipart" "net/http" + "reflect" "github.com/Wenchy/requests/internal/auth" ) @@ -82,14 +83,16 @@ func request(c *Client, method, url string, opts *Options) (*Response, error) { func requestData(c *Client, method, url string, opts *Options) (*Response, error) { body := bytes.NewBuffer(nil) if opts.Data != nil { - d := fmt.Sprintf("%v", opts.Data) - _, err := body.WriteString(d) + contentType, bytes, err := deduceContentTypeAndBody(opts.Data) + if err != nil { + return nil, err + } + _, err = body.Write(bytes) if err != nil { return nil, err } + opts.Headers.Set("Content-Type", contentType) } - // TODO: judge content type - // opts.Headers["Content-Type"] = "application/x-www-form-urlencoded" opts.Body = body return c.request(method, url, opts, body.Bytes()) } @@ -105,7 +108,7 @@ func requestForm(c *Client, method, url string, opts *Options) (*Response, error return nil, err } } - opts.Headers.Set("Content-Type", "application/x-www-form-urlencoded") + opts.Headers.Set("Content-Type", formContentType) opts.Body = body return c.request(method, url, opts, body.Bytes()) } @@ -123,7 +126,7 @@ func requestJSON(c *Client, method, url string, opts *Options) (*Response, error return nil, err } } - opts.Headers.Set("Content-Type", "application/json") + opts.Headers.Set("Content-Type", jsonContentType) opts.Body = body return c.request(method, url, opts, body.Bytes()) } @@ -171,3 +174,30 @@ var dispatchers map[bodyType]dispatcher = map[bodyType]dispatcher{ bodyTypeJSON: requestJSON, bodyTypeFiles: requestFiles, } + +var ( + plainTextType = "text/plain; charset=utf-8" + jsonContentType = "application/json" + formContentType = "application/x-www-form-urlencoded" +) + +// deduceContentTypeAndBody parses content type and request body from request data +func deduceContentTypeAndBody(data any) (string, []byte, error) { + if reader, ok := data.(io.Reader); ok { + body, err := io.ReadAll(reader) + return http.DetectContentType(body), body, err + } + bodyValue := reflect.Indirect(reflect.ValueOf(data)) + switch bodyValue.Kind() { + case reflect.Struct, reflect.Map, reflect.Slice: + // check slice here to differentiate between any slice vs byte slice + if body, ok := data.([]byte); ok { + return http.DetectContentType(body), body, nil + } else { + body, err := json.Marshal(data) + return jsonContentType, body, err + } + default: + return plainTextType, fmt.Appendf(nil, "%v", bodyValue.Interface()), nil + } +} diff --git a/request_test.go b/request_test.go index 5d2d26b..7ea69cb 100644 --- a/request_test.go +++ b/request_test.go @@ -1,6 +1,7 @@ package requests import ( + "bytes" "context" "crypto/md5" "encoding/hex" @@ -776,3 +777,147 @@ func TestInterceptors(t *testing.T) { }) } } + +func toPtr[T any](v T) *T { + return &v +} + +func Test_deduceContentTypeAndBody(t *testing.T) { + type mystruct struct { + A int + B string + } + tests := []struct { + name string + body any + want string + want2 []byte + }{ + { + name: "int", + body: 123, + want: plainTextType, + want2: []byte("123"), + }, + { + name: "*int", + body: toPtr(123), + want: plainTextType, + want2: []byte("123"), + }, + { + name: "string", + body: "abc", + want: plainTextType, + want2: []byte("abc"), + }, + { + name: "*string", + body: toPtr("abc"), + want: plainTextType, + want2: []byte("abc"), + }, + { + name: "bytes", + body: []byte("abc"), + want: plainTextType, + want2: []byte("abc"), + }, + { + name: "struct", + body: mystruct{A: 123, B: "abc"}, + want: jsonContentType, + want2: []byte(`{"A":123,"B":"abc"}`), + }, + { + name: "*struct", + body: &mystruct{A: 123, B: "abc"}, + want: jsonContentType, + want2: []byte(`{"A":123,"B":"abc"}`), + }, + { + name: "map", + body: map[int]string{1: "a", 2: "b", 3: "c"}, + want: jsonContentType, + want2: []byte(`{"1":"a","2":"b","3":"c"}`), + }, + { + name: "[]int", + body: []int{123, 456}, + want: jsonContentType, + want2: []byte("[123,456]"), + }, + { + name: "[]*int", + body: []*int{toPtr(123), toPtr(456)}, + want: jsonContentType, + want2: []byte("[123,456]"), + }, + { + name: "[]string", + body: []string{"abc", "def"}, + want: jsonContentType, + want2: []byte(`["abc","def"]`), + }, + { + name: "[]*string", + body: []*string{toPtr("abc"), toPtr("def")}, + want: jsonContentType, + want2: []byte(`["abc","def"]`), + }, + { + name: "[]bytes", + body: [][]byte{[]byte("abc"), []byte("def")}, + want: jsonContentType, + want2: []byte(`["YWJj","ZGVm"]`), + }, + { + name: "[]struct", + body: []mystruct{ + {A: 123, B: "abc"}, + {A: 456, B: "def"}, + }, + want: jsonContentType, + want2: []byte(`[{"A":123,"B":"abc"},{"A":456,"B":"def"}]`), + }, + { + name: "[]*struct", + body: []*mystruct{ + {A: 123, B: "abc"}, + {A: 456, B: "def"}, + }, + want: jsonContentType, + want2: []byte(`[{"A":123,"B":"abc"},{"A":456,"B":"def"}]`), + }, + { + name: "[]map", + body: []map[int]string{ + {1: "a", 2: "b", 3: "c"}, + {4: "d", 5: "e", 6: "f"}, + }, + want: jsonContentType, + want2: []byte(`[{"1":"a","2":"b","3":"c"},{"4":"d","5":"e","6":"f"}]`), + }, + { + name: "io.Reader", + body: bytes.NewBuffer([]byte("abc")), + want: plainTextType, + want2: []byte("abc"), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, got2, gotErr := deduceContentTypeAndBody(tt.body) + if gotErr != nil { + t.Errorf("detectContentType() failed: %v", gotErr) + return + } + if got != tt.want { + t.Errorf("detectContentType() = %v, want %v", got, tt.want) + } + if string(got2) != string(tt.want2) { + t.Errorf("detectContentType() = %v, want %v", string(got2), string(tt.want2)) + } + }) + } +}