Skip to content

Commit 91931e0

Browse files
authored
Merge pull request #270 from tencentyun/feature_jojoliang_9baf1e2a
download分块可以超过10000, 增加sts签名方式
2 parents 7b3b630 + 808cb04 commit 91931e0

File tree

5 files changed

+489
-8
lines changed

5 files changed

+489
-8
lines changed

auth.go

Lines changed: 213 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@ import (
44
"context"
55
"crypto/hmac"
66
"crypto/sha1"
7+
"encoding/base64"
78
"encoding/json"
89
"fmt"
910
"hash"
11+
"io"
1012
"io/ioutil"
1113
math_rand "math/rand"
1214
"net"
@@ -26,11 +28,13 @@ const (
2628
)
2729

2830
var (
29-
defaultCVMAuthExpire = int64(600)
31+
defaultTmpAuthExpire = int64(600)
3032
defaultCVMSchema = "http"
3133
defaultCVMMetaHost = "metadata.tencentyun.com"
3234
defaultCVMCredURI = "latest/meta-data/cam/security-credentials"
3335
internalHost = regexp.MustCompile(`^.*cos-internal\.[a-z-1]+\.tencentcos\.cn$`)
36+
defaultStsHost = "sts.tencentcloudapi.com"
37+
defaultStsSchema = "https"
3438
)
3539

3640
var DNSScatterDialContext = DNSScatterDialContextFunc
@@ -424,7 +428,7 @@ func (t *CVMCredentialTransport) GetRoles() ([]string, error) {
424428
func (t *CVMCredentialTransport) UpdateCredential(now int64) (string, string, string, error) {
425429
t.rwLocker.Lock()
426430
defer t.rwLocker.Unlock()
427-
if t.expiredTime > now+defaultCVMAuthExpire {
431+
if t.expiredTime > now+defaultTmpAuthExpire {
428432
return t.secretID, t.secretKey, t.sessionToken, nil
429433
}
430434
roleName := t.RoleName
@@ -460,8 +464,8 @@ func (t *CVMCredentialTransport) UpdateCredential(now int64) (string, string, st
460464
func (t *CVMCredentialTransport) GetCredential() (string, string, string, error) {
461465
now := time.Now().Unix()
462466
t.rwLocker.RLock()
463-
// 提前 defaultCVMAuthExpire 获取重新获取临时密钥
464-
if t.expiredTime <= now+defaultCVMAuthExpire {
467+
// 提前 defaultTmpAuthExpire 获取重新获取临时密钥
468+
if t.expiredTime <= now+defaultTmpAuthExpire {
465469
expiredTime := t.expiredTime
466470
t.rwLocker.RUnlock()
467471
secretID, secretKey, secretToken, err := t.UpdateCredential(now)
@@ -545,3 +549,208 @@ func (c *Credential) GetSecretId() string {
545549
func (c *Credential) GetToken() string {
546550
return c.SessionToken
547551
}
552+
553+
// 通过sts访问
554+
type Credentials struct {
555+
TmpSecretID string `json:"TmpSecretId,omitempty"`
556+
TmpSecretKey string `json:"TmpSecretKey,omitempty"`
557+
SessionToken string `json:"Token,omitempty"`
558+
}
559+
type CredentialError struct {
560+
Code string `json:"Code,omitempty"`
561+
Message string `json:"Message,omitempty"`
562+
RequestId string `json:"RequestId,omitempty"`
563+
}
564+
565+
func (e *CredentialError) Error() string {
566+
return fmt.Sprintf("Code: %v, Message: %v, RequestId: %v", e.Code, e.Message, e.RequestId)
567+
}
568+
569+
type CredentialResult struct {
570+
Credentials *Credentials `json:"Credentials,omitempty"`
571+
ExpiredTime int64 `json:"ExpiredTime,omitempty"`
572+
RequestId string `json:"RequestId,omitempty"`
573+
Error *CredentialError `json:"Error,omitempty"`
574+
}
575+
576+
type CredentialCompleteResult struct {
577+
Response *CredentialResult `json:"Response"`
578+
}
579+
580+
type CredentialPolicyStatement struct {
581+
Action []string `json:"action,omitempty"`
582+
Effect string `json:"effect,omitempty"`
583+
Resource []string `json:"resource,omitempty"`
584+
Condition map[string]map[string]interface{} `json:"condition,omitempty"`
585+
}
586+
587+
type CredentialPolicy struct {
588+
Version string `json:"version,omitempty"`
589+
Statement []CredentialPolicyStatement `json:"statement,omitempty"`
590+
}
591+
592+
type StsCredentialTransport struct {
593+
Transport http.RoundTripper
594+
SecretID string
595+
SecretKey string
596+
Policy *CredentialPolicy
597+
Host string
598+
Region string
599+
expiredTime int64
600+
credential Credentials
601+
rwLocker sync.RWMutex
602+
}
603+
604+
func (t *StsCredentialTransport) UpdateCredential(now int64) (string, string, string, error) {
605+
t.rwLocker.Lock()
606+
defer t.rwLocker.Unlock()
607+
if t.expiredTime > now+defaultTmpAuthExpire {
608+
return t.credential.TmpSecretID, t.credential.TmpSecretKey, t.credential.SessionToken, nil
609+
}
610+
region := t.Region
611+
if region == "" {
612+
region = "ap-guangzhou"
613+
}
614+
policy, err := getPolicy(t.Policy)
615+
if err != nil {
616+
return t.credential.TmpSecretID, t.credential.TmpSecretKey, t.credential.SessionToken, err
617+
}
618+
params := map[string]interface{}{
619+
"SecretId": t.SecretID,
620+
"Policy": url.QueryEscape(policy),
621+
"DurationSeconds": 1800,
622+
"Region": region,
623+
"Timestamp": time.Now().Unix(),
624+
"Nonce": math_rand.Int(),
625+
"Name": "cos-sts-sdk",
626+
"Action": "GetFederationToken",
627+
"Version": "2018-08-13",
628+
}
629+
resp, err := t.sendRequest(params)
630+
if err != nil {
631+
return t.credential.TmpSecretID, t.credential.TmpSecretKey, t.credential.SessionToken, err
632+
}
633+
defer resp.Body.Close()
634+
if resp.StatusCode > 299 {
635+
return t.credential.TmpSecretID, t.credential.TmpSecretKey, t.credential.SessionToken, fmt.Errorf("sts StatusCode error: %v", resp.StatusCode)
636+
}
637+
result := &CredentialCompleteResult{}
638+
err = json.NewDecoder(resp.Body).Decode(result)
639+
if err == io.EOF {
640+
err = nil // ignore EOF errors caused by empty response body
641+
}
642+
if err != nil {
643+
return t.credential.TmpSecretID, t.credential.TmpSecretKey, t.credential.SessionToken, err
644+
}
645+
if result.Response != nil && result.Response.Error != nil {
646+
result.Response.Error.RequestId = result.Response.RequestId
647+
return t.credential.TmpSecretID, t.credential.TmpSecretKey, t.credential.SessionToken, result.Response.Error
648+
}
649+
if result.Response != nil && result.Response.Credentials != nil {
650+
t.credential.TmpSecretID, t.credential.TmpSecretKey, t.credential.SessionToken, t.expiredTime = result.Response.Credentials.TmpSecretID, result.Response.Credentials.TmpSecretKey, result.Response.Credentials.SessionToken, result.Response.ExpiredTime
651+
return t.credential.TmpSecretID, t.credential.TmpSecretKey, t.credential.SessionToken, nil
652+
}
653+
return t.credential.TmpSecretID, t.credential.TmpSecretKey, t.credential.SessionToken, fmt.Errorf("GetCredential failed, result: %v", result.Response)
654+
}
655+
656+
func (t *StsCredentialTransport) GetCredential() (string, string, string, error) {
657+
now := time.Now().Unix()
658+
t.rwLocker.RLock()
659+
// 提前 defaultTmpAuthExpire 获取重新获取临时密钥
660+
if t.expiredTime <= now+defaultTmpAuthExpire {
661+
expiredTime := t.expiredTime
662+
t.rwLocker.RUnlock()
663+
secretID, secretKey, secretToken, err := t.UpdateCredential(now)
664+
// 获取临时密钥失败但密钥未过期
665+
if err != nil && now < expiredTime {
666+
err = nil
667+
}
668+
return secretID, secretKey, secretToken, err
669+
}
670+
defer t.rwLocker.RUnlock()
671+
return t.credential.TmpSecretID, t.credential.TmpSecretKey, t.credential.SessionToken, nil
672+
}
673+
674+
func (t *StsCredentialTransport) RoundTrip(req *http.Request) (*http.Response, error) {
675+
ak, sk, token, err := t.GetCredential()
676+
if err != nil {
677+
return nil, err
678+
}
679+
req = cloneRequest(req)
680+
// 增加 Authorization header
681+
authTime := NewAuthTime(defaultAuthExpire)
682+
AddAuthorizationHeader(ak, sk, token, req, authTime)
683+
684+
resp, err := t.transport().RoundTrip(req)
685+
return resp, err
686+
}
687+
688+
func (t *StsCredentialTransport) transport() http.RoundTripper {
689+
if t.Transport != nil {
690+
return t.Transport
691+
}
692+
return http.DefaultTransport
693+
}
694+
695+
func (t *StsCredentialTransport) sendRequest(params map[string]interface{}) (*http.Response, error) {
696+
paramValues := url.Values{}
697+
for k, v := range params {
698+
paramValues.Add(fmt.Sprintf("%v", k), fmt.Sprintf("%v", v))
699+
}
700+
sign := t.signed("POST", params)
701+
paramValues.Add("Signature", sign)
702+
703+
host := defaultStsHost
704+
if t.Host != "" {
705+
host = t.Host
706+
}
707+
resp, err := http.DefaultClient.PostForm(defaultStsSchema+"://"+host, paramValues)
708+
return resp, err
709+
}
710+
711+
func (t *StsCredentialTransport) signed(method string, params map[string]interface{}) string {
712+
host := defaultStsHost
713+
if t.Host != "" {
714+
host = t.Host
715+
}
716+
source := method + host + "/?" + makeFlat(params)
717+
718+
hmacObj := hmac.New(sha1.New, []byte(t.SecretKey))
719+
hmacObj.Write([]byte(source))
720+
721+
sign := base64.StdEncoding.EncodeToString(hmacObj.Sum(nil))
722+
723+
return sign
724+
}
725+
726+
func getPolicy(policy *CredentialPolicy) (string, error) {
727+
if policy == nil {
728+
return "", nil
729+
}
730+
res := policy
731+
if policy.Version == "" {
732+
res = &CredentialPolicy{
733+
Version: "2.0",
734+
Statement: policy.Statement,
735+
}
736+
}
737+
bs, err := json.Marshal(res)
738+
if err != nil {
739+
return "", err
740+
}
741+
return string(bs), nil
742+
}
743+
744+
func makeFlat(params map[string]interface{}) string {
745+
keys := make([]string, 0, len(params))
746+
for k, _ := range params {
747+
keys = append(keys, k)
748+
}
749+
sort.Strings(keys)
750+
751+
var plainParms string
752+
for _, k := range keys {
753+
plainParms += fmt.Sprintf("&%v=%v", k, params[k])
754+
}
755+
return plainParms[1:]
756+
}

0 commit comments

Comments
 (0)