From 32f414e19dfa12ed6911b66676bab8e79f0995f0 Mon Sep 17 00:00:00 2001 From: Muaz Ahmad Date: Wed, 20 Sep 2023 15:06:32 +0500 Subject: [PATCH] basic data packet handling --- srt/data_chain.go | 116 ++++++++++++++++++++++++++++++++++++++++++++++ srt/intake.go | 5 +- srt/packet.go | 4 ++ srt/protocol.go | 42 ++++++++++++++--- srt/server.go | 6 +++ srt/tunnel.go | 14 +++++- 6 files changed, 179 insertions(+), 8 deletions(-) create mode 100644 srt/data_chain.go diff --git a/srt/data_chain.go b/srt/data_chain.go new file mode 100644 index 0000000..d1a34d7 --- /dev/null +++ b/srt/data_chain.go @@ -0,0 +1,116 @@ +package srt + +import ( + "math" + "sort" +) + +type Datum struct { + seq_num uint32 + data []byte + next *Datum +} + +type DatumLink struct { + queued int + root *Datum + end *Datum +} + +type chains []*DatumLink + +func (c chains) Len() (int) { + return len(c) +} + +func (c chains) Swap(i, j int) { + c[i], c[j] = c[j], c[i] +} + +func (c chains) Less(i, j int) (bool) { + x_1 := c[i].end.seq_num + x_2 := c[j].root.seq_num + serial_add_limit := uint32(math.Pow(2, 30)) + if (x_1 < x_2 && x_2 - x_1 < serial_add_limit) || (x_1 > x_2 && x_1 - x_2 > serial_add_limit) { + return true + } + return false +} + +type DatumStorage struct { + main *DatumLink + offshoots chains +} + +func (buffer *DatumLink) NewDatum(pkt *Packet) { + datum := new(Datum) + datum.seq_num = pkt.header_info.(*DataHeader).seq_num + datum.data = pkt.cif.([]byte) + + buffer.queued += 1 + buffer.end.next = datum + buffer.end = datum +} + +func NewDatumLink(pkt *Packet) (*DatumLink) { + buffer := new(DatumLink) + root_datum := new(Datum) + root_datum.seq_num = pkt.header_info.(*DataHeader).seq_num + root_datum.data = pkt.cif.([]byte) + + buffer.root = root_datum + buffer.end = root_datum + buffer.queued = 1 + return buffer +} + +func NewDatumStorage(packet *Packet) (*DatumStorage) { + storage := new(DatumStorage) + storage.main = NewDatumLink(packet) + return storage +} + +func (storage *DatumStorage) NewDatum(pkt *Packet) { + prev_num := (pkt.header_info.(*DataHeader).seq_num - 1) % uint32(math.Pow(2, 31)) + if storage.main.end.seq_num == prev_num { + storage.main.NewDatum(pkt) + } else { + for _, v := range storage.offshoots { + if v.end.seq_num == prev_num { + v.NewDatum(pkt) + break + } + } + } +} + +func (buffer *DatumLink) Link(buffer_next *DatumLink) { + buffer.end.next = buffer_next.root + buffer.end = buffer_next.end + buffer.queued += buffer_next.queued +} + +func check_append_serial_next(buffer *DatumLink, buffer_next *DatumLink) (bool) { + seq_1 := buffer.end.seq_num + seq_2 := (seq_1 + 1) % uint32(math.Pow(2, 31)) + if seq_1 == seq_2 { + buffer.Link(buffer_next) + return true + } + return false +} + +func (storage *DatumStorage) Relink() { + sort.Sort(storage.offshoots) + buffer := storage.main + i := 0 + for i < len(storage.offshoots) { + if check_append_serial_next(buffer, storage.offshoots[i]) { + storage.offshoots = append(storage.offshoots[:i], storage.offshoots[i + 1:]...) + } else { + buffer = storage.offshoots[i] + i++ + } + } +} + diff --git a/srt/intake.go b/srt/intake.go index f226336..6f81bbf 100644 --- a/srt/intake.go +++ b/srt/intake.go @@ -26,7 +26,7 @@ func (intake *Intake) NewTunnel(l net.PacketConn, peer net.Addr) (*Tunnel) { tunnel := new(Tunnel) tunnel.socket = l tunnel.peer = peer - tunnel.queue = make(chan []byte, 10) + tunnel.queue = make(chan []byte, 100) intake.tunnels = append(intake.tunnels, tunnel) go tunnel.Start() return tunnel @@ -38,6 +38,7 @@ func (intake *Intake) get_tunnel(peer net.Addr) (*Tunnel) { var tunnel *Tunnel for i := 0; i < len(intake.tunnels); i++ { if intake.tunnels[i].broken { + intake.tunnels[i].Shutdown() intake.tunnels = append(intake.tunnels[:i], intake.tunnels[i+1:]...) i-- continue @@ -52,6 +53,8 @@ func (intake *Intake) get_tunnel(peer net.Addr) (*Tunnel) { return tunnel } + + func (intake *Intake) Read() { n, peer, err := intake.socket.ReadFrom(intake.buffer) if err != nil { diff --git a/srt/packet.go b/srt/packet.go index d1c0c94..b58c401 100644 --- a/srt/packet.go +++ b/srt/packet.go @@ -118,6 +118,10 @@ func marshall_ctrl_packet(packet *Packet, header []byte) ([]byte, error) { } cif := marshall_hs_cif(data) return append(header, cif...), nil + case SHUTDOWN: + header[1] |= 5 + cif := make([]byte, 4) + return append(header, cif...), nil } return header, errors.New("Control packet type not recognized") } diff --git a/srt/protocol.go b/srt/protocol.go index 00c6549..38a7cbd 100644 --- a/srt/protocol.go +++ b/srt/protocol.go @@ -20,12 +20,15 @@ type SRTManager struct { syn_cookie uint32 socket net.PacketConn ctrl_sock_peer uint32 + storage *DatumStorage + ack_interval uint8 } func NewSRTManager(l net.PacketConn) (*SRTManager) { agent := new(SRTManager) agent.init = time.Now() agent.socket = l + agent.ack_interval = 25 return agent } @@ -72,16 +75,16 @@ func (agent *SRTManager) make_syn_cookie(peer net.Addr) { } } -func (agent *SRTManager) process_induction(packet *Packet) (*Packet) { +func (agent *SRTManager) process_induction(packet *Packet) (*Packet, error) { if packet.packet_type == HANDSHAKE { hs_cif := packet.cif.(*HandshakeCIF) if hs_cif.hs_type == 1 { agent.state = CONCLUSION agent.ctrl_sock_peer = hs_cif.sock_id - return agent.create_induction_resp() + return agent.create_induction_resp(), nil } } - return nil + return nil, errors.New("Packet was not handshake") } func (agent *SRTManager) create_conclusion_resp() (*Packet) { @@ -131,15 +134,42 @@ func (agent *SRTManager) process_conclusion(packet *Packet) (*Packet) { return nil } +func (agent *SRTManager) create_ack_report() (*Packet) { + return nil +} + +func (agent *SRTManager) process_data(packet *Packet) (*Packet) { + switch packet.packet_type { + case DATA: + agent.handle_data_storage(packet) + if agent.storage.main.queued >= 25 { + return agent.create_ack_report() + } + default: + return nil + } + return nil +} + +func (agent *SRTManager) handle_data_storage(packet *Packet) { + if agent.storage == nil { + agent.storage = NewDatumStorage(packet) + } else { + agent.storage.NewDatum(packet) + } + if len(agent.storage.offshoots) != 0 { + agent.storage.Relink() + } +} + func (agent *SRTManager) Process(packet *Packet) (*Packet, error) { switch agent.state { case INDUCTION: - return agent.process_induction(packet), nil + return agent.process_induction(packet) case CONCLUSION: return agent.process_conclusion(packet), nil case DATA_LOOP: - fmt.Println(packet) - return nil, nil + return agent.process_data(packet), nil default: return nil, errors.New("State not implemented") } diff --git a/srt/server.go b/srt/server.go index 556d1b6..4fd0f86 100644 --- a/srt/server.go +++ b/srt/server.go @@ -2,6 +2,7 @@ package srt import ( "net" + "fmt" ) func NewServer(port string) (error) { @@ -14,6 +15,11 @@ func NewServer(port string) (error) { } func start(l net.PacketConn) { + defer func() { + if r := recover(); r != nil { + fmt.Println(r) + } + }() intake := NewIntake(l, 1) for { intake.Read() diff --git a/srt/tunnel.go b/srt/tunnel.go index 05f53e4..7f84547 100644 --- a/srt/tunnel.go +++ b/srt/tunnel.go @@ -21,7 +21,6 @@ func (tunnel *Tunnel) Start() { *a = true }(&(tunnel.broken)) tunnel.state = NewSRTManager(tunnel.socket) - tunnel.state.make_syn_cookie(tunnel.peer) for { packet, err := tunnel.ReadPacket() if err != nil { @@ -29,6 +28,7 @@ func (tunnel *Tunnel) Start() { } response, err := tunnel.state.Process(packet) if err != nil { + fmt.Println(err) tunnel.broken = true } if response != nil { @@ -37,10 +37,22 @@ func (tunnel *Tunnel) Start() { } } +func (tunnel *Tunnel) Shutdown() { + if tunnel.state != nil && tunnel.state.state > 1 { + packet := tunnel.state.create_basic_header() + packet.packet_type = SHUTDOWN + info := new(ControlHeader) + info.ctrl_type = 5 + packet.header_info = info + tunnel.WritePacket(packet) + } +} + func (tunnel *Tunnel) WritePacket(packet *Packet) { buffer, err := MarshallPacket(packet, tunnel.state) if err != nil { tunnel.broken = true + fmt.Println(err) return } tunnel.socket.WriteTo(buffer, tunnel.peer)