basic data packet handling

This commit is contained in:
Muaz Ahmad 2023-09-20 15:06:32 +05:00
parent 19c601e9be
commit 32f414e19d
6 changed files with 179 additions and 8 deletions

116
srt/data_chain.go Normal file
View file

@ -0,0 +1,116 @@
package srt
import (
"math"
"sort"
)
type Datum struct {
seq_num uint32
data []byte
next *Datum
}
type DatumLink struct {
queued int
root *Datum
end *Datum
}
type chains []*DatumLink
func (c chains) Len() (int) {
return len(c)
}
func (c chains) Swap(i, j int) {
c[i], c[j] = c[j], c[i]
}
func (c chains) Less(i, j int) (bool) {
x_1 := c[i].end.seq_num
x_2 := c[j].root.seq_num
serial_add_limit := uint32(math.Pow(2, 30))
if (x_1 < x_2 && x_2 - x_1 < serial_add_limit) || (x_1 > x_2 && x_1 - x_2 > serial_add_limit) {
return true
}
return false
}
type DatumStorage struct {
main *DatumLink
offshoots chains
}
func (buffer *DatumLink) NewDatum(pkt *Packet) {
datum := new(Datum)
datum.seq_num = pkt.header_info.(*DataHeader).seq_num
datum.data = pkt.cif.([]byte)
buffer.queued += 1
buffer.end.next = datum
buffer.end = datum
}
func NewDatumLink(pkt *Packet) (*DatumLink) {
buffer := new(DatumLink)
root_datum := new(Datum)
root_datum.seq_num = pkt.header_info.(*DataHeader).seq_num
root_datum.data = pkt.cif.([]byte)
buffer.root = root_datum
buffer.end = root_datum
buffer.queued = 1
return buffer
}
func NewDatumStorage(packet *Packet) (*DatumStorage) {
storage := new(DatumStorage)
storage.main = NewDatumLink(packet)
return storage
}
func (storage *DatumStorage) NewDatum(pkt *Packet) {
prev_num := (pkt.header_info.(*DataHeader).seq_num - 1) % uint32(math.Pow(2, 31))
if storage.main.end.seq_num == prev_num {
storage.main.NewDatum(pkt)
} else {
for _, v := range storage.offshoots {
if v.end.seq_num == prev_num {
v.NewDatum(pkt)
break
}
}
}
}
func (buffer *DatumLink) Link(buffer_next *DatumLink) {
buffer.end.next = buffer_next.root
buffer.end = buffer_next.end
buffer.queued += buffer_next.queued
}
func check_append_serial_next(buffer *DatumLink, buffer_next *DatumLink) (bool) {
seq_1 := buffer.end.seq_num
seq_2 := (seq_1 + 1) % uint32(math.Pow(2, 31))
if seq_1 == seq_2 {
buffer.Link(buffer_next)
return true
}
return false
}
func (storage *DatumStorage) Relink() {
sort.Sort(storage.offshoots)
buffer := storage.main
i := 0
for i < len(storage.offshoots) {
if check_append_serial_next(buffer, storage.offshoots[i]) {
storage.offshoots = append(storage.offshoots[:i], storage.offshoots[i + 1:]...)
} else {
buffer = storage.offshoots[i]
i++
}
}
}

View file

@ -26,7 +26,7 @@ func (intake *Intake) NewTunnel(l net.PacketConn, peer net.Addr) (*Tunnel) {
tunnel := new(Tunnel) tunnel := new(Tunnel)
tunnel.socket = l tunnel.socket = l
tunnel.peer = peer tunnel.peer = peer
tunnel.queue = make(chan []byte, 10) tunnel.queue = make(chan []byte, 100)
intake.tunnels = append(intake.tunnels, tunnel) intake.tunnels = append(intake.tunnels, tunnel)
go tunnel.Start() go tunnel.Start()
return tunnel return tunnel
@ -38,6 +38,7 @@ func (intake *Intake) get_tunnel(peer net.Addr) (*Tunnel) {
var tunnel *Tunnel var tunnel *Tunnel
for i := 0; i < len(intake.tunnels); i++ { for i := 0; i < len(intake.tunnels); i++ {
if intake.tunnels[i].broken { if intake.tunnels[i].broken {
intake.tunnels[i].Shutdown()
intake.tunnels = append(intake.tunnels[:i], intake.tunnels[i+1:]...) intake.tunnels = append(intake.tunnels[:i], intake.tunnels[i+1:]...)
i-- i--
continue continue
@ -52,6 +53,8 @@ func (intake *Intake) get_tunnel(peer net.Addr) (*Tunnel) {
return tunnel return tunnel
} }
func (intake *Intake) Read() { func (intake *Intake) Read() {
n, peer, err := intake.socket.ReadFrom(intake.buffer) n, peer, err := intake.socket.ReadFrom(intake.buffer)
if err != nil { if err != nil {

View file

@ -118,6 +118,10 @@ func marshall_ctrl_packet(packet *Packet, header []byte) ([]byte, error) {
} }
cif := marshall_hs_cif(data) cif := marshall_hs_cif(data)
return append(header, cif...), nil return append(header, cif...), nil
case SHUTDOWN:
header[1] |= 5
cif := make([]byte, 4)
return append(header, cif...), nil
} }
return header, errors.New("Control packet type not recognized") return header, errors.New("Control packet type not recognized")
} }

View file

@ -20,12 +20,15 @@ type SRTManager struct {
syn_cookie uint32 syn_cookie uint32
socket net.PacketConn socket net.PacketConn
ctrl_sock_peer uint32 ctrl_sock_peer uint32
storage *DatumStorage
ack_interval uint8
} }
func NewSRTManager(l net.PacketConn) (*SRTManager) { func NewSRTManager(l net.PacketConn) (*SRTManager) {
agent := new(SRTManager) agent := new(SRTManager)
agent.init = time.Now() agent.init = time.Now()
agent.socket = l agent.socket = l
agent.ack_interval = 25
return agent return agent
} }
@ -72,16 +75,16 @@ func (agent *SRTManager) make_syn_cookie(peer net.Addr) {
} }
} }
func (agent *SRTManager) process_induction(packet *Packet) (*Packet) { func (agent *SRTManager) process_induction(packet *Packet) (*Packet, error) {
if packet.packet_type == HANDSHAKE { if packet.packet_type == HANDSHAKE {
hs_cif := packet.cif.(*HandshakeCIF) hs_cif := packet.cif.(*HandshakeCIF)
if hs_cif.hs_type == 1 { if hs_cif.hs_type == 1 {
agent.state = CONCLUSION agent.state = CONCLUSION
agent.ctrl_sock_peer = hs_cif.sock_id agent.ctrl_sock_peer = hs_cif.sock_id
return agent.create_induction_resp() return agent.create_induction_resp(), nil
} }
} }
return nil return nil, errors.New("Packet was not handshake")
} }
func (agent *SRTManager) create_conclusion_resp() (*Packet) { func (agent *SRTManager) create_conclusion_resp() (*Packet) {
@ -131,15 +134,42 @@ func (agent *SRTManager) process_conclusion(packet *Packet) (*Packet) {
return nil return nil
} }
func (agent *SRTManager) create_ack_report() (*Packet) {
return nil
}
func (agent *SRTManager) process_data(packet *Packet) (*Packet) {
switch packet.packet_type {
case DATA:
agent.handle_data_storage(packet)
if agent.storage.main.queued >= 25 {
return agent.create_ack_report()
}
default:
return nil
}
return nil
}
func (agent *SRTManager) handle_data_storage(packet *Packet) {
if agent.storage == nil {
agent.storage = NewDatumStorage(packet)
} else {
agent.storage.NewDatum(packet)
}
if len(agent.storage.offshoots) != 0 {
agent.storage.Relink()
}
}
func (agent *SRTManager) Process(packet *Packet) (*Packet, error) { func (agent *SRTManager) Process(packet *Packet) (*Packet, error) {
switch agent.state { switch agent.state {
case INDUCTION: case INDUCTION:
return agent.process_induction(packet), nil return agent.process_induction(packet)
case CONCLUSION: case CONCLUSION:
return agent.process_conclusion(packet), nil return agent.process_conclusion(packet), nil
case DATA_LOOP: case DATA_LOOP:
fmt.Println(packet) return agent.process_data(packet), nil
return nil, nil
default: default:
return nil, errors.New("State not implemented") return nil, errors.New("State not implemented")
} }

View file

@ -2,6 +2,7 @@ package srt
import ( import (
"net" "net"
"fmt"
) )
func NewServer(port string) (error) { func NewServer(port string) (error) {
@ -14,6 +15,11 @@ func NewServer(port string) (error) {
} }
func start(l net.PacketConn) { func start(l net.PacketConn) {
defer func() {
if r := recover(); r != nil {
fmt.Println(r)
}
}()
intake := NewIntake(l, 1) intake := NewIntake(l, 1)
for { for {
intake.Read() intake.Read()

View file

@ -21,7 +21,6 @@ func (tunnel *Tunnel) Start() {
*a = true *a = true
}(&(tunnel.broken)) }(&(tunnel.broken))
tunnel.state = NewSRTManager(tunnel.socket) tunnel.state = NewSRTManager(tunnel.socket)
tunnel.state.make_syn_cookie(tunnel.peer)
for { for {
packet, err := tunnel.ReadPacket() packet, err := tunnel.ReadPacket()
if err != nil { if err != nil {
@ -29,6 +28,7 @@ func (tunnel *Tunnel) Start() {
} }
response, err := tunnel.state.Process(packet) response, err := tunnel.state.Process(packet)
if err != nil { if err != nil {
fmt.Println(err)
tunnel.broken = true tunnel.broken = true
} }
if response != nil { if response != nil {
@ -37,10 +37,22 @@ func (tunnel *Tunnel) Start() {
} }
} }
func (tunnel *Tunnel) Shutdown() {
if tunnel.state != nil && tunnel.state.state > 1 {
packet := tunnel.state.create_basic_header()
packet.packet_type = SHUTDOWN
info := new(ControlHeader)
info.ctrl_type = 5
packet.header_info = info
tunnel.WritePacket(packet)
}
}
func (tunnel *Tunnel) WritePacket(packet *Packet) { func (tunnel *Tunnel) WritePacket(packet *Packet) {
buffer, err := MarshallPacket(packet, tunnel.state) buffer, err := MarshallPacket(packet, tunnel.state)
if err != nil { if err != nil {
tunnel.broken = true tunnel.broken = true
fmt.Println(err)
return return
} }
tunnel.socket.WriteTo(buffer, tunnel.peer) tunnel.socket.WriteTo(buffer, tunnel.peer)