96 lines
2.6 KiB
Go
96 lines
2.6 KiB
Go
|
package shadowaead
|
||
|
|
||
|
import (
|
||
|
"crypto/rand"
|
||
|
"errors"
|
||
|
"io"
|
||
|
"net"
|
||
|
|
||
|
"github.com/Dreamacro/clash/common/pool"
|
||
|
)
|
||
|
|
||
|
// ErrShortPacket means that the packet is too short for a valid encrypted packet.
|
||
|
var ErrShortPacket = errors.New("short packet")
|
||
|
|
||
|
var _zerononce [128]byte // read-only. 128 bytes is more than enough.
|
||
|
|
||
|
// Pack encrypts plaintext using Cipher with a randomly generated salt and
|
||
|
// returns a slice of dst containing the encrypted packet and any error occurred.
|
||
|
// Ensure len(dst) >= ciph.SaltSize() + len(plaintext) + aead.Overhead().
|
||
|
func Pack(dst, plaintext []byte, ciph Cipher) ([]byte, error) {
|
||
|
saltSize := ciph.SaltSize()
|
||
|
salt := dst[:saltSize]
|
||
|
if _, err := rand.Read(salt); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
aead, err := ciph.Encrypter(salt)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
if len(dst) < saltSize+len(plaintext)+aead.Overhead() {
|
||
|
return nil, io.ErrShortBuffer
|
||
|
}
|
||
|
b := aead.Seal(dst[saltSize:saltSize], _zerononce[:aead.NonceSize()], plaintext, nil)
|
||
|
return dst[:saltSize+len(b)], nil
|
||
|
}
|
||
|
|
||
|
// Unpack decrypts pkt using Cipher and returns a slice of dst containing the decrypted payload and any error occurred.
|
||
|
// Ensure len(dst) >= len(pkt) - aead.SaltSize() - aead.Overhead().
|
||
|
func Unpack(dst, pkt []byte, ciph Cipher) ([]byte, error) {
|
||
|
saltSize := ciph.SaltSize()
|
||
|
if len(pkt) < saltSize {
|
||
|
return nil, ErrShortPacket
|
||
|
}
|
||
|
salt := pkt[:saltSize]
|
||
|
aead, err := ciph.Decrypter(salt)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
if len(pkt) < saltSize+aead.Overhead() {
|
||
|
return nil, ErrShortPacket
|
||
|
}
|
||
|
if saltSize+len(dst)+aead.Overhead() < len(pkt) {
|
||
|
return nil, io.ErrShortBuffer
|
||
|
}
|
||
|
b, err := aead.Open(dst[:0], _zerononce[:aead.NonceSize()], pkt[saltSize:], nil)
|
||
|
return b, err
|
||
|
}
|
||
|
|
||
|
type PacketConn struct {
|
||
|
net.PacketConn
|
||
|
Cipher
|
||
|
}
|
||
|
|
||
|
const maxPacketSize = 64 * 1024
|
||
|
|
||
|
// NewPacketConn wraps a net.PacketConn with cipher
|
||
|
func NewPacketConn(c net.PacketConn, ciph Cipher) *PacketConn {
|
||
|
return &PacketConn{PacketConn: c, Cipher: ciph}
|
||
|
}
|
||
|
|
||
|
// WriteTo encrypts b and write to addr using the embedded PacketConn.
|
||
|
func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (int, error) {
|
||
|
buf := pool.Get(maxPacketSize)
|
||
|
defer pool.Put(buf)
|
||
|
buf, err := Pack(buf, b, c)
|
||
|
if err != nil {
|
||
|
return 0, err
|
||
|
}
|
||
|
_, err = c.PacketConn.WriteTo(buf, addr)
|
||
|
return len(b), err
|
||
|
}
|
||
|
|
||
|
// ReadFrom reads from the embedded PacketConn and decrypts into b.
|
||
|
func (c *PacketConn) ReadFrom(b []byte) (int, net.Addr, error) {
|
||
|
n, addr, err := c.PacketConn.ReadFrom(b)
|
||
|
if err != nil {
|
||
|
return n, addr, err
|
||
|
}
|
||
|
bb, err := Unpack(b[c.Cipher.SaltSize():], b[:n], c)
|
||
|
if err != nil {
|
||
|
return n, addr, err
|
||
|
}
|
||
|
copy(b, bb)
|
||
|
return len(bb), addr, err
|
||
|
}
|