# coding: utf-8 from __future__ import print_function import re from .bit import get_bits, set_bits from .buffer import Buffer, BufferError LDH = set(range(33, 127)) ESCAPE = re.compile(r"\\([0-9][0-9][0-9])") avahi_379 = 0 def set_avahi_379(): global avahi_379 avahi_379 = 1 def log_avahi_379(args): global avahi_379 if avahi_379 == 2: return avahi_379 = 2 t = "Invalid pointer in DNSLabel [offset=%d,pointer=%d,length=%d];\n\033[35m NOTE: this is probably avahi-bug #379, packet corruption in Avahi's mDNS-reflection feature. Copyparty has a workaround and is OK, but other devices need either --zm4 or --zm6" raise BufferError(t % args) class DNSLabelError(Exception): pass class DNSLabel(object): def __init__(self, label): if type(label) == DNSLabel: self.label = label.label elif type(label) in (list, tuple): self.label = tuple(label) else: if not label or label in (b".", "."): self.label = () elif type(label) is not bytes: if type("") != type(b""): label = ESCAPE.sub(lambda m: chr(int(m[1])), label) self.label = tuple(label.encode("idna").rstrip(b".").split(b".")) else: if type("") == type(b""): label = ESCAPE.sub(lambda m: chr(int(m.groups()[0])), label) self.label = tuple(label.rstrip(b".").split(b".")) def add(self, name): new = DNSLabel(name) if self.label: new.label += self.label return new def idna(self): return ".".join([s.decode("idna") for s in self.label]) + "." def _decode(self, s): if set(s).issubset(LDH): return s.decode() else: return "".join([(chr(c) if (c in LDH) else "\\%03d" % c) for c in s]) def __str__(self): return ".".join([self._decode(bytearray(s)) for s in self.label]) + "." def __repr__(self): return "" % str(self) def __hash__(self): return hash(tuple(map(lambda x: x.lower(), self.label))) def __ne__(self, other): return not self == other def __eq__(self, other): if type(other) != DNSLabel: return self.__eq__(DNSLabel(other)) else: return [l.lower() for l in self.label] == [l.lower() for l in other.label] def __len__(self): return len(b".".join(self.label)) class DNSBuffer(Buffer): def __init__(self, data=b""): super(DNSBuffer, self).__init__(data) self.names = {} def decode_name(self, last=-1): label = [] done = False while not done: (length,) = self.unpack("!B") if get_bits(length, 6, 2) == 3: self.offset -= 1 pointer = get_bits(self.unpack("!H")[0], 0, 14) save = self.offset if last == save: raise BufferError( "Recursive pointer in DNSLabel [offset=%d,pointer=%d,length=%d]" % (self.offset, pointer, len(self.data)) ) if pointer < self.offset: self.offset = pointer elif avahi_379: log_avahi_379((self.offset, pointer, len(self.data))) label.extend(b"a") break else: raise BufferError( "Invalid pointer in DNSLabel [offset=%d,pointer=%d,length=%d]" % (self.offset, pointer, len(self.data)) ) label.extend(self.decode_name(save).label) self.offset = save done = True else: if length > 0: l = self.get(length) try: l.decode() except UnicodeDecodeError: raise BufferError("Invalid label <%s>" % l) label.append(l) else: done = True return DNSLabel(label) def encode_name(self, name): if not isinstance(name, DNSLabel): name = DNSLabel(name) if len(name) > 253: raise DNSLabelError("Domain label too long: %r" % name) name = list(name.label) while name: if tuple(name) in self.names: pointer = self.names[tuple(name)] pointer = set_bits(pointer, 3, 14, 2) self.pack("!H", pointer) return else: self.names[tuple(name)] = self.offset element = name.pop(0) if len(element) > 63: raise DNSLabelError("Label component too long: %r" % element) self.pack("!B", len(element)) self.append(element) self.append(b"\x00") def encode_name_nocompress(self, name): if not isinstance(name, DNSLabel): name = DNSLabel(name) if len(name) > 253: raise DNSLabelError("Domain label too long: %r" % name) name = list(name.label) while name: element = name.pop(0) if len(element) > 63: raise DNSLabelError("Label component too long: %r" % element) self.pack("!B", len(element)) self.append(element) self.append(b"\x00")