From c957c29188e71a4fa8a668166777b5249c0db1c9 Mon Sep 17 00:00:00 2001 From: Muaz Ahmad Date: Fri, 22 Sep 2023 14:58:33 +0500 Subject: [PATCH] comments --- srt/data_chain.go | 30 ++++++++++++++++++++++++++++-- srt/intake.go | 14 ++++++++++---- srt/packet.go | 15 ++++++++++++--- srt/protocol.go | 32 ++++++++++++++++++++++++++++++-- srt/server.go | 5 ++++- srt/stream_ids.go | 9 ++++++++- srt/tunnel.go | 8 ++++++-- 7 files changed, 98 insertions(+), 15 deletions(-) diff --git a/srt/data_chain.go b/srt/data_chain.go index 94e34fc..982f41e 100644 --- a/srt/data_chain.go +++ b/srt/data_chain.go @@ -13,12 +13,20 @@ type Datum struct { next *Datum } +// linked chain of datums, specifically to store continuous segments +// when a packet is missed, start and end can be used to generate lost +// packet reports, and easily link when missing is received +// 1..2..3..4 6..7..8..9 +// chain_1 chain_2 +// nack: get 5 type DatumLink struct { - queued int + queued int // remove eventually, was to be used for ACK recv rate calcs, not needed root *Datum end *Datum } +// data type and function to allow sorting so order can be ignored during +// linking since each is sequential on outset type chains []*DatumLink func (c chains) Len() (int) { @@ -29,6 +37,7 @@ func (c chains) Swap(i, j int) { c[i], c[j] = c[j], c[i] } +// chain_1 is less then chain_2 when chain_1 ends before chain_2 starts func (c chains) Less(i, j int) (bool) { x_1 := c[i].end.seq_num x_2 := c[j].root.seq_num @@ -43,6 +52,7 @@ type DatumStorage struct { offshoots chains } +// append new packet to end of buffer func (buffer *DatumLink) NewDatum(pkt *Packet) { datum := new(Datum) datum.seq_num = pkt.header_info.(*DataHeader).seq_num @@ -53,6 +63,7 @@ func (buffer *DatumLink) NewDatum(pkt *Packet) { buffer.end = datum } +// create a new datumlink with root and end at the given packet func NewDatumLink(pkt *Packet) (*DatumLink) { buffer := new(DatumLink) root_datum := new(Datum) @@ -66,12 +77,14 @@ func NewDatumLink(pkt *Packet) (*DatumLink) { return buffer } +// initialize storage with the given data packet, in the main chain func NewDatumStorage(packet *Packet) (*DatumStorage) { storage := new(DatumStorage) storage.main = NewDatumLink(packet) return storage } +// purge all packets in the main chain except the last for future linkage func (storage *DatumStorage) Expunge(output io.WriteCloser) (error) { curr_datum := storage.main.root seq_num_end := storage.main.end.seq_num @@ -86,6 +99,7 @@ func (storage *DatumStorage) Expunge(output io.WriteCloser) (error) { return nil } +// check if the given sequence number should already be inside the given buffer func (buffer *DatumLink) Holds(new_seq_num uint32) (bool) { start := buffer.root.seq_num end := buffer.end.seq_num @@ -98,6 +112,8 @@ func (buffer *DatumLink) Holds(new_seq_num uint32) (bool) { return true } +// check if the given seq num lies before the given buffer starts +// buffer is After seq num? func (buffer *DatumLink) After(new_seq_num uint32) (bool) { start := buffer.root.seq_num if serial_less(new_seq_num, start, 31) { @@ -106,6 +122,8 @@ func (buffer *DatumLink) After(new_seq_num uint32) (bool) { return false } +// check if the given seq num lies after the given buffer starts +// buffer is Before seq num? func (buffer *DatumLink) Before(new_seq_num uint32) (bool) { end := buffer.end.seq_num if serial_less(end, new_seq_num, 31) { @@ -114,6 +132,7 @@ func (buffer *DatumLink) Before(new_seq_num uint32) (bool) { return false } +// add the given packet to the appropriate buffer or add a new offshoot buffer func (storage *DatumStorage) NewDatum(pkt *Packet) { new_pkt_num := pkt.header_info.(*DataHeader).seq_num prev_num := (new_pkt_num - 1) % uint32(1 << 31) @@ -139,12 +158,16 @@ func (storage *DatumStorage) NewDatum(pkt *Packet) { } } +// Link 2 buffers to each other func (buffer *DatumLink) Link(buffer_next *DatumLink) { buffer.end.next = buffer_next.root buffer.end = buffer_next.end buffer.queued += buffer_next.queued } +// check if the start of the buffer_next is sequentially next to the end of buffer +// or it is too late to get the true next packets to link and the packets must be linked +// given a 31-bit wrap around func check_append_serial_next(buffer *DatumLink, buffer_next *DatumLink, curr_time uint32) (bool) { seq_1 := buffer.end.seq_num seq_2 := (seq_1 + 1) % uint32(math.Pow(2, 31)) @@ -155,6 +178,7 @@ func check_append_serial_next(buffer *DatumLink, buffer_next *DatumLink, curr_ti return false } +// Given the current storage state, check what packets need to added to fully link main func (storage *DatumStorage) GenNACKCIF() (*NACKCIF, bool) { if len(storage.offshoots) == 0 { return nil, false @@ -173,13 +197,14 @@ func (storage *DatumStorage) GenNACKCIF() (*NACKCIF, bool) { return cif, true } +// Try to relink all chains wherever possible or necessary due to TLPKTDROP func (storage *DatumStorage) Relink(curr_time uint32) { sort.Sort(storage.offshoots) buffer := storage.main i := 0 for i < len(storage.offshoots) { if check_append_serial_next(buffer, storage.offshoots[i], curr_time) { - storage.offshoots = append(storage.offshoots[:i], storage.offshoots[i + 1:]...) + storage.offshoots = append(storage.offshoots[:i], storage.offshoots[i + 1:]...) // nuke the chain since it is not needed anymore } else { buffer = storage.offshoots[i] i++ @@ -187,6 +212,7 @@ func (storage *DatumStorage) Relink(curr_time uint32) { } } +// check if a is less than b under serial arithmetic (modulo operations) func serial_less(a uint32, b uint32, bits int) (bool) { if (a < b && b-a < (1 << (bits - 1))) || (a > b && a-b > (1 << (bits - 1))) { return true diff --git a/srt/intake.go b/srt/intake.go index 3cf7c0b..4963597 100644 --- a/srt/intake.go +++ b/srt/intake.go @@ -15,7 +15,7 @@ func NewIntake(l net.PacketConn, max_conns int) (*Intake) { intake := new(Intake) intake.max_conns = max_conns intake.tunnels = make([]*Tunnel, 0) - intake.buffer = make([]byte, 1500) + intake.buffer = make([]byte, 1500) // each packet is restricted to a max size of 1500 intake.socket = l return intake @@ -26,9 +26,9 @@ func (intake *Intake) NewTunnel(l net.PacketConn, peer net.Addr) (*Tunnel) { tunnel := new(Tunnel) tunnel.socket = l tunnel.peer = peer - tunnel.queue = make(chan []byte, 10) + tunnel.queue = make(chan []byte, 10) // packet buffer, will cause packet loss if low intake.tunnels = append(intake.tunnels, tunnel) - go tunnel.Start() + go tunnel.Start() // start the tunnel SRT processing return tunnel } return nil @@ -37,6 +37,7 @@ func (intake *Intake) NewTunnel(l net.PacketConn, peer net.Addr) (*Tunnel) { func (intake *Intake) get_tunnel(peer net.Addr) (*Tunnel) { var tunnel *Tunnel for i := 0; i < len(intake.tunnels); i++ { + // check if tunnels are broken and remove if intake.tunnels[i].broken { intake.tunnels[i].Shutdown() intake.tunnels = append(intake.tunnels[:i], intake.tunnels[i+1:]...) @@ -47,6 +48,9 @@ func (intake *Intake) get_tunnel(peer net.Addr) (*Tunnel) { tunnel = intake.tunnels[i] } } + // if no tunnel was found, make one + // should be after conclusion handshake, but wanted to keep all protocol + // related actions separate from UDP handling if tunnel == nil { tunnel = intake.NewTunnel(intake.socket, peer) } @@ -58,14 +62,16 @@ func (intake *Intake) get_tunnel(peer net.Addr) (*Tunnel) { func (intake *Intake) Read() { n, peer, err := intake.socket.ReadFrom(intake.buffer) if err != nil { - return + return // ignore UDP errors } + // find the SRT/UDT tunnel corresponding to the given peer tunnel := intake.get_tunnel(peer) if tunnel == nil { return } pkt := make([]byte, n) copy(pkt, intake.buffer[:n]) + // send a copy to the corresponding tunnels packet queue if not full select { case tunnel.queue <- pkt: default: diff --git a/srt/packet.go b/srt/packet.go index b72cdac..296ef99 100644 --- a/srt/packet.go +++ b/srt/packet.go @@ -5,6 +5,7 @@ import ( "encoding/binary" ) +// arbitrary indexing const ( DATA uint8 = iota HANDSHAKE @@ -15,6 +16,7 @@ const ( DROP ) +// see SRT protocol RFC for information type ControlHeader struct { ctrl_type uint16 ctrl_subtype uint16 @@ -80,23 +82,26 @@ type pckts_range struct { end uint32 } +// should be safe to ignore type DROPCIF struct { to_drop pckts_range } +// header and cif are interfaces to allow easier typing, fitting above structs type Packet struct { packet_type uint8 timestamp uint32 dest_sock uint32 - header_info interface{} + header_info interface{} cif interface{} } +// completely pointless since only implementing receiver, here anyway func marshall_data_packet(packet *Packet, header []byte) ([]byte, error) { info, ok_head := packet.header_info.(*DataHeader) data, ok_data := packet.cif.([]byte) if !ok_head || !ok_data { - return header, errors.New("data packet does not have data header") + return header, errors.New("data packet does not have data header or data") } binary.BigEndian.PutUint32(header[:4], info.seq_num) head2 := (uint32(info.msg_flags) << 26) + info.msg_num @@ -162,6 +167,9 @@ func marshall_nack_cif(data *NACKCIF) ([]byte) { return loss_list_bytes } +// locations and length are determined by protocol, +// no real point abstracting that +// could be cleaner with reflects, but added work for very little real gain func marshall_ack_cif(data *ACKCIF) ([]byte) { cif := make([]byte, 28) binary.BigEndian.PutUint32(cif[:4], data.last_acked) @@ -175,6 +183,7 @@ func marshall_ack_cif(data *ACKCIF) ([]byte) { return cif } +// same as above func marshall_hs_cif(data *HandshakeCIF) ([]byte) { cif := make([]byte, 48) binary.BigEndian.PutUint32(cif[:4], data.version) @@ -231,7 +240,7 @@ func MarshallPacket(packet *Packet, agent *SRTManager) ([]byte, error) { func parse_data_packet(pkt *Packet, buffer []byte) (error) { info := new(DataHeader) info.seq_num = binary.BigEndian.Uint32(buffer[:4]) - info.msg_flags = buffer[4] >> 2 + info.msg_flags = buffer[4] >> 2 // unsused since live streaming makes it irrelevant, kk for encrypt eventually info.msg_num = binary.BigEndian.Uint32(buffer[4:8]) & 0x03ffffff pkt.header_info = info diff --git a/srt/protocol.go b/srt/protocol.go index 80f1f8c..47cc157 100644 --- a/srt/protocol.go +++ b/srt/protocol.go @@ -39,11 +39,12 @@ func NewSRTManager(l net.PacketConn) (*SRTManager) { agent := new(SRTManager) agent.init = time.Now() agent.socket = l - agent.bw = 15000 + agent.bw = 15000 // in pkts (mtu bytes) per second agent.mtu = 1500 return agent } +// adds basic information present in all packets, timestamp and destination SRT socket func (agent *SRTManager) create_basic_header() (*Packet) { packet := new(Packet) packet.timestamp = uint32(time.Now().Sub(agent.init).Microseconds()) @@ -74,6 +75,7 @@ func (agent *SRTManager) create_induction_resp() (*Packet) { packet.cif = cif + // use the handshake as a placeholder ack-ackack rtt initializer var init_ping_time [2]time.Time init_ping_time[0] = time.Now() agent.pings = append(agent.pings, init_ping_time) @@ -81,6 +83,7 @@ func (agent *SRTManager) create_induction_resp() (*Packet) { return packet } +// not ideal, but works func (agent *SRTManager) make_syn_cookie(peer net.Addr) { t := uint32(time.Now().Unix()) >> 6 s := sha256.New() @@ -112,7 +115,7 @@ func (agent *SRTManager) create_conclusion_resp() (*Packet) { cif := new(HandshakeCIF) cif.version = 5 - cif.ext_field = 0x1 + cif.ext_field = 0x1 // 1 for HS-ext, does not allow encryption currently cif.sock_id = 1 cif.mtu = agent.mtu cif.max_flow = 8192 @@ -145,6 +148,8 @@ func (agent *SRTManager) process_conclusion(packet *Packet) (*Packet) { hs_cif := packet.cif.(*HandshakeCIF) if hs_cif.hs_type == 0xffffffff && hs_cif.syn_cookie == agent.syn_cookie { for _, v := range hs_cif.hs_extensions { + // force client to add a stream_id for output location + // to do: add encryption handling switch v.ext_type { case 5: writer, stream_key, ok := CheckStreamID(v.ext_contents.([]byte)) @@ -159,6 +164,7 @@ func (agent *SRTManager) process_conclusion(packet *Packet) (*Packet) { } } agent.pings[0][1] = time.Now() + // if output was successfully initialized, proceed with data looping if agent.output != nil { agent.state = DATA_LOOP return resp_packet @@ -180,9 +186,13 @@ func (agent *SRTManager) create_ack_report() (*Packet) { packet.header_info = info cif := new(ACKCIF) + // main has the latest unbroken chain, either no other packets, or + // missing packet which must be nak'd cif.last_acked = agent.storage.main.end.seq_num cif.bw = agent.bw + // calculate rtt variance from valid ping pairs, use last value as rtt of last + // exchange since that's what it is var rtt_sum uint32 var rtt_2_sum uint32 var rtt uint32 @@ -198,7 +208,10 @@ func (agent *SRTManager) create_ack_report() (*Packet) { cif.rtt = rtt cif.var_rtt = uint32(rtt_2_sum / rtt_n) - uint32(math.Pow(float64(rtt_sum / rtt_n), 2)) + // use the packets received since the last ack report was sent to calc + // estimated recv rates cif.pkt_recv_rate = uint32(len(agent.pkt_sizes) * 100) + // arbitrary, should use len(channel) to set this but doesn't really seem to matter cif.buff_size = 100 var bytes_recvd uint32 for _, v := range agent.pkt_sizes { @@ -210,6 +223,7 @@ func (agent *SRTManager) create_ack_report() (*Packet) { var next_ping_pair [2]time.Time next_ping_pair[0] = time.Now() + // only keep last 100 acks, use offset for correct ackack ping indexing if len(agent.pings) >= 100 { agent.pings = append(agent.pings[1:], next_ping_pair) agent.ping_offset++ @@ -220,6 +234,7 @@ func (agent *SRTManager) create_ack_report() (*Packet) { return packet } +// only need the recieve time from ackacks for rtt calcs, ignore otherwise func (agent *SRTManager) handle_ackack(packet *Packet) { ack_num := packet.header_info.(*ControlHeader).tsi agent.pings[int(ack_num) - agent.ping_offset][1] = time.Now() @@ -243,9 +258,12 @@ func (agent *SRTManager) create_nack_report() (*Packet) { return packet } +// handling packets during data loop func (agent *SRTManager) process_data(packet *Packet) (*Packet) { switch packet.packet_type { case DATA: + // if data, add to storage, linking, etc + // then check if ack or nack can be generated (every 10 ms) agent.handle_data_storage(packet) if time.Now().Sub(agent.pings[len(agent.pings) - 1][0]).Milliseconds() >= 10 { return agent.create_ack_report() @@ -256,6 +274,8 @@ func (agent *SRTManager) process_data(packet *Packet) (*Packet) { case ACKACK: agent.handle_ackack(packet) case SHUTDOWN: + // state 3 should raise error and shutdown tunnel, + // for now start cleanup procedure in 10s agent.state = 3 go CleanFiles(agent.stream_key, 10) default: @@ -265,18 +285,26 @@ func (agent *SRTManager) process_data(packet *Packet) (*Packet) { } func (agent *SRTManager) handle_data_storage(packet *Packet) { + // data packets always have []byte as "cif" agent.pkt_sizes = append(agent.pkt_sizes, uint32(len(packet.cif.([]byte)))) + // initialize storage if does not exist, else add where it can if agent.storage == nil { agent.storage = NewDatumStorage(packet) } else { agent.storage.NewDatum(packet) } + // attempt to relink any offshoots + // timestamp for TLPKTDROP if len(agent.storage.offshoots) != 0 { agent.storage.Relink(packet.timestamp) } + // write out all possible packets agent.storage.Expunge(agent.output) } +// determines appropriate packets and responses depending on tunnel state +// some need to ignore depending on state, eg +// late induction requests during conclusion phase func (agent *SRTManager) Process(packet *Packet) (*Packet, error) { switch agent.state { case INDUCTION: diff --git a/srt/server.go b/srt/server.go index 4fd0f86..b576e92 100644 --- a/srt/server.go +++ b/srt/server.go @@ -5,6 +5,8 @@ import ( "fmt" ) +// main entry point, no concept of tunnels in UDP so need to implement +// that separately and cannot simply add a max connlimit here like with RTMP func NewServer(port string) (error) { l, err := net.ListenPacket("udp", ":" + port) if err != nil { @@ -15,12 +17,13 @@ func NewServer(port string) (error) { } func start(l net.PacketConn) { + // basic panic logging for debugging mostly defer func() { if r := recover(); r != nil { fmt.Println(r) } }() - intake := NewIntake(l, 1) + intake := NewIntake(l, 1) // limit to one concurrent tunnel for { intake.Read() } diff --git a/srt/stream_ids.go b/srt/stream_ids.go index ab5b36c..7f598e1 100644 --- a/srt/stream_ids.go +++ b/srt/stream_ids.go @@ -7,9 +7,9 @@ import ( "stream_server/transcoder" "time" "path/filepath" - "fmt" ) +// spawn a transcoder instance and return its stdin pipe func NewWriter(stream_key string) (io.WriteCloser, error) { transcoder_in, err := transcoder.NewTranscoder(stream_key) if err != nil { @@ -18,6 +18,9 @@ func NewWriter(stream_key string) (io.WriteCloser, error) { return transcoder_in, nil } +// check if the init.mp4 segment has been modified in between a given sleep time +// if it hasn't (stream disconnected for longer) delete it +// else a new stream has started which shouldn't be deleted func CleanFiles(stream_key string, delay time.Duration) { time.Sleep(delay * time.Second) base_dir, _ := os.UserHomeDir() @@ -31,6 +34,9 @@ func CleanFiles(stream_key string, delay time.Duration) { } } +// stream_id is in reverse order, len is multiple of 4 padded with 0 +// get the string, check if corresponding folder exists, then attempt +// to spawn a transcoder instance func CheckStreamID(stream_id []byte) (io.WriteCloser, string, bool) { stream_key := make([]byte, 0) for i := len(stream_id) - 1; i >= 0; i-- { @@ -50,6 +56,7 @@ func CheckStreamID(stream_id []byte) (io.WriteCloser, string, bool) { return nil, stream_key_string, false } +// checks if folder exists corresponding to the stream_key func check_stream_key(stream_key string) (bool) { base_dir, _ := os.UserHomeDir() if fileinfo, err := os.Stat(base_dir + "/live/" + stream_key); err == nil && fileinfo.IsDir() { diff --git a/srt/tunnel.go b/srt/tunnel.go index ac0acbd..aa338b8 100644 --- a/srt/tunnel.go +++ b/srt/tunnel.go @@ -19,8 +19,10 @@ func (tunnel *Tunnel) Start() { fmt.Println(r) } *a = true - }(&(tunnel.broken)) + }(&(tunnel.broken)) // force mark tunnel for deletion if any error occurs tunnel.state = NewSRTManager(tunnel.socket) + // central tunnel loop, read incoming, process and generate response + // write response if any for { packet, err := tunnel.ReadPacket() if err != nil { @@ -38,6 +40,8 @@ func (tunnel *Tunnel) Start() { } } +// send a shutdown command, for use when tunnel gets broken +// not ideal but works func (tunnel *Tunnel) Shutdown() { if tunnel.state != nil && tunnel.state.state > 1 { packet := tunnel.state.create_basic_header() @@ -63,6 +67,6 @@ func (tunnel *Tunnel) WritePacket(packet *Packet) { } func (tunnel *Tunnel) ReadPacket() (*Packet, error) { - packet := <- tunnel.queue + packet := <- tunnel.queue // blocking read, should add timeout here return ParsePacket(packet) }