From 5eaac2e5520c09f00f1d68a45c7208561ac6809f Mon Sep 17 00:00:00 2001 From: Muaz Ahmad Date: Mon, 11 Mar 2024 14:00:21 +0500 Subject: [PATCH] wrapped all RTMP tcp reads in io.ReadFull --- rtmp/chunk.go | 16 ++++++++-------- rtmp/handshake.go | 5 +++-- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/rtmp/chunk.go b/rtmp/chunk.go index 5784ad5..02a69e5 100644 --- a/rtmp/chunk.go +++ b/rtmp/chunk.go @@ -43,18 +43,18 @@ type ChunkBuffers struct { // reads the initial variable size header that defines the chunk's format and chunkstream id // 5.3.1.1 func read_basic_header(conn net.Conn, chunk_bufs_ptr *ChunkBuffers) (format uint8, csid uint32, err error) { - if _, err = conn.Read(chunk_bufs_ptr.fmt_csid_byte); err != nil { + if _, err = io.ReadFull(conn, chunk_bufs_ptr.fmt_csid_byte); err != nil { return } format = uint8(chunk_bufs_ptr.fmt_csid_byte[0] >> 6) // get first 2 bits for format 0-3 switch chunk_bufs_ptr.fmt_csid_byte[0] & 0x3f { // last 6 bits 0-63 case 0: // csid 0 is invalid, means true csid is the next byte, + 64 for the 6 bits prior - if _, err = conn.Read(chunk_bufs_ptr.csid_true[1:]); err != nil { + if _, err = io.ReadFull(conn, chunk_bufs_ptr.csid_true[1:]); err != nil { return } csid = uint32(chunk_bufs_ptr.csid_true[1]) + 64 case 1: // csid 1 is invalid, means true csid is in the next 2 bytes, reverse order (little endian) and the 64, reconstruct - if _, err = conn.Read(chunk_bufs_ptr.csid_true); err != nil { + if _, err = io.ReadFull(conn, chunk_bufs_ptr.csid_true); err != nil { return } csid = uint32(binary.LittleEndian.Uint16(chunk_bufs_ptr.csid_true)) + 64 @@ -66,7 +66,7 @@ func read_basic_header(conn net.Conn, chunk_bufs_ptr *ChunkBuffers) (format uint func read_time(conn net.Conn, chunk_bufs_ptr *ChunkBuffers) (time uint32, extended_time bool, err error) { chunk_bufs_ptr.time[0] = 0 - if _, err = conn.Read(chunk_bufs_ptr.time[1:]); err != nil { + if _, err = io.ReadFull(conn, chunk_bufs_ptr.time[1:]); err != nil { return } time = binary.BigEndian.Uint32(chunk_bufs_ptr.time) @@ -77,7 +77,7 @@ func read_time(conn net.Conn, chunk_bufs_ptr *ChunkBuffers) (time uint32, extend } func read_msg_len(conn net.Conn, chunk_bufs_ptr *ChunkBuffers) (msg_len uint32, err error) { - if _, err = conn.Read(chunk_bufs_ptr.msg_len[1:]); err != nil { + if _, err = io.ReadFull(conn, chunk_bufs_ptr.msg_len[1:]); err != nil { return } msg_len = binary.BigEndian.Uint32(chunk_bufs_ptr.msg_len) @@ -85,7 +85,7 @@ func read_msg_len(conn net.Conn, chunk_bufs_ptr *ChunkBuffers) (msg_len uint32, } func read_msg_typeid(conn net.Conn, chunk_bufs_ptr *ChunkBuffers) (msg_type uint8, err error) { - if _, err = conn.Read(chunk_bufs_ptr.msg_typeid); err != nil { + if _, err = io.ReadFull(conn, chunk_bufs_ptr.msg_typeid); err != nil { return } msg_type = chunk_bufs_ptr.msg_typeid[0] @@ -93,7 +93,7 @@ func read_msg_typeid(conn net.Conn, chunk_bufs_ptr *ChunkBuffers) (msg_type uint } func read_msg_streamid(conn net.Conn, chunk_bufs_ptr *ChunkBuffers) (msg_streamid uint32, err error) { - if _, err = conn.Read(chunk_bufs_ptr.msg_streamid); err != nil { + if _, err = io.ReadFull(conn, chunk_bufs_ptr.msg_streamid); err != nil { return } msg_streamid = binary.LittleEndian.Uint32(chunk_bufs_ptr.msg_streamid) @@ -101,7 +101,7 @@ func read_msg_streamid(conn net.Conn, chunk_bufs_ptr *ChunkBuffers) (msg_streami } func read_time_extd(conn net.Conn, chunk_bufs_ptr *ChunkBuffers) (time uint32, err error) { - if _, err = conn.Read(chunk_bufs_ptr.time); err != nil { + if _, err = io.ReadFull(conn, chunk_bufs_ptr.time); err != nil { return } time = binary.BigEndian.Uint32(chunk_bufs_ptr.time) diff --git a/rtmp/handshake.go b/rtmp/handshake.go index 8be432a..a36bbfb 100644 --- a/rtmp/handshake.go +++ b/rtmp/handshake.go @@ -1,6 +1,7 @@ package rtmp import ( + "io" "net" "time" "encoding/binary" @@ -17,7 +18,7 @@ func DoHandshake(conn net.Conn) (hs_success bool) { // force handshake to finish in under 15 seconds (aribtrary) or throw an error conn.SetDeadline(time.Now().Add(15 * time.Second)) - if _, err := conn.Read(C0C1C2); err != nil || C0C1C2[0] != 3 { + if _, err := io.ReadFull(conn, C0C1C2); err != nil || C0C1C2[0] != 3 { return } copy(C0C1C2[1:1536], S0S1S2[1+1536:]) @@ -26,7 +27,7 @@ func DoHandshake(conn net.Conn) (hs_success bool) { if _, err := conn.Write(S0S1S2); err != nil { // specs say only send S0S1 and wait for C2 before sending S2, obs doesn't care apparently return } - if _, err := conn.Read(C0C1C2[1+1536:]); err != nil { + if _, err := io.ReadFull(conn, C0C1C2[1+1536:]); err != nil { return } hs_success = true