Skip to content

Commit

Permalink
fix: the equality checks for DNSPointer and DNSService should be case…
Browse files Browse the repository at this point in the history
… insensitive (#1122)
  • Loading branch information
bdraco committed Dec 24, 2022
1 parent d6115c8 commit 48ae77f
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 6 deletions.
1 change: 1 addition & 0 deletions src/zeroconf/_dns.pxd
Expand Up @@ -60,6 +60,7 @@ cdef class DNSPointer(DNSRecord):

cdef public cython.int _hash
cdef public object alias
cdef public object alias_key

cdef _eq(self, DNSPointer other)

Expand Down
11 changes: 6 additions & 5 deletions src/zeroconf/_dns.py
Expand Up @@ -314,14 +314,15 @@ class DNSPointer(DNSRecord):

"""A DNS pointer record"""

__slots__ = ('_hash', 'alias')
__slots__ = ('_hash', 'alias', 'alias_key')

def __init__(
self, name: str, type_: int, class_: int, ttl: int, alias: str, created: Optional[float] = None
) -> None:
super().__init__(name, type_, class_, ttl, created)
self.alias = alias
self._hash = hash((self.key, type_, self.class_, alias))
self.alias_key = self.alias.lower()
self._hash = hash((self.key, type_, self.class_, self.alias_key))

@property
def max_size_compressed(self) -> int:
Expand All @@ -343,7 +344,7 @@ def __eq__(self, other: Any) -> bool:

def _eq(self, other) -> bool: # type: ignore[no-untyped-def]
"""Tests equality on alias."""
return self.alias == other.alias and self._dns_entry_matches(other)
return self.alias_key == other.alias_key and self._dns_entry_matches(other)

def __hash__(self) -> int:
"""Hash to compare like DNSPointer."""
Expand Down Expand Up @@ -415,7 +416,7 @@ def __init__(
self.port = port
self.server = server
self.server_key = server.lower()
self._hash = hash((self.key, type_, self.class_, priority, weight, port, server))
self._hash = hash((self.key, type_, self.class_, priority, weight, port, self.server_key))

def write(self, out: 'DNSOutgoing') -> None:
"""Used in constructing an outgoing packet"""
Expand All @@ -434,7 +435,7 @@ def _eq(self, other) -> bool: # type: ignore[no-untyped-def]
self.priority == other.priority
and self.weight == other.weight
and self.port == other.port
and self.server == other.server
and self.server_key == other.server_key
and self._dns_entry_matches(other)
)

Expand Down
2 changes: 1 addition & 1 deletion tests/services/test_browser.py
Expand Up @@ -176,7 +176,7 @@ def update_service(self, zc, type_, name) -> None: # type: ignore[no-untyped-de
socket.AF_INET6, service_v6_second_address
) in service_info.addresses_by_version(r.IPVersion.V6Only)
assert service_info.text == service_text
assert service_info.server == service_server
assert service_info.server.lower() == service_server.lower()
service_updated_event.set()

def mock_incoming_msg(service_state_change: r.ServiceStateChange) -> r.DNSIncoming:
Expand Down
19 changes: 19 additions & 0 deletions tests/test_dns.py
Expand Up @@ -283,6 +283,14 @@ def test_dns_pointer_record_hashablity():
assert len(record_set) == 2


def test_dns_pointer_comparison_is_case_insensitive():
"""Test DNSPointer comparison is case insensitive."""
ptr1 = r.DNSPointer('irrelevant', const._TYPE_PTR, const._CLASS_IN, const._DNS_OTHER_TTL, '123')
ptr2 = r.DNSPointer('irrelevant'.upper(), const._TYPE_PTR, const._CLASS_IN, const._DNS_OTHER_TTL, '123')

assert ptr1 == ptr2


def test_dns_text_record_hashablity():
"""Test DNSText are hashable."""
text1 = r.DNSText('irrelevant', 0, 0, const._DNS_OTHER_TTL, b'12345678901')
Expand Down Expand Up @@ -340,6 +348,17 @@ def test_dns_service_server_key():
assert srv1.server_key == 'x.local.'


def test_dns_service_server_comparison_is_case_insensitive():
"""Test DNSService server comparison is case insensitive."""
srv1 = r.DNSService(
'X._tcp._http.local.', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL, 0, 0, 80, 'X.local.'
)
srv2 = r.DNSService(
'X._tcp._http.local.', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL, 0, 0, 80, 'x.local.'
)
assert srv1 == srv2


def test_dns_nsec_record_hashablity():
"""Test DNSNsec are hashable."""
nsec1 = r.DNSNsec(
Expand Down

0 comments on commit 48ae77f

Please sign in to comment.