summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-16 16:18:53 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-16 16:18:53 +0000
commit1cdc15a87db98ea2a6a55d331e65ec1a4fc4f273 (patch)
tree34af891c87f9f96c9816500e46b7ea11588dc6ea
parentInitial commit. (diff)
downloadgolang-github-inetaf-tcpproxy-upstream/0.0_git20231102.2862066.tar.xz
golang-github-inetaf-tcpproxy-upstream/0.0_git20231102.2862066.zip
Adding upstream version 0.0~git20231102.2862066.upstream/0.0_git20231102.2862066upstream
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
-rw-r--r--.travis.yml45
-rw-r--r--CONTRIBUTING.md8
-rw-r--r--LICENSE202
-rw-r--r--README.md5
-rw-r--r--cmd/tlsrouter/README.md51
-rw-r--r--cmd/tlsrouter/config.go137
-rw-r--r--cmd/tlsrouter/config_test.go61
-rw-r--r--cmd/tlsrouter/e2e_test.go216
-rw-r--r--cmd/tlsrouter/main.go191
-rw-r--r--cmd/tlsrouter/sni.go232
-rw-r--r--cmd/tlsrouter/sni_test.go456
-rw-r--r--go.mod5
-rw-r--r--go.sum2
-rw-r--r--http.go125
-rw-r--r--listener.go108
-rw-r--r--listener_test.go49
-rw-r--r--scripts/prune_old_versions.go150
-rw-r--r--sni.go115
-rw-r--r--systemd/tlsrouter.service25
-rw-r--r--tcpproxy.go496
-rw-r--r--tcpproxy_test.go525
21 files changed, 3204 insertions, 0 deletions
diff --git a/.travis.yml b/.travis.yml
new file mode 100644
index 0000000..a8d3a50
--- /dev/null
+++ b/.travis.yml
@@ -0,0 +1,45 @@
+language: go
+go:
+- "1.16.x"
+- "1.17.x"
+- tip
+os:
+- linux
+script:
+- go build ./...
+- go test ./...
+- go vet ./...
+
+jobs:
+ include:
+ - stage: deploy
+ go: "1.16"
+ install:
+ - gem install fpm
+ script:
+ - go build ./cmd/tlsrouter
+ - fpm -s dir -t deb -n tlsrouter -v $(date '+%Y%m%d%H%M%S')
+ --license Apache2
+ --vendor "David Anderson <dave@natulte.net>"
+ --maintainer "David Anderson <dave@natulte.net>"
+ --description "TLS SNI router"
+ --url "https://github.com/inetaf/tcpproxy/tree/master/cmd/tlsrouter"
+ ./tlsrouter=/usr/bin/tlsrouter
+ ./systemd/tlsrouter.service=/lib/systemd/system/tlsrouter.service
+ deploy:
+ - provider: packagecloud
+ repository: tlsrouter
+ username: danderson
+ dist: debian/stretch
+ skip_cleanup: true
+ on:
+ branch: master
+ token:
+ secure: gNU3o70EU4oYeIS6pr0K5oLMGqqxrcf41EOv6c/YoHPVdV6Cx4j9NW0/ISgu6a1/Xf2NgWKT5BWwLpAuhmGdALuOz1Ah//YBWd9N8mGHGaC6RpOPDU8/9NkQdBEmjEH9sgX4PNOh1KQ7d7O0OH0g8RqJlJa0MkUYbTtN6KJ29oiUXxKmZM4D/iWB8VonKOnrtx1NwQL8jL8imZyEV/1fknhDwumz2iKeU1le4Neq9zkxwICMLUonmgphlrp+SDb1EOoHxT6cn51bqBQtQUplfC4dN4OQU/CPqE9E1N1noibvN29YA93qfcrjD3I95KT9wzq+3B6he33+kb0Gz+Cj5ypGy4P85l7TuX4CtQg0U3NAlJCk32IfsdjK+o47pdmADij9IIb9yKt+g99FMERkJJY5EInqEsxHlW/vNF5OqQCmpiHstZL4R2XaHEsWh6j77npnjjC1Aea8xZTWr8PTsbSzVkbG7bTmFpZoPH8eEmr4GNuw5gnbi6D1AJDjcA+UdY9s5qZNpzuWOqfhOFxL+zUW+8sHBvcoFw3R+pwHECs2LCL1c0xAC1LtNUnmW/gnwHavtvKkzErjR1P8Xl7obCbeChJjp+b/BcFYlNACldZcuzBAPyPwIdlWVyUonL4bm63upfMEEShiAIDDJ21y7fjsQK7CfPA7g25bpyo+hV8=
+ - provider: script
+ on:
+ branch: master
+ script: go run scripts/prune_old_versions.go -user=danderson -repo=tlsrouter -distro=debian -version=stretch -package=tlsrouter -arch=amd64 -limit=2
+ env:
+ # Packagecloud API key, for prune_old_versions.go
+ - secure: "SRcNwt+45QyPS1w9aGxMg9905Y6d9w4mBM29G6iTTnUB5nD7cAk4m+tf834knGSobVXlWcRnTDW8zrHdQ9yX22dPqCpH5qE+qzTmIvxRHrVJRMmPeYvligJ/9jYfHgQbvuRT8cUpIcpCQAla6rw8nXfKTOE3h8XqMP2hdc3DTVOu2HCfKCNco1tJ7is+AIAnFV2Wpsbb3ZsdKFvHvi2RKUfFaX61J1GNt2/XJIlZs8jC6Y1IAC+ftjql9UsAE/WjZ9fL0Ww1b9/LBIIGHXWI3HpVv9WvlhhIxIlJgOVjmU2lbSuj2w/EBDJ9cd1Qe+wJkT3yKzE1NRsNScVjGg+Ku5igJu/XXuaHkIX01+15BqgPduBYRL0atiNQDhqgBiSyVhXZBX9vsgsp0bgpKaBSF++CV18Q9dara8aljqqS33M3imO3I8JmXU10944QA9Wvu7pCYuIzXxhINcDXRvqxBqz5LnFJGwnGqngTrOCSVS2xn7Y+sjmhe1n5cPCEISlozfa9mPYPvMPp8zg3TbATOOM8CVfcpaNscLqa/+SExN3zMwSanjNKrBgoaQcBzGW5mIgSPxhXkWikBgapiEN7+2Y032Lhqdb9dYjH+EuwcnofspDjjMabWxnuJaln+E3/9vZi2ooQrBEtvymUTy4VMSnqwIX5bU7nPdIuQycdWhk="
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
new file mode 100644
index 0000000..188ad87
--- /dev/null
+++ b/CONTRIBUTING.md
@@ -0,0 +1,8 @@
+Contributions are welcome by pull request.
+
+You need to sign the Google Contributor License Agreement before your
+contributions can be accepted. You can find the individual and organization
+level CLAs here:
+
+Individual: https://cla.developers.google.com/about/google-individual
+Organization: https://cla.developers.google.com/about/google-corporate
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000..d645695
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,202 @@
+
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..f526c21
--- /dev/null
+++ b/README.md
@@ -0,0 +1,5 @@
+# tcpproxy
+
+For library usage, see https://godoc.org/inet.af/tcpproxy/
+
+For CLI usage, see https://github.com/inetaf/tcpproxy/blob/master/cmd/tlsrouter/README.md
diff --git a/cmd/tlsrouter/README.md b/cmd/tlsrouter/README.md
new file mode 100644
index 0000000..d915c32
--- /dev/null
+++ b/cmd/tlsrouter/README.md
@@ -0,0 +1,51 @@
+# TLS SNI router
+
+[![license](https://img.shields.io/github/license/google/tlsrouter.svg?maxAge=2592000)](https://github.com/inetaf/tcpproxy/blob/master/LICENSE) [![Travis](https://img.shields.io/travis/google/tlsrouter.svg?maxAge=2592000)](https://travis-ci.org/google/tlsrouter) [![api](https://img.shields.io/badge/api-unstable-red.svg)](https://godoc.org/go.universe.tf/tlsrouter)
+
+TLSRouter is a TLS proxy that routes connections to backends based on
+the TLS SNI (Server Name Indication) of the TLS handshake. It carries
+no encryption keys and cannot decode the traffic that it proxies.
+
+## Installation
+
+Install TLSRouter via `go get`:
+
+```shell
+go get go.universe.tf/tcpproxy/cmd/tlsrouter
+```
+
+## Usage
+
+TLSRouter requires a configuration file that tells it what backend to
+use for a given hostname. The config file looks like:
+
+```
+# Basic hostname -> backend mapping
+go.universe.tf localhost:1234
+
+# DNS wildcards are understood as well.
+*.go.universe.tf 1.2.3.4:8080
+
+# DNS wildcards can go anywhere in name.
+google.* 10.20.30.40:443
+
+# RE2 regexes are also available
+/(alpha|beta|gamma)\.mon(itoring)?\.dave\.tf/ 100.200.100.200:443
+
+# If your backend supports HAProxy's PROXY protocol, you can enable
+# it to receive the real client ip:port.
+
+fancy.backend 2.3.4.5:443 PROXY
+```
+
+TLSRouter takes one mandatory commandline argument, the configuration file to use:
+
+```shell
+tlsrouter -conf tlsrouter.conf
+```
+
+Optional flags are:
+
+ * `-listen <addr>`: set the listen address (default `:443`)
+ * `-hello-timeout <duration>`: how long to wait for the start of the
+ TLS handshake (default `3s`)
diff --git a/cmd/tlsrouter/config.go b/cmd/tlsrouter/config.go
new file mode 100644
index 0000000..692b04b
--- /dev/null
+++ b/cmd/tlsrouter/config.go
@@ -0,0 +1,137 @@
+// Copyright 2016 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package main
+
+import (
+ "bufio"
+ "bytes"
+ "errors"
+ "fmt"
+ "io"
+ "os"
+ "regexp"
+ "strings"
+ "sync"
+)
+
+// A Route maps a match on a domain name to a backend.
+type Route struct {
+ match *regexp.Regexp
+ backend string
+ proxyInfo bool
+}
+
+// Config stores the TLS routing configuration.
+type Config struct {
+ mu sync.Mutex
+ routes []Route
+}
+
+func dnsRegex(s string) (*regexp.Regexp, error) {
+ if len(s) >= 2 && s[0] == '/' && s[len(s)-1] == '/' {
+ return regexp.Compile(s[1 : len(s)-1])
+ }
+
+ var b []string
+ for _, f := range strings.Split(s, ".") {
+ switch f {
+ case "*":
+ b = append(b, `[^.]+`)
+ case "":
+ return nil, fmt.Errorf("DNS name %q has empty label", s)
+ default:
+ b = append(b, regexp.QuoteMeta(f))
+ }
+ }
+ return regexp.Compile(fmt.Sprintf("^%s$", strings.Join(b, `\.`)))
+}
+
+// Match returns the backend for hostname, and whether to use the PROXY protocol.
+func (c *Config) Match(hostname string) (string, bool) {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ for _, r := range c.routes {
+ if r.match.MatchString(hostname) {
+ return r.backend, r.proxyInfo
+ }
+ }
+ return "", false
+}
+
+// Read replaces the current Config with one read from r.
+func (c *Config) Read(r io.Reader) error {
+ var routes []Route
+ var backends []string
+
+ s := bufio.NewScanner(r)
+ for s.Scan() {
+ if strings.HasPrefix(strings.TrimSpace(s.Text()), "#") {
+ // Comment, ignore.
+ continue
+ }
+
+ fs := strings.Fields(s.Text())
+ switch len(fs) {
+ case 0:
+ continue
+ case 1:
+ return fmt.Errorf("invalid %q on a line by itself", s.Text())
+ case 2:
+ re, err := dnsRegex(fs[0])
+ if err != nil {
+ return err
+ }
+ routes = append(routes, Route{re, fs[1], false})
+ backends = append(backends, fs[1])
+ case 3:
+ re, err := dnsRegex(fs[0])
+ if err != nil {
+ return err
+ }
+ if fs[2] != "PROXY" {
+ return errors.New("third item on a line can only be PROXY")
+ }
+ routes = append(routes, Route{re, fs[1], true})
+ backends = append(backends, fs[1])
+ default:
+ // TODO: multiple backends?
+ return fmt.Errorf("too many fields on line: %q", s.Text())
+ }
+ }
+ if err := s.Err(); err != nil {
+ return err
+ }
+
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ c.routes = routes
+ return nil
+}
+
+// ReadFile replaces the current Config with one read from path.
+func (c *Config) ReadFile(path string) error {
+ f, err := os.Open(path)
+ if err != nil {
+ return err
+ }
+ return c.Read(f)
+}
+
+// ReadString replaces the current Config with one read from cfg.
+func (c *Config) ReadString(cfg string) error {
+ b := bytes.NewBufferString(cfg)
+ return c.Read(b)
+}
diff --git a/cmd/tlsrouter/config_test.go b/cmd/tlsrouter/config_test.go
new file mode 100644
index 0000000..9819b91
--- /dev/null
+++ b/cmd/tlsrouter/config_test.go
@@ -0,0 +1,61 @@
+package main
+
+import (
+ "bytes"
+ "testing"
+)
+
+func TestConfig(t *testing.T) {
+ type result struct {
+ backend string
+ proxy bool
+ }
+
+ cases := []struct {
+ Config string
+ Tests map[string]result
+ }{
+ {
+ Config: `
+# Comment
+go.universe.tf 1.2.3.4
+*.universe.tf 2.3.4.5
+# Comment
+google.* 3.4.5.6
+/gooo+gle\.com/ 4.5.6.7
+foobar.net 6.7.8.9 PROXY
+`,
+ Tests: map[string]result{
+ "go.universe.tf": result{"1.2.3.4", false},
+ "foo.universe.tf": result{"2.3.4.5", false},
+ "bar.universe.tf": result{"2.3.4.5", false},
+ "google.com": result{"3.4.5.6", false},
+ "google.fr": result{"3.4.5.6", false},
+ "goooooooooogle.com": result{"4.5.6.7", false},
+ "foobar.net": result{"6.7.8.9", true},
+
+ "blah.com": result{"", false},
+ "google.com.br": result{"", false},
+ "foo.bar.universe.tf": result{"", false},
+ "goooooglexcom": result{"", false},
+ },
+ },
+ }
+
+ for _, test := range cases {
+ var cfg Config
+ if err := cfg.Read(bytes.NewBufferString(test.Config)); err != nil {
+ t.Fatalf("Failed to read config (%s):\n%q", err, test.Config)
+ }
+
+ for hostname, expected := range test.Tests {
+ backend, proxy := cfg.Match(hostname)
+ if expected.backend != backend {
+ t.Errorf("cfg.Match(%q) is %q, want %q", hostname, backend, expected.backend)
+ }
+ if expected.proxy != proxy {
+ t.Errorf("cfg.Match(%q).proxy is %v, want %v", hostname, proxy, expected.proxy)
+ }
+ }
+ }
+}
diff --git a/cmd/tlsrouter/e2e_test.go b/cmd/tlsrouter/e2e_test.go
new file mode 100644
index 0000000..6e54021
--- /dev/null
+++ b/cmd/tlsrouter/e2e_test.go
@@ -0,0 +1,216 @@
+package main
+
+import (
+ "bytes"
+ "crypto/rand"
+ "crypto/rsa"
+ "crypto/tls"
+ "crypto/x509"
+ "crypto/x509/pkix"
+ "encoding/pem"
+ "fmt"
+ "io/ioutil"
+ "math/big"
+ "net"
+ "strings"
+ "sync/atomic"
+ "testing"
+ "time"
+
+ proxyproto "github.com/armon/go-proxyproto"
+)
+
+func TestRouting(t *testing.T) {
+ // Backend servers
+ s1, err := serveTLS(t, "server1", false, "test.com")
+ if err != nil {
+ t.Fatalf("serve TLS server1: %s", err)
+ }
+ defer s1.Close()
+
+ s2, err := serveTLS(t, "server2", false, "foo.net")
+ if err != nil {
+ t.Fatalf("serve TLS server2: %s", err)
+ }
+ defer s2.Close()
+
+ s4, err := serveTLS(t, "server4", true, "proxy.design")
+ if err != nil {
+ t.Fatalf("server TLS server4: %s", err)
+ }
+ defer s4.Close()
+
+ // One proxy
+ var p Proxy
+ l, err := net.Listen("tcp", "localhost:0")
+ if err != nil {
+ t.Fatalf("create listener: %s", err)
+ }
+ defer l.Close()
+ go p.Serve(l)
+
+ if err := p.Config.ReadString(fmt.Sprintf(`
+test.com %s
+foo.net %s
+proxy.design %s PROXY
+`, s1.Addr(), s2.Addr(), s4.Addr())); err != nil {
+ t.Fatalf("configure proxy: %s", err)
+ }
+
+ for _, test := range []struct {
+ N, V string
+ P *x509.CertPool
+ OK bool
+ Transparent bool
+ }{
+ {"test.com", "server1", s1.Pool, true, false},
+ {"foo.net", "server2", s2.Pool, true, false},
+ {"bar.org", "", s1.Pool, false, false},
+ {"proxy.design", "server4", s4.Pool, true, true},
+ } {
+ res, transparent, err := getTLS(l.Addr().String(), test.N, test.P)
+ switch {
+ case test.OK && err != nil:
+ t.Fatalf("get %q failed: %s", test.N, err)
+ case !test.OK && err == nil:
+ t.Fatalf("get %q should have failed, but returned %q", test.N, res)
+ case test.OK && res != test.V:
+ t.Fatalf("got wrong value from %q, got %q, want %q", test.N, res, test.V)
+ case test.OK && transparent != test.Transparent:
+ t.Fatalf("connection transparency for %q was %v, want %v", test.N, transparent, test.Transparent)
+ }
+ }
+}
+
+// getTLS attempts to set up a TLS session using the given proxy
+// address, domain, and cert pool. It returns the value served by the
+// server, as well as a bool indicating whether the server knew the
+// true client address, indicating that the PROXY protocol was in use.
+func getTLS(addr string, domain string, pool *x509.CertPool) (string, bool, error) {
+ cfg := tls.Config{
+ RootCAs: pool,
+ ServerName: domain,
+ }
+ conn, err := tls.Dial("tcp", addr, &cfg)
+ if err != nil {
+ return "", false, fmt.Errorf("dial TLS %q for %q: %s", addr, domain, err)
+ }
+ defer conn.Close()
+ bs, err := ioutil.ReadAll(conn)
+ if err != nil {
+ return "", false, fmt.Errorf("read TLS from %q (domain %q): %s", addr, domain, err)
+ }
+ fs := strings.Split(string(bs), " ")
+ if len(fs) != 2 {
+ return "", false, fmt.Errorf("read TLS from %q (domain %q): incoherent response %q", addr, domain, string(bs))
+ }
+ transparent := fs[1] == conn.LocalAddr().String()
+ return fs[0], transparent, nil
+}
+
+type tlsServer struct {
+ Domains []string
+ Value string
+ Pool *x509.CertPool
+ Test *testing.T
+ NumHits uint32
+ l net.Listener
+}
+
+func (s *tlsServer) Serve() {
+ for {
+ c, err := s.l.Accept()
+ if err != nil {
+ s.Test.Logf("accept failed on %q: %s", s.Domains, err)
+ return
+ }
+ atomic.AddUint32(&s.NumHits, 1)
+ fmt.Fprintf(c, "%s %s", s.Value, c.RemoteAddr())
+ c.Close()
+ }
+}
+
+func (s *tlsServer) Addr() string {
+ return s.l.Addr().String()
+}
+
+func (s *tlsServer) Close() error {
+ return s.l.Close()
+}
+
+func serveTLS(t *testing.T, value string, understandProxy bool, domains ...string) (*tlsServer, error) {
+ cert, pool, err := selfSignedCert(domains)
+ if err != nil {
+ return nil, err
+ }
+
+ cfg := &tls.Config{
+ Certificates: []tls.Certificate{cert},
+ }
+ cfg.BuildNameToCertificate()
+
+ var l net.Listener
+
+ l, err = net.Listen("tcp", "localhost:0")
+ if err != nil {
+ return nil, err
+ }
+
+ if understandProxy {
+ l = &proxyproto.Listener{Listener: l}
+ }
+
+ l = tls.NewListener(l, cfg)
+
+ ret := &tlsServer{
+ Domains: domains,
+ Value: value,
+ Pool: pool,
+ Test: t,
+ l: l,
+ }
+ go ret.Serve()
+ return ret, nil
+}
+
+func selfSignedCert(domains []string) (tls.Certificate, *x509.CertPool, error) {
+ pkey, err := rsa.GenerateKey(rand.Reader, 2048)
+ if err != nil {
+ return tls.Certificate{}, nil, err
+ }
+ template := &x509.Certificate{
+ SerialNumber: big.NewInt(1),
+ Subject: pkix.Name{
+ Organization: []string{"Test Co"},
+ CommonName: domains[0],
+ },
+ NotBefore: time.Now().Add(-5 * time.Minute),
+ NotAfter: time.Now().Add(60 * time.Minute),
+ IsCA: true,
+ KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
+ ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
+ BasicConstraintsValid: true,
+ DNSNames: domains[:],
+ }
+
+ derBytes, err := x509.CreateCertificate(rand.Reader, template, template, pkey.Public(), pkey)
+ if err != nil {
+ return tls.Certificate{}, nil, err
+ }
+
+ var cert, key bytes.Buffer
+ pem.Encode(&cert, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
+ pem.Encode(&key, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(pkey)})
+
+ tlscert, err := tls.X509KeyPair(cert.Bytes(), key.Bytes())
+ if err != nil {
+ return tls.Certificate{}, nil, err
+ }
+
+ pool := x509.NewCertPool()
+ if !pool.AppendCertsFromPEM(cert.Bytes()) {
+ return tls.Certificate{}, nil, fmt.Errorf("failed to add cert %q to pool", domains)
+ }
+
+ return tlscert, pool, nil
+}
diff --git a/cmd/tlsrouter/main.go b/cmd/tlsrouter/main.go
new file mode 100644
index 0000000..ff1a816
--- /dev/null
+++ b/cmd/tlsrouter/main.go
@@ -0,0 +1,191 @@
+// Copyright 2016 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package main
+
+import (
+ "bytes"
+ "flag"
+ "fmt"
+ "io"
+ "log"
+ "net"
+ "sync"
+ "time"
+)
+
+var (
+ cfgFile = flag.String("conf", "", "configuration file")
+ listen = flag.String("listen", ":443", "listening port")
+ helloTimeout = flag.Duration("hello-timeout", 3*time.Second, "how long to wait for the TLS ClientHello")
+)
+
+func main() {
+ flag.Parse()
+
+ p := &Proxy{}
+ if err := p.Config.ReadFile(*cfgFile); err != nil {
+ log.Fatalf("Failed to read config %q: %s", *cfgFile, err)
+ }
+
+ log.Fatalf("%s", p.ListenAndServe(*listen))
+}
+
+// Proxy routes connections to backends based on a Config.
+type Proxy struct {
+ Config Config
+ l net.Listener
+}
+
+// Serve accepts connections from l and routes them according to TLS SNI.
+func (p *Proxy) Serve(l net.Listener) error {
+ for {
+ c, err := l.Accept()
+ if err != nil {
+ return fmt.Errorf("accept new conn: %s", err)
+ }
+
+ conn := &Conn{
+ TCPConn: c.(*net.TCPConn),
+ config: &p.Config,
+ }
+ go conn.proxy()
+ }
+}
+
+// ListenAndServe creates a listener on addr calls Serve on it.
+func (p *Proxy) ListenAndServe(addr string) error {
+ l, err := net.Listen("tcp", addr)
+ if err != nil {
+ return fmt.Errorf("create listener: %s", err)
+ }
+ return p.Serve(l)
+}
+
+// A Conn handles the TLS proxying of one user connection.
+type Conn struct {
+ *net.TCPConn
+ config *Config
+
+ tlsMinor int
+ hostname string
+ backend string
+ backendConn *net.TCPConn
+}
+
+func (c *Conn) logf(msg string, args ...interface{}) {
+ msg = fmt.Sprintf(msg, args...)
+ log.Printf("%s <> %s: %s", c.RemoteAddr(), c.LocalAddr(), msg)
+}
+
+func (c *Conn) abort(alert byte, msg string, args ...interface{}) {
+ c.logf(msg, args...)
+ alertMsg := []byte{21, 3, byte(c.tlsMinor), 0, 2, 2, alert}
+
+ if err := c.SetWriteDeadline(time.Now().Add(*helloTimeout)); err != nil {
+ c.logf("error while setting write deadline during abort: %s", err)
+ // Do NOT send the alert if we can't set a write deadline,
+ // that could result in leaking a connection for an extended
+ // period.
+ return
+ }
+
+ if _, err := c.Write(alertMsg); err != nil {
+ c.logf("error while sending alert: %s", err)
+ }
+}
+
+func (c *Conn) internalError(msg string, args ...interface{}) { c.abort(80, msg, args...) }
+func (c *Conn) sniFailed(msg string, args ...interface{}) { c.abort(112, msg, args...) }
+
+func (c *Conn) proxy() {
+ defer c.Close()
+
+ if err := c.SetReadDeadline(time.Now().Add(*helloTimeout)); err != nil {
+ c.internalError("Setting read deadline for ClientHello: %s", err)
+ return
+ }
+
+ var (
+ err error
+ handshakeBuf bytes.Buffer
+ )
+ c.hostname, c.tlsMinor, err = extractSNI(io.TeeReader(c, &handshakeBuf))
+ if err != nil {
+ c.internalError("Extracting SNI: %s", err)
+ return
+ }
+
+ c.logf("extracted SNI %s", c.hostname)
+
+ if err = c.SetReadDeadline(time.Time{}); err != nil {
+ c.internalError("Clearing read deadline for ClientHello: %s", err)
+ return
+ }
+
+ addProxyHeader := false
+ c.backend, addProxyHeader = c.config.Match(c.hostname)
+ if c.backend == "" {
+ c.sniFailed("no backend found for %q", c.hostname)
+ return
+ }
+
+ c.logf("routing %q to %q", c.hostname, c.backend)
+ backend, err := net.DialTimeout("tcp", c.backend, 10*time.Second)
+ if err != nil {
+ c.internalError("failed to dial backend %q for %q: %s", c.backend, c.hostname, err)
+ return
+ }
+ defer backend.Close()
+
+ c.backendConn = backend.(*net.TCPConn)
+
+ // If the backend supports the HAProxy PROXY protocol, give it the
+ // real source information about the connection.
+ if addProxyHeader {
+ remote := c.TCPConn.RemoteAddr().(*net.TCPAddr)
+ local := c.TCPConn.LocalAddr().(*net.TCPAddr)
+ family := "TCP6"
+ if remote.IP.To4() != nil {
+ family = "TCP4"
+ }
+ if _, err := fmt.Fprintf(c.backendConn, "PROXY %s %s %s %d %d\r\n", family, remote.IP, local.IP, remote.Port, local.Port); err != nil {
+ c.internalError("failed to send PROXY header to %q: %s", c.backend, err)
+ return
+ }
+ }
+
+ // Replay the piece of the handshake we had to read to do the
+ // routing, then blindly proxy any other bytes.
+ if _, err = io.Copy(c.backendConn, &handshakeBuf); err != nil {
+ c.internalError("failed to replay handshake to %q: %s", c.backend, err)
+ return
+ }
+
+ var wg sync.WaitGroup
+ wg.Add(2)
+ go proxy(&wg, c.TCPConn, c.backendConn)
+ go proxy(&wg, c.backendConn, c.TCPConn)
+ wg.Wait()
+}
+
+func proxy(wg *sync.WaitGroup, a, b net.Conn) {
+ defer wg.Done()
+ atcp, btcp := a.(*net.TCPConn), b.(*net.TCPConn)
+ if _, err := io.Copy(atcp, btcp); err != nil {
+ log.Printf("%s<>%s -> %s<>%s: %s", atcp.RemoteAddr(), atcp.LocalAddr(), btcp.LocalAddr(), btcp.RemoteAddr(), err)
+ }
+ btcp.CloseWrite()
+ atcp.CloseRead()
+}
diff --git a/cmd/tlsrouter/sni.go b/cmd/tlsrouter/sni.go
new file mode 100644
index 0000000..ed79df2
--- /dev/null
+++ b/cmd/tlsrouter/sni.go
@@ -0,0 +1,232 @@
+// Copyright 2016 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package main
+
+import (
+ "encoding/binary"
+ "errors"
+ "fmt"
+ "io"
+)
+
+func extractSNI(r io.Reader) (string, int, error) {
+ handshake, tlsver, err := handshakeRecord(r)
+ if err != nil {
+ return "", 0, fmt.Errorf("reading TLS record: %s", err)
+ }
+
+ sni, err := parseHello(handshake)
+ if err != nil {
+ return "", 0, fmt.Errorf("reading ClientHello: %s", err)
+ }
+ if len(sni) == 0 {
+ // ClientHello did not present an SNI extension. Valid packet,
+ // no hostname.
+ return "", tlsver, nil
+ }
+
+ hostname, err := parseSNI(sni)
+ if err != nil {
+ return "", 0, fmt.Errorf("parsing SNI extension: %s", err)
+ }
+ return hostname, tlsver, nil
+}
+
+// Extract the indicated hostname, if any, from the given SNI
+// extension bytes.
+func parseSNI(b []byte) (string, error) {
+ b, _, err := vector(b, 2)
+ if err != nil {
+ return "", err
+ }
+
+ var ret []byte
+ for len(b) >= 3 {
+ typ := b[0]
+ ret, b, err = vector(b[1:], 2)
+ if err != nil {
+ return "", fmt.Errorf("truncated SNI extension")
+ }
+
+ if typ == sniHostnameID {
+ return string(ret), nil
+ }
+ }
+
+ if len(b) != 0 {
+ return "", fmt.Errorf("trailing garbage at end of SNI extension")
+ }
+
+ // No DNS-based SNI present.
+ return "", nil
+}
+
+const sniExtensionID = 0
+const sniHostnameID = 0
+
+// Parse a TLS handshake record as a ClientHello message and extract
+// the SNI extension bytes, if any.
+func parseHello(b []byte) ([]byte, error) {
+ if len(b) == 0 {
+ return nil, errors.New("zero length handshake record")
+ }
+ if b[0] != 1 {
+ return nil, fmt.Errorf("non-ClientHello handshake record type %d", b[0])
+ }
+
+ // We're expecting a stricter TLS parser to run after we've
+ // proxied, so we ignore any trailing bytes that might be present
+ // (e.g. another handshake message).
+ b, _, err := vector(b[1:], 3)
+ if err != nil {
+ return nil, fmt.Errorf("reading ClientHello: %s", err)
+ }
+
+ // ClientHello must be at least 34 bytes to reach the first vector
+ // length byte. The actual minimal size is larger than that, but
+ // vector() will correctly handle truncated packets.
+ if len(b) < 34 {
+ return nil, errors.New("ClientHello packet too short")
+ }
+
+ if b[0] != 3 {
+ return nil, fmt.Errorf("ClientHello has unsupported version %d.%d", b[0], b[1])
+ }
+ switch b[1] {
+ case 1, 2, 3:
+ // TLS 1.0, TLS 1.1, TLS 1.2
+ default:
+ return nil, fmt.Errorf("TLS record has unsupported version %d.%d", b[0], b[1])
+ }
+
+ // Skip over version and random struct
+ b = b[34:]
+
+ // We don't technically care about SessionID, but we care that the
+ // framing is well-formed all the way up to the SNI field, so that
+ // we are sure that we're pulling the same SNI bytes as the
+ // eventual TLS implementation.
+ vec, b, err := vector(b, 1)
+ if err != nil {
+ return nil, fmt.Errorf("reading ClientHello SessionID: %s", err)
+ }
+ if len(vec) > 32 {
+ return nil, fmt.Errorf("ClientHello SessionID too long (%db)", len(vec))
+ }
+
+ // Likewise, we're just checking the bare minimum of framing.
+ vec, b, err = vector(b, 2)
+ if err != nil {
+ return nil, fmt.Errorf("reading ClientHello CipherSuites: %s", err)
+ }
+ if len(vec) < 2 || len(vec)%2 != 0 {
+ return nil, fmt.Errorf("ClientHello CipherSuites invalid length %d", len(vec))
+ }
+
+ vec, b, err = vector(b, 1)
+ if err != nil {
+ return nil, fmt.Errorf("reading ClientHello CompressionMethods: %s", err)
+ }
+ if len(vec) < 1 {
+ return nil, fmt.Errorf("ClientHello CompressionMethods invalid length %d", len(vec))
+ }
+
+ // Finally, we reach the extensions.
+ if len(b) == 0 {
+ // No extensions. This is not an error, it just means we have
+ // no SNI payload.
+ return nil, nil
+ }
+ b, vec, err = vector(b, 2)
+ if err != nil {
+ return nil, fmt.Errorf("reading ClientHello extensions: %s", err)
+ }
+ if len(vec) != 0 {
+ return nil, fmt.Errorf("%d bytes of trailing garbage in ClientHello", len(vec))
+ }
+
+ for len(b) >= 4 {
+ typ := binary.BigEndian.Uint16(b[:2])
+ vec, b, err = vector(b[2:], 2)
+ if err != nil {
+ return nil, fmt.Errorf("reading ClientHello extension %d: %s", typ, err)
+ }
+ if typ == sniExtensionID {
+ // Found the SNI extension, return its payload. We don't
+ // care about anything in the packet beyond this point.
+ return vec, nil
+ }
+ }
+
+ if len(b) != 0 {
+ return nil, fmt.Errorf("%d bytes of trailing garbage in ClientHello", len(b))
+ }
+
+ // Successfully parsed all extensions, but there was no SNI.
+ return nil, nil
+}
+
+const maxTLSRecordLength = 16384
+
+// Read one TLS record, which must be for the handshake protocol, from r.
+func handshakeRecord(r io.Reader) ([]byte, int, error) {
+ var hdr struct {
+ Type uint8
+ Major, Minor uint8
+ Length uint16
+ }
+ if err := binary.Read(r, binary.BigEndian, &hdr); err != nil {
+ return nil, 0, fmt.Errorf("reading TLS record header: %s", err)
+ }
+
+ if hdr.Type != 22 {
+ return nil, 0, fmt.Errorf("TLS record is not a handshake")
+ }
+
+ if hdr.Major != 3 {
+ return nil, 0, fmt.Errorf("TLS record has unsupported version %d.%d", hdr.Major, hdr.Minor)
+ }
+ switch hdr.Minor {
+ case 1, 2, 3:
+ // TLS 1.0, TLS 1.1, TLS 1.2
+ default:
+ return nil, 0, fmt.Errorf("TLS record has unsupported version %d.%d", hdr.Major, hdr.Minor)
+ }
+
+ if hdr.Length > maxTLSRecordLength {
+ return nil, 0, fmt.Errorf("TLS record length is greater than %d", maxTLSRecordLength)
+ }
+
+ ret := make([]byte, hdr.Length)
+ if _, err := io.ReadFull(r, ret); err != nil {
+ return nil, 0, err
+ }
+
+ return ret, int(hdr.Minor), nil
+}
+
+func vector(b []byte, lenBytes int) ([]byte, []byte, error) {
+ if len(b) < lenBytes {
+ return nil, nil, errors.New("not enough space in packet for vector")
+ }
+ var l int
+ for _, b := range b[:lenBytes] {
+ l = (l << 8) + int(b)
+ }
+ if len(b) < l+lenBytes {
+ return nil, nil, errors.New("not enough space in packet for vector")
+ }
+ return b[lenBytes : l+lenBytes], b[l+lenBytes:], nil
+}
diff --git a/cmd/tlsrouter/sni_test.go b/cmd/tlsrouter/sni_test.go
new file mode 100644
index 0000000..8c87d24
--- /dev/null
+++ b/cmd/tlsrouter/sni_test.go
@@ -0,0 +1,456 @@
+// Copyright 2016 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package main
+
+import (
+ "bytes"
+ "testing"
+)
+
+func slice(l int) []byte {
+ ret := make([]byte, l)
+ for i := 0; i < l; i++ {
+ ret[i] = byte(i)
+ }
+ return ret
+}
+
+func vec(l, lenBytes int) []byte {
+ b := slice(l)
+ vecLen := len(b)
+ ret := make([]byte, vecLen+l)
+ for i := l - 1; i >= 0; i-- {
+ ret[i] = byte(vecLen & 0xff)
+ vecLen >>= 8
+ }
+ copy(ret[l:], b)
+ return ret
+}
+
+func packet(bs ...[]byte) []byte {
+ var ret []byte
+ for _, b := range bs {
+ ret = append(ret, b...)
+ }
+ return ret
+}
+
+func offset(b []byte, off int) []byte {
+ return b[off:]
+}
+
+func TestVector(t *testing.T) {
+ tests := []struct {
+ in []byte
+ inLen int
+ out1, out2 []byte
+ err bool
+ }{
+ {
+ // 1b length
+ append([]byte{3}, slice(10)...), 1,
+ slice(3), offset(slice(10), 3), false,
+ },
+ {
+ // 1b length, no trailer
+ append([]byte{10}, slice(10)...), 1,
+ slice(10), []byte{}, false,
+ },
+ {
+ // 1b length, no vector
+ append([]byte{0}, slice(10)...), 1,
+ []byte{}, slice(10), false,
+ },
+ {
+ // 1b length, no vector or trailer
+ []byte{0}, 1,
+ []byte{}, []byte{}, false,
+ },
+ {
+ // 2b length, LSB only
+ append([]byte{0, 3}, slice(10)...), 2,
+ slice(3), offset(slice(10), 3), false,
+ },
+ {
+ // 2b length, MSB only
+ append([]byte{3, 0}, slice(1024)...), 2,
+ slice(768), offset(slice(1024), 768), false,
+ },
+ {
+ // 2b length, both bytes
+ append([]byte{3, 2}, slice(1024)...), 2,
+ slice(770), offset(slice(1024), 770), false,
+ },
+ {
+ // 3b length
+ append([]byte{1, 2, 3}, slice(100000)...), 3,
+ slice(66051), offset(slice(100000), 66051), false,
+ },
+ {
+ // no bytes
+ []byte{}, 1,
+ nil, nil, true,
+ },
+ {
+ // no slice
+ nil, 1,
+ nil, nil, true,
+ },
+ {
+ // not enough bytes for length
+ []byte{1}, 2,
+ nil, nil, true,
+ },
+ {
+ // no bytes after length
+ []byte{1}, 1,
+ nil, nil, true,
+ },
+ {
+ // not enough bytes for vector
+ []byte{4, 1, 2}, 1,
+ nil, nil, true,
+ },
+ }
+
+ for _, test := range tests {
+ actual1, actual2, err := vector(test.in, test.inLen)
+ if !test.err && (err != nil) {
+ t.Errorf("unexpected error %q", err)
+ }
+ if test.err && (err == nil) {
+ t.Errorf("unexpected success")
+ }
+ if err != nil {
+ continue
+ }
+ if !bytes.Equal(actual1, test.out1) {
+ t.Errorf("wrong bytes for vector slice. Got %#v, want %#v", actual1, test.out1)
+ }
+ if !bytes.Equal(actual2, test.out2) {
+ t.Errorf("wrong bytes for vector slice. Got %#v, want %#v", actual2, test.out2)
+ }
+ }
+}
+
+func TestHandshakeRecord(t *testing.T) {
+ tests := []struct {
+ in []byte
+ out []byte
+ tlsver int
+ }{
+ {
+ // TLS 1.0, 1b packet
+ []byte{22, 3, 1, 0, 1, 3},
+ []byte{3},
+ 1,
+ },
+ {
+ // TLS 1.1, 1b packet
+ []byte{22, 3, 2, 0, 1, 3},
+ []byte{3},
+ 2,
+ },
+ {
+ // TLS 1.2, 1b packet
+ []byte{22, 3, 3, 0, 1, 3},
+ []byte{3},
+ 3,
+ },
+ {
+ // TLS 1.2, no payload bytes
+ []byte{22, 3, 3, 0, 0},
+ []byte{},
+ 3,
+ },
+ {
+ // TLS 1.2, >255b payload w/ trailing stuff
+ append([]byte{22, 3, 3, 3, 2}, slice(1024)...),
+ slice(770),
+ 3,
+ },
+ {
+ // TLS 1.2, 2^14 payload
+ append([]byte{22, 3, 3, 64, 0}, slice(maxTLSRecordLength)...),
+ slice(maxTLSRecordLength),
+ 3,
+ },
+ {
+ // TLS 1.2, >2^14 payload
+ append([]byte{22, 3, 3, 64, 1}, slice(maxTLSRecordLength+1)...),
+ nil,
+ 0,
+ },
+ {
+ // TLS 1.2, truncated payload
+ []byte{22, 3, 3, 0, 4, 1, 2},
+ nil,
+ 0,
+ },
+ {
+ // truncated header
+ []byte{22},
+ nil,
+ 0,
+ },
+ {
+ // wrong record type
+ []byte{42, 3, 3, 0, 1, 3},
+ nil,
+ 0,
+ },
+ {
+ // wrong TLS major version
+ []byte{22, 2, 3, 0, 1, 3},
+ nil,
+ 0,
+ },
+ {
+ // wrong TLS minor version
+ []byte{22, 3, 42, 0, 1, 3},
+ nil,
+ 0,
+ },
+ {
+ // Obsolete SSL 3.0
+ []byte{22, 3, 0, 0, 1, 3},
+ nil,
+ 0,
+ },
+ }
+
+ for _, test := range tests {
+ r := bytes.NewBuffer(test.in)
+ actual, tlsver, err := handshakeRecord(r)
+ if test.out == nil && err == nil {
+ t.Errorf("unexpected success")
+ continue
+ }
+ if !bytes.Equal(test.out, actual) {
+ t.Errorf("wrong bytes for TLS record. Got %#v, want %#v", actual, test.out)
+ }
+ if tlsver != test.tlsver {
+ t.Errorf("wrong TLS version returned. Got %d, want %d", tlsver, test.tlsver)
+ }
+ }
+}
+
+func TestParseHello(t *testing.T) {
+ tests := []struct {
+ in []byte
+ out []byte
+ err bool
+ }{
+ {
+ // Wrong record type
+ packet([]byte{42, 0, 0, 1, 1}),
+ nil,
+ true,
+ },
+ {
+ // Truncated payload
+ packet([]byte{1, 0, 0, 1}),
+ nil,
+ true,
+ },
+ {
+ // Payload too small
+ packet([]byte{1, 0, 0, 1, 1}),
+ nil,
+ true,
+ },
+ {
+ // Unknown major version
+ packet([]byte{1, 0, 0, 34, 1, 0}, slice(32)),
+ nil,
+ true,
+ },
+ {
+ // Unknown minor version
+ packet([]byte{1, 0, 0, 34, 3, 42}, slice(32)),
+ nil,
+ true,
+ },
+ {
+ // Missing required variadic fields
+ packet([]byte{1, 0, 0, 34, 3, 1}, slice(32)),
+ nil,
+ true,
+ },
+ {
+ // All zero variadic fields (no ciphersuites, no compression)
+ packet([]byte{1, 0, 0, 38, 3, 1}, slice(32), []byte{0, 0, 0, 0}),
+ nil,
+ true,
+ },
+ {
+ // All zero variadic fields (no ciphersuites, no compression, nonzero session ID)
+ packet([]byte{1, 0, 0, 70, 3, 1}, slice(32), []byte{32}, slice(32), []byte{0, 0, 0}),
+ nil,
+ true,
+ },
+ {
+ // Session + ciphersuites, no compression
+ packet([]byte{1, 0, 0, 72, 3, 1}, slice(32), []byte{32}, slice(32), []byte{0, 2, 1, 2, 0}),
+ nil,
+ true,
+ },
+ {
+ // First valid packet. TLS 1.0, no extensions present.
+ packet([]byte{1, 0, 0, 73, 3, 1}, slice(32), []byte{32}, slice(32), []byte{0, 2, 1, 2, 1, 0}),
+ nil,
+ false,
+ },
+ {
+ // TLS 1.1, no extensions present.
+ packet([]byte{1, 0, 0, 73, 3, 2}, slice(32), []byte{32}, slice(32), []byte{0, 2, 1, 2, 1, 0}),
+ nil,
+ false,
+ },
+ {
+ // TLS 1.2, no extensions present.
+ packet([]byte{1, 0, 0, 73, 3, 3}, slice(32), []byte{32}, slice(32), []byte{0, 2, 1, 2, 1, 0}),
+ nil,
+ false,
+ },
+ {
+ // TLS 1.2, garbage extensions
+ packet([]byte{1, 0, 0, 115, 3, 3}, slice(32), []byte{32}, slice(32), []byte{0, 2, 1, 2, 1, 0}, slice(42)),
+ nil,
+ true,
+ },
+ {
+ // empty extensions vector
+ packet([]byte{1, 0, 0, 75, 3, 3}, slice(32), []byte{32}, slice(32), []byte{0, 2, 1, 2, 1, 0}, []byte{0, 0}),
+ nil,
+ false,
+ },
+ {
+ // non-SNI extensions
+ packet([]byte{1, 0, 0, 85, 3, 3}, slice(32), []byte{32}, slice(32), []byte{0, 2, 1, 2, 1, 0}, []byte{0, 10, 42, 42, 0, 0, 100, 100, 0, 2, 1, 2}),
+ nil,
+ false,
+ },
+ {
+ // SNI present
+ packet([]byte{1, 0, 0, 90, 3, 3}, slice(32), []byte{32}, slice(32), []byte{0, 2, 1, 2, 1, 0}, []byte{0, 15, 42, 42, 0, 0, 100, 100, 0, 2, 1, 2, 0, 0, 0, 1, 182}),
+ []byte{182},
+ false,
+ },
+ {
+ // Longer SNI
+ packet([]byte{1, 0, 0, 93, 3, 3}, slice(32), []byte{32}, slice(32), []byte{0, 2, 1, 2, 1, 0}, []byte{0, 18, 42, 42, 0, 0, 100, 100, 0, 2, 1, 2, 0, 0, 0, 4}, slice(4)),
+ slice(4),
+ false,
+ },
+ {
+ // Embedded SNI
+ packet([]byte{1, 0, 0, 93, 3, 3}, slice(32), []byte{32}, slice(32), []byte{0, 2, 1, 2, 1, 0}, []byte{0, 18, 42, 42, 0, 0, 0, 0, 0, 4}, slice(4), []byte{100, 100, 0, 2, 1, 2}),
+ slice(4),
+ false,
+ },
+ }
+
+ for _, test := range tests {
+ actual, err := parseHello(test.in)
+ if test.err {
+ if err == nil {
+ t.Errorf("unexpected success")
+ }
+ continue
+ }
+ if err != nil {
+ t.Errorf("unexpected error %q", err)
+ continue
+ }
+ if !bytes.Equal(test.out, actual) {
+ t.Errorf("wrong bytes for SNI data. Got %#v, want %#v", actual, test.out)
+ }
+ }
+}
+
+func TestParseSNI(t *testing.T) {
+ tests := []struct {
+ in []byte
+ out string
+ err bool
+ }{
+ {
+ // Empty packet
+ []byte{},
+ "",
+ true,
+ },
+ {
+ // Truncated packet
+ []byte{0, 2, 1},
+ "",
+ true,
+ },
+ {
+ // Truncated packet within SNI vector
+ []byte{0, 2, 1, 2},
+ "",
+ true,
+ },
+ {
+ // Wrong SNI kind
+ []byte{0, 3, 1, 0, 0},
+ "",
+ false,
+ },
+ {
+ // Right SNI kind, no hostname
+ []byte{0, 3, 0, 0, 0},
+ "",
+ false,
+ },
+ {
+ // SNI hostname
+ packet([]byte{0, 6, 0, 0, 3}, []byte("lol")),
+ "lol",
+ false,
+ },
+ {
+ // Multiple SNI kinds
+ packet([]byte{0, 13, 1, 0, 0, 0, 0, 3}, []byte("lol"), []byte{42, 0, 1, 2}),
+ "lol",
+ false,
+ },
+ {
+ // Multiple SNI hostnames (illegal, but we just return the first)
+ packet([]byte{0, 13, 1, 0, 0, 0, 0, 3}, []byte("bar"), []byte{0, 0, 3}, []byte("lol")),
+ "bar",
+ false,
+ },
+ }
+
+ for _, test := range tests {
+ actual, err := parseSNI(test.in)
+ if test.err {
+ if err == nil {
+ t.Errorf("unexpected success")
+ }
+ continue
+ }
+ if err != nil {
+ t.Errorf("unexpected error %q", err)
+ continue
+ }
+ if test.out != actual {
+ t.Errorf("wrong SNI hostname. Got %q, want %q", actual, test.out)
+ }
+ }
+}
diff --git a/go.mod b/go.mod
new file mode 100644
index 0000000..9c8ce9f
--- /dev/null
+++ b/go.mod
@@ -0,0 +1,5 @@
+module inet.af/tcpproxy
+
+go 1.16
+
+require github.com/armon/go-proxyproto v0.0.0-20210323213023-7e956b284f0a
diff --git a/go.sum b/go.sum
new file mode 100644
index 0000000..de51fb1
--- /dev/null
+++ b/go.sum
@@ -0,0 +1,2 @@
+github.com/armon/go-proxyproto v0.0.0-20210323213023-7e956b284f0a h1:AP/vsCIvJZ129pdm9Ek7bH7yutN3hByqsMoNrWAxRQc=
+github.com/armon/go-proxyproto v0.0.0-20210323213023-7e956b284f0a/go.mod h1:QmP9hvJ91BbJmGVGSbutW19IC0Q9phDCLGaomwTJbgU=
diff --git a/http.go b/http.go
new file mode 100644
index 0000000..d28c66f
--- /dev/null
+++ b/http.go
@@ -0,0 +1,125 @@
+// Copyright 2017 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcpproxy
+
+import (
+ "bufio"
+ "bytes"
+ "context"
+ "net/http"
+)
+
+// AddHTTPHostRoute appends a route to the ipPort listener that
+// routes to dest if the incoming HTTP/1.x Host header name is
+// httpHost. If it doesn't match, rule processing continues for any
+// additional routes on ipPort.
+//
+// The ipPort is any valid net.Listen TCP address.
+func (p *Proxy) AddHTTPHostRoute(ipPort, httpHost string, dest Target) {
+ p.AddHTTPHostMatchRoute(ipPort, equals(httpHost), dest)
+}
+
+// AddHTTPHostMatchRoute appends a route to the ipPort listener that
+// routes to dest if the incoming HTTP/1.x Host header name is
+// accepted by matcher. If it doesn't match, rule processing continues
+// for any additional routes on ipPort.
+//
+// The ipPort is any valid net.Listen TCP address.
+func (p *Proxy) AddHTTPHostMatchRoute(ipPort string, match Matcher, dest Target) {
+ p.addRoute(ipPort, httpHostMatch{match, dest})
+}
+
+type httpHostMatch struct {
+ matcher Matcher
+ target Target
+}
+
+func (m httpHostMatch) match(br *bufio.Reader) (Target, string) {
+ hh := httpHostHeader(br)
+ if m.matcher(context.TODO(), hh) {
+ return m.target, hh
+ }
+ return nil, ""
+}
+
+// httpHostHeader returns the HTTP Host header from br without
+// consuming any of its bytes. It returns "" if it can't find one.
+func httpHostHeader(br *bufio.Reader) string {
+ const maxPeek = 4 << 10
+ peekSize := 0
+ for {
+ peekSize++
+ if peekSize > maxPeek {
+ b, _ := br.Peek(br.Buffered())
+ return httpHostHeaderFromBytes(b)
+ }
+ b, err := br.Peek(peekSize)
+ if n := br.Buffered(); n > peekSize {
+ b, _ = br.Peek(n)
+ peekSize = n
+ }
+ if len(b) > 0 {
+ if b[0] < 'A' || b[0] > 'Z' {
+ // Doesn't look like an HTTP verb
+ // (GET, POST, etc).
+ return ""
+ }
+ if bytes.Index(b, crlfcrlf) != -1 || bytes.Index(b, lflf) != -1 {
+ req, err := http.ReadRequest(bufio.NewReader(bytes.NewReader(b)))
+ if err != nil {
+ return ""
+ }
+ if len(req.Header["Host"]) > 1 {
+ // TODO(bradfitz): what does
+ // ReadRequest do if there are
+ // multiple Host headers?
+ return ""
+ }
+ return req.Host
+ }
+ }
+ if err != nil {
+ return httpHostHeaderFromBytes(b)
+ }
+ }
+}
+
+var (
+ lfHostColon = []byte("\nHost:")
+ lfhostColon = []byte("\nhost:")
+ crlf = []byte("\r\n")
+ lf = []byte("\n")
+ crlfcrlf = []byte("\r\n\r\n")
+ lflf = []byte("\n\n")
+)
+
+func httpHostHeaderFromBytes(b []byte) string {
+ if i := bytes.Index(b, lfHostColon); i != -1 {
+ return string(bytes.TrimSpace(untilEOL(b[i+len(lfHostColon):])))
+ }
+ if i := bytes.Index(b, lfhostColon); i != -1 {
+ return string(bytes.TrimSpace(untilEOL(b[i+len(lfhostColon):])))
+ }
+ return ""
+}
+
+// untilEOL returns v, truncated before the first '\n' byte, if any.
+// The returned slice may include a '\r' at the end.
+func untilEOL(v []byte) []byte {
+ if i := bytes.IndexByte(v, '\n'); i != -1 {
+ return v[:i]
+ }
+ return v
+}
diff --git a/listener.go b/listener.go
new file mode 100644
index 0000000..1ddc48e
--- /dev/null
+++ b/listener.go
@@ -0,0 +1,108 @@
+// Copyright 2017 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcpproxy
+
+import (
+ "io"
+ "net"
+ "sync"
+)
+
+// TargetListener implements both net.Listener and Target.
+// Matched Targets become accepted connections.
+type TargetListener struct {
+ Address string // Address is the string reported by TargetListener.Addr().String().
+
+ mu sync.Mutex
+ cond *sync.Cond
+ closed bool
+ nextConn net.Conn
+}
+
+var (
+ _ net.Listener = (*TargetListener)(nil)
+ _ Target = (*TargetListener)(nil)
+)
+
+func (tl *TargetListener) lock() {
+ tl.mu.Lock()
+ if tl.cond == nil {
+ tl.cond = sync.NewCond(&tl.mu)
+ }
+}
+
+type tcpAddr string
+
+func (a tcpAddr) Network() string { return "tcp" }
+func (a tcpAddr) String() string { return string(a) }
+
+// Addr returns the listener's Address field as a net.Addr.
+func (tl *TargetListener) Addr() net.Addr { return tcpAddr(tl.Address) }
+
+// Close stops listening for new connections. All new connections
+// routed to this listener will be closed. Already accepted
+// connections are not closed.
+func (tl *TargetListener) Close() error {
+ tl.lock()
+ if tl.closed {
+ tl.mu.Unlock()
+ return nil
+ }
+ tl.closed = true
+ tl.mu.Unlock()
+ tl.cond.Broadcast()
+ return nil
+}
+
+// HandleConn implements the Target interface. It blocks until tl is
+// closed or another goroutine has called Accept and received c.
+func (tl *TargetListener) HandleConn(c net.Conn) {
+ tl.lock()
+ defer tl.mu.Unlock()
+ for tl.nextConn != nil && !tl.closed {
+ tl.cond.Wait()
+ }
+ if tl.closed {
+ c.Close()
+ return
+ }
+ tl.nextConn = c
+ tl.cond.Broadcast() // Signal might be sufficient; verify.
+ for tl.nextConn == c && !tl.closed {
+ tl.cond.Wait()
+ }
+ if tl.closed {
+ c.Close()
+ return
+ }
+}
+
+// Accept implements the Accept method in the net.Listener interface.
+func (tl *TargetListener) Accept() (net.Conn, error) {
+ tl.lock()
+ for tl.nextConn == nil && !tl.closed {
+ tl.cond.Wait()
+ }
+ if tl.closed {
+ tl.mu.Unlock()
+ return nil, io.EOF
+ }
+ c := tl.nextConn
+ tl.nextConn = nil
+ tl.mu.Unlock()
+ tl.cond.Broadcast() // Signal might be sufficient; verify.
+
+ return c, nil
+}
diff --git a/listener_test.go b/listener_test.go
new file mode 100644
index 0000000..70087ca
--- /dev/null
+++ b/listener_test.go
@@ -0,0 +1,49 @@
+// Copyright 2017 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcpproxy
+
+import (
+ "io"
+ "testing"
+)
+
+func TestListenerAccept(t *testing.T) {
+ tl := new(TargetListener)
+ ch := make(chan interface{}, 1)
+ go func() {
+ for {
+ conn, err := tl.Accept()
+ if err != nil {
+ ch <- err
+ return
+ }
+ ch <- conn
+ }
+ }()
+
+ for i := 0; i < 3; i++ {
+ conn := new(Conn)
+ tl.HandleConn(conn)
+ got := <-ch
+ if got != conn {
+ t.Errorf("Accept conn = %v; want %v", got, conn)
+ }
+ }
+ tl.Close()
+ got := <-ch
+ if got != io.EOF {
+ t.Errorf("Accept error post-Close = %v; want io.EOF", got)
+ }
+}
diff --git a/scripts/prune_old_versions.go b/scripts/prune_old_versions.go
new file mode 100644
index 0000000..42e031e
--- /dev/null
+++ b/scripts/prune_old_versions.go
@@ -0,0 +1,150 @@
+// Copyright 2017 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package main
+
+import (
+ "encoding/json"
+ "flag"
+ "fmt"
+ "io/ioutil"
+ "net/http"
+ "os"
+ "sort"
+ "strings"
+ "time"
+)
+
+var (
+ user = flag.String("user", "", "username")
+ repo = flag.String("repo", "", "repository name")
+ pkgType = flag.String("pkg-type", "deb", "Package type, e.g. 'deb'")
+ distro = flag.String("distro", "", "distro name, e.g. 'debian'")
+ distroVersion = flag.String("version", "", "distro version, e.g. 'stretch'")
+ pkg = flag.String("package", "", "package name")
+ arch = flag.String("arch", "", "package architecture")
+ limit = flag.Int("limit", 2, "package versions to keep")
+)
+
+func fatalf(msg string, args ...interface{}) {
+ fmt.Printf(msg+"\n", args...)
+ os.Exit(1)
+}
+
+func main() {
+ flag.Parse()
+ if *user == "" {
+ fatalf("missing -user")
+ }
+ if *repo == "" {
+ fatalf("missing -repo")
+ }
+ if *pkgType == "" {
+ fatalf("missing -pkg-type")
+ }
+ if *distro == "" {
+ fatalf("missing -distro")
+ }
+ if *distroVersion == "" {
+ fatalf("missing -version")
+ }
+ if *pkg == "" {
+ fatalf("missing -package")
+ }
+ if *arch == "" {
+ fatalf("missing -arch")
+ }
+ if *limit < 1 {
+ fatalf("limit must be >= 1")
+ }
+
+ files, err := packageVersions(*user, *repo, *pkgType, *distro, *distroVersion, *pkg, *arch)
+ if err != nil {
+ fmt.Println(err)
+ os.Exit(1)
+ }
+ if len(files) <= *limit {
+ fmt.Println("Below limit, no packages deleted")
+ return
+ }
+ delete := files[:len(files)-*limit]
+ keep := files[len(files)-*limit:]
+ if err = deletePackages(delete); err != nil {
+ fmt.Println(err)
+ os.Exit(1)
+ }
+
+ fmt.Printf("Deleted:\n\n%s\n\nKept:\n\n%s\n", strings.Join(delete, "\n"), strings.Join(keep, "\n"))
+}
+
+type packageMeta struct {
+ Created time.Time `json:"created_at"`
+ Filename string `json:"filename"`
+}
+
+type metaSort []packageMeta
+
+func (m metaSort) Len() int { return len(m) }
+func (m metaSort) Less(i, j int) bool { return m[i].Created.Before(m[j].Created) }
+func (m metaSort) Swap(i, j int) { m[i], m[j] = m[j], m[i] }
+
+func packageVersions(user, repo, typ, distro, version, pkgname, arch string) ([]string, error) {
+ url := fmt.Sprintf("https://%s:@packagecloud.io/api/v1/repos/%s/%s/package/%s/%s/%s/%s/%s/versions.json", os.Getenv("PACKAGECLOUD_API_KEY"), user, repo, typ, distro, version, pkgname, arch)
+ resp, err := http.Get(url)
+ if err != nil {
+ return nil, fmt.Errorf("get versions.json: %s", err)
+ }
+ defer resp.Body.Close()
+ if resp.StatusCode != 200 {
+ msg, err := ioutil.ReadAll(resp.Body)
+ if err != nil {
+ return nil, fmt.Errorf("get error message of versions.json get: %s", err)
+ }
+ return nil, fmt.Errorf("get versions.json: %s (%q)", resp.Status, string(msg))
+ }
+
+ var files []packageMeta
+ if err := json.NewDecoder(resp.Body).Decode(&files); err != nil {
+ return nil, fmt.Errorf("decode versions.json: %s", err)
+ }
+
+ // Newest first
+ sort.Sort(metaSort(files))
+
+ var ret []string
+ for _, meta := range files {
+ ret = append(ret, fmt.Sprintf("/api/v1/repos/%s/%s/%s/%s/%s", user, repo, distro, version, meta.Filename))
+ }
+
+ return ret, nil
+}
+
+func deletePackages(urls []string) error {
+ for _, url := range urls {
+ fullURL := fmt.Sprintf("https://%s:@packagecloud.io%s", os.Getenv("PACKAGECLOUD_API_KEY"), url)
+ req, err := http.NewRequest("DELETE", fullURL, nil)
+ if err != nil {
+ return fmt.Errorf("build delete request for %s: %s", url, err)
+ }
+ resp, err := http.DefaultClient.Do(req)
+ if err != nil {
+ return fmt.Errorf("delete %s: %s", url, err)
+ }
+ defer resp.Body.Close()
+ if resp.StatusCode != 200 {
+ return fmt.Errorf("delete %s: %s", url, resp.Status)
+ }
+ }
+ return nil
+}
diff --git a/sni.go b/sni.go
new file mode 100644
index 0000000..c2d37e0
--- /dev/null
+++ b/sni.go
@@ -0,0 +1,115 @@
+// Copyright 2017 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcpproxy
+
+import (
+ "bufio"
+ "bytes"
+ "context"
+ "crypto/tls"
+ "io"
+ "net"
+)
+
+// AddSNIRoute appends a route to the ipPort listener that routes to
+// dest if the incoming TLS SNI server name is sni. If it doesn't
+// match, rule processing continues for any additional routes on
+// ipPort.
+//
+// The ipPort is any valid net.Listen TCP address.
+func (p *Proxy) AddSNIRoute(ipPort, sni string, dest Target) {
+ p.AddSNIMatchRoute(ipPort, equals(sni), dest)
+}
+
+// AddSNIMatchRoute appends a route to the ipPort listener that routes
+// to dest if the incoming TLS SNI server name is accepted by
+// matcher. If it doesn't match, rule processing continues for any
+// additional routes on ipPort.
+//
+// The ipPort is any valid net.Listen TCP address.
+func (p *Proxy) AddSNIMatchRoute(ipPort string, matcher Matcher, dest Target) {
+ p.addRoute(ipPort, sniMatch{matcher: matcher, target: dest})
+}
+
+// SNITargetFunc is the func callback used by Proxy.AddSNIRouteFunc.
+type SNITargetFunc func(ctx context.Context, sniName string) (t Target, ok bool)
+
+// AddSNIRouteFunc adds a route to ipPort that matches an SNI request and calls
+// fn to map its nap to a target.
+func (p *Proxy) AddSNIRouteFunc(ipPort string, fn SNITargetFunc) {
+ p.addRoute(ipPort, sniMatch{targetFunc: fn})
+}
+
+type sniMatch struct {
+ matcher Matcher
+ target Target
+
+ // Alternatively, if targetFunc is non-nil, it's used instead:
+ targetFunc SNITargetFunc
+}
+
+func (m sniMatch) match(br *bufio.Reader) (Target, string) {
+ sni := clientHelloServerName(br)
+ if sni == "" {
+ return nil, ""
+ }
+ if m.targetFunc != nil {
+ if t, ok := m.targetFunc(context.TODO(), sni); ok {
+ return t, sni
+ }
+ return nil, ""
+ }
+ if m.matcher(context.TODO(), sni) {
+ return m.target, sni
+ }
+ return nil, ""
+}
+
+// clientHelloServerName returns the SNI server name inside the TLS ClientHello,
+// without consuming any bytes from br.
+// On any error, the empty string is returned.
+func clientHelloServerName(br *bufio.Reader) (sni string) {
+ const recordHeaderLen = 5
+ hdr, err := br.Peek(recordHeaderLen)
+ if err != nil {
+ return ""
+ }
+ const recordTypeHandshake = 0x16
+ if hdr[0] != recordTypeHandshake {
+ return "" // Not TLS.
+ }
+ recLen := int(hdr[3])<<8 | int(hdr[4]) // ignoring version in hdr[1:3]
+ helloBytes, err := br.Peek(recordHeaderLen + recLen)
+ if err != nil {
+ return ""
+ }
+ tls.Server(sniSniffConn{r: bytes.NewReader(helloBytes)}, &tls.Config{
+ GetConfigForClient: func(hello *tls.ClientHelloInfo) (*tls.Config, error) {
+ sni = hello.ServerName
+ return nil, nil
+ },
+ }).Handshake()
+ return
+}
+
+// sniSniffConn is a net.Conn that reads from r, fails on Writes,
+// and crashes otherwise.
+type sniSniffConn struct {
+ r io.Reader
+ net.Conn // nil; crash on any unexpected use
+}
+
+func (c sniSniffConn) Read(p []byte) (int, error) { return c.r.Read(p) }
+func (sniSniffConn) Write(p []byte) (int, error) { return 0, io.EOF }
diff --git a/systemd/tlsrouter.service b/systemd/tlsrouter.service
new file mode 100644
index 0000000..23e8fe1
--- /dev/null
+++ b/systemd/tlsrouter.service
@@ -0,0 +1,25 @@
+[Unit]
+Description=TLS SNI proxy
+Documentation=https://github.com/google/tlsrouter
+
+[Service]
+WorkingDirectory=/tmp
+ExecStart=/usr/bin/tlsrouter -conf /etc/tlsrouter.conf
+Restart=always
+User=nobody
+Group=nogroup
+CapabilityBoundingSet=CAP_NET_BIND_SERVICE
+AmbientCapabilities=CAP_NET_BIND_SERVICE
+PrivateTmp=true
+PrivateDevices=true
+ProtectSystem=strict
+ProtectHome=true
+ProtectKernelTunables=true
+ProtectControlGroups=true
+ProtectKernelModules=true
+NoNewPrivileges=true
+SystemCallFilter=~@clock @cpu-emulation @debug @keyring @module @mount @obsolete @privileged @raw-io
+RestrictAddressFamilies=AF_INET AF_INET6 AF_UNIX
+
+[Install]
+WantedBy=multi-user.target
diff --git a/tcpproxy.go b/tcpproxy.go
new file mode 100644
index 0000000..1f03e32
--- /dev/null
+++ b/tcpproxy.go
@@ -0,0 +1,496 @@
+// Copyright 2017 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package tcpproxy lets users build TCP proxies, optionally making
+// routing decisions based on HTTP/1 Host headers and the SNI hostname
+// in TLS connections.
+//
+// Typical usage:
+//
+// var p tcpproxy.Proxy
+// p.AddHTTPHostRoute(":80", "foo.com", tcpproxy.To("10.0.0.1:8081"))
+// p.AddHTTPHostRoute(":80", "bar.com", tcpproxy.To("10.0.0.2:8082"))
+// p.AddRoute(":80", tcpproxy.To("10.0.0.1:8081")) // fallback
+// p.AddSNIRoute(":443", "foo.com", tcpproxy.To("10.0.0.1:4431"))
+// p.AddSNIRoute(":443", "bar.com", tcpproxy.To("10.0.0.2:4432"))
+// p.AddRoute(":443", tcpproxy.To("10.0.0.1:4431")) // fallback
+// log.Fatal(p.Run())
+//
+// Calling Run (or Start) on a proxy also starts all the necessary
+// listeners.
+//
+// For each accepted connection, the rules for that ipPort are
+// matched, in order. If one matches (currently HTTP Host, SNI, or
+// always), then the connection is handed to the target.
+//
+// The two predefined Target implementations are:
+//
+// 1) DialProxy, proxying to another address (use the To func to return a
+// DialProxy value),
+//
+// 2) TargetListener, making the matched connection available via a
+// net.Listener.Accept call.
+//
+// But Target is an interface, so you can also write your own.
+//
+// Note that tcpproxy does not do any TLS encryption or decryption. It
+// only (via DialProxy) copies bytes around. The SNI hostname in the TLS
+// header is unencrypted, for better or worse.
+//
+// This package makes no API stability promises. If you depend on it,
+// vendor it.
+package tcpproxy
+
+import (
+ "bufio"
+ "context"
+ "errors"
+ "fmt"
+ "io"
+ "log"
+ "net"
+ "time"
+)
+
+// Proxy is a proxy. Its zero value is a valid proxy that does
+// nothing. Call methods to add routes before calling Start or Run.
+//
+// The order that routes are added in matters; each is matched in the order
+// registered.
+type Proxy struct {
+ configs map[string]*config // ip:port => config
+
+ lns []net.Listener
+ donec chan struct{} // closed before err
+ err error // any error from listening
+
+ // ListenFunc optionally specifies an alternate listen
+ // function. If nil, net.Dial is used.
+ // The provided net is always "tcp".
+ ListenFunc func(net, laddr string) (net.Listener, error)
+}
+
+// Matcher reports whether hostname matches the Matcher's criteria.
+type Matcher func(ctx context.Context, hostname string) bool
+
+// equals is a trivial Matcher that implements string equality.
+func equals(want string) Matcher {
+ return func(_ context.Context, got string) bool {
+ return want == got
+ }
+}
+
+// config contains the proxying state for one listener.
+type config struct {
+ routes []route
+}
+
+// A route matches a connection to a target.
+type route interface {
+ // match examines the initial bytes of a connection, looking for a
+ // match. If a match is found, match returns a non-nil Target to
+ // which the stream should be proxied. match returns nil if the
+ // connection doesn't match.
+ //
+ // match must not consume bytes from the given bufio.Reader, it
+ // can only Peek.
+ //
+ // If an sni or host header was parsed successfully, that will be
+ // returned as the second parameter.
+ match(*bufio.Reader) (Target, string)
+}
+
+func (p *Proxy) netListen() func(net, laddr string) (net.Listener, error) {
+ if p.ListenFunc != nil {
+ return p.ListenFunc
+ }
+ return net.Listen
+}
+
+func (p *Proxy) configFor(ipPort string) *config {
+ if p.configs == nil {
+ p.configs = make(map[string]*config)
+ }
+ if p.configs[ipPort] == nil {
+ p.configs[ipPort] = &config{}
+ }
+ return p.configs[ipPort]
+}
+
+func (p *Proxy) addRoute(ipPort string, r route) {
+ cfg := p.configFor(ipPort)
+ cfg.routes = append(cfg.routes, r)
+}
+
+// AddRoute appends an always-matching route to the ipPort listener,
+// directing any connection to dest.
+//
+// This is generally used as either the only rule (for simple TCP
+// proxies), or as the final fallback rule for an ipPort.
+//
+// The ipPort is any valid net.Listen TCP address.
+func (p *Proxy) AddRoute(ipPort string, dest Target) {
+ p.addRoute(ipPort, fixedTarget{dest})
+}
+
+type fixedTarget struct {
+ t Target
+}
+
+func (m fixedTarget) match(*bufio.Reader) (Target, string) { return m.t, "" }
+
+// Run is calls Start, and then Wait.
+//
+// It blocks until there's an error. The return value is always
+// non-nil.
+func (p *Proxy) Run() error {
+ if err := p.Start(); err != nil {
+ return err
+ }
+ return p.Wait()
+}
+
+// Wait waits for the Proxy to finish running. Currently this can only
+// happen if a Listener is closed, or Close is called on the proxy.
+//
+// It is only valid to call Wait after a successful call to Start.
+func (p *Proxy) Wait() error {
+ <-p.donec
+ return p.err
+}
+
+// Close closes all the proxy's self-opened listeners.
+func (p *Proxy) Close() error {
+ for _, c := range p.lns {
+ c.Close()
+ }
+ return nil
+}
+
+// Start creates a TCP listener for each unique ipPort from the
+// previously created routes and starts the proxy. It returns any
+// error from starting listeners.
+//
+// If it returns a non-nil error, any successfully opened listeners
+// are closed.
+func (p *Proxy) Start() error {
+ if p.donec != nil {
+ return errors.New("already started")
+ }
+ p.donec = make(chan struct{})
+ errc := make(chan error, len(p.configs))
+ p.lns = make([]net.Listener, 0, len(p.configs))
+ for ipPort, config := range p.configs {
+ ln, err := p.netListen()("tcp", ipPort)
+ if err != nil {
+ p.Close()
+ return err
+ }
+ p.lns = append(p.lns, ln)
+ go p.serveListener(errc, ln, config.routes)
+ }
+ go p.awaitFirstError(errc)
+ return nil
+}
+
+func (p *Proxy) awaitFirstError(errc <-chan error) {
+ p.err = <-errc
+ close(p.donec)
+}
+
+func (p *Proxy) serveListener(ret chan<- error, ln net.Listener, routes []route) {
+ for {
+ c, err := ln.Accept()
+ if err != nil {
+ ret <- err
+ return
+ }
+ go p.serveConn(c, routes)
+ }
+}
+
+// serveConn runs in its own goroutine and matches c against routes.
+// It returns whether it matched purely for testing.
+func (p *Proxy) serveConn(c net.Conn, routes []route) bool {
+ br := bufio.NewReader(c)
+ for _, route := range routes {
+ if target, hostName := route.match(br); target != nil {
+ if n := br.Buffered(); n > 0 {
+ peeked, _ := br.Peek(br.Buffered())
+ c = &Conn{
+ HostName: hostName,
+ Peeked: peeked,
+ Conn: c,
+ }
+ }
+ target.HandleConn(c)
+ return true
+ }
+ }
+ // TODO: hook for this?
+ log.Printf("tcpproxy: no routes matched conn %v/%v; closing", c.RemoteAddr().String(), c.LocalAddr().String())
+ c.Close()
+ return false
+}
+
+// Conn is an incoming connection that has had some bytes read from it
+// to determine how to route the connection. The Read method stitches
+// the peeked bytes and unread bytes back together.
+type Conn struct {
+ // HostName is the hostname field that was sent to the request router.
+ // In the case of TLS, this is the SNI header, in the case of HTTPHost
+ // route, it will be the host header. In the case of a fixed
+ // route, i.e. those created with AddRoute(), this will always be
+ // empty. This can be useful in the case where further routing decisions
+ // need to be made in the Target impementation.
+ HostName string
+
+ // Peeked are the bytes that have been read from Conn for the
+ // purposes of route matching, but have not yet been consumed
+ // by Read calls. It set to nil by Read when fully consumed.
+ Peeked []byte
+
+ // Conn is the underlying connection.
+ // It can be type asserted against *net.TCPConn or other types
+ // as needed. It should not be read from directly unless
+ // Peeked is nil.
+ net.Conn
+}
+
+func (c *Conn) Read(p []byte) (n int, err error) {
+ if len(c.Peeked) > 0 {
+ n = copy(p, c.Peeked)
+ c.Peeked = c.Peeked[n:]
+ if len(c.Peeked) == 0 {
+ c.Peeked = nil
+ }
+ return n, nil
+ }
+ return c.Conn.Read(p)
+}
+
+// Target is what an incoming matched connection is sent to.
+type Target interface {
+ // HandleConn is called when an incoming connection is
+ // matched. After the call to HandleConn, the tcpproxy
+ // package never touches the conn again. Implementations are
+ // responsible for closing the connection when needed.
+ //
+ // The concrete type of conn will be of type *Conn if any
+ // bytes have been consumed for the purposes of route
+ // matching.
+ HandleConn(net.Conn)
+}
+
+// To is shorthand way of writing &tcpproxy.DialProxy{Addr: addr}.
+func To(addr string) *DialProxy {
+ return &DialProxy{Addr: addr}
+}
+
+// DialProxy implements Target by dialing a new connection to Addr
+// and then proxying data back and forth.
+//
+// The To func is a shorthand way of creating a DialProxy.
+type DialProxy struct {
+ // Addr is the TCP address to proxy to.
+ Addr string
+
+ // KeepAlivePeriod sets the period between TCP keep alives.
+ // If zero, a default is used. To disable, use a negative number.
+ // The keep-alive is used for both the client connection and
+ KeepAlivePeriod time.Duration
+
+ // DialTimeout optionally specifies a dial timeout.
+ // If zero, a default is used.
+ // If negative, the timeout is disabled.
+ DialTimeout time.Duration
+
+ // DialContext optionally specifies an alternate dial function
+ // for TCP targets. If nil, the standard
+ // net.Dialer.DialContext method is used.
+ DialContext func(ctx context.Context, network, address string) (net.Conn, error)
+
+ // OnDialError optionally specifies an alternate way to handle errors dialing Addr.
+ // If nil, the error is logged and src is closed.
+ // If non-nil, src is not closed automatically.
+ OnDialError func(src net.Conn, dstDialErr error)
+
+ // ProxyProtocolVersion optionally specifies the version of
+ // HAProxy's PROXY protocol to use. The PROXY protocol provides
+ // connection metadata to the DialProxy target, via a header
+ // inserted ahead of the client's traffic. The DialProxy target
+ // must explicitly support and expect the PROXY header; there is
+ // no graceful downgrade.
+ // If zero, no PROXY header is sent. Currently, version 1 is supported.
+ ProxyProtocolVersion int
+}
+
+// UnderlyingConn returns c.Conn if c of type *Conn,
+// otherwise it returns c.
+func UnderlyingConn(c net.Conn) net.Conn {
+ if wrap, ok := c.(*Conn); ok {
+ return wrap.Conn
+ }
+ return c
+}
+
+func tcpConn(c net.Conn) (t *net.TCPConn, ok bool) {
+ if c, ok := UnderlyingConn(c).(*net.TCPConn); ok {
+ return c, ok
+ }
+ if c, ok := c.(*net.TCPConn); ok {
+ return c, ok
+ }
+ return nil, false
+}
+
+func goCloseConn(c net.Conn) { go c.Close() }
+
+func closeRead(c net.Conn) {
+ if c, ok := tcpConn(c); ok {
+ c.CloseRead()
+ }
+}
+
+func closeWrite(c net.Conn) {
+ if c, ok := tcpConn(c); ok {
+ c.CloseWrite()
+ }
+}
+
+// HandleConn implements the Target interface.
+func (dp *DialProxy) HandleConn(src net.Conn) {
+ ctx := context.Background()
+ var cancel context.CancelFunc
+ if dp.DialTimeout >= 0 {
+ ctx, cancel = context.WithTimeout(ctx, dp.dialTimeout())
+ }
+ dst, err := dp.dialContext()(ctx, "tcp", dp.Addr)
+ if cancel != nil {
+ cancel()
+ }
+ if err != nil {
+ dp.onDialError()(src, err)
+ return
+ }
+ defer goCloseConn(dst)
+
+ if err = dp.sendProxyHeader(dst, src); err != nil {
+ dp.onDialError()(src, err)
+ return
+ }
+ defer goCloseConn(src)
+
+ if ka := dp.keepAlivePeriod(); ka > 0 {
+ for _, c := range []net.Conn{src, dst} {
+ if c, ok := tcpConn(c); ok {
+ c.SetKeepAlive(true)
+ c.SetKeepAlivePeriod(ka)
+ }
+ }
+ }
+
+ errc := make(chan error, 2)
+ go proxyCopy(errc, src, dst)
+ go proxyCopy(errc, dst, src)
+ <-errc
+ <-errc
+}
+
+func (dp *DialProxy) sendProxyHeader(w io.Writer, src net.Conn) error {
+ switch dp.ProxyProtocolVersion {
+ case 0:
+ return nil
+ case 1:
+ var srcAddr, dstAddr *net.TCPAddr
+ if a, ok := src.RemoteAddr().(*net.TCPAddr); ok {
+ srcAddr = a
+ }
+ if a, ok := src.LocalAddr().(*net.TCPAddr); ok {
+ dstAddr = a
+ }
+
+ if srcAddr == nil || dstAddr == nil {
+ _, err := io.WriteString(w, "PROXY UNKNOWN\r\n")
+ return err
+ }
+
+ family := "TCP4"
+ if srcAddr.IP.To4() == nil {
+ family = "TCP6"
+ }
+ _, err := fmt.Fprintf(w, "PROXY %s %s %s %d %d\r\n", family, srcAddr.IP, dstAddr.IP, srcAddr.Port, dstAddr.Port)
+ return err
+ default:
+ return fmt.Errorf("PROXY protocol version %d not supported", dp.ProxyProtocolVersion)
+ }
+}
+
+// proxyCopy is the function that copies bytes around.
+// It's a named function instead of a func literal so users get
+// named goroutines in debug goroutine stack dumps.
+func proxyCopy(errc chan<- error, dst, src net.Conn) {
+ defer closeRead(src)
+ defer closeWrite(dst)
+
+ // Before we unwrap src and/or dst, copy any buffered data.
+ if wc, ok := src.(*Conn); ok && len(wc.Peeked) > 0 {
+ if _, err := dst.Write(wc.Peeked); err != nil {
+ errc <- err
+ return
+ }
+ wc.Peeked = nil
+ }
+
+ // Unwrap the src and dst from *Conn to *net.TCPConn so Go
+ // 1.11's splice optimization kicks in.
+ src = UnderlyingConn(src)
+ dst = UnderlyingConn(dst)
+
+ _, err := io.Copy(dst, src)
+ errc <- err
+}
+
+func (dp *DialProxy) keepAlivePeriod() time.Duration {
+ if dp.KeepAlivePeriod != 0 {
+ return dp.KeepAlivePeriod
+ }
+ return time.Minute
+}
+
+func (dp *DialProxy) dialTimeout() time.Duration {
+ if dp.DialTimeout > 0 {
+ return dp.DialTimeout
+ }
+ return 10 * time.Second
+}
+
+var defaultDialer = new(net.Dialer)
+
+func (dp *DialProxy) dialContext() func(ctx context.Context, network, address string) (net.Conn, error) {
+ if dp.DialContext != nil {
+ return dp.DialContext
+ }
+ return defaultDialer.DialContext
+}
+
+func (dp *DialProxy) onDialError() func(src net.Conn, dstDialErr error) {
+ if dp.OnDialError != nil {
+ return dp.OnDialError
+ }
+ return func(src net.Conn, dstDialErr error) {
+ log.Printf("tcpproxy: for incoming conn %v, error dialing %q: %v", src.RemoteAddr().String(), dp.Addr, dstDialErr)
+ src.Close()
+ }
+}
diff --git a/tcpproxy_test.go b/tcpproxy_test.go
new file mode 100644
index 0000000..0346a7a
--- /dev/null
+++ b/tcpproxy_test.go
@@ -0,0 +1,525 @@
+// Copyright 2017 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcpproxy
+
+import (
+ "bufio"
+ "bytes"
+ "context"
+ "crypto/rand"
+ "crypto/rsa"
+ "crypto/tls"
+ "crypto/x509"
+ "crypto/x509/pkix"
+ "encoding/pem"
+ "errors"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "math/big"
+ "net"
+ "strings"
+ "testing"
+ "time"
+)
+
+type noopTarget struct{}
+
+func (noopTarget) HandleConn(net.Conn) {}
+
+func TestMatchHTTPHost(t *testing.T) {
+ tests := []struct {
+ name string
+ r io.Reader
+ host string
+ want bool
+ }{
+ {
+ name: "match",
+ r: strings.NewReader("GET / HTTP/1.1\r\nHost: foo.com\r\n\r\n"),
+ host: "foo.com",
+ want: true,
+ },
+ {
+ name: "no-match",
+ r: strings.NewReader("GET / HTTP/1.1\r\nHost: foo.com\r\n\r\n"),
+ host: "bar.com",
+ want: false,
+ },
+ {
+ name: "match-huge-request",
+ r: io.MultiReader(strings.NewReader("GET / HTTP/1.1\r\nHost: foo.com\r\n"), neverEnding('a')),
+ host: "foo.com",
+ want: true,
+ },
+ }
+ for i, tt := range tests {
+ name := tt.name
+ if name == "" {
+ name = fmt.Sprintf("test_index_%d", i)
+ }
+ t.Run(name, func(t *testing.T) {
+ br := bufio.NewReader(tt.r)
+ r := httpHostMatch{equals(tt.host), noopTarget{}}
+ m, name := r.match(br)
+ got := m != nil
+ if got != tt.want {
+ t.Fatalf("match = %v; want %v", got, tt.want)
+ }
+ if tt.want && name != tt.host {
+ t.Fatalf("host = %s; want %s", name, tt.host)
+ }
+ get := make([]byte, 3)
+ if _, err := io.ReadFull(br, get); err != nil {
+ t.Fatal(err)
+ }
+ if string(get) != "GET" {
+ t.Fatalf("did bufio.Reader consume bytes? got %q; want GET", get)
+ }
+ })
+ }
+}
+
+type neverEnding byte
+
+func (b neverEnding) Read(p []byte) (n int, err error) {
+ for i := range p {
+ p[i] = byte(b)
+ }
+ return len(p), nil
+}
+
+type recordWritesConn struct {
+ buf bytes.Buffer
+ net.Conn
+}
+
+func (c *recordWritesConn) Write(p []byte) (int, error) {
+ c.buf.Write(p)
+ return len(p), nil
+}
+
+func (c *recordWritesConn) Read(p []byte) (int, error) { return 0, io.EOF }
+
+func clientHelloRecord(t *testing.T, hostName string) string {
+ rec := new(recordWritesConn)
+ cl := tls.Client(rec, &tls.Config{ServerName: hostName})
+ cl.Handshake()
+
+ s := rec.buf.String()
+ if !strings.Contains(s, hostName) {
+ t.Fatalf("clientHello sent in test didn't contain %q", hostName)
+ }
+ return s
+}
+
+func TestSNI(t *testing.T) {
+ const hostName = "foo.com"
+ greeting := clientHelloRecord(t, hostName)
+ got := clientHelloServerName(bufio.NewReader(strings.NewReader(greeting)))
+ if got != hostName {
+ t.Errorf("got SNI %q; want %q", got, hostName)
+ }
+}
+
+func TestProxyStartNone(t *testing.T) {
+ var p Proxy
+ if err := p.Start(); err != nil {
+ t.Fatal(err)
+ }
+}
+
+func newLocalListener(t *testing.T) net.Listener {
+ ln, err := net.Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ ln, err = net.Listen("tcp", "[::1]:0")
+ if err != nil {
+ t.Fatal(err)
+ }
+ }
+ return ln
+}
+
+const testFrontAddr = "1.2.3.4:567"
+
+func testListenFunc(t *testing.T, ln net.Listener) func(network, laddr string) (net.Listener, error) {
+ return func(network, laddr string) (net.Listener, error) {
+ if network != "tcp" {
+ t.Errorf("got Listen call with network %q, not tcp", network)
+ return nil, errors.New("invalid network")
+ }
+ if laddr != testFrontAddr {
+ t.Fatalf("got Listen call with laddr %q, want %q", laddr, testFrontAddr)
+ panic("bogus address")
+ }
+ return ln, nil
+ }
+}
+
+func testProxy(t *testing.T, front net.Listener) *Proxy {
+ return &Proxy{
+ ListenFunc: testListenFunc(t, front),
+ }
+}
+
+func TestBufferedClose(t *testing.T) {
+ front := newLocalListener(t)
+ defer front.Close()
+ back := newLocalListener(t)
+ defer back.Close()
+
+ p := testProxy(t, front)
+ p.AddRoute(testFrontAddr, To(back.Addr().String()))
+ if err := p.Start(); err != nil {
+ t.Fatal(err)
+ }
+
+ toFront, err := net.Dial("tcp", front.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer toFront.Close()
+
+ fromProxy, err := back.Accept()
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer fromProxy.Close()
+ const msg = "message"
+ if _, err := io.WriteString(toFront, msg); err != nil {
+ t.Fatal(err)
+ }
+ // actively close toFront, the write should still make to the back.
+ toFront.Close()
+
+ buf := make([]byte, len(msg))
+ if _, err := io.ReadFull(fromProxy, buf); err != nil {
+ t.Fatal(err)
+ }
+ if string(buf) != msg {
+ t.Fatalf("got %q; want %q", buf, msg)
+ }
+}
+
+func TestProxyAlwaysMatch(t *testing.T) {
+ front := newLocalListener(t)
+ defer front.Close()
+ back := newLocalListener(t)
+ defer back.Close()
+
+ p := testProxy(t, front)
+ p.AddRoute(testFrontAddr, To(back.Addr().String()))
+ if err := p.Start(); err != nil {
+ t.Fatal(err)
+ }
+
+ toFront, err := net.Dial("tcp", front.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer toFront.Close()
+
+ fromProxy, err := back.Accept()
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer fromProxy.Close()
+ const msg = "message"
+ io.WriteString(toFront, msg)
+
+ buf := make([]byte, len(msg))
+ if _, err := io.ReadFull(fromProxy, buf); err != nil {
+ t.Fatal(err)
+ }
+ if string(buf) != msg {
+ t.Fatalf("got %q; want %q", buf, msg)
+ }
+}
+
+func TestProxyHTTP(t *testing.T) {
+ front := newLocalListener(t)
+ defer front.Close()
+
+ backFoo := newLocalListener(t)
+ defer backFoo.Close()
+ backBar := newLocalListener(t)
+ defer backBar.Close()
+
+ p := testProxy(t, front)
+ p.AddHTTPHostRoute(testFrontAddr, "foo.com", To(backFoo.Addr().String()))
+ p.AddHTTPHostRoute(testFrontAddr, "bar.com", To(backBar.Addr().String()))
+ if err := p.Start(); err != nil {
+ t.Fatal(err)
+ }
+
+ toFront, err := net.Dial("tcp", front.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer toFront.Close()
+
+ const msg = "GET / HTTP/1.1\r\nHost: bar.com\r\n\r\n"
+ io.WriteString(toFront, msg)
+
+ fromProxy, err := backBar.Accept()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ buf := make([]byte, len(msg))
+ if _, err := io.ReadFull(fromProxy, buf); err != nil {
+ t.Fatal(err)
+ }
+ if string(buf) != msg {
+ t.Fatalf("got %q; want %q", buf, msg)
+ }
+}
+
+func TestProxySNI(t *testing.T) {
+ front := newLocalListener(t)
+ defer front.Close()
+
+ backFoo := newLocalListener(t)
+ defer backFoo.Close()
+ backBar := newLocalListener(t)
+ defer backBar.Close()
+
+ p := testProxy(t, front)
+ p.AddSNIRoute(testFrontAddr, "foo.com", To(backFoo.Addr().String()))
+ p.AddSNIRoute(testFrontAddr, "bar.com", To(backBar.Addr().String()))
+ if err := p.Start(); err != nil {
+ t.Fatal(err)
+ }
+
+ toFront, err := net.Dial("tcp", front.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer toFront.Close()
+
+ msg := clientHelloRecord(t, "bar.com")
+ io.WriteString(toFront, msg)
+
+ fromProxy, err := backBar.Accept()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ buf := make([]byte, len(msg))
+ if _, err := io.ReadFull(fromProxy, buf); err != nil {
+ t.Fatal(err)
+ }
+ if string(buf) != msg {
+ t.Fatalf("got %q; want %q", buf, msg)
+ }
+}
+
+func TestAddSNIRouteFunc(t *testing.T) {
+ front := newLocalListener(t)
+ defer front.Close()
+
+ backFoo := newLocalListener(t)
+ defer backFoo.Close()
+ backBar := newLocalListener(t)
+ defer backBar.Close()
+
+ p := testProxy(t, front)
+ p.AddSNIRouteFunc(testFrontAddr, func(ctx context.Context, sniName string) (_ Target, ok bool) {
+ if sniName == "bar.com" {
+ return To(backBar.Addr().String()), true
+ }
+ t.Fatalf("failed to match %q", sniName)
+ return nil, false
+ })
+ if err := p.Start(); err != nil {
+ t.Fatal(err)
+ }
+
+ toFront, err := net.Dial("tcp", front.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer toFront.Close()
+
+ msg := clientHelloRecord(t, "bar.com")
+ io.WriteString(toFront, msg)
+
+ fromProxy, err := backBar.Accept()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ buf := make([]byte, len(msg))
+ if _, err := io.ReadFull(fromProxy, buf); err != nil {
+ t.Fatal(err)
+ }
+ if string(buf) != msg {
+ t.Fatalf("got %q; want %q", buf, msg)
+ }
+}
+func TestProxyPROXYOut(t *testing.T) {
+ front := newLocalListener(t)
+ defer front.Close()
+ back := newLocalListener(t)
+ defer back.Close()
+
+ p := testProxy(t, front)
+ p.AddRoute(testFrontAddr, &DialProxy{
+ Addr: back.Addr().String(),
+ ProxyProtocolVersion: 1,
+ })
+ if err := p.Start(); err != nil {
+ t.Fatal(err)
+ }
+
+ toFront, err := net.Dial("tcp", front.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ io.WriteString(toFront, "foo")
+ toFront.Close()
+
+ fromProxy, err := back.Accept()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ bs, err := ioutil.ReadAll(fromProxy)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ want := fmt.Sprintf("PROXY TCP4 %s %s %d %d\r\nfoo", toFront.LocalAddr().(*net.TCPAddr).IP, toFront.RemoteAddr().(*net.TCPAddr).IP, toFront.LocalAddr().(*net.TCPAddr).Port, toFront.RemoteAddr().(*net.TCPAddr).Port)
+ if string(bs) != want {
+ t.Fatalf("got %q; want %q", bs, want)
+ }
+}
+
+type tlsServer struct {
+ Listener net.Listener
+ Domain string
+ Test *testing.T
+}
+
+func (t *tlsServer) Start() {
+ cert := cert(t.Test, t.Domain)
+ cfg := &tls.Config{
+ Certificates: []tls.Certificate{cert},
+ }
+ cfg.BuildNameToCertificate()
+
+ go func() {
+ for {
+ rawConn, err := t.Listener.Accept()
+ if err != nil {
+ return // assume Close()
+ }
+
+ conn := tls.Server(rawConn, cfg)
+ if _, err = io.WriteString(conn, t.Domain); err != nil {
+ t.Test.Errorf("writing to tlsconn: %s", err)
+ }
+ conn.Close()
+ }
+ }()
+}
+
+func (t *tlsServer) Close() {
+ t.Listener.Close()
+}
+
+// cert creates a well-formed, but completely insecure self-signed
+// cert for domain.
+func cert(t *testing.T, domain string) tls.Certificate {
+ private, err := rsa.GenerateKey(rand.Reader, 1024)
+ if err != nil {
+ t.Fatal(err)
+ }
+ template := &x509.Certificate{
+ SerialNumber: big.NewInt(1),
+ Subject: pkix.Name{
+ Organization: []string{"Test Co"},
+ CommonName: domain,
+ },
+ NotBefore: time.Time{},
+ NotAfter: time.Now().Add(60 * time.Minute),
+ IsCA: true,
+ KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
+ ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
+ BasicConstraintsValid: true,
+ }
+
+ derBytes, err := x509.CreateCertificate(rand.Reader, template, template, &private.PublicKey, private)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ var cert, key bytes.Buffer
+ pem.Encode(&cert, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
+ pem.Encode(&key, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(private)})
+
+ tlscert, err := tls.X509KeyPair(cert.Bytes(), key.Bytes())
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ return tlscert
+}
+
+// newTLSServer starts a TLS server that serves a self-signed cert for
+// domain.
+func newTLSServer(t *testing.T, domain string) net.Listener {
+ cert := cert(t, domain)
+
+ l := newLocalListener(t)
+ go func() {
+ for {
+ rawConn, err := l.Accept()
+ if err != nil {
+ return // assume closed
+ }
+
+ cfg := &tls.Config{
+ Certificates: []tls.Certificate{cert},
+ }
+ cfg.BuildNameToCertificate()
+ conn := tls.Server(rawConn, cfg)
+ if _, err = io.WriteString(conn, domain); err != nil {
+ t.Errorf("writing to tlsconn: %s", err)
+ }
+ conn.Close()
+ }
+ }()
+
+ return l
+}
+
+func readTLS(dest, domain string) (string, error) {
+ conn, err := tls.Dial("tcp", dest, &tls.Config{
+ ServerName: domain,
+ InsecureSkipVerify: true,
+ })
+ if err != nil {
+ return "", err
+ }
+ defer conn.Close()
+
+ bs, err := ioutil.ReadAll(conn)
+ if err != nil {
+ return "", err
+ }
+ return string(bs), nil
+}