Skip to content

Commit

Permalink
Merge pull request #270 from tencentyun/feature_jojoliang_9baf1e2a
Browse files Browse the repository at this point in the history
download分块可以超过10000, 增加sts签名方式
  • Loading branch information
agin719 authored Jul 11, 2024
2 parents 7b3b630 + 808cb04 commit 91931e0
Show file tree
Hide file tree
Showing 5 changed files with 489 additions and 8 deletions.
217 changes: 213 additions & 4 deletions auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@ import (
"context"
"crypto/hmac"
"crypto/sha1"
"encoding/base64"
"encoding/json"
"fmt"
"hash"
"io"
"io/ioutil"
math_rand "math/rand"
"net"
Expand All @@ -26,11 +28,13 @@ const (
)

var (
defaultCVMAuthExpire = int64(600)
defaultTmpAuthExpire = int64(600)
defaultCVMSchema = "http"
defaultCVMMetaHost = "metadata.tencentyun.com"
defaultCVMCredURI = "latest/meta-data/cam/security-credentials"
internalHost = regexp.MustCompile(`^.*cos-internal\.[a-z-1]+\.tencentcos\.cn$`)
defaultStsHost = "sts.tencentcloudapi.com"
defaultStsSchema = "https"
)

var DNSScatterDialContext = DNSScatterDialContextFunc
Expand Down Expand Up @@ -424,7 +428,7 @@ func (t *CVMCredentialTransport) GetRoles() ([]string, error) {
func (t *CVMCredentialTransport) UpdateCredential(now int64) (string, string, string, error) {
t.rwLocker.Lock()
defer t.rwLocker.Unlock()
if t.expiredTime > now+defaultCVMAuthExpire {
if t.expiredTime > now+defaultTmpAuthExpire {
return t.secretID, t.secretKey, t.sessionToken, nil
}
roleName := t.RoleName
Expand Down Expand Up @@ -460,8 +464,8 @@ func (t *CVMCredentialTransport) UpdateCredential(now int64) (string, string, st
func (t *CVMCredentialTransport) GetCredential() (string, string, string, error) {
now := time.Now().Unix()
t.rwLocker.RLock()
// 提前 defaultCVMAuthExpire 获取重新获取临时密钥
if t.expiredTime <= now+defaultCVMAuthExpire {
// 提前 defaultTmpAuthExpire 获取重新获取临时密钥
if t.expiredTime <= now+defaultTmpAuthExpire {
expiredTime := t.expiredTime
t.rwLocker.RUnlock()
secretID, secretKey, secretToken, err := t.UpdateCredential(now)
Expand Down Expand Up @@ -545,3 +549,208 @@ func (c *Credential) GetSecretId() string {
func (c *Credential) GetToken() string {
return c.SessionToken
}

// 通过sts访问
type Credentials struct {
TmpSecretID string `json:"TmpSecretId,omitempty"`
TmpSecretKey string `json:"TmpSecretKey,omitempty"`
SessionToken string `json:"Token,omitempty"`
}
type CredentialError struct {
Code string `json:"Code,omitempty"`
Message string `json:"Message,omitempty"`
RequestId string `json:"RequestId,omitempty"`
}

func (e *CredentialError) Error() string {
return fmt.Sprintf("Code: %v, Message: %v, RequestId: %v", e.Code, e.Message, e.RequestId)
}

type CredentialResult struct {
Credentials *Credentials `json:"Credentials,omitempty"`
ExpiredTime int64 `json:"ExpiredTime,omitempty"`
RequestId string `json:"RequestId,omitempty"`
Error *CredentialError `json:"Error,omitempty"`
}

type CredentialCompleteResult struct {
Response *CredentialResult `json:"Response"`
}

type CredentialPolicyStatement struct {
Action []string `json:"action,omitempty"`
Effect string `json:"effect,omitempty"`
Resource []string `json:"resource,omitempty"`
Condition map[string]map[string]interface{} `json:"condition,omitempty"`
}

type CredentialPolicy struct {
Version string `json:"version,omitempty"`
Statement []CredentialPolicyStatement `json:"statement,omitempty"`
}

type StsCredentialTransport struct {
Transport http.RoundTripper
SecretID string
SecretKey string
Policy *CredentialPolicy
Host string
Region string
expiredTime int64
credential Credentials
rwLocker sync.RWMutex
}

func (t *StsCredentialTransport) UpdateCredential(now int64) (string, string, string, error) {
t.rwLocker.Lock()
defer t.rwLocker.Unlock()
if t.expiredTime > now+defaultTmpAuthExpire {
return t.credential.TmpSecretID, t.credential.TmpSecretKey, t.credential.SessionToken, nil
}
region := t.Region
if region == "" {
region = "ap-guangzhou"
}
policy, err := getPolicy(t.Policy)
if err != nil {
return t.credential.TmpSecretID, t.credential.TmpSecretKey, t.credential.SessionToken, err
}
params := map[string]interface{}{
"SecretId": t.SecretID,
"Policy": url.QueryEscape(policy),
"DurationSeconds": 1800,
"Region": region,
"Timestamp": time.Now().Unix(),
"Nonce": math_rand.Int(),
"Name": "cos-sts-sdk",
"Action": "GetFederationToken",
"Version": "2018-08-13",
}
resp, err := t.sendRequest(params)
if err != nil {
return t.credential.TmpSecretID, t.credential.TmpSecretKey, t.credential.SessionToken, err
}
defer resp.Body.Close()
if resp.StatusCode > 299 {
return t.credential.TmpSecretID, t.credential.TmpSecretKey, t.credential.SessionToken, fmt.Errorf("sts StatusCode error: %v", resp.StatusCode)
}
result := &CredentialCompleteResult{}
err = json.NewDecoder(resp.Body).Decode(result)
if err == io.EOF {
err = nil // ignore EOF errors caused by empty response body
}
if err != nil {
return t.credential.TmpSecretID, t.credential.TmpSecretKey, t.credential.SessionToken, err
}
if result.Response != nil && result.Response.Error != nil {
result.Response.Error.RequestId = result.Response.RequestId
return t.credential.TmpSecretID, t.credential.TmpSecretKey, t.credential.SessionToken, result.Response.Error
}
if result.Response != nil && result.Response.Credentials != nil {
t.credential.TmpSecretID, t.credential.TmpSecretKey, t.credential.SessionToken, t.expiredTime = result.Response.Credentials.TmpSecretID, result.Response.Credentials.TmpSecretKey, result.Response.Credentials.SessionToken, result.Response.ExpiredTime
return t.credential.TmpSecretID, t.credential.TmpSecretKey, t.credential.SessionToken, nil
}
return t.credential.TmpSecretID, t.credential.TmpSecretKey, t.credential.SessionToken, fmt.Errorf("GetCredential failed, result: %v", result.Response)
}

func (t *StsCredentialTransport) GetCredential() (string, string, string, error) {
now := time.Now().Unix()
t.rwLocker.RLock()
// 提前 defaultTmpAuthExpire 获取重新获取临时密钥
if t.expiredTime <= now+defaultTmpAuthExpire {
expiredTime := t.expiredTime
t.rwLocker.RUnlock()
secretID, secretKey, secretToken, err := t.UpdateCredential(now)
// 获取临时密钥失败但密钥未过期
if err != nil && now < expiredTime {
err = nil
}
return secretID, secretKey, secretToken, err
}
defer t.rwLocker.RUnlock()
return t.credential.TmpSecretID, t.credential.TmpSecretKey, t.credential.SessionToken, nil
}

func (t *StsCredentialTransport) RoundTrip(req *http.Request) (*http.Response, error) {
ak, sk, token, err := t.GetCredential()
if err != nil {
return nil, err
}
req = cloneRequest(req)
// 增加 Authorization header
authTime := NewAuthTime(defaultAuthExpire)
AddAuthorizationHeader(ak, sk, token, req, authTime)

resp, err := t.transport().RoundTrip(req)
return resp, err
}

func (t *StsCredentialTransport) transport() http.RoundTripper {
if t.Transport != nil {
return t.Transport
}
return http.DefaultTransport
}

func (t *StsCredentialTransport) sendRequest(params map[string]interface{}) (*http.Response, error) {
paramValues := url.Values{}
for k, v := range params {
paramValues.Add(fmt.Sprintf("%v", k), fmt.Sprintf("%v", v))
}
sign := t.signed("POST", params)
paramValues.Add("Signature", sign)

host := defaultStsHost
if t.Host != "" {
host = t.Host
}
resp, err := http.DefaultClient.PostForm(defaultStsSchema+"://"+host, paramValues)
return resp, err
}

func (t *StsCredentialTransport) signed(method string, params map[string]interface{}) string {
host := defaultStsHost
if t.Host != "" {
host = t.Host
}
source := method + host + "/?" + makeFlat(params)

hmacObj := hmac.New(sha1.New, []byte(t.SecretKey))
hmacObj.Write([]byte(source))

sign := base64.StdEncoding.EncodeToString(hmacObj.Sum(nil))

return sign
}

func getPolicy(policy *CredentialPolicy) (string, error) {
if policy == nil {
return "", nil
}
res := policy
if policy.Version == "" {
res = &CredentialPolicy{
Version: "2.0",
Statement: policy.Statement,
}
}
bs, err := json.Marshal(res)
if err != nil {
return "", err
}
return string(bs), nil
}

func makeFlat(params map[string]interface{}) string {
keys := make([]string, 0, len(params))
for k, _ := range params {
keys = append(keys, k)
}
sort.Strings(keys)

var plainParms string
for _, k := range keys {
plainParms += fmt.Sprintf("&%v=%v", k, params[k])
}
return plainParms[1:]
}
Loading

0 comments on commit 91931e0

Please sign in to comment.