Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/17894.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Remove usage of internal header encoding API.
40 changes: 16 additions & 24 deletions synapse/http/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,25 +51,17 @@
# "Hop-by-hop" headers (as opposed to "end-to-end" headers) as defined by RFC2616
# section 13.5.1 and referenced in RFC9110 section 7.6.1. These are meant to only be
# consumed by the immediate recipient and not be forwarded on.
HOP_BY_HOP_HEADERS = {
"Connection",
"Keep-Alive",
"Proxy-Authenticate",
"Proxy-Authorization",
"TE",
"Trailers",
"Transfer-Encoding",
"Upgrade",
HOP_BY_HOP_HEADERS_LOWERCASE = {
"connection",
"keep-alive",
"proxy-authenticate",
"proxy-authorization",
"te",
"trailers",
"transfer-encoding",
"upgrade",
}

if hasattr(Headers, "_canonicalNameCaps"):
# Twisted < 24.7.0rc1
_canonicalHeaderName = Headers()._canonicalNameCaps # type: ignore[attr-defined]
else:
# Twisted >= 24.7.0rc1
# But note that `_encodeName` still exists on prior versions,
# it just encodes differently
_canonicalHeaderName = Headers()._encodeName
assert all(header.lower() == header for header in HOP_BY_HOP_HEADERS_LOWERCASE)


def parse_connection_header_value(
Expand All @@ -92,12 +84,12 @@ def parse_connection_header_value(

Returns:
The set of header names that should not be copied over from the remote response.
The keys are capitalized in canonical capitalization.
The keys are lowercased.
"""
extra_headers_to_remove: Set[str] = set()
if connection_header_value:
extra_headers_to_remove = {
_canonicalHeaderName(connection_option.strip()).decode("ascii")
connection_option.decode("ascii").strip().lower()
for connection_option in connection_header_value.split(b",")
}

Expand Down Expand Up @@ -194,18 +186,18 @@ def _send_response(

# The `Connection` header also defines which headers should not be copied over.
connection_header = response_headers.getRawHeaders(b"connection")
extra_headers_to_remove = parse_connection_header_value(
extra_headers_to_remove_lowercase = parse_connection_header_value(
connection_header[0] if connection_header else None
)

# Copy headers.
for k, v in response_headers.getAllRawHeaders():
# Do not copy over any hop-by-hop headers. These are meant to only be
# consumed by the immediate recipient and not be forwarded on.
header_key = k.decode("ascii")
header_key_lowercase = k.decode("ascii").lower()
if (
header_key in HOP_BY_HOP_HEADERS
or header_key in extra_headers_to_remove
header_key_lowercase in HOP_BY_HOP_HEADERS_LOWERCASE
or header_key_lowercase in extra_headers_to_remove_lowercase
):
continue

Expand Down
23 changes: 19 additions & 4 deletions tests/http/test_matrixfederationclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -903,12 +903,19 @@ def test_proxy_requests_and_discards_hop_by_hop_headers(self) -> None:
headers=Headers(
{
"Content-Type": ["application/json"],
"Connection": ["close, X-Foo, X-Bar"],
"X-Test": ["test"],
# Define some hop-by-hop headers (try with varying casing to
# make sure we still match-up the headers)
"Connection": ["close, X-fOo, X-Bar, X-baz"],
# Should be removed because it's defined in the `Connection` header
"X-Foo": ["foo"],
"X-Bar": ["bar"],
# (not in canonical case)
"x-baZ": ["baz"],
# Should be removed because it's a hop-by-hop header
"Proxy-Authorization": "abcdef",
# Should be removed because it's a hop-by-hop header (not in canonical case)
"transfer-EnCoDiNg": "abcdef",
}
),
)
Expand Down Expand Up @@ -938,9 +945,17 @@ def test_proxy_requests_and_discards_hop_by_hop_headers(self) -> None:
header_names = set(headers.keys())

# Make sure the response does not include the hop-by-hop headers
self.assertNotIn(b"X-Foo", header_names)
self.assertNotIn(b"X-Bar", header_names)
self.assertNotIn(b"Proxy-Authorization", header_names)
self.assertIncludes(
header_names,
{
b"Content-Type",
b"X-Test",
# Default headers from Twisted
b"Date",
b"Server",
},
exact=True,
)
# Make sure the response is as expected back on the main worker
self.assertEqual(res, {"foo": "bar"})

Expand Down
32 changes: 24 additions & 8 deletions tests/http/test_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,27 +22,42 @@

from parameterized import parameterized

from synapse.http.proxy import parse_connection_header_value
from synapse.http.proxy import (
HOP_BY_HOP_HEADERS_LOWERCASE,
parse_connection_header_value,
)

from tests.unittest import TestCase


def mix_case(s: str) -> str:
"""
Mix up the case of each character in the string (upper or lower case)
"""
return "".join(c.upper() if i % 2 == 0 else c.lower() for i, c in enumerate(s))


class ProxyTests(TestCase):
@parameterized.expand(
[
[b"close, X-Foo, X-Bar", {"Close", "X-Foo", "X-Bar"}],
[b"close, X-Foo, X-Bar", {"close", "x-foo", "x-bar"}],
# No whitespace
[b"close,X-Foo,X-Bar", {"Close", "X-Foo", "X-Bar"}],
[b"close,X-Foo,X-Bar", {"close", "x-foo", "x-bar"}],
# More whitespace
[b"close, X-Foo, X-Bar", {"Close", "X-Foo", "X-Bar"}],
[b"close, X-Foo, X-Bar", {"close", "x-foo", "x-bar"}],
# "close" directive in not the first position
[b"X-Foo, X-Bar, close", {"X-Foo", "X-Bar", "Close"}],
[b"X-Foo, X-Bar, close", {"x-foo", "x-bar", "close"}],
# Normalizes header capitalization
[b"keep-alive, x-fOo, x-bAr", {"Keep-Alive", "X-Foo", "X-Bar"}],
[b"keep-alive, x-fOo, x-bAr", {"keep-alive", "x-foo", "x-bar"}],
# Handles header names with whitespace
[
b"keep-alive, x foo, x bar",
{"Keep-Alive", "X foo", "X bar"},
{"keep-alive", "x foo", "x bar"},
],
# Make sure we handle all of the hop-by-hop headers
[
mix_case(", ".join(HOP_BY_HOP_HEADERS_LOWERCASE)).encode("ascii"),
HOP_BY_HOP_HEADERS_LOWERCASE,
],
]
)
Expand All @@ -54,7 +69,8 @@ def test_parse_connection_header_value(
"""
Tests that the connection header value is parsed correctly
"""
self.assertEqual(
self.assertIncludes(
expected_extra_headers_to_remove,
parse_connection_header_value(connection_header_value),
exact=True,
)
Loading