Skip to content

Commit

Permalink
Merge branch 'main' into adityahegde/request-project-access
Browse files Browse the repository at this point in the history
  • Loading branch information
AdityaHegde committed Jul 12, 2024
2 parents 55d0151 + 88f0168 commit 71a438a
Show file tree
Hide file tree
Showing 41 changed files with 690 additions and 287 deletions.
2 changes: 1 addition & 1 deletion admin/database/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -891,7 +891,7 @@ type InsertVirtualFileOptions struct {

type Asset struct {
ID string
OrganizationID string `db:"org_id"`
OrganizationID *string `db:"org_id"`
Path string `db:"path"`
OwnerID string `db:"owner_id"`
CreatedOn time.Time `db:"created_on"`
Expand Down
9 changes: 1 addition & 8 deletions admin/database/postgres/migrations/0034.sql
Original file line number Diff line number Diff line change
@@ -1,8 +1 @@
CREATE TABLE project_access_requests (
id UUID NOT NULL DEFAULT uuid_generate_v4() PRIMARY KEY,
user_id UUID REFERENCES users (id) ON DELETE CASCADE,
project_id UUID REFERENCES projects (id) ON DELETE CASCADE,
created_on TIMESTAMPTZ NOT NULL DEFAULT now()
);

CREATE UNIQUE INDEX project_access_requests_user_id_project_idx ON project_access_requests (user_id, project_id);
ALTER TABLE assets ALTER COLUMN org_id DROP NOT NULL;
8 changes: 8 additions & 0 deletions admin/database/postgres/migrations/0035.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
CREATE TABLE project_access_requests (
id UUID NOT NULL DEFAULT uuid_generate_v4() PRIMARY KEY,
user_id UUID REFERENCES users (id) ON DELETE CASCADE,
project_id UUID REFERENCES projects (id) ON DELETE CASCADE,
created_on TIMESTAMPTZ NOT NULL DEFAULT now()
);

CREATE UNIQUE INDEX project_access_requests_user_id_project_idx ON project_access_requests (user_id, project_id);
4 changes: 2 additions & 2 deletions admin/database/postgres/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -725,7 +725,7 @@ func (c *connection) InsertUsergroup(ctx context.Context, opts *database.InsertU

func (c *connection) UpdateUsergroupName(ctx context.Context, name, groupID string) (*database.Usergroup, error) {
res := &database.Usergroup{}
err := c.getDB(ctx).QueryRowxContext(ctx, "UPDATE usergroups SET name=$1 WHERE id=$2 RETURNING *", name, groupID).StructScan(res)
err := c.getDB(ctx).QueryRowxContext(ctx, "UPDATE usergroups SET name=$1, updated_on=now() WHERE id=$2 RETURNING *", name, groupID).StructScan(res)
if err != nil {
return nil, parseErr("usergroup", err)
}
Expand All @@ -734,7 +734,7 @@ func (c *connection) UpdateUsergroupName(ctx context.Context, name, groupID stri

func (c *connection) UpdateUsergroupDescription(ctx context.Context, description, groupID string) (*database.Usergroup, error) {
res := &database.Usergroup{}
err := c.getDB(ctx).QueryRowxContext(ctx, "UPDATE usergroups SET description=$1 WHERE id=$2 RETURNING *", description, groupID).StructScan(res)
err := c.getDB(ctx).QueryRowxContext(ctx, "UPDATE usergroups SET description=$1, updated_on=now() WHERE id=$2 RETURNING *", description, groupID).StructScan(res)
if err != nil {
return nil, parseErr("usergroup", err)
}
Expand Down
4 changes: 2 additions & 2 deletions admin/deployments.go
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ func (s *Service) HibernateDeployments(ctx context.Context) error {

s.Logger.Info("hibernate: deleting deployment", zap.String("project_id", proj.ID), zap.String("deployment_id", depl.ID))

err = s.teardownDeployment(ctx, depl)
err = s.TeardownDeployment(ctx, depl)
if err != nil {
s.Logger.Error("hibernate: teardown deployment error", zap.String("project_id", proj.ID), zap.String("deployment_id", depl.ID), zap.Error(err), observability.ZapCtx(ctx))
continue
Expand Down Expand Up @@ -366,7 +366,7 @@ func (s *Service) OpenRuntimeClient(host, audience string) (*client.Client, erro
return rt, nil
}

func (s *Service) teardownDeployment(ctx context.Context, depl *database.Deployment) error {
func (s *Service) TeardownDeployment(ctx context.Context, depl *database.Deployment) error {
// Delete the deployment
err := s.DB.DeleteDeployment(ctx, depl.ID)
if err != nil {
Expand Down
8 changes: 4 additions & 4 deletions admin/projects.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ func (s *Service) CreateProject(ctx context.Context, org *database.Organization,
Annotations: proj.Annotations,
})
if err != nil {
err2 := s.teardownDeployment(ctx, depl)
err2 := s.TeardownDeployment(ctx, depl)
err3 := s.DB.DeleteProject(ctx, proj.ID)
return nil, multierr.Combine(err, err2, err3)
}
Expand All @@ -119,7 +119,7 @@ func (s *Service) TeardownProject(ctx context.Context, p *database.Project) erro
}

for _, d := range ds {
err := s.teardownDeployment(ctx, d)
err := s.TeardownDeployment(ctx, d)
if err != nil {
return err
}
Expand Down Expand Up @@ -285,13 +285,13 @@ func (s *Service) TriggerRedeploy(ctx context.Context, proj *database.Project, p
Annotations: proj.Annotations,
})
if err != nil {
err2 := s.teardownDeployment(ctx, newDepl)
err2 := s.TeardownDeployment(ctx, newDepl)
return nil, multierr.Combine(err, err2)
}

// Delete old prod deployment if exists
if prevDepl != nil {
err = s.teardownDeployment(ctx, prevDepl)
err = s.TeardownDeployment(ctx, prevDepl)
if err != nil {
s.Logger.Error("trigger redeploy: could not teardown old deployment", zap.String("deployment_id", prevDepl.ID), zap.Error(err), observability.ZapCtx(ctx))
}
Expand Down
60 changes: 51 additions & 9 deletions admin/provisioner/kubernetes.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@ package provisioner
import (
"bytes"
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"net/url"
"os"
"path/filepath"
"strings"
"text/template"
Expand Down Expand Up @@ -44,9 +47,10 @@ type KubernetesTemplatePaths struct {
}

type KubernetesProvisioner struct {
Spec *KubernetesSpec
clientset *kubernetes.Clientset
templates *template.Template
Spec *KubernetesSpec
clientset *kubernetes.Clientset
templates *template.Template
templatesChecksum string
}

type TemplateData struct {
Expand Down Expand Up @@ -94,18 +98,33 @@ func NewKubernetes(spec json.RawMessage) (*KubernetesProvisioner, error) {
delete(funcMap, "env")
delete(funcMap, "expandenv")

// Parse the template definitions
templates := template.Must(template.New("").Funcs(funcMap).ParseFiles(
// Define template files
templateFiles := []string{
ksp.TemplatePaths.HTTPIngress,
ksp.TemplatePaths.GRPCIngress,
ksp.TemplatePaths.Service,
ksp.TemplatePaths.StatefulSet,
))
}

// Parse the template definitions
templates := template.Must(template.New("").Funcs(funcMap).ParseFiles(templateFiles...))

// Calculate the combined sha256 sum of all the template files
h := sha256.New()
for _, f := range templateFiles {
d, err := os.ReadFile(f)
if err != nil {
return nil, err
}
h.Write(d)
}
templatesChecksum := hex.EncodeToString(h.Sum(nil))

return &KubernetesProvisioner{
Spec: ksp,
clientset: clientset,
templates: templates,
Spec: ksp,
clientset: clientset,
templates: templates,
templatesChecksum: templatesChecksum,
}, nil
}

Expand Down Expand Up @@ -165,6 +184,7 @@ func (p *KubernetesProvisioner) Provision(ctx context.Context, opts *ProvisionOp

// Create statefulset
sts.ObjectMeta.Name = names.StatefulSet
sts.ObjectMeta.Annotations["checksum/templates"] = p.templatesChecksum
p.addCommonLabels(opts.ProvisionID, sts.ObjectMeta.Labels)
_, err = p.clientset.AppsV1().StatefulSets(p.Spec.Namespace).Create(ctx, sts, metav1.CreateOptions{})
if err != nil {
Expand Down Expand Up @@ -308,6 +328,28 @@ func (p *KubernetesProvisioner) CheckCapacity(ctx context.Context) error {
return nil
}

func (p *KubernetesProvisioner) ValidateConfig(ctx context.Context, provisionID string) (bool, error) {
// Get Kubernetes resource names
names := p.getResourceNames(provisionID)

// Get the statefulset
sts, err := p.clientset.AppsV1().StatefulSets(p.Spec.Namespace).Get(ctx, names.StatefulSet, metav1.GetOptions{})
if err != nil {
return false, err
}

// Compare the provisioned templates checksum with the current one
if sts.ObjectMeta.Annotations["checksum/templates"] != p.templatesChecksum {
return false, nil
}

return true, nil
}

func (p *KubernetesProvisioner) Type() string {
return "kubernetes"
}

func (p *KubernetesProvisioner) getResourceNames(provisionID string) ResourceNames {
return ResourceNames{
StatefulSet: fmt.Sprintf("runtime-%s", provisionID),
Expand Down
2 changes: 2 additions & 0 deletions admin/provisioner/provisioner.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ type Provisioner interface {
AwaitReady(ctx context.Context, provisionID string) error
Update(ctx context.Context, provisionID string, newVersion string) error
CheckCapacity(ctx context.Context) error
ValidateConfig(ctx context.Context, provisionID string) (bool, error)
Type() string
}

type ProvisionOptions struct {
Expand Down
9 changes: 9 additions & 0 deletions admin/provisioner/static.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,3 +129,12 @@ func (p *StaticProvisioner) Update(ctx context.Context, provisionID, newVersion
// No-op
return nil
}

func (p *StaticProvisioner) ValidateConfig(ctx context.Context, provisionID string) (bool, error) {
// No-op
return true, nil
}

func (p *StaticProvisioner) Type() string {
return "static"
}
47 changes: 24 additions & 23 deletions admin/server/projects.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,10 +168,10 @@ func (s *Server) GetProject(ctx context.Context, req *adminv1.GetProjectRequest)
Rule: &runtimev1.SecurityRule_Access{
Access: &runtimev1.SecurityRuleAccess{
Condition: fmt.Sprintf(
"NOT ('{{.self.kind}}'='%s' OR '{{.self.kind}}'='%s' AND '{{ .self.name }}'=%s)",
"NOT ('{{.self.kind}}'='%s' OR '{{.self.kind}}'='%s' AND '{{ lower .self.name }}'=%s)",
runtime.ResourceKindTheme,
runtime.ResourceKindMetricsView,
duckdbsql.EscapeStringValue(mdl.MetricsView),
duckdbsql.EscapeStringValue(strings.ToLower(mdl.MetricsView)),
),
Allow: false,
},
Expand Down Expand Up @@ -331,20 +331,14 @@ func (s *Server) CreateProject(ctx context.Context, req *adminv1.CreateProjectRe
attribute.String("args.archive_asset_id", req.ArchiveAssetId),
)

// Check the request is made by a user
claims := auth.GetClaims(ctx)
if claims.OwnerType() != auth.OwnerTypeUser {
return nil, status.Error(codes.Unauthenticated, "not authenticated as a user")
}
userID := claims.OwnerID()

// Find parent org
org, err := s.admin.DB.FindOrganizationByName(ctx, req.OrganizationName)
if err != nil {
return nil, status.Error(codes.InvalidArgument, err.Error())
}

// Check permissions
claims := auth.GetClaims(ctx)
if !claims.OrganizationPermissions(ctx, org.ID).CreateProjects {
return nil, status.Error(codes.PermissionDenied, "does not have permission to create projects")
}
Expand Down Expand Up @@ -387,12 +381,19 @@ func (s *Server) CreateProject(ctx context.Context, req *adminv1.CreateProjectRe
req.ProdVersion = "latest"
}

// Capture creating user (can be nil if created with a service token)
var userID *string
if claims.OwnerType() == auth.OwnerTypeUser {
tmp := claims.OwnerID()
userID = &tmp
}

opts := &database.InsertProjectOptions{
OrganizationID: org.ID,
Name: req.Name,
Description: req.Description,
Public: req.Public,
CreatedByUserID: &userID,
CreatedByUserID: userID,
Provisioner: req.Provisioner,
ProdVersion: req.ProdVersion,
ProdOLAPDriver: req.ProdOlapDriver,
Expand All @@ -403,8 +404,13 @@ func (s *Server) CreateProject(ctx context.Context, req *adminv1.CreateProjectRe
}

if req.GithubUrl != "" {
// Github projects must be configured by a user so we can ensure that they're allowed to access the repo.
if userID == nil {
return nil, status.Error(codes.Unauthenticated, "not authenticated as a user")
}

// Check Github app is installed and caller has access on the repo
installationID, err := s.getAndCheckGithubInstallationID(ctx, req.GithubUrl, userID)
installationID, err := s.getAndCheckGithubInstallationID(ctx, req.GithubUrl, *userID)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -495,18 +501,13 @@ func (s *Server) UpdateProject(ctx context.Context, req *adminv1.UpdateProjectRe
observability.AddRequestAttributes(ctx, attribute.String("args.new_name", *req.NewName))
}

// Check the request is made by a user
claims := auth.GetClaims(ctx)
if claims.OwnerType() != auth.OwnerTypeUser {
return nil, status.Error(codes.Unauthenticated, "not authenticated")
}

// Find project
proj, err := s.admin.DB.FindProjectByName(ctx, req.OrganizationName, req.Name)
if err != nil {
return nil, status.Error(codes.InvalidArgument, err.Error())
}

claims := auth.GetClaims(ctx)
if !claims.ProjectPermissions(ctx, proj.OrganizationID, proj.ID).ManageProject {
return nil, status.Error(codes.PermissionDenied, "does not have permission to delete project")
}
Expand All @@ -519,6 +520,11 @@ func (s *Server) UpdateProject(ctx context.Context, req *adminv1.UpdateProjectRe
if req.GithubUrl != nil {
// If changing the Github URL, check github app is installed and caller has access on the repo
if safeStr(proj.GithubURL) != *req.GithubUrl {
// Github projects must be configured by a user so we can ensure that they're allowed to access the repo.
if claims.OwnerType() != auth.OwnerTypeUser {
return nil, status.Error(codes.Unauthenticated, "not authenticated as a user")
}

_, err = s.getAndCheckGithubInstallationID(ctx, *req.GithubUrl, claims.OwnerID())
if err != nil {
return nil, err
Expand Down Expand Up @@ -598,12 +604,7 @@ func (s *Server) UpdateProjectVariables(ctx context.Context, req *adminv1.Update
return nil, status.Error(codes.InvalidArgument, err.Error())
}

// Check the request is made by a user
claims := auth.GetClaims(ctx)
if claims.OwnerType() != auth.OwnerTypeUser {
return nil, status.Error(codes.Unauthenticated, "not authenticated")
}

if !claims.ProjectPermissions(ctx, proj.OrganizationID, proj.ID).ManageProject {
return nil, status.Error(codes.PermissionDenied, "does not have permission to update project variables")
}
Expand Down Expand Up @@ -1476,7 +1477,7 @@ func (s *Server) hasAssetUsagePermission(ctx context.Context, id, orgID, ownerID
if err != nil {
return false
}
return asset.OrganizationID == orgID && asset.OwnerID == ownerID
return asset.OrganizationID != nil && *asset.OrganizationID == orgID && asset.OwnerID == ownerID
}

func deploymentToDTO(d *database.Deployment) *adminv1.Deployment {
Expand Down
15 changes: 14 additions & 1 deletion admin/worker/run_autoscaler.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,21 @@ func (w *Worker) allRecommendations(ctx context.Context) ([]metrics.AutoscalerSl
return recs, true, nil
}

// shouldScale determines whether scaling operations should be initiated based on the comparison of
// the current number of slots (originSlots) and the recommended number of slots (recommendSlots).
func shouldScale(originSlots, recommendSlots int) bool {
// Temproray disable scale DOWN - Tony
if recommendSlots <= originSlots {
return false
}

lowerBound := float64(originSlots) * (1 - scaleThreshold)
upperBound := float64(originSlots) * (1 + scaleThreshold)
return float64(recommendSlots) < lowerBound || float64(recommendSlots) > upperBound
if float64(recommendSlots) >= lowerBound && float64(recommendSlots) <= upperBound {
return false
}

// TODO: Skip scaling for manually assigned slots

return true
}
Loading

0 comments on commit 71a438a

Please sign in to comment.