diff --git a/src/transform.zig b/src/transform.zig index 45e8ffd..113fd0a 100644 --- a/src/transform.zig +++ b/src/transform.zig @@ -2,7 +2,63 @@ const std = @import("std"); const util = @import("util.zig"); -pub fn quantize(source: *util.Block, target: *util.BlockQuantized) void { - _ = target; - _ = source; +pub const dct_coeffs = gen_coeffs(); + +inline fn dct_cos(x: usize, f: usize) f16 { + const x_float: f16 = @floatFromInt(x); + const f_float: f16 = @floatFromInt(f); + return @cos((2.0 * x_float + 1.0) * f_float * std.math.pi / 16.0); +} + +inline fn dct_coeff(u: usize, v: usize) f16 { + return 0.25 * (if (u == 0) 1.0 / @sqrt(2.0) else 1.0) * (if (v == 0) 1.0 / @sqrt(2.0) else 1.0); +} + +inline fn zz_conv(u: usize, v: usize) struct { u: usize, v: usize } { + var band_i = u + v; + const band_max_u = @min(7, band_i); + const band_max_v = @min(7, band_i); + var idx: usize = 0; + for (0..band_i) |i| { + idx += zz_band_len(i); + } + if (band_i % 2 == 0) { + idx += band_max_v - v; + } else { + idx += band_max_u - u; + } + return .{ + .u = idx / 8, + .v = idx % 8, + }; +} + +inline fn zz_band_len(band_i: usize) usize { + return if (band_i < 8) band_i + 1 else 15 - band_i; +} + +fn gen_coeffs() [8][8]@Vector(64, f16) { + @setEvalBranchQuota(100000); + var ret: [8][8]@Vector(64, f16) = undefined; + for (0..8) |u| { + for (0..8) |v| { + const zz_idx = zz_conv(u, v); + for (0..8) |x| { + for (0..8) |y| { + ret[zz_idx.u][zz_idx.v][x * 8 + y] = dct_coeff(zz_idx.u, zz_idx.v) * dct_cos(x, zz_idx.u) * dct_cos(y, zz_idx.v); + } + } + } + } + @setEvalBranchQuota(1000); + return ret; +} + +pub fn quantize(source: *util.Block, target: *util.BlockQuantized, qtable: *util.QTable) void { + _ = qtable; + _ = target; + var holder: @Vector(64, f16) = undefined; + for (0..64) |i| { + holder[i] = @floatFromInt(source[i]); + } }