Skip to content

Commit

Permalink
Implement mediaMTX RTMP auth (#3240)
Browse files Browse the repository at this point in the history
* Experiment with mediaMTX auth

Instead of using the runOnReady mediaMTX config I am using this instead:
`authHTTPAddress: http://localhost:5936/live/video-to-video/start`

* Wait for ffprobe call to succeed before trying to pull the stream

* Keep auth and start calls separate

* Auth on ready hook

* Fix kick connection and improve logging

* improve log line
  • Loading branch information
mjh1 authored Nov 13, 2024
1 parent 9b44187 commit b87c7c3
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 5 deletions.
7 changes: 5 additions & 2 deletions media/rtmp2segment.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ type MediaSegmenter struct {
}

func (ms *MediaSegmenter) RunSegmentation(in string, segmentHandler SegmentHandler) {

outFilePattern := filepath.Join(ms.Workdir, randomString()+"-%d.ts")
completionSignal := make(chan bool, 1)
wg := &sync.WaitGroup{}
Expand All @@ -36,15 +35,19 @@ func (ms *MediaSegmenter) RunSegmentation(in string, segmentHandler SegmentHandl
defer wg.Done()
processSegments(segmentHandler, outFilePattern, completionSignal)
}()

ffmpeg.FfmpegSetLogLevel(ffmpeg.FFLogWarning)
ffmpeg.Transcode3(&ffmpeg.TranscodeOptionsIn{
_, err := ffmpeg.Transcode3(&ffmpeg.TranscodeOptionsIn{
Fname: in,
}, []ffmpeg.TranscodeOptions{{
Oname: outFilePattern,
AudioEncoder: ffmpeg.ComponentOptions{Name: "copy"},
VideoEncoder: ffmpeg.ComponentOptions{Name: "copy"},
Muxer: ffmpeg.ComponentOptions{Name: "segment"},
}})
if err != nil {
slog.Error("Failed to run segmentation", "in", in, "err", err)
}
completionSignal <- true
slog.Info("sent completion signal, now waiting")
wg.Wait()
Expand Down
55 changes: 54 additions & 1 deletion server/ai_mediaserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"time"

Expand Down Expand Up @@ -360,12 +361,40 @@ func (ls *LivepeerServer) StartLiveVideo() http.Handler {
http.Error(w, "Missing stream name", http.StatusBadRequest)
return
}
sourceID := r.FormValue("source_id")
if sourceID == "" {
http.Error(w, "Missing source_id", http.StatusBadRequest)
return
}
sourceType := r.FormValue("source_type")
if sourceType == "" {
http.Error(w, "Missing source_type", http.StatusBadRequest)
return
}

if streamName == "out-stream" {
// skip for now; we don't want to re-publish our own outputs
return
}
ctx := clog.AddVal(r.Context(), "stream", streamName)
ctx = clog.AddVal(ctx, "source_id", sourceID)
ctx = clog.AddVal(ctx, "source_type", sourceType)

err := authenticateAIStream(AuthWebhookURL, AIAuthRequest{
Stream: streamName,
})
if err != nil {
kickErr := kickInputConnection(sourceID, sourceType)
if kickErr != nil {
clog.Errorf(ctx, "failed to kick input connection: %s", kickErr.Error())
}
clog.Errorf(ctx, "Live AI auth failed: %s", err.Error())
http.Error(w, "Forbidden", http.StatusForbidden)
return
}

requestID := string(core.RandomManifestID())
ctx := clog.AddVal(r.Context(), "request_id", requestID)
ctx = clog.AddVal(ctx, "request_id", requestID)
clog.Infof(ctx, "Received live video AI request for %s", streamName)

// Kick off the RTMP pull and segmentation as soon as possible
Expand All @@ -389,3 +418,27 @@ func (ls *LivepeerServer) StartLiveVideo() http.Handler {
processAIRequest(ctx, params, req)
})
}

const mediaMTXControlPort = "9997"

func kickInputConnection(sourceID string, sourceType string) error {
var apiPath string
switch sourceType {
case "webrtcSession":
apiPath = "webrtcsessions"
case "rtmpConn":
apiPath = "rtmpconns"
default:
return fmt.Errorf("invalid sourceType: %s", sourceType)
}

resp, err := http.Post(fmt.Sprintf("http://localhost:%s/v3/%s/kick/%s", mediaMTXControlPort, apiPath, sourceID), "", nil)
if err != nil {
return err
}
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusBadRequest {
body, _ := io.ReadAll(resp.Body)
return fmt.Errorf("kick connection failed with status code: %d body: %s", resp.StatusCode, body)
}
return nil
}
40 changes: 38 additions & 2 deletions server/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import (
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"io"
"net/http"
"net/url"
"time"
Expand Down Expand Up @@ -34,7 +34,7 @@ func authenticateStream(authURL *url.URL, incomingRequestURL string) (*authWebho
return nil, err
}

rbody, err := ioutil.ReadAll(resp.Body)
rbody, err := io.ReadAll(resp.Body)
resp.Body.Close()
if resp.StatusCode != 200 {
return nil, fmt.Errorf("status=%d error=%s", resp.StatusCode, string(rbody))
Expand Down Expand Up @@ -95,3 +95,39 @@ func (a authWebhookResponse) areProfilesEqual(b authWebhookResponse) bool {

return string(profilesA) == string(profilesB)
}

type AIAuthRequest struct {
Stream string `json:"stream"`
// TODO not sure what params we need yet
}

func authenticateAIStream(authURL *url.URL, req AIAuthRequest) error {
if authURL == nil {
return nil
}
started := time.Now()

jsonValue, err := json.Marshal(req)
if err != nil {
return err
}

resp, err := http.Post(authURL.String(), "application/json", bytes.NewBuffer(jsonValue))
if err != nil {
return err
}

rbody, err := io.ReadAll(resp.Body)
resp.Body.Close()
if resp.StatusCode != 200 {
return fmt.Errorf("status=%d error=%s", resp.StatusCode, string(rbody))
}

took := time.Since(started)
glog.Infof("AI Stream authentication for authURL=%s stream=%s dur=%s", authURL, req.Stream, took)
if monitor.Enabled {
monitor.AuthWebhookFinished(took)
}

return nil
}
10 changes: 10 additions & 0 deletions server/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,16 @@ func TestAuthSucceeds(t *testing.T) {
require.Equal(t, "456", resp.StreamID)
}

func TestAILiveAuthSucceeds(t *testing.T) {
s, serverURL := stubAuthServer(t, http.StatusOK, `{}`)
defer s.Close()

err := authenticateAIStream(serverURL, AIAuthRequest{
Stream: "stream",
})
require.NoError(t, err)
}

func TestNoErrorWhenTranscodeAuthHeaderNotPassed(t *testing.T) {
r, err := http.NewRequest(http.MethodPost, "some.com/url", nil)
require.NoError(t, err)
Expand Down

0 comments on commit b87c7c3

Please sign in to comment.