Compare commits

..

No commits in common. "9b85fdb45603dd2ea111813ea9e6042c6ea536e9" and "25ce4427e65486e0a9f189919319543ad0f5233c" have entirely different histories.

10 changed files with 36 additions and 332 deletions

View file

@ -2,17 +2,17 @@
Extremely basic live stream server Extremely basic live stream server
Currently implements a degraded subset of RTMP and SRT for ingest. Encryption for SRT is implemented with AES-CTR with hardcoded passphrase "srttestpass". Only 16-byte keys have been tested, but nothing should stop the implementation from having problems in principle Currently implements a degraded subset of RTMP and SRT (un-encrypted) for ingest.
Uses the std lib http server implementation for the http serving side. Uses the std lib http server implementation for the http serving side.
**Not intended for actual use**. The stream key use is not secure and is used to handle directories without a user db system, than to provide auth. Same goes for the SRT passphrase. Also just accepts connections so will get DDOS'd immediately. **Not intended for actual use**. The stream key use is not secure and is used to handle directories without a user db system, than to provide auth
Limits to a single stream at a time, mostly for the lack of db to handle connections and user information rather than concurrency problems. Limits to a single stream at a time, mostly for the lack of db to handle connections and user information rather than concurrency problems.
Currently always transcodes to vp9 + opus, segments to fragmented mp4. Creates one segment playlist, no manifest. Uses ffmpeg Currently always transcodes to vp9 + opus, segments to fragmented mp4. Creates one segment playlist, no manifest. Uses ffmpeg
HTTP streaming relies on hls-player-js. Will be broken for standard hls players until I figure out how to modify the `EXT-X-MAP:URI` field to prepend a path prefix without changing directories. Or finish the transcoder project HTTP streaming relies on hls-player-js. Will be broken for standard hls players until I figure out how to modify the `EXT-X-MAP:URI` field to prepend a path prefix without changing directories.
Currently produces no logs nor debug info. Will just abandon a connection if there is a problem. Will not send any RTMP replies since flash server docs seem dead and abort messages are netStream commands. Currently produces no logs nor debug info. Will just abandon a connection if there is a problem. Will not send any RTMP replies since flash server docs seem dead and abort messages are netStream commands.

14
main.go
View file

@ -4,25 +4,19 @@ import (
"stream_server/rtmp" "stream_server/rtmp"
"stream_server/http" "stream_server/http"
"stream_server/srt" "stream_server/srt"
"flag"
) )
const ( const (
SRVTYPE_RTMP uint = iota SRVTYPE_RTMP uint8 = iota
SRVTYPE_SRT SRVTYPE_SRT
) )
func main() { func main() {
ingest_type := flag.Uint("ingest_type", 0, "Ingest server type, 0 for RTMP, 1 for SRT") err := NewIngestServer(SRVTYPE_SRT, "7878")
ingest_port := flag.String("ingest_port", "7878", "Port for stream intake")
http_port := flag.String("http_port", "7879", "Port to serve http requests")
flag.Parse()
err := NewIngestServer(*ingest_type, *ingest_port)
if err != nil { if err != nil {
panic(err) panic(err)
} }
err = NewHTTPServer(*http_port) err = NewHTTPServer("7879")
if err != nil { if err != nil {
panic(err) panic(err)
} }
@ -30,7 +24,7 @@ func main() {
} }
} }
func NewIngestServer(srvr_type uint, port string) (error) { func NewIngestServer(srvr_type uint8, port string) (error) {
var err error var err error
switch srvr_type { switch srvr_type {
case 0: case 0:

View file

@ -1,167 +0,0 @@
package srt
import (
"crypto/aes"
"crypto/sha1"
"crypto/hmac"
"hash"
"math"
"encoding/binary"
"errors"
"crypto/cipher"
)
type CryptHandler struct {
salt [16]byte
key_len uint8
odd_sek cipher.Block
even_sek cipher.Block
}
// init will apply to KM message immediately
func NewCryptHandler(passphrase string, km_msg *KMMSG) (*CryptHandler) {
crypt := new(CryptHandler)
crypt.key_len = km_msg.key_len
crypt.salt = km_msg.salt
ok := crypt.Unwrap(km_msg.wrapped_key, passphrase, km_msg.key_type)
if !ok {
return nil
}
return crypt
}
// PRF as defined in RFC 8018
func PRF(h hash.Hash, input []byte) ([]byte) {
h.Reset()
h.Write(input)
return h.Sum(nil)
}
// gets KEK from passphrase and salt according to fixed SRT iterations and algo from RFC doc
// see RFC 8018 for implementation details
func SRT_PBKDF2(passphrase string, salt []byte, dklen uint8) ([]byte) {
prf := hmac.New(sha1.New, []byte(passphrase))
hlen := prf.Size()
l := int(math.Ceil(float64(dklen) / float64(hlen)))
r := int(dklen) - (l - 1) * hlen
key := make([]byte, 0)
for block := 1; block <= l; block++ {
U := make([]byte, hlen)
T := make([]byte, hlen)
block_i := make([]byte, 4)
binary.BigEndian.PutUint32(block_i, uint32(block))
U = PRF(prf, append(salt, block_i...))
copy(T, U)
// skip one iter since done above
for n := 1; n < 2048; n++ {
U = PRF(prf, U)
for x := range T {
T[x] ^= U[x]
}
}
// final block may not use entire SHA output, still need full during computation
if block == l {
T = T[:r]
}
// final key is appended sequence of all blocks computed independently
key = append(key, T...)
}
return key
}
// See RFC 3394, inplace implementation
func AES_UNWRAP(key []byte, wrapped []byte) ([]byte, error) {
seks := make([]byte, 0) // bytes past IV
cipher, err := aes.NewCipher(key)
if err != nil {
return seks, err
}
A := wrapped[:8] // IV bytes
n := len(wrapped) / 8 - 1
R := make([][]byte, n) // actual message (SEKs)
for i := range R {
R[i] = wrapped[(i + 1) * 8: (i + 2) * 8]
}
for j := 5; j >= 0; j-- {
for i := n; i > 0; i-- {
t := make([]byte, 8)
binary.BigEndian.PutUint64(t, uint64(n * j + i))
for k := range t {
t[k] ^= A[k]
}
B := make([]byte, 16)
cipher.Decrypt(B, append(t, R[i - 1]...))
copy(A, B[:8])
copy(R[i - 1], B[8:])
}
}
// SRT uses default IV, 8 repeating bytes of 0xa6 prepended in wrap, check if
// preserved in unwrap
for i := range A {
if A[i] != 0xa6 {
return seks, errors.New("IV not default")
}
}
// R is 8 byte blocks, keys can be 16-32 bytes, prepend all together and
// let wrappers figure it out
for _, v := range R {
seks = append(seks, v...)
}
return seks, nil
}
// unwrap and store SEK ciphers, key_type defined as KK 2-bit value in Key Material from SRT docs
func (crypt *CryptHandler) Unwrap(wrapped_key []byte, passphrase string, key_type uint8) (bool) {
kek := SRT_PBKDF2(passphrase, crypt.salt[8:], crypt.key_len)
// need a copy since original will be sent back
wrapped_copy := make([]byte, len(wrapped_key))
copy(wrapped_copy, wrapped_key)
seks, err := AES_UNWRAP(kek, wrapped_copy)
// either unwrap fails or key len does not match expected (1 or 2 SEKs len identical)
if err != nil || len(seks) % int(crypt.key_len) != 0 {
return false
}
// always have one SEK, if more bytes (second key) and peer did not send 2 keys
// something is wrong
sek_1 := seks[:crypt.key_len]
if len(seks) > int(crypt.key_len) && key_type != 3 {
return false
}
switch key_type {
case 1:
crypt.even_sek, _ = aes.NewCipher(sek_1)
case 2:
crypt.odd_sek, _ = aes.NewCipher(sek_1)
case 3:
sek_2 := seks[crypt.key_len:]
crypt.even_sek, _ = aes.NewCipher(sek_1)
crypt.odd_sek, _ = aes.NewCipher(sek_2)
default:
return false
}
return true
}
func (crypt *CryptHandler) Decrypt(pkt *Packet) {
var sek cipher.Block
switch pkt.header_info.(*DataHeader).msg_flags & 0x6 {
case 2:
sek = crypt.even_sek
case 4:
sek = crypt.odd_sek
default:
return
}
IV := make([]byte, crypt.key_len)
binary.BigEndian.PutUint32(IV[10:14], pkt.header_info.(*DataHeader).seq_num)
for i := 0; i < 14; i++ {
IV[i] ^= crypt.salt[i]
}
ctr := cipher.NewCTR(sek, IV)
ctr.XORKeyStream(pkt.cif.([]byte), pkt.cif.([]byte))
}

View file

@ -13,20 +13,12 @@ type Datum struct {
next *Datum next *Datum
} }
// linked chain of datums, specifically to store continuous segments
// when a packet is missed, start and end can be used to generate lost
// packet reports, and easily link when missing is received
// 1..2..3..4 6..7..8..9
// chain_1 chain_2
// nack: get 5
type DatumLink struct { type DatumLink struct {
queued int // remove eventually, was to be used for ACK recv rate calcs, not needed queued int
root *Datum root *Datum
end *Datum end *Datum
} }
// data type and function to allow sorting so order can be ignored during
// linking since each is sequential on outset
type chains []*DatumLink type chains []*DatumLink
func (c chains) Len() (int) { func (c chains) Len() (int) {
@ -37,7 +29,6 @@ func (c chains) Swap(i, j int) {
c[i], c[j] = c[j], c[i] c[i], c[j] = c[j], c[i]
} }
// chain_1 is less then chain_2 when chain_1 ends before chain_2 starts
func (c chains) Less(i, j int) (bool) { func (c chains) Less(i, j int) (bool) {
x_1 := c[i].end.seq_num x_1 := c[i].end.seq_num
x_2 := c[j].root.seq_num x_2 := c[j].root.seq_num
@ -52,7 +43,6 @@ type DatumStorage struct {
offshoots chains offshoots chains
} }
// append new packet to end of buffer
func (buffer *DatumLink) NewDatum(pkt *Packet) { func (buffer *DatumLink) NewDatum(pkt *Packet) {
datum := new(Datum) datum := new(Datum)
datum.seq_num = pkt.header_info.(*DataHeader).seq_num datum.seq_num = pkt.header_info.(*DataHeader).seq_num
@ -63,7 +53,6 @@ func (buffer *DatumLink) NewDatum(pkt *Packet) {
buffer.end = datum buffer.end = datum
} }
// create a new datumlink with root and end at the given packet
func NewDatumLink(pkt *Packet) (*DatumLink) { func NewDatumLink(pkt *Packet) (*DatumLink) {
buffer := new(DatumLink) buffer := new(DatumLink)
root_datum := new(Datum) root_datum := new(Datum)
@ -77,14 +66,12 @@ func NewDatumLink(pkt *Packet) (*DatumLink) {
return buffer return buffer
} }
// initialize storage with the given data packet, in the main chain
func NewDatumStorage(packet *Packet) (*DatumStorage) { func NewDatumStorage(packet *Packet) (*DatumStorage) {
storage := new(DatumStorage) storage := new(DatumStorage)
storage.main = NewDatumLink(packet) storage.main = NewDatumLink(packet)
return storage return storage
} }
// purge all packets in the main chain except the last for future linkage
func (storage *DatumStorage) Expunge(output io.WriteCloser) (error) { func (storage *DatumStorage) Expunge(output io.WriteCloser) (error) {
curr_datum := storage.main.root curr_datum := storage.main.root
seq_num_end := storage.main.end.seq_num seq_num_end := storage.main.end.seq_num
@ -99,7 +86,6 @@ func (storage *DatumStorage) Expunge(output io.WriteCloser) (error) {
return nil return nil
} }
// check if the given sequence number should already be inside the given buffer
func (buffer *DatumLink) Holds(new_seq_num uint32) (bool) { func (buffer *DatumLink) Holds(new_seq_num uint32) (bool) {
start := buffer.root.seq_num start := buffer.root.seq_num
end := buffer.end.seq_num end := buffer.end.seq_num
@ -112,8 +98,6 @@ func (buffer *DatumLink) Holds(new_seq_num uint32) (bool) {
return true return true
} }
// check if the given seq num lies before the given buffer starts
// buffer is After seq num?
func (buffer *DatumLink) After(new_seq_num uint32) (bool) { func (buffer *DatumLink) After(new_seq_num uint32) (bool) {
start := buffer.root.seq_num start := buffer.root.seq_num
if serial_less(new_seq_num, start, 31) { if serial_less(new_seq_num, start, 31) {
@ -122,8 +106,6 @@ func (buffer *DatumLink) After(new_seq_num uint32) (bool) {
return false return false
} }
// check if the given seq num lies after the given buffer starts
// buffer is Before seq num?
func (buffer *DatumLink) Before(new_seq_num uint32) (bool) { func (buffer *DatumLink) Before(new_seq_num uint32) (bool) {
end := buffer.end.seq_num end := buffer.end.seq_num
if serial_less(end, new_seq_num, 31) { if serial_less(end, new_seq_num, 31) {
@ -132,7 +114,6 @@ func (buffer *DatumLink) Before(new_seq_num uint32) (bool) {
return false return false
} }
// add the given packet to the appropriate buffer or add a new offshoot buffer
func (storage *DatumStorage) NewDatum(pkt *Packet) { func (storage *DatumStorage) NewDatum(pkt *Packet) {
new_pkt_num := pkt.header_info.(*DataHeader).seq_num new_pkt_num := pkt.header_info.(*DataHeader).seq_num
prev_num := (new_pkt_num - 1) % uint32(1 << 31) prev_num := (new_pkt_num - 1) % uint32(1 << 31)
@ -158,16 +139,12 @@ func (storage *DatumStorage) NewDatum(pkt *Packet) {
} }
} }
// Link 2 buffers to each other
func (buffer *DatumLink) Link(buffer_next *DatumLink) { func (buffer *DatumLink) Link(buffer_next *DatumLink) {
buffer.end.next = buffer_next.root buffer.end.next = buffer_next.root
buffer.end = buffer_next.end buffer.end = buffer_next.end
buffer.queued += buffer_next.queued buffer.queued += buffer_next.queued
} }
// check if the start of the buffer_next is sequentially next to the end of buffer
// or it is too late to get the true next packets to link and the packets must be linked
// given a 31-bit wrap around
func check_append_serial_next(buffer *DatumLink, buffer_next *DatumLink, curr_time uint32) (bool) { func check_append_serial_next(buffer *DatumLink, buffer_next *DatumLink, curr_time uint32) (bool) {
seq_1 := buffer.end.seq_num seq_1 := buffer.end.seq_num
seq_2 := (seq_1 + 1) % uint32(math.Pow(2, 31)) seq_2 := (seq_1 + 1) % uint32(math.Pow(2, 31))
@ -178,7 +155,6 @@ func check_append_serial_next(buffer *DatumLink, buffer_next *DatumLink, curr_ti
return false return false
} }
// Given the current storage state, check what packets need to added to fully link main
func (storage *DatumStorage) GenNACKCIF() (*NACKCIF, bool) { func (storage *DatumStorage) GenNACKCIF() (*NACKCIF, bool) {
if len(storage.offshoots) == 0 { if len(storage.offshoots) == 0 {
return nil, false return nil, false
@ -197,14 +173,13 @@ func (storage *DatumStorage) GenNACKCIF() (*NACKCIF, bool) {
return cif, true return cif, true
} }
// Try to relink all chains wherever possible or necessary due to TLPKTDROP
func (storage *DatumStorage) Relink(curr_time uint32) { func (storage *DatumStorage) Relink(curr_time uint32) {
sort.Sort(storage.offshoots) sort.Sort(storage.offshoots)
buffer := storage.main buffer := storage.main
i := 0 i := 0
for i < len(storage.offshoots) { for i < len(storage.offshoots) {
if check_append_serial_next(buffer, storage.offshoots[i], curr_time) { if check_append_serial_next(buffer, storage.offshoots[i], curr_time) {
storage.offshoots = append(storage.offshoots[:i], storage.offshoots[i + 1:]...) // nuke the chain since it is not needed anymore storage.offshoots = append(storage.offshoots[:i], storage.offshoots[i + 1:]...)
} else { } else {
buffer = storage.offshoots[i] buffer = storage.offshoots[i]
i++ i++
@ -212,7 +187,6 @@ func (storage *DatumStorage) Relink(curr_time uint32) {
} }
} }
// check if a is less than b under serial arithmetic (modulo operations)
func serial_less(a uint32, b uint32, bits int) (bool) { func serial_less(a uint32, b uint32, bits int) (bool) {
if (a < b && b-a < (1 << (bits - 1))) || (a > b && a-b > (1 << (bits - 1))) { if (a < b && b-a < (1 << (bits - 1))) || (a > b && a-b > (1 << (bits - 1))) {
return true return true

View file

@ -15,7 +15,7 @@ func NewIntake(l net.PacketConn, max_conns int) (*Intake) {
intake := new(Intake) intake := new(Intake)
intake.max_conns = max_conns intake.max_conns = max_conns
intake.tunnels = make([]*Tunnel, 0) intake.tunnels = make([]*Tunnel, 0)
intake.buffer = make([]byte, 1500) // each packet is restricted to a max size of 1500 intake.buffer = make([]byte, 1500)
intake.socket = l intake.socket = l
return intake return intake
@ -26,9 +26,9 @@ func (intake *Intake) NewTunnel(l net.PacketConn, peer net.Addr) (*Tunnel) {
tunnel := new(Tunnel) tunnel := new(Tunnel)
tunnel.socket = l tunnel.socket = l
tunnel.peer = peer tunnel.peer = peer
tunnel.queue = make(chan []byte, 10) // packet buffer, will cause packet loss if low tunnel.queue = make(chan []byte, 10)
intake.tunnels = append(intake.tunnels, tunnel) intake.tunnels = append(intake.tunnels, tunnel)
go tunnel.Start() // start the tunnel SRT processing go tunnel.Start()
return tunnel return tunnel
} }
return nil return nil
@ -37,7 +37,6 @@ func (intake *Intake) NewTunnel(l net.PacketConn, peer net.Addr) (*Tunnel) {
func (intake *Intake) get_tunnel(peer net.Addr) (*Tunnel) { func (intake *Intake) get_tunnel(peer net.Addr) (*Tunnel) {
var tunnel *Tunnel var tunnel *Tunnel
for i := 0; i < len(intake.tunnels); i++ { for i := 0; i < len(intake.tunnels); i++ {
// check if tunnels are broken and remove
if intake.tunnels[i].broken { if intake.tunnels[i].broken {
intake.tunnels[i].Shutdown() intake.tunnels[i].Shutdown()
intake.tunnels = append(intake.tunnels[:i], intake.tunnels[i+1:]...) intake.tunnels = append(intake.tunnels[:i], intake.tunnels[i+1:]...)
@ -48,9 +47,6 @@ func (intake *Intake) get_tunnel(peer net.Addr) (*Tunnel) {
tunnel = intake.tunnels[i] tunnel = intake.tunnels[i]
} }
} }
// if no tunnel was found, make one
// should be after conclusion handshake, but wanted to keep all protocol
// related actions separate from UDP handling
if tunnel == nil { if tunnel == nil {
tunnel = intake.NewTunnel(intake.socket, peer) tunnel = intake.NewTunnel(intake.socket, peer)
} }
@ -62,16 +58,14 @@ func (intake *Intake) get_tunnel(peer net.Addr) (*Tunnel) {
func (intake *Intake) Read() { func (intake *Intake) Read() {
n, peer, err := intake.socket.ReadFrom(intake.buffer) n, peer, err := intake.socket.ReadFrom(intake.buffer)
if err != nil { if err != nil {
return // ignore UDP errors return
} }
// find the SRT/UDT tunnel corresponding to the given peer
tunnel := intake.get_tunnel(peer) tunnel := intake.get_tunnel(peer)
if tunnel == nil { if tunnel == nil {
return return
} }
pkt := make([]byte, n) pkt := make([]byte, n)
copy(pkt, intake.buffer[:n]) copy(pkt, intake.buffer[:n])
// send a copy to the corresponding tunnels packet queue if not full
select { select {
case tunnel.queue <- pkt: case tunnel.queue <- pkt:
default: default:

View file

@ -5,7 +5,6 @@ import (
"encoding/binary" "encoding/binary"
) )
// arbitrary indexing
const ( const (
DATA uint8 = iota DATA uint8 = iota
HANDSHAKE HANDSHAKE
@ -16,7 +15,6 @@ const (
DROP DROP
) )
// see SRT protocol RFC for information
type ControlHeader struct { type ControlHeader struct {
ctrl_type uint16 ctrl_type uint16
ctrl_subtype uint16 ctrl_subtype uint16
@ -82,12 +80,10 @@ type pckts_range struct {
end uint32 end uint32
} }
// should be safe to ignore
type DROPCIF struct { type DROPCIF struct {
to_drop pckts_range to_drop pckts_range
} }
// header and cif are interfaces to allow easier typing, fitting above structs
type Packet struct { type Packet struct {
packet_type uint8 packet_type uint8
timestamp uint32 timestamp uint32
@ -96,12 +92,11 @@ type Packet struct {
cif interface{} cif interface{}
} }
// completely pointless since only implementing receiver, here anyway
func marshall_data_packet(packet *Packet, header []byte) ([]byte, error) { func marshall_data_packet(packet *Packet, header []byte) ([]byte, error) {
info, ok_head := packet.header_info.(*DataHeader) info, ok_head := packet.header_info.(*DataHeader)
data, ok_data := packet.cif.([]byte) data, ok_data := packet.cif.([]byte)
if !ok_head || !ok_data { if !ok_head || !ok_data {
return header, errors.New("data packet does not have data header or data") return header, errors.New("data packet does not have data header")
} }
binary.BigEndian.PutUint32(header[:4], info.seq_num) binary.BigEndian.PutUint32(header[:4], info.seq_num)
head2 := (uint32(info.msg_flags) << 26) + info.msg_num head2 := (uint32(info.msg_flags) << 26) + info.msg_num
@ -167,9 +162,6 @@ func marshall_nack_cif(data *NACKCIF) ([]byte) {
return loss_list_bytes return loss_list_bytes
} }
// locations and length are determined by protocol,
// no real point abstracting that
// could be cleaner with reflects, but added work for very little real gain
func marshall_ack_cif(data *ACKCIF) ([]byte) { func marshall_ack_cif(data *ACKCIF) ([]byte) {
cif := make([]byte, 28) cif := make([]byte, 28)
binary.BigEndian.PutUint32(cif[:4], data.last_acked) binary.BigEndian.PutUint32(cif[:4], data.last_acked)
@ -183,7 +175,6 @@ func marshall_ack_cif(data *ACKCIF) ([]byte) {
return cif return cif
} }
// same as above
func marshall_hs_cif(data *HandshakeCIF) ([]byte) { func marshall_hs_cif(data *HandshakeCIF) ([]byte) {
cif := make([]byte, 48) cif := make([]byte, 48)
binary.BigEndian.PutUint32(cif[:4], data.version) binary.BigEndian.PutUint32(cif[:4], data.version)
@ -210,10 +201,7 @@ func marshall_hs_cif(data *HandshakeCIF) ([]byte) {
binary.BigEndian.PutUint16(ext_buff[12:14], contents.recv_delay) binary.BigEndian.PutUint16(ext_buff[12:14], contents.recv_delay)
binary.BigEndian.PutUint16(ext_buff[14:16], contents.send_delay) binary.BigEndian.PutUint16(ext_buff[14:16], contents.send_delay)
case 4: case 4:
contents, ok := extension.ext_contents.(*KMMSG) contents := extension.ext_contents.(*KMMSG)
if !ok { // handle km_state error
copy(ext_buff[4:8], extension.ext_contents.([]byte))
} else {
binary.BigEndian.PutUint32(ext_buff[4:8], uint32(0x12202900) | uint32(contents.key_type)) binary.BigEndian.PutUint32(ext_buff[4:8], uint32(0x12202900) | uint32(contents.key_type))
binary.BigEndian.PutUint32(ext_buff[12:16], uint32(0x02000200)) binary.BigEndian.PutUint32(ext_buff[12:16], uint32(0x02000200))
binary.BigEndian.PutUint32(ext_buff[16:20], uint32(0x0400) | uint32(contents.key_len / 4)) binary.BigEndian.PutUint32(ext_buff[16:20], uint32(0x0400) | uint32(contents.key_len / 4))
@ -221,7 +209,6 @@ func marshall_hs_cif(data *HandshakeCIF) ([]byte) {
ext_buff[20 + i] = contents.salt[i] ext_buff[20 + i] = contents.salt[i]
} }
copy(ext_buff[36:], contents.wrapped_key) copy(ext_buff[36:], contents.wrapped_key)
}
default: default:
copy(ext_buff[4:], extension.ext_contents.([]byte)) copy(ext_buff[4:], extension.ext_contents.([]byte))
} }
@ -244,7 +231,7 @@ func MarshallPacket(packet *Packet, agent *SRTManager) ([]byte, error) {
func parse_data_packet(pkt *Packet, buffer []byte) (error) { func parse_data_packet(pkt *Packet, buffer []byte) (error) {
info := new(DataHeader) info := new(DataHeader)
info.seq_num = binary.BigEndian.Uint32(buffer[:4]) info.seq_num = binary.BigEndian.Uint32(buffer[:4])
info.msg_flags = buffer[4] >> 2 // unsused since live streaming makes it irrelevant, kk for encrypt eventually info.msg_flags = buffer[4] >> 2
info.msg_num = binary.BigEndian.Uint32(buffer[4:8]) & 0x03ffffff info.msg_num = binary.BigEndian.Uint32(buffer[4:8]) & 0x03ffffff
pkt.header_info = info pkt.header_info = info
@ -320,14 +307,13 @@ func parse_hs_cif(cif *HandshakeCIF, buffer []byte) (error) {
case 3: case 3:
content := new(KMMSG) content := new(KMMSG)
content.key_type = extensions[7] & 0x3 content.key_type = extensions[7] & 0x3
content.key_len = extensions[19] * 4 content.key_len = extensions[19]
for i := 0; i < 16; i++ { for i := 0; i < 4; i++ {
content.salt[i] = extensions[20 + i] content.salt[i] = extensions[20 + i]
} }
// -36 from actual content len, extensions includes headers as well wrap_key_len := 4 + ext.ext_len - 24
wrap_key_len := 4 + ext.ext_len - 36
content.wrapped_key = make([]byte, wrap_key_len) content.wrapped_key = make([]byte, wrap_key_len)
copy(content.wrapped_key, extensions[36:36 + wrap_key_len]) copy(content.wrapped_key, extensions[24:24 + wrap_key_len])
ext.ext_contents = content ext.ext_contents = content
default: default:
content := make([]byte, ext.ext_len) content := make([]byte, ext.ext_len)

View file

@ -18,7 +18,6 @@ const (
) )
type SRTManager struct { type SRTManager struct {
crypt *CryptHandler
state uint8 state uint8
init time.Time init time.Time
syn_cookie uint32 syn_cookie uint32
@ -40,12 +39,11 @@ func NewSRTManager(l net.PacketConn) (*SRTManager) {
agent := new(SRTManager) agent := new(SRTManager)
agent.init = time.Now() agent.init = time.Now()
agent.socket = l agent.socket = l
agent.bw = 15000 // in pkts (mtu bytes) per second agent.bw = 15000
agent.mtu = 1500 agent.mtu = 1500
return agent return agent
} }
// adds basic information present in all packets, timestamp and destination SRT socket
func (agent *SRTManager) create_basic_header() (*Packet) { func (agent *SRTManager) create_basic_header() (*Packet) {
packet := new(Packet) packet := new(Packet)
packet.timestamp = uint32(time.Now().Sub(agent.init).Microseconds()) packet.timestamp = uint32(time.Now().Sub(agent.init).Microseconds())
@ -76,7 +74,6 @@ func (agent *SRTManager) create_induction_resp() (*Packet) {
packet.cif = cif packet.cif = cif
// use the handshake as a placeholder ack-ackack rtt initializer
var init_ping_time [2]time.Time var init_ping_time [2]time.Time
init_ping_time[0] = time.Now() init_ping_time[0] = time.Now()
agent.pings = append(agent.pings, init_ping_time) agent.pings = append(agent.pings, init_ping_time)
@ -84,7 +81,6 @@ func (agent *SRTManager) create_induction_resp() (*Packet) {
return packet return packet
} }
// not ideal, but works
func (agent *SRTManager) make_syn_cookie(peer net.Addr) { func (agent *SRTManager) make_syn_cookie(peer net.Addr) {
t := uint32(time.Now().Unix()) >> 6 t := uint32(time.Now().Unix()) >> 6
s := sha256.New() s := sha256.New()
@ -116,7 +112,7 @@ func (agent *SRTManager) create_conclusion_resp() (*Packet) {
cif := new(HandshakeCIF) cif := new(HandshakeCIF)
cif.version = 5 cif.version = 5
cif.ext_field = 0x1 // 1 for HS-ext, does not allow encryption currently cif.ext_field = 0x1
cif.sock_id = 1 cif.sock_id = 1
cif.mtu = agent.mtu cif.mtu = agent.mtu
cif.max_flow = 8192 cif.max_flow = 8192
@ -149,44 +145,20 @@ func (agent *SRTManager) process_conclusion(packet *Packet) (*Packet) {
hs_cif := packet.cif.(*HandshakeCIF) hs_cif := packet.cif.(*HandshakeCIF)
if hs_cif.hs_type == 0xffffffff && hs_cif.syn_cookie == agent.syn_cookie { if hs_cif.hs_type == 0xffffffff && hs_cif.syn_cookie == agent.syn_cookie {
for _, v := range hs_cif.hs_extensions { for _, v := range hs_cif.hs_extensions {
// force client to add a stream_id for output location
// to do: add encryption handling
switch v.ext_type { switch v.ext_type {
case 5: case 5:
writer, stream_key, ok := CheckStreamID(v.ext_contents.([]byte)) writer, stream_key, ok := CheckStreamID(v.ext_contents.([]byte))
agent.stream_key = stream_key agent.stream_key = stream_key
if !ok { if !ok {
resp_packet.cif.(*HandshakeCIF).hs_type = 1003 resp_packet.cif.(*HandshakeCIF).hs_type = 1003
agent.state = 3
return resp_packet return resp_packet
} else { } else {
agent.output = writer agent.output = writer
CleanFiles(agent.stream_key, 0) CleanFiles(agent.stream_key, 0)
} }
case 3:
resp_packet.cif.(*HandshakeCIF).ext_field = 3
// passphrase harcoded for testing, should pass in somehow with a user management system
crypt_handler := NewCryptHandler("srttestpass", v.ext_contents.(*KMMSG))
if crypt_handler == nil { // if sek unwrap required but fails
agent.state = 3
resp_packet.cif.(*HandshakeCIF).hs_type = 1010
resp_ext := new(HandshakeExtension)
resp_ext.ext_type = 4
resp_ext.ext_len = 4
km_state := make([]byte, 4)
km_state[3] = 4 // BADSECRET code
resp_ext.ext_contents = km_state
resp_packet.cif.(*HandshakeCIF).hs_extensions = append(resp_packet.cif.(*HandshakeCIF).hs_extensions, resp_ext)
return resp_packet
}
// else return since needed
resp_packet.cif.(*HandshakeCIF).hs_extensions = append(resp_packet.cif.(*HandshakeCIF).hs_extensions, v)
v.ext_type = 4
agent.crypt = crypt_handler
} }
} }
agent.pings[0][1] = time.Now() agent.pings[0][1] = time.Now()
// if output was successfully initialized, proceed with data looping
if agent.output != nil { if agent.output != nil {
agent.state = DATA_LOOP agent.state = DATA_LOOP
return resp_packet return resp_packet
@ -208,13 +180,9 @@ func (agent *SRTManager) create_ack_report() (*Packet) {
packet.header_info = info packet.header_info = info
cif := new(ACKCIF) cif := new(ACKCIF)
// main has the latest unbroken chain, either no other packets, or
// missing packet which must be nak'd
cif.last_acked = agent.storage.main.end.seq_num cif.last_acked = agent.storage.main.end.seq_num
cif.bw = agent.bw cif.bw = agent.bw
// calculate rtt variance from valid ping pairs, use last value as rtt of last
// exchange since that's what it is
var rtt_sum uint32 var rtt_sum uint32
var rtt_2_sum uint32 var rtt_2_sum uint32
var rtt uint32 var rtt uint32
@ -230,10 +198,7 @@ func (agent *SRTManager) create_ack_report() (*Packet) {
cif.rtt = rtt cif.rtt = rtt
cif.var_rtt = uint32(rtt_2_sum / rtt_n) - uint32(math.Pow(float64(rtt_sum / rtt_n), 2)) cif.var_rtt = uint32(rtt_2_sum / rtt_n) - uint32(math.Pow(float64(rtt_sum / rtt_n), 2))
// use the packets received since the last ack report was sent to calc
// estimated recv rates
cif.pkt_recv_rate = uint32(len(agent.pkt_sizes) * 100) cif.pkt_recv_rate = uint32(len(agent.pkt_sizes) * 100)
// arbitrary, should use len(channel) to set this but doesn't really seem to matter
cif.buff_size = 100 cif.buff_size = 100
var bytes_recvd uint32 var bytes_recvd uint32
for _, v := range agent.pkt_sizes { for _, v := range agent.pkt_sizes {
@ -245,7 +210,6 @@ func (agent *SRTManager) create_ack_report() (*Packet) {
var next_ping_pair [2]time.Time var next_ping_pair [2]time.Time
next_ping_pair[0] = time.Now() next_ping_pair[0] = time.Now()
// only keep last 100 acks, use offset for correct ackack ping indexing
if len(agent.pings) >= 100 { if len(agent.pings) >= 100 {
agent.pings = append(agent.pings[1:], next_ping_pair) agent.pings = append(agent.pings[1:], next_ping_pair)
agent.ping_offset++ agent.ping_offset++
@ -256,7 +220,6 @@ func (agent *SRTManager) create_ack_report() (*Packet) {
return packet return packet
} }
// only need the recieve time from ackacks for rtt calcs, ignore otherwise
func (agent *SRTManager) handle_ackack(packet *Packet) { func (agent *SRTManager) handle_ackack(packet *Packet) {
ack_num := packet.header_info.(*ControlHeader).tsi ack_num := packet.header_info.(*ControlHeader).tsi
agent.pings[int(ack_num) - agent.ping_offset][1] = time.Now() agent.pings[int(ack_num) - agent.ping_offset][1] = time.Now()
@ -280,15 +243,9 @@ func (agent *SRTManager) create_nack_report() (*Packet) {
return packet return packet
} }
// handling packets during data loop
func (agent *SRTManager) process_data(packet *Packet) (*Packet) { func (agent *SRTManager) process_data(packet *Packet) (*Packet) {
switch packet.packet_type { switch packet.packet_type {
case DATA: case DATA:
// if data, add to storage, linking, etc
// then check if ack or nack can be generated (every 10 ms)
if agent.crypt != nil {
agent.crypt.Decrypt(packet)
}
agent.handle_data_storage(packet) agent.handle_data_storage(packet)
if time.Now().Sub(agent.pings[len(agent.pings) - 1][0]).Milliseconds() >= 10 { if time.Now().Sub(agent.pings[len(agent.pings) - 1][0]).Milliseconds() >= 10 {
return agent.create_ack_report() return agent.create_ack_report()
@ -299,9 +256,7 @@ func (agent *SRTManager) process_data(packet *Packet) (*Packet) {
case ACKACK: case ACKACK:
agent.handle_ackack(packet) agent.handle_ackack(packet)
case SHUTDOWN: case SHUTDOWN:
// state 3 should raise error and shutdown tunnel, agent.state = 3
// for now start cleanup procedure in 10s
agent.state = BROKEN
go CleanFiles(agent.stream_key, 10) go CleanFiles(agent.stream_key, 10)
default: default:
return nil return nil
@ -310,28 +265,18 @@ func (agent *SRTManager) process_data(packet *Packet) (*Packet) {
} }
func (agent *SRTManager) handle_data_storage(packet *Packet) { func (agent *SRTManager) handle_data_storage(packet *Packet) {
// data packets always have []byte as "cif"
agent.pkt_sizes = append(agent.pkt_sizes, uint32(len(packet.cif.([]byte)))) agent.pkt_sizes = append(agent.pkt_sizes, uint32(len(packet.cif.([]byte))))
// initialize storage if does not exist, else add where it can
if agent.storage == nil { if agent.storage == nil {
agent.storage = NewDatumStorage(packet) agent.storage = NewDatumStorage(packet)
} else { } else {
agent.storage.NewDatum(packet) agent.storage.NewDatum(packet)
} }
// attempt to relink any offshoots
// timestamp for TLPKTDROP
if len(agent.storage.offshoots) != 0 { if len(agent.storage.offshoots) != 0 {
agent.storage.Relink(packet.timestamp) agent.storage.Relink(packet.timestamp)
} }
// write out all possible packets agent.storage.Expunge(agent.output)
if err := agent.storage.Expunge(agent.output); err != nil {
agent.state = BROKEN
}
} }
// determines appropriate packets and responses depending on tunnel state
// some need to ignore depending on state, eg
// late induction requests during conclusion phase
func (agent *SRTManager) Process(packet *Packet) (*Packet, error) { func (agent *SRTManager) Process(packet *Packet) (*Packet, error) {
switch agent.state { switch agent.state {
case INDUCTION: case INDUCTION:

View file

@ -5,8 +5,6 @@ import (
"fmt" "fmt"
) )
// main entry point, no concept of tunnels in UDP so need to implement
// that separately and cannot simply add a max connlimit here like with RTMP
func NewServer(port string) (error) { func NewServer(port string) (error) {
l, err := net.ListenPacket("udp", ":" + port) l, err := net.ListenPacket("udp", ":" + port)
if err != nil { if err != nil {
@ -17,13 +15,12 @@ func NewServer(port string) (error) {
} }
func start(l net.PacketConn) { func start(l net.PacketConn) {
// basic panic logging for debugging mostly
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
fmt.Println(r) fmt.Println(r)
} }
}() }()
intake := NewIntake(l, 1) // limit to one concurrent tunnel intake := NewIntake(l, 1)
for { for {
intake.Read() intake.Read()
} }

View file

@ -7,9 +7,9 @@ import (
"stream_server/transcoder" "stream_server/transcoder"
"time" "time"
"path/filepath" "path/filepath"
"fmt"
) )
// spawn a transcoder instance and return its stdin pipe
func NewWriter(stream_key string) (io.WriteCloser, error) { func NewWriter(stream_key string) (io.WriteCloser, error) {
transcoder_in, err := transcoder.NewTranscoder(stream_key) transcoder_in, err := transcoder.NewTranscoder(stream_key)
if err != nil { if err != nil {
@ -18,17 +18,11 @@ func NewWriter(stream_key string) (io.WriteCloser, error) {
return transcoder_in, nil return transcoder_in, nil
} }
// check if the init.mp4 segment has been modified in between a given sleep time
// if it hasn't (stream disconnected for longer) delete it
// else a new stream has started which shouldn't be deleted
func CleanFiles(stream_key string, delay time.Duration) { func CleanFiles(stream_key string, delay time.Duration) {
time.Sleep(delay * time.Second) time.Sleep(delay * time.Second)
base_dir, _ := os.UserHomeDir() base_dir, _ := os.UserHomeDir()
stream_dir := base_dir + "/live/" + stream_key stream_dir := base_dir + "/live/" + stream_key
fileinfo, file_ok := os.Stat(stream_dir + "/init.mp4") fileinfo, _ := os.Stat(stream_dir + "/init.mp4")
if file_ok != nil {
return
}
if time.Now().Sub(fileinfo.ModTime()) > delay * time.Second { if time.Now().Sub(fileinfo.ModTime()) > delay * time.Second {
leftover_files, _ := filepath.Glob(stream_dir + "/*") leftover_files, _ := filepath.Glob(stream_dir + "/*")
for _, file := range leftover_files { for _, file := range leftover_files {
@ -37,9 +31,6 @@ func CleanFiles(stream_key string, delay time.Duration) {
} }
} }
// stream_id is in reverse order, len is multiple of 4 padded with 0
// get the string, check if corresponding folder exists, then attempt
// to spawn a transcoder instance
func CheckStreamID(stream_id []byte) (io.WriteCloser, string, bool) { func CheckStreamID(stream_id []byte) (io.WriteCloser, string, bool) {
stream_key := make([]byte, 0) stream_key := make([]byte, 0)
for i := len(stream_id) - 1; i >= 0; i-- { for i := len(stream_id) - 1; i >= 0; i-- {
@ -59,7 +50,6 @@ func CheckStreamID(stream_id []byte) (io.WriteCloser, string, bool) {
return nil, stream_key_string, false return nil, stream_key_string, false
} }
// checks if folder exists corresponding to the stream_key
func check_stream_key(stream_key string) (bool) { func check_stream_key(stream_key string) (bool) {
base_dir, _ := os.UserHomeDir() base_dir, _ := os.UserHomeDir()
if fileinfo, err := os.Stat(base_dir + "/live/" + stream_key); err == nil && fileinfo.IsDir() { if fileinfo, err := os.Stat(base_dir + "/live/" + stream_key); err == nil && fileinfo.IsDir() {

View file

@ -19,16 +19,9 @@ func (tunnel *Tunnel) Start() {
fmt.Println(r) fmt.Println(r)
} }
*a = true *a = true
}(&(tunnel.broken)) // force mark tunnel for deletion if any error occurs }(&(tunnel.broken))
tunnel.state = NewSRTManager(tunnel.socket) tunnel.state = NewSRTManager(tunnel.socket)
// central tunnel loop, read incoming, process and generate response
// write response if any
for { for {
// force check since no new packets after shutdown
if tunnel.state.state == 3 {
tunnel.broken = true
break
}
packet, err := tunnel.ReadPacket() packet, err := tunnel.ReadPacket()
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
@ -45,8 +38,6 @@ func (tunnel *Tunnel) Start() {
} }
} }
// send a shutdown command, for use when tunnel gets broken
// not ideal but works
func (tunnel *Tunnel) Shutdown() { func (tunnel *Tunnel) Shutdown() {
if tunnel.state != nil && tunnel.state.state > 1 { if tunnel.state != nil && tunnel.state.state > 1 {
packet := tunnel.state.create_basic_header() packet := tunnel.state.create_basic_header()
@ -72,6 +63,6 @@ func (tunnel *Tunnel) WritePacket(packet *Packet) {
} }
func (tunnel *Tunnel) ReadPacket() (*Packet, error) { func (tunnel *Tunnel) ReadPacket() (*Packet, error) {
packet := <- tunnel.queue // blocking read, should add timeout here packet := <- tunnel.queue
return ParsePacket(packet) return ParsePacket(packet)
} }