Skip to content

Commit

Permalink
Merge pull request #6353 from jhawk28/pull-improvements
Browse files Browse the repository at this point in the history
Improve pull for transit/bootstrap
  • Loading branch information
vishalchangrani authored Aug 28, 2024
2 parents fc56066 + b89f92d commit bada8ca
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 21 deletions.
46 changes: 36 additions & 10 deletions cmd/bootstrap/cmd/pull.go
Original file line number Diff line number Diff line change
@@ -1,21 +1,26 @@
package cmd

import (
"bytes"
"context"
"fmt"
"path/filepath"
"strings"
"sync"
"time"

"github.com/spf13/cobra"
"golang.org/x/sync/semaphore"

"github.com/onflow/flow-go/cmd"
"github.com/onflow/flow-go/cmd/bootstrap/gcs"
"github.com/onflow/flow-go/cmd/bootstrap/utils"
)

var (
flagNetwork string
flagBucketName string
flagNetwork string
flagBucketName string
flagConcurrency int64
)

// pullCmd represents a command to pull parnter node details from the google
Expand All @@ -37,6 +42,7 @@ func addPullCmdFlags() {
cmd.MarkFlagRequired(pullCmd, "network")

pullCmd.Flags().StringVar(&flagBucketName, "bucket", "flow-genesis-bootstrap", "google bucket name")
pullCmd.Flags().Int64Var(&flagConcurrency, "concurrency", 2, "concurrency limit")
}

// pull partner node info from google bucket
Expand All @@ -62,15 +68,35 @@ func pull(cmd *cobra.Command, args []string) {
}
log.Info().Msgf("found %d files in google bucket", len(files))

sem := semaphore.NewWeighted(flagConcurrency)
wait := sync.WaitGroup{}
for _, file := range files {
if strings.Contains(file, "node-info.pub") {
fullOutpath := filepath.Join(flagOutdir, file)
log.Printf("downloading %s", file)

err = bucket.DownloadFile(ctx, client, fullOutpath, file)
if err != nil {
log.Error().Msgf("error trying download google bucket file: %v", err)
wait.Add(1)
go func(file gcs.GCSFile) {
_ = sem.Acquire(ctx, 1)
defer func() {
sem.Release(1)
wait.Done()
}()

if strings.Contains(file.Name, "node-info.pub") {
fullOutpath := filepath.Join(flagOutdir, file.Name)

fmd5 := utils.CalcMd5(fullOutpath)
// only skip files that have an MD5 hash
if file.MD5 != nil && bytes.Equal(fmd5, file.MD5) {
log.Printf("skipping %s", file)
return
}

log.Printf("downloading %s", file)
err = bucket.DownloadFile(ctx, client, fullOutpath, file.Name)
if err != nil {
log.Error().Msgf("error trying download google bucket file: %v", err)
}
}
}
}(file)
}

wait.Wait()
}
14 changes: 11 additions & 3 deletions cmd/bootstrap/gcs/gcs.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,19 @@ func (g *googleBucket) NewClient(ctx context.Context) (*storage.Client, error) {
return client, nil
}

type GCSFile struct {
Name string
MD5 []byte
}

// GetFiles returns a list of file names within the Google bucket
func (g *googleBucket) GetFiles(ctx context.Context, client *storage.Client, prefix, delimiter string) ([]string, error) {
func (g *googleBucket) GetFiles(ctx context.Context, client *storage.Client, prefix, delimiter string) ([]GCSFile, error) {
it := client.Bucket(g.Name).Objects(ctx, &storage.Query{
Prefix: prefix,
Delimiter: delimiter,
})

var files []string
var files []GCSFile
for {
attrs, err := it.Next()
if err == iterator.Done {
Expand All @@ -50,7 +55,10 @@ func (g *googleBucket) GetFiles(ctx context.Context, client *storage.Client, pre
return nil, err
}

files = append(files, attrs.Name)
files = append(files, GCSFile{
Name: attrs.Name,
MD5: attrs.MD5,
})
}

return files, nil
Expand Down
1 change: 1 addition & 0 deletions cmd/bootstrap/transit/cmd/flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ var (
flagAccessAddress string
flagNodeRole string
flagTimeout time.Duration
flagConcurrency int64

flagWrapID string // wrap ID
flagVoteFile string
Expand Down
38 changes: 30 additions & 8 deletions cmd/bootstrap/transit/cmd/pull.go
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
package cmd

import (
"bytes"
"context"
"fmt"
"io/fs"
"path/filepath"
"strings"
"sync"
"time"

"github.com/spf13/cobra"
"golang.org/x/sync/semaphore"

"github.com/onflow/flow-go/cmd/bootstrap/gcs"
"github.com/onflow/flow-go/cmd/bootstrap/utils"
model "github.com/onflow/flow-go/model/bootstrap"
"github.com/onflow/flow-go/model/flow"
)
Expand All @@ -32,6 +36,7 @@ func addPullCmdFlags() {
pullCmd.Flags().StringVarP(&flagToken, "token", "t", "", "token provided by the Flow team to access the Transit server")
pullCmd.Flags().StringVarP(&flagNodeRole, "role", "r", "", `node role (can be "collection", "consensus", "execution", "verification" or "access")`)
pullCmd.Flags().DurationVar(&flagTimeout, "timeout", time.Second*300, `timeout for pull`)
pullCmd.Flags().Int64Var(&flagConcurrency, "concurrency", 2, `concurrency limit for pull`)

_ = pullCmd.MarkFlagRequired("token")
_ = pullCmd.MarkFlagRequired("role")
Expand Down Expand Up @@ -78,17 +83,34 @@ func pull(cmd *cobra.Command, args []string) {
}
log.Info().Msgf("found %d files in Google Bucket", len(files))

// download found files
sem := semaphore.NewWeighted(flagConcurrency)
wait := sync.WaitGroup{}
for _, file := range files {
fullOutpath := filepath.Join(flagBootDir, "public-root-information", filepath.Base(file))

log.Info().Str("source", file).Str("dest", fullOutpath).Msgf("downloading file from transit servers")
err = bucket.DownloadFile(ctx, client, fullOutpath, file)
if err != nil {
log.Fatal().Err(err).Msgf("could not download google bucket file")
}
wait.Add(1)
go func(file gcs.GCSFile) {
_ = sem.Acquire(ctx, 1)
defer func() {
sem.Release(1)
wait.Done()
}()

fullOutpath := filepath.Join(flagBootDir, "public-root-information", filepath.Base(file.Name))
fmd5 := utils.CalcMd5(fullOutpath)
// only skip files that have an MD5 hash
if file.MD5 != nil && bytes.Equal(fmd5, file.MD5) {
log.Info().Str("source", file.Name).Str("dest", fullOutpath).Msgf("skipping existing file from transit servers")
return
}
log.Info().Str("source", file.Name).Str("dest", fullOutpath).Msgf("downloading file from transit servers")
err = bucket.DownloadFile(ctx, client, fullOutpath, file.Name)
if err != nil {
log.Fatal().Err(err).Msgf("could not download google bucket file")
}
}(file)
}

wait.Wait()

// download any extra files specific to node role
extraFiles := getAdditionalFilesToDownload(role, nodeID)
for _, file := range extraFiles {
Expand Down
25 changes: 25 additions & 0 deletions cmd/bootstrap/utils/md5.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package utils

// The google storage API only provides md5 and crc32 hence overriding the linter flag for md5
// #nosec
import (
"crypto/md5"
"io"
"os"
)

func CalcMd5(outpath string) []byte {
f, err := os.Open(outpath)
if err != nil {
return nil
}
defer f.Close()

// #nosec
h := md5.New()
if _, err := io.Copy(h, f); err != nil {
return nil
}

return h.Sum(nil)
}

0 comments on commit bada8ca

Please sign in to comment.