stream-server/srt/protocol.go
2023-09-21 13:36:11 +05:00

261 lines
5.9 KiB
Go

package srt
import (
"time"
"math"
"net"
"crypto/sha256"
"fmt"
"errors"
"io"
)
const (
INDUCTION uint8 = iota
CONCLUSION
DATA_LOOP
)
type SRTManager struct {
state uint8
init time.Time
syn_cookie uint32
socket net.PacketConn
ctrl_sock_peer uint32
storage *DatumStorage
ack_idx uint32
pings [][2]time.Time
ping_offset int
pkt_sizes []uint32
bw uint32
mtu uint32
output io.WriteCloser
}
func NewSRTManager(l net.PacketConn) (*SRTManager) {
agent := new(SRTManager)
agent.init = time.Now()
agent.socket = l
agent.bw = 15000
agent.mtu = 1500
return agent
}
func (agent *SRTManager) create_basic_header() (*Packet) {
packet := new(Packet)
packet.timestamp = uint32(time.Now().Sub(agent.init).Microseconds())
packet.dest_sock = agent.ctrl_sock_peer
return packet
}
func (agent *SRTManager) create_induction_resp() (*Packet) {
packet := agent.create_basic_header()
packet.packet_type = HANDSHAKE
info := new(ControlHeader)
packet.header_info = info
cif := new(HandshakeCIF)
cif.version = 5
cif.ext_field = 0x4a17
cif.hs_type = 1
cif.syn_cookie = agent.syn_cookie
cif.sock_id = 1
cif.mtu = agent.mtu
cif.max_flow = 8192
ip := agent.socket.LocalAddr().(*net.UDPAddr).IP
for i := 0; i < len(ip); i++ {
cif.peer_ip[i] = ip[i]
}
packet.cif = cif
var init_ping_time [2]time.Time
init_ping_time[0] = time.Now()
agent.pings = append(agent.pings, init_ping_time)
return packet
}
func (agent *SRTManager) make_syn_cookie(peer net.Addr) {
t := uint32(time.Now().Unix()) >> 6
s := sha256.New()
s.Write([]byte(peer.String() + fmt.Sprintf("%d", t)))
agent.syn_cookie = (agent.syn_cookie + t % 32) << 3
for _, v := range s.Sum(nil)[29:] {
agent.syn_cookie = (agent.syn_cookie << 8) + uint32(v)
}
}
func (agent *SRTManager) process_induction(packet *Packet) (*Packet, error) {
if packet.packet_type == HANDSHAKE {
hs_cif := packet.cif.(*HandshakeCIF)
if hs_cif.hs_type == 1 {
agent.state = CONCLUSION
agent.ctrl_sock_peer = hs_cif.sock_id
return agent.create_induction_resp(), nil
}
}
return nil, errors.New("Packet was not handshake")
}
func (agent *SRTManager) create_conclusion_resp() (*Packet) {
packet := agent.create_basic_header()
packet.packet_type = HANDSHAKE
info := new(ControlHeader)
packet.header_info = info
cif := new(HandshakeCIF)
cif.version = 5
cif.ext_field = 0x1
cif.sock_id = 1
cif.mtu = agent.mtu
cif.max_flow = 8192
ip := agent.socket.LocalAddr().(*net.UDPAddr).IP
for i := 0; i < len(ip); i++ {
cif.peer_ip[i] = ip[i]
}
hs_ext := new(HandshakeExtension)
hs_ext.ext_type = 2
hs_ext.ext_len = 12
hs_msg := new(HSEMSG)
hs_msg.flags = uint32(0x01 | 0x02 | 0x04 | 0x08 | 0x20)
hs_msg.version = uint32(0x00010000)
hs_msg.recv_delay = 120
hs_msg.send_delay = 120
hs_ext.ext_contents = hs_msg
cif.hs_extensions = append(cif.hs_extensions, hs_ext)
packet.cif = cif
return packet
}
func (agent *SRTManager) process_conclusion(packet *Packet) (*Packet) {
resp_packet := agent.create_conclusion_resp()
if packet.packet_type == HANDSHAKE {
hs_cif := packet.cif.(*HandshakeCIF)
if hs_cif.hs_type == 0xffffffff && hs_cif.syn_cookie == agent.syn_cookie {
for _, v := range hs_cif.hs_extensions {
switch v.ext_type {
case 5:
writer, ok := CheckStreamID(v.ext_contents.([]byte))
if !ok {
resp_packet.cif.(*HandshakeCIF).hs_type = 1003
return resp_packet
} else {
agent.output = writer
}
}
}
agent.pings[0][1] = time.Now()
if agent.output != nil {
agent.state = DATA_LOOP
return resp_packet
}
}
}
resp_packet.cif.(*HandshakeCIF).hs_type = 1000
return resp_packet
}
func (agent *SRTManager) create_ack_report() (*Packet) {
packet := agent.create_basic_header()
packet.packet_type = ACK
info := new(ControlHeader)
info.ctrl_type = 2
agent.ack_idx++
info.tsi = agent.ack_idx
packet.header_info = info
cif := new(ACKCIF)
cif.last_acked = agent.storage.main.end.seq_num
cif.bw = agent.bw
var rtt_sum uint32
var rtt_2_sum uint32
var rtt uint32
var rtt_n uint32
for _, v := range agent.pings {
if !v[0].IsZero() && !v[1].IsZero() {
rtt_n++
rtt = uint32(v[1].Sub(v[0]).Microseconds())
rtt_sum += rtt
rtt_2_sum += uint32(math.Pow(float64(rtt), 2))
}
}
cif.rtt = rtt
cif.var_rtt = uint32(rtt_2_sum / rtt_n) - uint32(math.Pow(float64(rtt_sum / rtt_n), 2))
cif.pkt_recv_rate = uint32(len(agent.pkt_sizes) * 100)
cif.buff_size = 100
var bytes_recvd uint32
for _, v := range agent.pkt_sizes {
bytes_recvd += v
}
cif.rcv_rate = bytes_recvd * 100
packet.cif = cif
var next_ping_pair [2]time.Time
next_ping_pair[0] = time.Now()
if len(agent.pings) >= 100 {
agent.pings = append(agent.pings[1:], next_ping_pair)
agent.ping_offset++
} else {
agent.pings = append(agent.pings[:], next_ping_pair)
}
agent.pkt_sizes = make([]uint32, 0)
return packet
}
func (agent *SRTManager) handle_ackack(packet *Packet) {
ack_num := packet.header_info.(*ControlHeader).tsi
agent.pings[int(ack_num) - agent.ping_offset][1] = time.Now()
}
func (agent *SRTManager) process_data(packet *Packet) (*Packet) {
switch packet.packet_type {
case DATA:
agent.handle_data_storage(packet)
if time.Now().Sub(agent.pings[len(agent.pings) - 1][0]).Milliseconds() >= 10 {
return agent.create_ack_report()
}
case ACKACK:
agent.handle_ackack(packet)
default:
return nil
}
return nil
}
func (agent *SRTManager) handle_data_storage(packet *Packet) {
agent.pkt_sizes = append(agent.pkt_sizes, uint32(len(packet.cif.([]byte))))
if agent.storage == nil {
agent.storage = NewDatumStorage(packet)
} else {
agent.storage.NewDatum(packet)
}
if len(agent.storage.offshoots) != 0 {
agent.storage.Relink()
}
}
func (agent *SRTManager) Process(packet *Packet) (*Packet, error) {
switch agent.state {
case INDUCTION:
return agent.process_induction(packet)
case CONCLUSION:
return agent.process_conclusion(packet), nil
case DATA_LOOP:
return agent.process_data(packet), nil
default:
return nil, errors.New("State not implemented")
}
}