diff --git a/srt/data_chain.go b/srt/data_chain.go index d1a34d7..9576b73 100644 --- a/srt/data_chain.go +++ b/srt/data_chain.go @@ -3,6 +3,7 @@ package srt import ( "math" "sort" + "io" ) type Datum struct { @@ -70,17 +71,57 @@ func NewDatumStorage(packet *Packet) (*DatumStorage) { return storage } +func (storage *DatumStorage) Expunge(output io.WriteCloser) (error) { + curr_datum := storage.main.root + seq_num_end := storage.main.end.seq_num + for curr_datum.seq_num != seq_num_end { + _, err := output.Write(curr_datum.data) + if err != nil { + return err + } + curr_datum = curr_datum.next + } + storage.main.root = curr_datum + return nil +} + +func (buffer *DatumLink) InRange(new_seq_num uint32) (bool) { + start := buffer.root.seq_num + end := buffer.end.seq_num + if new_seq_num == start || new_seq_num == end { + return true + } + if !((start < new_seq_num && new_seq_num - start < (1 << 30)) || (start > new_seq_num && start - new_seq_num > (1 << 30))) { + return false + } + if !((new_seq_num < end && end - new_seq_num < (1 << 30)) || (new_seq_num > end && new_seq_num - end > (1 << 30))) { + return false + } + return true +} + func (storage *DatumStorage) NewDatum(pkt *Packet) { - prev_num := (pkt.header_info.(*DataHeader).seq_num - 1) % uint32(math.Pow(2, 31)) + new_pkt_num := pkt.header_info.(*DataHeader).seq_num + prev_num := (new_pkt_num - 1) % uint32(1 << 31) if storage.main.end.seq_num == prev_num { storage.main.NewDatum(pkt) + } else if storage.main.InRange(new_pkt_num) { + return } else { + oldest := storage.main.root.seq_num + if (new_pkt_num < oldest && oldest - new_pkt_num < (1 << 30)) || (new_pkt_num > oldest && new_pkt_num - oldest > (1 << 30)) { + return + } for _, v := range storage.offshoots { if v.end.seq_num == prev_num { v.NewDatum(pkt) - break + return + } else if v.InRange(new_pkt_num) { + return } } + new_link := NewDatumLink(pkt) + storage.offshoots = append(storage.offshoots, new_link) } } @@ -93,13 +134,31 @@ func (buffer *DatumLink) Link(buffer_next *DatumLink) { func check_append_serial_next(buffer *DatumLink, buffer_next *DatumLink) (bool) { seq_1 := buffer.end.seq_num seq_2 := (seq_1 + 1) % uint32(math.Pow(2, 31)) - if seq_1 == seq_2 { + if buffer_next.root.seq_num == seq_2 { buffer.Link(buffer_next) return true } return false } +func (storage *DatumStorage) GenNACKCIF() (*NACKCIF, bool) { + if len(storage.offshoots) == 0 { + return nil, false + } + cif := new(NACKCIF) + init_range := new(pckts_range) + init_range.start = (storage.main.end.seq_num + 1) % (1 << 31) + init_range.end = (storage.offshoots[0].root.seq_num - 1) % (1 << 31) + cif.lost_pkts = append(cif.lost_pkts, init_range) + for i := 0; i < len(storage.offshoots) - 1; i++ { + new_range := new(pckts_range) + new_range.start = (storage.offshoots[i].end.seq_num + 1) % (1 << 31) + new_range.end = (storage.offshoots[i + 1].root.seq_num - 1) % (1 << 31) + cif.lost_pkts = append(cif.lost_pkts, new_range) + } + return cif, true +} + func (storage *DatumStorage) Relink() { sort.Sort(storage.offshoots) buffer := storage.main diff --git a/srt/intake.go b/srt/intake.go index 6f81bbf..3cf7c0b 100644 --- a/srt/intake.go +++ b/srt/intake.go @@ -26,7 +26,7 @@ func (intake *Intake) NewTunnel(l net.PacketConn, peer net.Addr) (*Tunnel) { tunnel := new(Tunnel) tunnel.socket = l tunnel.peer = peer - tunnel.queue = make(chan []byte, 100) + tunnel.queue = make(chan []byte, 10) intake.tunnels = append(intake.tunnels, tunnel) go tunnel.Start() return tunnel @@ -69,6 +69,6 @@ func (intake *Intake) Read() { select { case tunnel.queue <- pkt: default: - tunnel.broken = true + //tunnel.broken = true } } diff --git a/srt/packet.go b/srt/packet.go index a7d7e49..936a875 100644 --- a/srt/packet.go +++ b/srt/packet.go @@ -3,6 +3,7 @@ package srt import ( "errors" "encoding/binary" + "fmt" ) const ( @@ -109,7 +110,10 @@ func marshall_ctrl_packet(packet *Packet, header []byte) ([]byte, error) { if !ok_head { return header, errors.New("control packet does not have ctrl header") } - header[0] = 0x80 + binary.BigEndian.PutUint16(header[:2], ctrl_header.ctrl_type) + binary.BigEndian.PutUint16(header[2:4], ctrl_header.ctrl_subtype) + binary.BigEndian.PutUint32(header[4:8], ctrl_header.tsi) + header[0] |= 0x80 switch packet.packet_type { case HANDSHAKE: data, ok_data := packet.cif.(*HandshakeCIF) @@ -119,22 +123,48 @@ func marshall_ctrl_packet(packet *Packet, header []byte) ([]byte, error) { cif := marshall_hs_cif(data) return append(header, cif...), nil case SHUTDOWN: - header[1] |= 5 cif := make([]byte, 4) return append(header, cif...), nil case ACK: - header[1] |= 2 - binary.BigEndian.PutUint32(header[4:8], ctrl_header.tsi) data, ok_data := packet.cif.(*ACKCIF) if !ok_data { return header, errors.New("ACK has no data") } cif := marshall_ack_cif(data) return append(header, cif...), nil + case NAK: + data, ok_data := packet.cif.(*NACKCIF) + if !ok_data { + return header, errors.New("NAK has no data") + } + cif := marshall_nack_cif(data) + return append(header, cif...), nil } return header, errors.New("Control packet type not recognized") } +func marshall_nack_cif(data *NACKCIF) ([]byte) { + fmt.Println("new NAK") + var loss_list_bytes []byte + for _, pkts := range data.lost_pkts { + fmt.Println(pkts) + if pkts.start == pkts.end { + curr_bytes := make([]byte, 4) + binary.BigEndian.PutUint32(curr_bytes, pkts.start) + loss_list_bytes = append(loss_list_bytes, curr_bytes...) + } else { + curr_bytes := make([]byte, 8) + binary.BigEndian.PutUint32(curr_bytes[:4], pkts.start) + binary.BigEndian.PutUint32(curr_bytes[4:8], pkts.end) + if (pkts.end - pkts.start) % uint32(1 << 31) != 1 { + curr_bytes[0] |= 0x80 + } + loss_list_bytes = append(loss_list_bytes, curr_bytes...) + } + } + return loss_list_bytes +} + func marshall_ack_cif(data *ACKCIF) ([]byte) { cif := make([]byte, 28) binary.BigEndian.PutUint32(cif[:4], data.last_acked) diff --git a/srt/protocol.go b/srt/protocol.go index 54e7c17..4522bda 100644 --- a/srt/protocol.go +++ b/srt/protocol.go @@ -25,6 +25,7 @@ type SRTManager struct { storage *DatumStorage ack_idx uint32 pings [][2]time.Time + last_nack time.Time ping_offset int pkt_sizes []uint32 bw uint32 @@ -220,6 +221,24 @@ func (agent *SRTManager) handle_ackack(packet *Packet) { agent.pings[int(ack_num) - agent.ping_offset][1] = time.Now() } +func (agent *SRTManager) create_nack_report() (*Packet) { + agent.last_nack = time.Now() + cif, ok := agent.storage.GenNACKCIF() + if !ok { + return nil + } + packet := agent.create_basic_header() + packet.packet_type = NAK + + info := new(ControlHeader) + info.ctrl_type = 3 + packet.header_info = info + + packet.cif = cif + + return packet +} + func (agent *SRTManager) process_data(packet *Packet) (*Packet) { switch packet.packet_type { case DATA: @@ -227,6 +246,9 @@ func (agent *SRTManager) process_data(packet *Packet) (*Packet) { if time.Now().Sub(agent.pings[len(agent.pings) - 1][0]).Milliseconds() >= 10 { return agent.create_ack_report() } + if agent.last_nack.IsZero() || time.Now().Sub(agent.last_nack).Milliseconds() >= 10 { + return agent.create_nack_report() + } case ACKACK: agent.handle_ackack(packet) default: @@ -245,6 +267,7 @@ func (agent *SRTManager) handle_data_storage(packet *Packet) { if len(agent.storage.offshoots) != 0 { agent.storage.Relink() } + agent.storage.Expunge(agent.output) } func (agent *SRTManager) Process(packet *Packet) (*Packet, error) { diff --git a/srt/tunnel.go b/srt/tunnel.go index 2fd169d..ac0acbd 100644 --- a/srt/tunnel.go +++ b/srt/tunnel.go @@ -46,6 +46,9 @@ func (tunnel *Tunnel) Shutdown() { info.ctrl_type = 5 packet.header_info = info tunnel.WritePacket(packet) + if tunnel.state.output != nil { + tunnel.state.output.Close() + } } }