Add dual version handling, relative path handling
This commit is contained in:
parent
6669634565
commit
5b05de3ab3
3 changed files with 68 additions and 25 deletions
|
@ -61,16 +61,7 @@ fn purge_existing(old_hosts: std.fs.File, tmp_hosts: std.fs.File) !void {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn append_new(tmp_hosts: std.fs.File, addr: std.net.Address, local_domain: util.Domain) !void {
|
fn write_subdomains_map(writer: std.fs.File.Writer, addr: std.net.Address, local_domain: util.Domain) !void {
|
||||||
const writer = tmp_hosts.writer();
|
|
||||||
var buff = [_]u8{0x00} ** FILE_LINE_BUFF_SIZE;
|
|
||||||
|
|
||||||
// write output
|
|
||||||
const header_len = hosts_header.*.len;
|
|
||||||
@memcpy(buff[0..header_len], hosts_header);
|
|
||||||
try writer.writeAll(buff[0..header_len]);
|
|
||||||
try writer.writeByte('\n');
|
|
||||||
|
|
||||||
var subdomains = try get_subdomains();
|
var subdomains = try get_subdomains();
|
||||||
const domain = try util.getenv("TARGET_DOMAIN");
|
const domain = try util.getenv("TARGET_DOMAIN");
|
||||||
|
|
||||||
|
@ -86,6 +77,24 @@ fn append_new(tmp_hosts: std.fs.File, addr: std.net.Address, local_domain: util.
|
||||||
try std.fmt.format(writer, "{s} {s}.{s}\n", .{ addr_str, subdomain, domain });
|
try std.fmt.format(writer, "{s} {s}.{s}\n", .{ addr_str, subdomain, domain });
|
||||||
}
|
}
|
||||||
try std.fmt.format(writer, "{s} {s}\n", .{ addr_str, local_domain.name });
|
try std.fmt.format(writer, "{s} {s}\n", .{ addr_str, local_domain.name });
|
||||||
|
}
|
||||||
|
|
||||||
|
fn append_new(tmp_hosts: std.fs.File, ip: util.IP, local_domain: util.Domain) !void {
|
||||||
|
const writer = tmp_hosts.writer();
|
||||||
|
var buff = [_]u8{0x00} ** FILE_LINE_BUFF_SIZE;
|
||||||
|
|
||||||
|
// write output
|
||||||
|
const header_len = hosts_header.*.len;
|
||||||
|
@memcpy(buff[0..header_len], hosts_header);
|
||||||
|
try writer.writeAll(buff[0..header_len]);
|
||||||
|
try writer.writeByte('\n');
|
||||||
|
|
||||||
|
if (ip.v6) |addr| {
|
||||||
|
try write_subdomains_map(writer, addr, local_domain);
|
||||||
|
}
|
||||||
|
if (ip.v4) |addr| {
|
||||||
|
try write_subdomains_map(writer, addr, local_domain);
|
||||||
|
}
|
||||||
|
|
||||||
// add extra final newline
|
// add extra final newline
|
||||||
try writer.writeByte('\n');
|
try writer.writeByte('\n');
|
||||||
|
@ -97,22 +106,20 @@ fn get_subdomains() !std.mem.SplitIterator(u8, .scalar) {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn create_tmp_hosts(ip: util.IP, domain: util.Domain) !void {
|
fn create_tmp_hosts(ip: util.IP, domain: util.Domain) !void {
|
||||||
var old_hosts = try std.fs.openFileAbsoluteZ(try util.getenv("OLD_HOSTS_PATH"), .{ .mode = .read_only });
|
var old_hosts = try std.fs.Dir.openFileZ(util.cwd, try util.getenv("OLD_HOSTS_PATH"), .{ .mode = .read_only });
|
||||||
defer old_hosts.close();
|
defer old_hosts.close();
|
||||||
var tmp_hosts = try std.fs.createFileAbsoluteZ(try util.getenv("TMP_HOSTS_PATH"), .{ .truncate = true });
|
var tmp_hosts = try std.fs.Dir.createFileZ(util.cwd, try util.getenv("TMP_HOSTS_PATH"), .{ .truncate = true });
|
||||||
defer tmp_hosts.close();
|
defer tmp_hosts.close();
|
||||||
|
|
||||||
try purge_existing(old_hosts, tmp_hosts);
|
try purge_existing(old_hosts, tmp_hosts);
|
||||||
if (ip) |addr| {
|
try append_new(tmp_hosts, ip, domain);
|
||||||
try append_new(tmp_hosts, addr, domain);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn move_tmp_hosts() !void {
|
fn move_tmp_hosts() !void {
|
||||||
var target_hosts = try std.fs.createFileAbsoluteZ(try util.getenv("OLD_HOSTS_PATH"), .{ .lock = .exclusive });
|
var target_hosts = try std.fs.Dir.createFileZ(util.cwd, try util.getenv("OLD_HOSTS_PATH"), .{ .lock = .exclusive });
|
||||||
defer target_hosts.close();
|
defer target_hosts.close();
|
||||||
const target_writer = target_hosts.writer();
|
const target_writer = target_hosts.writer();
|
||||||
var tmp_hosts = try std.fs.openFileAbsoluteZ(try util.getenv("TMP_HOSTS_PATH"), .{});
|
var tmp_hosts = try std.fs.Dir.openFileZ(util.cwd, try util.getenv("TMP_HOSTS_PATH"), .{});
|
||||||
defer tmp_hosts.close();
|
defer tmp_hosts.close();
|
||||||
const source_reader = tmp_hosts.reader();
|
const source_reader = tmp_hosts.reader();
|
||||||
|
|
||||||
|
@ -122,5 +129,5 @@ fn move_tmp_hosts() !void {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
try std.fs.deleteFileAbsoluteZ(try util.getenv("TMP_HOSTS_PATH"));
|
try std.fs.Dir.deleteFileZ(util.cwd, try util.getenv("TMP_HOSTS_PATH"));
|
||||||
}
|
}
|
||||||
|
|
26
src/mdns.zig
26
src/mdns.zig
|
@ -2,7 +2,6 @@ const std = @import("std");
|
||||||
const util = @import("util.zig");
|
const util = @import("util.zig");
|
||||||
|
|
||||||
const MDNSError = error{
|
const MDNSError = error{
|
||||||
Unimplemented,
|
|
||||||
SocketInitFail,
|
SocketInitFail,
|
||||||
UDPConnectFail,
|
UDPConnectFail,
|
||||||
UDPSendFail,
|
UDPSendFail,
|
||||||
|
@ -14,10 +13,22 @@ const MDNSError = error{
|
||||||
};
|
};
|
||||||
|
|
||||||
pub fn get_mdns(domain: util.Domain, ip_info: util.IPInfo) !util.IP {
|
pub fn get_mdns(domain: util.Domain, ip_info: util.IPInfo) !util.IP {
|
||||||
|
if (ip_info.version != .Both) {
|
||||||
const sock = try send_query(domain, ip_info);
|
const sock = try send_query(domain, ip_info);
|
||||||
return receive_response(sock, ip_info);
|
return receive_response(sock, ip_info);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const ip_info4 = util.IPInfo{ .version = .IPv4, .interface = ip_info.interface };
|
||||||
|
var sock = try send_query(domain, ip_info4);
|
||||||
|
const addr_v4 = try receive_response(sock, ip_info4);
|
||||||
|
|
||||||
|
const ip_info6 = util.IPInfo{ .version = .IPv4, .interface = ip_info.interface };
|
||||||
|
sock = try send_query(domain, ip_info6);
|
||||||
|
const addr_v6 = try receive_response(sock, ip_info6);
|
||||||
|
|
||||||
|
return util.merge_addrs(addr_v4, addr_v6);
|
||||||
|
}
|
||||||
|
|
||||||
const socket = c_int;
|
const socket = c_int;
|
||||||
const MSG_BUFF_SIZE = 200;
|
const MSG_BUFF_SIZE = 200;
|
||||||
|
|
||||||
|
@ -25,6 +36,7 @@ fn get_mdns_socket(ip_info: util.IPInfo) !socket {
|
||||||
const sock = std.c.socket(switch (ip_info.version) {
|
const sock = std.c.socket(switch (ip_info.version) {
|
||||||
util.IP_VER_ENUM.IPv4 => std.c.AF.INET,
|
util.IP_VER_ENUM.IPv4 => std.c.AF.INET,
|
||||||
util.IP_VER_ENUM.IPv6 => std.c.AF.INET6,
|
util.IP_VER_ENUM.IPv6 => std.c.AF.INET6,
|
||||||
|
else => unreachable,
|
||||||
}, std.c.SOCK.DGRAM, std.c.IPPROTO.UDP);
|
}, std.c.SOCK.DGRAM, std.c.IPPROTO.UDP);
|
||||||
if (sock == -1) {
|
if (sock == -1) {
|
||||||
return MDNSError.SocketInitFail;
|
return MDNSError.SocketInitFail;
|
||||||
|
@ -59,6 +71,7 @@ fn construct_mdns_query(domain: util.Domain, ip_info: util.IPInfo, buff: []u8) !
|
||||||
[_]u8{ 0x00, switch (ip_info.version) {
|
[_]u8{ 0x00, switch (ip_info.version) {
|
||||||
util.IP_VER_ENUM.IPv4 => 0x01,
|
util.IP_VER_ENUM.IPv4 => 0x01,
|
||||||
util.IP_VER_ENUM.IPv6 => 0x1c,
|
util.IP_VER_ENUM.IPv6 => 0x1c,
|
||||||
|
else => unreachable,
|
||||||
} } ++ // A or AAAA record
|
} } ++ // A or AAAA record
|
||||||
[_]u8{ 0x00, 0x01 } // IN query
|
[_]u8{ 0x00, 0x01 } // IN query
|
||||||
;
|
;
|
||||||
|
@ -77,6 +90,7 @@ fn get_target_address(ip_info: util.IPInfo) !std.net.Address {
|
||||||
try std.fmt.format(writer, "ff02::fb%{s}", .{ip_info.interface.?});
|
try std.fmt.format(writer, "ff02::fb%{s}", .{ip_info.interface.?});
|
||||||
break :blk buf[0..writer.context.pos];
|
break :blk buf[0..writer.context.pos];
|
||||||
},
|
},
|
||||||
|
else => unreachable,
|
||||||
};
|
};
|
||||||
return std.net.Address.resolveIp(target_addr, 5353);
|
return std.net.Address.resolveIp(target_addr, 5353);
|
||||||
}
|
}
|
||||||
|
@ -153,6 +167,7 @@ fn parse_mdns_response(response: []u8, ip_info: util.IPInfo) !util.IP {
|
||||||
if (switch (ip_info.version) {
|
if (switch (ip_info.version) {
|
||||||
util.IP_VER_ENUM.IPv4 => ip_len != 4,
|
util.IP_VER_ENUM.IPv4 => ip_len != 4,
|
||||||
util.IP_VER_ENUM.IPv6 => ip_len != 16,
|
util.IP_VER_ENUM.IPv6 => ip_len != 16,
|
||||||
|
else => unreachable,
|
||||||
}) {
|
}) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
@ -170,5 +185,12 @@ fn parse_mdns_response(response: []u8, ip_info: util.IPInfo) !util.IP {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return addr;
|
if (addr == null) {
|
||||||
|
return MDNSError.NoMatchingAddress;
|
||||||
|
}
|
||||||
|
return switch (ip_info.version) {
|
||||||
|
.IPv4 => util.IP{ .v4 = addr },
|
||||||
|
.IPv6 => util.IP{ .v6 = addr },
|
||||||
|
else => unreachable,
|
||||||
|
};
|
||||||
}
|
}
|
||||||
|
|
20
src/util.zig
20
src/util.zig
|
@ -7,14 +7,26 @@ const ArgError = error{
|
||||||
InterfaceRequired,
|
InterfaceRequired,
|
||||||
InvalidInterface,
|
InvalidInterface,
|
||||||
EnvVarNotSet,
|
EnvVarNotSet,
|
||||||
|
InvalidOldHostsPath,
|
||||||
};
|
};
|
||||||
|
|
||||||
pub const IP_VER_ENUM = enum(u3) {
|
pub const IP_VER_ENUM = enum(u3) {
|
||||||
|
Both = 0,
|
||||||
IPv4 = 4,
|
IPv4 = 4,
|
||||||
IPv6 = 6,
|
IPv6 = 6,
|
||||||
};
|
};
|
||||||
|
|
||||||
pub const IP = ?std.net.Address;
|
pub const IP = struct {
|
||||||
|
v4: ?std.net.Address = null,
|
||||||
|
v6: ?std.net.Address = null,
|
||||||
|
};
|
||||||
|
|
||||||
|
pub fn merge_addrs(v4: IP, v6: IP) IP {
|
||||||
|
return IP{
|
||||||
|
.v4 = v4.v4,
|
||||||
|
.v6 = v6.v6,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
pub const IPInfo = struct {
|
pub const IPInfo = struct {
|
||||||
version: IP_VER_ENUM,
|
version: IP_VER_ENUM,
|
||||||
|
@ -31,10 +43,12 @@ pub fn getenv(key: [*:0]const u8) ![*:0]const u8 {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn check_perms() !void {
|
pub fn check_perms() !void {
|
||||||
var f = try std.fs.openFileAbsoluteZ(try getenv("OLD_HOSTS_PATH"), .{ .mode = .write_only });
|
var f = try std.fs.Dir.openFileZ(cwd, try getenv("OLD_HOSTS_PATH"), .{ .mode = .write_only });
|
||||||
f.close();
|
f.close();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub const cwd = std.fs.cwd();
|
||||||
|
|
||||||
pub fn get_input() !struct {
|
pub fn get_input() !struct {
|
||||||
Domain,
|
Domain,
|
||||||
IPInfo,
|
IPInfo,
|
||||||
|
@ -49,7 +63,7 @@ pub fn get_input() !struct {
|
||||||
const ip_ver = std.meta.intToEnum(IP_VER_ENUM, std.fmt.parseInt(u3, ip_ver_str, 10) catch return ArgError.InvalidAddressVer) catch return ArgError.InvalidAddressVer;
|
const ip_ver = std.meta.intToEnum(IP_VER_ENUM, std.fmt.parseInt(u3, ip_ver_str, 10) catch return ArgError.InvalidAddressVer) catch return ArgError.InvalidAddressVer;
|
||||||
|
|
||||||
var iface: ?[:0]const u8 = null;
|
var iface: ?[:0]const u8 = null;
|
||||||
if (ip_ver == IP_VER_ENUM.IPv6) {
|
if (ip_ver != IP_VER_ENUM.IPv4) {
|
||||||
iface = args.next() orelse return ArgError.InterfaceRequired;
|
iface = args.next() orelse return ArgError.InterfaceRequired;
|
||||||
if (std.c.if_nametoindex(iface.?) == 0) {
|
if (std.c.if_nametoindex(iface.?) == 0) {
|
||||||
return ArgError.InvalidInterface;
|
return ArgError.InvalidInterface;
|
||||||
|
|
Loading…
Reference in a new issue