diff --git a/src/hosts.zig b/src/hosts.zig index 870e681..0edf623 100644 --- a/src/hosts.zig +++ b/src/hosts.zig @@ -61,16 +61,7 @@ fn purge_existing(old_hosts: std.fs.File, tmp_hosts: std.fs.File) !void { } } -fn append_new(tmp_hosts: std.fs.File, addr: std.net.Address, local_domain: util.Domain) !void { - const writer = tmp_hosts.writer(); - var buff = [_]u8{0x00} ** FILE_LINE_BUFF_SIZE; - - // write output - const header_len = hosts_header.*.len; - @memcpy(buff[0..header_len], hosts_header); - try writer.writeAll(buff[0..header_len]); - try writer.writeByte('\n'); - +fn write_subdomains_map(writer: std.fs.File.Writer, addr: std.net.Address, local_domain: util.Domain) !void { var subdomains = try get_subdomains(); const domain = try util.getenv("TARGET_DOMAIN"); @@ -86,6 +77,24 @@ fn append_new(tmp_hosts: std.fs.File, addr: std.net.Address, local_domain: util. try std.fmt.format(writer, "{s} {s}.{s}\n", .{ addr_str, subdomain, domain }); } try std.fmt.format(writer, "{s} {s}\n", .{ addr_str, local_domain.name }); +} + +fn append_new(tmp_hosts: std.fs.File, ip: util.IP, local_domain: util.Domain) !void { + const writer = tmp_hosts.writer(); + var buff = [_]u8{0x00} ** FILE_LINE_BUFF_SIZE; + + // write output + const header_len = hosts_header.*.len; + @memcpy(buff[0..header_len], hosts_header); + try writer.writeAll(buff[0..header_len]); + try writer.writeByte('\n'); + + if (ip.v6) |addr| { + try write_subdomains_map(writer, addr, local_domain); + } + if (ip.v4) |addr| { + try write_subdomains_map(writer, addr, local_domain); + } // add extra final newline try writer.writeByte('\n'); @@ -97,22 +106,20 @@ fn get_subdomains() !std.mem.SplitIterator(u8, .scalar) { } fn create_tmp_hosts(ip: util.IP, domain: util.Domain) !void { - var old_hosts = try std.fs.openFileAbsoluteZ(try util.getenv("OLD_HOSTS_PATH"), .{ .mode = .read_only }); + var old_hosts = try std.fs.Dir.openFileZ(util.cwd, try util.getenv("OLD_HOSTS_PATH"), .{ .mode = .read_only }); defer old_hosts.close(); - var tmp_hosts = try std.fs.createFileAbsoluteZ(try util.getenv("TMP_HOSTS_PATH"), .{ .truncate = true }); + var tmp_hosts = try std.fs.Dir.createFileZ(util.cwd, try util.getenv("TMP_HOSTS_PATH"), .{ .truncate = true }); defer tmp_hosts.close(); try purge_existing(old_hosts, tmp_hosts); - if (ip) |addr| { - try append_new(tmp_hosts, addr, domain); - } + try append_new(tmp_hosts, ip, domain); } fn move_tmp_hosts() !void { - var target_hosts = try std.fs.createFileAbsoluteZ(try util.getenv("OLD_HOSTS_PATH"), .{ .lock = .exclusive }); + var target_hosts = try std.fs.Dir.createFileZ(util.cwd, try util.getenv("OLD_HOSTS_PATH"), .{ .lock = .exclusive }); defer target_hosts.close(); const target_writer = target_hosts.writer(); - var tmp_hosts = try std.fs.openFileAbsoluteZ(try util.getenv("TMP_HOSTS_PATH"), .{}); + var tmp_hosts = try std.fs.Dir.openFileZ(util.cwd, try util.getenv("TMP_HOSTS_PATH"), .{}); defer tmp_hosts.close(); const source_reader = tmp_hosts.reader(); @@ -122,5 +129,5 @@ fn move_tmp_hosts() !void { } }; - try std.fs.deleteFileAbsoluteZ(try util.getenv("TMP_HOSTS_PATH")); + try std.fs.Dir.deleteFileZ(util.cwd, try util.getenv("TMP_HOSTS_PATH")); } diff --git a/src/mdns.zig b/src/mdns.zig index 7b3d098..1d5749b 100644 --- a/src/mdns.zig +++ b/src/mdns.zig @@ -2,7 +2,6 @@ const std = @import("std"); const util = @import("util.zig"); const MDNSError = error{ - Unimplemented, SocketInitFail, UDPConnectFail, UDPSendFail, @@ -14,8 +13,20 @@ const MDNSError = error{ }; pub fn get_mdns(domain: util.Domain, ip_info: util.IPInfo) !util.IP { - const sock = try send_query(domain, ip_info); - return receive_response(sock, ip_info); + if (ip_info.version != .Both) { + const sock = try send_query(domain, ip_info); + return receive_response(sock, ip_info); + } + + const ip_info4 = util.IPInfo{ .version = .IPv4, .interface = ip_info.interface }; + var sock = try send_query(domain, ip_info4); + const addr_v4 = try receive_response(sock, ip_info4); + + const ip_info6 = util.IPInfo{ .version = .IPv4, .interface = ip_info.interface }; + sock = try send_query(domain, ip_info6); + const addr_v6 = try receive_response(sock, ip_info6); + + return util.merge_addrs(addr_v4, addr_v6); } const socket = c_int; @@ -25,6 +36,7 @@ fn get_mdns_socket(ip_info: util.IPInfo) !socket { const sock = std.c.socket(switch (ip_info.version) { util.IP_VER_ENUM.IPv4 => std.c.AF.INET, util.IP_VER_ENUM.IPv6 => std.c.AF.INET6, + else => unreachable, }, std.c.SOCK.DGRAM, std.c.IPPROTO.UDP); if (sock == -1) { return MDNSError.SocketInitFail; @@ -59,6 +71,7 @@ fn construct_mdns_query(domain: util.Domain, ip_info: util.IPInfo, buff: []u8) ! [_]u8{ 0x00, switch (ip_info.version) { util.IP_VER_ENUM.IPv4 => 0x01, util.IP_VER_ENUM.IPv6 => 0x1c, + else => unreachable, } } ++ // A or AAAA record [_]u8{ 0x00, 0x01 } // IN query ; @@ -77,6 +90,7 @@ fn get_target_address(ip_info: util.IPInfo) !std.net.Address { try std.fmt.format(writer, "ff02::fb%{s}", .{ip_info.interface.?}); break :blk buf[0..writer.context.pos]; }, + else => unreachable, }; return std.net.Address.resolveIp(target_addr, 5353); } @@ -153,6 +167,7 @@ fn parse_mdns_response(response: []u8, ip_info: util.IPInfo) !util.IP { if (switch (ip_info.version) { util.IP_VER_ENUM.IPv4 => ip_len != 4, util.IP_VER_ENUM.IPv6 => ip_len != 16, + else => unreachable, }) { continue; } @@ -170,5 +185,12 @@ fn parse_mdns_response(response: []u8, ip_info: util.IPInfo) !util.IP { } } } - return addr; + if (addr == null) { + return MDNSError.NoMatchingAddress; + } + return switch (ip_info.version) { + .IPv4 => util.IP{ .v4 = addr }, + .IPv6 => util.IP{ .v6 = addr }, + else => unreachable, + }; } diff --git a/src/util.zig b/src/util.zig index 7bb1afc..208376d 100644 --- a/src/util.zig +++ b/src/util.zig @@ -7,14 +7,26 @@ const ArgError = error{ InterfaceRequired, InvalidInterface, EnvVarNotSet, + InvalidOldHostsPath, }; pub const IP_VER_ENUM = enum(u3) { + Both = 0, IPv4 = 4, IPv6 = 6, }; -pub const IP = ?std.net.Address; +pub const IP = struct { + v4: ?std.net.Address = null, + v6: ?std.net.Address = null, +}; + +pub fn merge_addrs(v4: IP, v6: IP) IP { + return IP{ + .v4 = v4.v4, + .v6 = v6.v6, + }; +} pub const IPInfo = struct { version: IP_VER_ENUM, @@ -31,10 +43,12 @@ pub fn getenv(key: [*:0]const u8) ![*:0]const u8 { } pub fn check_perms() !void { - var f = try std.fs.openFileAbsoluteZ(try getenv("OLD_HOSTS_PATH"), .{ .mode = .write_only }); + var f = try std.fs.Dir.openFileZ(cwd, try getenv("OLD_HOSTS_PATH"), .{ .mode = .write_only }); f.close(); } +pub const cwd = std.fs.cwd(); + pub fn get_input() !struct { Domain, IPInfo, @@ -49,7 +63,7 @@ pub fn get_input() !struct { const ip_ver = std.meta.intToEnum(IP_VER_ENUM, std.fmt.parseInt(u3, ip_ver_str, 10) catch return ArgError.InvalidAddressVer) catch return ArgError.InvalidAddressVer; var iface: ?[:0]const u8 = null; - if (ip_ver == IP_VER_ENUM.IPv6) { + if (ip_ver != IP_VER_ENUM.IPv4) { iface = args.next() orelse return ArgError.InterfaceRequired; if (std.c.if_nametoindex(iface.?) == 0) { return ArgError.InvalidInterface;