Skip to content

Commit 48ae77f

Browse files
authoredDec 24, 2022
fix: the equality checks for DNSPointer and DNSService should be case insensitive (#1122)
1 parent d6115c8 commit 48ae77f

File tree

4 files changed

+27
-6
lines changed

4 files changed

+27
-6
lines changed
 

‎src/zeroconf/_dns.pxd

+1
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ cdef class DNSPointer(DNSRecord):
6060

6161
cdef public cython.int _hash
6262
cdef public object alias
63+
cdef public object alias_key
6364

6465
cdef _eq(self, DNSPointer other)
6566

‎src/zeroconf/_dns.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -314,14 +314,15 @@ class DNSPointer(DNSRecord):
314314

315315
"""A DNS pointer record"""
316316

317-
__slots__ = ('_hash', 'alias')
317+
__slots__ = ('_hash', 'alias', 'alias_key')
318318

319319
def __init__(
320320
self, name: str, type_: int, class_: int, ttl: int, alias: str, created: Optional[float] = None
321321
) -> None:
322322
super().__init__(name, type_, class_, ttl, created)
323323
self.alias = alias
324-
self._hash = hash((self.key, type_, self.class_, alias))
324+
self.alias_key = self.alias.lower()
325+
self._hash = hash((self.key, type_, self.class_, self.alias_key))
325326

326327
@property
327328
def max_size_compressed(self) -> int:
@@ -343,7 +344,7 @@ def __eq__(self, other: Any) -> bool:
343344

344345
def _eq(self, other) -> bool: # type: ignore[no-untyped-def]
345346
"""Tests equality on alias."""
346-
return self.alias == other.alias and self._dns_entry_matches(other)
347+
return self.alias_key == other.alias_key and self._dns_entry_matches(other)
347348

348349
def __hash__(self) -> int:
349350
"""Hash to compare like DNSPointer."""
@@ -415,7 +416,7 @@ def __init__(
415416
self.port = port
416417
self.server = server
417418
self.server_key = server.lower()
418-
self._hash = hash((self.key, type_, self.class_, priority, weight, port, server))
419+
self._hash = hash((self.key, type_, self.class_, priority, weight, port, self.server_key))
419420

420421
def write(self, out: 'DNSOutgoing') -> None:
421422
"""Used in constructing an outgoing packet"""
@@ -434,7 +435,7 @@ def _eq(self, other) -> bool: # type: ignore[no-untyped-def]
434435
self.priority == other.priority
435436
and self.weight == other.weight
436437
and self.port == other.port
437-
and self.server == other.server
438+
and self.server_key == other.server_key
438439
and self._dns_entry_matches(other)
439440
)
440441

‎tests/services/test_browser.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ def update_service(self, zc, type_, name) -> None: # type: ignore[no-untyped-de
176176
socket.AF_INET6, service_v6_second_address
177177
) in service_info.addresses_by_version(r.IPVersion.V6Only)
178178
assert service_info.text == service_text
179-
assert service_info.server == service_server
179+
assert service_info.server.lower() == service_server.lower()
180180
service_updated_event.set()
181181

182182
def mock_incoming_msg(service_state_change: r.ServiceStateChange) -> r.DNSIncoming:

‎tests/test_dns.py

+19
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,14 @@ def test_dns_pointer_record_hashablity():
283283
assert len(record_set) == 2
284284

285285

286+
def test_dns_pointer_comparison_is_case_insensitive():
287+
"""Test DNSPointer comparison is case insensitive."""
288+
ptr1 = r.DNSPointer('irrelevant', const._TYPE_PTR, const._CLASS_IN, const._DNS_OTHER_TTL, '123')
289+
ptr2 = r.DNSPointer('irrelevant'.upper(), const._TYPE_PTR, const._CLASS_IN, const._DNS_OTHER_TTL, '123')
290+
291+
assert ptr1 == ptr2
292+
293+
286294
def test_dns_text_record_hashablity():
287295
"""Test DNSText are hashable."""
288296
text1 = r.DNSText('irrelevant', 0, 0, const._DNS_OTHER_TTL, b'12345678901')
@@ -340,6 +348,17 @@ def test_dns_service_server_key():
340348
assert srv1.server_key == 'x.local.'
341349

342350

351+
def test_dns_service_server_comparison_is_case_insensitive():
352+
"""Test DNSService server comparison is case insensitive."""
353+
srv1 = r.DNSService(
354+
'X._tcp._http.local.', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL, 0, 0, 80, 'X.local.'
355+
)
356+
srv2 = r.DNSService(
357+
'X._tcp._http.local.', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL, 0, 0, 80, 'x.local.'
358+
)
359+
assert srv1 == srv2
360+
361+
343362
def test_dns_nsec_record_hashablity():
344363
"""Test DNSNsec are hashable."""
345364
nsec1 = r.DNSNsec(

0 commit comments

Comments
 (0)
Please sign in to comment.