file writing, lost packet handling

This commit is contained in:
Muaz Ahmad 2023-09-21 16:11:27 +05:00
parent e1852b6ae0
commit c2d9ebbeb4
5 changed files with 124 additions and 9 deletions

View file

@ -3,6 +3,7 @@ package srt
import ( import (
"math" "math"
"sort" "sort"
"io"
) )
type Datum struct { type Datum struct {
@ -70,17 +71,57 @@ func NewDatumStorage(packet *Packet) (*DatumStorage) {
return storage return storage
} }
func (storage *DatumStorage) Expunge(output io.WriteCloser) (error) {
curr_datum := storage.main.root
seq_num_end := storage.main.end.seq_num
for curr_datum.seq_num != seq_num_end {
_, err := output.Write(curr_datum.data)
if err != nil {
return err
}
curr_datum = curr_datum.next
}
storage.main.root = curr_datum
return nil
}
func (buffer *DatumLink) InRange(new_seq_num uint32) (bool) {
start := buffer.root.seq_num
end := buffer.end.seq_num
if new_seq_num == start || new_seq_num == end {
return true
}
if !((start < new_seq_num && new_seq_num - start < (1 << 30)) || (start > new_seq_num && start - new_seq_num > (1 << 30))) {
return false
}
if !((new_seq_num < end && end - new_seq_num < (1 << 30)) || (new_seq_num > end && new_seq_num - end > (1 << 30))) {
return false
}
return true
}
func (storage *DatumStorage) NewDatum(pkt *Packet) { func (storage *DatumStorage) NewDatum(pkt *Packet) {
prev_num := (pkt.header_info.(*DataHeader).seq_num - 1) % uint32(math.Pow(2, 31)) new_pkt_num := pkt.header_info.(*DataHeader).seq_num
prev_num := (new_pkt_num - 1) % uint32(1 << 31)
if storage.main.end.seq_num == prev_num { if storage.main.end.seq_num == prev_num {
storage.main.NewDatum(pkt) storage.main.NewDatum(pkt)
} else if storage.main.InRange(new_pkt_num) {
return
} else { } else {
oldest := storage.main.root.seq_num
if (new_pkt_num < oldest && oldest - new_pkt_num < (1 << 30)) || (new_pkt_num > oldest && new_pkt_num - oldest > (1 << 30)) {
return
}
for _, v := range storage.offshoots { for _, v := range storage.offshoots {
if v.end.seq_num == prev_num { if v.end.seq_num == prev_num {
v.NewDatum(pkt) v.NewDatum(pkt)
break return
} else if v.InRange(new_pkt_num) {
return
} }
} }
new_link := NewDatumLink(pkt)
storage.offshoots = append(storage.offshoots, new_link)
} }
} }
@ -93,13 +134,31 @@ func (buffer *DatumLink) Link(buffer_next *DatumLink) {
func check_append_serial_next(buffer *DatumLink, buffer_next *DatumLink) (bool) { func check_append_serial_next(buffer *DatumLink, buffer_next *DatumLink) (bool) {
seq_1 := buffer.end.seq_num seq_1 := buffer.end.seq_num
seq_2 := (seq_1 + 1) % uint32(math.Pow(2, 31)) seq_2 := (seq_1 + 1) % uint32(math.Pow(2, 31))
if seq_1 == seq_2 { if buffer_next.root.seq_num == seq_2 {
buffer.Link(buffer_next) buffer.Link(buffer_next)
return true return true
} }
return false return false
} }
func (storage *DatumStorage) GenNACKCIF() (*NACKCIF, bool) {
if len(storage.offshoots) == 0 {
return nil, false
}
cif := new(NACKCIF)
init_range := new(pckts_range)
init_range.start = (storage.main.end.seq_num + 1) % (1 << 31)
init_range.end = (storage.offshoots[0].root.seq_num - 1) % (1 << 31)
cif.lost_pkts = append(cif.lost_pkts, init_range)
for i := 0; i < len(storage.offshoots) - 1; i++ {
new_range := new(pckts_range)
new_range.start = (storage.offshoots[i].end.seq_num + 1) % (1 << 31)
new_range.end = (storage.offshoots[i + 1].root.seq_num - 1) % (1 << 31)
cif.lost_pkts = append(cif.lost_pkts, new_range)
}
return cif, true
}
func (storage *DatumStorage) Relink() { func (storage *DatumStorage) Relink() {
sort.Sort(storage.offshoots) sort.Sort(storage.offshoots)
buffer := storage.main buffer := storage.main

View file

@ -26,7 +26,7 @@ func (intake *Intake) NewTunnel(l net.PacketConn, peer net.Addr) (*Tunnel) {
tunnel := new(Tunnel) tunnel := new(Tunnel)
tunnel.socket = l tunnel.socket = l
tunnel.peer = peer tunnel.peer = peer
tunnel.queue = make(chan []byte, 100) tunnel.queue = make(chan []byte, 10)
intake.tunnels = append(intake.tunnels, tunnel) intake.tunnels = append(intake.tunnels, tunnel)
go tunnel.Start() go tunnel.Start()
return tunnel return tunnel
@ -69,6 +69,6 @@ func (intake *Intake) Read() {
select { select {
case tunnel.queue <- pkt: case tunnel.queue <- pkt:
default: default:
tunnel.broken = true //tunnel.broken = true
} }
} }

View file

@ -3,6 +3,7 @@ package srt
import ( import (
"errors" "errors"
"encoding/binary" "encoding/binary"
"fmt"
) )
const ( const (
@ -109,7 +110,10 @@ func marshall_ctrl_packet(packet *Packet, header []byte) ([]byte, error) {
if !ok_head { if !ok_head {
return header, errors.New("control packet does not have ctrl header") return header, errors.New("control packet does not have ctrl header")
} }
header[0] = 0x80 binary.BigEndian.PutUint16(header[:2], ctrl_header.ctrl_type)
binary.BigEndian.PutUint16(header[2:4], ctrl_header.ctrl_subtype)
binary.BigEndian.PutUint32(header[4:8], ctrl_header.tsi)
header[0] |= 0x80
switch packet.packet_type { switch packet.packet_type {
case HANDSHAKE: case HANDSHAKE:
data, ok_data := packet.cif.(*HandshakeCIF) data, ok_data := packet.cif.(*HandshakeCIF)
@ -119,22 +123,48 @@ func marshall_ctrl_packet(packet *Packet, header []byte) ([]byte, error) {
cif := marshall_hs_cif(data) cif := marshall_hs_cif(data)
return append(header, cif...), nil return append(header, cif...), nil
case SHUTDOWN: case SHUTDOWN:
header[1] |= 5
cif := make([]byte, 4) cif := make([]byte, 4)
return append(header, cif...), nil return append(header, cif...), nil
case ACK: case ACK:
header[1] |= 2
binary.BigEndian.PutUint32(header[4:8], ctrl_header.tsi)
data, ok_data := packet.cif.(*ACKCIF) data, ok_data := packet.cif.(*ACKCIF)
if !ok_data { if !ok_data {
return header, errors.New("ACK has no data") return header, errors.New("ACK has no data")
} }
cif := marshall_ack_cif(data) cif := marshall_ack_cif(data)
return append(header, cif...), nil return append(header, cif...), nil
case NAK:
data, ok_data := packet.cif.(*NACKCIF)
if !ok_data {
return header, errors.New("NAK has no data")
}
cif := marshall_nack_cif(data)
return append(header, cif...), nil
} }
return header, errors.New("Control packet type not recognized") return header, errors.New("Control packet type not recognized")
} }
func marshall_nack_cif(data *NACKCIF) ([]byte) {
fmt.Println("new NAK")
var loss_list_bytes []byte
for _, pkts := range data.lost_pkts {
fmt.Println(pkts)
if pkts.start == pkts.end {
curr_bytes := make([]byte, 4)
binary.BigEndian.PutUint32(curr_bytes, pkts.start)
loss_list_bytes = append(loss_list_bytes, curr_bytes...)
} else {
curr_bytes := make([]byte, 8)
binary.BigEndian.PutUint32(curr_bytes[:4], pkts.start)
binary.BigEndian.PutUint32(curr_bytes[4:8], pkts.end)
if (pkts.end - pkts.start) % uint32(1 << 31) != 1 {
curr_bytes[0] |= 0x80
}
loss_list_bytes = append(loss_list_bytes, curr_bytes...)
}
}
return loss_list_bytes
}
func marshall_ack_cif(data *ACKCIF) ([]byte) { func marshall_ack_cif(data *ACKCIF) ([]byte) {
cif := make([]byte, 28) cif := make([]byte, 28)
binary.BigEndian.PutUint32(cif[:4], data.last_acked) binary.BigEndian.PutUint32(cif[:4], data.last_acked)

View file

@ -25,6 +25,7 @@ type SRTManager struct {
storage *DatumStorage storage *DatumStorage
ack_idx uint32 ack_idx uint32
pings [][2]time.Time pings [][2]time.Time
last_nack time.Time
ping_offset int ping_offset int
pkt_sizes []uint32 pkt_sizes []uint32
bw uint32 bw uint32
@ -220,6 +221,24 @@ func (agent *SRTManager) handle_ackack(packet *Packet) {
agent.pings[int(ack_num) - agent.ping_offset][1] = time.Now() agent.pings[int(ack_num) - agent.ping_offset][1] = time.Now()
} }
func (agent *SRTManager) create_nack_report() (*Packet) {
agent.last_nack = time.Now()
cif, ok := agent.storage.GenNACKCIF()
if !ok {
return nil
}
packet := agent.create_basic_header()
packet.packet_type = NAK
info := new(ControlHeader)
info.ctrl_type = 3
packet.header_info = info
packet.cif = cif
return packet
}
func (agent *SRTManager) process_data(packet *Packet) (*Packet) { func (agent *SRTManager) process_data(packet *Packet) (*Packet) {
switch packet.packet_type { switch packet.packet_type {
case DATA: case DATA:
@ -227,6 +246,9 @@ func (agent *SRTManager) process_data(packet *Packet) (*Packet) {
if time.Now().Sub(agent.pings[len(agent.pings) - 1][0]).Milliseconds() >= 10 { if time.Now().Sub(agent.pings[len(agent.pings) - 1][0]).Milliseconds() >= 10 {
return agent.create_ack_report() return agent.create_ack_report()
} }
if agent.last_nack.IsZero() || time.Now().Sub(agent.last_nack).Milliseconds() >= 10 {
return agent.create_nack_report()
}
case ACKACK: case ACKACK:
agent.handle_ackack(packet) agent.handle_ackack(packet)
default: default:
@ -245,6 +267,7 @@ func (agent *SRTManager) handle_data_storage(packet *Packet) {
if len(agent.storage.offshoots) != 0 { if len(agent.storage.offshoots) != 0 {
agent.storage.Relink() agent.storage.Relink()
} }
agent.storage.Expunge(agent.output)
} }
func (agent *SRTManager) Process(packet *Packet) (*Packet, error) { func (agent *SRTManager) Process(packet *Packet) (*Packet, error) {

View file

@ -46,6 +46,9 @@ func (tunnel *Tunnel) Shutdown() {
info.ctrl_type = 5 info.ctrl_type = 5
packet.header_info = info packet.header_info = info
tunnel.WritePacket(packet) tunnel.WritePacket(packet)
if tunnel.state.output != nil {
tunnel.state.output.Close()
}
} }
} }