149 lines
2.9 KiB
Go
149 lines
2.9 KiB
Go
package deadline
|
|
|
|
import (
|
|
"net"
|
|
"os"
|
|
"time"
|
|
|
|
"github.com/metacubex/mihomo/common/atomic"
|
|
|
|
"github.com/sagernet/sing/common/buf"
|
|
"github.com/sagernet/sing/common/bufio"
|
|
"github.com/sagernet/sing/common/network"
|
|
)
|
|
|
|
type connReadResult struct {
|
|
buffer []byte
|
|
err error
|
|
}
|
|
|
|
type Conn struct {
|
|
network.ExtendedConn
|
|
deadline atomic.TypedValue[time.Time]
|
|
pipeDeadline pipeDeadline
|
|
disablePipe atomic.Bool
|
|
inRead atomic.Bool
|
|
resultCh chan *connReadResult
|
|
}
|
|
|
|
func NewConn(conn net.Conn) *Conn {
|
|
c := &Conn{
|
|
ExtendedConn: bufio.NewExtendedConn(conn),
|
|
pipeDeadline: makePipeDeadline(),
|
|
resultCh: make(chan *connReadResult, 1),
|
|
}
|
|
c.resultCh <- nil
|
|
return c
|
|
}
|
|
|
|
func (c *Conn) Read(p []byte) (n int, err error) {
|
|
select {
|
|
case result := <-c.resultCh:
|
|
if result != nil {
|
|
n = copy(p, result.buffer)
|
|
err = result.err
|
|
if n >= len(result.buffer) {
|
|
c.resultCh <- nil // finish cache read
|
|
} else {
|
|
result.buffer = result.buffer[n:]
|
|
c.resultCh <- result // push back for next call
|
|
}
|
|
return
|
|
} else {
|
|
c.resultCh <- nil
|
|
break
|
|
}
|
|
case <-c.pipeDeadline.wait():
|
|
return 0, os.ErrDeadlineExceeded
|
|
}
|
|
|
|
if c.disablePipe.Load() {
|
|
return c.ExtendedConn.Read(p)
|
|
} else if c.deadline.Load().IsZero() {
|
|
c.inRead.Store(true)
|
|
defer c.inRead.Store(false)
|
|
return c.ExtendedConn.Read(p)
|
|
}
|
|
|
|
<-c.resultCh
|
|
go c.pipeRead(len(p))
|
|
|
|
return c.Read(p)
|
|
}
|
|
|
|
func (c *Conn) pipeRead(size int) {
|
|
buffer := make([]byte, size)
|
|
n, err := c.ExtendedConn.Read(buffer)
|
|
buffer = buffer[:n]
|
|
c.resultCh <- &connReadResult{
|
|
buffer: buffer,
|
|
err: err,
|
|
}
|
|
}
|
|
|
|
func (c *Conn) ReadBuffer(buffer *buf.Buffer) (err error) {
|
|
select {
|
|
case result := <-c.resultCh:
|
|
if result != nil {
|
|
n, _ := buffer.Write(result.buffer)
|
|
err = result.err
|
|
|
|
if n >= len(result.buffer) {
|
|
c.resultCh <- nil // finish cache read
|
|
} else {
|
|
result.buffer = result.buffer[n:]
|
|
c.resultCh <- result // push back for next call
|
|
}
|
|
return
|
|
} else {
|
|
c.resultCh <- nil
|
|
break
|
|
}
|
|
case <-c.pipeDeadline.wait():
|
|
return os.ErrDeadlineExceeded
|
|
}
|
|
|
|
if c.disablePipe.Load() {
|
|
return c.ExtendedConn.ReadBuffer(buffer)
|
|
} else if c.deadline.Load().IsZero() {
|
|
c.inRead.Store(true)
|
|
defer c.inRead.Store(false)
|
|
return c.ExtendedConn.ReadBuffer(buffer)
|
|
}
|
|
|
|
<-c.resultCh
|
|
go c.pipeRead(buffer.FreeLen())
|
|
|
|
return c.ReadBuffer(buffer)
|
|
}
|
|
|
|
func (c *Conn) SetReadDeadline(t time.Time) error {
|
|
if c.disablePipe.Load() {
|
|
return c.ExtendedConn.SetReadDeadline(t)
|
|
} else if c.inRead.Load() {
|
|
c.disablePipe.Store(true)
|
|
return c.ExtendedConn.SetReadDeadline(t)
|
|
}
|
|
c.deadline.Store(t)
|
|
c.pipeDeadline.set(t)
|
|
return nil
|
|
}
|
|
|
|
func (c *Conn) ReaderReplaceable() bool {
|
|
select {
|
|
case result := <-c.resultCh:
|
|
c.resultCh <- result
|
|
if result != nil {
|
|
return false // cache reading
|
|
} else {
|
|
break
|
|
}
|
|
default:
|
|
return false // pipe reading
|
|
}
|
|
return c.disablePipe.Load() || c.deadline.Load().IsZero()
|
|
}
|
|
|
|
func (c *Conn) Upstream() any {
|
|
return c.ExtendedConn
|
|
}
|