diff --git a/rtmp/chunk.go b/rtmp/chunk.go index a7b0569..dd2e7bf 100644 --- a/rtmp/chunk.go +++ b/rtmp/chunk.go @@ -13,7 +13,6 @@ type chnk_stream struct { } type message struct { -// msg_strm_id uint32 data []byte curr_bytes_read uint32 timestamp uint32 @@ -21,80 +20,166 @@ type message struct { msg_len uint32 } -func OpenStreamsMapInit() (map[uint32]*chnk_stream, map[uint32]*message) { +type chunk_bufs struct { + time []byte + msg_len []byte + msg_typeid []byte + msg_streamid []byte + fmt_csid_byte []byte + csid_true []byte +} + +func new_chunk_bufs() (*chunk_bufs) { + new_chunk_bufs := chunk_bufs{ + make([]byte, 4), + make([]byte, 4), + make([]byte, 1), + make([]byte, 4), + make([]byte, 1), + make([]byte, 2), + } + return &new_chunk_bufs +} + +func OpenStreamsMapInit() (map[uint32]*chnk_stream, map[uint32]*message, *chunk_bufs) { open_chnkstrms := make(map[uint32]*chnk_stream) open_msgs := make(map[uint32]*message) - return open_chnkstrms, open_msgs + chunk_bufs_ptr := new_chunk_bufs() + return open_chnkstrms, open_msgs, chunk_bufs_ptr } -func read_basic_header(conn net.Conn) (uint8, uint32, error) { - fmt_csid_byte := make([]byte, 1) - if _, err := conn.Read(fmt_csid_byte); err != nil { - return 0, 0, err +func read_basic_header(conn net.Conn, chunk_bufs_ptr *chunk_bufs) (format uint8, csid uint32, err error) { + if _, err = conn.Read(chunk_bufs_ptr.fmt_csid_byte); err != nil { + return } - format := uint8(fmt_csid_byte[0] >> 6) - var csid uint32 - switch fmt_csid_byte[0] & 0x3f { + format = uint8(chunk_bufs_ptr.fmt_csid_byte[0] >> 6) + switch chunk_bufs_ptr.fmt_csid_byte[0] & 0x3f { case 0: - csid_true := make([]byte, 1) - if _, err := conn.Read(csid_true); err != nil { - return 0, 0, err + if _, err = conn.Read(chunk_bufs_ptr.csid_true[1:]); err != nil { + return } - csid = uint32(csid_true[0]) + 64 + csid = uint32(chunk_bufs_ptr.csid_true[1]) + 64 case 1: - csid_true := make([]byte, 2) - if _, err := conn.Read(csid_true); err != nil { - return 0, 0, err + if _, err = conn.Read(chunk_bufs_ptr.csid_true); err != nil { + return } - csid = uint32(binary.LittleEndian.Uint16(csid_true)) + 64 + csid = uint32(binary.LittleEndian.Uint16(chunk_bufs_ptr.csid_true)) + 64 default: - csid = uint32(fmt_csid_byte[0] & 0x3f) + csid = uint32(chunk_bufs_ptr.fmt_csid_byte[0] & 0x3f) } - return format, csid, nil + return } -func read_message_header_0(conn net.Conn, chnk_stream_ptr *chnk_stream, msg_ptr *message) (error) { - var extended_time bool - timestamp := make([]byte, 4) - if _, err := conn.Read(timestamp[1:]); err != nil { - return err +func read_time(conn net.Conn, chunk_bufs_ptr *chunk_bufs) (time uint32, extended_time bool, err error) { + chunk_bufs_ptr.time[0] = 0 + if _, err = conn.Read(chunk_bufs_ptr.time[1:]); err != nil { + return } - chnk_stream_ptr.timestamp = binary.BigEndian.Uint32(timestamp) - - if chnk_stream_ptr.timestamp ^ 0xffffff == 0 { + time = binary.BigEndian.Uint32(chunk_bufs_ptr.time) + if time ^ 0xffffff == 0 { extended_time = true } + return +} - msg_len := make([]byte, 4) - if _, err := conn.Read(msg_len[1:]); err != nil { +func read_msg_len(conn net.Conn, chunk_bufs_ptr *chunk_bufs) (msg_len uint32, err error) { + if _, err = conn.Read(chunk_bufs_ptr.msg_len[1:]); err != nil { + return + } + msg_len = binary.BigEndian.Uint32(chunk_bufs_ptr.msg_len) + return +} + +func read_msg_typeid(conn net.Conn, chunk_bufs_ptr *chunk_bufs) (msg_type uint8, err error) { + if _, err = conn.Read(chunk_bufs_ptr.msg_typeid); err != nil { + return + } + msg_type = chunk_bufs_ptr.msg_typeid[0] + return +} + +func read_msg_streamid(conn net.Conn, chunk_bufs_ptr *chunk_bufs) (msg_streamid uint32, err error) { + if _, err = conn.Read(chunk_bufs_ptr.msg_streamid); err != nil { + return + } + msg_streamid = binary.BigEndian.Uint32(chunk_bufs_ptr.msg_streamid) + return +} + +func read_time_extd(conn net.Conn, chunk_bufs_ptr *chunk_bufs) (time uint32, err error) { + if _, err = conn.Read(chunk_bufs_ptr.time); err != nil { + return + } + time = binary.BigEndian.Uint32(chunk_bufs_ptr.time) + return +} + +func read_message_header_0(conn net.Conn, chnk_stream_ptr *chnk_stream, msg_ptr *message, chunk_bufs_ptr *chunk_bufs) (error) { + var extended_time bool + var err error + chnk_stream_ptr.timestamp, extended_time, err = read_time(conn, chunk_bufs_ptr) + if err != nil { + return err + } + + msg_ptr.msg_len, err = read_msg_len(conn, chunk_bufs_ptr) + if err != nil { return err } - msg_ptr.msg_len = binary.BigEndian.Uint32(msg_len) msg_ptr.data = make([]byte, msg_ptr.msg_len) - msg_typeid := make([]byte, 1) - if _, err := conn.Read(msg_typeid); err != nil { + msg_ptr.msg_type, err = read_msg_typeid(conn, chunk_bufs_ptr) + if err != nil { return err } - msg_ptr.msg_type = msg_typeid[0] - msg_streamid := make([]byte, 4) - if _, err := conn.Read(msg_streamid); err != nil { + chnk_stream_ptr.last_msg_strm_id, err = read_msg_streamid(conn, chunk_bufs_ptr) + if err != nil { return err } - chnk_stream_ptr.last_msg_strm_id = binary.BigEndian.Uint32(msg_streamid) if extended_time { - if _, err := conn.Read(timestamp); err != nil { + chnk_stream_ptr.timestamp, err = read_time_extd(conn, chunk_bufs_ptr) + if err != nil { return err } - chnk_stream_ptr.timestamp = binary.BigEndian.Uint32(timestamp) } msg_ptr.timestamp = chnk_stream_ptr.timestamp return nil } +func read_message_header_1(conn net.Conn, chnk_stream_ptr *chnk_stream, msg_ptr *message, chunk_bufs_ptr *chunk_bufs) (error) { + var extended_time bool + var err error + + chnk_stream_ptr.timedelta, extended_time, err = read_time(conn, chunk_bufs_ptr) + if err != nil { + return err + } + msg_ptr.msg_len, err = read_msg_len(conn, chunk_bufs_ptr) + if err != nil { + return err + } + msg_ptr.data = make([]byte, msg_ptr.msg_len) + + msg_ptr.msg_type, err = read_msg_typeid(conn, chunk_bufs_ptr) + if err != nil { + return err + } + + if extended_time { + chnk_stream_ptr.timedelta, err = read_time_extd(conn, chunk_bufs_ptr) + if err != nil { + return err + } + } + chnk_stream_ptr.timestamp += chnk_stream_ptr.timedelta + msg_ptr.timestamp = chnk_stream_ptr.timestamp + + return nil +} + func read_chunk_data(conn net.Conn, msg_ptr *message, chnk_size uint32) (error) { bytes_left := msg_ptr.msg_len - msg_ptr.curr_bytes_read var buffer_end uint32 @@ -112,32 +197,46 @@ func read_chunk_data(conn net.Conn, msg_ptr *message, chnk_size uint32) (error) } -func ReadChunk(conn net.Conn, open_chnkstrms map[uint32]*chnk_stream, open_msgs map[uint32]*message, chnk_size uint32) (*message, error){ +func ReadChunk(conn net.Conn, open_chnkstrms map[uint32]*chnk_stream, open_msgs map[uint32]*message, chnk_size uint32, chunk_bufs_ptr *chunk_bufs) (*message, error){ conn.SetDeadline(time.Now().Add(10 * time.Second)) + + var chnkstream_ptr *chnk_stream + var msg_ptr *message - format, csid, err := read_basic_header(conn) + format, csid, err := read_basic_header(conn, chunk_bufs_ptr) if err != nil { return nil, err } switch format { case 0: - chnkstream_ptr := new(chnk_stream) + chnkstream_ptr = new(chnk_stream) open_chnkstrms[csid] = chnkstream_ptr - msg_ptr := new(message) + msg_ptr = new(message) - if err := read_message_header_0(conn, chnkstream_ptr, msg_ptr); err != nil { + if err := read_message_header_0(conn, chnkstream_ptr, msg_ptr, chunk_bufs_ptr); err != nil { return nil, err } - if err := read_chunk_data(conn, msg_ptr, chnk_size); err != nil { + case 1: + chnkstream_ptr = open_chnkstrms[csid] + msg_ptr = new(message) + + if err := read_message_header_1(conn, chnkstream_ptr, msg_ptr, chunk_bufs_ptr); err != nil { return nil, err } - if msg_ptr.msg_len > chnk_size { - open_msgs[chnkstream_ptr.last_msg_strm_id] = msg_ptr - return nil, nil - } else { - return msg_ptr, nil - } + } + + + if err := read_chunk_data(conn, msg_ptr, chnk_size); err != nil { + return nil, err + } + + if msg_ptr.curr_bytes_read < msg_ptr.msg_len { + open_msgs[chnkstream_ptr.last_msg_strm_id] = msg_ptr + return nil, nil + } else { + delete(open_msgs, chnkstream_ptr.last_msg_strm_id) + return msg_ptr, nil } conn.SetDeadline(time.Time{}) diff --git a/rtmp/connect.go b/rtmp/connect.go index 002521b..29abca8 100644 --- a/rtmp/connect.go +++ b/rtmp/connect.go @@ -6,13 +6,13 @@ import ( ) func NegotiateConnect(conn net.Conn) (bool) { - open_chnkstrms, open_msgs := OpenStreamsMapInit() - full_msg_ptr, err := ReadChunk(conn, open_chnkstrms, open_msgs, 4096) + open_chnkstrms, open_msgs, chunk_bufs := OpenStreamsMapInit() + full_msg_ptr, err := ReadChunk(conn, open_chnkstrms, open_msgs, 4096, chunk_bufs) if err != nil { return false } fmt.Printf("%08b\n", full_msg_ptr.data) - full_msg_ptr, err = ReadChunk(conn, open_chnkstrms, open_msgs, 4096) + full_msg_ptr, err = ReadChunk(conn, open_chnkstrms, open_msgs, 4096, chunk_bufs) if err != nil { return false }