Skip to content

Commit

Permalink
Refactored the test files with helpers to test backend
Browse files Browse the repository at this point in the history
```
func testRouteToBackend(t *testing.T, front net.Listener, back net.Listener, msg string)
func testNotRouteToBackend(t *testing.T, front net.Listener, back net.Listener, msg string) <-chan bool
```
  • Loading branch information
yangchenyun committed Apr 11, 2018
1 parent 2b928d9 commit 1be8ffa
Showing 1 changed file with 93 additions and 69 deletions.
162 changes: 93 additions & 69 deletions tcpproxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,38 +169,90 @@ func testProxy(t *testing.T, front net.Listener) *Proxy {
}
}

func TestProxyAlwaysMatch(t *testing.T) {
front := newLocalListener(t)
defer front.Close()
back := newLocalListener(t)
defer back.Close()
func testRouteToBackendWithExpected(t *testing.T, toFront net.Conn, back net.Listener, msg string, expected string) {
io.WriteString(toFront, msg)
fromProxy, err := back.Accept()
if err != nil {
t.Fatal(err)
}

p := testProxy(t, front)
p.AddRoute(testFrontAddr, To(back.Addr().String()))
if err := p.Start(); err != nil {
buf := make([]byte, len(expected))
if _, err := io.ReadFull(fromProxy, buf); err != nil {
t.Fatal(err)
}
if string(buf) != expected {
t.Fatalf("got %q; want %q", buf, expected)
}
}

func testRouteToBackend(t *testing.T, front net.Listener, back net.Listener, msg string) {
toFront, err := net.Dial("tcp", front.Addr().String())
if err != nil {
t.Fatal(err)
}
defer toFront.Close()

fromProxy, err := back.Accept()
testRouteToBackendWithExpected(t, toFront, back, msg, msg)
}

// test the backend is not receiving traffic
func testNotRouteToBackend(t *testing.T, front net.Listener, back net.Listener, msg string) <-chan bool {
done := make(chan bool)
toFront, err := net.Dial("tcp", front.Addr().String())
if err != nil {
t.Fatal(err)
}
const msg = "message"
io.WriteString(toFront, msg)
defer toFront.Close()

buf := make([]byte, len(msg))
if _, err := io.ReadFull(fromProxy, buf); err != nil {
timeC := time.NewTimer(10 * time.Millisecond).C
acceptC := make(chan struct{})
go func() {
io.WriteString(toFront, msg)
fromProxy, err := back.Accept()
acceptC <- struct{}{}
{
if err == nil {
buf := make([]byte, len(msg))
if _, err := io.ReadFull(fromProxy, buf); err != nil {
t.Fatal(err)
}
t.Fatalf("Expect backend to not receive message, but found %s", string(buf))
}
err, ok := err.(net.Error)
if !ok || !err.Timeout() {
t.Fatalf("Expect backend to timeout, but found err: %v", err)
}
}
}()
go func() {
select {
case <-timeC:
{
done <- true
}
case <-acceptC:
{
t.Fatal("Expect backend to not receive message")
done <- true
}
}
}()
return done
}

func TestProxyAlwaysMatch(t *testing.T) {
front := newLocalListener(t)
defer front.Close()
back := newLocalListener(t)
defer back.Close()

p := testProxy(t, front)
p.AddRoute(testFrontAddr, To(back.Addr().String()))
if err := p.Start(); err != nil {
t.Fatal(err)
}
if string(buf) != msg {
t.Fatalf("got %q; want %q", buf, msg)
}

testRouteToBackend(t, front, back, "message")
}

func TestProxyHTTP(t *testing.T) {
Expand All @@ -219,27 +271,9 @@ func TestProxyHTTP(t *testing.T) {
t.Fatal(err)
}

toFront, err := net.Dial("tcp", front.Addr().String())
if err != nil {
t.Fatal(err)
}
defer toFront.Close()

const msg = "GET / HTTP/1.1\r\nHost: bar.com\r\n\r\n"
io.WriteString(toFront, msg)

fromProxy, err := backBar.Accept()
if err != nil {
t.Fatal(err)
}

buf := make([]byte, len(msg))
if _, err := io.ReadFull(fromProxy, buf); err != nil {
t.Fatal(err)
}
if string(buf) != msg {
t.Fatalf("got %q; want %q", buf, msg)
}
testRouteToBackend(t, front, backBar, "GET / HTTP/1.1\r\nHost: bar.com\r\n\r\n")
<-testNotRouteToBackend(t, front, backBar, "GET / HTTP/1.1\r\nHost: boo.com\r\n\r\n")
testRouteToBackend(t, front, backFoo, "GET / HTTP/1.1\r\nHost: foo.com\r\n\r\n")
}

func TestProxySNI(t *testing.T) {
Expand All @@ -258,27 +292,32 @@ func TestProxySNI(t *testing.T) {
t.Fatal(err)
}

toFront, err := net.Dial("tcp", front.Addr().String())
if err != nil {
t.Fatal(err)
}
defer toFront.Close()
testRouteToBackend(t, front, backBar, clientHelloRecord(t, "bar.com"))
<-testNotRouteToBackend(t, front, backBar, clientHelloRecord(t, "foo.com"))
testRouteToBackend(t, front, backFoo, clientHelloRecord(t, "foo.com"))
}

msg := clientHelloRecord(t, "bar.com")
io.WriteString(toFront, msg)
func TestProxyRemoveRoute(t *testing.T) {
front := newLocalListener(t)
defer front.Close()
p := testProxy(t, front)

fromProxy, err := backBar.Accept()
if err != nil {
t.Fatal(err)
}
// NOTE: Needs to register testFrontAddr before server starts
p.AddSNIRoute(testFrontAddr, "unused.com", noopTarget{})

buf := make([]byte, len(msg))
if _, err := io.ReadFull(fromProxy, buf); err != nil {
if err := p.Start(); err != nil {
t.Fatal(err)
}
if string(buf) != msg {
t.Fatalf("got %q; want %q", buf, msg)
}

backBar := newLocalListener(t)
defer backBar.Close()
routeID := p.AddSNIRoute(testFrontAddr, "bar.com", To(backBar.Addr().String()))

msg := clientHelloRecord(t, "bar.com")
testRouteToBackend(t, front, backBar, msg)

p.RemoveRoute(testFrontAddr, routeID)
<-testNotRouteToBackend(t, front, backBar, msg)
}

func TestProxyPROXYOut(t *testing.T) {
Expand All @@ -301,23 +340,8 @@ func TestProxyPROXYOut(t *testing.T) {
t.Fatal(err)
}

io.WriteString(toFront, "foo")
toFront.Close()

fromProxy, err := back.Accept()
if err != nil {
t.Fatal(err)
}

bs, err := ioutil.ReadAll(fromProxy)
if err != nil {
t.Fatal(err)
}

want := fmt.Sprintf("PROXY TCP4 %s %d %s %d\r\nfoo", toFront.LocalAddr().(*net.TCPAddr).IP, toFront.LocalAddr().(*net.TCPAddr).Port, toFront.RemoteAddr().(*net.TCPAddr).IP, toFront.RemoteAddr().(*net.TCPAddr).Port)
if string(bs) != want {
t.Fatalf("got %q; want %q", bs, want)
}
testRouteToBackendWithExpected(t, toFront, back, "foo", want)
}

type tlsServer struct {
Expand Down

0 comments on commit 1be8ffa

Please sign in to comment.