diff --git a/srt/packet.go b/srt/packet.go index 6747047..d1c0c94 100644 --- a/srt/packet.go +++ b/srt/packet.go @@ -208,6 +208,8 @@ func parse_ctrl_packet(pkt *Packet, buffer []byte) (error) { return parse_hs_cif(cif, buffer[16:]) } return errors.New("HS not long enough") + case 5: + return errors.New("Shutdown received") } return errors.New("Unexpected control type") } @@ -233,7 +235,7 @@ func parse_hs_cif(cif *HandshakeCIF, buffer []byte) (error) { ext := new(HandshakeExtension) ext.ext_type = binary.BigEndian.Uint16(extensions[:2]) ext.ext_len = uint32(binary.BigEndian.Uint16(extensions[2:4])) * 4 - if len(extensions) <= 4 + int(ext.ext_len) { + if len(extensions) < 4 + int(ext.ext_len) { return errors.New("Extension shorter than advertised") } switch ext.ext_type { diff --git a/srt/protocol.go b/srt/protocol.go index 1973d72..00c6549 100644 --- a/srt/protocol.go +++ b/srt/protocol.go @@ -29,10 +29,15 @@ func NewSRTManager(l net.PacketConn) (*SRTManager) { return agent } -func (agent *SRTManager) create_induction_resp() (*Packet) { +func (agent *SRTManager) create_basic_header() (*Packet) { packet := new(Packet) - packet.timestamp = uint32(time.Now().Sub(agent.init).Milliseconds()) + packet.timestamp = uint32(time.Now().Sub(agent.init).Microseconds()) packet.dest_sock = agent.ctrl_sock_peer + return packet +} + +func (agent *SRTManager) create_induction_resp() (*Packet) { + packet := agent.create_basic_header() packet.packet_type = HANDSHAKE info := new(ControlHeader) @@ -71,7 +76,7 @@ func (agent *SRTManager) process_induction(packet *Packet) (*Packet) { if packet.packet_type == HANDSHAKE { hs_cif := packet.cif.(*HandshakeCIF) if hs_cif.hs_type == 1 { - agent.state = 1 + agent.state = CONCLUSION agent.ctrl_sock_peer = hs_cif.sock_id return agent.create_induction_resp() } @@ -79,10 +84,62 @@ func (agent *SRTManager) process_induction(packet *Packet) (*Packet) { return nil } +func (agent *SRTManager) create_conclusion_resp() (*Packet) { + packet := agent.create_basic_header() + packet.packet_type = HANDSHAKE + + info := new(ControlHeader) + packet.header_info = info + + cif := new(HandshakeCIF) + cif.version = 5 + cif.ext_field = 0x1 + cif.sock_id = 1 + cif.mtu = 1500 + cif.max_flow = 8192 + + ip := agent.socket.LocalAddr().(*net.UDPAddr).IP + for i := 0; i < len(ip); i++ { + cif.peer_ip[i] = ip[i] + } + + hs_ext := new(HandshakeExtension) + hs_ext.ext_type = 2 + hs_ext.ext_len = 12 + hs_msg := new(HSEMSG) + hs_msg.flags = uint32(0x01 | 0x02 | 0x04 | 0x08 | 0x20) + hs_msg.version = uint32(0x00010000) + hs_msg.recv_delay = 120 + hs_msg.send_delay = 120 + hs_ext.ext_contents = hs_msg + + cif.hs_extensions = append(cif.hs_extensions, hs_ext) + + packet.cif = cif + + return packet +} + +func (agent *SRTManager) process_conclusion(packet *Packet) (*Packet) { + if packet.packet_type == HANDSHAKE { + hs_cif := packet.cif.(*HandshakeCIF) + if hs_cif.hs_type == 0xffffffff && hs_cif.syn_cookie == agent.syn_cookie { + agent.state = DATA_LOOP + return agent.create_conclusion_resp() + } + } + return nil +} + func (agent *SRTManager) Process(packet *Packet) (*Packet, error) { switch agent.state { - case 0: + case INDUCTION: return agent.process_induction(packet), nil + case CONCLUSION: + return agent.process_conclusion(packet), nil + case DATA_LOOP: + fmt.Println(packet) + return nil, nil default: return nil, errors.New("State not implemented") } diff --git a/srt/tunnel.go b/srt/tunnel.go index 1f6a514..05f53e4 100644 --- a/srt/tunnel.go +++ b/srt/tunnel.go @@ -29,7 +29,6 @@ func (tunnel *Tunnel) Start() { } response, err := tunnel.state.Process(packet) if err != nil { - fmt.Println("error in processing") tunnel.broken = true } if response != nil {