stream-server/srt/protocol.go
2023-09-23 20:36:22 +05:00

348 lines
9.3 KiB
Go

package srt
import (
"time"
"math"
"net"
"crypto/sha256"
"fmt"
"errors"
"io"
)
const (
INDUCTION uint8 = iota
CONCLUSION
DATA_LOOP
BROKEN
)
type SRTManager struct {
crypt *CryptHandler
state uint8
init time.Time
syn_cookie uint32
socket net.PacketConn
ctrl_sock_peer uint32
storage *DatumStorage
ack_idx uint32
pings [][2]time.Time
last_nack time.Time
ping_offset int
pkt_sizes []uint32
bw uint32
mtu uint32
output io.WriteCloser
stream_key string
}
func NewSRTManager(l net.PacketConn) (*SRTManager) {
agent := new(SRTManager)
agent.init = time.Now()
agent.socket = l
agent.bw = 15000 // in pkts (mtu bytes) per second
agent.mtu = 1500
return agent
}
// adds basic information present in all packets, timestamp and destination SRT socket
func (agent *SRTManager) create_basic_header() (*Packet) {
packet := new(Packet)
packet.timestamp = uint32(time.Now().Sub(agent.init).Microseconds())
packet.dest_sock = agent.ctrl_sock_peer
return packet
}
func (agent *SRTManager) create_induction_resp() (*Packet) {
packet := agent.create_basic_header()
packet.packet_type = HANDSHAKE
info := new(ControlHeader)
packet.header_info = info
cif := new(HandshakeCIF)
cif.version = 5
cif.ext_field = 0x4a17
cif.hs_type = 1
cif.syn_cookie = agent.syn_cookie
cif.sock_id = 1
cif.mtu = agent.mtu
cif.max_flow = 8192
ip := agent.socket.LocalAddr().(*net.UDPAddr).IP
for i := 0; i < len(ip); i++ {
cif.peer_ip[i] = ip[i]
}
packet.cif = cif
// use the handshake as a placeholder ack-ackack rtt initializer
var init_ping_time [2]time.Time
init_ping_time[0] = time.Now()
agent.pings = append(agent.pings, init_ping_time)
return packet
}
// not ideal, but works
func (agent *SRTManager) make_syn_cookie(peer net.Addr) {
t := uint32(time.Now().Unix()) >> 6
s := sha256.New()
s.Write([]byte(peer.String() + fmt.Sprintf("%d", t)))
agent.syn_cookie = (agent.syn_cookie + t % 32) << 3
for _, v := range s.Sum(nil)[29:] {
agent.syn_cookie = (agent.syn_cookie << 8) + uint32(v)
}
}
func (agent *SRTManager) process_induction(packet *Packet) (*Packet, error) {
if packet.packet_type == HANDSHAKE {
hs_cif := packet.cif.(*HandshakeCIF)
if hs_cif.hs_type == 1 {
agent.state = CONCLUSION
agent.ctrl_sock_peer = hs_cif.sock_id
return agent.create_induction_resp(), nil
}
}
return nil, errors.New("Packet was not handshake")
}
func (agent *SRTManager) create_conclusion_resp() (*Packet) {
packet := agent.create_basic_header()
packet.packet_type = HANDSHAKE
info := new(ControlHeader)
packet.header_info = info
cif := new(HandshakeCIF)
cif.version = 5
cif.ext_field = 0x1 // 1 for HS-ext, does not allow encryption currently
cif.sock_id = 1
cif.mtu = agent.mtu
cif.max_flow = 8192
ip := agent.socket.LocalAddr().(*net.UDPAddr).IP
for i := 0; i < len(ip); i++ {
cif.peer_ip[i] = ip[i]
}
hs_ext := new(HandshakeExtension)
hs_ext.ext_type = 2
hs_ext.ext_len = 12
hs_msg := new(HSEMSG)
hs_msg.flags = uint32(0x01 | 0x02 | 0x04 | 0x08 | 0x20)
hs_msg.version = uint32(0x00010000)
hs_msg.recv_delay = 120
hs_msg.send_delay = 120
hs_ext.ext_contents = hs_msg
cif.hs_extensions = append(cif.hs_extensions, hs_ext)
packet.cif = cif
return packet
}
func (agent *SRTManager) process_conclusion(packet *Packet) (*Packet) {
resp_packet := agent.create_conclusion_resp()
if packet.packet_type == HANDSHAKE {
hs_cif := packet.cif.(*HandshakeCIF)
if hs_cif.hs_type == 0xffffffff && hs_cif.syn_cookie == agent.syn_cookie {
for _, v := range hs_cif.hs_extensions {
// force client to add a stream_id for output location
// to do: add encryption handling
switch v.ext_type {
case 5:
writer, stream_key, ok := CheckStreamID(v.ext_contents.([]byte))
agent.stream_key = stream_key
if !ok {
resp_packet.cif.(*HandshakeCIF).hs_type = 1003
agent.state = 3
return resp_packet
} else {
agent.output = writer
CleanFiles(agent.stream_key, 0)
}
case 3:
resp_packet.cif.(*HandshakeCIF).ext_field = 3
// passphrase harcoded for testing, should pass in somehow with a user management system
crypt_handler := NewCryptHandler("srttestpass", v.ext_contents.(*KMMSG))
if crypt_handler == nil { // if sek unwrap required but fails
agent.state = 3
resp_packet.cif.(*HandshakeCIF).hs_type = 1010
resp_ext := new(HandshakeExtension)
resp_ext.ext_type = 4
resp_ext.ext_len = 4
km_state := make([]byte, 4)
km_state[3] = 4 // BADSECRET code
resp_ext.ext_contents = km_state
resp_packet.cif.(*HandshakeCIF).hs_extensions = append(resp_packet.cif.(*HandshakeCIF).hs_extensions, resp_ext)
return resp_packet
}
// else return since needed
resp_packet.cif.(*HandshakeCIF).hs_extensions = append(resp_packet.cif.(*HandshakeCIF).hs_extensions, v)
v.ext_type = 4
agent.crypt = crypt_handler
}
}
agent.pings[0][1] = time.Now()
// if output was successfully initialized, proceed with data looping
if agent.output != nil {
agent.state = DATA_LOOP
return resp_packet
}
}
}
resp_packet.cif.(*HandshakeCIF).hs_type = 1000
return resp_packet
}
func (agent *SRTManager) create_ack_report() (*Packet) {
packet := agent.create_basic_header()
packet.packet_type = ACK
info := new(ControlHeader)
info.ctrl_type = 2
agent.ack_idx++
info.tsi = agent.ack_idx
packet.header_info = info
cif := new(ACKCIF)
// main has the latest unbroken chain, either no other packets, or
// missing packet which must be nak'd
cif.last_acked = agent.storage.main.end.seq_num
cif.bw = agent.bw
// calculate rtt variance from valid ping pairs, use last value as rtt of last
// exchange since that's what it is
var rtt_sum uint32
var rtt_2_sum uint32
var rtt uint32
var rtt_n uint32
for _, v := range agent.pings {
if !v[0].IsZero() && !v[1].IsZero() {
rtt_n++
rtt = uint32(v[1].Sub(v[0]).Microseconds())
rtt_sum += rtt
rtt_2_sum += uint32(math.Pow(float64(rtt), 2))
}
}
cif.rtt = rtt
cif.var_rtt = uint32(rtt_2_sum / rtt_n) - uint32(math.Pow(float64(rtt_sum / rtt_n), 2))
// use the packets received since the last ack report was sent to calc
// estimated recv rates
cif.pkt_recv_rate = uint32(len(agent.pkt_sizes) * 100)
// arbitrary, should use len(channel) to set this but doesn't really seem to matter
cif.buff_size = 100
var bytes_recvd uint32
for _, v := range agent.pkt_sizes {
bytes_recvd += v
}
cif.rcv_rate = bytes_recvd * 100
packet.cif = cif
var next_ping_pair [2]time.Time
next_ping_pair[0] = time.Now()
// only keep last 100 acks, use offset for correct ackack ping indexing
if len(agent.pings) >= 100 {
agent.pings = append(agent.pings[1:], next_ping_pair)
agent.ping_offset++
} else {
agent.pings = append(agent.pings[:], next_ping_pair)
}
agent.pkt_sizes = make([]uint32, 0)
return packet
}
// only need the recieve time from ackacks for rtt calcs, ignore otherwise
func (agent *SRTManager) handle_ackack(packet *Packet) {
ack_num := packet.header_info.(*ControlHeader).tsi
agent.pings[int(ack_num) - agent.ping_offset][1] = time.Now()
}
func (agent *SRTManager) create_nack_report() (*Packet) {
agent.last_nack = time.Now()
cif, ok := agent.storage.GenNACKCIF()
if !ok {
return nil
}
packet := agent.create_basic_header()
packet.packet_type = NAK
info := new(ControlHeader)
info.ctrl_type = 3
packet.header_info = info
packet.cif = cif
return packet
}
// handling packets during data loop
func (agent *SRTManager) process_data(packet *Packet) (*Packet) {
switch packet.packet_type {
case DATA:
// if data, add to storage, linking, etc
// then check if ack or nack can be generated (every 10 ms)
if agent.crypt != nil {
agent.crypt.Decrypt(packet)
}
agent.handle_data_storage(packet)
if time.Now().Sub(agent.pings[len(agent.pings) - 1][0]).Milliseconds() >= 10 {
return agent.create_ack_report()
}
if agent.last_nack.IsZero() || time.Now().Sub(agent.last_nack).Milliseconds() >= 10 {
return agent.create_nack_report()
}
case ACKACK:
agent.handle_ackack(packet)
case SHUTDOWN:
// state 3 should raise error and shutdown tunnel,
// for now start cleanup procedure in 10s
agent.state = 3
go CleanFiles(agent.stream_key, 10)
default:
return nil
}
return nil
}
func (agent *SRTManager) handle_data_storage(packet *Packet) {
// data packets always have []byte as "cif"
agent.pkt_sizes = append(agent.pkt_sizes, uint32(len(packet.cif.([]byte))))
// initialize storage if does not exist, else add where it can
if agent.storage == nil {
agent.storage = NewDatumStorage(packet)
} else {
agent.storage.NewDatum(packet)
}
// attempt to relink any offshoots
// timestamp for TLPKTDROP
if len(agent.storage.offshoots) != 0 {
agent.storage.Relink(packet.timestamp)
}
// write out all possible packets
if err := agent.storage.Expunge(agent.output); err != nil {
agent.state = 4
}
}
// determines appropriate packets and responses depending on tunnel state
// some need to ignore depending on state, eg
// late induction requests during conclusion phase
func (agent *SRTManager) Process(packet *Packet) (*Packet, error) {
switch agent.state {
case INDUCTION:
return agent.process_induction(packet)
case CONCLUSION:
return agent.process_conclusion(packet), nil
case DATA_LOOP:
return agent.process_data(packet), nil
case BROKEN:
return nil, errors.New("Tunnel shutdown")
default:
return nil, errors.New("State not implemented")
}
}