diff --git a/pkg/inference/models/http_handler.go b/pkg/inference/models/http_handler.go index d8eef014..63c56de9 100644 --- a/pkg/inference/models/http_handler.go +++ b/pkg/inference/models/http_handler.go @@ -23,6 +23,28 @@ import ( "github.com/docker/model-runner/pkg/middleware" ) +// parseBoolQueryParam parses a boolean query parameter from the request. +// Returns the parsed value, or false if the parameter is absent or unparseable +// (logging a warning in the latter case). Treats presence of the key with an +// empty value (e.g., `?force`) as true. +func parseBoolQueryParam(r *http.Request, log logging.Logger, name string) bool { + q := r.URL.Query() + if !q.Has(name) { + return false + } + valStr := q.Get(name) + // Treat presence of key with empty value as true (e.g., `?force`) + if valStr == "" { + return true + } + val, err := strconv.ParseBool(valStr) + if err != nil { + log.Warn("error while parsing query parameter", "param", name, "value", valStr, "error", err) + return false + } + return val +} + // HTTPHandler manages inference model pulls and storage. type HTTPHandler struct { // log is the associated logger. @@ -195,16 +217,7 @@ func (h *HTTPHandler) handleGetModel(w http.ResponseWriter, r *http.Request) { } func (h *HTTPHandler) handleGetModelByRef(w http.ResponseWriter, r *http.Request, modelRef string) { - // Parse remote query parameter - remote := false - if r.URL.Query().Has("remote") { - val, err := strconv.ParseBool(r.URL.Query().Get("remote")) - if err != nil { - h.log.Warn("error while parsing remote query parameter", "error", err) - } else { - remote = val - } - } + remote := parseBoolQueryParam(r, h.log, "remote") var ( apiModel *Model @@ -309,14 +322,7 @@ func (h *HTTPHandler) handleDeleteModel(w http.ResponseWriter, r *http.Request) modelRef := r.PathValue("name") - var force bool - if r.URL.Query().Has("force") { - if val, err := strconv.ParseBool(r.URL.Query().Get("force")); err != nil { - h.log.Warn("error while parsing force query parameter", "error", err) - } else { - force = val - } - } + force := parseBoolQueryParam(r, h.log, "force") // First try to delete without normalization (as ID), then with normalization if not found resp, err := h.manager.Delete(modelRef, force) diff --git a/pkg/inference/scheduling/http_handler.go b/pkg/inference/scheduling/http_handler.go index a9f3077b..9a7195ed 100644 --- a/pkg/inference/scheduling/http_handler.go +++ b/pkg/inference/scheduling/http_handler.go @@ -23,6 +23,23 @@ import ( type contextKey bool +// readRequestBody reads up to maxSize bytes from the request body and writes +// an appropriate HTTP error if reading fails. Returns (body, true) on success +// or (nil, false) after writing the error response. +func readRequestBody(w http.ResponseWriter, r *http.Request, maxSize int64) ([]byte, bool) { + body, err := io.ReadAll(http.MaxBytesReader(w, r.Body, maxSize)) + if err != nil { + var maxBytesError *http.MaxBytesError + if errors.As(err, &maxBytesError) { + http.Error(w, "request too large", http.StatusBadRequest) + } else { + http.Error(w, "failed to read request body", http.StatusInternalServerError) + } + return nil, false + } + return body, true +} + const preloadOnlyKey contextKey = false // HTTPHandler handles HTTP requests for the scheduler. @@ -132,14 +149,8 @@ func (h *HTTPHandler) handleOpenAIInference(w http.ResponseWriter, r *http.Reque // Read the entire request body. We put some basic size constraints in place // to avoid DoS attacks. We do this early to avoid client write timeouts. - body, err := io.ReadAll(http.MaxBytesReader(w, r.Body, maximumOpenAIInferenceRequestSize)) - if err != nil { - var maxBytesError *http.MaxBytesError - if errors.As(err, &maxBytesError) { - http.Error(w, "request too large", http.StatusBadRequest) - } else { - http.Error(w, "failed to read request body", http.StatusInternalServerError) - } + body, ok := readRequestBody(w, r, maximumOpenAIInferenceRequestSize) + if !ok { return } @@ -338,14 +349,8 @@ func (h *HTTPHandler) GetDiskUsage(w http.ResponseWriter, _ *http.Request) { // Unload unloads the specified runners (backend, model) from the backend. // Currently, this doesn't work for runners that are handling an OpenAI request. func (h *HTTPHandler) Unload(w http.ResponseWriter, r *http.Request) { - body, err := io.ReadAll(http.MaxBytesReader(w, r.Body, maximumOpenAIInferenceRequestSize)) - if err != nil { - var maxBytesError *http.MaxBytesError - if errors.As(err, &maxBytesError) { - http.Error(w, "request too large", http.StatusBadRequest) - } else { - http.Error(w, "failed to read request body", http.StatusInternalServerError) - } + body, ok := readRequestBody(w, r, maximumOpenAIInferenceRequestSize) + if !ok { return } @@ -371,14 +376,8 @@ type installBackendRequest struct { // InstallBackend handles POST /install-backend requests. // It triggers on-demand installation of a deferred backend. func (h *HTTPHandler) InstallBackend(w http.ResponseWriter, r *http.Request) { - body, err := io.ReadAll(http.MaxBytesReader(w, r.Body, maximumOpenAIInferenceRequestSize)) - if err != nil { - var maxBytesError *http.MaxBytesError - if errors.As(err, &maxBytesError) { - http.Error(w, "request too large", http.StatusBadRequest) - } else { - http.Error(w, "failed to read request body", http.StatusInternalServerError) - } + body, ok := readRequestBody(w, r, maximumOpenAIInferenceRequestSize) + if !ok { return } @@ -404,6 +403,7 @@ func (h *HTTPHandler) InstallBackend(w http.ResponseWriter, r *http.Request) { func (h *HTTPHandler) Configure(w http.ResponseWriter, r *http.Request) { // Determine the requested backend and ensure that it's valid. var backend inference.Backend + var err error if b := r.PathValue("backend"); b == "" { backend = h.scheduler.defaultBackend } else { @@ -414,14 +414,8 @@ func (h *HTTPHandler) Configure(w http.ResponseWriter, r *http.Request) { return } - body, err := io.ReadAll(http.MaxBytesReader(w, r.Body, maximumOpenAIInferenceRequestSize)) - if err != nil { - var maxBytesError *http.MaxBytesError - if errors.As(err, &maxBytesError) { - http.Error(w, "request too large", http.StatusBadRequest) - } else { - http.Error(w, "failed to read request body", http.StatusInternalServerError) - } + body, ok := readRequestBody(w, r, maximumOpenAIInferenceRequestSize) + if !ok { return }