Compare commits

..

10 commits

10 changed files with 332 additions and 36 deletions

View file

@ -2,17 +2,17 @@
Extremely basic live stream server Extremely basic live stream server
Currently implements a degraded subset of RTMP and SRT (un-encrypted) for ingest. Currently implements a degraded subset of RTMP and SRT for ingest. Encryption for SRT is implemented with AES-CTR with hardcoded passphrase "srttestpass". Only 16-byte keys have been tested, but nothing should stop the implementation from having problems in principle
Uses the std lib http server implementation for the http serving side. Uses the std lib http server implementation for the http serving side.
**Not intended for actual use**. The stream key use is not secure and is used to handle directories without a user db system, than to provide auth **Not intended for actual use**. The stream key use is not secure and is used to handle directories without a user db system, than to provide auth. Same goes for the SRT passphrase. Also just accepts connections so will get DDOS'd immediately.
Limits to a single stream at a time, mostly for the lack of db to handle connections and user information rather than concurrency problems. Limits to a single stream at a time, mostly for the lack of db to handle connections and user information rather than concurrency problems.
Currently always transcodes to vp9 + opus, segments to fragmented mp4. Creates one segment playlist, no manifest. Uses ffmpeg Currently always transcodes to vp9 + opus, segments to fragmented mp4. Creates one segment playlist, no manifest. Uses ffmpeg
HTTP streaming relies on hls-player-js. Will be broken for standard hls players until I figure out how to modify the `EXT-X-MAP:URI` field to prepend a path prefix without changing directories. HTTP streaming relies on hls-player-js. Will be broken for standard hls players until I figure out how to modify the `EXT-X-MAP:URI` field to prepend a path prefix without changing directories. Or finish the transcoder project
Currently produces no logs nor debug info. Will just abandon a connection if there is a problem. Will not send any RTMP replies since flash server docs seem dead and abort messages are netStream commands. Currently produces no logs nor debug info. Will just abandon a connection if there is a problem. Will not send any RTMP replies since flash server docs seem dead and abort messages are netStream commands.

14
main.go
View file

@ -4,19 +4,25 @@ import (
"stream_server/rtmp" "stream_server/rtmp"
"stream_server/http" "stream_server/http"
"stream_server/srt" "stream_server/srt"
"flag"
) )
const ( const (
SRVTYPE_RTMP uint8 = iota SRVTYPE_RTMP uint = iota
SRVTYPE_SRT SRVTYPE_SRT
) )
func main() { func main() {
err := NewIngestServer(SRVTYPE_SRT, "7878") ingest_type := flag.Uint("ingest_type", 0, "Ingest server type, 0 for RTMP, 1 for SRT")
ingest_port := flag.String("ingest_port", "7878", "Port for stream intake")
http_port := flag.String("http_port", "7879", "Port to serve http requests")
flag.Parse()
err := NewIngestServer(*ingest_type, *ingest_port)
if err != nil { if err != nil {
panic(err) panic(err)
} }
err = NewHTTPServer("7879") err = NewHTTPServer(*http_port)
if err != nil { if err != nil {
panic(err) panic(err)
} }
@ -24,7 +30,7 @@ func main() {
} }
} }
func NewIngestServer(srvr_type uint8, port string) (error) { func NewIngestServer(srvr_type uint, port string) (error) {
var err error var err error
switch srvr_type { switch srvr_type {
case 0: case 0:

167
srt/crypt.go Normal file
View file

@ -0,0 +1,167 @@
package srt
import (
"crypto/aes"
"crypto/sha1"
"crypto/hmac"
"hash"
"math"
"encoding/binary"
"errors"
"crypto/cipher"
)
type CryptHandler struct {
salt [16]byte
key_len uint8
odd_sek cipher.Block
even_sek cipher.Block
}
// init will apply to KM message immediately
func NewCryptHandler(passphrase string, km_msg *KMMSG) (*CryptHandler) {
crypt := new(CryptHandler)
crypt.key_len = km_msg.key_len
crypt.salt = km_msg.salt
ok := crypt.Unwrap(km_msg.wrapped_key, passphrase, km_msg.key_type)
if !ok {
return nil
}
return crypt
}
// PRF as defined in RFC 8018
func PRF(h hash.Hash, input []byte) ([]byte) {
h.Reset()
h.Write(input)
return h.Sum(nil)
}
// gets KEK from passphrase and salt according to fixed SRT iterations and algo from RFC doc
// see RFC 8018 for implementation details
func SRT_PBKDF2(passphrase string, salt []byte, dklen uint8) ([]byte) {
prf := hmac.New(sha1.New, []byte(passphrase))
hlen := prf.Size()
l := int(math.Ceil(float64(dklen) / float64(hlen)))
r := int(dklen) - (l - 1) * hlen
key := make([]byte, 0)
for block := 1; block <= l; block++ {
U := make([]byte, hlen)
T := make([]byte, hlen)
block_i := make([]byte, 4)
binary.BigEndian.PutUint32(block_i, uint32(block))
U = PRF(prf, append(salt, block_i...))
copy(T, U)
// skip one iter since done above
for n := 1; n < 2048; n++ {
U = PRF(prf, U)
for x := range T {
T[x] ^= U[x]
}
}
// final block may not use entire SHA output, still need full during computation
if block == l {
T = T[:r]
}
// final key is appended sequence of all blocks computed independently
key = append(key, T...)
}
return key
}
// See RFC 3394, inplace implementation
func AES_UNWRAP(key []byte, wrapped []byte) ([]byte, error) {
seks := make([]byte, 0) // bytes past IV
cipher, err := aes.NewCipher(key)
if err != nil {
return seks, err
}
A := wrapped[:8] // IV bytes
n := len(wrapped) / 8 - 1
R := make([][]byte, n) // actual message (SEKs)
for i := range R {
R[i] = wrapped[(i + 1) * 8: (i + 2) * 8]
}
for j := 5; j >= 0; j-- {
for i := n; i > 0; i-- {
t := make([]byte, 8)
binary.BigEndian.PutUint64(t, uint64(n * j + i))
for k := range t {
t[k] ^= A[k]
}
B := make([]byte, 16)
cipher.Decrypt(B, append(t, R[i - 1]...))
copy(A, B[:8])
copy(R[i - 1], B[8:])
}
}
// SRT uses default IV, 8 repeating bytes of 0xa6 prepended in wrap, check if
// preserved in unwrap
for i := range A {
if A[i] != 0xa6 {
return seks, errors.New("IV not default")
}
}
// R is 8 byte blocks, keys can be 16-32 bytes, prepend all together and
// let wrappers figure it out
for _, v := range R {
seks = append(seks, v...)
}
return seks, nil
}
// unwrap and store SEK ciphers, key_type defined as KK 2-bit value in Key Material from SRT docs
func (crypt *CryptHandler) Unwrap(wrapped_key []byte, passphrase string, key_type uint8) (bool) {
kek := SRT_PBKDF2(passphrase, crypt.salt[8:], crypt.key_len)
// need a copy since original will be sent back
wrapped_copy := make([]byte, len(wrapped_key))
copy(wrapped_copy, wrapped_key)
seks, err := AES_UNWRAP(kek, wrapped_copy)
// either unwrap fails or key len does not match expected (1 or 2 SEKs len identical)
if err != nil || len(seks) % int(crypt.key_len) != 0 {
return false
}
// always have one SEK, if more bytes (second key) and peer did not send 2 keys
// something is wrong
sek_1 := seks[:crypt.key_len]
if len(seks) > int(crypt.key_len) && key_type != 3 {
return false
}
switch key_type {
case 1:
crypt.even_sek, _ = aes.NewCipher(sek_1)
case 2:
crypt.odd_sek, _ = aes.NewCipher(sek_1)
case 3:
sek_2 := seks[crypt.key_len:]
crypt.even_sek, _ = aes.NewCipher(sek_1)
crypt.odd_sek, _ = aes.NewCipher(sek_2)
default:
return false
}
return true
}
func (crypt *CryptHandler) Decrypt(pkt *Packet) {
var sek cipher.Block
switch pkt.header_info.(*DataHeader).msg_flags & 0x6 {
case 2:
sek = crypt.even_sek
case 4:
sek = crypt.odd_sek
default:
return
}
IV := make([]byte, crypt.key_len)
binary.BigEndian.PutUint32(IV[10:14], pkt.header_info.(*DataHeader).seq_num)
for i := 0; i < 14; i++ {
IV[i] ^= crypt.salt[i]
}
ctr := cipher.NewCTR(sek, IV)
ctr.XORKeyStream(pkt.cif.([]byte), pkt.cif.([]byte))
}

View file

@ -13,12 +13,20 @@ type Datum struct {
next *Datum next *Datum
} }
// linked chain of datums, specifically to store continuous segments
// when a packet is missed, start and end can be used to generate lost
// packet reports, and easily link when missing is received
// 1..2..3..4 6..7..8..9
// chain_1 chain_2
// nack: get 5
type DatumLink struct { type DatumLink struct {
queued int queued int // remove eventually, was to be used for ACK recv rate calcs, not needed
root *Datum root *Datum
end *Datum end *Datum
} }
// data type and function to allow sorting so order can be ignored during
// linking since each is sequential on outset
type chains []*DatumLink type chains []*DatumLink
func (c chains) Len() (int) { func (c chains) Len() (int) {
@ -29,6 +37,7 @@ func (c chains) Swap(i, j int) {
c[i], c[j] = c[j], c[i] c[i], c[j] = c[j], c[i]
} }
// chain_1 is less then chain_2 when chain_1 ends before chain_2 starts
func (c chains) Less(i, j int) (bool) { func (c chains) Less(i, j int) (bool) {
x_1 := c[i].end.seq_num x_1 := c[i].end.seq_num
x_2 := c[j].root.seq_num x_2 := c[j].root.seq_num
@ -43,6 +52,7 @@ type DatumStorage struct {
offshoots chains offshoots chains
} }
// append new packet to end of buffer
func (buffer *DatumLink) NewDatum(pkt *Packet) { func (buffer *DatumLink) NewDatum(pkt *Packet) {
datum := new(Datum) datum := new(Datum)
datum.seq_num = pkt.header_info.(*DataHeader).seq_num datum.seq_num = pkt.header_info.(*DataHeader).seq_num
@ -53,6 +63,7 @@ func (buffer *DatumLink) NewDatum(pkt *Packet) {
buffer.end = datum buffer.end = datum
} }
// create a new datumlink with root and end at the given packet
func NewDatumLink(pkt *Packet) (*DatumLink) { func NewDatumLink(pkt *Packet) (*DatumLink) {
buffer := new(DatumLink) buffer := new(DatumLink)
root_datum := new(Datum) root_datum := new(Datum)
@ -66,12 +77,14 @@ func NewDatumLink(pkt *Packet) (*DatumLink) {
return buffer return buffer
} }
// initialize storage with the given data packet, in the main chain
func NewDatumStorage(packet *Packet) (*DatumStorage) { func NewDatumStorage(packet *Packet) (*DatumStorage) {
storage := new(DatumStorage) storage := new(DatumStorage)
storage.main = NewDatumLink(packet) storage.main = NewDatumLink(packet)
return storage return storage
} }
// purge all packets in the main chain except the last for future linkage
func (storage *DatumStorage) Expunge(output io.WriteCloser) (error) { func (storage *DatumStorage) Expunge(output io.WriteCloser) (error) {
curr_datum := storage.main.root curr_datum := storage.main.root
seq_num_end := storage.main.end.seq_num seq_num_end := storage.main.end.seq_num
@ -86,6 +99,7 @@ func (storage *DatumStorage) Expunge(output io.WriteCloser) (error) {
return nil return nil
} }
// check if the given sequence number should already be inside the given buffer
func (buffer *DatumLink) Holds(new_seq_num uint32) (bool) { func (buffer *DatumLink) Holds(new_seq_num uint32) (bool) {
start := buffer.root.seq_num start := buffer.root.seq_num
end := buffer.end.seq_num end := buffer.end.seq_num
@ -98,6 +112,8 @@ func (buffer *DatumLink) Holds(new_seq_num uint32) (bool) {
return true return true
} }
// check if the given seq num lies before the given buffer starts
// buffer is After seq num?
func (buffer *DatumLink) After(new_seq_num uint32) (bool) { func (buffer *DatumLink) After(new_seq_num uint32) (bool) {
start := buffer.root.seq_num start := buffer.root.seq_num
if serial_less(new_seq_num, start, 31) { if serial_less(new_seq_num, start, 31) {
@ -106,6 +122,8 @@ func (buffer *DatumLink) After(new_seq_num uint32) (bool) {
return false return false
} }
// check if the given seq num lies after the given buffer starts
// buffer is Before seq num?
func (buffer *DatumLink) Before(new_seq_num uint32) (bool) { func (buffer *DatumLink) Before(new_seq_num uint32) (bool) {
end := buffer.end.seq_num end := buffer.end.seq_num
if serial_less(end, new_seq_num, 31) { if serial_less(end, new_seq_num, 31) {
@ -114,6 +132,7 @@ func (buffer *DatumLink) Before(new_seq_num uint32) (bool) {
return false return false
} }
// add the given packet to the appropriate buffer or add a new offshoot buffer
func (storage *DatumStorage) NewDatum(pkt *Packet) { func (storage *DatumStorage) NewDatum(pkt *Packet) {
new_pkt_num := pkt.header_info.(*DataHeader).seq_num new_pkt_num := pkt.header_info.(*DataHeader).seq_num
prev_num := (new_pkt_num - 1) % uint32(1 << 31) prev_num := (new_pkt_num - 1) % uint32(1 << 31)
@ -139,12 +158,16 @@ func (storage *DatumStorage) NewDatum(pkt *Packet) {
} }
} }
// Link 2 buffers to each other
func (buffer *DatumLink) Link(buffer_next *DatumLink) { func (buffer *DatumLink) Link(buffer_next *DatumLink) {
buffer.end.next = buffer_next.root buffer.end.next = buffer_next.root
buffer.end = buffer_next.end buffer.end = buffer_next.end
buffer.queued += buffer_next.queued buffer.queued += buffer_next.queued
} }
// check if the start of the buffer_next is sequentially next to the end of buffer
// or it is too late to get the true next packets to link and the packets must be linked
// given a 31-bit wrap around
func check_append_serial_next(buffer *DatumLink, buffer_next *DatumLink, curr_time uint32) (bool) { func check_append_serial_next(buffer *DatumLink, buffer_next *DatumLink, curr_time uint32) (bool) {
seq_1 := buffer.end.seq_num seq_1 := buffer.end.seq_num
seq_2 := (seq_1 + 1) % uint32(math.Pow(2, 31)) seq_2 := (seq_1 + 1) % uint32(math.Pow(2, 31))
@ -155,6 +178,7 @@ func check_append_serial_next(buffer *DatumLink, buffer_next *DatumLink, curr_ti
return false return false
} }
// Given the current storage state, check what packets need to added to fully link main
func (storage *DatumStorage) GenNACKCIF() (*NACKCIF, bool) { func (storage *DatumStorage) GenNACKCIF() (*NACKCIF, bool) {
if len(storage.offshoots) == 0 { if len(storage.offshoots) == 0 {
return nil, false return nil, false
@ -173,13 +197,14 @@ func (storage *DatumStorage) GenNACKCIF() (*NACKCIF, bool) {
return cif, true return cif, true
} }
// Try to relink all chains wherever possible or necessary due to TLPKTDROP
func (storage *DatumStorage) Relink(curr_time uint32) { func (storage *DatumStorage) Relink(curr_time uint32) {
sort.Sort(storage.offshoots) sort.Sort(storage.offshoots)
buffer := storage.main buffer := storage.main
i := 0 i := 0
for i < len(storage.offshoots) { for i < len(storage.offshoots) {
if check_append_serial_next(buffer, storage.offshoots[i], curr_time) { if check_append_serial_next(buffer, storage.offshoots[i], curr_time) {
storage.offshoots = append(storage.offshoots[:i], storage.offshoots[i + 1:]...) storage.offshoots = append(storage.offshoots[:i], storage.offshoots[i + 1:]...) // nuke the chain since it is not needed anymore
} else { } else {
buffer = storage.offshoots[i] buffer = storage.offshoots[i]
i++ i++
@ -187,6 +212,7 @@ func (storage *DatumStorage) Relink(curr_time uint32) {
} }
} }
// check if a is less than b under serial arithmetic (modulo operations)
func serial_less(a uint32, b uint32, bits int) (bool) { func serial_less(a uint32, b uint32, bits int) (bool) {
if (a < b && b-a < (1 << (bits - 1))) || (a > b && a-b > (1 << (bits - 1))) { if (a < b && b-a < (1 << (bits - 1))) || (a > b && a-b > (1 << (bits - 1))) {
return true return true

View file

@ -15,7 +15,7 @@ func NewIntake(l net.PacketConn, max_conns int) (*Intake) {
intake := new(Intake) intake := new(Intake)
intake.max_conns = max_conns intake.max_conns = max_conns
intake.tunnels = make([]*Tunnel, 0) intake.tunnels = make([]*Tunnel, 0)
intake.buffer = make([]byte, 1500) intake.buffer = make([]byte, 1500) // each packet is restricted to a max size of 1500
intake.socket = l intake.socket = l
return intake return intake
@ -26,9 +26,9 @@ func (intake *Intake) NewTunnel(l net.PacketConn, peer net.Addr) (*Tunnel) {
tunnel := new(Tunnel) tunnel := new(Tunnel)
tunnel.socket = l tunnel.socket = l
tunnel.peer = peer tunnel.peer = peer
tunnel.queue = make(chan []byte, 10) tunnel.queue = make(chan []byte, 10) // packet buffer, will cause packet loss if low
intake.tunnels = append(intake.tunnels, tunnel) intake.tunnels = append(intake.tunnels, tunnel)
go tunnel.Start() go tunnel.Start() // start the tunnel SRT processing
return tunnel return tunnel
} }
return nil return nil
@ -37,6 +37,7 @@ func (intake *Intake) NewTunnel(l net.PacketConn, peer net.Addr) (*Tunnel) {
func (intake *Intake) get_tunnel(peer net.Addr) (*Tunnel) { func (intake *Intake) get_tunnel(peer net.Addr) (*Tunnel) {
var tunnel *Tunnel var tunnel *Tunnel
for i := 0; i < len(intake.tunnels); i++ { for i := 0; i < len(intake.tunnels); i++ {
// check if tunnels are broken and remove
if intake.tunnels[i].broken { if intake.tunnels[i].broken {
intake.tunnels[i].Shutdown() intake.tunnels[i].Shutdown()
intake.tunnels = append(intake.tunnels[:i], intake.tunnels[i+1:]...) intake.tunnels = append(intake.tunnels[:i], intake.tunnels[i+1:]...)
@ -47,6 +48,9 @@ func (intake *Intake) get_tunnel(peer net.Addr) (*Tunnel) {
tunnel = intake.tunnels[i] tunnel = intake.tunnels[i]
} }
} }
// if no tunnel was found, make one
// should be after conclusion handshake, but wanted to keep all protocol
// related actions separate from UDP handling
if tunnel == nil { if tunnel == nil {
tunnel = intake.NewTunnel(intake.socket, peer) tunnel = intake.NewTunnel(intake.socket, peer)
} }
@ -58,14 +62,16 @@ func (intake *Intake) get_tunnel(peer net.Addr) (*Tunnel) {
func (intake *Intake) Read() { func (intake *Intake) Read() {
n, peer, err := intake.socket.ReadFrom(intake.buffer) n, peer, err := intake.socket.ReadFrom(intake.buffer)
if err != nil { if err != nil {
return return // ignore UDP errors
} }
// find the SRT/UDT tunnel corresponding to the given peer
tunnel := intake.get_tunnel(peer) tunnel := intake.get_tunnel(peer)
if tunnel == nil { if tunnel == nil {
return return
} }
pkt := make([]byte, n) pkt := make([]byte, n)
copy(pkt, intake.buffer[:n]) copy(pkt, intake.buffer[:n])
// send a copy to the corresponding tunnels packet queue if not full
select { select {
case tunnel.queue <- pkt: case tunnel.queue <- pkt:
default: default:

View file

@ -5,6 +5,7 @@ import (
"encoding/binary" "encoding/binary"
) )
// arbitrary indexing
const ( const (
DATA uint8 = iota DATA uint8 = iota
HANDSHAKE HANDSHAKE
@ -15,6 +16,7 @@ const (
DROP DROP
) )
// see SRT protocol RFC for information
type ControlHeader struct { type ControlHeader struct {
ctrl_type uint16 ctrl_type uint16
ctrl_subtype uint16 ctrl_subtype uint16
@ -80,10 +82,12 @@ type pckts_range struct {
end uint32 end uint32
} }
// should be safe to ignore
type DROPCIF struct { type DROPCIF struct {
to_drop pckts_range to_drop pckts_range
} }
// header and cif are interfaces to allow easier typing, fitting above structs
type Packet struct { type Packet struct {
packet_type uint8 packet_type uint8
timestamp uint32 timestamp uint32
@ -92,11 +96,12 @@ type Packet struct {
cif interface{} cif interface{}
} }
// completely pointless since only implementing receiver, here anyway
func marshall_data_packet(packet *Packet, header []byte) ([]byte, error) { func marshall_data_packet(packet *Packet, header []byte) ([]byte, error) {
info, ok_head := packet.header_info.(*DataHeader) info, ok_head := packet.header_info.(*DataHeader)
data, ok_data := packet.cif.([]byte) data, ok_data := packet.cif.([]byte)
if !ok_head || !ok_data { if !ok_head || !ok_data {
return header, errors.New("data packet does not have data header") return header, errors.New("data packet does not have data header or data")
} }
binary.BigEndian.PutUint32(header[:4], info.seq_num) binary.BigEndian.PutUint32(header[:4], info.seq_num)
head2 := (uint32(info.msg_flags) << 26) + info.msg_num head2 := (uint32(info.msg_flags) << 26) + info.msg_num
@ -162,6 +167,9 @@ func marshall_nack_cif(data *NACKCIF) ([]byte) {
return loss_list_bytes return loss_list_bytes
} }
// locations and length are determined by protocol,
// no real point abstracting that
// could be cleaner with reflects, but added work for very little real gain
func marshall_ack_cif(data *ACKCIF) ([]byte) { func marshall_ack_cif(data *ACKCIF) ([]byte) {
cif := make([]byte, 28) cif := make([]byte, 28)
binary.BigEndian.PutUint32(cif[:4], data.last_acked) binary.BigEndian.PutUint32(cif[:4], data.last_acked)
@ -175,6 +183,7 @@ func marshall_ack_cif(data *ACKCIF) ([]byte) {
return cif return cif
} }
// same as above
func marshall_hs_cif(data *HandshakeCIF) ([]byte) { func marshall_hs_cif(data *HandshakeCIF) ([]byte) {
cif := make([]byte, 48) cif := make([]byte, 48)
binary.BigEndian.PutUint32(cif[:4], data.version) binary.BigEndian.PutUint32(cif[:4], data.version)
@ -201,14 +210,18 @@ func marshall_hs_cif(data *HandshakeCIF) ([]byte) {
binary.BigEndian.PutUint16(ext_buff[12:14], contents.recv_delay) binary.BigEndian.PutUint16(ext_buff[12:14], contents.recv_delay)
binary.BigEndian.PutUint16(ext_buff[14:16], contents.send_delay) binary.BigEndian.PutUint16(ext_buff[14:16], contents.send_delay)
case 4: case 4:
contents := extension.ext_contents.(*KMMSG) contents, ok := extension.ext_contents.(*KMMSG)
binary.BigEndian.PutUint32(ext_buff[4:8], uint32(0x12202900) | uint32(contents.key_type)) if !ok { // handle km_state error
binary.BigEndian.PutUint32(ext_buff[12:16], uint32(0x02000200)) copy(ext_buff[4:8], extension.ext_contents.([]byte))
binary.BigEndian.PutUint32(ext_buff[16:20], uint32(0x0400) | uint32(contents.key_len / 4)) } else {
for i := 0; i < 16; i++ { binary.BigEndian.PutUint32(ext_buff[4:8], uint32(0x12202900) | uint32(contents.key_type))
ext_buff[20 + i] = contents.salt[i] binary.BigEndian.PutUint32(ext_buff[12:16], uint32(0x02000200))
binary.BigEndian.PutUint32(ext_buff[16:20], uint32(0x0400) | uint32(contents.key_len / 4))
for i := 0; i < 16; i++ {
ext_buff[20 + i] = contents.salt[i]
}
copy(ext_buff[36:], contents.wrapped_key)
} }
copy(ext_buff[36:], contents.wrapped_key)
default: default:
copy(ext_buff[4:], extension.ext_contents.([]byte)) copy(ext_buff[4:], extension.ext_contents.([]byte))
} }
@ -231,7 +244,7 @@ func MarshallPacket(packet *Packet, agent *SRTManager) ([]byte, error) {
func parse_data_packet(pkt *Packet, buffer []byte) (error) { func parse_data_packet(pkt *Packet, buffer []byte) (error) {
info := new(DataHeader) info := new(DataHeader)
info.seq_num = binary.BigEndian.Uint32(buffer[:4]) info.seq_num = binary.BigEndian.Uint32(buffer[:4])
info.msg_flags = buffer[4] >> 2 info.msg_flags = buffer[4] >> 2 // unsused since live streaming makes it irrelevant, kk for encrypt eventually
info.msg_num = binary.BigEndian.Uint32(buffer[4:8]) & 0x03ffffff info.msg_num = binary.BigEndian.Uint32(buffer[4:8]) & 0x03ffffff
pkt.header_info = info pkt.header_info = info
@ -307,13 +320,14 @@ func parse_hs_cif(cif *HandshakeCIF, buffer []byte) (error) {
case 3: case 3:
content := new(KMMSG) content := new(KMMSG)
content.key_type = extensions[7] & 0x3 content.key_type = extensions[7] & 0x3
content.key_len = extensions[19] content.key_len = extensions[19] * 4
for i := 0; i < 4; i++ { for i := 0; i < 16; i++ {
content.salt[i] = extensions[20 + i] content.salt[i] = extensions[20 + i]
} }
wrap_key_len := 4 + ext.ext_len - 24 // -36 from actual content len, extensions includes headers as well
wrap_key_len := 4 + ext.ext_len - 36
content.wrapped_key = make([]byte, wrap_key_len) content.wrapped_key = make([]byte, wrap_key_len)
copy(content.wrapped_key, extensions[24:24 + wrap_key_len]) copy(content.wrapped_key, extensions[36:36 + wrap_key_len])
ext.ext_contents = content ext.ext_contents = content
default: default:
content := make([]byte, ext.ext_len) content := make([]byte, ext.ext_len)

View file

@ -18,6 +18,7 @@ const (
) )
type SRTManager struct { type SRTManager struct {
crypt *CryptHandler
state uint8 state uint8
init time.Time init time.Time
syn_cookie uint32 syn_cookie uint32
@ -39,11 +40,12 @@ func NewSRTManager(l net.PacketConn) (*SRTManager) {
agent := new(SRTManager) agent := new(SRTManager)
agent.init = time.Now() agent.init = time.Now()
agent.socket = l agent.socket = l
agent.bw = 15000 agent.bw = 15000 // in pkts (mtu bytes) per second
agent.mtu = 1500 agent.mtu = 1500
return agent return agent
} }
// adds basic information present in all packets, timestamp and destination SRT socket
func (agent *SRTManager) create_basic_header() (*Packet) { func (agent *SRTManager) create_basic_header() (*Packet) {
packet := new(Packet) packet := new(Packet)
packet.timestamp = uint32(time.Now().Sub(agent.init).Microseconds()) packet.timestamp = uint32(time.Now().Sub(agent.init).Microseconds())
@ -74,6 +76,7 @@ func (agent *SRTManager) create_induction_resp() (*Packet) {
packet.cif = cif packet.cif = cif
// use the handshake as a placeholder ack-ackack rtt initializer
var init_ping_time [2]time.Time var init_ping_time [2]time.Time
init_ping_time[0] = time.Now() init_ping_time[0] = time.Now()
agent.pings = append(agent.pings, init_ping_time) agent.pings = append(agent.pings, init_ping_time)
@ -81,6 +84,7 @@ func (agent *SRTManager) create_induction_resp() (*Packet) {
return packet return packet
} }
// not ideal, but works
func (agent *SRTManager) make_syn_cookie(peer net.Addr) { func (agent *SRTManager) make_syn_cookie(peer net.Addr) {
t := uint32(time.Now().Unix()) >> 6 t := uint32(time.Now().Unix()) >> 6
s := sha256.New() s := sha256.New()
@ -112,7 +116,7 @@ func (agent *SRTManager) create_conclusion_resp() (*Packet) {
cif := new(HandshakeCIF) cif := new(HandshakeCIF)
cif.version = 5 cif.version = 5
cif.ext_field = 0x1 cif.ext_field = 0x1 // 1 for HS-ext, does not allow encryption currently
cif.sock_id = 1 cif.sock_id = 1
cif.mtu = agent.mtu cif.mtu = agent.mtu
cif.max_flow = 8192 cif.max_flow = 8192
@ -145,20 +149,44 @@ func (agent *SRTManager) process_conclusion(packet *Packet) (*Packet) {
hs_cif := packet.cif.(*HandshakeCIF) hs_cif := packet.cif.(*HandshakeCIF)
if hs_cif.hs_type == 0xffffffff && hs_cif.syn_cookie == agent.syn_cookie { if hs_cif.hs_type == 0xffffffff && hs_cif.syn_cookie == agent.syn_cookie {
for _, v := range hs_cif.hs_extensions { for _, v := range hs_cif.hs_extensions {
// force client to add a stream_id for output location
// to do: add encryption handling
switch v.ext_type { switch v.ext_type {
case 5: case 5:
writer, stream_key, ok := CheckStreamID(v.ext_contents.([]byte)) writer, stream_key, ok := CheckStreamID(v.ext_contents.([]byte))
agent.stream_key = stream_key agent.stream_key = stream_key
if !ok { if !ok {
resp_packet.cif.(*HandshakeCIF).hs_type = 1003 resp_packet.cif.(*HandshakeCIF).hs_type = 1003
agent.state = 3
return resp_packet return resp_packet
} else { } else {
agent.output = writer agent.output = writer
CleanFiles(agent.stream_key, 0) CleanFiles(agent.stream_key, 0)
} }
case 3:
resp_packet.cif.(*HandshakeCIF).ext_field = 3
// passphrase harcoded for testing, should pass in somehow with a user management system
crypt_handler := NewCryptHandler("srttestpass", v.ext_contents.(*KMMSG))
if crypt_handler == nil { // if sek unwrap required but fails
agent.state = 3
resp_packet.cif.(*HandshakeCIF).hs_type = 1010
resp_ext := new(HandshakeExtension)
resp_ext.ext_type = 4
resp_ext.ext_len = 4
km_state := make([]byte, 4)
km_state[3] = 4 // BADSECRET code
resp_ext.ext_contents = km_state
resp_packet.cif.(*HandshakeCIF).hs_extensions = append(resp_packet.cif.(*HandshakeCIF).hs_extensions, resp_ext)
return resp_packet
}
// else return since needed
resp_packet.cif.(*HandshakeCIF).hs_extensions = append(resp_packet.cif.(*HandshakeCIF).hs_extensions, v)
v.ext_type = 4
agent.crypt = crypt_handler
} }
} }
agent.pings[0][1] = time.Now() agent.pings[0][1] = time.Now()
// if output was successfully initialized, proceed with data looping
if agent.output != nil { if agent.output != nil {
agent.state = DATA_LOOP agent.state = DATA_LOOP
return resp_packet return resp_packet
@ -180,9 +208,13 @@ func (agent *SRTManager) create_ack_report() (*Packet) {
packet.header_info = info packet.header_info = info
cif := new(ACKCIF) cif := new(ACKCIF)
// main has the latest unbroken chain, either no other packets, or
// missing packet which must be nak'd
cif.last_acked = agent.storage.main.end.seq_num cif.last_acked = agent.storage.main.end.seq_num
cif.bw = agent.bw cif.bw = agent.bw
// calculate rtt variance from valid ping pairs, use last value as rtt of last
// exchange since that's what it is
var rtt_sum uint32 var rtt_sum uint32
var rtt_2_sum uint32 var rtt_2_sum uint32
var rtt uint32 var rtt uint32
@ -198,7 +230,10 @@ func (agent *SRTManager) create_ack_report() (*Packet) {
cif.rtt = rtt cif.rtt = rtt
cif.var_rtt = uint32(rtt_2_sum / rtt_n) - uint32(math.Pow(float64(rtt_sum / rtt_n), 2)) cif.var_rtt = uint32(rtt_2_sum / rtt_n) - uint32(math.Pow(float64(rtt_sum / rtt_n), 2))
// use the packets received since the last ack report was sent to calc
// estimated recv rates
cif.pkt_recv_rate = uint32(len(agent.pkt_sizes) * 100) cif.pkt_recv_rate = uint32(len(agent.pkt_sizes) * 100)
// arbitrary, should use len(channel) to set this but doesn't really seem to matter
cif.buff_size = 100 cif.buff_size = 100
var bytes_recvd uint32 var bytes_recvd uint32
for _, v := range agent.pkt_sizes { for _, v := range agent.pkt_sizes {
@ -210,6 +245,7 @@ func (agent *SRTManager) create_ack_report() (*Packet) {
var next_ping_pair [2]time.Time var next_ping_pair [2]time.Time
next_ping_pair[0] = time.Now() next_ping_pair[0] = time.Now()
// only keep last 100 acks, use offset for correct ackack ping indexing
if len(agent.pings) >= 100 { if len(agent.pings) >= 100 {
agent.pings = append(agent.pings[1:], next_ping_pair) agent.pings = append(agent.pings[1:], next_ping_pair)
agent.ping_offset++ agent.ping_offset++
@ -220,6 +256,7 @@ func (agent *SRTManager) create_ack_report() (*Packet) {
return packet return packet
} }
// only need the recieve time from ackacks for rtt calcs, ignore otherwise
func (agent *SRTManager) handle_ackack(packet *Packet) { func (agent *SRTManager) handle_ackack(packet *Packet) {
ack_num := packet.header_info.(*ControlHeader).tsi ack_num := packet.header_info.(*ControlHeader).tsi
agent.pings[int(ack_num) - agent.ping_offset][1] = time.Now() agent.pings[int(ack_num) - agent.ping_offset][1] = time.Now()
@ -243,9 +280,15 @@ func (agent *SRTManager) create_nack_report() (*Packet) {
return packet return packet
} }
// handling packets during data loop
func (agent *SRTManager) process_data(packet *Packet) (*Packet) { func (agent *SRTManager) process_data(packet *Packet) (*Packet) {
switch packet.packet_type { switch packet.packet_type {
case DATA: case DATA:
// if data, add to storage, linking, etc
// then check if ack or nack can be generated (every 10 ms)
if agent.crypt != nil {
agent.crypt.Decrypt(packet)
}
agent.handle_data_storage(packet) agent.handle_data_storage(packet)
if time.Now().Sub(agent.pings[len(agent.pings) - 1][0]).Milliseconds() >= 10 { if time.Now().Sub(agent.pings[len(agent.pings) - 1][0]).Milliseconds() >= 10 {
return agent.create_ack_report() return agent.create_ack_report()
@ -256,7 +299,9 @@ func (agent *SRTManager) process_data(packet *Packet) (*Packet) {
case ACKACK: case ACKACK:
agent.handle_ackack(packet) agent.handle_ackack(packet)
case SHUTDOWN: case SHUTDOWN:
agent.state = 3 // state 3 should raise error and shutdown tunnel,
// for now start cleanup procedure in 10s
agent.state = BROKEN
go CleanFiles(agent.stream_key, 10) go CleanFiles(agent.stream_key, 10)
default: default:
return nil return nil
@ -265,18 +310,28 @@ func (agent *SRTManager) process_data(packet *Packet) (*Packet) {
} }
func (agent *SRTManager) handle_data_storage(packet *Packet) { func (agent *SRTManager) handle_data_storage(packet *Packet) {
// data packets always have []byte as "cif"
agent.pkt_sizes = append(agent.pkt_sizes, uint32(len(packet.cif.([]byte)))) agent.pkt_sizes = append(agent.pkt_sizes, uint32(len(packet.cif.([]byte))))
// initialize storage if does not exist, else add where it can
if agent.storage == nil { if agent.storage == nil {
agent.storage = NewDatumStorage(packet) agent.storage = NewDatumStorage(packet)
} else { } else {
agent.storage.NewDatum(packet) agent.storage.NewDatum(packet)
} }
// attempt to relink any offshoots
// timestamp for TLPKTDROP
if len(agent.storage.offshoots) != 0 { if len(agent.storage.offshoots) != 0 {
agent.storage.Relink(packet.timestamp) agent.storage.Relink(packet.timestamp)
} }
agent.storage.Expunge(agent.output) // write out all possible packets
if err := agent.storage.Expunge(agent.output); err != nil {
agent.state = BROKEN
}
} }
// determines appropriate packets and responses depending on tunnel state
// some need to ignore depending on state, eg
// late induction requests during conclusion phase
func (agent *SRTManager) Process(packet *Packet) (*Packet, error) { func (agent *SRTManager) Process(packet *Packet) (*Packet, error) {
switch agent.state { switch agent.state {
case INDUCTION: case INDUCTION:

View file

@ -5,6 +5,8 @@ import (
"fmt" "fmt"
) )
// main entry point, no concept of tunnels in UDP so need to implement
// that separately and cannot simply add a max connlimit here like with RTMP
func NewServer(port string) (error) { func NewServer(port string) (error) {
l, err := net.ListenPacket("udp", ":" + port) l, err := net.ListenPacket("udp", ":" + port)
if err != nil { if err != nil {
@ -15,12 +17,13 @@ func NewServer(port string) (error) {
} }
func start(l net.PacketConn) { func start(l net.PacketConn) {
// basic panic logging for debugging mostly
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
fmt.Println(r) fmt.Println(r)
} }
}() }()
intake := NewIntake(l, 1) intake := NewIntake(l, 1) // limit to one concurrent tunnel
for { for {
intake.Read() intake.Read()
} }

View file

@ -7,9 +7,9 @@ import (
"stream_server/transcoder" "stream_server/transcoder"
"time" "time"
"path/filepath" "path/filepath"
"fmt"
) )
// spawn a transcoder instance and return its stdin pipe
func NewWriter(stream_key string) (io.WriteCloser, error) { func NewWriter(stream_key string) (io.WriteCloser, error) {
transcoder_in, err := transcoder.NewTranscoder(stream_key) transcoder_in, err := transcoder.NewTranscoder(stream_key)
if err != nil { if err != nil {
@ -18,11 +18,17 @@ func NewWriter(stream_key string) (io.WriteCloser, error) {
return transcoder_in, nil return transcoder_in, nil
} }
// check if the init.mp4 segment has been modified in between a given sleep time
// if it hasn't (stream disconnected for longer) delete it
// else a new stream has started which shouldn't be deleted
func CleanFiles(stream_key string, delay time.Duration) { func CleanFiles(stream_key string, delay time.Duration) {
time.Sleep(delay * time.Second) time.Sleep(delay * time.Second)
base_dir, _ := os.UserHomeDir() base_dir, _ := os.UserHomeDir()
stream_dir := base_dir + "/live/" + stream_key stream_dir := base_dir + "/live/" + stream_key
fileinfo, _ := os.Stat(stream_dir + "/init.mp4") fileinfo, file_ok := os.Stat(stream_dir + "/init.mp4")
if file_ok != nil {
return
}
if time.Now().Sub(fileinfo.ModTime()) > delay * time.Second { if time.Now().Sub(fileinfo.ModTime()) > delay * time.Second {
leftover_files, _ := filepath.Glob(stream_dir + "/*") leftover_files, _ := filepath.Glob(stream_dir + "/*")
for _, file := range leftover_files { for _, file := range leftover_files {
@ -31,6 +37,9 @@ func CleanFiles(stream_key string, delay time.Duration) {
} }
} }
// stream_id is in reverse order, len is multiple of 4 padded with 0
// get the string, check if corresponding folder exists, then attempt
// to spawn a transcoder instance
func CheckStreamID(stream_id []byte) (io.WriteCloser, string, bool) { func CheckStreamID(stream_id []byte) (io.WriteCloser, string, bool) {
stream_key := make([]byte, 0) stream_key := make([]byte, 0)
for i := len(stream_id) - 1; i >= 0; i-- { for i := len(stream_id) - 1; i >= 0; i-- {
@ -50,6 +59,7 @@ func CheckStreamID(stream_id []byte) (io.WriteCloser, string, bool) {
return nil, stream_key_string, false return nil, stream_key_string, false
} }
// checks if folder exists corresponding to the stream_key
func check_stream_key(stream_key string) (bool) { func check_stream_key(stream_key string) (bool) {
base_dir, _ := os.UserHomeDir() base_dir, _ := os.UserHomeDir()
if fileinfo, err := os.Stat(base_dir + "/live/" + stream_key); err == nil && fileinfo.IsDir() { if fileinfo, err := os.Stat(base_dir + "/live/" + stream_key); err == nil && fileinfo.IsDir() {

View file

@ -19,9 +19,16 @@ func (tunnel *Tunnel) Start() {
fmt.Println(r) fmt.Println(r)
} }
*a = true *a = true
}(&(tunnel.broken)) }(&(tunnel.broken)) // force mark tunnel for deletion if any error occurs
tunnel.state = NewSRTManager(tunnel.socket) tunnel.state = NewSRTManager(tunnel.socket)
// central tunnel loop, read incoming, process and generate response
// write response if any
for { for {
// force check since no new packets after shutdown
if tunnel.state.state == 3 {
tunnel.broken = true
break
}
packet, err := tunnel.ReadPacket() packet, err := tunnel.ReadPacket()
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
@ -38,6 +45,8 @@ func (tunnel *Tunnel) Start() {
} }
} }
// send a shutdown command, for use when tunnel gets broken
// not ideal but works
func (tunnel *Tunnel) Shutdown() { func (tunnel *Tunnel) Shutdown() {
if tunnel.state != nil && tunnel.state.state > 1 { if tunnel.state != nil && tunnel.state.state > 1 {
packet := tunnel.state.create_basic_header() packet := tunnel.state.create_basic_header()
@ -63,6 +72,6 @@ func (tunnel *Tunnel) WritePacket(packet *Packet) {
} }
func (tunnel *Tunnel) ReadPacket() (*Packet, error) { func (tunnel *Tunnel) ReadPacket() (*Packet, error) {
packet := <- tunnel.queue packet := <- tunnel.queue // blocking read, should add timeout here
return ParsePacket(packet) return ParsePacket(packet)
} }