Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 48 additions & 17 deletions handler.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
package openapi

import (
"bytes"
"fmt"
"strings"

"net/http"

"sync"

"github.com/getkin/kin-openapi/openapi3filter"

"github.com/caddyserver/caddy/v2"
Expand Down Expand Up @@ -99,22 +102,34 @@ func (oapi OpenAPI) ServeHTTP(w http.ResponseWriter, req *http.Request, next cad
}
}

wrapper := &WrapperResponseWriter{ResponseWriter: w}
if err := next.ServeHTTP(wrapper, req); nil != err {
// In case we shouldn't validate responses, we're going to execute the next handler and return early (less overhead)
if (nil == route) || (nil == oapi.Check) || (nil == oapi.contentMap) {
return next.ServeHTTP(w, req)
}

// get a buffer to hold the response body
respBuf := bufPool.Get().(*bytes.Buffer)
respBuf.Reset()
defer bufPool.Put(respBuf)

shouldBuffer := func(status int, header http.Header) bool {
return true
}
rec := caddyhttp.NewResponseRecorder(w, respBuf, shouldBuffer)
if err := next.ServeHTTP(rec, req); nil != err {
return err
}

if nil != oapi.contentMap {
contentType := w.Header().Get("Content-Type")
if "" == contentType {
return nil
}
contentType = strings.ToLower(strings.TrimSpace(strings.Split(contentType, ";")[0]))
_, ok := oapi.contentMap[contentType]
if !ok {
return nil
}
// if ResponseRecorder was not buffered, we don't need to validate response
if !rec.Buffered() {
return nil
}

contentType := w.Header().Get("Content-Type")
contentType = strings.ToLower(strings.TrimSpace(strings.Split(contentType, ";")[0]))

_, ok := oapi.contentMap[contentType]
if ok {
validateReqInput := &openapi3filter.RequestValidationInput{
Request: req,
PathParams: pathParams,
Expand All @@ -126,19 +141,35 @@ func (oapi OpenAPI) ServeHTTP(w http.ResponseWriter, req *http.Request, next cad
},
}

if (nil != wrapper.Buffer) && (len(wrapper.Buffer) > 0) {
body := rec.Buffer().Bytes()

if (nil != body) && (len(body) > 0) {
validateRespInput := &openapi3filter.ResponseValidationInput{
RequestValidationInput: validateReqInput,
Status: wrapper.StatusCode,
Header: http.Header{"Content-Type": oapi.Check.ResponseBody},
Status: rec.Status(),
Header: rec.Header(),
}
validateRespInput.SetBodyBytes(wrapper.Buffer)
validateRespInput.SetBodyBytes(body)
if err := openapi3filter.ValidateResponse(req.Context(), validateRespInput); nil != err {
respErr := err.(*openapi3filter.ResponseError)
replacer.Set(OPENAPI_RESPONSE_ERROR, respErr.Error())
oapi.err(fmt.Sprintf("<< %s %s %s: %s", getIP(req), req.Method, req.RequestURI, respErr.Error()))
if oapi.LogError {
oapi.err(fmt.Sprintf("<< %s %s %s: %s", getIP(req), req.Method, req.RequestURI, respErr.Error()))
}
if !oapi.FallThrough {
return err
}
}
}
}

rec.WriteResponse()

return nil
}

var bufPool = sync.Pool{
New: func() interface{} {
return new(bytes.Buffer)
},
}
16 changes: 0 additions & 16 deletions util.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,6 @@ import (
"github.com/open-policy-agent/opa/rego"
)

type WrapperResponseWriter struct {
http.ResponseWriter
StatusCode int
Buffer []byte
}

func (w *WrapperResponseWriter) WriteHeader(sc int) {
w.ResponseWriter.WriteHeader(sc)
w.StatusCode = sc
}

func (w *WrapperResponseWriter) Write(buff []byte) (int, error) {
w.Buffer = append(w.Buffer[:], buff[:]...)
return w.ResponseWriter.Write(buff)
}

func getIP(req *http.Request) string {
ip := req.Header.Get("X-Forwarded-For")
if "" != ip {
Expand Down