Skip to content

Commit

Permalink
inmemory: implement continuing watch based on resourceVersion
Browse files Browse the repository at this point in the history
  • Loading branch information
chrischdi committed Jan 17, 2025
1 parent e08e606 commit 3bac618
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 19 deletions.
3 changes: 3 additions & 0 deletions test/infrastructure/inmemory/pkg/runtime/cache/hooks.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ func (c *cache) beforeCreate(_ string, obj client.Object, resourceVersion *uint6
// TODO: UID
obj.SetAnnotations(appendAnnotations(obj, lastSyncTimeAnnotation, now.Format(time.RFC3339)))
obj.SetResourceVersion(fmt.Sprintf("%d", *resourceVersion))
obj.SetGeneration(1)
*resourceVersion++
}

Expand All @@ -41,6 +42,7 @@ func (c *cache) afterCreate(resourceGroup string, obj client.Object) {
func (c *cache) beforeUpdate(_ string, oldObj, newObj client.Object, resourceVersion *uint64) {
newObj.SetCreationTimestamp(oldObj.GetCreationTimestamp())
newObj.SetResourceVersion(oldObj.GetResourceVersion())
newObj.SetGeneration(oldObj.GetGeneration())
// TODO: UID
newObj.SetAnnotations(appendAnnotations(newObj, lastSyncTimeAnnotation, oldObj.GetAnnotations()[lastSyncTimeAnnotation]))
if !oldObj.GetDeletionTimestamp().IsZero() {
Expand All @@ -51,6 +53,7 @@ func (c *cache) beforeUpdate(_ string, oldObj, newObj client.Object, resourceVer
newObj.SetAnnotations(appendAnnotations(newObj, lastSyncTimeAnnotation, now.Format(time.RFC3339)))

newObj.SetResourceVersion(fmt.Sprintf("%d", *resourceVersion))
newObj.SetGeneration(oldObj.GetGeneration() + 1)
*resourceVersion++
}
}
Expand Down
40 changes: 26 additions & 14 deletions test/infrastructure/inmemory/pkg/server/api/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ import (
"sigs.k8s.io/controller-runtime/pkg/client"

inmemoryruntime "sigs.k8s.io/cluster-api/test/infrastructure/inmemory/pkg/runtime"
inmemoryclient "sigs.k8s.io/cluster-api/test/infrastructure/inmemory/pkg/runtime/client"
inmemoryportforward "sigs.k8s.io/cluster-api/test/infrastructure/inmemory/pkg/server/api/portforward"
)

Expand Down Expand Up @@ -315,6 +316,25 @@ func (h *apiServerHandler) apiV1List(req *restful.Request, resp *restful.Respons
return
}

h.log.V(3).Info(fmt.Sprintf("Serving List for %v", req.Request.URL), "resourceGroup", resourceGroup)

list, err := h.apiV1list(ctx, req, *gvk, inmemoryClient)
if err != nil {
if status, ok := err.(apierrors.APIStatus); ok || errors.As(err, &status) {
_ = resp.WriteHeaderAndEntity(int(status.Status().Code), status)
return
}
_ = resp.WriteErrorString(http.StatusInternalServerError, err.Error())
return
}

if err := resp.WriteEntity(list); err != nil {
_ = resp.WriteErrorString(http.StatusInternalServerError, err.Error())
return
}
}

func (h *apiServerHandler) apiV1list(ctx context.Context, req *restful.Request, gvk schema.GroupVersionKind, inmemoryClient inmemoryclient.Client) (*unstructured.UnstructuredList, error) {
// Reads and returns the requested data.
list := &unstructured.UnstructuredList{}
list.SetAPIVersion(gvk.GroupVersion().String())
Expand All @@ -328,33 +348,23 @@ func (h *apiServerHandler) apiV1List(req *restful.Request, resp *restful.Respons
// TODO: The only field Selector which works is for `spec.nodeName` on pods.
fieldSelector, err := fields.ParseSelector(req.QueryParameter("fieldSelector"))
if err != nil {
_ = resp.WriteErrorString(http.StatusInternalServerError, err.Error())
return
return nil, err
}
if fieldSelector != nil {
listOpts = append(listOpts, client.MatchingFieldsSelector{Selector: fieldSelector})
}

labelSelector, err := labels.Parse(req.QueryParameter("labelSelector"))
if err != nil {
_ = resp.WriteErrorString(http.StatusInternalServerError, err.Error())
return
return nil, err
}
if labelSelector != nil {
listOpts = append(listOpts, client.MatchingLabelsSelector{Selector: labelSelector})
}
if err := inmemoryClient.List(ctx, list, listOpts...); err != nil {
if status, ok := err.(apierrors.APIStatus); ok || errors.As(err, &status) {
_ = resp.WriteHeaderAndEntity(int(status.Status().Code), status)
return
}
_ = resp.WriteHeaderAndEntity(http.StatusInternalServerError, err.Error())
return
}
if err := resp.WriteEntity(list); err != nil {
_ = resp.WriteErrorString(http.StatusInternalServerError, err.Error())
return
return nil, err
}
return list, nil
}

func (h *apiServerHandler) apiV1Watch(req *restful.Request, resp *restful.Response) {
Expand All @@ -372,6 +382,8 @@ func (h *apiServerHandler) apiV1Watch(req *restful.Request, resp *restful.Respon
return
}

h.log.V(3).Info(fmt.Sprintf("Serving Watch for %v", req.Request.URL), "resourceGroup", resourceGroup)

// If the request is a Watch handle it using watchForResource.
err = h.watchForResource(req, resp, resourceGroup, *gvk)
if err != nil {
Expand Down
85 changes: 80 additions & 5 deletions test/infrastructure/inmemory/pkg/server/api/watch.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,12 @@ import (
"context"
"fmt"
"net/http"
"sort"
"strconv"
"time"

"github.com/emicklei/go-restful/v3"
"github.com/pkg/errors"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/apimachinery/pkg/watch"
"sigs.k8s.io/controller-runtime/pkg/client"
Expand All @@ -33,7 +34,7 @@ import (
// Event records a lifecycle event for a Kubernetes object.
type Event struct {
Type watch.EventType `json:"type,omitempty"`
Object runtime.Object `json:"object,omitempty"`
Object client.Object `json:"object,omitempty"`
}

// WatchEventDispatcher dispatches events for a single resourceGroup.
Expand Down Expand Up @@ -89,12 +90,12 @@ func (m *WatchEventDispatcher) OnGeneric(resourceGroup string, o client.Object)
func (h *apiServerHandler) watchForResource(req *restful.Request, resp *restful.Response, resourceGroup string, gvk schema.GroupVersionKind) (reterr error) {
ctx := req.Request.Context()
queryTimeout := req.QueryParameter("timeoutSeconds")
resourceVersion := req.QueryParameter("resourceVersion")
c := h.manager.GetCache()
i, err := c.GetInformerForKind(ctx, gvk)
if err != nil {
return err
}
h.log.Info(fmt.Sprintf("Serving Watch for %v", req.Request.URL))
// With an unbuffered event channel RemoveEventHandler could be blocked because it requires a lock on the informer.
// When Run stops reading from the channel the informer could be blocked with an unbuffered chanel and then RemoveEventHandler never goes through.
// 1000 is used to avoid deadlocks in clusters with a higher number of Machines/Nodes.
Expand All @@ -108,6 +109,50 @@ func (h *apiServerHandler) watchForResource(req *restful.Request, resp *restful.
return err
}

initialEvents := []Event{}
if resourceVersion != "" {
parsedResourceVersion, err := strconv.ParseUint(resourceVersion, 10, 64)
if err != nil {
return err
}

// Get at client to the resource group and list all relevant objects.
inmemoryClient := h.manager.GetResourceGroup(resourceGroup).GetClient()
list, err := h.apiV1list(ctx, req, gvk, inmemoryClient)
if err != nil {
return err
}

// Sort the objects by resourceVersion to later write the events in order.
sort.SliceStable(list.Items, func(i, j int) bool {
a, err := strconv.ParseUint(list.Items[i].GetResourceVersion(), 10, 64)
if err != nil {
panic(err)
}
b, err := strconv.ParseUint(list.Items[j].GetResourceVersion(), 10, 64)
if err != nil {
panic(err)
}
return a < b
})

// Loop over all items and fill the list of events which were missed since the last watch.
for _, obj := range list.Items {
objResourceVersion, err := strconv.ParseUint(obj.GetResourceVersion(), 10, 64)
if err != nil {
return err
}
if objResourceVersion <= parsedResourceVersion {
continue
}
eventType := watch.Modified
if obj.GetGeneration() == 0 {
eventType = watch.Added
}
initialEvents = append(initialEvents, Event{Type: eventType, Object: &obj})
}
}

// Defer cleanup which removes the event handler and ensures the channel is empty of events.
defer func() {
// Doing this to ensure the channel is empty.
Expand All @@ -124,11 +169,11 @@ func (h *apiServerHandler) watchForResource(req *restful.Request, resp *restful.
// Note: After we removed the handler, no new events will be written to the events channel.
}()

return watcher.Run(ctx, queryTimeout, resp)
return watcher.Run(ctx, queryTimeout, initialEvents, resp)
}

// Run serves a series of encoded events via HTTP with Transfer-Encoding: chunked.
func (m *WatchEventDispatcher) Run(ctx context.Context, timeout string, w http.ResponseWriter) error {
func (m *WatchEventDispatcher) Run(ctx context.Context, timeout string, initialEvents []Event, w http.ResponseWriter) error {
flusher, ok := w.(http.Flusher)
if !ok {
return errors.New("can't start Watch: can't get http.Flusher")
Expand All @@ -139,6 +184,12 @@ func (m *WatchEventDispatcher) Run(ctx context.Context, timeout string, w http.R
}
w.Header().Set("Transfer-Encoding", "chunked")
w.WriteHeader(http.StatusOK)
// Write all object events which happened since the last resourceVersion.
for _, event := range initialEvents {
if err := resp.WriteEntity(event); err != nil {
_ = resp.WriteErrorString(http.StatusInternalServerError, err.Error())
}
}
flusher.Flush()

timeoutTimer, seconds, err := setTimer(timeout)
Expand All @@ -149,6 +200,18 @@ func (m *WatchEventDispatcher) Run(ctx context.Context, timeout string, w http.R
ctx, cancel := context.WithTimeout(ctx, seconds)
defer cancel()
defer timeoutTimer.Stop()

// Determine the highest written resourceVersion so we can filter out duplicated events from the channel.
minResourceVersion := uint64(0)
if len(initialEvents) > 0 {
minResourceVersion, err = strconv.ParseUint(initialEvents[len(initialEvents)-1].Object.GetResourceVersion(), 10, 64)
if err != nil {
return err
}
minResourceVersion++
}

var objResourceVersion uint64
for {
select {
case <-ctx.Done():
Expand All @@ -160,6 +223,18 @@ func (m *WatchEventDispatcher) Run(ctx context.Context, timeout string, w http.R
// End of results.
return nil
}

// Parse and check if the object has a higher resource version than we allow.
objResourceVersion, err = strconv.ParseUint(event.Object.GetResourceVersion(), 10, 64)
if err != nil {
_ = resp.WriteErrorString(http.StatusInternalServerError, err.Error())
}

// Skip objects which were already written.
if objResourceVersion < minResourceVersion {
continue
}

if err := resp.WriteEntity(event); err != nil {
_ = resp.WriteErrorString(http.StatusInternalServerError, err.Error())
}
Expand Down

0 comments on commit 3bac618

Please sign in to comment.