diff --git a/rtmp/amf/decode.go b/rtmp/amf/decode.go index b297a34..9c1108c 100644 --- a/rtmp/amf/decode.go +++ b/rtmp/amf/decode.go @@ -12,7 +12,6 @@ func DecodeAMF(data *[]byte) (AMFObj, error) { var byte_idx uint32 var root_obj_idx uint8 amf_root_obj := make(AMFObj) - for { top_level_obj, err := read_next(data, &byte_idx) if err != nil { @@ -20,6 +19,9 @@ func DecodeAMF(data *[]byte) (AMFObj, error) { } amf_root_obj[root_obj_idx] = top_level_obj root_obj_idx += 1 + if byte_idx == uint32(len(*data)) { + break + } } return amf_root_obj, nil } @@ -28,8 +30,9 @@ func read_bytes(data *[]byte, byte_idx *uint32, n uint32) ([]byte, error) { if int(*byte_idx + n) > len(*data) { return make([]byte, 0), errors.New("Read goes past end") } + read_slice := (*data)[*byte_idx:*byte_idx + n] *byte_idx += n - return (*data)[*byte_idx:*byte_idx + n], nil + return read_slice, nil } func read_number(data *[]byte, byte_idx *uint32) (float64, error) { @@ -40,20 +43,49 @@ func read_number(data *[]byte, byte_idx *uint32) (float64, error) { return math.Float64frombits(binary.BigEndian.Uint64(float_bytes)), nil } +func read_bool(data *[]byte, byte_idx *uint32) (bool, error) { + bool_byte, err := read_bytes(data, byte_idx, 1) + if err != nil { + return false, err + } + if bool_byte[0] > 1 { + return false, errors.New("bool byte must be 0 or 1") + } + return bool_byte[0] != 0, nil +} + +func read_string(data *[]byte, byte_idx *uint32) (string, error) { + string_len, err := read_bytes(data, byte_idx, 2) + if err != nil { + return "", err + } + string_bytes, err := read_bytes(data, byte_idx, uint32(binary.BigEndian.Uint16(string_len))) + if err != nil { + return "", err + } + return string(string_bytes), err +} + func read_next(data *[]byte, byte_idx *uint32) (interface{}, error) { data_type, err := read_bytes(data, byte_idx, 1) + var next_obj interface{} if err != nil { return nil, err } switch data_type[0] { case 0: - next_obj, err := read_number(data, byte_idx) - if err != nil { - return nil, err - } - return next_obj, nil + next_obj, err = read_number(data, byte_idx) + case 1: + next_obj, err = read_bool(data, byte_idx) + case 2: + next_obj, err = read_string(data, byte_idx) + default: + return nil, errors.New("Unhandled data type") } - return nil, errors.New("tmp") + if err != nil { + return nil, err + } + return next_obj, nil }