From d55e2f8dfea49661e06934a7400cad1d80b5caca Mon Sep 17 00:00:00 2001 From: invist <35263248+c-f@users.noreply.github.com> Date: Wed, 13 Oct 2021 09:59:36 +0200 Subject: [PATCH] smalle upload refactoring --- pkg/httpserver/authlayer.go | 2 +- pkg/httpserver/httpserver.go | 22 ++++++++- pkg/httpserver/loglayer.go | 56 ----------------------- pkg/httpserver/uploadlayer.go | 85 ++++++++++++++++++++++++++++++++--- 4 files changed, 101 insertions(+), 64 deletions(-) diff --git a/pkg/httpserver/authlayer.go b/pkg/httpserver/authlayer.go index f2eff4b..297d863 100644 --- a/pkg/httpserver/authlayer.go +++ b/pkg/httpserver/authlayer.go @@ -5,7 +5,7 @@ import ( "net/http" ) -func (t *HTTPServer) basicauthlayer(handler http.Handler) http.HandlerFunc { +func (t *HTTPServer) basicauthlayer(handler http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { user, pass, ok := r.BasicAuth() if !ok || user != t.options.BasicAuthUsername || pass != t.options.BasicAuthPassword { diff --git a/pkg/httpserver/httpserver.go b/pkg/httpserver/httpserver.go index 72da466..439de3c 100644 --- a/pkg/httpserver/httpserver.go +++ b/pkg/httpserver/httpserver.go @@ -32,6 +32,9 @@ type HTTPServer struct { layers http.Handler } +// LayerHandler is the interface of all layer funcs +type Middleware func(http.Handler) http.Handler + // New http server instance with options func New(options *Options) (*HTTPServer, error) { var h HTTPServer @@ -50,10 +53,25 @@ func New(options *Options) (*HTTPServer, error) { if options.Sandbox { dir = SandboxFileSystem{fs: http.Dir(options.Folder), RootFolder: options.Folder} } - h.layers = h.loglayer(http.FileServer(dir)) + + httpHandler := http.FileServer(dir) + addHandler := func(newHandler Middleware) { + httpHandler = newHandler(httpHandler) + } + + // middleware + if options.EnableUpload { + addHandler(h.uploadlayer) + } + if options.BasicAuthUsername != "" || options.BasicAuthPassword != "" { - h.layers = h.loglayer(h.basicauthlayer(http.FileServer(dir))) + addHandler(h.basicauthlayer) } + + httpHandler = h.loglayer(httpHandler) + + // add handler + h.layers = httpHandler h.options = options return &h, nil diff --git a/pkg/httpserver/loglayer.go b/pkg/httpserver/loglayer.go index 0e1a87a..468fb6a 100644 --- a/pkg/httpserver/loglayer.go +++ b/pkg/httpserver/loglayer.go @@ -2,11 +2,8 @@ package httpserver import ( "bytes" - "io/ioutil" "net/http" "net/http/httputil" - "path" - "path/filepath" "github.com/projectdiscovery/gologger" ) @@ -23,59 +20,6 @@ func (t *HTTPServer) loglayer(handler http.Handler) http.Handler { lrw := newLoggingResponseWriter(w) handler.ServeHTTP(lrw, r) - // Handles file write if enabled - if EnableUpload && r.Method == http.MethodPut { - // sandbox - calcolate absolute path - if t.options.Sandbox { - absPath, err := filepath.Abs(filepath.Join(t.options.Folder, r.URL.Path)) - if err != nil { - gologger.Print().Msgf("%s\n", err) - w.WriteHeader(http.StatusBadRequest) - return - } - // check if the path is within the configured folder - pattern := t.options.Folder + string(filepath.Separator) + "*" - matched, err := filepath.Match(pattern, absPath) - if err != nil { - gologger.Print().Msgf("%s\n", err) - w.WriteHeader(http.StatusBadRequest) - return - } else if !matched { - gologger.Print().Msg("pointing to unauthorized directory") - w.WriteHeader(http.StatusBadRequest) - return - } - } - - var ( - data []byte - err error - ) - if t.options.Sandbox { - maxFileSize := toMb(t.options.MaxFileSize) - // check header content length - if r.ContentLength > maxFileSize { - gologger.Print().Msg("request too large") - return - } - // body max length - r.Body = http.MaxBytesReader(w, r.Body, maxFileSize) - } - - data, err = ioutil.ReadAll(r.Body) - if err != nil { - gologger.Print().Msgf("%s\n", err) - w.WriteHeader(http.StatusInternalServerError) - return - } - err = handleUpload(t.options.Folder, path.Base(r.URL.Path), data) - if err != nil { - gologger.Print().Msgf("%s\n", err) - w.WriteHeader(http.StatusInternalServerError) - return - } - } - if EnableVerbose { headers := new(bytes.Buffer) lrw.Header().Write(headers) //nolint diff --git a/pkg/httpserver/uploadlayer.go b/pkg/httpserver/uploadlayer.go index 928ac60..9f4821f 100644 --- a/pkg/httpserver/uploadlayer.go +++ b/pkg/httpserver/uploadlayer.go @@ -3,21 +3,96 @@ package httpserver import ( "errors" "io/ioutil" + "net/http" + "os" + "path" "path/filepath" "strings" + + "github.com/projectdiscovery/gologger" ) +// uploadlayer handles PUT requests and save the file to disk +func (t *HTTPServer) uploadlayer(handler http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Handles file write if enabled + if EnableUpload && r.Method == http.MethodPut { + // sandbox - calcolate absolute path + if t.options.Sandbox { + absPath, err := filepath.Abs(filepath.Join(t.options.Folder, r.URL.Path)) + if err != nil { + gologger.Print().Msgf("%s\n", err) + w.WriteHeader(http.StatusBadRequest) + return + } + // check if the path is within the configured folder + pattern := t.options.Folder + string(filepath.Separator) + "*" + matched, err := filepath.Match(pattern, absPath) + if err != nil { + gologger.Print().Msgf("%s\n", err) + w.WriteHeader(http.StatusBadRequest) + return + } else if !matched { + gologger.Print().Msg("pointing to unauthorized directory") + w.WriteHeader(http.StatusBadRequest) + return + } + } + + var ( + data []byte + err error + ) + if t.options.Sandbox { + maxFileSize := toMb(t.options.MaxFileSize) + // check header content length + if r.ContentLength > maxFileSize { + gologger.Print().Msg("request too large") + return + } + // body max length + r.Body = http.MaxBytesReader(w, r.Body, maxFileSize) + } + + data, err = ioutil.ReadAll(r.Body) + if err != nil { + gologger.Print().Msgf("%s\n", err) + w.WriteHeader(http.StatusInternalServerError) + return + } + + sanitizedPath := filepath.FromSlash(path.Clean("/" + strings.Trim(r.URL.Path, "/"))) + + err = handleUpload(t.options.Folder, sanitizedPath, data) + if err != nil { + gologger.Print().Msgf("%s\n", err) + w.WriteHeader(http.StatusInternalServerError) + return + } else { + w.WriteHeader(http.StatusCreated) + return + } + } + + handler.ServeHTTP(w, r) + }) +} + func handleUpload(base, file string, data []byte) error { // rejects all paths containing a non exhaustive list of invalid characters - This is only a best effort as the tool is meant for development if strings.ContainsAny(file, "\\`\"':") { return errors.New("invalid character") } - // allow upload only in subfolders - rel, err := filepath.Rel(base, file) - if rel == "" || err != nil { - return err + untrustedPath := filepath.Clean(filepath.Join(base, file)) + if !strings.HasPrefix(untrustedPath, filepath.Clean(base)) { + return errors.New("invalid path") + } + trustedPath := untrustedPath + + if _, err := os.Stat(path.Dir(trustedPath)); os.IsNotExist(err) { + return errors.New("invalid path") } - return ioutil.WriteFile(file, data, 0655) + return ioutil.WriteFile(trustedPath, data, 0655) }