Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pkg/httpserver/authlayer.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import (
"net/http"
)

func (t *HTTPServer) basicauthlayer(handler http.Handler) http.HandlerFunc {
func (t *HTTPServer) basicauthlayer(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
user, pass, ok := r.BasicAuth()
if !ok || user != t.options.BasicAuthUsername || pass != t.options.BasicAuthPassword {
Expand Down
22 changes: 20 additions & 2 deletions pkg/httpserver/httpserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ type HTTPServer struct {
layers http.Handler
}

// LayerHandler is the interface of all layer funcs
type Middleware func(http.Handler) http.Handler

// New http server instance with options
func New(options *Options) (*HTTPServer, error) {
var h HTTPServer
Expand All @@ -53,10 +56,25 @@ func New(options *Options) (*HTTPServer, error) {
if options.Sandbox {
dir = SandboxFileSystem{fs: http.Dir(options.Folder), RootFolder: options.Folder}
}
h.layers = h.loglayer(http.FileServer(dir))

httpHandler := http.FileServer(dir)
addHandler := func(newHandler Middleware) {
httpHandler = newHandler(httpHandler)
}

// middleware
if options.EnableUpload {
addHandler(h.uploadlayer)
}

if options.BasicAuthUsername != "" || options.BasicAuthPassword != "" {
h.layers = h.loglayer(h.basicauthlayer(http.FileServer(dir)))
addHandler(h.basicauthlayer)
}

httpHandler = h.loglayer(httpHandler)

// add handler
h.layers = httpHandler
h.options = options

return &h, nil
Expand Down
57 changes: 0 additions & 57 deletions pkg/httpserver/loglayer.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,10 @@ package httpserver

import (
"bytes"
"io/ioutil"
"net/http"
"net/http/httputil"
"path"
"path/filepath"

"github.com/projectdiscovery/gologger"
"github.com/projectdiscovery/simplehttpserver/pkg/unit"
)

// Convenience globals
Expand All @@ -33,59 +29,6 @@ func (t *HTTPServer) loglayer(handler http.Handler) http.Handler {
lrw := newLoggingResponseWriter(w, t.options.MaxDumpBodySize)
handler.ServeHTTP(lrw, r)

// Handles file write if enabled
if EnableUpload && r.Method == http.MethodPut {
// sandbox - calcolate absolute path
if t.options.Sandbox {
absPath, err := filepath.Abs(filepath.Join(t.options.Folder, r.URL.Path))
if err != nil {
gologger.Print().Msgf("%s\n", err)
w.WriteHeader(http.StatusBadRequest)
return
}
// check if the path is within the configured folder
pattern := t.options.Folder + string(filepath.Separator) + "*"
matched, err := filepath.Match(pattern, absPath)
if err != nil {
gologger.Print().Msgf("%s\n", err)
w.WriteHeader(http.StatusBadRequest)
return
} else if !matched {
gologger.Print().Msg("pointing to unauthorized directory")
w.WriteHeader(http.StatusBadRequest)
return
}
}

var (
data []byte
err error
)
if t.options.Sandbox {
maxFileSize := unit.ToMb(t.options.MaxFileSize)
// check header content length
if r.ContentLength > maxFileSize {
gologger.Print().Msg("request too large")
return
}
// body max length
r.Body = http.MaxBytesReader(w, r.Body, maxFileSize)
}

data, err = ioutil.ReadAll(r.Body)
if err != nil {
gologger.Print().Msgf("%s\n", err)
w.WriteHeader(http.StatusInternalServerError)
return
}
err = handleUpload(t.options.Folder, path.Base(r.URL.Path), data)
if err != nil {
gologger.Print().Msgf("%s\n", err)
w.WriteHeader(http.StatusInternalServerError)
return
}
}

if EnableVerbose {
headers := new(bytes.Buffer)
lrw.Header().Write(headers) //nolint
Expand Down
86 changes: 81 additions & 5 deletions pkg/httpserver/uploadlayer.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,97 @@ package httpserver
import (
"errors"
"io/ioutil"
"net/http"
"os"
"path"
"path/filepath"
"strings"

"github.com/projectdiscovery/gologger"
"github.com/projectdiscovery/simplehttpserver/pkg/unit"
)

// uploadlayer handles PUT requests and save the file to disk
func (t *HTTPServer) uploadlayer(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Handles file write if enabled
if EnableUpload && r.Method == http.MethodPut {
// sandbox - calcolate absolute path
if t.options.Sandbox {
absPath, err := filepath.Abs(filepath.Join(t.options.Folder, r.URL.Path))
if err != nil {
gologger.Print().Msgf("%s\n", err)
w.WriteHeader(http.StatusBadRequest)
return
}
// check if the path is within the configured folder
pattern := t.options.Folder + string(filepath.Separator) + "*"
matched, err := filepath.Match(pattern, absPath)
if err != nil {
gologger.Print().Msgf("%s\n", err)
w.WriteHeader(http.StatusBadRequest)
return
} else if !matched {
gologger.Print().Msg("pointing to unauthorized directory")
w.WriteHeader(http.StatusBadRequest)
return
}
}

var (
data []byte
err error
)
if t.options.Sandbox {
maxFileSize := unit.ToMb(t.options.MaxFileSize)
// check header content length
if r.ContentLength > maxFileSize {
gologger.Print().Msg("request too large")
return
}
// body max length
r.Body = http.MaxBytesReader(w, r.Body, maxFileSize)
}

data, err = ioutil.ReadAll(r.Body)
if err != nil {
gologger.Print().Msgf("%s\n", err)
w.WriteHeader(http.StatusInternalServerError)
return
}

sanitizedPath := filepath.FromSlash(path.Clean("/" + strings.Trim(r.URL.Path, "/")))

err = handleUpload(t.options.Folder, sanitizedPath, data)
if err != nil {
gologger.Print().Msgf("%s\n", err)
w.WriteHeader(http.StatusInternalServerError)
return
} else {
w.WriteHeader(http.StatusCreated)
return
}
}

handler.ServeHTTP(w, r)
})
}

func handleUpload(base, file string, data []byte) error {
// rejects all paths containing a non exhaustive list of invalid characters - This is only a best effort as the tool is meant for development
if strings.ContainsAny(file, "\\`\"':") {
return errors.New("invalid character")
}

// allow upload only in subfolders
rel, err := filepath.Rel(base, file)
if rel == "" || err != nil {
return err
untrustedPath := filepath.Clean(filepath.Join(base, file))
if !strings.HasPrefix(untrustedPath, filepath.Clean(base)) {
return errors.New("invalid path")
}
trustedPath := untrustedPath

if _, err := os.Stat(path.Dir(trustedPath)); os.IsNotExist(err) {
return errors.New("invalid path")
}

return ioutil.WriteFile(file, data, 0655)
return ioutil.WriteFile(trustedPath, data, 0655)
}