diff --git a/src/hosts.zig b/src/hosts.zig index 4b0fd71..8a74ea1 100644 --- a/src/hosts.zig +++ b/src/hosts.zig @@ -1,6 +1,6 @@ const std = @import("std"); const util = @import("util.zig"); -pub fn update_hosts(ips: util.IPs) !void { - std.debug.print("{any}\n", .{ips}); +pub fn update_hosts(ips: util.IP) !void { + std.debug.print("{}\n", .{ips}); } diff --git a/src/mdns.zig b/src/mdns.zig index 9da7dd4..f64ba03 100644 --- a/src/mdns.zig +++ b/src/mdns.zig @@ -7,11 +7,14 @@ const MDNSError = error{ UDPConnectFail, UDPSendFail, UDPRecvFail, + NotResponse, + NoMatchingAddress, + AddressBadFormat, }; -pub fn get_mdns(domain: util.Domain, ip_info: util.IPInfo) !util.IPs { +pub fn get_mdns(domain: util.Domain, ip_info: util.IPInfo) !util.IP { const sock = try send_query(domain, ip_info); - return receive_query(sock); + return receive_response(sock, ip_info); } const socket = c_int; @@ -59,10 +62,7 @@ fn construct_mdns_query(domain: util.Domain, ip_info: util.IPInfo, buff: []u8) ! return n; } -fn send_query(domain: util.Domain, ip_info: util.IPInfo) !socket { - const sock = try get_mdns_socket(ip_info); - errdefer _ = std.c.close(sock); - +fn get_target_address(ip_info: util.IPInfo) !std.net.Address { const target_addr: []const u8 = switch (ip_info.version) { util.IP_VER_ENUM.IPv4 => "224.0.0.251", util.IP_VER_ENUM.IPv6 => blk: { @@ -73,7 +73,14 @@ fn send_query(domain: util.Domain, ip_info: util.IPInfo) !socket { break :blk buf[0..writer.context.pos]; }, }; - const addr = try std.net.Address.resolveIp(target_addr, 5353); + return std.net.Address.resolveIp(target_addr, 5353); +} + +fn send_query(domain: util.Domain, ip_info: util.IPInfo) !socket { + const sock = try get_mdns_socket(ip_info); + errdefer _ = std.c.close(sock); + + const addr = try get_target_address(ip_info); var buff = [_]u8{0x00} ** MSG_BUFF_SIZE; @@ -84,10 +91,10 @@ fn send_query(domain: util.Domain, ip_info: util.IPInfo) !socket { return sock; } -fn receive_query(sock: socket) !util.IPs { +fn receive_response(sock: socket, ip_info: util.IPInfo) !util.IP { defer _ = std.c.close(sock); - var buff = [_]u8{0x00} ** MSG_BUFF_SIZE; + var buff = [_]u8{0x00} ** MSG_BUFF_SIZE; const n: usize = blk: { const ret = std.c.recv(sock, &buff, MSG_BUFF_SIZE, 0); if (ret < 0) { @@ -97,10 +104,64 @@ fn receive_query(sock: socket) !util.IPs { } }; - std.debug.print("{any}\n", .{buff[0..n]}); - - return util.IPs{ - .v4 = "", - .v6 = "", - }; + return parse_mdns_response(buff[0..n], ip_info); +} + +inline fn skip_name(pos: *usize, buf: []u8) void { + if (buf[pos.*] & 0xc0 != 0) { + pos.* += 2; // compressed form, skip flag + position + return; + } + while (buf[pos.*] != 0) { + pos.* += buf[pos.*] + 1; + } + pos.* += 1; // skip final NULL +} + +inline fn read_u16(bytes: []u8) u16 { + return std.mem.nativeToBig(u16, std.mem.bytesToValue(u16, bytes)); +} + +fn parse_mdns_response(response: []u8, ip_info: util.IPInfo) !util.IP { + var pos: usize = 0; + + if (response[2] & 0x80 == 0) { // check packet is response + return MDNSError.NotResponse; + } + + const n_answers = read_u16(response[6..8]); + pos = 12; // start of question segment + + skip_name(&pos, response); + + pos += 4; // skip type and IN + QU flags + + var addr: ?std.net.Address = null; + for (0..n_answers) |_| { + skip_name(&pos, response); + pos += 8; // skip type + cache flag + ttl; + const ip_len = read_u16(response[pos .. pos + 2]); + pos += 2; + const ip_bytes = response[pos .. pos + ip_len]; + pos += ip_len; + + if (switch (ip_info.version) { + util.IP_VER_ENUM.IPv4 => ip_len != 4, + util.IP_VER_ENUM.IPv6 => ip_len != 16, + }) { + continue; + } + if (ip_len == 4) { + var addr_buff = [_]u8{0x00} ** 4; + @memcpy(&addr_buff, ip_bytes); + addr = std.net.Address.initIp4(addr_buff, 0); + } else if (ip_len == 16) { + if (ip_bytes[0] == 0xfd) { + var addr_buff = [_]u8{0x00} ** 16; + @memcpy(&addr_buff, ip_bytes); + addr = std.net.Address.initIp6(addr_buff, 0, 0, @intCast(std.c.if_nametoindex(ip_info.interface.?))); + } + } + } + return addr orelse MDNSError.NoMatchingAddress; } diff --git a/src/util.zig b/src/util.zig index a63b190..9114dde 100644 --- a/src/util.zig +++ b/src/util.zig @@ -13,10 +13,7 @@ pub const IP_VER_ENUM = enum(u3) { IPv6 = 6, }; -pub const IPs = struct { - v4: []const u8, - v6: []const u8, -}; +pub const IP = std.net.Address; pub const IPInfo = struct { version: IP_VER_ENUM,