Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: the equality checks for DNSPointer and DNSService should be case insensitive #1122

Merged
merged 3 commits into from Dec 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions src/zeroconf/_dns.pxd
Expand Up @@ -60,6 +60,7 @@ cdef class DNSPointer(DNSRecord):

cdef public cython.int _hash
cdef public object alias
cdef public object alias_key

cdef _eq(self, DNSPointer other)

Expand Down
11 changes: 6 additions & 5 deletions src/zeroconf/_dns.py
Expand Up @@ -314,14 +314,15 @@ class DNSPointer(DNSRecord):

"""A DNS pointer record"""

__slots__ = ('_hash', 'alias')
__slots__ = ('_hash', 'alias', 'alias_key')

def __init__(
self, name: str, type_: int, class_: int, ttl: int, alias: str, created: Optional[float] = None
) -> None:
super().__init__(name, type_, class_, ttl, created)
self.alias = alias
self._hash = hash((self.key, type_, self.class_, alias))
self.alias_key = self.alias.lower()
self._hash = hash((self.key, type_, self.class_, self.alias_key))

@property
def max_size_compressed(self) -> int:
Expand All @@ -343,7 +344,7 @@ def __eq__(self, other: Any) -> bool:

def _eq(self, other) -> bool: # type: ignore[no-untyped-def]
"""Tests equality on alias."""
return self.alias == other.alias and self._dns_entry_matches(other)
return self.alias_key == other.alias_key and self._dns_entry_matches(other)

def __hash__(self) -> int:
"""Hash to compare like DNSPointer."""
Expand Down Expand Up @@ -415,7 +416,7 @@ def __init__(
self.port = port
self.server = server
self.server_key = server.lower()
self._hash = hash((self.key, type_, self.class_, priority, weight, port, server))
self._hash = hash((self.key, type_, self.class_, priority, weight, port, self.server_key))

def write(self, out: 'DNSOutgoing') -> None:
"""Used in constructing an outgoing packet"""
Expand All @@ -434,7 +435,7 @@ def _eq(self, other) -> bool: # type: ignore[no-untyped-def]
self.priority == other.priority
and self.weight == other.weight
and self.port == other.port
and self.server == other.server
and self.server_key == other.server_key
and self._dns_entry_matches(other)
)

Expand Down
2 changes: 1 addition & 1 deletion tests/services/test_browser.py
Expand Up @@ -176,7 +176,7 @@ def update_service(self, zc, type_, name) -> None: # type: ignore[no-untyped-de
socket.AF_INET6, service_v6_second_address
) in service_info.addresses_by_version(r.IPVersion.V6Only)
assert service_info.text == service_text
assert service_info.server == service_server
assert service_info.server.lower() == service_server.lower()
service_updated_event.set()

def mock_incoming_msg(service_state_change: r.ServiceStateChange) -> r.DNSIncoming:
Expand Down
19 changes: 19 additions & 0 deletions tests/test_dns.py
Expand Up @@ -283,6 +283,14 @@ def test_dns_pointer_record_hashablity():
assert len(record_set) == 2


def test_dns_pointer_comparison_is_case_insensitive():
"""Test DNSPointer comparison is case insensitive."""
ptr1 = r.DNSPointer('irrelevant', const._TYPE_PTR, const._CLASS_IN, const._DNS_OTHER_TTL, '123')
ptr2 = r.DNSPointer('irrelevant'.upper(), const._TYPE_PTR, const._CLASS_IN, const._DNS_OTHER_TTL, '123')

assert ptr1 == ptr2


def test_dns_text_record_hashablity():
"""Test DNSText are hashable."""
text1 = r.DNSText('irrelevant', 0, 0, const._DNS_OTHER_TTL, b'12345678901')
Expand Down Expand Up @@ -340,6 +348,17 @@ def test_dns_service_server_key():
assert srv1.server_key == 'x.local.'


def test_dns_service_server_comparison_is_case_insensitive():
"""Test DNSService server comparison is case insensitive."""
srv1 = r.DNSService(
'X._tcp._http.local.', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL, 0, 0, 80, 'X.local.'
)
srv2 = r.DNSService(
'X._tcp._http.local.', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL, 0, 0, 80, 'x.local.'
)
assert srv1 == srv2


def test_dns_nsec_record_hashablity():
"""Test DNSNsec are hashable."""
nsec1 = r.DNSNsec(
Expand Down