348 lines
9.3 KiB
Go
348 lines
9.3 KiB
Go
package srt
|
|
|
|
import (
|
|
"time"
|
|
"math"
|
|
"net"
|
|
"crypto/sha256"
|
|
"fmt"
|
|
"errors"
|
|
"io"
|
|
)
|
|
|
|
const (
|
|
INDUCTION uint8 = iota
|
|
CONCLUSION
|
|
DATA_LOOP
|
|
BROKEN
|
|
)
|
|
|
|
type SRTManager struct {
|
|
crypt *CryptHandler
|
|
state uint8
|
|
init time.Time
|
|
syn_cookie uint32
|
|
socket net.PacketConn
|
|
ctrl_sock_peer uint32
|
|
storage *DatumStorage
|
|
ack_idx uint32
|
|
pings [][2]time.Time
|
|
last_nack time.Time
|
|
ping_offset int
|
|
pkt_sizes []uint32
|
|
bw uint32
|
|
mtu uint32
|
|
output io.WriteCloser
|
|
stream_key string
|
|
}
|
|
|
|
func NewSRTManager(l net.PacketConn) (*SRTManager) {
|
|
agent := new(SRTManager)
|
|
agent.init = time.Now()
|
|
agent.socket = l
|
|
agent.bw = 15000 // in pkts (mtu bytes) per second
|
|
agent.mtu = 1500
|
|
return agent
|
|
}
|
|
|
|
// adds basic information present in all packets, timestamp and destination SRT socket
|
|
func (agent *SRTManager) create_basic_header() (*Packet) {
|
|
packet := new(Packet)
|
|
packet.timestamp = uint32(time.Now().Sub(agent.init).Microseconds())
|
|
packet.dest_sock = agent.ctrl_sock_peer
|
|
return packet
|
|
}
|
|
|
|
func (agent *SRTManager) create_induction_resp() (*Packet) {
|
|
packet := agent.create_basic_header()
|
|
packet.packet_type = HANDSHAKE
|
|
|
|
info := new(ControlHeader)
|
|
packet.header_info = info
|
|
|
|
cif := new(HandshakeCIF)
|
|
cif.version = 5
|
|
cif.ext_field = 0x4a17
|
|
cif.hs_type = 1
|
|
cif.syn_cookie = agent.syn_cookie
|
|
cif.sock_id = 1
|
|
cif.mtu = agent.mtu
|
|
cif.max_flow = 8192
|
|
|
|
ip := agent.socket.LocalAddr().(*net.UDPAddr).IP
|
|
for i := 0; i < len(ip); i++ {
|
|
cif.peer_ip[i] = ip[i]
|
|
}
|
|
|
|
packet.cif = cif
|
|
|
|
// use the handshake as a placeholder ack-ackack rtt initializer
|
|
var init_ping_time [2]time.Time
|
|
init_ping_time[0] = time.Now()
|
|
agent.pings = append(agent.pings, init_ping_time)
|
|
|
|
return packet
|
|
}
|
|
|
|
// not ideal, but works
|
|
func (agent *SRTManager) make_syn_cookie(peer net.Addr) {
|
|
t := uint32(time.Now().Unix()) >> 6
|
|
s := sha256.New()
|
|
s.Write([]byte(peer.String() + fmt.Sprintf("%d", t)))
|
|
agent.syn_cookie = (agent.syn_cookie + t % 32) << 3
|
|
for _, v := range s.Sum(nil)[29:] {
|
|
agent.syn_cookie = (agent.syn_cookie << 8) + uint32(v)
|
|
}
|
|
}
|
|
|
|
func (agent *SRTManager) process_induction(packet *Packet) (*Packet, error) {
|
|
if packet.packet_type == HANDSHAKE {
|
|
hs_cif := packet.cif.(*HandshakeCIF)
|
|
if hs_cif.hs_type == 1 {
|
|
agent.state = CONCLUSION
|
|
agent.ctrl_sock_peer = hs_cif.sock_id
|
|
return agent.create_induction_resp(), nil
|
|
}
|
|
}
|
|
return nil, errors.New("Packet was not handshake")
|
|
}
|
|
|
|
func (agent *SRTManager) create_conclusion_resp() (*Packet) {
|
|
packet := agent.create_basic_header()
|
|
packet.packet_type = HANDSHAKE
|
|
|
|
info := new(ControlHeader)
|
|
packet.header_info = info
|
|
|
|
cif := new(HandshakeCIF)
|
|
cif.version = 5
|
|
cif.ext_field = 0x1 // 1 for HS-ext, does not allow encryption currently
|
|
cif.sock_id = 1
|
|
cif.mtu = agent.mtu
|
|
cif.max_flow = 8192
|
|
|
|
ip := agent.socket.LocalAddr().(*net.UDPAddr).IP
|
|
for i := 0; i < len(ip); i++ {
|
|
cif.peer_ip[i] = ip[i]
|
|
}
|
|
|
|
hs_ext := new(HandshakeExtension)
|
|
hs_ext.ext_type = 2
|
|
hs_ext.ext_len = 12
|
|
hs_msg := new(HSEMSG)
|
|
hs_msg.flags = uint32(0x01 | 0x02 | 0x04 | 0x08 | 0x20)
|
|
hs_msg.version = uint32(0x00010000)
|
|
hs_msg.recv_delay = 120
|
|
hs_msg.send_delay = 120
|
|
hs_ext.ext_contents = hs_msg
|
|
|
|
cif.hs_extensions = append(cif.hs_extensions, hs_ext)
|
|
|
|
packet.cif = cif
|
|
|
|
return packet
|
|
}
|
|
|
|
func (agent *SRTManager) process_conclusion(packet *Packet) (*Packet) {
|
|
resp_packet := agent.create_conclusion_resp()
|
|
if packet.packet_type == HANDSHAKE {
|
|
hs_cif := packet.cif.(*HandshakeCIF)
|
|
if hs_cif.hs_type == 0xffffffff && hs_cif.syn_cookie == agent.syn_cookie {
|
|
for _, v := range hs_cif.hs_extensions {
|
|
// force client to add a stream_id for output location
|
|
// to do: add encryption handling
|
|
switch v.ext_type {
|
|
case 5:
|
|
writer, stream_key, ok := CheckStreamID(v.ext_contents.([]byte))
|
|
agent.stream_key = stream_key
|
|
if !ok {
|
|
resp_packet.cif.(*HandshakeCIF).hs_type = 1003
|
|
agent.state = 3
|
|
return resp_packet
|
|
} else {
|
|
agent.output = writer
|
|
CleanFiles(agent.stream_key, 0)
|
|
}
|
|
case 3:
|
|
resp_packet.cif.(*HandshakeCIF).ext_field = 3
|
|
// passphrase harcoded for testing, should pass in somehow with a user management system
|
|
crypt_handler := NewCryptHandler("srttestpass", v.ext_contents.(*KMMSG))
|
|
if crypt_handler == nil { // if sek unwrap required but fails
|
|
agent.state = 3
|
|
resp_packet.cif.(*HandshakeCIF).hs_type = 1010
|
|
resp_ext := new(HandshakeExtension)
|
|
resp_ext.ext_type = 4
|
|
resp_ext.ext_len = 4
|
|
km_state := make([]byte, 4)
|
|
km_state[3] = 4 // BADSECRET code
|
|
resp_ext.ext_contents = km_state
|
|
resp_packet.cif.(*HandshakeCIF).hs_extensions = append(resp_packet.cif.(*HandshakeCIF).hs_extensions, resp_ext)
|
|
return resp_packet
|
|
}
|
|
// else return since needed
|
|
resp_packet.cif.(*HandshakeCIF).hs_extensions = append(resp_packet.cif.(*HandshakeCIF).hs_extensions, v)
|
|
v.ext_type = 4
|
|
agent.crypt = crypt_handler
|
|
}
|
|
}
|
|
agent.pings[0][1] = time.Now()
|
|
// if output was successfully initialized, proceed with data looping
|
|
if agent.output != nil {
|
|
agent.state = DATA_LOOP
|
|
return resp_packet
|
|
}
|
|
}
|
|
}
|
|
resp_packet.cif.(*HandshakeCIF).hs_type = 1000
|
|
return resp_packet
|
|
}
|
|
|
|
func (agent *SRTManager) create_ack_report() (*Packet) {
|
|
packet := agent.create_basic_header()
|
|
packet.packet_type = ACK
|
|
|
|
info := new(ControlHeader)
|
|
info.ctrl_type = 2
|
|
agent.ack_idx++
|
|
info.tsi = agent.ack_idx
|
|
packet.header_info = info
|
|
|
|
cif := new(ACKCIF)
|
|
// main has the latest unbroken chain, either no other packets, or
|
|
// missing packet which must be nak'd
|
|
cif.last_acked = agent.storage.main.end.seq_num
|
|
cif.bw = agent.bw
|
|
|
|
// calculate rtt variance from valid ping pairs, use last value as rtt of last
|
|
// exchange since that's what it is
|
|
var rtt_sum uint32
|
|
var rtt_2_sum uint32
|
|
var rtt uint32
|
|
var rtt_n uint32
|
|
for _, v := range agent.pings {
|
|
if !v[0].IsZero() && !v[1].IsZero() {
|
|
rtt_n++
|
|
rtt = uint32(v[1].Sub(v[0]).Microseconds())
|
|
rtt_sum += rtt
|
|
rtt_2_sum += uint32(math.Pow(float64(rtt), 2))
|
|
}
|
|
}
|
|
cif.rtt = rtt
|
|
cif.var_rtt = uint32(rtt_2_sum / rtt_n) - uint32(math.Pow(float64(rtt_sum / rtt_n), 2))
|
|
|
|
// use the packets received since the last ack report was sent to calc
|
|
// estimated recv rates
|
|
cif.pkt_recv_rate = uint32(len(agent.pkt_sizes) * 100)
|
|
// arbitrary, should use len(channel) to set this but doesn't really seem to matter
|
|
cif.buff_size = 100
|
|
var bytes_recvd uint32
|
|
for _, v := range agent.pkt_sizes {
|
|
bytes_recvd += v
|
|
}
|
|
cif.rcv_rate = bytes_recvd * 100
|
|
|
|
packet.cif = cif
|
|
|
|
var next_ping_pair [2]time.Time
|
|
next_ping_pair[0] = time.Now()
|
|
// only keep last 100 acks, use offset for correct ackack ping indexing
|
|
if len(agent.pings) >= 100 {
|
|
agent.pings = append(agent.pings[1:], next_ping_pair)
|
|
agent.ping_offset++
|
|
} else {
|
|
agent.pings = append(agent.pings[:], next_ping_pair)
|
|
}
|
|
agent.pkt_sizes = make([]uint32, 0)
|
|
return packet
|
|
}
|
|
|
|
// only need the recieve time from ackacks for rtt calcs, ignore otherwise
|
|
func (agent *SRTManager) handle_ackack(packet *Packet) {
|
|
ack_num := packet.header_info.(*ControlHeader).tsi
|
|
agent.pings[int(ack_num) - agent.ping_offset][1] = time.Now()
|
|
}
|
|
|
|
func (agent *SRTManager) create_nack_report() (*Packet) {
|
|
agent.last_nack = time.Now()
|
|
cif, ok := agent.storage.GenNACKCIF()
|
|
if !ok {
|
|
return nil
|
|
}
|
|
packet := agent.create_basic_header()
|
|
packet.packet_type = NAK
|
|
|
|
info := new(ControlHeader)
|
|
info.ctrl_type = 3
|
|
packet.header_info = info
|
|
|
|
packet.cif = cif
|
|
|
|
return packet
|
|
}
|
|
|
|
// handling packets during data loop
|
|
func (agent *SRTManager) process_data(packet *Packet) (*Packet) {
|
|
switch packet.packet_type {
|
|
case DATA:
|
|
// if data, add to storage, linking, etc
|
|
// then check if ack or nack can be generated (every 10 ms)
|
|
if agent.crypt != nil {
|
|
agent.crypt.Decrypt(packet)
|
|
}
|
|
agent.handle_data_storage(packet)
|
|
if time.Now().Sub(agent.pings[len(agent.pings) - 1][0]).Milliseconds() >= 10 {
|
|
return agent.create_ack_report()
|
|
}
|
|
if agent.last_nack.IsZero() || time.Now().Sub(agent.last_nack).Milliseconds() >= 10 {
|
|
return agent.create_nack_report()
|
|
}
|
|
case ACKACK:
|
|
agent.handle_ackack(packet)
|
|
case SHUTDOWN:
|
|
// state 3 should raise error and shutdown tunnel,
|
|
// for now start cleanup procedure in 10s
|
|
agent.state = 3
|
|
go CleanFiles(agent.stream_key, 10)
|
|
default:
|
|
return nil
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (agent *SRTManager) handle_data_storage(packet *Packet) {
|
|
// data packets always have []byte as "cif"
|
|
agent.pkt_sizes = append(agent.pkt_sizes, uint32(len(packet.cif.([]byte))))
|
|
// initialize storage if does not exist, else add where it can
|
|
if agent.storage == nil {
|
|
agent.storage = NewDatumStorage(packet)
|
|
} else {
|
|
agent.storage.NewDatum(packet)
|
|
}
|
|
// attempt to relink any offshoots
|
|
// timestamp for TLPKTDROP
|
|
if len(agent.storage.offshoots) != 0 {
|
|
agent.storage.Relink(packet.timestamp)
|
|
}
|
|
// write out all possible packets
|
|
if err := agent.storage.Expunge(agent.output); err != nil {
|
|
agent.state = 4
|
|
}
|
|
}
|
|
|
|
// determines appropriate packets and responses depending on tunnel state
|
|
// some need to ignore depending on state, eg
|
|
// late induction requests during conclusion phase
|
|
func (agent *SRTManager) Process(packet *Packet) (*Packet, error) {
|
|
switch agent.state {
|
|
case INDUCTION:
|
|
return agent.process_induction(packet)
|
|
case CONCLUSION:
|
|
return agent.process_conclusion(packet), nil
|
|
case DATA_LOOP:
|
|
return agent.process_data(packet), nil
|
|
case BROKEN:
|
|
return nil, errors.New("Tunnel shutdown")
|
|
default:
|
|
return nil, errors.New("State not implemented")
|
|
}
|
|
}
|