From 808cb0447724a20cb081055bf0012db9baf1e2af Mon Sep 17 00:00:00 2001 From: jojoliang Date: Thu, 11 Jul 2024 17:49:39 +0800 Subject: [PATCH] =?UTF-8?q?download=E5=88=86=E5=9D=97=E5=8F=AF=E4=BB=A5?= =?UTF-8?q?=E8=B6=85=E8=BF=8710000,=20=E5=A2=9E=E5=8A=A0sts=E7=AD=BE?= =?UTF-8?q?=E5=90=8D=E6=96=B9=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- auth.go | 217 ++++++++++++++++++++++++++++++++- auth_test.go | 172 +++++++++++++++++++++++++- cos.go | 2 +- example/object/get_with_sts.go | 73 +++++++++++ object.go | 33 ++++- 5 files changed, 489 insertions(+), 8 deletions(-) create mode 100644 example/object/get_with_sts.go diff --git a/auth.go b/auth.go index 41b3bfe..f9e3c10 100644 --- a/auth.go +++ b/auth.go @@ -4,9 +4,11 @@ import ( "context" "crypto/hmac" "crypto/sha1" + "encoding/base64" "encoding/json" "fmt" "hash" + "io" "io/ioutil" math_rand "math/rand" "net" @@ -26,11 +28,13 @@ const ( ) var ( - defaultCVMAuthExpire = int64(600) + defaultTmpAuthExpire = int64(600) defaultCVMSchema = "http" defaultCVMMetaHost = "metadata.tencentyun.com" defaultCVMCredURI = "latest/meta-data/cam/security-credentials" internalHost = regexp.MustCompile(`^.*cos-internal\.[a-z-1]+\.tencentcos\.cn$`) + defaultStsHost = "sts.tencentcloudapi.com" + defaultStsSchema = "https" ) var DNSScatterDialContext = DNSScatterDialContextFunc @@ -424,7 +428,7 @@ func (t *CVMCredentialTransport) GetRoles() ([]string, error) { func (t *CVMCredentialTransport) UpdateCredential(now int64) (string, string, string, error) { t.rwLocker.Lock() defer t.rwLocker.Unlock() - if t.expiredTime > now+defaultCVMAuthExpire { + if t.expiredTime > now+defaultTmpAuthExpire { return t.secretID, t.secretKey, t.sessionToken, nil } roleName := t.RoleName @@ -460,8 +464,8 @@ func (t *CVMCredentialTransport) UpdateCredential(now int64) (string, string, st func (t *CVMCredentialTransport) GetCredential() (string, string, string, error) { now := time.Now().Unix() t.rwLocker.RLock() - // 提前 defaultCVMAuthExpire 获取重新获取临时密钥 - if t.expiredTime <= now+defaultCVMAuthExpire { + // 提前 defaultTmpAuthExpire 获取重新获取临时密钥 + if t.expiredTime <= now+defaultTmpAuthExpire { expiredTime := t.expiredTime t.rwLocker.RUnlock() secretID, secretKey, secretToken, err := t.UpdateCredential(now) @@ -545,3 +549,208 @@ func (c *Credential) GetSecretId() string { func (c *Credential) GetToken() string { return c.SessionToken } + +// 通过sts访问 +type Credentials struct { + TmpSecretID string `json:"TmpSecretId,omitempty"` + TmpSecretKey string `json:"TmpSecretKey,omitempty"` + SessionToken string `json:"Token,omitempty"` +} +type CredentialError struct { + Code string `json:"Code,omitempty"` + Message string `json:"Message,omitempty"` + RequestId string `json:"RequestId,omitempty"` +} + +func (e *CredentialError) Error() string { + return fmt.Sprintf("Code: %v, Message: %v, RequestId: %v", e.Code, e.Message, e.RequestId) +} + +type CredentialResult struct { + Credentials *Credentials `json:"Credentials,omitempty"` + ExpiredTime int64 `json:"ExpiredTime,omitempty"` + RequestId string `json:"RequestId,omitempty"` + Error *CredentialError `json:"Error,omitempty"` +} + +type CredentialCompleteResult struct { + Response *CredentialResult `json:"Response"` +} + +type CredentialPolicyStatement struct { + Action []string `json:"action,omitempty"` + Effect string `json:"effect,omitempty"` + Resource []string `json:"resource,omitempty"` + Condition map[string]map[string]interface{} `json:"condition,omitempty"` +} + +type CredentialPolicy struct { + Version string `json:"version,omitempty"` + Statement []CredentialPolicyStatement `json:"statement,omitempty"` +} + +type StsCredentialTransport struct { + Transport http.RoundTripper + SecretID string + SecretKey string + Policy *CredentialPolicy + Host string + Region string + expiredTime int64 + credential Credentials + rwLocker sync.RWMutex +} + +func (t *StsCredentialTransport) UpdateCredential(now int64) (string, string, string, error) { + t.rwLocker.Lock() + defer t.rwLocker.Unlock() + if t.expiredTime > now+defaultTmpAuthExpire { + return t.credential.TmpSecretID, t.credential.TmpSecretKey, t.credential.SessionToken, nil + } + region := t.Region + if region == "" { + region = "ap-guangzhou" + } + policy, err := getPolicy(t.Policy) + if err != nil { + return t.credential.TmpSecretID, t.credential.TmpSecretKey, t.credential.SessionToken, err + } + params := map[string]interface{}{ + "SecretId": t.SecretID, + "Policy": url.QueryEscape(policy), + "DurationSeconds": 1800, + "Region": region, + "Timestamp": time.Now().Unix(), + "Nonce": math_rand.Int(), + "Name": "cos-sts-sdk", + "Action": "GetFederationToken", + "Version": "2018-08-13", + } + resp, err := t.sendRequest(params) + if err != nil { + return t.credential.TmpSecretID, t.credential.TmpSecretKey, t.credential.SessionToken, err + } + defer resp.Body.Close() + if resp.StatusCode > 299 { + return t.credential.TmpSecretID, t.credential.TmpSecretKey, t.credential.SessionToken, fmt.Errorf("sts StatusCode error: %v", resp.StatusCode) + } + result := &CredentialCompleteResult{} + err = json.NewDecoder(resp.Body).Decode(result) + if err == io.EOF { + err = nil // ignore EOF errors caused by empty response body + } + if err != nil { + return t.credential.TmpSecretID, t.credential.TmpSecretKey, t.credential.SessionToken, err + } + if result.Response != nil && result.Response.Error != nil { + result.Response.Error.RequestId = result.Response.RequestId + return t.credential.TmpSecretID, t.credential.TmpSecretKey, t.credential.SessionToken, result.Response.Error + } + if result.Response != nil && result.Response.Credentials != nil { + t.credential.TmpSecretID, t.credential.TmpSecretKey, t.credential.SessionToken, t.expiredTime = result.Response.Credentials.TmpSecretID, result.Response.Credentials.TmpSecretKey, result.Response.Credentials.SessionToken, result.Response.ExpiredTime + return t.credential.TmpSecretID, t.credential.TmpSecretKey, t.credential.SessionToken, nil + } + return t.credential.TmpSecretID, t.credential.TmpSecretKey, t.credential.SessionToken, fmt.Errorf("GetCredential failed, result: %v", result.Response) +} + +func (t *StsCredentialTransport) GetCredential() (string, string, string, error) { + now := time.Now().Unix() + t.rwLocker.RLock() + // 提前 defaultTmpAuthExpire 获取重新获取临时密钥 + if t.expiredTime <= now+defaultTmpAuthExpire { + expiredTime := t.expiredTime + t.rwLocker.RUnlock() + secretID, secretKey, secretToken, err := t.UpdateCredential(now) + // 获取临时密钥失败但密钥未过期 + if err != nil && now < expiredTime { + err = nil + } + return secretID, secretKey, secretToken, err + } + defer t.rwLocker.RUnlock() + return t.credential.TmpSecretID, t.credential.TmpSecretKey, t.credential.SessionToken, nil +} + +func (t *StsCredentialTransport) RoundTrip(req *http.Request) (*http.Response, error) { + ak, sk, token, err := t.GetCredential() + if err != nil { + return nil, err + } + req = cloneRequest(req) + // 增加 Authorization header + authTime := NewAuthTime(defaultAuthExpire) + AddAuthorizationHeader(ak, sk, token, req, authTime) + + resp, err := t.transport().RoundTrip(req) + return resp, err +} + +func (t *StsCredentialTransport) transport() http.RoundTripper { + if t.Transport != nil { + return t.Transport + } + return http.DefaultTransport +} + +func (t *StsCredentialTransport) sendRequest(params map[string]interface{}) (*http.Response, error) { + paramValues := url.Values{} + for k, v := range params { + paramValues.Add(fmt.Sprintf("%v", k), fmt.Sprintf("%v", v)) + } + sign := t.signed("POST", params) + paramValues.Add("Signature", sign) + + host := defaultStsHost + if t.Host != "" { + host = t.Host + } + resp, err := http.DefaultClient.PostForm(defaultStsSchema+"://"+host, paramValues) + return resp, err +} + +func (t *StsCredentialTransport) signed(method string, params map[string]interface{}) string { + host := defaultStsHost + if t.Host != "" { + host = t.Host + } + source := method + host + "/?" + makeFlat(params) + + hmacObj := hmac.New(sha1.New, []byte(t.SecretKey)) + hmacObj.Write([]byte(source)) + + sign := base64.StdEncoding.EncodeToString(hmacObj.Sum(nil)) + + return sign +} + +func getPolicy(policy *CredentialPolicy) (string, error) { + if policy == nil { + return "", nil + } + res := policy + if policy.Version == "" { + res = &CredentialPolicy{ + Version: "2.0", + Statement: policy.Statement, + } + } + bs, err := json.Marshal(res) + if err != nil { + return "", err + } + return string(bs), nil +} + +func makeFlat(params map[string]interface{}) string { + keys := make([]string, 0, len(params)) + for k, _ := range params { + keys = append(keys, k) + } + sort.Strings(keys) + + var plainParms string + for _, k := range keys { + plainParms += fmt.Sprintf("&%v=%v", k, params[k]) + } + return plainParms[1:] +} diff --git a/auth_test.go b/auth_test.go index a903410..b0b1348 100644 --- a/auth_test.go +++ b/auth_test.go @@ -235,7 +235,7 @@ func TestCVMCredentialTransportErr(t *testing.T) { secretID: "ak", secretKey: "sk", sessionToken: "token", - expiredTime: nt + defaultCVMAuthExpire + 1, + expiredTime: nt + defaultTmpAuthExpire + 1, } // 密钥未超时 ak, sk, token, err := transport.UpdateCredential(nt) @@ -243,7 +243,7 @@ func TestCVMCredentialTransportErr(t *testing.T) { t.Errorf("UpdateCredential failed, return: %v, %v, %v, want: %v", ak, sk, token, *transport) } // 密钥超时,GetRoles返回错误 - transport.expiredTime = nt + defaultCVMAuthExpire - 1 + transport.expiredTime = nt + defaultTmpAuthExpire - 1 ak, sk, token, err = transport.UpdateCredential(nt) if ak != transport.secretID || sk != transport.secretKey || token != transport.sessionToken || err == nil { t.Errorf("UpdateCredential failed, return: %v, %v, %v, want: %v", ak, sk, token, *transport) @@ -317,3 +317,171 @@ func TestCredentialTransport(t *testing.T) { client.GetCredential() } + +func TestStsCredentialTransport(t *testing.T) { + setup() + defer teardown() + uri := client.BaseURL.BucketURL.String() + ak := "test_ak" + sk := "test_sk" + token := "test_token" + + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("x-cos-security-token") != token { + t.Errorf("StsCredentialTransport x-cos-security-token error, want:%v, return:%v\n", token, r.Header.Get("x-cos-security-token")) + } + auth := r.Header.Get("Authorization") + if auth == "" { + t.Error("StsCredentialTransport didn't add Authorization header") + } + field := strings.Split(auth, "&") + if len(field) != 7 { + t.Errorf("StsCredentialTransport Authorization header format error: %v\n", auth) + } + st_et := strings.Split(strings.Split(field[2], "=")[1], ";") + st, _ := strconv.ParseInt(st_et[0], 10, 64) + et, _ := strconv.ParseInt(st_et[1], 10, 64) + authTime := &AuthTime{ + SignStartTime: time.Unix(st, 0), + SignEndTime: time.Unix(et, 0), + KeyStartTime: time.Unix(st, 0), + KeyEndTime: time.Unix(et, 0), + } + host := strings.TrimLeft(uri, "http://") + req, _ := http.NewRequest("GET", uri, nil) + req.Header.Add("Host", host) + expect := newAuthorization(ak, sk, req, authTime, true) + if expect != auth { + t.Errorf("StsCredentialTransport Authorization error, want:%v, return:%v\n", expect, auth) + } + }) + + // CVM http server + cvm_mux := http.NewServeMux() + cvm_server := httptest.NewServer(cvm_mux) + defer cvm_server.Close() + // 将默认 CVM Host 修改成测试IP:PORT + defaultStsSchema = "http" + defaultStsHost = strings.TrimLeft(cvm_server.URL, "http://") + + cvm_mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, fmt.Sprintf(`{ + "Response": { + "Credentials": { + "TmpSecretId": "%v", + "TmpSecretKey": "%v", + "Token": "%v" + }, + "Expiration": "2023-06-14T05:06:57Z", + "ExpiredTime": 1686719217, + "RequestId": "59a5e07e-4147-4d2e-a808-dca76ac5b3fd" + } + }`, ak, sk, token)) + }) + + client.client.Transport = &StsCredentialTransport{} + req, _ := http.NewRequest("GET", client.BaseURL.BucketURL.String(), nil) + _, err := client.doAPI(context.Background(), req, nil, true) + if err != nil { + t.Errorf("doAPI failed: %v", err) + } + + _, err = client.doAPI(context.Background(), req, nil, true) + if err != nil { + t.Errorf("doAPI failed: %v", err) + } + + client.client.Transport = &StsCredentialTransport{ + Policy: &CredentialPolicy{ + Statement: []CredentialPolicyStatement{ + { + // 密钥的权限列表。简单上传和分片需要以下的权限,其他权限列表请看 https://cloud.tencent.com/document/product/436/31923 + Action: []string{ + // 简单上传 + "name/cos:GetObject", + }, + Effect: "allow", + Resource: []string{ + // 这里改成允许的路径前缀,可以根据自己网站的用户登录态判断允许上传的具体路径,例子: a.jpg 或者 a/* 或者 * (使用通配符*存在重大安全风险, 请谨慎评估使用) + // 存储桶的命名格式为 BucketName-APPID,此处填写的 bucket 必须为此格式 + "qcs::cos:ap-guangzhou:uid/1250000000:test-12500000000/*", + }, + }, + }, + }, + Host: strings.TrimLeft(cvm_server.URL, "http://"), + } + req, _ = http.NewRequest("GET", client.BaseURL.BucketURL.String(), nil) + _, err = client.doAPI(context.Background(), req, nil, true) + if err != nil { + t.Errorf("doAPI failed: %v", err) + } + +} + +func TestStsCredentialTransportErr(t *testing.T) { + setup() + defer teardown() + + // CVM http server + cvm_mux := http.NewServeMux() + cvm_server := httptest.NewServer(cvm_mux) + defer cvm_server.Close() + // 将默认 CVM Host 修改成测试IP:PORT + defaultStsSchema = "http" + + var expectErr int + cvm_mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + if expectErr == 0 { + w.WriteHeader(http.StatusGatewayTimeout) + } else if expectErr == 1 { + fmt.Fprint(w, `{"RequestId": "59a5e07e-4147-4d2e-a808-dca76ac5b3fd",}`) + } else if expectErr == 2 { + fmt.Fprint(w, `{"Response": { + "Error": { + "Code": "error", + "Message": "error" + }, + "RequestId": "59a5e07e-4147-4d2e-a808-dca76ac5b3fd" + }}`) + } else if expectErr == 3 { + fmt.Fprint(w, `{"Response": { + "Expiration": "2023-06-14T05:06:57Z", + "ExpiredTime": 1686719217, + "RequestId": "59a5e07e-4147-4d2e-a808-dca76ac5b3fd" + }}`) + } + }) + + client.client.Transport = &StsCredentialTransport{ + Host: strings.TrimLeft(cvm_server.URL, "http://"), + } + req, _ := http.NewRequest("GET", client.BaseURL.BucketURL.String(), nil) + expectErr = 0 + _, err := client.doAPI(context.Background(), req, nil, true) + if err == nil { + t.Errorf("doAPI expect error") + } + expectErr = 1 + _, err = client.doAPI(context.Background(), req, nil, true) + if err == nil { + t.Errorf("doAPI expect error") + } + expectErr = 2 + _, err = client.doAPI(context.Background(), req, nil, true) + if err == nil { + t.Errorf("doAPI expect error") + } + expectErr = 3 + _, err = client.doAPI(context.Background(), req, nil, true) + if err == nil { + t.Errorf("doAPI expect error") + } + if (&CredentialError{ + Code: "error", + Message: "error", + RequestId: "error", + }).Error() != "Code: error, Message: error, RequestId: error" { + t.Errorf("CredentialError format error") + } +} diff --git a/cos.go b/cos.go index 0dca512..3beb10c 100644 --- a/cos.go +++ b/cos.go @@ -25,7 +25,7 @@ import ( const ( // Version current go sdk version - Version = "0.7.53" + Version = "0.7.54" UserAgent = "cos-go-sdk-v5/" + Version contentTypeXML = "application/xml" defaultServiceBaseURL = "http://service.cos.myqcloud.com" diff --git a/example/object/get_with_sts.go b/example/object/get_with_sts.go new file mode 100644 index 0000000..daf34f4 --- /dev/null +++ b/example/object/get_with_sts.go @@ -0,0 +1,73 @@ +package main + +import ( + "context" + "fmt" + "net/http" + "net/url" + "os" + + "github.com/tencentyun/cos-go-sdk-v5" + "github.com/tencentyun/cos-go-sdk-v5/debug" +) + +func log_status(err error) { + if err == nil { + return + } + if cos.IsNotFoundError(err) { + // WARN + fmt.Println("WARN: Resource is not existed") + } else if e, ok := cos.IsCOSError(err); ok { + fmt.Printf("ERROR: Code: %v\n", e.Code) + fmt.Printf("ERROR: Message: %v\n", e.Message) + fmt.Printf("ERROR: Resource: %v\n", e.Resource) + fmt.Printf("ERROR: RequestId: %v\n", e.RequestID) + // ERROR + } else { + fmt.Printf("ERROR: %v\n", err) + // ERROR + } +} + +func main() { + // 存储桶名称,由bucketname-appid 组成,appid必须填入,可以在COS控制台查看存储桶名称。 https://console.cloud.tencent.com/cos5/bucket + // 替换为用户的 region,存储桶region可以在COS控制台“存储桶概览”查看 https://console.cloud.tencent.com/ ,关于地域的详情见 https://cloud.tencent.com/document/product/436/6224 。 + u, _ := url.Parse("https://test-1259654469.cos.ap-guangzhou.myqcloud.com") + b := &cos.BaseURL{BucketURL: u} + c := cos.NewClient(b, &http.Client{ + Transport: &cos.StsCredentialTransport{ + SecretID: os.Getenv("COS_SECRETID"), + SecretKey: os.Getenv("COS_SECRETKEY"), + Transport: &debug.DebugRequestTransport{ + RequestHeader: true, + RequestBody: true, + ResponseHeader: true, + ResponseBody: false, + }, + Policy: &cos.CredentialPolicy{ + Statement: []cos.CredentialPolicyStatement{ + { + // 密钥的权限列表。简单上传和分片需要以下的权限,其他权限列表请看 https://cloud.tencent.com/document/product/436/31923 + Action: []string{ + // 简单上传 + "name/cos:GetObject", + }, + Effect: "allow", + Resource: []string{ + // 这里改成允许的路径前缀,可以根据自己网站的用户登录态判断允许上传的具体路径,例子: a.jpg 或者 a/* 或者 * (使用通配符*存在重大安全风险, 请谨慎评估使用) + // 存储桶的命名格式为 BucketName-APPID,此处填写的 bucket 必须为此格式 + "qcs::cos:ap-guangzhou:uid/1259654469:test-1259654469/example", + }, + }, + }, + }, + }, + }) + + // Case1 上传对象 + name := "example" + // Case3 通过本地文件上传对象 + _, err := c.Object.GetToFile(context.Background(), name, "./test", nil) // 请求的超时时间为 min{context超时时间, HTTP超时时间} + log_status(err) +} diff --git a/object.go b/object.go index 59a7df6..368782d 100644 --- a/object.go +++ b/object.go @@ -1475,6 +1475,37 @@ func SplitSizeIntoChunks(totalBytes int64, partSize int64) ([]Chunk, int, error) return chunks, int(partNum), nil } +func SplitSizeIntoChunksToDownload(totalBytes int64, partSize int64) ([]Chunk, int, error) { + var partNum int64 + if partSize > 0 { + if partSize < 1024*1024 { + return nil, 0, errors.New("partSize>=1048576 is required") + } + partNum = totalBytes / partSize + } else { + partNum, partSize = DividePart(totalBytes, 16) + } + + var chunks []Chunk + var chunk = Chunk{} + for i := int64(0); i < partNum; i++ { + chunk.Number = int(i + 1) + chunk.OffSet = i * partSize + chunk.Size = partSize + chunks = append(chunks, chunk) + } + + if totalBytes%partSize > 0 { + chunk.Number = len(chunks) + 1 + chunk.OffSet = int64(len(chunks)) * partSize + chunk.Size = totalBytes % partSize + chunks = append(chunks, chunk) + partNum++ + } + + return chunks, int(partNum), nil +} + func (s *ObjectService) checkDownloadedParts(opt *MultiDownloadCPInfo, chfile string, chunks []Chunk) (*MultiDownloadCPInfo, bool) { var defaultRes MultiDownloadCPInfo defaultRes = *opt @@ -1551,7 +1582,7 @@ func (s *ObjectService) Download(ctx context.Context, name string, filepath stri } // 切分 - chunks, partNum, err := SplitSizeIntoChunks(totalBytes, opt.PartSize*1024*1024) + chunks, partNum, err := SplitSizeIntoChunksToDownload(totalBytes, opt.PartSize*1024*1024) if err != nil { return resp, err }