diff --git a/go.mod b/go.mod index 45875ee..76b010f 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,3 @@ module stream_server -go 1.20 +go 1.21 diff --git a/srt/packet.go b/srt/packet.go index b58c401..a7d7e49 100644 --- a/srt/packet.go +++ b/srt/packet.go @@ -105,7 +105,7 @@ func marshall_data_packet(packet *Packet, header []byte) ([]byte, error) { } func marshall_ctrl_packet(packet *Packet, header []byte) ([]byte, error) { - _, ok_head := packet.header_info.(*ControlHeader) + ctrl_header, ok_head := packet.header_info.(*ControlHeader) if !ok_head { return header, errors.New("control packet does not have ctrl header") } @@ -122,10 +122,32 @@ func marshall_ctrl_packet(packet *Packet, header []byte) ([]byte, error) { header[1] |= 5 cif := make([]byte, 4) return append(header, cif...), nil + case ACK: + header[1] |= 2 + binary.BigEndian.PutUint32(header[4:8], ctrl_header.tsi) + data, ok_data := packet.cif.(*ACKCIF) + if !ok_data { + return header, errors.New("ACK has no data") + } + cif := marshall_ack_cif(data) + return append(header, cif...), nil } return header, errors.New("Control packet type not recognized") } +func marshall_ack_cif(data *ACKCIF) ([]byte) { + cif := make([]byte, 28) + binary.BigEndian.PutUint32(cif[:4], data.last_acked) + binary.BigEndian.PutUint32(cif[4:8], data.rtt) + binary.BigEndian.PutUint32(cif[8:12], data.var_rtt) + binary.BigEndian.PutUint32(cif[12:16], data.buff_size) + binary.BigEndian.PutUint32(cif[16:20], data.pkt_recv_rate) + binary.BigEndian.PutUint32(cif[20:24], data.bw) + binary.BigEndian.PutUint32(cif[24:28], data.rcv_rate) + + return cif +} + func marshall_hs_cif(data *HandshakeCIF) ([]byte) { cif := make([]byte, 48) binary.BigEndian.PutUint32(cif[:4], data.version) @@ -212,10 +234,14 @@ func parse_ctrl_packet(pkt *Packet, buffer []byte) (error) { return parse_hs_cif(cif, buffer[16:]) } return errors.New("HS not long enough") + case 6: + pkt.packet_type = ACKACK + return nil case 5: return errors.New("Shutdown received") + default: + return errors.New("Unexpected control type") } - return errors.New("Unexpected control type") } func parse_hs_cif(cif *HandshakeCIF, buffer []byte) (error) { diff --git a/srt/protocol.go b/srt/protocol.go index 38a7cbd..1636c4b 100644 --- a/srt/protocol.go +++ b/srt/protocol.go @@ -2,10 +2,12 @@ package srt import ( "time" + "math" "net" "crypto/sha256" "fmt" "errors" + "io" ) const ( @@ -21,14 +23,21 @@ type SRTManager struct { socket net.PacketConn ctrl_sock_peer uint32 storage *DatumStorage - ack_interval uint8 + ack_idx uint32 + pings [][2]time.Time + ping_offset int + pkt_sizes []uint32 + bw uint32 + mtu uint32 + output io.WriteCloser } func NewSRTManager(l net.PacketConn) (*SRTManager) { agent := new(SRTManager) agent.init = time.Now() agent.socket = l - agent.ack_interval = 25 + agent.bw = 15000 + agent.mtu = 1500 return agent } @@ -52,7 +61,7 @@ func (agent *SRTManager) create_induction_resp() (*Packet) { cif.hs_type = 1 cif.syn_cookie = agent.syn_cookie cif.sock_id = 1 - cif.mtu = 1500 + cif.mtu = agent.mtu cif.max_flow = 8192 ip := agent.socket.LocalAddr().(*net.UDPAddr).IP @@ -62,6 +71,10 @@ func (agent *SRTManager) create_induction_resp() (*Packet) { packet.cif = cif + var init_ping_time [2]time.Time + init_ping_time[0] = time.Now() + agent.pings = append(agent.pings, init_ping_time) + return packet } @@ -98,7 +111,7 @@ func (agent *SRTManager) create_conclusion_resp() (*Packet) { cif.version = 5 cif.ext_field = 0x1 cif.sock_id = 1 - cif.mtu = 1500 + cif.mtu = agent.mtu cif.max_flow = 8192 ip := agent.socket.LocalAddr().(*net.UDPAddr).IP @@ -127,6 +140,7 @@ func (agent *SRTManager) process_conclusion(packet *Packet) (*Packet) { if packet.packet_type == HANDSHAKE { hs_cif := packet.cif.(*HandshakeCIF) if hs_cif.hs_type == 0xffffffff && hs_cif.syn_cookie == agent.syn_cookie { + agent.pings[0][1] = time.Now() agent.state = DATA_LOOP return agent.create_conclusion_resp() } @@ -135,16 +149,71 @@ func (agent *SRTManager) process_conclusion(packet *Packet) (*Packet) { } func (agent *SRTManager) create_ack_report() (*Packet) { - return nil + packet := agent.create_basic_header() + packet.packet_type = ACK + + info := new(ControlHeader) + info.ctrl_type = 2 + agent.ack_idx++ + info.tsi = agent.ack_idx + packet.header_info = info + + cif := new(ACKCIF) + cif.last_acked = agent.storage.main.end.seq_num + cif.bw = agent.bw + + var rtt_sum uint32 + var rtt_2_sum uint32 + var rtt uint32 + var rtt_n uint32 + for _, v := range agent.pings { + if !v[0].IsZero() && !v[1].IsZero() { + rtt_n++ + rtt = uint32(v[1].Sub(v[0]).Microseconds()) + rtt_sum += rtt + rtt_2_sum += uint32(math.Pow(float64(rtt), 2)) + } + } + fmt.Println(rtt, rtt_sum, rtt_2_sum, rtt_n) + cif.rtt = rtt + cif.var_rtt = uint32(rtt_2_sum / rtt_n) - uint32(math.Pow(float64(rtt_sum / rtt_n), 2)) + + cif.pkt_recv_rate = uint32(len(agent.pkt_sizes) * 100) + cif.buff_size = 100 + var bytes_recvd uint32 + for _, v := range agent.pkt_sizes { + bytes_recvd += v + } + cif.rcv_rate = bytes_recvd * 100 + + packet.cif = cif + + var next_ping_pair [2]time.Time + next_ping_pair[0] = time.Now() + if len(agent.pings) >= 100 { + agent.pings = append(agent.pings[1:], next_ping_pair) + agent.ping_offset++ + } else { + agent.pings = append(agent.pings[:], next_ping_pair) + } + agent.pkt_sizes = make([]uint32, 0) + return packet +} + +func (agent *SRTManager) handle_ackack(packet *Packet) { + ack_num := packet.header_info.(*ControlHeader).tsi + agent.pings[int(ack_num) - agent.ping_offset][1] = time.Now() } func (agent *SRTManager) process_data(packet *Packet) (*Packet) { switch packet.packet_type { case DATA: agent.handle_data_storage(packet) - if agent.storage.main.queued >= 25 { + if time.Now().Sub(agent.pings[len(agent.pings) - 1][0]).Milliseconds() >= 10 { return agent.create_ack_report() } + case ACKACK: + agent.handle_ackack(packet) default: return nil } @@ -152,6 +221,7 @@ func (agent *SRTManager) process_data(packet *Packet) (*Packet) { } func (agent *SRTManager) handle_data_storage(packet *Packet) { + agent.pkt_sizes = append(agent.pkt_sizes, uint32(len(packet.cif.([]byte)))) if agent.storage == nil { agent.storage = NewDatumStorage(packet) } else { diff --git a/srt/tunnel.go b/srt/tunnel.go index 7f84547..2fd169d 100644 --- a/srt/tunnel.go +++ b/srt/tunnel.go @@ -24,6 +24,7 @@ func (tunnel *Tunnel) Start() { for { packet, err := tunnel.ReadPacket() if err != nil { + fmt.Println(err) tunnel.broken = true } response, err := tunnel.state.Process(packet)