diff --git a/httpcore/_async/http_proxy.py b/httpcore/_async/http_proxy.py index cc9d9206..004e498a 100644 --- a/httpcore/_async/http_proxy.py +++ b/httpcore/_async/http_proxy.py @@ -268,78 +268,83 @@ async def handle_async_request(self, request: Request) -> Response: async with self._connect_lock: if not self._connected: - target = b"%b:%d" % (self._remote_origin.host, self._remote_origin.port) - - connect_url = URL( - scheme=self._proxy_origin.scheme, - host=self._proxy_origin.host, - port=self._proxy_origin.port, - target=target, - ) - connect_headers = merge_headers( - [(b"Host", target), (b"Accept", b"*/*")], self._proxy_headers - ) - connect_request = Request( - method=b"CONNECT", - url=connect_url, - headers=connect_headers, - extensions=request.extensions, - ) - connect_response = await self._connection.handle_async_request( - connect_request - ) - - if connect_response.status < 200 or connect_response.status > 299: - reason_bytes = connect_response.extensions.get("reason_phrase", b"") - reason_str = reason_bytes.decode("ascii", errors="ignore") - msg = "%d %s" % (connect_response.status, reason_str) - await self._connection.aclose() - raise ProxyError(msg) - - stream = connect_response.extensions["network_stream"] - - # Upgrade the stream to SSL - ssl_context = ( - default_ssl_context() - if self._ssl_context is None - else self._ssl_context - ) - alpn_protocols = ["http/1.1", "h2"] if self._http2 else ["http/1.1"] - ssl_context.set_alpn_protocols(alpn_protocols) - - kwargs = { - "ssl_context": ssl_context, - "server_hostname": self._remote_origin.host.decode("ascii"), - "timeout": timeout, - } - async with Trace("start_tls", logger, request, kwargs) as trace: - stream = await stream.start_tls(**kwargs) - trace.return_value = stream - - # Determine if we should be using HTTP/1.1 or HTTP/2 - ssl_object = stream.get_extra_info("ssl_object") - http2_negotiated = ( - ssl_object is not None - and ssl_object.selected_alpn_protocol() == "h2" - ) - - # Create the HTTP/1.1 or HTTP/2 connection - if http2_negotiated or (self._http2 and not self._http1): - from .http2 import AsyncHTTP2Connection - - self._connection = AsyncHTTP2Connection( - origin=self._remote_origin, - stream=stream, - keepalive_expiry=self._keepalive_expiry, + try: + target = b"%b:%d" % (self._remote_origin.host, self._remote_origin.port) + + connect_url = URL( + scheme=self._proxy_origin.scheme, + host=self._proxy_origin.host, + port=self._proxy_origin.port, + target=target, ) - else: - self._connection = AsyncHTTP11Connection( - origin=self._remote_origin, - stream=stream, - keepalive_expiry=self._keepalive_expiry, + connect_headers = merge_headers( + [(b"Host", target), (b"Accept", b"*/*")], self._proxy_headers + ) + connect_request = Request( + method=b"CONNECT", + url=connect_url, + headers=connect_headers, + extensions=request.extensions, + ) + connect_response = await self._connection.handle_async_request( + connect_request + ) + + if connect_response.status < 200 or connect_response.status > 299: + reason_bytes = connect_response.extensions.get("reason_phrase", b"") + reason_str = reason_bytes.decode("ascii", errors="ignore") + msg = "%d %s" % (connect_response.status, reason_str) + await self._connection.aclose() + raise ProxyError(msg) + + stream = connect_response.extensions["network_stream"] + + # Upgrade the stream to SSL + ssl_context = ( + default_ssl_context() + if self._ssl_context is None + else self._ssl_context + ) + alpn_protocols = ["http/1.1", "h2"] if self._http2 else ["http/1.1"] + ssl_context.set_alpn_protocols(alpn_protocols) + + kwargs = { + "ssl_context": ssl_context, + "server_hostname": self._remote_origin.host.decode("ascii"), + "timeout": timeout, + } + async with Trace("start_tls", logger, request, kwargs) as trace: + stream = await stream.start_tls(**kwargs) + trace.return_value = stream + + # Determine if we should be using HTTP/1.1 or HTTP/2 + ssl_object = stream.get_extra_info("ssl_object") + http2_negotiated = ( + ssl_object is not None + and ssl_object.selected_alpn_protocol() == "h2" ) - self._connected = True + # Create the HTTP/1.1 or HTTP/2 connection + if http2_negotiated or (self._http2 and not self._http1): + from .http2 import AsyncHTTP2Connection + + self._connection = AsyncHTTP2Connection( + origin=self._remote_origin, + stream=stream, + keepalive_expiry=self._keepalive_expiry, + ) + else: + self._connection = AsyncHTTP11Connection( + origin=self._remote_origin, + stream=stream, + keepalive_expiry=self._keepalive_expiry, + ) + + self._connected = True + except Exception: + if not self._connected and not self._connection.is_closed(): + await self._connection.aclose() + raise return await self._connection.handle_async_request(request) def can_handle_request(self, origin: Origin) -> bool: diff --git a/httpcore/_sync/http_proxy.py b/httpcore/_sync/http_proxy.py index ecca88f7..6789ce41 100644 --- a/httpcore/_sync/http_proxy.py +++ b/httpcore/_sync/http_proxy.py @@ -268,78 +268,83 @@ def handle_request(self, request: Request) -> Response: with self._connect_lock: if not self._connected: - target = b"%b:%d" % (self._remote_origin.host, self._remote_origin.port) - - connect_url = URL( - scheme=self._proxy_origin.scheme, - host=self._proxy_origin.host, - port=self._proxy_origin.port, - target=target, - ) - connect_headers = merge_headers( - [(b"Host", target), (b"Accept", b"*/*")], self._proxy_headers - ) - connect_request = Request( - method=b"CONNECT", - url=connect_url, - headers=connect_headers, - extensions=request.extensions, - ) - connect_response = self._connection.handle_request( - connect_request - ) - - if connect_response.status < 200 or connect_response.status > 299: - reason_bytes = connect_response.extensions.get("reason_phrase", b"") - reason_str = reason_bytes.decode("ascii", errors="ignore") - msg = "%d %s" % (connect_response.status, reason_str) - self._connection.close() - raise ProxyError(msg) - - stream = connect_response.extensions["network_stream"] - - # Upgrade the stream to SSL - ssl_context = ( - default_ssl_context() - if self._ssl_context is None - else self._ssl_context - ) - alpn_protocols = ["http/1.1", "h2"] if self._http2 else ["http/1.1"] - ssl_context.set_alpn_protocols(alpn_protocols) - - kwargs = { - "ssl_context": ssl_context, - "server_hostname": self._remote_origin.host.decode("ascii"), - "timeout": timeout, - } - with Trace("start_tls", logger, request, kwargs) as trace: - stream = stream.start_tls(**kwargs) - trace.return_value = stream - - # Determine if we should be using HTTP/1.1 or HTTP/2 - ssl_object = stream.get_extra_info("ssl_object") - http2_negotiated = ( - ssl_object is not None - and ssl_object.selected_alpn_protocol() == "h2" - ) - - # Create the HTTP/1.1 or HTTP/2 connection - if http2_negotiated or (self._http2 and not self._http1): - from .http2 import HTTP2Connection - - self._connection = HTTP2Connection( - origin=self._remote_origin, - stream=stream, - keepalive_expiry=self._keepalive_expiry, + try: + target = b"%b:%d" % (self._remote_origin.host, self._remote_origin.port) + + connect_url = URL( + scheme=self._proxy_origin.scheme, + host=self._proxy_origin.host, + port=self._proxy_origin.port, + target=target, ) - else: - self._connection = HTTP11Connection( - origin=self._remote_origin, - stream=stream, - keepalive_expiry=self._keepalive_expiry, + connect_headers = merge_headers( + [(b"Host", target), (b"Accept", b"*/*")], self._proxy_headers + ) + connect_request = Request( + method=b"CONNECT", + url=connect_url, + headers=connect_headers, + extensions=request.extensions, + ) + connect_response = self._connection.handle_request( + connect_request + ) + + if connect_response.status < 200 or connect_response.status > 299: + reason_bytes = connect_response.extensions.get("reason_phrase", b"") + reason_str = reason_bytes.decode("ascii", errors="ignore") + msg = "%d %s" % (connect_response.status, reason_str) + self._connection.close() + raise ProxyError(msg) + + stream = connect_response.extensions["network_stream"] + + # Upgrade the stream to SSL + ssl_context = ( + default_ssl_context() + if self._ssl_context is None + else self._ssl_context + ) + alpn_protocols = ["http/1.1", "h2"] if self._http2 else ["http/1.1"] + ssl_context.set_alpn_protocols(alpn_protocols) + + kwargs = { + "ssl_context": ssl_context, + "server_hostname": self._remote_origin.host.decode("ascii"), + "timeout": timeout, + } + with Trace("start_tls", logger, request, kwargs) as trace: + stream = stream.start_tls(**kwargs) + trace.return_value = stream + + # Determine if we should be using HTTP/1.1 or HTTP/2 + ssl_object = stream.get_extra_info("ssl_object") + http2_negotiated = ( + ssl_object is not None + and ssl_object.selected_alpn_protocol() == "h2" ) - self._connected = True + # Create the HTTP/1.1 or HTTP/2 connection + if http2_negotiated or (self._http2 and not self._http1): + from .http2 import HTTP2Connection + + self._connection = HTTP2Connection( + origin=self._remote_origin, + stream=stream, + keepalive_expiry=self._keepalive_expiry, + ) + else: + self._connection = HTTP11Connection( + origin=self._remote_origin, + stream=stream, + keepalive_expiry=self._keepalive_expiry, + ) + + self._connected = True + except Exception: + if not self._connected and not self._connection.is_closed(): + self._connection.close() + raise return self._connection.handle_request(request) def can_handle_request(self, origin: Origin) -> bool: