diff --git a/Sources/DemoServer.swift b/Sources/DemoServer.swift index e3ec153e..672e5194 100644 --- a/Sources/DemoServer.swift +++ b/Sources/DemoServer.swift @@ -101,6 +101,45 @@ public func demoServer(_ publicDir: String) -> HttpServer { return HttpResponse.ok(.html(response)) } + server.GET["/upload/logo"] = { r in + guard let resourceURL = Bundle.main.resourceURL else { + return .notFound + } + + let logoURL = resourceURL.appendingPathComponent("logo.png") + guard let exists = try? logoURL.path.exists(), true == exists else { + return .notFound + } + + guard let url = URL(string: "http://127.0.0.1:9080/upload/logo"), let body = try? Data(contentsOf: logoURL) else { + return .notFound + } + + var request = URLRequest(url: url) + request.httpMethod = "POST" + request.httpBody = body + request.setValue("application/octet-stream", forHTTPHeaderField: "Content-Type") + guard let data = try? NSURLConnection.sendSynchronousRequest(request, returning: nil) else { + return .badRequest(.html("Failed to send data")) + } + return .raw(200, "OK", [:], { writter in + try writter.write(data) + }) + } + + server.filePreprocess = true + server.POST["/upload/logo"] = { r in + guard let path = r.tempFile else { + return .badRequest(.html("no file")) + } + guard let file = try? path.openForReading() else { + return .notFound + } + return .raw(200, "OK", [:], { writter in + try writter.write(file) + }) + } + server.GET["/login"] = scopes { html { head { diff --git a/Sources/HttpParser.swift b/Sources/HttpParser.swift index ae7d0982..5e75f591 100644 --- a/Sources/HttpParser.swift +++ b/Sources/HttpParser.swift @@ -26,9 +26,13 @@ public class HttpParser { request.path = statusLineTokens[1] request.queryParams = extractQueryParams(request.path) request.headers = try readHeaders(socket) - if let contentLength = request.headers["content-length"], let contentLengthValue = Int(contentLength) { - request.body = try readBody(socket, size: contentLengthValue) - } +// if let contentLength = request.headers["content-length"], let contentLengthValue = Int(contentLength) { +// if request.headers["content-type"] == "application/octet-stream" { +// request.tempFile = try readFile(socket, length: contentLengthValue) +// } else { +// request.body = try readBody(socket, size: contentLengthValue) +// } +// } return request } @@ -75,10 +79,51 @@ public class HttpParser { // } } + public func readContent(_ socket: Socket, request: HttpRequest, filePreprocess: Bool) throws { + guard let contentType = request.headers["content-type"], + let contentLength = request.headers["content-length"], + let contentLengthValue = Int(contentLength) else { + return + } + let isFileUpload = contentType == "application/octet-stream" + if isFileUpload && filePreprocess { + request.tempFile = try readFile(socket, length: contentLengthValue) + } + else { + request.body = try readBody(socket, size: contentLengthValue) + } + } + + private let kBufferLength = 1024 + + private func readFile(_ socket: Socket, length: Int) throws -> String { + var offset = 0 + let filePath = NSTemporaryDirectory() + "/" + NSUUID().uuidString + let file = try filePath.openNewForWriting() + + while offset < length { + let length = offset + kBufferLength < length ? kBufferLength : length - offset + let buffer = try socket.read(length: length) + try file.write(buffer) + offset += buffer.count + } + file.close() + return filePath + } + private func readBody(_ socket: Socket, size: Int) throws -> [UInt8] { var body = [UInt8]() - for _ in 0.. [String: String] { diff --git a/Sources/HttpRequest.swift b/Sources/HttpRequest.swift index 23257d26..fb5dc83d 100644 --- a/Sources/HttpRequest.swift +++ b/Sources/HttpRequest.swift @@ -16,9 +16,17 @@ public class HttpRequest { public var body: [UInt8] = [] public var address: String? = "" public var params: [String: String] = [:] + public var tempFile: String? public init() {} + public func removeTempFileIfExists() throws { + if let path = tempFile, try path.exists() { + try FileManager.default.removeItem(atPath: path) + } + tempFile = nil + } + public func hasTokenForHeader(_ headerName: String, token: String) -> Bool { guard let headerValue = headers[headerName] else { return false diff --git a/Sources/HttpServer.swift b/Sources/HttpServer.swift index 6e72a80c..1608db26 100644 --- a/Sources/HttpServer.swift +++ b/Sources/HttpServer.swift @@ -40,7 +40,7 @@ public class HttpServer: HttpServerIO { } public var routes: [String] { - return router.routes(); + return router.routes() } public var notFoundHandler: ((HttpRequest) -> HttpResponse)? diff --git a/Sources/HttpServerIO.swift b/Sources/HttpServerIO.swift index 8e58e1bb..1aac9559 100644 --- a/Sources/HttpServerIO.swift +++ b/Sources/HttpServerIO.swift @@ -53,6 +53,14 @@ public class HttpServerIO { /// It's only used when the server is started with `forceIPv4` option set to false. /// Otherwise, `listenAddressIPv4` will be used. public var listenAddressIPv6: String? + + /// Bool representation of whether the file upload is preprocessed. + /// `true` if the file upload requires preprocessing when `content-type` is + /// `application/octet-stream`. `HttpParser` will create a temp file(`tempFile`) in + /// `NSTemporaryDirectory()`, and is deleted after the request ends. + /// Together, `body` will be empty. + /// `false` otherwise. + public var filePreprocess: Bool = false private let queue = DispatchQueue(label: "swifter.httpserverio.clientsockets") @@ -119,9 +127,22 @@ public class HttpServerIO { while self.operating, let request = try? parser.readHttpRequest(socket) { let request = request request.address = try? socket.peername() + + do { + try parser.readContent(socket, request: request, filePreprocess: filePreprocess) + } catch { + print("Failed to read content: \(error)") + break + } + let (params, handler) = self.dispatch(request) request.params = params let response = handler(request) + + if filePreprocess { + try? request.removeTempFileIfExists() + } + var keepConnection = parser.supportsKeepAlive(request.headers) do { if self.operating { diff --git a/Sources/Socket.swift b/Sources/Socket.swift index 54a5f1e4..38b14857 100644 --- a/Sources/Socket.swift +++ b/Sources/Socket.swift @@ -111,6 +111,19 @@ open class Socket: Hashable, Equatable { } } + open func read(length: Int) throws -> [UInt8] { + var buffer = [UInt8](repeating: 0, count: length) + let count = recv(self.socketFileDescriptor as Int32, &buffer, buffer.count, 0) + if count <= 0 { + throw SocketError.recvFailed(Errno.description()) + } + + if count < length { + buffer.removeSubrange(count.. UInt8 { var buffer = [UInt8](repeating: 0, count: 1) let next = recv(self.socketFileDescriptor as Int32, &buffer, Int(buffer.count), 0)