From 15bd8578df2d16f09180089d726bbd979b4633f8 Mon Sep 17 00:00:00 2001 From: Muaz Ahmad Date: Sat, 23 Sep 2023 19:12:46 +0500 Subject: [PATCH] basic encryption key material handling --- srt/crypt.go | 147 ++++++++++++++++++++++++++++++++++++++++++++++++ srt/packet.go | 25 ++++---- srt/protocol.go | 21 +++++++ 3 files changed, 183 insertions(+), 10 deletions(-) create mode 100644 srt/crypt.go diff --git a/srt/crypt.go b/srt/crypt.go new file mode 100644 index 0000000..e2e66cf --- /dev/null +++ b/srt/crypt.go @@ -0,0 +1,147 @@ +package srt + +import ( + "crypto/aes" + "crypto/sha1" + "crypto/hmac" + "hash" + "math" + "encoding/binary" + "errors" + "crypto/cipher" +) + +type CryptHandler struct { + salt [16]byte + key_len uint8 + odd_sek cipher.Block + even_sek cipher.Block +} + +// init will apply to KM message immediately +func NewCryptHandler(passphrase string, km_msg *KMMSG) (*CryptHandler) { + crypt := new(CryptHandler) + crypt.key_len = km_msg.key_len + crypt.salt = km_msg.salt + ok := crypt.Unwrap(km_msg.wrapped_key, passphrase, km_msg.key_type) + if !ok { + return nil + } + return crypt +} + +// PRF as defined in RFC 8018 +func PRF(h hash.Hash, input []byte) ([]byte) { + h.Reset() + h.Write(input) + return h.Sum(nil) +} + +// gets KEK from passphrase and salt according to fixed SRT iterations and algo from RFC doc +// see RFC 8018 for implementation details +func SRT_PBKDF2(passphrase string, salt []byte, dklen uint8) ([]byte) { + prf := hmac.New(sha1.New, []byte(passphrase)) + hlen := prf.Size() + l := int(math.Ceil(float64(dklen) / float64(hlen))) + r := int(dklen) - (l - 1) * hlen + + key := make([]byte, 0) + for block := 1; block <= l; block++ { + U := make([]byte, hlen) + T := make([]byte, hlen) + + block_i := make([]byte, 4) + binary.BigEndian.PutUint32(block_i, uint32(block)) + U = PRF(prf, append(salt, block_i...)) + copy(T, U) + // skip one iter since done above + for n := 1; n < 2048; n++ { + U = PRF(prf, U) + for x := range T { + T[x] ^= U[x] + } + } + // final block may not use entire SHA output, still need full during computation + if block == l { + T = T[:r] + } + // final key is appended sequence of all blocks computed independently + key = append(key, T...) + } + return key +} + +// See RFC 3394, inplace implementation +func AES_UNWRAP(key []byte, wrapped []byte) ([]byte, error) { + seks := make([]byte, 0) // bytes past IV + cipher, err := aes.NewCipher(key) + if err != nil { + return seks, err + } + A := wrapped[:8] // IV bytes + n := len(wrapped) / 8 - 1 + R := make([][]byte, n) // actual message (SEKs) + for i := range R { + R[i] = wrapped[(i + 1) * 8: (i + 2) * 8] + } + + for j := 5; j >= 0; j-- { + for i := n; i > 0; i-- { + t := make([]byte, 8) + binary.BigEndian.PutUint64(t, uint64(n * j + i)) + for k := range t { + t[k] ^= A[k] + } + B := make([]byte, 16) + cipher.Decrypt(B, append(t, R[i - 1]...)) + copy(A, B[:8]) + copy(R[i - 1], B[8:]) + } + } + + // SRT uses default IV, 8 repeating bytes of 0xa6 prepended in wrap, check if + // preserved in unwrap + for i := range A { + if A[i] != 0xa6 { + return seks, errors.New("IV not default") + } + } + // R is 8 byte blocks, keys can be 16-32 bytes, prepend all together and + // let wrappers figure it out + for _, v := range R { + seks = append(seks, v...) + } + return seks, nil +} + +// unwrap and store SEK ciphers, key_type defined as KK 2-bit value in Key Material from SRT docs +func (crypt *CryptHandler) Unwrap(wrapped_key []byte, passphrase string, key_type uint8) (bool) { + kek := SRT_PBKDF2(passphrase, crypt.salt[8:], crypt.key_len) + // need a copy since original will be sent back + wrapped_copy := make([]byte, len(wrapped_key)) + copy(wrapped_copy, wrapped_key) + seks, err := AES_UNWRAP(kek, wrapped_copy) + // either unwrap fails or key len does not match expected (1 or 2 SEKs len identical) + if err != nil || len(seks) % int(crypt.key_len) != 0 { + return false + } + // always have one SEK, if more bytes (second key) and peer did not send 2 keys + // something is wrong + sek_1 := seks[:crypt.key_len] + if len(seks) > int(crypt.key_len) && key_type != 3 { + return false + } + switch key_type { + case 1: + crypt.even_sek, _ = aes.NewCipher(sek_1) + case 2: + crypt.odd_sek, _ = aes.NewCipher(sek_1) + case 3: + sek_2 := seks[crypt.key_len:] + crypt.even_sek, _ = aes.NewCipher(sek_1) + crypt.odd_sek, _ = aes.NewCipher(sek_2) + default: + return false + } + return true +} diff --git a/srt/packet.go b/srt/packet.go index 1290a8a..e5f59ea 100644 --- a/srt/packet.go +++ b/srt/packet.go @@ -210,14 +210,18 @@ func marshall_hs_cif(data *HandshakeCIF) ([]byte) { binary.BigEndian.PutUint16(ext_buff[12:14], contents.recv_delay) binary.BigEndian.PutUint16(ext_buff[14:16], contents.send_delay) case 4: - contents := extension.ext_contents.(*KMMSG) - binary.BigEndian.PutUint32(ext_buff[4:8], uint32(0x12202900) | uint32(contents.key_type)) - binary.BigEndian.PutUint32(ext_buff[12:16], uint32(0x02000200)) - binary.BigEndian.PutUint32(ext_buff[16:20], uint32(0x0400) | uint32(contents.key_len / 4)) - for i := 0; i < 16; i++ { - ext_buff[20 + i] = contents.salt[i] + contents, ok := extension.ext_contents.(*KMMSG) + if !ok { // handle km_state error + copy(ext_buff[4:8], extension.ext_contents.([]byte)) + } else { + binary.BigEndian.PutUint32(ext_buff[4:8], uint32(0x12202900) | uint32(contents.key_type)) + binary.BigEndian.PutUint32(ext_buff[12:16], uint32(0x02000200)) + binary.BigEndian.PutUint32(ext_buff[16:20], uint32(0x0400) | uint32(contents.key_len / 4)) + for i := 0; i < 16; i++ { + ext_buff[20 + i] = contents.salt[i] + } + copy(ext_buff[36:], contents.wrapped_key) } - copy(ext_buff[36:], contents.wrapped_key) default: copy(ext_buff[4:], extension.ext_contents.([]byte)) } @@ -317,12 +321,13 @@ func parse_hs_cif(cif *HandshakeCIF, buffer []byte) (error) { content := new(KMMSG) content.key_type = extensions[7] & 0x3 content.key_len = extensions[19] * 4 - for i := 0; i < 4; i++ { + for i := 0; i < 16; i++ { content.salt[i] = extensions[20 + i] } - wrap_key_len := 4 + ext.ext_len - 24 + // -36 from actual content len, extensions includes headers as well + wrap_key_len := 4 + ext.ext_len - 36 content.wrapped_key = make([]byte, wrap_key_len) - copy(content.wrapped_key, extensions[24:24 + wrap_key_len]) + copy(content.wrapped_key, extensions[36:36 + wrap_key_len]) ext.ext_contents = content default: content := make([]byte, ext.ext_len) diff --git a/srt/protocol.go b/srt/protocol.go index 50c8ef6..5128c1e 100644 --- a/srt/protocol.go +++ b/srt/protocol.go @@ -18,6 +18,7 @@ const ( ) type SRTManager struct { + crypt *CryptHandler state uint8 init time.Time syn_cookie uint32 @@ -156,11 +157,31 @@ func (agent *SRTManager) process_conclusion(packet *Packet) (*Packet) { agent.stream_key = stream_key if !ok { resp_packet.cif.(*HandshakeCIF).hs_type = 1003 + agent.state = 3 return resp_packet } else { agent.output = writer CleanFiles(agent.stream_key, 0) } + case 3: + resp_packet.cif.(*HandshakeCIF).ext_field = 3 + // passphrase harcoded for testing, should pass in somehow with a user management system + crypt_handler := NewCryptHandler("srttestpass", v.ext_contents.(*KMMSG)) + if crypt_handler == nil { // if sek unwrap required but fails + agent.state = 3 + resp_packet.cif.(*HandshakeCIF).hs_type = 1010 + resp_ext := new(HandshakeExtension) + resp_ext.ext_type = 4 + resp_ext.ext_len = 4 + km_state := make([]byte, 4) + km_state[3] = 4 // BADSECRET code + resp_ext.ext_contents = km_state + resp_packet.cif.(*HandshakeCIF).hs_extensions = append(resp_packet.cif.(*HandshakeCIF).hs_extensions, resp_ext) + return resp_packet + } + // else return since needed + resp_packet.cif.(*HandshakeCIF).hs_extensions = append(resp_packet.cif.(*HandshakeCIF).hs_extensions, v) + v.ext_type = 4 } } agent.pings[0][1] = time.Now()