feat: Prepare to specify the fingerprint function

This commit is contained in:
Skyxim 2022-07-10 21:56:33 +08:00
parent fef9f95e65
commit dbce268692

View file

@ -12,61 +12,88 @@ import (
"time" "time"
) )
var fingerprints [][32]byte var globalFingerprints [][32]byte
var rwLock sync.Mutex var mutex sync.Mutex
var defaultTLSConfig = &tls.Config{
InsecureSkipVerify: true,
VerifyPeerCertificate: verifyPeerCertificate,
}
var verifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
fingerprints := fingerprints
var preErr error func verifyPeerCertificateAndFingerprints(fingerprints [][32]byte) func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
for i := range rawCerts { return func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
rawCert := rawCerts[i]
cert, err := x509.ParseCertificate(rawCert)
if err == nil {
opts := x509.VerifyOptions{
CurrentTime: time.Now(),
}
if _, err := cert.Verify(opts); err == nil { var preErr error
return nil for i := range rawCerts {
} else { rawCert := rawCerts[i]
fingerprint := sha256.Sum256(cert.Raw) cert, err := x509.ParseCertificate(rawCert)
for _, fp := range fingerprints { if err == nil {
if bytes.Equal(fingerprint[:], fp[:]) { opts := x509.VerifyOptions{
return nil CurrentTime: time.Now(),
}
} }
preErr = err if _, err := cert.Verify(opts); err == nil {
return nil
} else {
fingerprint := sha256.Sum256(cert.Raw)
for _, fp := range fingerprints {
if bytes.Equal(fingerprint[:], fp[:]) {
return nil
}
}
preErr = err
}
} }
} }
}
return preErr return preErr
}
} }
func AddCertFingerprint(fingerprint string) error { func AddCertFingerprint(fingerprint string) error {
fp := strings.Replace(fingerprint, ":", "", -1) fpByte, err2 := convertFingerprint(fingerprint)
fpByte, err := hex.DecodeString(fp) if err2 != nil {
if err != nil { return err2
return err
} }
if len(fpByte) != 32 { mutex.Lock()
return fmt.Errorf("fingerprint string length error,need sha25 fingerprint") globalFingerprints = append(globalFingerprints, *fpByte)
} mutex.Unlock()
rwLock.Lock()
fingerprints = append(fingerprints, *(*[32]byte)(fpByte))
rwLock.Unlock()
return nil return nil
} }
func convertFingerprint(fingerprint string) (*[32]byte, error) {
fp := strings.Replace(fingerprint, ":", "", -1)
fpByte, err := hex.DecodeString(fp)
if err != nil {
return nil, err
}
if len(fpByte) != 32 {
return nil, fmt.Errorf("fingerprint string length error,need sha25 fingerprint")
}
return (*[32]byte)(fpByte), nil
}
func GetDefaultTLSConfig() *tls.Config { func GetDefaultTLSConfig() *tls.Config {
return defaultTLSConfig return &tls.Config{
InsecureSkipVerify: true,
VerifyPeerCertificate: verifyPeerCertificateAndFingerprints(globalFingerprints),
}
}
// GetTLSConfigWithSpecifiedFingerprint specified fingerprint
func GetTLSConfigWithSpecifiedFingerprint(tlsConfig *tls.Config, fingerprint string) (*tls.Config, error) {
if fingerprintBytes, err := convertFingerprint(fingerprint); err != nil {
return nil, err
} else {
if tlsConfig == nil {
return &tls.Config{
InsecureSkipVerify: true,
VerifyPeerCertificate: verifyPeerCertificateAndFingerprints([][32]byte{*fingerprintBytes}),
}, nil
} else {
tlsConfig.InsecureSkipVerify = true
tlsConfig.VerifyPeerCertificate = verifyPeerCertificateAndFingerprints([][32]byte{*fingerprintBytes})
return tlsConfig, nil
}
}
} }
func MixinTLSConfig(tlsConfig *tls.Config) *tls.Config { func MixinTLSConfig(tlsConfig *tls.Config) *tls.Config {
@ -75,6 +102,6 @@ func MixinTLSConfig(tlsConfig *tls.Config) *tls.Config {
} }
tlsConfig.InsecureSkipVerify = true tlsConfig.InsecureSkipVerify = true
tlsConfig.VerifyPeerCertificate = verifyPeerCertificate tlsConfig.VerifyPeerCertificate = verifyPeerCertificateAndFingerprints(globalFingerprints)
return tlsConfig return tlsConfig
} }