diff --git a/src/http.lua b/src/http.lua index 84110d8e..beb1786d 100644 --- a/src/http.lua +++ b/src/http.lua @@ -203,8 +203,15 @@ end ----------------------------------------------------------------------------- local function adjusturi(reqt) local u = reqt - -- if there is a proxy, we need the full url. otherwise, just a part. - if not reqt.proxy and not _M.PROXY then + local proxy + if reqt.proxy then + proxy = url.parse(reqt.proxy) + end + + -- We just need the path if there's no proxy, + -- or if we use https over an http proxy. + -- Otherwise, we use a full url + if not proxy or (reqt.scheme == "https" and proxy.scheme == "http") then u = { path = socket.try(reqt.path, "invalid path 'nil'"), params = reqt.params, @@ -215,18 +222,7 @@ local function adjusturi(reqt) return url.build(u) end -local function adjustproxy(reqt) - local proxy = reqt.proxy or _M.PROXY - if proxy then - proxy = url.parse(proxy) - local proxy_create = SCHEMES[proxy.scheme].create(reqt) - return proxy.host, proxy.port or 3128, proxy_create - else - return reqt.host, reqt.port - end -end - -local function adjustheaders(reqt) +local function adjustheaders(reqt, https_connect) -- default headers local host = reqt.host local port = tostring(reqt.port) @@ -245,9 +241,8 @@ local function adjustheaders(reqt) url.unescape(reqt.password))) end -- if we have proxy authentication information, pass it along - local proxy = reqt.proxy or _M.PROXY - if proxy then - proxy = url.parse(proxy) + if reqt.proxy and (reqt.scheme == "http" or https_connect) then + local proxy = url.parse(reqt.proxy) if proxy.user and proxy.password then lower["proxy-authorization"] = "Basic " .. (mime.b64(proxy.user .. ":" .. proxy.password)) @@ -260,6 +255,82 @@ local function adjustheaders(reqt) return lower end +local function reg(conn) + local mt = getmetatable(conn.sock).__index + for name, method in pairs(mt) do + if type(method) == "function" then + conn[name] = function (self, ...) + return method(self.sock, ...) + end + end + end +end + +local function proxy_connect_create(params, proxy) + -- Copied and adapted from luasec's https.lua + -- in function ssl.http.tcp() + + local ssl = assert( + require("ssl"), 'LuaSocket: LuaSec not found') + local https = assert( + require("ssl.https"), 'LuaSocket: LuaSec not found') + local tcp = assert( + https.tcp, 'LuaSocket: Function tcp() not available from LuaSec') + + -- Force client mode + params.mode = "client" + -- 'create' function + return function () + local conn = {} + conn.proxy_sock = _M.open(proxy.host, proxy.port, proxy.create) + local try = conn.proxy_sock.try + + function conn:settimeout(...) + return self.proxy_sock.c:settimeout(https.TIMEOUT) + end + + -- Wrap the underlying connection function + function conn:connect(host, port) + conn.proxy_sock:sendrequestline("CONNECT", host .. ":" .. tostring(port)) + conn.proxy_sock:sendheaders(adjustheaders(proxy, true)) + + local code, _ = conn.proxy_sock:receivestatusline() + try(code == 200 or nil) + + self.sock = try(ssl.wrap(self.proxy_sock.c, params)) + self.sock:sni(host) + self.sock:settimeout(https.TIMEOUT) + try(self.sock:dohandshake()) + reg(self) + return 1 + end + + -- Close the underlying socket + function conn:close() + conn.proxy_sock:close() + end + + return conn + end +end + +local function adjustproxy(reqt) + if reqt.proxy then + local proxy = url.parse(reqt.proxy) + proxy.port = proxy.port or 3128 + proxy.create = SCHEMES[proxy.scheme].create(reqt) + + if reqt.scheme == "https" and proxy.scheme == "http" then + local wrapped_create = proxy_connect_create(reqt, proxy) + return reqt.host, reqt.port, wrapped_create + else + return proxy.host, proxy.port, proxy.create + end + else + return reqt.host, reqt.port, reqt.create + end +end + -- default url parts local default = { path ="/" @@ -269,6 +340,8 @@ local default = { local function adjustrequest(reqt) -- parse url if provided local nreqt = reqt.url and url.parse(reqt.url, default) or {} + -- get global proxy + nreqt.proxy = reqt.proxy or _M.PROXY -- explicit components override url for i,v in base.pairs(reqt) do nreqt[i] = v end -- default to scheme particulars @@ -283,7 +356,7 @@ local function adjustrequest(reqt) -- compute uri if user hasn't overriden nreqt.uri = reqt.uri or adjusturi(nreqt) -- adjust headers in request - nreqt.headers = adjustheaders(nreqt) + nreqt.headers = adjustheaders(nreqt, false) if nreqt.source and not nreqt.headers["content-length"] and not nreqt.headers["transfer-encoding"] @@ -291,7 +364,7 @@ local function adjustrequest(reqt) nreqt.headers["transfer-encoding"] = "chunked" end - -- ajust host and port if there is a proxy + -- adjust host, port and create if there is a proxy nreqt.host, nreqt.port, nreqt.create = adjustproxy(nreqt) return nreqt end