This commit is contained in:
Muaz Ahmad 2023-09-22 14:58:33 +05:00
parent 25ce4427e6
commit c957c29188
7 changed files with 98 additions and 15 deletions

View file

@ -13,12 +13,20 @@ type Datum struct {
next *Datum
}
// linked chain of datums, specifically to store continuous segments
// when a packet is missed, start and end can be used to generate lost
// packet reports, and easily link when missing is received
// 1..2..3..4 6..7..8..9
// chain_1 chain_2
// nack: get 5
type DatumLink struct {
queued int
queued int // remove eventually, was to be used for ACK recv rate calcs, not needed
root *Datum
end *Datum
}
// data type and function to allow sorting so order can be ignored during
// linking since each is sequential on outset
type chains []*DatumLink
func (c chains) Len() (int) {
@ -29,6 +37,7 @@ func (c chains) Swap(i, j int) {
c[i], c[j] = c[j], c[i]
}
// chain_1 is less then chain_2 when chain_1 ends before chain_2 starts
func (c chains) Less(i, j int) (bool) {
x_1 := c[i].end.seq_num
x_2 := c[j].root.seq_num
@ -43,6 +52,7 @@ type DatumStorage struct {
offshoots chains
}
// append new packet to end of buffer
func (buffer *DatumLink) NewDatum(pkt *Packet) {
datum := new(Datum)
datum.seq_num = pkt.header_info.(*DataHeader).seq_num
@ -53,6 +63,7 @@ func (buffer *DatumLink) NewDatum(pkt *Packet) {
buffer.end = datum
}
// create a new datumlink with root and end at the given packet
func NewDatumLink(pkt *Packet) (*DatumLink) {
buffer := new(DatumLink)
root_datum := new(Datum)
@ -66,12 +77,14 @@ func NewDatumLink(pkt *Packet) (*DatumLink) {
return buffer
}
// initialize storage with the given data packet, in the main chain
func NewDatumStorage(packet *Packet) (*DatumStorage) {
storage := new(DatumStorage)
storage.main = NewDatumLink(packet)
return storage
}
// purge all packets in the main chain except the last for future linkage
func (storage *DatumStorage) Expunge(output io.WriteCloser) (error) {
curr_datum := storage.main.root
seq_num_end := storage.main.end.seq_num
@ -86,6 +99,7 @@ func (storage *DatumStorage) Expunge(output io.WriteCloser) (error) {
return nil
}
// check if the given sequence number should already be inside the given buffer
func (buffer *DatumLink) Holds(new_seq_num uint32) (bool) {
start := buffer.root.seq_num
end := buffer.end.seq_num
@ -98,6 +112,8 @@ func (buffer *DatumLink) Holds(new_seq_num uint32) (bool) {
return true
}
// check if the given seq num lies before the given buffer starts
// buffer is After seq num?
func (buffer *DatumLink) After(new_seq_num uint32) (bool) {
start := buffer.root.seq_num
if serial_less(new_seq_num, start, 31) {
@ -106,6 +122,8 @@ func (buffer *DatumLink) After(new_seq_num uint32) (bool) {
return false
}
// check if the given seq num lies after the given buffer starts
// buffer is Before seq num?
func (buffer *DatumLink) Before(new_seq_num uint32) (bool) {
end := buffer.end.seq_num
if serial_less(end, new_seq_num, 31) {
@ -114,6 +132,7 @@ func (buffer *DatumLink) Before(new_seq_num uint32) (bool) {
return false
}
// add the given packet to the appropriate buffer or add a new offshoot buffer
func (storage *DatumStorage) NewDatum(pkt *Packet) {
new_pkt_num := pkt.header_info.(*DataHeader).seq_num
prev_num := (new_pkt_num - 1) % uint32(1 << 31)
@ -139,12 +158,16 @@ func (storage *DatumStorage) NewDatum(pkt *Packet) {
}
}
// Link 2 buffers to each other
func (buffer *DatumLink) Link(buffer_next *DatumLink) {
buffer.end.next = buffer_next.root
buffer.end = buffer_next.end
buffer.queued += buffer_next.queued
}
// check if the start of the buffer_next is sequentially next to the end of buffer
// or it is too late to get the true next packets to link and the packets must be linked
// given a 31-bit wrap around
func check_append_serial_next(buffer *DatumLink, buffer_next *DatumLink, curr_time uint32) (bool) {
seq_1 := buffer.end.seq_num
seq_2 := (seq_1 + 1) % uint32(math.Pow(2, 31))
@ -155,6 +178,7 @@ func check_append_serial_next(buffer *DatumLink, buffer_next *DatumLink, curr_ti
return false
}
// Given the current storage state, check what packets need to added to fully link main
func (storage *DatumStorage) GenNACKCIF() (*NACKCIF, bool) {
if len(storage.offshoots) == 0 {
return nil, false
@ -173,13 +197,14 @@ func (storage *DatumStorage) GenNACKCIF() (*NACKCIF, bool) {
return cif, true
}
// Try to relink all chains wherever possible or necessary due to TLPKTDROP
func (storage *DatumStorage) Relink(curr_time uint32) {
sort.Sort(storage.offshoots)
buffer := storage.main
i := 0
for i < len(storage.offshoots) {
if check_append_serial_next(buffer, storage.offshoots[i], curr_time) {
storage.offshoots = append(storage.offshoots[:i], storage.offshoots[i + 1:]...)
storage.offshoots = append(storage.offshoots[:i], storage.offshoots[i + 1:]...) // nuke the chain since it is not needed anymore
} else {
buffer = storage.offshoots[i]
i++
@ -187,6 +212,7 @@ func (storage *DatumStorage) Relink(curr_time uint32) {
}
}
// check if a is less than b under serial arithmetic (modulo operations)
func serial_less(a uint32, b uint32, bits int) (bool) {
if (a < b && b-a < (1 << (bits - 1))) || (a > b && a-b > (1 << (bits - 1))) {
return true

View file

@ -15,7 +15,7 @@ func NewIntake(l net.PacketConn, max_conns int) (*Intake) {
intake := new(Intake)
intake.max_conns = max_conns
intake.tunnels = make([]*Tunnel, 0)
intake.buffer = make([]byte, 1500)
intake.buffer = make([]byte, 1500) // each packet is restricted to a max size of 1500
intake.socket = l
return intake
@ -26,9 +26,9 @@ func (intake *Intake) NewTunnel(l net.PacketConn, peer net.Addr) (*Tunnel) {
tunnel := new(Tunnel)
tunnel.socket = l
tunnel.peer = peer
tunnel.queue = make(chan []byte, 10)
tunnel.queue = make(chan []byte, 10) // packet buffer, will cause packet loss if low
intake.tunnels = append(intake.tunnels, tunnel)
go tunnel.Start()
go tunnel.Start() // start the tunnel SRT processing
return tunnel
}
return nil
@ -37,6 +37,7 @@ func (intake *Intake) NewTunnel(l net.PacketConn, peer net.Addr) (*Tunnel) {
func (intake *Intake) get_tunnel(peer net.Addr) (*Tunnel) {
var tunnel *Tunnel
for i := 0; i < len(intake.tunnels); i++ {
// check if tunnels are broken and remove
if intake.tunnels[i].broken {
intake.tunnels[i].Shutdown()
intake.tunnels = append(intake.tunnels[:i], intake.tunnels[i+1:]...)
@ -47,6 +48,9 @@ func (intake *Intake) get_tunnel(peer net.Addr) (*Tunnel) {
tunnel = intake.tunnels[i]
}
}
// if no tunnel was found, make one
// should be after conclusion handshake, but wanted to keep all protocol
// related actions separate from UDP handling
if tunnel == nil {
tunnel = intake.NewTunnel(intake.socket, peer)
}
@ -58,14 +62,16 @@ func (intake *Intake) get_tunnel(peer net.Addr) (*Tunnel) {
func (intake *Intake) Read() {
n, peer, err := intake.socket.ReadFrom(intake.buffer)
if err != nil {
return
return // ignore UDP errors
}
// find the SRT/UDT tunnel corresponding to the given peer
tunnel := intake.get_tunnel(peer)
if tunnel == nil {
return
}
pkt := make([]byte, n)
copy(pkt, intake.buffer[:n])
// send a copy to the corresponding tunnels packet queue if not full
select {
case tunnel.queue <- pkt:
default:

View file

@ -5,6 +5,7 @@ import (
"encoding/binary"
)
// arbitrary indexing
const (
DATA uint8 = iota
HANDSHAKE
@ -15,6 +16,7 @@ const (
DROP
)
// see SRT protocol RFC for information
type ControlHeader struct {
ctrl_type uint16
ctrl_subtype uint16
@ -80,23 +82,26 @@ type pckts_range struct {
end uint32
}
// should be safe to ignore
type DROPCIF struct {
to_drop pckts_range
}
// header and cif are interfaces to allow easier typing, fitting above structs
type Packet struct {
packet_type uint8
timestamp uint32
dest_sock uint32
header_info interface{}
header_info interface{}
cif interface{}
}
// completely pointless since only implementing receiver, here anyway
func marshall_data_packet(packet *Packet, header []byte) ([]byte, error) {
info, ok_head := packet.header_info.(*DataHeader)
data, ok_data := packet.cif.([]byte)
if !ok_head || !ok_data {
return header, errors.New("data packet does not have data header")
return header, errors.New("data packet does not have data header or data")
}
binary.BigEndian.PutUint32(header[:4], info.seq_num)
head2 := (uint32(info.msg_flags) << 26) + info.msg_num
@ -162,6 +167,9 @@ func marshall_nack_cif(data *NACKCIF) ([]byte) {
return loss_list_bytes
}
// locations and length are determined by protocol,
// no real point abstracting that
// could be cleaner with reflects, but added work for very little real gain
func marshall_ack_cif(data *ACKCIF) ([]byte) {
cif := make([]byte, 28)
binary.BigEndian.PutUint32(cif[:4], data.last_acked)
@ -175,6 +183,7 @@ func marshall_ack_cif(data *ACKCIF) ([]byte) {
return cif
}
// same as above
func marshall_hs_cif(data *HandshakeCIF) ([]byte) {
cif := make([]byte, 48)
binary.BigEndian.PutUint32(cif[:4], data.version)
@ -231,7 +240,7 @@ func MarshallPacket(packet *Packet, agent *SRTManager) ([]byte, error) {
func parse_data_packet(pkt *Packet, buffer []byte) (error) {
info := new(DataHeader)
info.seq_num = binary.BigEndian.Uint32(buffer[:4])
info.msg_flags = buffer[4] >> 2
info.msg_flags = buffer[4] >> 2 // unsused since live streaming makes it irrelevant, kk for encrypt eventually
info.msg_num = binary.BigEndian.Uint32(buffer[4:8]) & 0x03ffffff
pkt.header_info = info

View file

@ -39,11 +39,12 @@ func NewSRTManager(l net.PacketConn) (*SRTManager) {
agent := new(SRTManager)
agent.init = time.Now()
agent.socket = l
agent.bw = 15000
agent.bw = 15000 // in pkts (mtu bytes) per second
agent.mtu = 1500
return agent
}
// adds basic information present in all packets, timestamp and destination SRT socket
func (agent *SRTManager) create_basic_header() (*Packet) {
packet := new(Packet)
packet.timestamp = uint32(time.Now().Sub(agent.init).Microseconds())
@ -74,6 +75,7 @@ func (agent *SRTManager) create_induction_resp() (*Packet) {
packet.cif = cif
// use the handshake as a placeholder ack-ackack rtt initializer
var init_ping_time [2]time.Time
init_ping_time[0] = time.Now()
agent.pings = append(agent.pings, init_ping_time)
@ -81,6 +83,7 @@ func (agent *SRTManager) create_induction_resp() (*Packet) {
return packet
}
// not ideal, but works
func (agent *SRTManager) make_syn_cookie(peer net.Addr) {
t := uint32(time.Now().Unix()) >> 6
s := sha256.New()
@ -112,7 +115,7 @@ func (agent *SRTManager) create_conclusion_resp() (*Packet) {
cif := new(HandshakeCIF)
cif.version = 5
cif.ext_field = 0x1
cif.ext_field = 0x1 // 1 for HS-ext, does not allow encryption currently
cif.sock_id = 1
cif.mtu = agent.mtu
cif.max_flow = 8192
@ -145,6 +148,8 @@ func (agent *SRTManager) process_conclusion(packet *Packet) (*Packet) {
hs_cif := packet.cif.(*HandshakeCIF)
if hs_cif.hs_type == 0xffffffff && hs_cif.syn_cookie == agent.syn_cookie {
for _, v := range hs_cif.hs_extensions {
// force client to add a stream_id for output location
// to do: add encryption handling
switch v.ext_type {
case 5:
writer, stream_key, ok := CheckStreamID(v.ext_contents.([]byte))
@ -159,6 +164,7 @@ func (agent *SRTManager) process_conclusion(packet *Packet) (*Packet) {
}
}
agent.pings[0][1] = time.Now()
// if output was successfully initialized, proceed with data looping
if agent.output != nil {
agent.state = DATA_LOOP
return resp_packet
@ -180,9 +186,13 @@ func (agent *SRTManager) create_ack_report() (*Packet) {
packet.header_info = info
cif := new(ACKCIF)
// main has the latest unbroken chain, either no other packets, or
// missing packet which must be nak'd
cif.last_acked = agent.storage.main.end.seq_num
cif.bw = agent.bw
// calculate rtt variance from valid ping pairs, use last value as rtt of last
// exchange since that's what it is
var rtt_sum uint32
var rtt_2_sum uint32
var rtt uint32
@ -198,7 +208,10 @@ func (agent *SRTManager) create_ack_report() (*Packet) {
cif.rtt = rtt
cif.var_rtt = uint32(rtt_2_sum / rtt_n) - uint32(math.Pow(float64(rtt_sum / rtt_n), 2))
// use the packets received since the last ack report was sent to calc
// estimated recv rates
cif.pkt_recv_rate = uint32(len(agent.pkt_sizes) * 100)
// arbitrary, should use len(channel) to set this but doesn't really seem to matter
cif.buff_size = 100
var bytes_recvd uint32
for _, v := range agent.pkt_sizes {
@ -210,6 +223,7 @@ func (agent *SRTManager) create_ack_report() (*Packet) {
var next_ping_pair [2]time.Time
next_ping_pair[0] = time.Now()
// only keep last 100 acks, use offset for correct ackack ping indexing
if len(agent.pings) >= 100 {
agent.pings = append(agent.pings[1:], next_ping_pair)
agent.ping_offset++
@ -220,6 +234,7 @@ func (agent *SRTManager) create_ack_report() (*Packet) {
return packet
}
// only need the recieve time from ackacks for rtt calcs, ignore otherwise
func (agent *SRTManager) handle_ackack(packet *Packet) {
ack_num := packet.header_info.(*ControlHeader).tsi
agent.pings[int(ack_num) - agent.ping_offset][1] = time.Now()
@ -243,9 +258,12 @@ func (agent *SRTManager) create_nack_report() (*Packet) {
return packet
}
// handling packets during data loop
func (agent *SRTManager) process_data(packet *Packet) (*Packet) {
switch packet.packet_type {
case DATA:
// if data, add to storage, linking, etc
// then check if ack or nack can be generated (every 10 ms)
agent.handle_data_storage(packet)
if time.Now().Sub(agent.pings[len(agent.pings) - 1][0]).Milliseconds() >= 10 {
return agent.create_ack_report()
@ -256,6 +274,8 @@ func (agent *SRTManager) process_data(packet *Packet) (*Packet) {
case ACKACK:
agent.handle_ackack(packet)
case SHUTDOWN:
// state 3 should raise error and shutdown tunnel,
// for now start cleanup procedure in 10s
agent.state = 3
go CleanFiles(agent.stream_key, 10)
default:
@ -265,18 +285,26 @@ func (agent *SRTManager) process_data(packet *Packet) (*Packet) {
}
func (agent *SRTManager) handle_data_storage(packet *Packet) {
// data packets always have []byte as "cif"
agent.pkt_sizes = append(agent.pkt_sizes, uint32(len(packet.cif.([]byte))))
// initialize storage if does not exist, else add where it can
if agent.storage == nil {
agent.storage = NewDatumStorage(packet)
} else {
agent.storage.NewDatum(packet)
}
// attempt to relink any offshoots
// timestamp for TLPKTDROP
if len(agent.storage.offshoots) != 0 {
agent.storage.Relink(packet.timestamp)
}
// write out all possible packets
agent.storage.Expunge(agent.output)
}
// determines appropriate packets and responses depending on tunnel state
// some need to ignore depending on state, eg
// late induction requests during conclusion phase
func (agent *SRTManager) Process(packet *Packet) (*Packet, error) {
switch agent.state {
case INDUCTION:

View file

@ -5,6 +5,8 @@ import (
"fmt"
)
// main entry point, no concept of tunnels in UDP so need to implement
// that separately and cannot simply add a max connlimit here like with RTMP
func NewServer(port string) (error) {
l, err := net.ListenPacket("udp", ":" + port)
if err != nil {
@ -15,12 +17,13 @@ func NewServer(port string) (error) {
}
func start(l net.PacketConn) {
// basic panic logging for debugging mostly
defer func() {
if r := recover(); r != nil {
fmt.Println(r)
}
}()
intake := NewIntake(l, 1)
intake := NewIntake(l, 1) // limit to one concurrent tunnel
for {
intake.Read()
}

View file

@ -7,9 +7,9 @@ import (
"stream_server/transcoder"
"time"
"path/filepath"
"fmt"
)
// spawn a transcoder instance and return its stdin pipe
func NewWriter(stream_key string) (io.WriteCloser, error) {
transcoder_in, err := transcoder.NewTranscoder(stream_key)
if err != nil {
@ -18,6 +18,9 @@ func NewWriter(stream_key string) (io.WriteCloser, error) {
return transcoder_in, nil
}
// check if the init.mp4 segment has been modified in between a given sleep time
// if it hasn't (stream disconnected for longer) delete it
// else a new stream has started which shouldn't be deleted
func CleanFiles(stream_key string, delay time.Duration) {
time.Sleep(delay * time.Second)
base_dir, _ := os.UserHomeDir()
@ -31,6 +34,9 @@ func CleanFiles(stream_key string, delay time.Duration) {
}
}
// stream_id is in reverse order, len is multiple of 4 padded with 0
// get the string, check if corresponding folder exists, then attempt
// to spawn a transcoder instance
func CheckStreamID(stream_id []byte) (io.WriteCloser, string, bool) {
stream_key := make([]byte, 0)
for i := len(stream_id) - 1; i >= 0; i-- {
@ -50,6 +56,7 @@ func CheckStreamID(stream_id []byte) (io.WriteCloser, string, bool) {
return nil, stream_key_string, false
}
// checks if folder exists corresponding to the stream_key
func check_stream_key(stream_key string) (bool) {
base_dir, _ := os.UserHomeDir()
if fileinfo, err := os.Stat(base_dir + "/live/" + stream_key); err == nil && fileinfo.IsDir() {

View file

@ -19,8 +19,10 @@ func (tunnel *Tunnel) Start() {
fmt.Println(r)
}
*a = true
}(&(tunnel.broken))
}(&(tunnel.broken)) // force mark tunnel for deletion if any error occurs
tunnel.state = NewSRTManager(tunnel.socket)
// central tunnel loop, read incoming, process and generate response
// write response if any
for {
packet, err := tunnel.ReadPacket()
if err != nil {
@ -38,6 +40,8 @@ func (tunnel *Tunnel) Start() {
}
}
// send a shutdown command, for use when tunnel gets broken
// not ideal but works
func (tunnel *Tunnel) Shutdown() {
if tunnel.state != nil && tunnel.state.state > 1 {
packet := tunnel.state.create_basic_header()
@ -63,6 +67,6 @@ func (tunnel *Tunnel) WritePacket(packet *Packet) {
}
func (tunnel *Tunnel) ReadPacket() (*Packet, error) {
packet := <- tunnel.queue
packet := <- tunnel.queue // blocking read, should add timeout here
return ParsePacket(packet)
}