diff --git a/storage/azure.go b/storage/azure.go index 10cd97e..b72e0de 100644 --- a/storage/azure.go +++ b/storage/azure.go @@ -3,21 +3,33 @@ package storage import ( "context" "fmt" + "time" "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" "github.com/Azure/azure-sdk-for-go/sdk/azidentity" "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob" + "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/sas" + "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/service" ) type AzureClient struct { account string cli *azblob.Client + + useIAM bool + + // sasCli is used to generate SAS token. + // When we want to copy object under two different service accounts, AD auth is not supported. + // So we need to use AD auth to generate SAS token and use SAS token to copy object. + sasCli *service.Client } func NewAzureClient(cfg Cfg) (*AzureClient, error) { endpoint := fmt.Sprintf("https://%s.blob.core.windows.net", cfg.AK) var cli *azblob.Client + var sasCli *service.Client if cfg.UseIAM { cred, err := azidentity.NewDefaultAzureCredential(nil) if err != nil { @@ -27,6 +39,10 @@ func NewAzureClient(cfg Cfg) (*AzureClient, error) { if err != nil { return nil, fmt.Errorf("storage: new azure client %w", err) } + sasCli, err = service.NewClient(endpoint, cred, nil) + if err != nil { + return nil, fmt.Errorf("storage: new azure service client %w", err) + } } else { cred, err := azblob.NewSharedKeyCredential(cfg.AK, cfg.SK) if err != nil { @@ -36,9 +52,10 @@ func NewAzureClient(cfg Cfg) (*AzureClient, error) { if err != nil { return nil, fmt.Errorf("storage: new azure client %w", err) } + // sasCli is not used when use shared key auth } - return &AzureClient{account: cfg.AK, cli: cli}, nil + return &AzureClient{account: cfg.AK, useIAM: cfg.UseIAM, cli: cli, sasCli: sasCli}, nil } func (a *AzureClient) CopyObject(ctx context.Context, i CopyObjectInput) error { @@ -46,11 +63,21 @@ func (a *AzureClient) CopyObject(ctx context.Context, i CopyObjectInput) error { if !ok { return fmt.Errorf("storage: azure copy object dest client is not azure client") } + var url string + // When we want to copy object under two different service accounts, AD auth is not supported. + if srcCli.useIAM && (srcCli.account != a.account) { + queryParam, err := a.getSAS(i) + if err != nil { + return fmt.Errorf("storage: azure get sas %w", err) + } + url = fmt.Sprintf("https://%s.blob.core.windows.net/%s/%s?%s", srcCli.account, i.SrcBucket, i.SrcKey, queryParam.Encode()) + } else { + url = fmt.Sprintf("https://%s.blob.core.windows.net/%s/%s", srcCli.account, i.SrcBucket, i.SrcKey) + } - url := fmt.Sprintf("https://%s.blob.core.windows.net/%s/%s", srcCli.account, i.SrcBucket, i.SrcKey) _, err := a.cli.ServiceClient(). NewContainerClient(i.DestBucket). - NewBlobClient(i.DestKey). + NewBlockBlobClient(i.DestKey). StartCopyFromURL(ctx, url, nil) if err != nil { return fmt.Errorf("storage: azure start copy from url %w", err) @@ -59,6 +86,36 @@ func (a *AzureClient) CopyObject(ctx context.Context, i CopyObjectInput) error { return nil } +func (a *AzureClient) getSAS(i CopyObjectInput) (sas.QueryParameters, error) { + srcCli, ok := i.SrcCli.(*AzureClient) + if !ok { + return sas.QueryParameters{}, fmt.Errorf("storage: azure copy object dest client is not azure client") + } + + now := time.Now().UTC().Add(-10 * time.Second) + expiry := now.Add(48 * time.Hour) + info := service.KeyInfo{ + Start: to.Ptr(now.UTC().Format(sas.TimeFormat)), + Expiry: to.Ptr(expiry.UTC().Format(sas.TimeFormat)), + } + udc, err := srcCli.sasCli.GetUserDelegationCredential(context.Background(), info, nil) + if err != nil { + return sas.QueryParameters{}, fmt.Errorf("storage: azure get user delegation credential %w", err) + } + sasQueryParams, err := sas.BlobSignatureValues{ + Protocol: sas.ProtocolHTTPS, + StartTime: time.Now().UTC().Add(time.Second * -10), + ExpiryTime: time.Now().UTC().Add(60 * time.Minute), + Permissions: to.Ptr(sas.ContainerPermissions{Read: true, List: true}).String(), + ContainerName: i.SrcBucket, + }.SignWithUserDelegation(udc) + if err != nil { + return sas.QueryParameters{}, fmt.Errorf("storage: azure sign with user delegation %w", err) + } + + return sasQueryParams, nil +} + func (a *AzureClient) HeadBucket(ctx context.Context, bucket string) error { page := a.cli.NewListContainersPager(&azblob.ListContainersOptions{Prefix: &bucket}) for page.More() {