143 lines
3.4 KiB
Go
143 lines
3.4 KiB
Go
package ca
|
|
|
|
import (
|
|
"bytes"
|
|
"crypto/sha256"
|
|
"crypto/tls"
|
|
"crypto/x509"
|
|
"encoding/hex"
|
|
"errors"
|
|
"fmt"
|
|
"os"
|
|
"strings"
|
|
"sync"
|
|
)
|
|
|
|
var trustCerts []*x509.Certificate
|
|
var globalCertPool *x509.CertPool
|
|
var mutex sync.RWMutex
|
|
var errNotMatch = errors.New("certificate fingerprints do not match")
|
|
|
|
func AddCertificate(certificate string) error {
|
|
mutex.Lock()
|
|
defer mutex.Unlock()
|
|
if certificate == "" {
|
|
return fmt.Errorf("certificate is empty")
|
|
}
|
|
if cert, err := x509.ParseCertificate([]byte(certificate)); err == nil {
|
|
trustCerts = append(trustCerts, cert)
|
|
return nil
|
|
} else {
|
|
return fmt.Errorf("add certificate failed")
|
|
}
|
|
}
|
|
|
|
func initializeCertPool() {
|
|
var err error
|
|
globalCertPool, err = x509.SystemCertPool()
|
|
if err != nil {
|
|
globalCertPool = x509.NewCertPool()
|
|
}
|
|
for _, cert := range trustCerts {
|
|
globalCertPool.AddCert(cert)
|
|
}
|
|
}
|
|
|
|
func ResetCertificate() {
|
|
mutex.Lock()
|
|
defer mutex.Unlock()
|
|
trustCerts = nil
|
|
initializeCertPool()
|
|
}
|
|
|
|
func getCertPool() *x509.CertPool {
|
|
if len(trustCerts) == 0 {
|
|
return nil
|
|
}
|
|
if globalCertPool == nil {
|
|
mutex.Lock()
|
|
defer mutex.Unlock()
|
|
if globalCertPool != nil {
|
|
return globalCertPool
|
|
}
|
|
initializeCertPool()
|
|
}
|
|
return globalCertPool
|
|
}
|
|
|
|
func verifyFingerprint(fingerprint *[32]byte) func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
|
|
return func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
|
|
// ssl pining
|
|
for i := range rawCerts {
|
|
rawCert := rawCerts[i]
|
|
cert, err := x509.ParseCertificate(rawCert)
|
|
if err == nil {
|
|
hash := sha256.Sum256(cert.Raw)
|
|
if bytes.Equal(fingerprint[:], hash[:]) {
|
|
return nil
|
|
}
|
|
}
|
|
}
|
|
return errNotMatch
|
|
}
|
|
}
|
|
|
|
func convertFingerprint(fingerprint string) (*[32]byte, error) {
|
|
fingerprint = strings.TrimSpace(strings.Replace(fingerprint, ":", "", -1))
|
|
fpByte, err := hex.DecodeString(fingerprint)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if len(fpByte) != 32 {
|
|
return nil, fmt.Errorf("fingerprint string length error,need sha256 fingerprint")
|
|
}
|
|
return (*[32]byte)(fpByte), nil
|
|
}
|
|
|
|
// GetTLSConfig specified fingerprint, customCA and customCAString
|
|
func GetTLSConfig(tlsConfig *tls.Config, fingerprint string, customCA string, customCAString string) (*tls.Config, error) {
|
|
if tlsConfig == nil {
|
|
tlsConfig = &tls.Config{}
|
|
}
|
|
var certificate []byte
|
|
var err error
|
|
if len(customCA) > 0 {
|
|
certificate, err = os.ReadFile(customCA)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("load ca error: %w", err)
|
|
}
|
|
} else if customCAString != "" {
|
|
certificate = []byte(customCAString)
|
|
}
|
|
if len(certificate) > 0 {
|
|
certPool := x509.NewCertPool()
|
|
if !certPool.AppendCertsFromPEM(certificate) {
|
|
return nil, fmt.Errorf("failed to parse certificate:\n\n %s", certificate)
|
|
}
|
|
tlsConfig.RootCAs = certPool
|
|
} else {
|
|
tlsConfig.RootCAs = getCertPool()
|
|
}
|
|
if len(fingerprint) > 0 {
|
|
var fingerprintBytes *[32]byte
|
|
fingerprintBytes, err = convertFingerprint(fingerprint)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
tlsConfig = GetGlobalTLSConfig(tlsConfig)
|
|
tlsConfig.VerifyPeerCertificate = verifyFingerprint(fingerprintBytes)
|
|
tlsConfig.InsecureSkipVerify = true
|
|
}
|
|
return tlsConfig, nil
|
|
}
|
|
|
|
// GetSpecifiedFingerprintTLSConfig specified fingerprint
|
|
func GetSpecifiedFingerprintTLSConfig(tlsConfig *tls.Config, fingerprint string) (*tls.Config, error) {
|
|
return GetTLSConfig(tlsConfig, fingerprint, "", "")
|
|
}
|
|
|
|
func GetGlobalTLSConfig(tlsConfig *tls.Config) *tls.Config {
|
|
tlsConfig, _ = GetTLSConfig(tlsConfig, "", "", "")
|
|
return tlsConfig
|
|
}
|