From 48ae77f026a96e2ca475b0ff80cb6d22207ce52f Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 23 Dec 2022 17:27:52 -1000 Subject: [PATCH] fix: the equality checks for DNSPointer and DNSService should be case insensitive (#1122) --- src/zeroconf/_dns.pxd | 1 + src/zeroconf/_dns.py | 11 ++++++----- tests/services/test_browser.py | 2 +- tests/test_dns.py | 19 +++++++++++++++++++ 4 files changed, 27 insertions(+), 6 deletions(-) diff --git a/src/zeroconf/_dns.pxd b/src/zeroconf/_dns.pxd index 762e9319..14c7fb70 100644 --- a/src/zeroconf/_dns.pxd +++ b/src/zeroconf/_dns.pxd @@ -60,6 +60,7 @@ cdef class DNSPointer(DNSRecord): cdef public cython.int _hash cdef public object alias + cdef public object alias_key cdef _eq(self, DNSPointer other) diff --git a/src/zeroconf/_dns.py b/src/zeroconf/_dns.py index f9e33541..5727d83a 100644 --- a/src/zeroconf/_dns.py +++ b/src/zeroconf/_dns.py @@ -314,14 +314,15 @@ class DNSPointer(DNSRecord): """A DNS pointer record""" - __slots__ = ('_hash', 'alias') + __slots__ = ('_hash', 'alias', 'alias_key') def __init__( self, name: str, type_: int, class_: int, ttl: int, alias: str, created: Optional[float] = None ) -> None: super().__init__(name, type_, class_, ttl, created) self.alias = alias - self._hash = hash((self.key, type_, self.class_, alias)) + self.alias_key = self.alias.lower() + self._hash = hash((self.key, type_, self.class_, self.alias_key)) @property def max_size_compressed(self) -> int: @@ -343,7 +344,7 @@ def __eq__(self, other: Any) -> bool: def _eq(self, other) -> bool: # type: ignore[no-untyped-def] """Tests equality on alias.""" - return self.alias == other.alias and self._dns_entry_matches(other) + return self.alias_key == other.alias_key and self._dns_entry_matches(other) def __hash__(self) -> int: """Hash to compare like DNSPointer.""" @@ -415,7 +416,7 @@ def __init__( self.port = port self.server = server self.server_key = server.lower() - self._hash = hash((self.key, type_, self.class_, priority, weight, port, server)) + self._hash = hash((self.key, type_, self.class_, priority, weight, port, self.server_key)) def write(self, out: 'DNSOutgoing') -> None: """Used in constructing an outgoing packet""" @@ -434,7 +435,7 @@ def _eq(self, other) -> bool: # type: ignore[no-untyped-def] self.priority == other.priority and self.weight == other.weight and self.port == other.port - and self.server == other.server + and self.server_key == other.server_key and self._dns_entry_matches(other) ) diff --git a/tests/services/test_browser.py b/tests/services/test_browser.py index fd588648..a3121e6d 100644 --- a/tests/services/test_browser.py +++ b/tests/services/test_browser.py @@ -176,7 +176,7 @@ def update_service(self, zc, type_, name) -> None: # type: ignore[no-untyped-de socket.AF_INET6, service_v6_second_address ) in service_info.addresses_by_version(r.IPVersion.V6Only) assert service_info.text == service_text - assert service_info.server == service_server + assert service_info.server.lower() == service_server.lower() service_updated_event.set() def mock_incoming_msg(service_state_change: r.ServiceStateChange) -> r.DNSIncoming: diff --git a/tests/test_dns.py b/tests/test_dns.py index 59b4932a..08f805f0 100644 --- a/tests/test_dns.py +++ b/tests/test_dns.py @@ -283,6 +283,14 @@ def test_dns_pointer_record_hashablity(): assert len(record_set) == 2 +def test_dns_pointer_comparison_is_case_insensitive(): + """Test DNSPointer comparison is case insensitive.""" + ptr1 = r.DNSPointer('irrelevant', const._TYPE_PTR, const._CLASS_IN, const._DNS_OTHER_TTL, '123') + ptr2 = r.DNSPointer('irrelevant'.upper(), const._TYPE_PTR, const._CLASS_IN, const._DNS_OTHER_TTL, '123') + + assert ptr1 == ptr2 + + def test_dns_text_record_hashablity(): """Test DNSText are hashable.""" text1 = r.DNSText('irrelevant', 0, 0, const._DNS_OTHER_TTL, b'12345678901') @@ -340,6 +348,17 @@ def test_dns_service_server_key(): assert srv1.server_key == 'x.local.' +def test_dns_service_server_comparison_is_case_insensitive(): + """Test DNSService server comparison is case insensitive.""" + srv1 = r.DNSService( + 'X._tcp._http.local.', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL, 0, 0, 80, 'X.local.' + ) + srv2 = r.DNSService( + 'X._tcp._http.local.', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL, 0, 0, 80, 'x.local.' + ) + assert srv1 == srv2 + + def test_dns_nsec_record_hashablity(): """Test DNSNsec are hashable.""" nsec1 = r.DNSNsec(