stream-server/srt/crypt.go
2023-09-23 20:36:22 +05:00

167 lines
4.2 KiB
Go

package srt
import (
"crypto/aes"
"crypto/sha1"
"crypto/hmac"
"hash"
"math"
"encoding/binary"
"errors"
"crypto/cipher"
)
type CryptHandler struct {
salt [16]byte
key_len uint8
odd_sek cipher.Block
even_sek cipher.Block
}
// init will apply to KM message immediately
func NewCryptHandler(passphrase string, km_msg *KMMSG) (*CryptHandler) {
crypt := new(CryptHandler)
crypt.key_len = km_msg.key_len
crypt.salt = km_msg.salt
ok := crypt.Unwrap(km_msg.wrapped_key, passphrase, km_msg.key_type)
if !ok {
return nil
}
return crypt
}
// PRF as defined in RFC 8018
func PRF(h hash.Hash, input []byte) ([]byte) {
h.Reset()
h.Write(input)
return h.Sum(nil)
}
// gets KEK from passphrase and salt according to fixed SRT iterations and algo from RFC doc
// see RFC 8018 for implementation details
func SRT_PBKDF2(passphrase string, salt []byte, dklen uint8) ([]byte) {
prf := hmac.New(sha1.New, []byte(passphrase))
hlen := prf.Size()
l := int(math.Ceil(float64(dklen) / float64(hlen)))
r := int(dklen) - (l - 1) * hlen
key := make([]byte, 0)
for block := 1; block <= l; block++ {
U := make([]byte, hlen)
T := make([]byte, hlen)
block_i := make([]byte, 4)
binary.BigEndian.PutUint32(block_i, uint32(block))
U = PRF(prf, append(salt, block_i...))
copy(T, U)
// skip one iter since done above
for n := 1; n < 2048; n++ {
U = PRF(prf, U)
for x := range T {
T[x] ^= U[x]
}
}
// final block may not use entire SHA output, still need full during computation
if block == l {
T = T[:r]
}
// final key is appended sequence of all blocks computed independently
key = append(key, T...)
}
return key
}
// See RFC 3394, inplace implementation
func AES_UNWRAP(key []byte, wrapped []byte) ([]byte, error) {
seks := make([]byte, 0) // bytes past IV
cipher, err := aes.NewCipher(key)
if err != nil {
return seks, err
}
A := wrapped[:8] // IV bytes
n := len(wrapped) / 8 - 1
R := make([][]byte, n) // actual message (SEKs)
for i := range R {
R[i] = wrapped[(i + 1) * 8: (i + 2) * 8]
}
for j := 5; j >= 0; j-- {
for i := n; i > 0; i-- {
t := make([]byte, 8)
binary.BigEndian.PutUint64(t, uint64(n * j + i))
for k := range t {
t[k] ^= A[k]
}
B := make([]byte, 16)
cipher.Decrypt(B, append(t, R[i - 1]...))
copy(A, B[:8])
copy(R[i - 1], B[8:])
}
}
// SRT uses default IV, 8 repeating bytes of 0xa6 prepended in wrap, check if
// preserved in unwrap
for i := range A {
if A[i] != 0xa6 {
return seks, errors.New("IV not default")
}
}
// R is 8 byte blocks, keys can be 16-32 bytes, prepend all together and
// let wrappers figure it out
for _, v := range R {
seks = append(seks, v...)
}
return seks, nil
}
// unwrap and store SEK ciphers, key_type defined as KK 2-bit value in Key Material from SRT docs
func (crypt *CryptHandler) Unwrap(wrapped_key []byte, passphrase string, key_type uint8) (bool) {
kek := SRT_PBKDF2(passphrase, crypt.salt[8:], crypt.key_len)
// need a copy since original will be sent back
wrapped_copy := make([]byte, len(wrapped_key))
copy(wrapped_copy, wrapped_key)
seks, err := AES_UNWRAP(kek, wrapped_copy)
// either unwrap fails or key len does not match expected (1 or 2 SEKs len identical)
if err != nil || len(seks) % int(crypt.key_len) != 0 {
return false
}
// always have one SEK, if more bytes (second key) and peer did not send 2 keys
// something is wrong
sek_1 := seks[:crypt.key_len]
if len(seks) > int(crypt.key_len) && key_type != 3 {
return false
}
switch key_type {
case 1:
crypt.even_sek, _ = aes.NewCipher(sek_1)
case 2:
crypt.odd_sek, _ = aes.NewCipher(sek_1)
case 3:
sek_2 := seks[crypt.key_len:]
crypt.even_sek, _ = aes.NewCipher(sek_1)
crypt.odd_sek, _ = aes.NewCipher(sek_2)
default:
return false
}
return true
}
func (crypt *CryptHandler) Decrypt(pkt *Packet) {
var sek cipher.Block
switch pkt.header_info.(*DataHeader).msg_flags & 0x6 {
case 2:
sek = crypt.even_sek
case 4:
sek = crypt.odd_sek
default:
return
}
IV := make([]byte, crypt.key_len)
binary.BigEndian.PutUint32(IV[10:14], pkt.header_info.(*DataHeader).seq_num)
for i := 0; i < 14; i++ {
IV[i] ^= crypt.salt[i]
}
ctr := cipher.NewCTR(sek, IV)
ctr.XORKeyStream(pkt.cif.([]byte), pkt.cif.([]byte))
}