diff --git a/srt/intake.go b/srt/intake.go index 28cd985..f226336 100644 --- a/srt/intake.go +++ b/srt/intake.go @@ -2,7 +2,6 @@ package srt import ( "net" - "fmt" ) type Intake struct { @@ -28,6 +27,8 @@ func (intake *Intake) NewTunnel(l net.PacketConn, peer net.Addr) (*Tunnel) { tunnel.socket = l tunnel.peer = peer tunnel.queue = make(chan []byte, 10) + intake.tunnels = append(intake.tunnels, tunnel) + go tunnel.Start() return tunnel } return nil @@ -60,8 +61,11 @@ func (intake *Intake) Read() { if tunnel == nil { return } - fmt.Println(string(intake.buffer[:n])) pkt := make([]byte, n) copy(pkt, intake.buffer[:n]) - tunnel.queue <- pkt + select { + case tunnel.queue <- pkt: + default: + tunnel.broken = true + } } diff --git a/srt/packet.go b/srt/packet.go index 4503151..6747047 100644 --- a/srt/packet.go +++ b/srt/packet.go @@ -1,16 +1,289 @@ package srt import ( + "errors" + "encoding/binary" ) -type Packet struct { - raw []byte +const ( + DATA uint8 = iota + HANDSHAKE + ACK + NAK + ACKACK + SHUTDOWN + DROP +) + +type ControlHeader struct { + ctrl_type uint16 + ctrl_subtype uint16 + tsi uint32 } -func MarshallPacket(packet *Packet) ([]byte, error) { - return packet.raw, nil +type DataHeader struct { + seq_num uint32 + msg_flags uint8 + msg_num uint32 +} + +type HandshakeCIF struct { + version uint32 + enc_field uint16 + ext_field uint16 + init_seq_num uint32 + mtu uint32 + max_flow uint32 + hs_type uint32 + sock_id uint32 + syn_cookie uint32 + peer_ip [16]byte + hs_extensions []*HandshakeExtension +} + +type HandshakeExtension struct { + ext_type uint16 + ext_len uint32 + ext_contents interface{} +} + +type HSEMSG struct { + version uint32 + flags uint32 + recv_delay uint16 + send_delay uint16 +} + +type KMMSG struct { + key_type uint8 + key_len uint8 + salt [16]byte + wrapped_key []byte +} + +type ACKCIF struct { + last_acked uint32 + rtt uint32 + var_rtt uint32 + buff_size uint32 + pkt_recv_rate uint32 + bw uint32 + rcv_rate uint32 +} + +type NACKCIF struct { + lost_pkts []*pckts_range +} + +type pckts_range struct { + start uint32 + end uint32 +} + +type DROPCIF struct { + to_drop pckts_range +} + +type Packet struct { + packet_type uint8 + timestamp uint32 + dest_sock uint32 + header_info interface{} + cif interface{} +} + +func marshall_data_packet(packet *Packet, header []byte) ([]byte, error) { + info, ok_head := packet.header_info.(*DataHeader) + data, ok_data := packet.cif.([]byte) + if !ok_head || !ok_data { + return header, errors.New("data packet does not have data header") + } + binary.BigEndian.PutUint32(header[:4], info.seq_num) + head2 := (uint32(info.msg_flags) << 26) + info.msg_num + binary.BigEndian.PutUint32(header[4:8], head2) + return append(header, data...), nil +} + +func marshall_ctrl_packet(packet *Packet, header []byte) ([]byte, error) { + _, ok_head := packet.header_info.(*ControlHeader) + if !ok_head { + return header, errors.New("control packet does not have ctrl header") + } + header[0] = 0x80 + switch packet.packet_type { + case HANDSHAKE: + data, ok_data := packet.cif.(*HandshakeCIF) + if !ok_data { + return header, errors.New("Handshake has no data") + } + cif := marshall_hs_cif(data) + return append(header, cif...), nil + } + return header, errors.New("Control packet type not recognized") +} + +func marshall_hs_cif(data *HandshakeCIF) ([]byte) { + cif := make([]byte, 48) + binary.BigEndian.PutUint32(cif[:4], data.version) + binary.BigEndian.PutUint16(cif[4:6], data.enc_field) + binary.BigEndian.PutUint16(cif[6:8], data.ext_field) + binary.BigEndian.PutUint32(cif[8:12], data.init_seq_num) + binary.BigEndian.PutUint32(cif[12:16], data.mtu) + binary.BigEndian.PutUint32(cif[16:20], data.max_flow) + binary.BigEndian.PutUint32(cif[20:24], data.hs_type) + binary.BigEndian.PutUint32(cif[24:28], data.sock_id) + binary.BigEndian.PutUint32(cif[28:32], data.syn_cookie) + for i := 0; i < 16; i++ { + cif[32 + i] = data.peer_ip[i] + } + for _, extension := range data.hs_extensions { + ext_buff := make([]byte, int(extension.ext_len) + 4) + binary.BigEndian.PutUint16(ext_buff[:2], extension.ext_type) + binary.BigEndian.PutUint16(ext_buff[2:4], uint16(extension.ext_len / 4)) + switch extension.ext_type { + case 2: + contents := extension.ext_contents.(*HSEMSG) + binary.BigEndian.PutUint32(ext_buff[4:8], contents.version) + binary.BigEndian.PutUint32(ext_buff[8:12], contents.flags) + binary.BigEndian.PutUint16(ext_buff[12:14], contents.recv_delay) + binary.BigEndian.PutUint16(ext_buff[14:16], contents.send_delay) + case 4: + contents := extension.ext_contents.(*KMMSG) + binary.BigEndian.PutUint32(ext_buff[4:8], uint32(0x12202900) | uint32(contents.key_type)) + binary.BigEndian.PutUint32(ext_buff[12:16], uint32(0x02000200)) + binary.BigEndian.PutUint32(ext_buff[16:20], uint32(0x0400) | uint32(contents.key_len / 4)) + for i := 0; i < 16; i++ { + ext_buff[20 + i] = contents.salt[i] + } + copy(ext_buff[36:], contents.wrapped_key) + default: + copy(ext_buff[4:], extension.ext_contents.([]byte)) + } + cif = append(cif, ext_buff...) + } + return cif +} + +func MarshallPacket(packet *Packet, agent *SRTManager) ([]byte, error) { + header := make([]byte, 16) + binary.BigEndian.PutUint32(header[8:12], packet.timestamp) + binary.BigEndian.PutUint32(header[12:16], packet.dest_sock) + if packet.packet_type == DATA { + return marshall_data_packet(packet, header) + } else { + return marshall_ctrl_packet(packet, header) + } +} + +func parse_data_packet(pkt *Packet, buffer []byte) (error) { + info := new(DataHeader) + info.seq_num = binary.BigEndian.Uint32(buffer[:4]) + info.msg_flags = buffer[4] >> 2 + info.msg_num = binary.BigEndian.Uint32(buffer[4:8]) & 0x03ffffff + + pkt.header_info = info + if len(buffer) > 16 { + data := make([]byte, len(buffer) - 16) + copy(data, buffer[16:]) + pkt.cif = data + return nil + } + return errors.New("Data Packet has no data") +} + +func parse_ctrl_packet(pkt *Packet, buffer []byte) (error) { + info := new(ControlHeader) + info.ctrl_type = binary.BigEndian.Uint16(buffer[:2]) & 0x7fff + info.ctrl_subtype = binary.BigEndian.Uint16(buffer[2:4]) + info.tsi = binary.BigEndian.Uint32(buffer[4:8]) + + pkt.header_info = info + + switch info.ctrl_type { + case 0: + pkt.packet_type = HANDSHAKE + if len(buffer) >= 64 { + cif := new(HandshakeCIF) + pkt.cif = cif + return parse_hs_cif(cif, buffer[16:]) + } + return errors.New("HS not long enough") + } + return errors.New("Unexpected control type") +} + +func parse_hs_cif(cif *HandshakeCIF, buffer []byte) (error) { + cif.version = binary.BigEndian.Uint32(buffer[:4]) + cif.enc_field = binary.BigEndian.Uint16(buffer[4:6]) + cif.ext_field = binary.BigEndian.Uint16(buffer[6:8]) + cif.init_seq_num = binary.BigEndian.Uint32(buffer[8:12]) + cif.mtu = binary.BigEndian.Uint32(buffer[12:16]) + cif.max_flow = binary.BigEndian.Uint32(buffer[16:20]) + cif.hs_type = binary.BigEndian.Uint32(buffer[20:24]) + cif.sock_id = binary.BigEndian.Uint32(buffer[24:28]) + cif.syn_cookie = binary.BigEndian.Uint32(buffer[28:32]) + for i := 0; i < 16; i++ { + cif.peer_ip[i] = buffer[32 + i] + } + extensions := buffer[48:] + for len(extensions) != 0 { + if len(extensions) <= 4 { + return errors.New("Extension present, shorter than header") + } + ext := new(HandshakeExtension) + ext.ext_type = binary.BigEndian.Uint16(extensions[:2]) + ext.ext_len = uint32(binary.BigEndian.Uint16(extensions[2:4])) * 4 + if len(extensions) <= 4 + int(ext.ext_len) { + return errors.New("Extension shorter than advertised") + } + switch ext.ext_type { + case 1: + content := new(HSEMSG) + content.version = binary.BigEndian.Uint32(extensions[4:8]) + content.flags = binary.BigEndian.Uint32(extensions[8:12]) + content.recv_delay = binary.BigEndian.Uint16(extensions[12:14]) + content.send_delay = binary.BigEndian.Uint16(extensions[14:16]) + ext.ext_contents = content + case 3: + content := new(KMMSG) + content.key_type = extensions[7] & 0x3 + content.key_len = extensions[19] + for i := 0; i < 4; i++ { + content.salt[i] = extensions[20 + i] + } + wrap_key_len := 4 + ext.ext_len - 24 + content.wrapped_key = make([]byte, wrap_key_len) + copy(content.wrapped_key, extensions[24:24 + wrap_key_len]) + ext.ext_contents = content + default: + content := make([]byte, ext.ext_len) + copy(content, extensions[4:4 + ext.ext_len]) + ext.ext_contents = content + } + cif.hs_extensions = append(cif.hs_extensions, ext) + extensions = extensions[4 + ext.ext_len:] + } + return nil } func ParsePacket(buffer []byte) (*Packet, error) { - return &Packet{buffer}, nil + if len(buffer) < 16 { + return nil, errors.New("packet too short") + } + pkt := new(Packet) + pkt.timestamp = binary.BigEndian.Uint32(buffer[8:12]) + pkt.dest_sock = binary.BigEndian.Uint32(buffer[12:16]) + + if buffer[0] >> 7 == 0 { + err := parse_data_packet(pkt, buffer) + if err != nil { + return pkt, err + } + } else { + err := parse_ctrl_packet(pkt, buffer) + if err != nil { + return pkt, err + } + } + return pkt, nil } diff --git a/srt/protocol.go b/srt/protocol.go index fabcda2..1973d72 100644 --- a/srt/protocol.go +++ b/srt/protocol.go @@ -1,7 +1,11 @@ package srt import ( + "time" + "net" + "crypto/sha256" "fmt" + "errors" ) const ( @@ -12,9 +16,74 @@ const ( type SRTManager struct { state uint8 + init time.Time + syn_cookie uint32 + socket net.PacketConn + ctrl_sock_peer uint32 } -func (proto *SRTManager) Decide(packet *Packet) (*Packet, error) { - fmt.Println(*packet) - return nil, nil +func NewSRTManager(l net.PacketConn) (*SRTManager) { + agent := new(SRTManager) + agent.init = time.Now() + agent.socket = l + return agent +} + +func (agent *SRTManager) create_induction_resp() (*Packet) { + packet := new(Packet) + packet.timestamp = uint32(time.Now().Sub(agent.init).Milliseconds()) + packet.dest_sock = agent.ctrl_sock_peer + packet.packet_type = HANDSHAKE + + info := new(ControlHeader) + packet.header_info = info + + cif := new(HandshakeCIF) + cif.version = 5 + cif.ext_field = 0x4a17 + cif.hs_type = 1 + cif.syn_cookie = agent.syn_cookie + cif.sock_id = 1 + cif.mtu = 1500 + cif.max_flow = 8192 + + ip := agent.socket.LocalAddr().(*net.UDPAddr).IP + for i := 0; i < len(ip); i++ { + cif.peer_ip[i] = ip[i] + } + + packet.cif = cif + + return packet +} + +func (agent *SRTManager) make_syn_cookie(peer net.Addr) { + t := uint32(time.Now().Unix()) >> 6 + s := sha256.New() + s.Write([]byte(peer.String() + fmt.Sprintf("%d", t))) + agent.syn_cookie = (agent.syn_cookie + t % 32) << 3 + for _, v := range s.Sum(nil)[29:] { + agent.syn_cookie = (agent.syn_cookie << 8) + uint32(v) + } +} + +func (agent *SRTManager) process_induction(packet *Packet) (*Packet) { + if packet.packet_type == HANDSHAKE { + hs_cif := packet.cif.(*HandshakeCIF) + if hs_cif.hs_type == 1 { + agent.state = 1 + agent.ctrl_sock_peer = hs_cif.sock_id + return agent.create_induction_resp() + } + } + return nil +} + +func (agent *SRTManager) Process(packet *Packet) (*Packet, error) { + switch agent.state { + case 0: + return agent.process_induction(packet), nil + default: + return nil, errors.New("State not implemented") + } } diff --git a/srt/tunnel.go b/srt/tunnel.go index 8c4e4ff..1f6a514 100644 --- a/srt/tunnel.go +++ b/srt/tunnel.go @@ -2,6 +2,7 @@ package srt import ( "net" + "fmt" ) type Tunnel struct { @@ -13,14 +14,22 @@ type Tunnel struct { } func (tunnel *Tunnel) Start() { - tunnel.state = new(SRTManager) + defer func(a *bool) { + if r := recover(); r != nil { + fmt.Println(r) + } + *a = true + }(&(tunnel.broken)) + tunnel.state = NewSRTManager(tunnel.socket) + tunnel.state.make_syn_cookie(tunnel.peer) for { packet, err := tunnel.ReadPacket() if err != nil { tunnel.broken = true } - response, err := tunnel.state.Decide(packet) + response, err := tunnel.state.Process(packet) if err != nil { + fmt.Println("error in processing") tunnel.broken = true } if response != nil { @@ -30,7 +39,7 @@ func (tunnel *Tunnel) Start() { } func (tunnel *Tunnel) WritePacket(packet *Packet) { - buffer, err := MarshallPacket(packet) + buffer, err := MarshallPacket(packet, tunnel.state) if err != nil { tunnel.broken = true return