mihomo/component/ca/config.go

144 lines
3.4 KiB
Go
Raw Normal View History

2023-09-22 14:45:34 +08:00
package ca
2022-07-10 20:44:24 +08:00
import (
"bytes"
"crypto/sha256"
"crypto/tls"
"crypto/x509"
"encoding/hex"
"errors"
2022-07-10 20:44:24 +08:00
"fmt"
2023-09-22 14:45:34 +08:00
"os"
"strings"
2022-07-10 20:44:24 +08:00
"sync"
)
2023-03-27 22:27:59 +08:00
var trustCerts []*x509.Certificate
2023-09-22 14:45:34 +08:00
var globalCertPool *x509.CertPool
var mutex sync.RWMutex
var errNotMatch = errors.New("certificate fingerprints do not match")
2022-07-10 20:44:24 +08:00
2023-02-25 22:01:20 +08:00
func AddCertificate(certificate string) error {
mutex.Lock()
defer mutex.Unlock()
2023-02-25 22:01:20 +08:00
if certificate == "" {
return fmt.Errorf("certificate is empty")
}
2023-03-27 22:27:59 +08:00
if cert, err := x509.ParseCertificate([]byte(certificate)); err == nil {
trustCerts = append(trustCerts, cert)
return nil
} else {
2023-02-25 22:01:20 +08:00
return fmt.Errorf("add certificate failed")
}
}
func initializeCertPool() {
var err error
2023-09-22 14:45:34 +08:00
globalCertPool, err = x509.SystemCertPool()
if err != nil {
2023-09-22 14:45:34 +08:00
globalCertPool = x509.NewCertPool()
}
for _, cert := range trustCerts {
2023-09-22 14:45:34 +08:00
globalCertPool.AddCert(cert)
}
}
2023-02-26 20:38:32 +08:00
func ResetCertificate() {
2023-02-25 22:01:20 +08:00
mutex.Lock()
defer mutex.Unlock()
2023-03-27 22:27:59 +08:00
trustCerts = nil
initializeCertPool()
2023-03-27 22:27:59 +08:00
}
func getCertPool() *x509.CertPool {
if len(trustCerts) == 0 {
return nil
}
2023-09-22 14:45:34 +08:00
if globalCertPool == nil {
mutex.Lock()
defer mutex.Unlock()
2023-09-22 14:45:34 +08:00
if globalCertPool != nil {
return globalCertPool
}
initializeCertPool()
2023-03-27 22:27:59 +08:00
}
2023-09-22 14:45:34 +08:00
return globalCertPool
}
func verifyFingerprint(fingerprint *[32]byte) func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
return func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
// ssl pining
for i := range rawCerts {
rawCert := rawCerts[i]
cert, err := x509.ParseCertificate(rawCert)
if err == nil {
hash := sha256.Sum256(cert.Raw)
if bytes.Equal(fingerprint[:], hash[:]) {
return nil
}
2022-07-10 20:44:24 +08:00
}
}
return errNotMatch
}
2022-07-10 20:44:24 +08:00
}
func convertFingerprint(fingerprint string) (*[32]byte, error) {
fingerprint = strings.TrimSpace(strings.Replace(fingerprint, ":", "", -1))
2022-07-11 13:44:27 +08:00
fpByte, err := hex.DecodeString(fingerprint)
2022-07-10 20:44:24 +08:00
if err != nil {
return nil, err
2022-07-10 20:44:24 +08:00
}
if len(fpByte) != 32 {
return nil, fmt.Errorf("fingerprint string length error,need sha256 fingerprint")
2022-07-10 20:44:24 +08:00
}
return (*[32]byte)(fpByte), nil
2022-07-10 20:44:24 +08:00
}
2023-09-22 14:45:34 +08:00
// GetTLSConfig specified fingerprint, customCA and customCAString
func GetTLSConfig(tlsConfig *tls.Config, fingerprint string, customCA string, customCAString string) (*tls.Config, error) {
if tlsConfig == nil {
tlsConfig = &tls.Config{}
}
var certificate []byte
var err error
if len(customCA) > 0 {
certificate, err = os.ReadFile(customCA)
if err != nil {
return nil, fmt.Errorf("load ca error: %w", err)
}
} else if customCAString != "" {
certificate = []byte(customCAString)
}
if len(certificate) > 0 {
certPool := x509.NewCertPool()
if !certPool.AppendCertsFromPEM(certificate) {
return nil, fmt.Errorf("failed to parse certificate:\n\n %s", certificate)
}
tlsConfig.RootCAs = certPool
} else {
2023-09-22 14:45:34 +08:00
tlsConfig.RootCAs = getCertPool()
}
if len(fingerprint) > 0 {
var fingerprintBytes *[32]byte
fingerprintBytes, err = convertFingerprint(fingerprint)
if err != nil {
return nil, err
}
tlsConfig = GetGlobalTLSConfig(tlsConfig)
tlsConfig.VerifyPeerCertificate = verifyFingerprint(fingerprintBytes)
tlsConfig.InsecureSkipVerify = true
}
2023-09-22 14:45:34 +08:00
return tlsConfig, nil
}
// GetSpecifiedFingerprintTLSConfig specified fingerprint
func GetSpecifiedFingerprintTLSConfig(tlsConfig *tls.Config, fingerprint string) (*tls.Config, error) {
return GetTLSConfig(tlsConfig, fingerprint, "", "")
2022-07-10 20:44:24 +08:00
}
func GetGlobalTLSConfig(tlsConfig *tls.Config) *tls.Config {
2023-09-22 14:45:34 +08:00
tlsConfig, _ = GetTLSConfig(tlsConfig, "", "", "")
2022-07-10 20:44:24 +08:00
return tlsConfig
}