From 93045d80f7a9c277bef71c1a50649823bd7a2b4f Mon Sep 17 00:00:00 2001 From: Muaz Ahmad Date: Sun, 12 May 2024 18:52:13 +0500 Subject: [PATCH] Add env var args, old values purging --- src/hosts.zig | 64 +++++++++++++++++++++++++++++++++++++++++++-------- src/main.zig | 4 ++-- src/mdns.zig | 2 +- src/util.zig | 9 ++++++-- 4 files changed, 65 insertions(+), 14 deletions(-) diff --git a/src/hosts.zig b/src/hosts.zig index f452a94..fdbcb7b 100644 --- a/src/hosts.zig +++ b/src/hosts.zig @@ -1,19 +1,25 @@ const std = @import("std"); const util = @import("util.zig"); +const hosts_header = "# local-etc-hosts-updater"; + pub fn update_hosts(ip: util.IP) !void { try create_tmp_hosts(ip); try move_tmp_hosts(); } -fn create_tmp_hosts(ip: util.IP) !void { - _ = ip; +inline fn check_purge(line: []u8, purging: *bool, skip_last_line: *bool) void { + if (purging.* and std.mem.eql(u8, line, "")) { + purging.* = false; + skip_last_line.* = true; + } - var old_hosts = try std.fs.openFileAbsolute("/etc/hosts", .{ .mode = .read_only }); - defer old_hosts.close(); - var tmp_hosts = try std.fs.createFileAbsolute("/tmp/hosts", .{ .truncate = true }); - defer tmp_hosts.close(); + if (std.mem.eql(u8, line, hosts_header)) { + purging.* = true; + } +} +fn purge_existing(old_hosts: std.fs.File, tmp_hosts: std.fs.File) !void { var buff = [_]u8{0x00} ** 50; var buff_stream = std.io.fixedBufferStream(&buff); const buff_reader = buff_stream.reader(); @@ -21,18 +27,58 @@ fn create_tmp_hosts(ip: util.IP) !void { var old_reader = old_hosts.reader(); const tmp_writer = tmp_hosts.writer(); + var purging = false; + var skip_last_line = false; while (old_reader.streamUntilDelimiter(buff_writer, '\n', 50 - 1)) |_| { + defer buff_stream.pos = 0; var n = buff_stream.pos; // do things to line + check_purge(buff[0..n], &purging, &skip_last_line); + + if (purging) { + continue; + } + if (skip_last_line) { + skip_last_line = false; + continue; + } + + // copy line buff[n] = '\n'; n += 1; buff_stream.pos = 0; try buff_reader.streamUntilDelimiter(tmp_writer, '\n', 50); try tmp_writer.writeByte('\n'); - buff_stream.pos = 0; - } else |_| { - // EOF + } else |err| { + if (err != error.EndOfStream) { + return err; + } + } +} + +fn append_new(tmp_hosts: std.fs.File, addr: std.net.Address) !void { + _ = addr; + const writer = tmp_hosts.writer(); + var buff = [_]u8{0x00} ** 50; + + const header_len = hosts_header.*.len; + @memcpy(buff[0..header_len], hosts_header); + try writer.writeAll(buff[0..header_len]); + try writer.writeByte('\n'); + + try writer.writeByte('\n'); +} + +fn create_tmp_hosts(ip: util.IP) !void { + var old_hosts = try std.fs.openFileAbsoluteZ(try util.getenv("OLD_HOSTS_PATH"), .{ .mode = .read_only }); + defer old_hosts.close(); + var tmp_hosts = try std.fs.createFileAbsoluteZ(try util.getenv("TMP_HOSTS_PATH"), .{ .truncate = true }); + defer tmp_hosts.close(); + + try purge_existing(old_hosts, tmp_hosts); + if (ip) |addr| { + try append_new(tmp_hosts, addr); } } diff --git a/src/main.zig b/src/main.zig index 75f1815..de7ed0d 100644 --- a/src/main.zig +++ b/src/main.zig @@ -5,7 +5,7 @@ const hosts = @import("hosts.zig"); pub fn main() !void { try util.check_perms(); - const domain, const ip_ver = try util.get_input(); - const ip = try mdns.get_mdns(domain, ip_ver); + const domain, const ip_info = try util.get_input(); + const ip = try mdns.get_mdns(domain, ip_info); try hosts.update_hosts(ip); } diff --git a/src/mdns.zig b/src/mdns.zig index f5c8521..15343fe 100644 --- a/src/mdns.zig +++ b/src/mdns.zig @@ -165,5 +165,5 @@ fn parse_mdns_response(response: []u8, ip_info: util.IPInfo) !util.IP { } } } - return addr orelse MDNSError.NoMatchingAddress; + return addr; } diff --git a/src/util.zig b/src/util.zig index 9114dde..48ed5a8 100644 --- a/src/util.zig +++ b/src/util.zig @@ -6,6 +6,7 @@ const ArgError = error{ InvalidAddressVer, InterfaceRequired, InvalidInterface, + EnvVarNotSet, }; pub const IP_VER_ENUM = enum(u3) { @@ -13,7 +14,7 @@ pub const IP_VER_ENUM = enum(u3) { IPv6 = 6, }; -pub const IP = std.net.Address; +pub const IP = ?std.net.Address; pub const IPInfo = struct { version: IP_VER_ENUM, @@ -25,8 +26,12 @@ pub const Domain = struct { labels: [5][]const u8, }; +pub fn getenv(key: [*:0]const u8) ![*:0]const u8 { + return (std.c.getenv(key) orelse ArgError.EnvVarNotSet); +} + pub fn check_perms() !void { - var f = try std.fs.openFileAbsolute("/etc/hosts", .{ .mode = .write_only }); + var f = try std.fs.openFileAbsoluteZ(try getenv("OLD_HOSTS_PATH"), .{ .mode = .write_only }); f.close(); }