567 lines
26 KiB
Diff
567 lines
26 KiB
Diff
diff --git a/requirements.txt b/requirements.txt
|
|
index e26e6b3..b16569f 100644
|
|
--- a/requirements.txt
|
|
+++ b/requirements.txt
|
|
@@ -7,7 +7,7 @@ h11 @ git+https://github.com/python-hyper/h11.git@master
|
|
# Explicit optionals
|
|
a2wsgi==1.10.7
|
|
wsproto==1.2.0
|
|
-websockets==13.1
|
|
+websockets==14.1
|
|
|
|
# Packaging
|
|
build==1.2.2.post1
|
|
diff --git a/tests/middleware/test_logging.py b/tests/middleware/test_logging.py
|
|
index 63d7daf..5aef174 100644
|
|
--- a/tests/middleware/test_logging.py
|
|
+++ b/tests/middleware/test_logging.py
|
|
@@ -8,8 +8,7 @@ import typing
|
|
|
|
import httpx
|
|
import pytest
|
|
-import websockets
|
|
-import websockets.client
|
|
+from websockets.asyncio.client import connect
|
|
|
|
from tests.utils import run_server
|
|
from uvicorn import Config
|
|
@@ -107,8 +106,8 @@ async def test_trace_logging_on_ws_protocol(
|
|
break
|
|
|
|
async def open_connection(url: str):
|
|
- async with websockets.client.connect(url) as websocket:
|
|
- return websocket.open
|
|
+ async with connect(url):
|
|
+ return True
|
|
|
|
config = Config(
|
|
app=websocket_app,
|
|
diff --git a/tests/middleware/test_proxy_headers.py b/tests/middleware/test_proxy_headers.py
|
|
index d300c45..4b5f195 100644
|
|
--- a/tests/middleware/test_proxy_headers.py
|
|
+++ b/tests/middleware/test_proxy_headers.py
|
|
@@ -5,7 +5,7 @@ from typing import TYPE_CHECKING
|
|
import httpx
|
|
import httpx._transports.asgi
|
|
import pytest
|
|
-import websockets.client
|
|
+from websockets.asyncio.client import connect
|
|
|
|
from tests.response import Response
|
|
from tests.utils import run_server
|
|
@@ -479,7 +479,7 @@ async def test_proxy_headers_websocket_x_forwarded_proto(
|
|
async with run_server(config):
|
|
url = f"ws://127.0.0.1:{unused_tcp_port}"
|
|
headers = {X_FORWARDED_FOR: "1.2.3.4", X_FORWARDED_PROTO: forwarded_proto}
|
|
- async with websockets.client.connect(url, extra_headers=headers) as websocket:
|
|
+ async with connect(url, additional_headers=headers) as websocket:
|
|
data = await websocket.recv()
|
|
assert data == expected
|
|
|
|
diff --git a/tests/protocols/test_websocket.py b/tests/protocols/test_websocket.py
|
|
index e728544..b9035ec 100644
|
|
--- a/tests/protocols/test_websocket.py
|
|
+++ b/tests/protocols/test_websocket.py
|
|
@@ -12,6 +12,8 @@ import websockets.asyncio.client
|
|
import websockets.client
|
|
import websockets.exceptions
|
|
from typing_extensions import TypedDict
|
|
+from websockets.asyncio.client import ClientConnection, connect
|
|
+from websockets.exceptions import ConnectionClosed, ConnectionClosedError, InvalidHandshake, InvalidStatus
|
|
from websockets.extensions.permessage_deflate import ClientPerMessageDeflateFactory
|
|
from websockets.typing import Subprotocol
|
|
|
|
@@ -130,8 +132,8 @@ async def test_accept_connection(ws_protocol_cls: WSProtocol, http_protocol_cls:
|
|
await self.send({"type": "websocket.accept"})
|
|
|
|
async def open_connection(url: str):
|
|
- async with websockets.client.connect(url) as websocket:
|
|
- return websocket.open
|
|
+ async with connect(url):
|
|
+ return True
|
|
|
|
config = Config(app=App, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port)
|
|
async with run_server(config):
|
|
@@ -146,7 +148,7 @@ async def test_shutdown(ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProt
|
|
|
|
config = Config(app=App, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port)
|
|
async with run_server(config) as server:
|
|
- async with websockets.client.connect(f"ws://127.0.0.1:{unused_tcp_port}"):
|
|
+ async with connect(f"ws://127.0.0.1:{unused_tcp_port}"):
|
|
# Attempt shutdown while connection is still open
|
|
await server.shutdown()
|
|
|
|
@@ -160,8 +162,8 @@ async def test_supports_permessage_deflate_extension(
|
|
|
|
async def open_connection(url: str):
|
|
extension_factories = [ClientPerMessageDeflateFactory()]
|
|
- async with websockets.client.connect(url, extensions=extension_factories) as websocket:
|
|
- return [extension.name for extension in websocket.extensions]
|
|
+ async with connect(url, extensions=extension_factories) as websocket:
|
|
+ return [extension.name for extension in websocket.protocol.extensions]
|
|
|
|
config = Config(app=App, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port)
|
|
async with run_server(config):
|
|
@@ -180,8 +182,8 @@ async def test_can_disable_permessage_deflate_extension(
|
|
# enable per-message deflate on the client, so that we can check the server
|
|
# won't support it when it's disabled.
|
|
extension_factories = [ClientPerMessageDeflateFactory()]
|
|
- async with websockets.client.connect(url, extensions=extension_factories) as websocket:
|
|
- return [extension.name for extension in websocket.extensions]
|
|
+ async with connect(url, extensions=extension_factories) as websocket:
|
|
+ return [extension.name for extension in websocket.protocol.extensions]
|
|
|
|
config = Config(
|
|
app=App,
|
|
@@ -203,8 +205,8 @@ async def test_close_connection(ws_protocol_cls: WSProtocol, http_protocol_cls:
|
|
|
|
async def open_connection(url: str):
|
|
try:
|
|
- await websockets.client.connect(url)
|
|
- except websockets.exceptions.InvalidHandshake:
|
|
+ await connect(url)
|
|
+ except InvalidHandshake:
|
|
return False
|
|
return True # pragma: no cover
|
|
|
|
@@ -224,8 +226,8 @@ async def test_headers(ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProto
|
|
await self.send({"type": "websocket.accept"})
|
|
|
|
async def open_connection(url: str):
|
|
- async with websockets.client.connect(url, extra_headers=[("username", "abraão")]) as websocket:
|
|
- return websocket.open
|
|
+ async with connect(url, additional_headers=[("username", "abraão")]):
|
|
+ return True
|
|
|
|
config = Config(app=App, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port)
|
|
async with run_server(config):
|
|
@@ -239,8 +241,9 @@ async def test_extra_headers(ws_protocol_cls: WSProtocol, http_protocol_cls: HTT
|
|
await self.send({"type": "websocket.accept", "headers": [(b"extra", b"header")]})
|
|
|
|
async def open_connection(url: str):
|
|
- async with websockets.client.connect(url) as websocket:
|
|
- return websocket.response_headers
|
|
+ async with connect(url) as websocket:
|
|
+ assert websocket.response
|
|
+ return websocket.response.headers
|
|
|
|
config = Config(app=App, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port)
|
|
async with run_server(config):
|
|
@@ -258,8 +261,8 @@ async def test_path_and_raw_path(ws_protocol_cls: WSProtocol, http_protocol_cls:
|
|
await self.send({"type": "websocket.accept"})
|
|
|
|
async def open_connection(url: str):
|
|
- async with websockets.client.connect(url) as websocket:
|
|
- return websocket.open
|
|
+ async with connect(url):
|
|
+ return True
|
|
|
|
config = Config(app=App, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port)
|
|
async with run_server(config):
|
|
@@ -276,7 +279,7 @@ async def test_send_text_data_to_client(
|
|
await self.send({"type": "websocket.send", "text": "123"})
|
|
|
|
async def get_data(url: str):
|
|
- async with websockets.client.connect(url) as websocket:
|
|
+ async with connect(url) as websocket:
|
|
return await websocket.recv()
|
|
|
|
config = Config(app=App, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port)
|
|
@@ -294,7 +297,7 @@ async def test_send_binary_data_to_client(
|
|
await self.send({"type": "websocket.send", "bytes": b"123"})
|
|
|
|
async def get_data(url: str):
|
|
- async with websockets.client.connect(url) as websocket:
|
|
+ async with connect(url) as websocket:
|
|
return await websocket.recv()
|
|
|
|
config = Config(app=App, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port)
|
|
@@ -313,7 +316,7 @@ async def test_send_and_close_connection(
|
|
await self.send({"type": "websocket.close"})
|
|
|
|
async def get_data(url: str):
|
|
- async with websockets.client.connect(url) as websocket:
|
|
+ async with connect(url) as websocket:
|
|
data = await websocket.recv()
|
|
is_open = True
|
|
try:
|
|
@@ -342,7 +345,7 @@ async def test_send_text_data_to_server(
|
|
await self.send({"type": "websocket.send", "text": _text})
|
|
|
|
async def send_text(url: str):
|
|
- async with websockets.client.connect(url) as websocket:
|
|
+ async with connect(url) as websocket:
|
|
await websocket.send("abc")
|
|
return await websocket.recv()
|
|
|
|
@@ -365,7 +368,7 @@ async def test_send_binary_data_to_server(
|
|
await self.send({"type": "websocket.send", "bytes": _bytes})
|
|
|
|
async def send_text(url: str):
|
|
- async with websockets.client.connect(url) as websocket:
|
|
+ async with connect(url) as websocket:
|
|
await websocket.send(b"abc")
|
|
return await websocket.recv()
|
|
|
|
@@ -387,7 +390,7 @@ async def test_send_after_protocol_close(
|
|
await self.send({"type": "websocket.send", "text": "123"})
|
|
|
|
async def get_data(url: str):
|
|
- async with websockets.client.connect(url) as websocket:
|
|
+ async with connect(url) as websocket:
|
|
data = await websocket.recv()
|
|
is_open = True
|
|
try:
|
|
@@ -407,14 +410,14 @@ async def test_missing_handshake(ws_protocol_cls: WSProtocol, http_protocol_cls:
|
|
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
|
|
pass
|
|
|
|
- async def connect(url: str):
|
|
- await websockets.client.connect(url)
|
|
+ async def open_connection(url: str):
|
|
+ await connect(url)
|
|
|
|
config = Config(app=app, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port)
|
|
async with run_server(config):
|
|
- with pytest.raises(websockets.exceptions.InvalidStatusCode) as exc_info:
|
|
- await connect(f"ws://127.0.0.1:{unused_tcp_port}")
|
|
- assert exc_info.value.status_code == 500
|
|
+ with pytest.raises(InvalidStatus) as exc_info:
|
|
+ await open_connection(f"ws://127.0.0.1:{unused_tcp_port}")
|
|
+ assert exc_info.value.response.status_code == 500
|
|
|
|
|
|
async def test_send_before_handshake(
|
|
@@ -423,14 +426,14 @@ async def test_send_before_handshake(
|
|
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
|
|
await send({"type": "websocket.send", "text": "123"})
|
|
|
|
- async def connect(url: str):
|
|
- await websockets.client.connect(url)
|
|
+ async def open_connection(url: str):
|
|
+ await connect(url)
|
|
|
|
config = Config(app=app, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port)
|
|
async with run_server(config):
|
|
- with pytest.raises(websockets.exceptions.InvalidStatusCode) as exc_info:
|
|
- await connect(f"ws://127.0.0.1:{unused_tcp_port}")
|
|
- assert exc_info.value.status_code == 500
|
|
+ with pytest.raises(InvalidStatus) as exc_info:
|
|
+ await open_connection(f"ws://127.0.0.1:{unused_tcp_port}")
|
|
+ assert exc_info.value.response.status_code == 500
|
|
|
|
|
|
async def test_duplicate_handshake(ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int):
|
|
@@ -440,10 +443,10 @@ async def test_duplicate_handshake(ws_protocol_cls: WSProtocol, http_protocol_cl
|
|
|
|
config = Config(app=app, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port)
|
|
async with run_server(config):
|
|
- async with websockets.client.connect(f"ws://127.0.0.1:{unused_tcp_port}") as websocket:
|
|
- with pytest.raises(websockets.exceptions.ConnectionClosed):
|
|
+ async with connect(f"ws://127.0.0.1:{unused_tcp_port}") as websocket:
|
|
+ with pytest.raises(ConnectionClosed):
|
|
_ = await websocket.recv()
|
|
- assert websocket.close_code == 1006
|
|
+ assert websocket.protocol.close_code == 1006
|
|
|
|
|
|
async def test_asgi_return_value(ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int):
|
|
@@ -458,10 +461,10 @@ async def test_asgi_return_value(ws_protocol_cls: WSProtocol, http_protocol_cls:
|
|
|
|
config = Config(app=app, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port)
|
|
async with run_server(config):
|
|
- async with websockets.client.connect(f"ws://127.0.0.1:{unused_tcp_port}") as websocket:
|
|
- with pytest.raises(websockets.exceptions.ConnectionClosed):
|
|
+ async with connect(f"ws://127.0.0.1:{unused_tcp_port}") as websocket:
|
|
+ with pytest.raises(ConnectionClosed):
|
|
_ = await websocket.recv()
|
|
- assert websocket.close_code == 1006
|
|
+ assert websocket.protocol.close_code == 1006
|
|
|
|
|
|
@pytest.mark.parametrize("code", [None, 1000, 1001])
|
|
@@ -493,13 +496,13 @@ async def test_app_close(
|
|
|
|
config = Config(app=app, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port)
|
|
async with run_server(config):
|
|
- async with websockets.client.connect(f"ws://127.0.0.1:{unused_tcp_port}") as websocket:
|
|
+ async with connect(f"ws://127.0.0.1:{unused_tcp_port}") as websocket:
|
|
await websocket.ping()
|
|
await websocket.send("abc")
|
|
- with pytest.raises(websockets.exceptions.ConnectionClosed):
|
|
+ with pytest.raises(ConnectionClosed):
|
|
await websocket.recv()
|
|
- assert websocket.close_code == (code or 1000)
|
|
- assert websocket.close_reason == (reason or "")
|
|
+ assert websocket.protocol.close_code == (code or 1000)
|
|
+ assert websocket.protocol.close_reason == (reason or "")
|
|
|
|
|
|
async def test_client_close(ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int):
|
|
@@ -518,7 +521,7 @@ async def test_client_close(ws_protocol_cls: WSProtocol, http_protocol_cls: HTTP
|
|
break
|
|
|
|
async def websocket_session(url: str):
|
|
- async with websockets.client.connect(url) as websocket:
|
|
+ async with connect(url) as websocket:
|
|
await websocket.ping()
|
|
await websocket.send("abc")
|
|
await websocket.close(code=1001, reason="custom reason")
|
|
@@ -555,7 +558,7 @@ async def test_client_connection_lost(
|
|
port=unused_tcp_port,
|
|
)
|
|
async with run_server(config):
|
|
- async with websockets.client.connect(f"ws://127.0.0.1:{unused_tcp_port}") as websocket:
|
|
+ async with connect(f"ws://127.0.0.1:{unused_tcp_port}") as websocket:
|
|
websocket.transport.close()
|
|
await asyncio.sleep(0.1)
|
|
got_disconnect_event_before_shutdown = got_disconnect_event
|
|
@@ -583,7 +586,7 @@ async def test_client_connection_lost_on_send(
|
|
config = Config(app=app, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port)
|
|
async with run_server(config):
|
|
url = f"ws://127.0.0.1:{unused_tcp_port}"
|
|
- async with websockets.client.connect(url):
|
|
+ async with connect(url):
|
|
await asyncio.sleep(0.1)
|
|
disconnect.set()
|
|
|
|
@@ -642,11 +645,11 @@ async def test_send_close_on_server_shutdown(
|
|
disconnect_message = message
|
|
break
|
|
|
|
- websocket: websockets.client.WebSocketClientProtocol | None = None
|
|
+ websocket: ClientConnection | None = None
|
|
|
|
async def websocket_session(uri: str):
|
|
nonlocal websocket
|
|
- async with websockets.client.connect(uri) as ws_connection:
|
|
+ async with connect(uri) as ws_connection:
|
|
websocket = ws_connection
|
|
await server_shutdown_event.wait()
|
|
|
|
@@ -676,9 +679,7 @@ async def test_subprotocols(
|
|
await self.send({"type": "websocket.accept", "subprotocol": subprotocol})
|
|
|
|
async def get_subprotocol(url: str):
|
|
- async with websockets.client.connect(
|
|
- url, subprotocols=[Subprotocol("proto1"), Subprotocol("proto2")]
|
|
- ) as websocket:
|
|
+ async with connect(url, subprotocols=[Subprotocol("proto1"), Subprotocol("proto2")]) as websocket:
|
|
return websocket.subprotocol
|
|
|
|
config = Config(app=App, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port)
|
|
@@ -688,7 +689,7 @@ async def test_subprotocols(
|
|
|
|
|
|
MAX_WS_BYTES = 1024 * 1024 * 16
|
|
-MAX_WS_BYTES_PLUS1 = MAX_WS_BYTES + 1
|
|
+MAX_WS_BYTES_PLUS1 = MAX_WS_BYTES + 10
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
@@ -731,15 +732,15 @@ async def test_send_binary_data_to_server_bigger_than_default_on_websockets(
|
|
port=unused_tcp_port,
|
|
)
|
|
async with run_server(config):
|
|
- async with websockets.client.connect(f"ws://127.0.0.1:{unused_tcp_port}", max_size=client_size_sent) as ws:
|
|
+ async with connect(f"ws://127.0.0.1:{unused_tcp_port}", max_size=client_size_sent) as ws:
|
|
await ws.send(b"\x01" * client_size_sent)
|
|
if expected_result == 0:
|
|
data = await ws.recv()
|
|
assert data == b"\x01" * client_size_sent
|
|
else:
|
|
- with pytest.raises(websockets.exceptions.ConnectionClosedError):
|
|
+ with pytest.raises(ConnectionClosedError):
|
|
await ws.recv()
|
|
- assert ws.close_code == expected_result
|
|
+ assert ws.protocol.close_code == expected_result
|
|
|
|
|
|
async def test_server_reject_connection(
|
|
@@ -764,10 +765,10 @@ async def test_server_reject_connection(
|
|
disconnected_message = await receive()
|
|
|
|
async def websocket_session(url: str):
|
|
- with pytest.raises(websockets.exceptions.InvalidStatusCode) as exc_info:
|
|
- async with websockets.client.connect(url):
|
|
+ with pytest.raises(InvalidStatus) as exc_info:
|
|
+ async with connect(url):
|
|
pass # pragma: no cover
|
|
- assert exc_info.value.status_code == 403
|
|
+ assert exc_info.value.response.status_code == 403
|
|
|
|
config = Config(app=app, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port)
|
|
async with run_server(config):
|
|
@@ -937,10 +938,10 @@ async def test_server_reject_connection_with_invalid_msg(
|
|
await send(message)
|
|
|
|
async def websocket_session(url: str):
|
|
- with pytest.raises(websockets.exceptions.InvalidStatusCode) as exc_info:
|
|
- async with websockets.client.connect(url):
|
|
+ with pytest.raises(InvalidStatus) as exc_info:
|
|
+ async with connect(url):
|
|
pass # pragma: no cover
|
|
- assert exc_info.value.status_code == 404
|
|
+ assert exc_info.value.response.status_code == 404
|
|
|
|
config = Config(app=app, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port)
|
|
async with run_server(config):
|
|
@@ -971,10 +972,10 @@ async def test_server_reject_connection_with_missing_body(
|
|
# no further message
|
|
|
|
async def websocket_session(url: str):
|
|
- with pytest.raises(websockets.exceptions.InvalidStatusCode) as exc_info:
|
|
- async with websockets.client.connect(url):
|
|
+ with pytest.raises(InvalidStatus) as exc_info:
|
|
+ async with connect(url):
|
|
pass # pragma: no cover
|
|
- assert exc_info.value.status_code == 404
|
|
+ assert exc_info.value.response.status_code == 404
|
|
|
|
config = Config(app=app, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port)
|
|
async with run_server(config):
|
|
@@ -1014,17 +1015,17 @@ async def test_server_multiple_websocket_http_response_start_events(
|
|
exception_message = str(exc)
|
|
|
|
async def websocket_session(url: str):
|
|
- with pytest.raises(websockets.exceptions.InvalidStatusCode) as exc_info:
|
|
- async with websockets.client.connect(url):
|
|
+ with pytest.raises(InvalidStatus) as exc_info:
|
|
+ async with connect(url):
|
|
pass # pragma: no cover
|
|
- assert exc_info.value.status_code == 404
|
|
+ assert exc_info.value.response.status_code == 404
|
|
|
|
config = Config(app=app, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port)
|
|
async with run_server(config):
|
|
await websocket_session(f"ws://127.0.0.1:{unused_tcp_port}")
|
|
|
|
assert exception_message == (
|
|
- "Expected ASGI message 'websocket.http.response.body' but got " "'websocket.http.response.start'."
|
|
+ "Expected ASGI message 'websocket.http.response.body' but got 'websocket.http.response.start'."
|
|
)
|
|
|
|
|
|
@@ -1053,7 +1054,7 @@ async def test_server_can_read_messages_in_buffer_after_close(
|
|
|
|
config = Config(app=App, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port)
|
|
async with run_server(config):
|
|
- async with websockets.client.connect(f"ws://127.0.0.1:{unused_tcp_port}") as websocket:
|
|
+ async with connect(f"ws://127.0.0.1:{unused_tcp_port}") as websocket:
|
|
await websocket.send(b"abc")
|
|
await websocket.send(b"abc")
|
|
await websocket.send(b"abc")
|
|
@@ -1070,8 +1071,9 @@ async def test_default_server_headers(
|
|
await self.send({"type": "websocket.accept"})
|
|
|
|
async def open_connection(url: str):
|
|
- async with websockets.client.connect(url) as websocket:
|
|
- return websocket.response_headers
|
|
+ async with connect(url) as websocket:
|
|
+ assert websocket.response
|
|
+ return websocket.response.headers
|
|
|
|
config = Config(app=App, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port)
|
|
async with run_server(config):
|
|
@@ -1085,8 +1087,9 @@ async def test_no_server_headers(ws_protocol_cls: WSProtocol, http_protocol_cls:
|
|
await self.send({"type": "websocket.accept"})
|
|
|
|
async def open_connection(url: str):
|
|
- async with websockets.client.connect(url) as websocket:
|
|
- return websocket.response_headers
|
|
+ async with connect(url) as websocket:
|
|
+ assert websocket.response
|
|
+ return websocket.response.headers
|
|
|
|
config = Config(
|
|
app=App,
|
|
@@ -1108,8 +1111,9 @@ async def test_no_date_header_on_wsproto(http_protocol_cls: HTTPProtocol, unused
|
|
await self.send({"type": "websocket.accept"})
|
|
|
|
async def open_connection(url: str):
|
|
- async with websockets.client.connect(url) as websocket:
|
|
- return websocket.response_headers
|
|
+ async with connect(url) as websocket:
|
|
+ assert websocket.response
|
|
+ return websocket.response.headers
|
|
|
|
config = Config(
|
|
app=App,
|
|
@@ -1140,8 +1144,9 @@ async def test_multiple_server_header(
|
|
)
|
|
|
|
async def open_connection(url: str):
|
|
- async with websockets.client.connect(url) as websocket:
|
|
- return websocket.response_headers
|
|
+ async with connect(url) as websocket:
|
|
+ assert websocket.response
|
|
+ return websocket.response.headers
|
|
|
|
config = Config(app=App, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port)
|
|
async with run_server(config):
|
|
@@ -1176,8 +1181,8 @@ async def test_lifespan_state(ws_protocol_cls: WSProtocol, http_protocol_cls: HT
|
|
await self.send({"type": "websocket.accept"})
|
|
|
|
async def open_connection(url: str):
|
|
- async with websockets.client.connect(url) as websocket:
|
|
- return websocket.open
|
|
+ async with connect(url):
|
|
+ return True
|
|
|
|
async def app_wrapper(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
|
|
if scope["type"] == "lifespan":
|
|
diff --git a/uvicorn/protocols/websockets/websockets_impl.py b/uvicorn/protocols/websockets/websockets_impl.py
|
|
index cd6c54f..685d6b6 100644
|
|
--- a/uvicorn/protocols/websockets/websockets_impl.py
|
|
+++ b/uvicorn/protocols/websockets/websockets_impl.py
|
|
@@ -13,8 +13,7 @@ from websockets.datastructures import Headers
|
|
from websockets.exceptions import ConnectionClosed
|
|
from websockets.extensions.base import ServerExtensionFactory
|
|
from websockets.extensions.permessage_deflate import ServerPerMessageDeflateFactory
|
|
-from websockets.legacy.server import HTTPResponse
|
|
-from websockets.server import WebSocketServerProtocol
|
|
+from websockets.legacy.server import HTTPResponse, WebSocketServerProtocol
|
|
from websockets.typing import Subprotocol
|
|
|
|
from uvicorn._types import (
|
|
diff --git a/uvicorn/protocols/websockets/wsproto_impl.py b/uvicorn/protocols/websockets/wsproto_impl.py
|
|
index 828afe5..5d84bff 100644
|
|
--- a/uvicorn/protocols/websockets/wsproto_impl.py
|
|
+++ b/uvicorn/protocols/websockets/wsproto_impl.py
|
|
@@ -149,12 +149,13 @@ class WSProtocol(asyncio.Protocol):
|
|
self.writable.set() # pragma: full coverage
|
|
|
|
def shutdown(self) -> None:
|
|
- if self.handshake_complete:
|
|
- self.queue.put_nowait({"type": "websocket.disconnect", "code": 1012})
|
|
- output = self.conn.send(wsproto.events.CloseConnection(code=1012))
|
|
- self.transport.write(output)
|
|
- else:
|
|
- self.send_500_response()
|
|
+ if not self.response_started:
|
|
+ if self.handshake_complete:
|
|
+ self.queue.put_nowait({"type": "websocket.disconnect", "code": 1012})
|
|
+ output = self.conn.send(wsproto.events.CloseConnection(code=1012))
|
|
+ self.transport.write(output)
|
|
+ else:
|
|
+ self.send_500_response()
|
|
self.transport.close()
|
|
|
|
def on_task_complete(self, task: asyncio.Task[None]) -> None:
|
|
@@ -221,13 +222,15 @@ class WSProtocol(asyncio.Protocol):
|
|
def send_500_response(self) -> None:
|
|
if self.response_started or self.handshake_complete:
|
|
return # we cannot send responses anymore
|
|
+ reject_data = b"Internal Server Error"
|
|
headers: list[tuple[bytes, bytes]] = [
|
|
(b"content-type", b"text/plain; charset=utf-8"),
|
|
+ (b"content-length", str(len(reject_data)).encode()),
|
|
(b"connection", b"close"),
|
|
(b"content-length", b"21"),
|
|
]
|
|
output = self.conn.send(wsproto.events.RejectConnection(status_code=500, headers=headers, has_body=True))
|
|
- output += self.conn.send(wsproto.events.RejectData(data=b"Internal Server Error"))
|
|
+ output += self.conn.send(wsproto.events.RejectData(data=reject_data))
|
|
self.transport.write(output)
|
|
|
|
async def run_asgi(self) -> None:
|