@@ -4,9 +4,11 @@ import (
4
4
"context"
5
5
"crypto/hmac"
6
6
"crypto/sha1"
7
+ "encoding/base64"
7
8
"encoding/json"
8
9
"fmt"
9
10
"hash"
11
+ "io"
10
12
"io/ioutil"
11
13
math_rand "math/rand"
12
14
"net"
@@ -26,11 +28,13 @@ const (
26
28
)
27
29
28
30
var (
29
- defaultCVMAuthExpire = int64 (600 )
31
+ defaultTmpAuthExpire = int64 (600 )
30
32
defaultCVMSchema = "http"
31
33
defaultCVMMetaHost = "metadata.tencentyun.com"
32
34
defaultCVMCredURI = "latest/meta-data/cam/security-credentials"
33
35
internalHost = regexp .MustCompile (`^.*cos-internal\.[a-z-1]+\.tencentcos\.cn$` )
36
+ defaultStsHost = "sts.tencentcloudapi.com"
37
+ defaultStsSchema = "https"
34
38
)
35
39
36
40
var DNSScatterDialContext = DNSScatterDialContextFunc
@@ -424,7 +428,7 @@ func (t *CVMCredentialTransport) GetRoles() ([]string, error) {
424
428
func (t * CVMCredentialTransport ) UpdateCredential (now int64 ) (string , string , string , error ) {
425
429
t .rwLocker .Lock ()
426
430
defer t .rwLocker .Unlock ()
427
- if t .expiredTime > now + defaultCVMAuthExpire {
431
+ if t .expiredTime > now + defaultTmpAuthExpire {
428
432
return t .secretID , t .secretKey , t .sessionToken , nil
429
433
}
430
434
roleName := t .RoleName
@@ -460,8 +464,8 @@ func (t *CVMCredentialTransport) UpdateCredential(now int64) (string, string, st
460
464
func (t * CVMCredentialTransport ) GetCredential () (string , string , string , error ) {
461
465
now := time .Now ().Unix ()
462
466
t .rwLocker .RLock ()
463
- // 提前 defaultCVMAuthExpire 获取重新获取临时密钥
464
- if t .expiredTime <= now + defaultCVMAuthExpire {
467
+ // 提前 defaultTmpAuthExpire 获取重新获取临时密钥
468
+ if t .expiredTime <= now + defaultTmpAuthExpire {
465
469
expiredTime := t .expiredTime
466
470
t .rwLocker .RUnlock ()
467
471
secretID , secretKey , secretToken , err := t .UpdateCredential (now )
@@ -545,3 +549,208 @@ func (c *Credential) GetSecretId() string {
545
549
func (c * Credential ) GetToken () string {
546
550
return c .SessionToken
547
551
}
552
+
553
+ // 通过sts访问
554
+ type Credentials struct {
555
+ TmpSecretID string `json:"TmpSecretId,omitempty"`
556
+ TmpSecretKey string `json:"TmpSecretKey,omitempty"`
557
+ SessionToken string `json:"Token,omitempty"`
558
+ }
559
+ type CredentialError struct {
560
+ Code string `json:"Code,omitempty"`
561
+ Message string `json:"Message,omitempty"`
562
+ RequestId string `json:"RequestId,omitempty"`
563
+ }
564
+
565
+ func (e * CredentialError ) Error () string {
566
+ return fmt .Sprintf ("Code: %v, Message: %v, RequestId: %v" , e .Code , e .Message , e .RequestId )
567
+ }
568
+
569
+ type CredentialResult struct {
570
+ Credentials * Credentials `json:"Credentials,omitempty"`
571
+ ExpiredTime int64 `json:"ExpiredTime,omitempty"`
572
+ RequestId string `json:"RequestId,omitempty"`
573
+ Error * CredentialError `json:"Error,omitempty"`
574
+ }
575
+
576
+ type CredentialCompleteResult struct {
577
+ Response * CredentialResult `json:"Response"`
578
+ }
579
+
580
+ type CredentialPolicyStatement struct {
581
+ Action []string `json:"action,omitempty"`
582
+ Effect string `json:"effect,omitempty"`
583
+ Resource []string `json:"resource,omitempty"`
584
+ Condition map [string ]map [string ]interface {} `json:"condition,omitempty"`
585
+ }
586
+
587
+ type CredentialPolicy struct {
588
+ Version string `json:"version,omitempty"`
589
+ Statement []CredentialPolicyStatement `json:"statement,omitempty"`
590
+ }
591
+
592
+ type StsCredentialTransport struct {
593
+ Transport http.RoundTripper
594
+ SecretID string
595
+ SecretKey string
596
+ Policy * CredentialPolicy
597
+ Host string
598
+ Region string
599
+ expiredTime int64
600
+ credential Credentials
601
+ rwLocker sync.RWMutex
602
+ }
603
+
604
+ func (t * StsCredentialTransport ) UpdateCredential (now int64 ) (string , string , string , error ) {
605
+ t .rwLocker .Lock ()
606
+ defer t .rwLocker .Unlock ()
607
+ if t .expiredTime > now + defaultTmpAuthExpire {
608
+ return t .credential .TmpSecretID , t .credential .TmpSecretKey , t .credential .SessionToken , nil
609
+ }
610
+ region := t .Region
611
+ if region == "" {
612
+ region = "ap-guangzhou"
613
+ }
614
+ policy , err := getPolicy (t .Policy )
615
+ if err != nil {
616
+ return t .credential .TmpSecretID , t .credential .TmpSecretKey , t .credential .SessionToken , err
617
+ }
618
+ params := map [string ]interface {}{
619
+ "SecretId" : t .SecretID ,
620
+ "Policy" : url .QueryEscape (policy ),
621
+ "DurationSeconds" : 1800 ,
622
+ "Region" : region ,
623
+ "Timestamp" : time .Now ().Unix (),
624
+ "Nonce" : math_rand .Int (),
625
+ "Name" : "cos-sts-sdk" ,
626
+ "Action" : "GetFederationToken" ,
627
+ "Version" : "2018-08-13" ,
628
+ }
629
+ resp , err := t .sendRequest (params )
630
+ if err != nil {
631
+ return t .credential .TmpSecretID , t .credential .TmpSecretKey , t .credential .SessionToken , err
632
+ }
633
+ defer resp .Body .Close ()
634
+ if resp .StatusCode > 299 {
635
+ return t .credential .TmpSecretID , t .credential .TmpSecretKey , t .credential .SessionToken , fmt .Errorf ("sts StatusCode error: %v" , resp .StatusCode )
636
+ }
637
+ result := & CredentialCompleteResult {}
638
+ err = json .NewDecoder (resp .Body ).Decode (result )
639
+ if err == io .EOF {
640
+ err = nil // ignore EOF errors caused by empty response body
641
+ }
642
+ if err != nil {
643
+ return t .credential .TmpSecretID , t .credential .TmpSecretKey , t .credential .SessionToken , err
644
+ }
645
+ if result .Response != nil && result .Response .Error != nil {
646
+ result .Response .Error .RequestId = result .Response .RequestId
647
+ return t .credential .TmpSecretID , t .credential .TmpSecretKey , t .credential .SessionToken , result .Response .Error
648
+ }
649
+ if result .Response != nil && result .Response .Credentials != nil {
650
+ t .credential .TmpSecretID , t .credential .TmpSecretKey , t .credential .SessionToken , t .expiredTime = result .Response .Credentials .TmpSecretID , result .Response .Credentials .TmpSecretKey , result .Response .Credentials .SessionToken , result .Response .ExpiredTime
651
+ return t .credential .TmpSecretID , t .credential .TmpSecretKey , t .credential .SessionToken , nil
652
+ }
653
+ return t .credential .TmpSecretID , t .credential .TmpSecretKey , t .credential .SessionToken , fmt .Errorf ("GetCredential failed, result: %v" , result .Response )
654
+ }
655
+
656
+ func (t * StsCredentialTransport ) GetCredential () (string , string , string , error ) {
657
+ now := time .Now ().Unix ()
658
+ t .rwLocker .RLock ()
659
+ // 提前 defaultTmpAuthExpire 获取重新获取临时密钥
660
+ if t .expiredTime <= now + defaultTmpAuthExpire {
661
+ expiredTime := t .expiredTime
662
+ t .rwLocker .RUnlock ()
663
+ secretID , secretKey , secretToken , err := t .UpdateCredential (now )
664
+ // 获取临时密钥失败但密钥未过期
665
+ if err != nil && now < expiredTime {
666
+ err = nil
667
+ }
668
+ return secretID , secretKey , secretToken , err
669
+ }
670
+ defer t .rwLocker .RUnlock ()
671
+ return t .credential .TmpSecretID , t .credential .TmpSecretKey , t .credential .SessionToken , nil
672
+ }
673
+
674
+ func (t * StsCredentialTransport ) RoundTrip (req * http.Request ) (* http.Response , error ) {
675
+ ak , sk , token , err := t .GetCredential ()
676
+ if err != nil {
677
+ return nil , err
678
+ }
679
+ req = cloneRequest (req )
680
+ // 增加 Authorization header
681
+ authTime := NewAuthTime (defaultAuthExpire )
682
+ AddAuthorizationHeader (ak , sk , token , req , authTime )
683
+
684
+ resp , err := t .transport ().RoundTrip (req )
685
+ return resp , err
686
+ }
687
+
688
+ func (t * StsCredentialTransport ) transport () http.RoundTripper {
689
+ if t .Transport != nil {
690
+ return t .Transport
691
+ }
692
+ return http .DefaultTransport
693
+ }
694
+
695
+ func (t * StsCredentialTransport ) sendRequest (params map [string ]interface {}) (* http.Response , error ) {
696
+ paramValues := url.Values {}
697
+ for k , v := range params {
698
+ paramValues .Add (fmt .Sprintf ("%v" , k ), fmt .Sprintf ("%v" , v ))
699
+ }
700
+ sign := t .signed ("POST" , params )
701
+ paramValues .Add ("Signature" , sign )
702
+
703
+ host := defaultStsHost
704
+ if t .Host != "" {
705
+ host = t .Host
706
+ }
707
+ resp , err := http .DefaultClient .PostForm (defaultStsSchema + "://" + host , paramValues )
708
+ return resp , err
709
+ }
710
+
711
+ func (t * StsCredentialTransport ) signed (method string , params map [string ]interface {}) string {
712
+ host := defaultStsHost
713
+ if t .Host != "" {
714
+ host = t .Host
715
+ }
716
+ source := method + host + "/?" + makeFlat (params )
717
+
718
+ hmacObj := hmac .New (sha1 .New , []byte (t .SecretKey ))
719
+ hmacObj .Write ([]byte (source ))
720
+
721
+ sign := base64 .StdEncoding .EncodeToString (hmacObj .Sum (nil ))
722
+
723
+ return sign
724
+ }
725
+
726
+ func getPolicy (policy * CredentialPolicy ) (string , error ) {
727
+ if policy == nil {
728
+ return "" , nil
729
+ }
730
+ res := policy
731
+ if policy .Version == "" {
732
+ res = & CredentialPolicy {
733
+ Version : "2.0" ,
734
+ Statement : policy .Statement ,
735
+ }
736
+ }
737
+ bs , err := json .Marshal (res )
738
+ if err != nil {
739
+ return "" , err
740
+ }
741
+ return string (bs ), nil
742
+ }
743
+
744
+ func makeFlat (params map [string ]interface {}) string {
745
+ keys := make ([]string , 0 , len (params ))
746
+ for k , _ := range params {
747
+ keys = append (keys , k )
748
+ }
749
+ sort .Strings (keys )
750
+
751
+ var plainParms string
752
+ for _ , k := range keys {
753
+ plainParms += fmt .Sprintf ("&%v=%v" , k , params [k ])
754
+ }
755
+ return plainParms [1 :]
756
+ }
0 commit comments