diff --git a/common/tls/config.go b/common/tls/config.go index 70aa390d..c19b0179 100644 --- a/common/tls/config.go +++ b/common/tls/config.go @@ -12,61 +12,88 @@ import ( "time" ) -var fingerprints [][32]byte -var rwLock sync.Mutex -var defaultTLSConfig = &tls.Config{ - InsecureSkipVerify: true, - VerifyPeerCertificate: verifyPeerCertificate, -} -var verifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { - fingerprints := fingerprints +var globalFingerprints [][32]byte +var mutex sync.Mutex - var preErr error - for i := range rawCerts { - rawCert := rawCerts[i] - cert, err := x509.ParseCertificate(rawCert) - if err == nil { - opts := x509.VerifyOptions{ - CurrentTime: time.Now(), - } +func verifyPeerCertificateAndFingerprints(fingerprints [][32]byte) func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { + return func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { - if _, err := cert.Verify(opts); err == nil { - return nil - } else { - fingerprint := sha256.Sum256(cert.Raw) - for _, fp := range fingerprints { - if bytes.Equal(fingerprint[:], fp[:]) { - return nil - } + var preErr error + for i := range rawCerts { + rawCert := rawCerts[i] + cert, err := x509.ParseCertificate(rawCert) + if err == nil { + opts := x509.VerifyOptions{ + CurrentTime: time.Now(), } - preErr = err + if _, err := cert.Verify(opts); err == nil { + return nil + } else { + fingerprint := sha256.Sum256(cert.Raw) + for _, fp := range fingerprints { + if bytes.Equal(fingerprint[:], fp[:]) { + return nil + } + } + + preErr = err + } } } - } - return preErr + return preErr + } } func AddCertFingerprint(fingerprint string) error { - fp := strings.Replace(fingerprint, ":", "", -1) - fpByte, err := hex.DecodeString(fp) - if err != nil { - return err + fpByte, err2 := convertFingerprint(fingerprint) + if err2 != nil { + return err2 } - if len(fpByte) != 32 { - return fmt.Errorf("fingerprint string length error,need sha25 fingerprint") - } - - rwLock.Lock() - fingerprints = append(fingerprints, *(*[32]byte)(fpByte)) - rwLock.Unlock() + mutex.Lock() + globalFingerprints = append(globalFingerprints, *fpByte) + mutex.Unlock() return nil } +func convertFingerprint(fingerprint string) (*[32]byte, error) { + fp := strings.Replace(fingerprint, ":", "", -1) + fpByte, err := hex.DecodeString(fp) + if err != nil { + return nil, err + } + + if len(fpByte) != 32 { + return nil, fmt.Errorf("fingerprint string length error,need sha25 fingerprint") + } + return (*[32]byte)(fpByte), nil +} + func GetDefaultTLSConfig() *tls.Config { - return defaultTLSConfig + return &tls.Config{ + InsecureSkipVerify: true, + VerifyPeerCertificate: verifyPeerCertificateAndFingerprints(globalFingerprints), + } +} + +// GetTLSConfigWithSpecifiedFingerprint specified fingerprint +func GetTLSConfigWithSpecifiedFingerprint(tlsConfig *tls.Config, fingerprint string) (*tls.Config, error) { + if fingerprintBytes, err := convertFingerprint(fingerprint); err != nil { + return nil, err + } else { + if tlsConfig == nil { + return &tls.Config{ + InsecureSkipVerify: true, + VerifyPeerCertificate: verifyPeerCertificateAndFingerprints([][32]byte{*fingerprintBytes}), + }, nil + } else { + tlsConfig.InsecureSkipVerify = true + tlsConfig.VerifyPeerCertificate = verifyPeerCertificateAndFingerprints([][32]byte{*fingerprintBytes}) + return tlsConfig, nil + } + } } func MixinTLSConfig(tlsConfig *tls.Config) *tls.Config { @@ -75,6 +102,6 @@ func MixinTLSConfig(tlsConfig *tls.Config) *tls.Config { } tlsConfig.InsecureSkipVerify = true - tlsConfig.VerifyPeerCertificate = verifyPeerCertificate + tlsConfig.VerifyPeerCertificate = verifyPeerCertificateAndFingerprints(globalFingerprints) return tlsConfig }