Skip to content

Commit

Permalink
[YUNIKORN-2967] Cleanup REST response headers (#994)
Browse files Browse the repository at this point in the history
Only respond with the allowed methods for the request, not with a
general all allowed set. OPTIONS is supported via the generic config.
Add a test to make sure a change in router does not break that.

Remove the Access-Control-Allow-Credentials as recommended in the RFC.
We also do not use cookies or authentication so not relevant to set.

Closes: #994

Signed-off-by: Craig Condit <[email protected]>
  • Loading branch information
wilfred-s authored and craigcondit committed Nov 14, 2024
1 parent d7d1408 commit ac32595
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 42 deletions.
65 changes: 34 additions & 31 deletions pkg/webservice/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ func redirectDebug(w http.ResponseWriter, r *http.Request) {
}

func getStackInfo(w http.ResponseWriter, r *http.Request) {
writeHeaders(w)
writeHeaders(w, r.Method)
var stack = func() []byte {
buf := make([]byte, 1024)
for {
Expand All @@ -145,7 +145,7 @@ func getStackInfo(w http.ResponseWriter, r *http.Request) {
}

func getClusterInfo(w http.ResponseWriter, r *http.Request) {
writeHeaders(w)
writeHeaders(w, r.Method)

lists := schedulerContext.Load().GetPartitionMapClone()
clustersInfo := getClusterDAO(lists)
Expand All @@ -167,7 +167,7 @@ func validateQueue(queuePath string) error {
}

func validateConf(w http.ResponseWriter, r *http.Request) {
writeHeaders(w)
writeHeaders(w, r.Method)
requestBytes, err := io.ReadAll(r.Body)
if err == nil {
_, err = configs.LoadSchedulerConfigFromByteArray(requestBytes)
Expand All @@ -184,11 +184,14 @@ func validateConf(w http.ResponseWriter, r *http.Request) {
}
}

func writeHeaders(w http.ResponseWriter) {
func writeHeaders(w http.ResponseWriter, method string) {
w.Header().Set("Content-Type", "application/json; charset=UTF-8")
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Credentials", "true")
w.Header().Set("Access-Control-Allow-Methods", "GET,POST,HEAD,OPTIONS")
methods := "GET, OPTIONS"
if method == http.MethodPost {
methods = "OPTIONS, POST"
}
w.Header().Set("Access-Control-Allow-Methods", methods)
w.Header().Set("Access-Control-Allow-Headers", "X-Requested-With,Content-Type,Accept,Origin")
}

Expand Down Expand Up @@ -233,7 +236,7 @@ func getClusterUtilJSON(partition *scheduler.PartitionContext) []*dao.ClusterUti
}
utils = append(utils, utilization)
}
} else if !getResource {
} else {
utilization := &dao.ClusterUtilDAOInfo{
ResourceType: "N/A",
Total: int64(-1),
Expand Down Expand Up @@ -446,7 +449,7 @@ func getNodesDAO(entries []*objects.Node) []*dao.NodeDAOInfo {
// Only check the default partition
// Deprecated - To be removed in next major release. Replaced with getNodesUtilisations
func getNodeUtilisation(w http.ResponseWriter, r *http.Request) {
writeHeaders(w)
writeHeaders(w, r.Method)
partitionContext := schedulerContext.Load().GetPartitionWithoutClusterID(configs.DefaultPartition)
if partitionContext == nil {
buildJSONErrorResponse(w, PartitionDoesNotExists, http.StatusInternalServerError)
Expand Down Expand Up @@ -510,7 +513,7 @@ func getNodesUtilJSON(partition *scheduler.PartitionContext, name string) *dao.N
}

func getNodeUtilisations(w http.ResponseWriter, r *http.Request) {
writeHeaders(w)
writeHeaders(w, r.Method)
var result []*dao.PartitionNodesUtilDAOInfo
for _, part := range schedulerContext.Load().GetPartitionMapClone() {
result = append(result, getPartitionNodesUtilJSON(part))
Expand Down Expand Up @@ -583,7 +586,7 @@ func getPartitionNodesUtilJSON(partition *scheduler.PartitionContext) *dao.Parti
}

func getApplicationHistory(w http.ResponseWriter, r *http.Request) {
writeHeaders(w)
writeHeaders(w, r.Method)

// There is nothing to return but we did not really encounter a problem
if imHistory == nil {
Expand All @@ -600,7 +603,7 @@ func getApplicationHistory(w http.ResponseWriter, r *http.Request) {
}

func getContainerHistory(w http.ResponseWriter, r *http.Request) {
writeHeaders(w)
writeHeaders(w, r.Method)

// There is nothing to return but we did not really encounter a problem
if imHistory == nil {
Expand All @@ -617,7 +620,7 @@ func getContainerHistory(w http.ResponseWriter, r *http.Request) {
}

func getClusterConfig(w http.ResponseWriter, r *http.Request) {
writeHeaders(w)
writeHeaders(w, r.Method)

var marshalledConf []byte
var err error
Expand Down Expand Up @@ -653,7 +656,7 @@ func getClusterConfigDAO() *dao.ConfigDAOInfo {
}

func checkHealthStatus(w http.ResponseWriter, r *http.Request) {
writeHeaders(w)
writeHeaders(w, r.Method)

// Fetch last healthCheck result
result := schedulerContext.Load().GetLastHealthCheckResult()
Expand All @@ -675,8 +678,8 @@ func checkHealthStatus(w http.ResponseWriter, r *http.Request) {
}
}

func getPartitions(w http.ResponseWriter, _ *http.Request) {
writeHeaders(w)
func getPartitions(w http.ResponseWriter, r *http.Request) {
writeHeaders(w, r.Method)

lists := schedulerContext.Load().GetPartitionMapClone()
partitionsInfo := getPartitionInfoDAO(lists)
Expand All @@ -686,7 +689,7 @@ func getPartitions(w http.ResponseWriter, _ *http.Request) {
}

func getPartitionQueues(w http.ResponseWriter, r *http.Request) {
writeHeaders(w)
writeHeaders(w, r.Method)
vars := httprouter.ParamsFromContext(r.Context())
if vars == nil {
buildJSONErrorResponse(w, MissingParamsName, http.StatusBadRequest)
Expand All @@ -707,7 +710,7 @@ func getPartitionQueues(w http.ResponseWriter, r *http.Request) {
}

func getPartitionQueue(w http.ResponseWriter, r *http.Request) {
writeHeaders(w)
writeHeaders(w, r.Method)
vars := httprouter.ParamsFromContext(r.Context())
if vars == nil {
buildJSONErrorResponse(w, MissingParamsName, http.StatusBadRequest)
Expand Down Expand Up @@ -742,7 +745,7 @@ func getPartitionQueue(w http.ResponseWriter, r *http.Request) {
}

func getPartitionNodes(w http.ResponseWriter, r *http.Request) {
writeHeaders(w)
writeHeaders(w, r.Method)
vars := httprouter.ParamsFromContext(r.Context())
if vars == nil {
buildJSONErrorResponse(w, MissingParamsName, http.StatusBadRequest)
Expand All @@ -761,7 +764,7 @@ func getPartitionNodes(w http.ResponseWriter, r *http.Request) {
}

func getPartitionNode(w http.ResponseWriter, r *http.Request) {
writeHeaders(w)
writeHeaders(w, r.Method)
vars := httprouter.ParamsFromContext(r.Context())
if vars == nil {
buildJSONErrorResponse(w, MissingParamsName, http.StatusBadRequest)
Expand All @@ -786,7 +789,7 @@ func getPartitionNode(w http.ResponseWriter, r *http.Request) {
}

func getQueueApplications(w http.ResponseWriter, r *http.Request) {
writeHeaders(w)
writeHeaders(w, r.Method)
vars := httprouter.ParamsFromContext(r.Context())
if vars == nil {
buildJSONErrorResponse(w, MissingParamsName, http.StatusBadRequest)
Expand Down Expand Up @@ -827,7 +830,7 @@ func getQueueApplications(w http.ResponseWriter, r *http.Request) {
}

func getPartitionApplicationsByState(w http.ResponseWriter, r *http.Request) {
writeHeaders(w)
writeHeaders(w, r.Method)
vars := httprouter.ParamsFromContext(r.Context())
if vars == nil {
buildJSONErrorResponse(w, MissingParamsName, http.StatusBadRequest)
Expand Down Expand Up @@ -876,7 +879,7 @@ func getPartitionApplicationsByState(w http.ResponseWriter, r *http.Request) {
}

func getApplication(w http.ResponseWriter, r *http.Request) {
writeHeaders(w)
writeHeaders(w, r.Method)
vars := httprouter.ParamsFromContext(r.Context())
if vars == nil {
buildJSONErrorResponse(w, MissingParamsName, http.StatusBadRequest)
Expand Down Expand Up @@ -924,7 +927,7 @@ func getApplication(w http.ResponseWriter, r *http.Request) {
}

func getPartitionRules(w http.ResponseWriter, r *http.Request) {
writeHeaders(w)
writeHeaders(w, r.Method)
vars := httprouter.ParamsFromContext(r.Context())
if vars == nil {
buildJSONErrorResponse(w, MissingParamsName, http.StatusBadRequest)
Expand All @@ -943,7 +946,7 @@ func getPartitionRules(w http.ResponseWriter, r *http.Request) {
}

func getQueueApplicationsByState(w http.ResponseWriter, r *http.Request) {
writeHeaders(w)
writeHeaders(w, r.Method)
vars := httprouter.ParamsFromContext(r.Context())
if vars == nil {
buildJSONErrorResponse(w, MissingParamsName, http.StatusBadRequest)
Expand Down Expand Up @@ -1180,8 +1183,8 @@ func getMetrics(w http.ResponseWriter, r *http.Request) {
promhttp.Handler().ServeHTTP(w, r)
}

func getUsersResourceUsage(w http.ResponseWriter, _ *http.Request) {
writeHeaders(w)
func getUsersResourceUsage(w http.ResponseWriter, r *http.Request) {
writeHeaders(w, r.Method)
userManager := ugm.GetUserManager()
trackers := userManager.GetUserTrackers()
result := make([]*dao.UserResourceUsageDAOInfo, len(trackers))
Expand All @@ -1194,7 +1197,7 @@ func getUsersResourceUsage(w http.ResponseWriter, _ *http.Request) {
}

func getUserResourceUsage(w http.ResponseWriter, r *http.Request) {
writeHeaders(w)
writeHeaders(w, r.Method)
vars := httprouter.ParamsFromContext(r.Context())
if vars == nil {
buildJSONErrorResponse(w, MissingParamsName, http.StatusBadRequest)
Expand Down Expand Up @@ -1222,7 +1225,7 @@ func getUserResourceUsage(w http.ResponseWriter, r *http.Request) {
}

func getGroupsResourceUsage(w http.ResponseWriter, r *http.Request) {
writeHeaders(w)
writeHeaders(w, r.Method)
userManager := ugm.GetUserManager()
trackers := userManager.GetGroupTrackers()
result := make([]*dao.GroupResourceUsageDAOInfo, len(trackers))
Expand All @@ -1235,7 +1238,7 @@ func getGroupsResourceUsage(w http.ResponseWriter, r *http.Request) {
}

func getGroupResourceUsage(w http.ResponseWriter, r *http.Request) {
writeHeaders(w)
writeHeaders(w, r.Method)
vars := httprouter.ParamsFromContext(r.Context())
if vars == nil {
buildJSONErrorResponse(w, MissingParamsName, http.StatusBadRequest)
Expand Down Expand Up @@ -1263,7 +1266,7 @@ func getGroupResourceUsage(w http.ResponseWriter, r *http.Request) {
}

func getEvents(w http.ResponseWriter, r *http.Request) {
writeHeaders(w)
writeHeaders(w, r.Method)
eventSystem := events.GetEventSystem()
if !eventSystem.IsEventTrackingEnabled() {
buildJSONErrorResponse(w, "Event tracking is disabled", http.StatusInternalServerError)
Expand Down Expand Up @@ -1311,7 +1314,7 @@ func getEvents(w http.ResponseWriter, r *http.Request) {
}

func getStream(w http.ResponseWriter, r *http.Request) {
writeHeaders(w)
writeHeaders(w, r.Method)
eventSystem := events.GetEventSystem()
if !eventSystem.IsEventTrackingEnabled() {
buildJSONErrorResponse(w, "Event tracking is disabled", http.StatusInternalServerError)
Expand Down
20 changes: 12 additions & 8 deletions pkg/webservice/handlers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1285,16 +1285,18 @@ func TestGetPartitionQueueHandler(t *testing.T) {
func TestGetClusterInfo(t *testing.T) {
schedulerContext.Store(&scheduler.ClusterContext{})
resp := &MockResponseWriter{}
getClusterInfo(resp, nil)
req, err := http.NewRequest("GET", "/ws/v1/clusters", strings.NewReader(""))
assert.NilError(t, err, "error while creating http request")
getClusterInfo(resp, req)
var data []*dao.ClusterDAOInfo
err := json.Unmarshal(resp.outputBytes, &data)
err = json.Unmarshal(resp.outputBytes, &data)
assert.NilError(t, err)
assert.Equal(t, 0, len(data))

setup(t, configTwoLevelQueues, 2)

resp = &MockResponseWriter{}
getClusterInfo(resp, nil)
getClusterInfo(resp, req)
err = json.Unmarshal(resp.outputBytes, &data)
assert.NilError(t, err)
assert.Equal(t, 2, len(data))
Expand Down Expand Up @@ -1412,11 +1414,11 @@ func TestGetPartitionNode(t *testing.T) {
_, allocCreated, err := partition.UpdateAllocation(alloc1)
assert.NilError(t, err, "add alloc-1 should not have failed")
assert.Check(t, allocCreated)
falloc1 := newForeignAlloc("foreign-1", "", node1ID, resAlloc1, siCommon.AllocTypeDefault, 0)
falloc1 := newForeignAlloc("foreign-1", node1ID, resAlloc1, siCommon.AllocTypeDefault, 0)
_, allocCreated, err = partition.UpdateAllocation(falloc1)
assert.NilError(t, err, "add falloc-1 should not have failed")
assert.Check(t, allocCreated)
falloc2 := newForeignAlloc("foreign-2", "", node1ID, resAlloc2, siCommon.AllocTypeStatic, 123)
falloc2 := newForeignAlloc("foreign-2", node1ID, resAlloc2, siCommon.AllocTypeStatic, 123)
_, allocCreated, err = partition.UpdateAllocation(falloc2)
assert.NilError(t, err, "add falloc-2 should not have failed")
assert.Check(t, allocCreated)
Expand Down Expand Up @@ -1746,6 +1748,7 @@ func checkGetQueueAppByState(t *testing.T, partition, queue, state, status strin
url = fmt.Sprintf("/ws/v1/partition/%s/queue/%s/applications/%s?status=%s", partition, queue, state, status)
}
req, err := http.NewRequest("GET", url, strings.NewReader(""))
assert.NilError(t, err, "unexpected error creating request")
req = req.WithContext(context.WithValue(req.Context(), httprouter.ParamsKey, httprouter.Params{
httprouter.Param{Key: "partition", Value: partition},
httprouter.Param{Key: "queue", Value: queue},
Expand Down Expand Up @@ -1780,6 +1783,7 @@ func checkGetQueueAppByIllegalStateOrStatus(t *testing.T, partition, queue, stat
url = fmt.Sprintf("/ws/v1/partition/%s/queue/%s/applications/%s?status=%s", partition, queue, state, status)
}
req, err := http.NewRequest("GET", url, strings.NewReader(""))
assert.NilError(t, err, "unexpected error creating request")
req = req.WithContext(context.WithValue(req.Context(), httprouter.ParamsKey, httprouter.Params{
httprouter.Param{Key: "partition", Value: partition},
httprouter.Param{Key: "queue", Value: queue},
Expand Down Expand Up @@ -2115,9 +2119,9 @@ func TestFullStateDumpPath(t *testing.T) {
prepareSchedulerContext(t)

partitionContext := schedulerContext.Load().GetPartitionMapClone()
context := partitionContext[normalizedPartitionName]
ctx := partitionContext[normalizedPartitionName]
app := newApplication("appID", normalizedPartitionName, "root.default", rmID, security.UserGroup{})
err := context.AddApplication(app)
err := ctx.AddApplication(app)
assert.NilError(t, err, "failed to add Application to partition")

imHistory = history.NewInternalMetricsHistory(5)
Expand Down Expand Up @@ -3053,7 +3057,7 @@ func newAlloc(allocationKey string, appID string, nodeID string, resAlloc *resou
})
}

func newForeignAlloc(allocationKey string, appID string, nodeID string, resAlloc *resources.Resource, fType string, priority int32) *objects.Allocation {
func newForeignAlloc(allocationKey string, nodeID string, resAlloc *resources.Resource, fType string, priority int32) *objects.Allocation {
return objects.NewAllocationFromSI(&si.Allocation{
AllocationKey: allocationKey,
NodeID: nodeID,
Expand Down
2 changes: 1 addition & 1 deletion pkg/webservice/state_dump.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ type AggregatedStateInfo struct {
}

func getFullStateDump(w http.ResponseWriter, r *http.Request) {
writeHeaders(w)
writeHeaders(w, r.Method)
if err := doStateDump(w); err != nil {
buildJSONErrorResponse(w, err.Error(), http.StatusInternalServerError)
}
Expand Down
Loading

0 comments on commit ac32595

Please sign in to comment.