diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-16 16:18:53 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-16 16:18:53 +0000 |
commit | 1cdc15a87db98ea2a6a55d331e65ec1a4fc4f273 (patch) | |
tree | 34af891c87f9f96c9816500e46b7ea11588dc6ea | |
parent | Initial commit. (diff) | |
download | golang-github-inetaf-tcpproxy-upstream/0.0_git20231102.2862066.tar.xz golang-github-inetaf-tcpproxy-upstream/0.0_git20231102.2862066.zip |
Adding upstream version 0.0~git20231102.2862066.upstream/0.0_git20231102.2862066upstream
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
-rw-r--r-- | .travis.yml | 45 | ||||
-rw-r--r-- | CONTRIBUTING.md | 8 | ||||
-rw-r--r-- | LICENSE | 202 | ||||
-rw-r--r-- | README.md | 5 | ||||
-rw-r--r-- | cmd/tlsrouter/README.md | 51 | ||||
-rw-r--r-- | cmd/tlsrouter/config.go | 137 | ||||
-rw-r--r-- | cmd/tlsrouter/config_test.go | 61 | ||||
-rw-r--r-- | cmd/tlsrouter/e2e_test.go | 216 | ||||
-rw-r--r-- | cmd/tlsrouter/main.go | 191 | ||||
-rw-r--r-- | cmd/tlsrouter/sni.go | 232 | ||||
-rw-r--r-- | cmd/tlsrouter/sni_test.go | 456 | ||||
-rw-r--r-- | go.mod | 5 | ||||
-rw-r--r-- | go.sum | 2 | ||||
-rw-r--r-- | http.go | 125 | ||||
-rw-r--r-- | listener.go | 108 | ||||
-rw-r--r-- | listener_test.go | 49 | ||||
-rw-r--r-- | scripts/prune_old_versions.go | 150 | ||||
-rw-r--r-- | sni.go | 115 | ||||
-rw-r--r-- | systemd/tlsrouter.service | 25 | ||||
-rw-r--r-- | tcpproxy.go | 496 | ||||
-rw-r--r-- | tcpproxy_test.go | 525 |
21 files changed, 3204 insertions, 0 deletions
diff --git a/.travis.yml b/.travis.yml new file mode 100644 index 0000000..a8d3a50 --- /dev/null +++ b/.travis.yml @@ -0,0 +1,45 @@ +language: go +go: +- "1.16.x" +- "1.17.x" +- tip +os: +- linux +script: +- go build ./... +- go test ./... +- go vet ./... + +jobs: + include: + - stage: deploy + go: "1.16" + install: + - gem install fpm + script: + - go build ./cmd/tlsrouter + - fpm -s dir -t deb -n tlsrouter -v $(date '+%Y%m%d%H%M%S') + --license Apache2 + --vendor "David Anderson <dave@natulte.net>" + --maintainer "David Anderson <dave@natulte.net>" + --description "TLS SNI router" + --url "https://github.com/inetaf/tcpproxy/tree/master/cmd/tlsrouter" + ./tlsrouter=/usr/bin/tlsrouter + ./systemd/tlsrouter.service=/lib/systemd/system/tlsrouter.service + deploy: + - provider: packagecloud + repository: tlsrouter + username: danderson + dist: debian/stretch + skip_cleanup: true + on: + branch: master + token: + secure: gNU3o70EU4oYeIS6pr0K5oLMGqqxrcf41EOv6c/YoHPVdV6Cx4j9NW0/ISgu6a1/Xf2NgWKT5BWwLpAuhmGdALuOz1Ah//YBWd9N8mGHGaC6RpOPDU8/9NkQdBEmjEH9sgX4PNOh1KQ7d7O0OH0g8RqJlJa0MkUYbTtN6KJ29oiUXxKmZM4D/iWB8VonKOnrtx1NwQL8jL8imZyEV/1fknhDwumz2iKeU1le4Neq9zkxwICMLUonmgphlrp+SDb1EOoHxT6cn51bqBQtQUplfC4dN4OQU/CPqE9E1N1noibvN29YA93qfcrjD3I95KT9wzq+3B6he33+kb0Gz+Cj5ypGy4P85l7TuX4CtQg0U3NAlJCk32IfsdjK+o47pdmADij9IIb9yKt+g99FMERkJJY5EInqEsxHlW/vNF5OqQCmpiHstZL4R2XaHEsWh6j77npnjjC1Aea8xZTWr8PTsbSzVkbG7bTmFpZoPH8eEmr4GNuw5gnbi6D1AJDjcA+UdY9s5qZNpzuWOqfhOFxL+zUW+8sHBvcoFw3R+pwHECs2LCL1c0xAC1LtNUnmW/gnwHavtvKkzErjR1P8Xl7obCbeChJjp+b/BcFYlNACldZcuzBAPyPwIdlWVyUonL4bm63upfMEEShiAIDDJ21y7fjsQK7CfPA7g25bpyo+hV8= + - provider: script + on: + branch: master + script: go run scripts/prune_old_versions.go -user=danderson -repo=tlsrouter -distro=debian -version=stretch -package=tlsrouter -arch=amd64 -limit=2 + env: + # Packagecloud API key, for prune_old_versions.go + - secure: "SRcNwt+45QyPS1w9aGxMg9905Y6d9w4mBM29G6iTTnUB5nD7cAk4m+tf834knGSobVXlWcRnTDW8zrHdQ9yX22dPqCpH5qE+qzTmIvxRHrVJRMmPeYvligJ/9jYfHgQbvuRT8cUpIcpCQAla6rw8nXfKTOE3h8XqMP2hdc3DTVOu2HCfKCNco1tJ7is+AIAnFV2Wpsbb3ZsdKFvHvi2RKUfFaX61J1GNt2/XJIlZs8jC6Y1IAC+ftjql9UsAE/WjZ9fL0Ww1b9/LBIIGHXWI3HpVv9WvlhhIxIlJgOVjmU2lbSuj2w/EBDJ9cd1Qe+wJkT3yKzE1NRsNScVjGg+Ku5igJu/XXuaHkIX01+15BqgPduBYRL0atiNQDhqgBiSyVhXZBX9vsgsp0bgpKaBSF++CV18Q9dara8aljqqS33M3imO3I8JmXU10944QA9Wvu7pCYuIzXxhINcDXRvqxBqz5LnFJGwnGqngTrOCSVS2xn7Y+sjmhe1n5cPCEISlozfa9mPYPvMPp8zg3TbATOOM8CVfcpaNscLqa/+SExN3zMwSanjNKrBgoaQcBzGW5mIgSPxhXkWikBgapiEN7+2Y032Lhqdb9dYjH+EuwcnofspDjjMabWxnuJaln+E3/9vZi2ooQrBEtvymUTy4VMSnqwIX5bU7nPdIuQycdWhk=" diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..188ad87 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,8 @@ +Contributions are welcome by pull request. + +You need to sign the Google Contributor License Agreement before your +contributions can be accepted. You can find the individual and organization +level CLAs here: + +Individual: https://cla.developers.google.com/about/google-individual +Organization: https://cla.developers.google.com/about/google-corporate @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md new file mode 100644 index 0000000..f526c21 --- /dev/null +++ b/README.md @@ -0,0 +1,5 @@ +# tcpproxy + +For library usage, see https://godoc.org/inet.af/tcpproxy/ + +For CLI usage, see https://github.com/inetaf/tcpproxy/blob/master/cmd/tlsrouter/README.md diff --git a/cmd/tlsrouter/README.md b/cmd/tlsrouter/README.md new file mode 100644 index 0000000..d915c32 --- /dev/null +++ b/cmd/tlsrouter/README.md @@ -0,0 +1,51 @@ +# TLS SNI router + +[![license](https://img.shields.io/github/license/google/tlsrouter.svg?maxAge=2592000)](https://github.com/inetaf/tcpproxy/blob/master/LICENSE) [![Travis](https://img.shields.io/travis/google/tlsrouter.svg?maxAge=2592000)](https://travis-ci.org/google/tlsrouter) [![api](https://img.shields.io/badge/api-unstable-red.svg)](https://godoc.org/go.universe.tf/tlsrouter) + +TLSRouter is a TLS proxy that routes connections to backends based on +the TLS SNI (Server Name Indication) of the TLS handshake. It carries +no encryption keys and cannot decode the traffic that it proxies. + +## Installation + +Install TLSRouter via `go get`: + +```shell +go get go.universe.tf/tcpproxy/cmd/tlsrouter +``` + +## Usage + +TLSRouter requires a configuration file that tells it what backend to +use for a given hostname. The config file looks like: + +``` +# Basic hostname -> backend mapping +go.universe.tf localhost:1234 + +# DNS wildcards are understood as well. +*.go.universe.tf 1.2.3.4:8080 + +# DNS wildcards can go anywhere in name. +google.* 10.20.30.40:443 + +# RE2 regexes are also available +/(alpha|beta|gamma)\.mon(itoring)?\.dave\.tf/ 100.200.100.200:443 + +# If your backend supports HAProxy's PROXY protocol, you can enable +# it to receive the real client ip:port. + +fancy.backend 2.3.4.5:443 PROXY +``` + +TLSRouter takes one mandatory commandline argument, the configuration file to use: + +```shell +tlsrouter -conf tlsrouter.conf +``` + +Optional flags are: + + * `-listen <addr>`: set the listen address (default `:443`) + * `-hello-timeout <duration>`: how long to wait for the start of the + TLS handshake (default `3s`) diff --git a/cmd/tlsrouter/config.go b/cmd/tlsrouter/config.go new file mode 100644 index 0000000..692b04b --- /dev/null +++ b/cmd/tlsrouter/config.go @@ -0,0 +1,137 @@ +// Copyright 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "bufio" + "bytes" + "errors" + "fmt" + "io" + "os" + "regexp" + "strings" + "sync" +) + +// A Route maps a match on a domain name to a backend. +type Route struct { + match *regexp.Regexp + backend string + proxyInfo bool +} + +// Config stores the TLS routing configuration. +type Config struct { + mu sync.Mutex + routes []Route +} + +func dnsRegex(s string) (*regexp.Regexp, error) { + if len(s) >= 2 && s[0] == '/' && s[len(s)-1] == '/' { + return regexp.Compile(s[1 : len(s)-1]) + } + + var b []string + for _, f := range strings.Split(s, ".") { + switch f { + case "*": + b = append(b, `[^.]+`) + case "": + return nil, fmt.Errorf("DNS name %q has empty label", s) + default: + b = append(b, regexp.QuoteMeta(f)) + } + } + return regexp.Compile(fmt.Sprintf("^%s$", strings.Join(b, `\.`))) +} + +// Match returns the backend for hostname, and whether to use the PROXY protocol. +func (c *Config) Match(hostname string) (string, bool) { + c.mu.Lock() + defer c.mu.Unlock() + + for _, r := range c.routes { + if r.match.MatchString(hostname) { + return r.backend, r.proxyInfo + } + } + return "", false +} + +// Read replaces the current Config with one read from r. +func (c *Config) Read(r io.Reader) error { + var routes []Route + var backends []string + + s := bufio.NewScanner(r) + for s.Scan() { + if strings.HasPrefix(strings.TrimSpace(s.Text()), "#") { + // Comment, ignore. + continue + } + + fs := strings.Fields(s.Text()) + switch len(fs) { + case 0: + continue + case 1: + return fmt.Errorf("invalid %q on a line by itself", s.Text()) + case 2: + re, err := dnsRegex(fs[0]) + if err != nil { + return err + } + routes = append(routes, Route{re, fs[1], false}) + backends = append(backends, fs[1]) + case 3: + re, err := dnsRegex(fs[0]) + if err != nil { + return err + } + if fs[2] != "PROXY" { + return errors.New("third item on a line can only be PROXY") + } + routes = append(routes, Route{re, fs[1], true}) + backends = append(backends, fs[1]) + default: + // TODO: multiple backends? + return fmt.Errorf("too many fields on line: %q", s.Text()) + } + } + if err := s.Err(); err != nil { + return err + } + + c.mu.Lock() + defer c.mu.Unlock() + c.routes = routes + return nil +} + +// ReadFile replaces the current Config with one read from path. +func (c *Config) ReadFile(path string) error { + f, err := os.Open(path) + if err != nil { + return err + } + return c.Read(f) +} + +// ReadString replaces the current Config with one read from cfg. +func (c *Config) ReadString(cfg string) error { + b := bytes.NewBufferString(cfg) + return c.Read(b) +} diff --git a/cmd/tlsrouter/config_test.go b/cmd/tlsrouter/config_test.go new file mode 100644 index 0000000..9819b91 --- /dev/null +++ b/cmd/tlsrouter/config_test.go @@ -0,0 +1,61 @@ +package main + +import ( + "bytes" + "testing" +) + +func TestConfig(t *testing.T) { + type result struct { + backend string + proxy bool + } + + cases := []struct { + Config string + Tests map[string]result + }{ + { + Config: ` +# Comment +go.universe.tf 1.2.3.4 +*.universe.tf 2.3.4.5 +# Comment +google.* 3.4.5.6 +/gooo+gle\.com/ 4.5.6.7 +foobar.net 6.7.8.9 PROXY +`, + Tests: map[string]result{ + "go.universe.tf": result{"1.2.3.4", false}, + "foo.universe.tf": result{"2.3.4.5", false}, + "bar.universe.tf": result{"2.3.4.5", false}, + "google.com": result{"3.4.5.6", false}, + "google.fr": result{"3.4.5.6", false}, + "goooooooooogle.com": result{"4.5.6.7", false}, + "foobar.net": result{"6.7.8.9", true}, + + "blah.com": result{"", false}, + "google.com.br": result{"", false}, + "foo.bar.universe.tf": result{"", false}, + "goooooglexcom": result{"", false}, + }, + }, + } + + for _, test := range cases { + var cfg Config + if err := cfg.Read(bytes.NewBufferString(test.Config)); err != nil { + t.Fatalf("Failed to read config (%s):\n%q", err, test.Config) + } + + for hostname, expected := range test.Tests { + backend, proxy := cfg.Match(hostname) + if expected.backend != backend { + t.Errorf("cfg.Match(%q) is %q, want %q", hostname, backend, expected.backend) + } + if expected.proxy != proxy { + t.Errorf("cfg.Match(%q).proxy is %v, want %v", hostname, proxy, expected.proxy) + } + } + } +} diff --git a/cmd/tlsrouter/e2e_test.go b/cmd/tlsrouter/e2e_test.go new file mode 100644 index 0000000..6e54021 --- /dev/null +++ b/cmd/tlsrouter/e2e_test.go @@ -0,0 +1,216 @@ +package main + +import ( + "bytes" + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "fmt" + "io/ioutil" + "math/big" + "net" + "strings" + "sync/atomic" + "testing" + "time" + + proxyproto "github.com/armon/go-proxyproto" +) + +func TestRouting(t *testing.T) { + // Backend servers + s1, err := serveTLS(t, "server1", false, "test.com") + if err != nil { + t.Fatalf("serve TLS server1: %s", err) + } + defer s1.Close() + + s2, err := serveTLS(t, "server2", false, "foo.net") + if err != nil { + t.Fatalf("serve TLS server2: %s", err) + } + defer s2.Close() + + s4, err := serveTLS(t, "server4", true, "proxy.design") + if err != nil { + t.Fatalf("server TLS server4: %s", err) + } + defer s4.Close() + + // One proxy + var p Proxy + l, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("create listener: %s", err) + } + defer l.Close() + go p.Serve(l) + + if err := p.Config.ReadString(fmt.Sprintf(` +test.com %s +foo.net %s +proxy.design %s PROXY +`, s1.Addr(), s2.Addr(), s4.Addr())); err != nil { + t.Fatalf("configure proxy: %s", err) + } + + for _, test := range []struct { + N, V string + P *x509.CertPool + OK bool + Transparent bool + }{ + {"test.com", "server1", s1.Pool, true, false}, + {"foo.net", "server2", s2.Pool, true, false}, + {"bar.org", "", s1.Pool, false, false}, + {"proxy.design", "server4", s4.Pool, true, true}, + } { + res, transparent, err := getTLS(l.Addr().String(), test.N, test.P) + switch { + case test.OK && err != nil: + t.Fatalf("get %q failed: %s", test.N, err) + case !test.OK && err == nil: + t.Fatalf("get %q should have failed, but returned %q", test.N, res) + case test.OK && res != test.V: + t.Fatalf("got wrong value from %q, got %q, want %q", test.N, res, test.V) + case test.OK && transparent != test.Transparent: + t.Fatalf("connection transparency for %q was %v, want %v", test.N, transparent, test.Transparent) + } + } +} + +// getTLS attempts to set up a TLS session using the given proxy +// address, domain, and cert pool. It returns the value served by the +// server, as well as a bool indicating whether the server knew the +// true client address, indicating that the PROXY protocol was in use. +func getTLS(addr string, domain string, pool *x509.CertPool) (string, bool, error) { + cfg := tls.Config{ + RootCAs: pool, + ServerName: domain, + } + conn, err := tls.Dial("tcp", addr, &cfg) + if err != nil { + return "", false, fmt.Errorf("dial TLS %q for %q: %s", addr, domain, err) + } + defer conn.Close() + bs, err := ioutil.ReadAll(conn) + if err != nil { + return "", false, fmt.Errorf("read TLS from %q (domain %q): %s", addr, domain, err) + } + fs := strings.Split(string(bs), " ") + if len(fs) != 2 { + return "", false, fmt.Errorf("read TLS from %q (domain %q): incoherent response %q", addr, domain, string(bs)) + } + transparent := fs[1] == conn.LocalAddr().String() + return fs[0], transparent, nil +} + +type tlsServer struct { + Domains []string + Value string + Pool *x509.CertPool + Test *testing.T + NumHits uint32 + l net.Listener +} + +func (s *tlsServer) Serve() { + for { + c, err := s.l.Accept() + if err != nil { + s.Test.Logf("accept failed on %q: %s", s.Domains, err) + return + } + atomic.AddUint32(&s.NumHits, 1) + fmt.Fprintf(c, "%s %s", s.Value, c.RemoteAddr()) + c.Close() + } +} + +func (s *tlsServer) Addr() string { + return s.l.Addr().String() +} + +func (s *tlsServer) Close() error { + return s.l.Close() +} + +func serveTLS(t *testing.T, value string, understandProxy bool, domains ...string) (*tlsServer, error) { + cert, pool, err := selfSignedCert(domains) + if err != nil { + return nil, err + } + + cfg := &tls.Config{ + Certificates: []tls.Certificate{cert}, + } + cfg.BuildNameToCertificate() + + var l net.Listener + + l, err = net.Listen("tcp", "localhost:0") + if err != nil { + return nil, err + } + + if understandProxy { + l = &proxyproto.Listener{Listener: l} + } + + l = tls.NewListener(l, cfg) + + ret := &tlsServer{ + Domains: domains, + Value: value, + Pool: pool, + Test: t, + l: l, + } + go ret.Serve() + return ret, nil +} + +func selfSignedCert(domains []string) (tls.Certificate, *x509.CertPool, error) { + pkey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return tls.Certificate{}, nil, err + } + template := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{ + Organization: []string{"Test Co"}, + CommonName: domains[0], + }, + NotBefore: time.Now().Add(-5 * time.Minute), + NotAfter: time.Now().Add(60 * time.Minute), + IsCA: true, + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + DNSNames: domains[:], + } + + derBytes, err := x509.CreateCertificate(rand.Reader, template, template, pkey.Public(), pkey) + if err != nil { + return tls.Certificate{}, nil, err + } + + var cert, key bytes.Buffer + pem.Encode(&cert, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) + pem.Encode(&key, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(pkey)}) + + tlscert, err := tls.X509KeyPair(cert.Bytes(), key.Bytes()) + if err != nil { + return tls.Certificate{}, nil, err + } + + pool := x509.NewCertPool() + if !pool.AppendCertsFromPEM(cert.Bytes()) { + return tls.Certificate{}, nil, fmt.Errorf("failed to add cert %q to pool", domains) + } + + return tlscert, pool, nil +} diff --git a/cmd/tlsrouter/main.go b/cmd/tlsrouter/main.go new file mode 100644 index 0000000..ff1a816 --- /dev/null +++ b/cmd/tlsrouter/main.go @@ -0,0 +1,191 @@ +// Copyright 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "bytes" + "flag" + "fmt" + "io" + "log" + "net" + "sync" + "time" +) + +var ( + cfgFile = flag.String("conf", "", "configuration file") + listen = flag.String("listen", ":443", "listening port") + helloTimeout = flag.Duration("hello-timeout", 3*time.Second, "how long to wait for the TLS ClientHello") +) + +func main() { + flag.Parse() + + p := &Proxy{} + if err := p.Config.ReadFile(*cfgFile); err != nil { + log.Fatalf("Failed to read config %q: %s", *cfgFile, err) + } + + log.Fatalf("%s", p.ListenAndServe(*listen)) +} + +// Proxy routes connections to backends based on a Config. +type Proxy struct { + Config Config + l net.Listener +} + +// Serve accepts connections from l and routes them according to TLS SNI. +func (p *Proxy) Serve(l net.Listener) error { + for { + c, err := l.Accept() + if err != nil { + return fmt.Errorf("accept new conn: %s", err) + } + + conn := &Conn{ + TCPConn: c.(*net.TCPConn), + config: &p.Config, + } + go conn.proxy() + } +} + +// ListenAndServe creates a listener on addr calls Serve on it. +func (p *Proxy) ListenAndServe(addr string) error { + l, err := net.Listen("tcp", addr) + if err != nil { + return fmt.Errorf("create listener: %s", err) + } + return p.Serve(l) +} + +// A Conn handles the TLS proxying of one user connection. +type Conn struct { + *net.TCPConn + config *Config + + tlsMinor int + hostname string + backend string + backendConn *net.TCPConn +} + +func (c *Conn) logf(msg string, args ...interface{}) { + msg = fmt.Sprintf(msg, args...) + log.Printf("%s <> %s: %s", c.RemoteAddr(), c.LocalAddr(), msg) +} + +func (c *Conn) abort(alert byte, msg string, args ...interface{}) { + c.logf(msg, args...) + alertMsg := []byte{21, 3, byte(c.tlsMinor), 0, 2, 2, alert} + + if err := c.SetWriteDeadline(time.Now().Add(*helloTimeout)); err != nil { + c.logf("error while setting write deadline during abort: %s", err) + // Do NOT send the alert if we can't set a write deadline, + // that could result in leaking a connection for an extended + // period. + return + } + + if _, err := c.Write(alertMsg); err != nil { + c.logf("error while sending alert: %s", err) + } +} + +func (c *Conn) internalError(msg string, args ...interface{}) { c.abort(80, msg, args...) } +func (c *Conn) sniFailed(msg string, args ...interface{}) { c.abort(112, msg, args...) } + +func (c *Conn) proxy() { + defer c.Close() + + if err := c.SetReadDeadline(time.Now().Add(*helloTimeout)); err != nil { + c.internalError("Setting read deadline for ClientHello: %s", err) + return + } + + var ( + err error + handshakeBuf bytes.Buffer + ) + c.hostname, c.tlsMinor, err = extractSNI(io.TeeReader(c, &handshakeBuf)) + if err != nil { + c.internalError("Extracting SNI: %s", err) + return + } + + c.logf("extracted SNI %s", c.hostname) + + if err = c.SetReadDeadline(time.Time{}); err != nil { + c.internalError("Clearing read deadline for ClientHello: %s", err) + return + } + + addProxyHeader := false + c.backend, addProxyHeader = c.config.Match(c.hostname) + if c.backend == "" { + c.sniFailed("no backend found for %q", c.hostname) + return + } + + c.logf("routing %q to %q", c.hostname, c.backend) + backend, err := net.DialTimeout("tcp", c.backend, 10*time.Second) + if err != nil { + c.internalError("failed to dial backend %q for %q: %s", c.backend, c.hostname, err) + return + } + defer backend.Close() + + c.backendConn = backend.(*net.TCPConn) + + // If the backend supports the HAProxy PROXY protocol, give it the + // real source information about the connection. + if addProxyHeader { + remote := c.TCPConn.RemoteAddr().(*net.TCPAddr) + local := c.TCPConn.LocalAddr().(*net.TCPAddr) + family := "TCP6" + if remote.IP.To4() != nil { + family = "TCP4" + } + if _, err := fmt.Fprintf(c.backendConn, "PROXY %s %s %s %d %d\r\n", family, remote.IP, local.IP, remote.Port, local.Port); err != nil { + c.internalError("failed to send PROXY header to %q: %s", c.backend, err) + return + } + } + + // Replay the piece of the handshake we had to read to do the + // routing, then blindly proxy any other bytes. + if _, err = io.Copy(c.backendConn, &handshakeBuf); err != nil { + c.internalError("failed to replay handshake to %q: %s", c.backend, err) + return + } + + var wg sync.WaitGroup + wg.Add(2) + go proxy(&wg, c.TCPConn, c.backendConn) + go proxy(&wg, c.backendConn, c.TCPConn) + wg.Wait() +} + +func proxy(wg *sync.WaitGroup, a, b net.Conn) { + defer wg.Done() + atcp, btcp := a.(*net.TCPConn), b.(*net.TCPConn) + if _, err := io.Copy(atcp, btcp); err != nil { + log.Printf("%s<>%s -> %s<>%s: %s", atcp.RemoteAddr(), atcp.LocalAddr(), btcp.LocalAddr(), btcp.RemoteAddr(), err) + } + btcp.CloseWrite() + atcp.CloseRead() +} diff --git a/cmd/tlsrouter/sni.go b/cmd/tlsrouter/sni.go new file mode 100644 index 0000000..ed79df2 --- /dev/null +++ b/cmd/tlsrouter/sni.go @@ -0,0 +1,232 @@ +// Copyright 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "encoding/binary" + "errors" + "fmt" + "io" +) + +func extractSNI(r io.Reader) (string, int, error) { + handshake, tlsver, err := handshakeRecord(r) + if err != nil { + return "", 0, fmt.Errorf("reading TLS record: %s", err) + } + + sni, err := parseHello(handshake) + if err != nil { + return "", 0, fmt.Errorf("reading ClientHello: %s", err) + } + if len(sni) == 0 { + // ClientHello did not present an SNI extension. Valid packet, + // no hostname. + return "", tlsver, nil + } + + hostname, err := parseSNI(sni) + if err != nil { + return "", 0, fmt.Errorf("parsing SNI extension: %s", err) + } + return hostname, tlsver, nil +} + +// Extract the indicated hostname, if any, from the given SNI +// extension bytes. +func parseSNI(b []byte) (string, error) { + b, _, err := vector(b, 2) + if err != nil { + return "", err + } + + var ret []byte + for len(b) >= 3 { + typ := b[0] + ret, b, err = vector(b[1:], 2) + if err != nil { + return "", fmt.Errorf("truncated SNI extension") + } + + if typ == sniHostnameID { + return string(ret), nil + } + } + + if len(b) != 0 { + return "", fmt.Errorf("trailing garbage at end of SNI extension") + } + + // No DNS-based SNI present. + return "", nil +} + +const sniExtensionID = 0 +const sniHostnameID = 0 + +// Parse a TLS handshake record as a ClientHello message and extract +// the SNI extension bytes, if any. +func parseHello(b []byte) ([]byte, error) { + if len(b) == 0 { + return nil, errors.New("zero length handshake record") + } + if b[0] != 1 { + return nil, fmt.Errorf("non-ClientHello handshake record type %d", b[0]) + } + + // We're expecting a stricter TLS parser to run after we've + // proxied, so we ignore any trailing bytes that might be present + // (e.g. another handshake message). + b, _, err := vector(b[1:], 3) + if err != nil { + return nil, fmt.Errorf("reading ClientHello: %s", err) + } + + // ClientHello must be at least 34 bytes to reach the first vector + // length byte. The actual minimal size is larger than that, but + // vector() will correctly handle truncated packets. + if len(b) < 34 { + return nil, errors.New("ClientHello packet too short") + } + + if b[0] != 3 { + return nil, fmt.Errorf("ClientHello has unsupported version %d.%d", b[0], b[1]) + } + switch b[1] { + case 1, 2, 3: + // TLS 1.0, TLS 1.1, TLS 1.2 + default: + return nil, fmt.Errorf("TLS record has unsupported version %d.%d", b[0], b[1]) + } + + // Skip over version and random struct + b = b[34:] + + // We don't technically care about SessionID, but we care that the + // framing is well-formed all the way up to the SNI field, so that + // we are sure that we're pulling the same SNI bytes as the + // eventual TLS implementation. + vec, b, err := vector(b, 1) + if err != nil { + return nil, fmt.Errorf("reading ClientHello SessionID: %s", err) + } + if len(vec) > 32 { + return nil, fmt.Errorf("ClientHello SessionID too long (%db)", len(vec)) + } + + // Likewise, we're just checking the bare minimum of framing. + vec, b, err = vector(b, 2) + if err != nil { + return nil, fmt.Errorf("reading ClientHello CipherSuites: %s", err) + } + if len(vec) < 2 || len(vec)%2 != 0 { + return nil, fmt.Errorf("ClientHello CipherSuites invalid length %d", len(vec)) + } + + vec, b, err = vector(b, 1) + if err != nil { + return nil, fmt.Errorf("reading ClientHello CompressionMethods: %s", err) + } + if len(vec) < 1 { + return nil, fmt.Errorf("ClientHello CompressionMethods invalid length %d", len(vec)) + } + + // Finally, we reach the extensions. + if len(b) == 0 { + // No extensions. This is not an error, it just means we have + // no SNI payload. + return nil, nil + } + b, vec, err = vector(b, 2) + if err != nil { + return nil, fmt.Errorf("reading ClientHello extensions: %s", err) + } + if len(vec) != 0 { + return nil, fmt.Errorf("%d bytes of trailing garbage in ClientHello", len(vec)) + } + + for len(b) >= 4 { + typ := binary.BigEndian.Uint16(b[:2]) + vec, b, err = vector(b[2:], 2) + if err != nil { + return nil, fmt.Errorf("reading ClientHello extension %d: %s", typ, err) + } + if typ == sniExtensionID { + // Found the SNI extension, return its payload. We don't + // care about anything in the packet beyond this point. + return vec, nil + } + } + + if len(b) != 0 { + return nil, fmt.Errorf("%d bytes of trailing garbage in ClientHello", len(b)) + } + + // Successfully parsed all extensions, but there was no SNI. + return nil, nil +} + +const maxTLSRecordLength = 16384 + +// Read one TLS record, which must be for the handshake protocol, from r. +func handshakeRecord(r io.Reader) ([]byte, int, error) { + var hdr struct { + Type uint8 + Major, Minor uint8 + Length uint16 + } + if err := binary.Read(r, binary.BigEndian, &hdr); err != nil { + return nil, 0, fmt.Errorf("reading TLS record header: %s", err) + } + + if hdr.Type != 22 { + return nil, 0, fmt.Errorf("TLS record is not a handshake") + } + + if hdr.Major != 3 { + return nil, 0, fmt.Errorf("TLS record has unsupported version %d.%d", hdr.Major, hdr.Minor) + } + switch hdr.Minor { + case 1, 2, 3: + // TLS 1.0, TLS 1.1, TLS 1.2 + default: + return nil, 0, fmt.Errorf("TLS record has unsupported version %d.%d", hdr.Major, hdr.Minor) + } + + if hdr.Length > maxTLSRecordLength { + return nil, 0, fmt.Errorf("TLS record length is greater than %d", maxTLSRecordLength) + } + + ret := make([]byte, hdr.Length) + if _, err := io.ReadFull(r, ret); err != nil { + return nil, 0, err + } + + return ret, int(hdr.Minor), nil +} + +func vector(b []byte, lenBytes int) ([]byte, []byte, error) { + if len(b) < lenBytes { + return nil, nil, errors.New("not enough space in packet for vector") + } + var l int + for _, b := range b[:lenBytes] { + l = (l << 8) + int(b) + } + if len(b) < l+lenBytes { + return nil, nil, errors.New("not enough space in packet for vector") + } + return b[lenBytes : l+lenBytes], b[l+lenBytes:], nil +} diff --git a/cmd/tlsrouter/sni_test.go b/cmd/tlsrouter/sni_test.go new file mode 100644 index 0000000..8c87d24 --- /dev/null +++ b/cmd/tlsrouter/sni_test.go @@ -0,0 +1,456 @@ +// Copyright 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "bytes" + "testing" +) + +func slice(l int) []byte { + ret := make([]byte, l) + for i := 0; i < l; i++ { + ret[i] = byte(i) + } + return ret +} + +func vec(l, lenBytes int) []byte { + b := slice(l) + vecLen := len(b) + ret := make([]byte, vecLen+l) + for i := l - 1; i >= 0; i-- { + ret[i] = byte(vecLen & 0xff) + vecLen >>= 8 + } + copy(ret[l:], b) + return ret +} + +func packet(bs ...[]byte) []byte { + var ret []byte + for _, b := range bs { + ret = append(ret, b...) + } + return ret +} + +func offset(b []byte, off int) []byte { + return b[off:] +} + +func TestVector(t *testing.T) { + tests := []struct { + in []byte + inLen int + out1, out2 []byte + err bool + }{ + { + // 1b length + append([]byte{3}, slice(10)...), 1, + slice(3), offset(slice(10), 3), false, + }, + { + // 1b length, no trailer + append([]byte{10}, slice(10)...), 1, + slice(10), []byte{}, false, + }, + { + // 1b length, no vector + append([]byte{0}, slice(10)...), 1, + []byte{}, slice(10), false, + }, + { + // 1b length, no vector or trailer + []byte{0}, 1, + []byte{}, []byte{}, false, + }, + { + // 2b length, LSB only + append([]byte{0, 3}, slice(10)...), 2, + slice(3), offset(slice(10), 3), false, + }, + { + // 2b length, MSB only + append([]byte{3, 0}, slice(1024)...), 2, + slice(768), offset(slice(1024), 768), false, + }, + { + // 2b length, both bytes + append([]byte{3, 2}, slice(1024)...), 2, + slice(770), offset(slice(1024), 770), false, + }, + { + // 3b length + append([]byte{1, 2, 3}, slice(100000)...), 3, + slice(66051), offset(slice(100000), 66051), false, + }, + { + // no bytes + []byte{}, 1, + nil, nil, true, + }, + { + // no slice + nil, 1, + nil, nil, true, + }, + { + // not enough bytes for length + []byte{1}, 2, + nil, nil, true, + }, + { + // no bytes after length + []byte{1}, 1, + nil, nil, true, + }, + { + // not enough bytes for vector + []byte{4, 1, 2}, 1, + nil, nil, true, + }, + } + + for _, test := range tests { + actual1, actual2, err := vector(test.in, test.inLen) + if !test.err && (err != nil) { + t.Errorf("unexpected error %q", err) + } + if test.err && (err == nil) { + t.Errorf("unexpected success") + } + if err != nil { + continue + } + if !bytes.Equal(actual1, test.out1) { + t.Errorf("wrong bytes for vector slice. Got %#v, want %#v", actual1, test.out1) + } + if !bytes.Equal(actual2, test.out2) { + t.Errorf("wrong bytes for vector slice. Got %#v, want %#v", actual2, test.out2) + } + } +} + +func TestHandshakeRecord(t *testing.T) { + tests := []struct { + in []byte + out []byte + tlsver int + }{ + { + // TLS 1.0, 1b packet + []byte{22, 3, 1, 0, 1, 3}, + []byte{3}, + 1, + }, + { + // TLS 1.1, 1b packet + []byte{22, 3, 2, 0, 1, 3}, + []byte{3}, + 2, + }, + { + // TLS 1.2, 1b packet + []byte{22, 3, 3, 0, 1, 3}, + []byte{3}, + 3, + }, + { + // TLS 1.2, no payload bytes + []byte{22, 3, 3, 0, 0}, + []byte{}, + 3, + }, + { + // TLS 1.2, >255b payload w/ trailing stuff + append([]byte{22, 3, 3, 3, 2}, slice(1024)...), + slice(770), + 3, + }, + { + // TLS 1.2, 2^14 payload + append([]byte{22, 3, 3, 64, 0}, slice(maxTLSRecordLength)...), + slice(maxTLSRecordLength), + 3, + }, + { + // TLS 1.2, >2^14 payload + append([]byte{22, 3, 3, 64, 1}, slice(maxTLSRecordLength+1)...), + nil, + 0, + }, + { + // TLS 1.2, truncated payload + []byte{22, 3, 3, 0, 4, 1, 2}, + nil, + 0, + }, + { + // truncated header + []byte{22}, + nil, + 0, + }, + { + // wrong record type + []byte{42, 3, 3, 0, 1, 3}, + nil, + 0, + }, + { + // wrong TLS major version + []byte{22, 2, 3, 0, 1, 3}, + nil, + 0, + }, + { + // wrong TLS minor version + []byte{22, 3, 42, 0, 1, 3}, + nil, + 0, + }, + { + // Obsolete SSL 3.0 + []byte{22, 3, 0, 0, 1, 3}, + nil, + 0, + }, + } + + for _, test := range tests { + r := bytes.NewBuffer(test.in) + actual, tlsver, err := handshakeRecord(r) + if test.out == nil && err == nil { + t.Errorf("unexpected success") + continue + } + if !bytes.Equal(test.out, actual) { + t.Errorf("wrong bytes for TLS record. Got %#v, want %#v", actual, test.out) + } + if tlsver != test.tlsver { + t.Errorf("wrong TLS version returned. Got %d, want %d", tlsver, test.tlsver) + } + } +} + +func TestParseHello(t *testing.T) { + tests := []struct { + in []byte + out []byte + err bool + }{ + { + // Wrong record type + packet([]byte{42, 0, 0, 1, 1}), + nil, + true, + }, + { + // Truncated payload + packet([]byte{1, 0, 0, 1}), + nil, + true, + }, + { + // Payload too small + packet([]byte{1, 0, 0, 1, 1}), + nil, + true, + }, + { + // Unknown major version + packet([]byte{1, 0, 0, 34, 1, 0}, slice(32)), + nil, + true, + }, + { + // Unknown minor version + packet([]byte{1, 0, 0, 34, 3, 42}, slice(32)), + nil, + true, + }, + { + // Missing required variadic fields + packet([]byte{1, 0, 0, 34, 3, 1}, slice(32)), + nil, + true, + }, + { + // All zero variadic fields (no ciphersuites, no compression) + packet([]byte{1, 0, 0, 38, 3, 1}, slice(32), []byte{0, 0, 0, 0}), + nil, + true, + }, + { + // All zero variadic fields (no ciphersuites, no compression, nonzero session ID) + packet([]byte{1, 0, 0, 70, 3, 1}, slice(32), []byte{32}, slice(32), []byte{0, 0, 0}), + nil, + true, + }, + { + // Session + ciphersuites, no compression + packet([]byte{1, 0, 0, 72, 3, 1}, slice(32), []byte{32}, slice(32), []byte{0, 2, 1, 2, 0}), + nil, + true, + }, + { + // First valid packet. TLS 1.0, no extensions present. + packet([]byte{1, 0, 0, 73, 3, 1}, slice(32), []byte{32}, slice(32), []byte{0, 2, 1, 2, 1, 0}), + nil, + false, + }, + { + // TLS 1.1, no extensions present. + packet([]byte{1, 0, 0, 73, 3, 2}, slice(32), []byte{32}, slice(32), []byte{0, 2, 1, 2, 1, 0}), + nil, + false, + }, + { + // TLS 1.2, no extensions present. + packet([]byte{1, 0, 0, 73, 3, 3}, slice(32), []byte{32}, slice(32), []byte{0, 2, 1, 2, 1, 0}), + nil, + false, + }, + { + // TLS 1.2, garbage extensions + packet([]byte{1, 0, 0, 115, 3, 3}, slice(32), []byte{32}, slice(32), []byte{0, 2, 1, 2, 1, 0}, slice(42)), + nil, + true, + }, + { + // empty extensions vector + packet([]byte{1, 0, 0, 75, 3, 3}, slice(32), []byte{32}, slice(32), []byte{0, 2, 1, 2, 1, 0}, []byte{0, 0}), + nil, + false, + }, + { + // non-SNI extensions + packet([]byte{1, 0, 0, 85, 3, 3}, slice(32), []byte{32}, slice(32), []byte{0, 2, 1, 2, 1, 0}, []byte{0, 10, 42, 42, 0, 0, 100, 100, 0, 2, 1, 2}), + nil, + false, + }, + { + // SNI present + packet([]byte{1, 0, 0, 90, 3, 3}, slice(32), []byte{32}, slice(32), []byte{0, 2, 1, 2, 1, 0}, []byte{0, 15, 42, 42, 0, 0, 100, 100, 0, 2, 1, 2, 0, 0, 0, 1, 182}), + []byte{182}, + false, + }, + { + // Longer SNI + packet([]byte{1, 0, 0, 93, 3, 3}, slice(32), []byte{32}, slice(32), []byte{0, 2, 1, 2, 1, 0}, []byte{0, 18, 42, 42, 0, 0, 100, 100, 0, 2, 1, 2, 0, 0, 0, 4}, slice(4)), + slice(4), + false, + }, + { + // Embedded SNI + packet([]byte{1, 0, 0, 93, 3, 3}, slice(32), []byte{32}, slice(32), []byte{0, 2, 1, 2, 1, 0}, []byte{0, 18, 42, 42, 0, 0, 0, 0, 0, 4}, slice(4), []byte{100, 100, 0, 2, 1, 2}), + slice(4), + false, + }, + } + + for _, test := range tests { + actual, err := parseHello(test.in) + if test.err { + if err == nil { + t.Errorf("unexpected success") + } + continue + } + if err != nil { + t.Errorf("unexpected error %q", err) + continue + } + if !bytes.Equal(test.out, actual) { + t.Errorf("wrong bytes for SNI data. Got %#v, want %#v", actual, test.out) + } + } +} + +func TestParseSNI(t *testing.T) { + tests := []struct { + in []byte + out string + err bool + }{ + { + // Empty packet + []byte{}, + "", + true, + }, + { + // Truncated packet + []byte{0, 2, 1}, + "", + true, + }, + { + // Truncated packet within SNI vector + []byte{0, 2, 1, 2}, + "", + true, + }, + { + // Wrong SNI kind + []byte{0, 3, 1, 0, 0}, + "", + false, + }, + { + // Right SNI kind, no hostname + []byte{0, 3, 0, 0, 0}, + "", + false, + }, + { + // SNI hostname + packet([]byte{0, 6, 0, 0, 3}, []byte("lol")), + "lol", + false, + }, + { + // Multiple SNI kinds + packet([]byte{0, 13, 1, 0, 0, 0, 0, 3}, []byte("lol"), []byte{42, 0, 1, 2}), + "lol", + false, + }, + { + // Multiple SNI hostnames (illegal, but we just return the first) + packet([]byte{0, 13, 1, 0, 0, 0, 0, 3}, []byte("bar"), []byte{0, 0, 3}, []byte("lol")), + "bar", + false, + }, + } + + for _, test := range tests { + actual, err := parseSNI(test.in) + if test.err { + if err == nil { + t.Errorf("unexpected success") + } + continue + } + if err != nil { + t.Errorf("unexpected error %q", err) + continue + } + if test.out != actual { + t.Errorf("wrong SNI hostname. Got %q, want %q", actual, test.out) + } + } +} @@ -0,0 +1,5 @@ +module inet.af/tcpproxy + +go 1.16 + +require github.com/armon/go-proxyproto v0.0.0-20210323213023-7e956b284f0a @@ -0,0 +1,2 @@ +github.com/armon/go-proxyproto v0.0.0-20210323213023-7e956b284f0a h1:AP/vsCIvJZ129pdm9Ek7bH7yutN3hByqsMoNrWAxRQc= +github.com/armon/go-proxyproto v0.0.0-20210323213023-7e956b284f0a/go.mod h1:QmP9hvJ91BbJmGVGSbutW19IC0Q9phDCLGaomwTJbgU= @@ -0,0 +1,125 @@ +// Copyright 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcpproxy + +import ( + "bufio" + "bytes" + "context" + "net/http" +) + +// AddHTTPHostRoute appends a route to the ipPort listener that +// routes to dest if the incoming HTTP/1.x Host header name is +// httpHost. If it doesn't match, rule processing continues for any +// additional routes on ipPort. +// +// The ipPort is any valid net.Listen TCP address. +func (p *Proxy) AddHTTPHostRoute(ipPort, httpHost string, dest Target) { + p.AddHTTPHostMatchRoute(ipPort, equals(httpHost), dest) +} + +// AddHTTPHostMatchRoute appends a route to the ipPort listener that +// routes to dest if the incoming HTTP/1.x Host header name is +// accepted by matcher. If it doesn't match, rule processing continues +// for any additional routes on ipPort. +// +// The ipPort is any valid net.Listen TCP address. +func (p *Proxy) AddHTTPHostMatchRoute(ipPort string, match Matcher, dest Target) { + p.addRoute(ipPort, httpHostMatch{match, dest}) +} + +type httpHostMatch struct { + matcher Matcher + target Target +} + +func (m httpHostMatch) match(br *bufio.Reader) (Target, string) { + hh := httpHostHeader(br) + if m.matcher(context.TODO(), hh) { + return m.target, hh + } + return nil, "" +} + +// httpHostHeader returns the HTTP Host header from br without +// consuming any of its bytes. It returns "" if it can't find one. +func httpHostHeader(br *bufio.Reader) string { + const maxPeek = 4 << 10 + peekSize := 0 + for { + peekSize++ + if peekSize > maxPeek { + b, _ := br.Peek(br.Buffered()) + return httpHostHeaderFromBytes(b) + } + b, err := br.Peek(peekSize) + if n := br.Buffered(); n > peekSize { + b, _ = br.Peek(n) + peekSize = n + } + if len(b) > 0 { + if b[0] < 'A' || b[0] > 'Z' { + // Doesn't look like an HTTP verb + // (GET, POST, etc). + return "" + } + if bytes.Index(b, crlfcrlf) != -1 || bytes.Index(b, lflf) != -1 { + req, err := http.ReadRequest(bufio.NewReader(bytes.NewReader(b))) + if err != nil { + return "" + } + if len(req.Header["Host"]) > 1 { + // TODO(bradfitz): what does + // ReadRequest do if there are + // multiple Host headers? + return "" + } + return req.Host + } + } + if err != nil { + return httpHostHeaderFromBytes(b) + } + } +} + +var ( + lfHostColon = []byte("\nHost:") + lfhostColon = []byte("\nhost:") + crlf = []byte("\r\n") + lf = []byte("\n") + crlfcrlf = []byte("\r\n\r\n") + lflf = []byte("\n\n") +) + +func httpHostHeaderFromBytes(b []byte) string { + if i := bytes.Index(b, lfHostColon); i != -1 { + return string(bytes.TrimSpace(untilEOL(b[i+len(lfHostColon):]))) + } + if i := bytes.Index(b, lfhostColon); i != -1 { + return string(bytes.TrimSpace(untilEOL(b[i+len(lfhostColon):]))) + } + return "" +} + +// untilEOL returns v, truncated before the first '\n' byte, if any. +// The returned slice may include a '\r' at the end. +func untilEOL(v []byte) []byte { + if i := bytes.IndexByte(v, '\n'); i != -1 { + return v[:i] + } + return v +} diff --git a/listener.go b/listener.go new file mode 100644 index 0000000..1ddc48e --- /dev/null +++ b/listener.go @@ -0,0 +1,108 @@ +// Copyright 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcpproxy + +import ( + "io" + "net" + "sync" +) + +// TargetListener implements both net.Listener and Target. +// Matched Targets become accepted connections. +type TargetListener struct { + Address string // Address is the string reported by TargetListener.Addr().String(). + + mu sync.Mutex + cond *sync.Cond + closed bool + nextConn net.Conn +} + +var ( + _ net.Listener = (*TargetListener)(nil) + _ Target = (*TargetListener)(nil) +) + +func (tl *TargetListener) lock() { + tl.mu.Lock() + if tl.cond == nil { + tl.cond = sync.NewCond(&tl.mu) + } +} + +type tcpAddr string + +func (a tcpAddr) Network() string { return "tcp" } +func (a tcpAddr) String() string { return string(a) } + +// Addr returns the listener's Address field as a net.Addr. +func (tl *TargetListener) Addr() net.Addr { return tcpAddr(tl.Address) } + +// Close stops listening for new connections. All new connections +// routed to this listener will be closed. Already accepted +// connections are not closed. +func (tl *TargetListener) Close() error { + tl.lock() + if tl.closed { + tl.mu.Unlock() + return nil + } + tl.closed = true + tl.mu.Unlock() + tl.cond.Broadcast() + return nil +} + +// HandleConn implements the Target interface. It blocks until tl is +// closed or another goroutine has called Accept and received c. +func (tl *TargetListener) HandleConn(c net.Conn) { + tl.lock() + defer tl.mu.Unlock() + for tl.nextConn != nil && !tl.closed { + tl.cond.Wait() + } + if tl.closed { + c.Close() + return + } + tl.nextConn = c + tl.cond.Broadcast() // Signal might be sufficient; verify. + for tl.nextConn == c && !tl.closed { + tl.cond.Wait() + } + if tl.closed { + c.Close() + return + } +} + +// Accept implements the Accept method in the net.Listener interface. +func (tl *TargetListener) Accept() (net.Conn, error) { + tl.lock() + for tl.nextConn == nil && !tl.closed { + tl.cond.Wait() + } + if tl.closed { + tl.mu.Unlock() + return nil, io.EOF + } + c := tl.nextConn + tl.nextConn = nil + tl.mu.Unlock() + tl.cond.Broadcast() // Signal might be sufficient; verify. + + return c, nil +} diff --git a/listener_test.go b/listener_test.go new file mode 100644 index 0000000..70087ca --- /dev/null +++ b/listener_test.go @@ -0,0 +1,49 @@ +// Copyright 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcpproxy + +import ( + "io" + "testing" +) + +func TestListenerAccept(t *testing.T) { + tl := new(TargetListener) + ch := make(chan interface{}, 1) + go func() { + for { + conn, err := tl.Accept() + if err != nil { + ch <- err + return + } + ch <- conn + } + }() + + for i := 0; i < 3; i++ { + conn := new(Conn) + tl.HandleConn(conn) + got := <-ch + if got != conn { + t.Errorf("Accept conn = %v; want %v", got, conn) + } + } + tl.Close() + got := <-ch + if got != io.EOF { + t.Errorf("Accept error post-Close = %v; want io.EOF", got) + } +} diff --git a/scripts/prune_old_versions.go b/scripts/prune_old_versions.go new file mode 100644 index 0000000..42e031e --- /dev/null +++ b/scripts/prune_old_versions.go @@ -0,0 +1,150 @@ +// Copyright 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "encoding/json" + "flag" + "fmt" + "io/ioutil" + "net/http" + "os" + "sort" + "strings" + "time" +) + +var ( + user = flag.String("user", "", "username") + repo = flag.String("repo", "", "repository name") + pkgType = flag.String("pkg-type", "deb", "Package type, e.g. 'deb'") + distro = flag.String("distro", "", "distro name, e.g. 'debian'") + distroVersion = flag.String("version", "", "distro version, e.g. 'stretch'") + pkg = flag.String("package", "", "package name") + arch = flag.String("arch", "", "package architecture") + limit = flag.Int("limit", 2, "package versions to keep") +) + +func fatalf(msg string, args ...interface{}) { + fmt.Printf(msg+"\n", args...) + os.Exit(1) +} + +func main() { + flag.Parse() + if *user == "" { + fatalf("missing -user") + } + if *repo == "" { + fatalf("missing -repo") + } + if *pkgType == "" { + fatalf("missing -pkg-type") + } + if *distro == "" { + fatalf("missing -distro") + } + if *distroVersion == "" { + fatalf("missing -version") + } + if *pkg == "" { + fatalf("missing -package") + } + if *arch == "" { + fatalf("missing -arch") + } + if *limit < 1 { + fatalf("limit must be >= 1") + } + + files, err := packageVersions(*user, *repo, *pkgType, *distro, *distroVersion, *pkg, *arch) + if err != nil { + fmt.Println(err) + os.Exit(1) + } + if len(files) <= *limit { + fmt.Println("Below limit, no packages deleted") + return + } + delete := files[:len(files)-*limit] + keep := files[len(files)-*limit:] + if err = deletePackages(delete); err != nil { + fmt.Println(err) + os.Exit(1) + } + + fmt.Printf("Deleted:\n\n%s\n\nKept:\n\n%s\n", strings.Join(delete, "\n"), strings.Join(keep, "\n")) +} + +type packageMeta struct { + Created time.Time `json:"created_at"` + Filename string `json:"filename"` +} + +type metaSort []packageMeta + +func (m metaSort) Len() int { return len(m) } +func (m metaSort) Less(i, j int) bool { return m[i].Created.Before(m[j].Created) } +func (m metaSort) Swap(i, j int) { m[i], m[j] = m[j], m[i] } + +func packageVersions(user, repo, typ, distro, version, pkgname, arch string) ([]string, error) { + url := fmt.Sprintf("https://%s:@packagecloud.io/api/v1/repos/%s/%s/package/%s/%s/%s/%s/%s/versions.json", os.Getenv("PACKAGECLOUD_API_KEY"), user, repo, typ, distro, version, pkgname, arch) + resp, err := http.Get(url) + if err != nil { + return nil, fmt.Errorf("get versions.json: %s", err) + } + defer resp.Body.Close() + if resp.StatusCode != 200 { + msg, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("get error message of versions.json get: %s", err) + } + return nil, fmt.Errorf("get versions.json: %s (%q)", resp.Status, string(msg)) + } + + var files []packageMeta + if err := json.NewDecoder(resp.Body).Decode(&files); err != nil { + return nil, fmt.Errorf("decode versions.json: %s", err) + } + + // Newest first + sort.Sort(metaSort(files)) + + var ret []string + for _, meta := range files { + ret = append(ret, fmt.Sprintf("/api/v1/repos/%s/%s/%s/%s/%s", user, repo, distro, version, meta.Filename)) + } + + return ret, nil +} + +func deletePackages(urls []string) error { + for _, url := range urls { + fullURL := fmt.Sprintf("https://%s:@packagecloud.io%s", os.Getenv("PACKAGECLOUD_API_KEY"), url) + req, err := http.NewRequest("DELETE", fullURL, nil) + if err != nil { + return fmt.Errorf("build delete request for %s: %s", url, err) + } + resp, err := http.DefaultClient.Do(req) + if err != nil { + return fmt.Errorf("delete %s: %s", url, err) + } + defer resp.Body.Close() + if resp.StatusCode != 200 { + return fmt.Errorf("delete %s: %s", url, resp.Status) + } + } + return nil +} @@ -0,0 +1,115 @@ +// Copyright 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcpproxy + +import ( + "bufio" + "bytes" + "context" + "crypto/tls" + "io" + "net" +) + +// AddSNIRoute appends a route to the ipPort listener that routes to +// dest if the incoming TLS SNI server name is sni. If it doesn't +// match, rule processing continues for any additional routes on +// ipPort. +// +// The ipPort is any valid net.Listen TCP address. +func (p *Proxy) AddSNIRoute(ipPort, sni string, dest Target) { + p.AddSNIMatchRoute(ipPort, equals(sni), dest) +} + +// AddSNIMatchRoute appends a route to the ipPort listener that routes +// to dest if the incoming TLS SNI server name is accepted by +// matcher. If it doesn't match, rule processing continues for any +// additional routes on ipPort. +// +// The ipPort is any valid net.Listen TCP address. +func (p *Proxy) AddSNIMatchRoute(ipPort string, matcher Matcher, dest Target) { + p.addRoute(ipPort, sniMatch{matcher: matcher, target: dest}) +} + +// SNITargetFunc is the func callback used by Proxy.AddSNIRouteFunc. +type SNITargetFunc func(ctx context.Context, sniName string) (t Target, ok bool) + +// AddSNIRouteFunc adds a route to ipPort that matches an SNI request and calls +// fn to map its nap to a target. +func (p *Proxy) AddSNIRouteFunc(ipPort string, fn SNITargetFunc) { + p.addRoute(ipPort, sniMatch{targetFunc: fn}) +} + +type sniMatch struct { + matcher Matcher + target Target + + // Alternatively, if targetFunc is non-nil, it's used instead: + targetFunc SNITargetFunc +} + +func (m sniMatch) match(br *bufio.Reader) (Target, string) { + sni := clientHelloServerName(br) + if sni == "" { + return nil, "" + } + if m.targetFunc != nil { + if t, ok := m.targetFunc(context.TODO(), sni); ok { + return t, sni + } + return nil, "" + } + if m.matcher(context.TODO(), sni) { + return m.target, sni + } + return nil, "" +} + +// clientHelloServerName returns the SNI server name inside the TLS ClientHello, +// without consuming any bytes from br. +// On any error, the empty string is returned. +func clientHelloServerName(br *bufio.Reader) (sni string) { + const recordHeaderLen = 5 + hdr, err := br.Peek(recordHeaderLen) + if err != nil { + return "" + } + const recordTypeHandshake = 0x16 + if hdr[0] != recordTypeHandshake { + return "" // Not TLS. + } + recLen := int(hdr[3])<<8 | int(hdr[4]) // ignoring version in hdr[1:3] + helloBytes, err := br.Peek(recordHeaderLen + recLen) + if err != nil { + return "" + } + tls.Server(sniSniffConn{r: bytes.NewReader(helloBytes)}, &tls.Config{ + GetConfigForClient: func(hello *tls.ClientHelloInfo) (*tls.Config, error) { + sni = hello.ServerName + return nil, nil + }, + }).Handshake() + return +} + +// sniSniffConn is a net.Conn that reads from r, fails on Writes, +// and crashes otherwise. +type sniSniffConn struct { + r io.Reader + net.Conn // nil; crash on any unexpected use +} + +func (c sniSniffConn) Read(p []byte) (int, error) { return c.r.Read(p) } +func (sniSniffConn) Write(p []byte) (int, error) { return 0, io.EOF } diff --git a/systemd/tlsrouter.service b/systemd/tlsrouter.service new file mode 100644 index 0000000..23e8fe1 --- /dev/null +++ b/systemd/tlsrouter.service @@ -0,0 +1,25 @@ +[Unit] +Description=TLS SNI proxy +Documentation=https://github.com/google/tlsrouter + +[Service] +WorkingDirectory=/tmp +ExecStart=/usr/bin/tlsrouter -conf /etc/tlsrouter.conf +Restart=always +User=nobody +Group=nogroup +CapabilityBoundingSet=CAP_NET_BIND_SERVICE +AmbientCapabilities=CAP_NET_BIND_SERVICE +PrivateTmp=true +PrivateDevices=true +ProtectSystem=strict +ProtectHome=true +ProtectKernelTunables=true +ProtectControlGroups=true +ProtectKernelModules=true +NoNewPrivileges=true +SystemCallFilter=~@clock @cpu-emulation @debug @keyring @module @mount @obsolete @privileged @raw-io +RestrictAddressFamilies=AF_INET AF_INET6 AF_UNIX + +[Install] +WantedBy=multi-user.target diff --git a/tcpproxy.go b/tcpproxy.go new file mode 100644 index 0000000..1f03e32 --- /dev/null +++ b/tcpproxy.go @@ -0,0 +1,496 @@ +// Copyright 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package tcpproxy lets users build TCP proxies, optionally making +// routing decisions based on HTTP/1 Host headers and the SNI hostname +// in TLS connections. +// +// Typical usage: +// +// var p tcpproxy.Proxy +// p.AddHTTPHostRoute(":80", "foo.com", tcpproxy.To("10.0.0.1:8081")) +// p.AddHTTPHostRoute(":80", "bar.com", tcpproxy.To("10.0.0.2:8082")) +// p.AddRoute(":80", tcpproxy.To("10.0.0.1:8081")) // fallback +// p.AddSNIRoute(":443", "foo.com", tcpproxy.To("10.0.0.1:4431")) +// p.AddSNIRoute(":443", "bar.com", tcpproxy.To("10.0.0.2:4432")) +// p.AddRoute(":443", tcpproxy.To("10.0.0.1:4431")) // fallback +// log.Fatal(p.Run()) +// +// Calling Run (or Start) on a proxy also starts all the necessary +// listeners. +// +// For each accepted connection, the rules for that ipPort are +// matched, in order. If one matches (currently HTTP Host, SNI, or +// always), then the connection is handed to the target. +// +// The two predefined Target implementations are: +// +// 1) DialProxy, proxying to another address (use the To func to return a +// DialProxy value), +// +// 2) TargetListener, making the matched connection available via a +// net.Listener.Accept call. +// +// But Target is an interface, so you can also write your own. +// +// Note that tcpproxy does not do any TLS encryption or decryption. It +// only (via DialProxy) copies bytes around. The SNI hostname in the TLS +// header is unencrypted, for better or worse. +// +// This package makes no API stability promises. If you depend on it, +// vendor it. +package tcpproxy + +import ( + "bufio" + "context" + "errors" + "fmt" + "io" + "log" + "net" + "time" +) + +// Proxy is a proxy. Its zero value is a valid proxy that does +// nothing. Call methods to add routes before calling Start or Run. +// +// The order that routes are added in matters; each is matched in the order +// registered. +type Proxy struct { + configs map[string]*config // ip:port => config + + lns []net.Listener + donec chan struct{} // closed before err + err error // any error from listening + + // ListenFunc optionally specifies an alternate listen + // function. If nil, net.Dial is used. + // The provided net is always "tcp". + ListenFunc func(net, laddr string) (net.Listener, error) +} + +// Matcher reports whether hostname matches the Matcher's criteria. +type Matcher func(ctx context.Context, hostname string) bool + +// equals is a trivial Matcher that implements string equality. +func equals(want string) Matcher { + return func(_ context.Context, got string) bool { + return want == got + } +} + +// config contains the proxying state for one listener. +type config struct { + routes []route +} + +// A route matches a connection to a target. +type route interface { + // match examines the initial bytes of a connection, looking for a + // match. If a match is found, match returns a non-nil Target to + // which the stream should be proxied. match returns nil if the + // connection doesn't match. + // + // match must not consume bytes from the given bufio.Reader, it + // can only Peek. + // + // If an sni or host header was parsed successfully, that will be + // returned as the second parameter. + match(*bufio.Reader) (Target, string) +} + +func (p *Proxy) netListen() func(net, laddr string) (net.Listener, error) { + if p.ListenFunc != nil { + return p.ListenFunc + } + return net.Listen +} + +func (p *Proxy) configFor(ipPort string) *config { + if p.configs == nil { + p.configs = make(map[string]*config) + } + if p.configs[ipPort] == nil { + p.configs[ipPort] = &config{} + } + return p.configs[ipPort] +} + +func (p *Proxy) addRoute(ipPort string, r route) { + cfg := p.configFor(ipPort) + cfg.routes = append(cfg.routes, r) +} + +// AddRoute appends an always-matching route to the ipPort listener, +// directing any connection to dest. +// +// This is generally used as either the only rule (for simple TCP +// proxies), or as the final fallback rule for an ipPort. +// +// The ipPort is any valid net.Listen TCP address. +func (p *Proxy) AddRoute(ipPort string, dest Target) { + p.addRoute(ipPort, fixedTarget{dest}) +} + +type fixedTarget struct { + t Target +} + +func (m fixedTarget) match(*bufio.Reader) (Target, string) { return m.t, "" } + +// Run is calls Start, and then Wait. +// +// It blocks until there's an error. The return value is always +// non-nil. +func (p *Proxy) Run() error { + if err := p.Start(); err != nil { + return err + } + return p.Wait() +} + +// Wait waits for the Proxy to finish running. Currently this can only +// happen if a Listener is closed, or Close is called on the proxy. +// +// It is only valid to call Wait after a successful call to Start. +func (p *Proxy) Wait() error { + <-p.donec + return p.err +} + +// Close closes all the proxy's self-opened listeners. +func (p *Proxy) Close() error { + for _, c := range p.lns { + c.Close() + } + return nil +} + +// Start creates a TCP listener for each unique ipPort from the +// previously created routes and starts the proxy. It returns any +// error from starting listeners. +// +// If it returns a non-nil error, any successfully opened listeners +// are closed. +func (p *Proxy) Start() error { + if p.donec != nil { + return errors.New("already started") + } + p.donec = make(chan struct{}) + errc := make(chan error, len(p.configs)) + p.lns = make([]net.Listener, 0, len(p.configs)) + for ipPort, config := range p.configs { + ln, err := p.netListen()("tcp", ipPort) + if err != nil { + p.Close() + return err + } + p.lns = append(p.lns, ln) + go p.serveListener(errc, ln, config.routes) + } + go p.awaitFirstError(errc) + return nil +} + +func (p *Proxy) awaitFirstError(errc <-chan error) { + p.err = <-errc + close(p.donec) +} + +func (p *Proxy) serveListener(ret chan<- error, ln net.Listener, routes []route) { + for { + c, err := ln.Accept() + if err != nil { + ret <- err + return + } + go p.serveConn(c, routes) + } +} + +// serveConn runs in its own goroutine and matches c against routes. +// It returns whether it matched purely for testing. +func (p *Proxy) serveConn(c net.Conn, routes []route) bool { + br := bufio.NewReader(c) + for _, route := range routes { + if target, hostName := route.match(br); target != nil { + if n := br.Buffered(); n > 0 { + peeked, _ := br.Peek(br.Buffered()) + c = &Conn{ + HostName: hostName, + Peeked: peeked, + Conn: c, + } + } + target.HandleConn(c) + return true + } + } + // TODO: hook for this? + log.Printf("tcpproxy: no routes matched conn %v/%v; closing", c.RemoteAddr().String(), c.LocalAddr().String()) + c.Close() + return false +} + +// Conn is an incoming connection that has had some bytes read from it +// to determine how to route the connection. The Read method stitches +// the peeked bytes and unread bytes back together. +type Conn struct { + // HostName is the hostname field that was sent to the request router. + // In the case of TLS, this is the SNI header, in the case of HTTPHost + // route, it will be the host header. In the case of a fixed + // route, i.e. those created with AddRoute(), this will always be + // empty. This can be useful in the case where further routing decisions + // need to be made in the Target impementation. + HostName string + + // Peeked are the bytes that have been read from Conn for the + // purposes of route matching, but have not yet been consumed + // by Read calls. It set to nil by Read when fully consumed. + Peeked []byte + + // Conn is the underlying connection. + // It can be type asserted against *net.TCPConn or other types + // as needed. It should not be read from directly unless + // Peeked is nil. + net.Conn +} + +func (c *Conn) Read(p []byte) (n int, err error) { + if len(c.Peeked) > 0 { + n = copy(p, c.Peeked) + c.Peeked = c.Peeked[n:] + if len(c.Peeked) == 0 { + c.Peeked = nil + } + return n, nil + } + return c.Conn.Read(p) +} + +// Target is what an incoming matched connection is sent to. +type Target interface { + // HandleConn is called when an incoming connection is + // matched. After the call to HandleConn, the tcpproxy + // package never touches the conn again. Implementations are + // responsible for closing the connection when needed. + // + // The concrete type of conn will be of type *Conn if any + // bytes have been consumed for the purposes of route + // matching. + HandleConn(net.Conn) +} + +// To is shorthand way of writing &tcpproxy.DialProxy{Addr: addr}. +func To(addr string) *DialProxy { + return &DialProxy{Addr: addr} +} + +// DialProxy implements Target by dialing a new connection to Addr +// and then proxying data back and forth. +// +// The To func is a shorthand way of creating a DialProxy. +type DialProxy struct { + // Addr is the TCP address to proxy to. + Addr string + + // KeepAlivePeriod sets the period between TCP keep alives. + // If zero, a default is used. To disable, use a negative number. + // The keep-alive is used for both the client connection and + KeepAlivePeriod time.Duration + + // DialTimeout optionally specifies a dial timeout. + // If zero, a default is used. + // If negative, the timeout is disabled. + DialTimeout time.Duration + + // DialContext optionally specifies an alternate dial function + // for TCP targets. If nil, the standard + // net.Dialer.DialContext method is used. + DialContext func(ctx context.Context, network, address string) (net.Conn, error) + + // OnDialError optionally specifies an alternate way to handle errors dialing Addr. + // If nil, the error is logged and src is closed. + // If non-nil, src is not closed automatically. + OnDialError func(src net.Conn, dstDialErr error) + + // ProxyProtocolVersion optionally specifies the version of + // HAProxy's PROXY protocol to use. The PROXY protocol provides + // connection metadata to the DialProxy target, via a header + // inserted ahead of the client's traffic. The DialProxy target + // must explicitly support and expect the PROXY header; there is + // no graceful downgrade. + // If zero, no PROXY header is sent. Currently, version 1 is supported. + ProxyProtocolVersion int +} + +// UnderlyingConn returns c.Conn if c of type *Conn, +// otherwise it returns c. +func UnderlyingConn(c net.Conn) net.Conn { + if wrap, ok := c.(*Conn); ok { + return wrap.Conn + } + return c +} + +func tcpConn(c net.Conn) (t *net.TCPConn, ok bool) { + if c, ok := UnderlyingConn(c).(*net.TCPConn); ok { + return c, ok + } + if c, ok := c.(*net.TCPConn); ok { + return c, ok + } + return nil, false +} + +func goCloseConn(c net.Conn) { go c.Close() } + +func closeRead(c net.Conn) { + if c, ok := tcpConn(c); ok { + c.CloseRead() + } +} + +func closeWrite(c net.Conn) { + if c, ok := tcpConn(c); ok { + c.CloseWrite() + } +} + +// HandleConn implements the Target interface. +func (dp *DialProxy) HandleConn(src net.Conn) { + ctx := context.Background() + var cancel context.CancelFunc + if dp.DialTimeout >= 0 { + ctx, cancel = context.WithTimeout(ctx, dp.dialTimeout()) + } + dst, err := dp.dialContext()(ctx, "tcp", dp.Addr) + if cancel != nil { + cancel() + } + if err != nil { + dp.onDialError()(src, err) + return + } + defer goCloseConn(dst) + + if err = dp.sendProxyHeader(dst, src); err != nil { + dp.onDialError()(src, err) + return + } + defer goCloseConn(src) + + if ka := dp.keepAlivePeriod(); ka > 0 { + for _, c := range []net.Conn{src, dst} { + if c, ok := tcpConn(c); ok { + c.SetKeepAlive(true) + c.SetKeepAlivePeriod(ka) + } + } + } + + errc := make(chan error, 2) + go proxyCopy(errc, src, dst) + go proxyCopy(errc, dst, src) + <-errc + <-errc +} + +func (dp *DialProxy) sendProxyHeader(w io.Writer, src net.Conn) error { + switch dp.ProxyProtocolVersion { + case 0: + return nil + case 1: + var srcAddr, dstAddr *net.TCPAddr + if a, ok := src.RemoteAddr().(*net.TCPAddr); ok { + srcAddr = a + } + if a, ok := src.LocalAddr().(*net.TCPAddr); ok { + dstAddr = a + } + + if srcAddr == nil || dstAddr == nil { + _, err := io.WriteString(w, "PROXY UNKNOWN\r\n") + return err + } + + family := "TCP4" + if srcAddr.IP.To4() == nil { + family = "TCP6" + } + _, err := fmt.Fprintf(w, "PROXY %s %s %s %d %d\r\n", family, srcAddr.IP, dstAddr.IP, srcAddr.Port, dstAddr.Port) + return err + default: + return fmt.Errorf("PROXY protocol version %d not supported", dp.ProxyProtocolVersion) + } +} + +// proxyCopy is the function that copies bytes around. +// It's a named function instead of a func literal so users get +// named goroutines in debug goroutine stack dumps. +func proxyCopy(errc chan<- error, dst, src net.Conn) { + defer closeRead(src) + defer closeWrite(dst) + + // Before we unwrap src and/or dst, copy any buffered data. + if wc, ok := src.(*Conn); ok && len(wc.Peeked) > 0 { + if _, err := dst.Write(wc.Peeked); err != nil { + errc <- err + return + } + wc.Peeked = nil + } + + // Unwrap the src and dst from *Conn to *net.TCPConn so Go + // 1.11's splice optimization kicks in. + src = UnderlyingConn(src) + dst = UnderlyingConn(dst) + + _, err := io.Copy(dst, src) + errc <- err +} + +func (dp *DialProxy) keepAlivePeriod() time.Duration { + if dp.KeepAlivePeriod != 0 { + return dp.KeepAlivePeriod + } + return time.Minute +} + +func (dp *DialProxy) dialTimeout() time.Duration { + if dp.DialTimeout > 0 { + return dp.DialTimeout + } + return 10 * time.Second +} + +var defaultDialer = new(net.Dialer) + +func (dp *DialProxy) dialContext() func(ctx context.Context, network, address string) (net.Conn, error) { + if dp.DialContext != nil { + return dp.DialContext + } + return defaultDialer.DialContext +} + +func (dp *DialProxy) onDialError() func(src net.Conn, dstDialErr error) { + if dp.OnDialError != nil { + return dp.OnDialError + } + return func(src net.Conn, dstDialErr error) { + log.Printf("tcpproxy: for incoming conn %v, error dialing %q: %v", src.RemoteAddr().String(), dp.Addr, dstDialErr) + src.Close() + } +} diff --git a/tcpproxy_test.go b/tcpproxy_test.go new file mode 100644 index 0000000..0346a7a --- /dev/null +++ b/tcpproxy_test.go @@ -0,0 +1,525 @@ +// Copyright 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcpproxy + +import ( + "bufio" + "bytes" + "context" + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "errors" + "fmt" + "io" + "io/ioutil" + "math/big" + "net" + "strings" + "testing" + "time" +) + +type noopTarget struct{} + +func (noopTarget) HandleConn(net.Conn) {} + +func TestMatchHTTPHost(t *testing.T) { + tests := []struct { + name string + r io.Reader + host string + want bool + }{ + { + name: "match", + r: strings.NewReader("GET / HTTP/1.1\r\nHost: foo.com\r\n\r\n"), + host: "foo.com", + want: true, + }, + { + name: "no-match", + r: strings.NewReader("GET / HTTP/1.1\r\nHost: foo.com\r\n\r\n"), + host: "bar.com", + want: false, + }, + { + name: "match-huge-request", + r: io.MultiReader(strings.NewReader("GET / HTTP/1.1\r\nHost: foo.com\r\n"), neverEnding('a')), + host: "foo.com", + want: true, + }, + } + for i, tt := range tests { + name := tt.name + if name == "" { + name = fmt.Sprintf("test_index_%d", i) + } + t.Run(name, func(t *testing.T) { + br := bufio.NewReader(tt.r) + r := httpHostMatch{equals(tt.host), noopTarget{}} + m, name := r.match(br) + got := m != nil + if got != tt.want { + t.Fatalf("match = %v; want %v", got, tt.want) + } + if tt.want && name != tt.host { + t.Fatalf("host = %s; want %s", name, tt.host) + } + get := make([]byte, 3) + if _, err := io.ReadFull(br, get); err != nil { + t.Fatal(err) + } + if string(get) != "GET" { + t.Fatalf("did bufio.Reader consume bytes? got %q; want GET", get) + } + }) + } +} + +type neverEnding byte + +func (b neverEnding) Read(p []byte) (n int, err error) { + for i := range p { + p[i] = byte(b) + } + return len(p), nil +} + +type recordWritesConn struct { + buf bytes.Buffer + net.Conn +} + +func (c *recordWritesConn) Write(p []byte) (int, error) { + c.buf.Write(p) + return len(p), nil +} + +func (c *recordWritesConn) Read(p []byte) (int, error) { return 0, io.EOF } + +func clientHelloRecord(t *testing.T, hostName string) string { + rec := new(recordWritesConn) + cl := tls.Client(rec, &tls.Config{ServerName: hostName}) + cl.Handshake() + + s := rec.buf.String() + if !strings.Contains(s, hostName) { + t.Fatalf("clientHello sent in test didn't contain %q", hostName) + } + return s +} + +func TestSNI(t *testing.T) { + const hostName = "foo.com" + greeting := clientHelloRecord(t, hostName) + got := clientHelloServerName(bufio.NewReader(strings.NewReader(greeting))) + if got != hostName { + t.Errorf("got SNI %q; want %q", got, hostName) + } +} + +func TestProxyStartNone(t *testing.T) { + var p Proxy + if err := p.Start(); err != nil { + t.Fatal(err) + } +} + +func newLocalListener(t *testing.T) net.Listener { + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + ln, err = net.Listen("tcp", "[::1]:0") + if err != nil { + t.Fatal(err) + } + } + return ln +} + +const testFrontAddr = "1.2.3.4:567" + +func testListenFunc(t *testing.T, ln net.Listener) func(network, laddr string) (net.Listener, error) { + return func(network, laddr string) (net.Listener, error) { + if network != "tcp" { + t.Errorf("got Listen call with network %q, not tcp", network) + return nil, errors.New("invalid network") + } + if laddr != testFrontAddr { + t.Fatalf("got Listen call with laddr %q, want %q", laddr, testFrontAddr) + panic("bogus address") + } + return ln, nil + } +} + +func testProxy(t *testing.T, front net.Listener) *Proxy { + return &Proxy{ + ListenFunc: testListenFunc(t, front), + } +} + +func TestBufferedClose(t *testing.T) { + front := newLocalListener(t) + defer front.Close() + back := newLocalListener(t) + defer back.Close() + + p := testProxy(t, front) + p.AddRoute(testFrontAddr, To(back.Addr().String())) + if err := p.Start(); err != nil { + t.Fatal(err) + } + + toFront, err := net.Dial("tcp", front.Addr().String()) + if err != nil { + t.Fatal(err) + } + defer toFront.Close() + + fromProxy, err := back.Accept() + if err != nil { + t.Fatal(err) + } + defer fromProxy.Close() + const msg = "message" + if _, err := io.WriteString(toFront, msg); err != nil { + t.Fatal(err) + } + // actively close toFront, the write should still make to the back. + toFront.Close() + + buf := make([]byte, len(msg)) + if _, err := io.ReadFull(fromProxy, buf); err != nil { + t.Fatal(err) + } + if string(buf) != msg { + t.Fatalf("got %q; want %q", buf, msg) + } +} + +func TestProxyAlwaysMatch(t *testing.T) { + front := newLocalListener(t) + defer front.Close() + back := newLocalListener(t) + defer back.Close() + + p := testProxy(t, front) + p.AddRoute(testFrontAddr, To(back.Addr().String())) + if err := p.Start(); err != nil { + t.Fatal(err) + } + + toFront, err := net.Dial("tcp", front.Addr().String()) + if err != nil { + t.Fatal(err) + } + defer toFront.Close() + + fromProxy, err := back.Accept() + if err != nil { + t.Fatal(err) + } + defer fromProxy.Close() + const msg = "message" + io.WriteString(toFront, msg) + + buf := make([]byte, len(msg)) + if _, err := io.ReadFull(fromProxy, buf); err != nil { + t.Fatal(err) + } + if string(buf) != msg { + t.Fatalf("got %q; want %q", buf, msg) + } +} + +func TestProxyHTTP(t *testing.T) { + front := newLocalListener(t) + defer front.Close() + + backFoo := newLocalListener(t) + defer backFoo.Close() + backBar := newLocalListener(t) + defer backBar.Close() + + p := testProxy(t, front) + p.AddHTTPHostRoute(testFrontAddr, "foo.com", To(backFoo.Addr().String())) + p.AddHTTPHostRoute(testFrontAddr, "bar.com", To(backBar.Addr().String())) + if err := p.Start(); err != nil { + t.Fatal(err) + } + + toFront, err := net.Dial("tcp", front.Addr().String()) + if err != nil { + t.Fatal(err) + } + defer toFront.Close() + + const msg = "GET / HTTP/1.1\r\nHost: bar.com\r\n\r\n" + io.WriteString(toFront, msg) + + fromProxy, err := backBar.Accept() + if err != nil { + t.Fatal(err) + } + + buf := make([]byte, len(msg)) + if _, err := io.ReadFull(fromProxy, buf); err != nil { + t.Fatal(err) + } + if string(buf) != msg { + t.Fatalf("got %q; want %q", buf, msg) + } +} + +func TestProxySNI(t *testing.T) { + front := newLocalListener(t) + defer front.Close() + + backFoo := newLocalListener(t) + defer backFoo.Close() + backBar := newLocalListener(t) + defer backBar.Close() + + p := testProxy(t, front) + p.AddSNIRoute(testFrontAddr, "foo.com", To(backFoo.Addr().String())) + p.AddSNIRoute(testFrontAddr, "bar.com", To(backBar.Addr().String())) + if err := p.Start(); err != nil { + t.Fatal(err) + } + + toFront, err := net.Dial("tcp", front.Addr().String()) + if err != nil { + t.Fatal(err) + } + defer toFront.Close() + + msg := clientHelloRecord(t, "bar.com") + io.WriteString(toFront, msg) + + fromProxy, err := backBar.Accept() + if err != nil { + t.Fatal(err) + } + + buf := make([]byte, len(msg)) + if _, err := io.ReadFull(fromProxy, buf); err != nil { + t.Fatal(err) + } + if string(buf) != msg { + t.Fatalf("got %q; want %q", buf, msg) + } +} + +func TestAddSNIRouteFunc(t *testing.T) { + front := newLocalListener(t) + defer front.Close() + + backFoo := newLocalListener(t) + defer backFoo.Close() + backBar := newLocalListener(t) + defer backBar.Close() + + p := testProxy(t, front) + p.AddSNIRouteFunc(testFrontAddr, func(ctx context.Context, sniName string) (_ Target, ok bool) { + if sniName == "bar.com" { + return To(backBar.Addr().String()), true + } + t.Fatalf("failed to match %q", sniName) + return nil, false + }) + if err := p.Start(); err != nil { + t.Fatal(err) + } + + toFront, err := net.Dial("tcp", front.Addr().String()) + if err != nil { + t.Fatal(err) + } + defer toFront.Close() + + msg := clientHelloRecord(t, "bar.com") + io.WriteString(toFront, msg) + + fromProxy, err := backBar.Accept() + if err != nil { + t.Fatal(err) + } + + buf := make([]byte, len(msg)) + if _, err := io.ReadFull(fromProxy, buf); err != nil { + t.Fatal(err) + } + if string(buf) != msg { + t.Fatalf("got %q; want %q", buf, msg) + } +} +func TestProxyPROXYOut(t *testing.T) { + front := newLocalListener(t) + defer front.Close() + back := newLocalListener(t) + defer back.Close() + + p := testProxy(t, front) + p.AddRoute(testFrontAddr, &DialProxy{ + Addr: back.Addr().String(), + ProxyProtocolVersion: 1, + }) + if err := p.Start(); err != nil { + t.Fatal(err) + } + + toFront, err := net.Dial("tcp", front.Addr().String()) + if err != nil { + t.Fatal(err) + } + + io.WriteString(toFront, "foo") + toFront.Close() + + fromProxy, err := back.Accept() + if err != nil { + t.Fatal(err) + } + + bs, err := ioutil.ReadAll(fromProxy) + if err != nil { + t.Fatal(err) + } + + want := fmt.Sprintf("PROXY TCP4 %s %s %d %d\r\nfoo", toFront.LocalAddr().(*net.TCPAddr).IP, toFront.RemoteAddr().(*net.TCPAddr).IP, toFront.LocalAddr().(*net.TCPAddr).Port, toFront.RemoteAddr().(*net.TCPAddr).Port) + if string(bs) != want { + t.Fatalf("got %q; want %q", bs, want) + } +} + +type tlsServer struct { + Listener net.Listener + Domain string + Test *testing.T +} + +func (t *tlsServer) Start() { + cert := cert(t.Test, t.Domain) + cfg := &tls.Config{ + Certificates: []tls.Certificate{cert}, + } + cfg.BuildNameToCertificate() + + go func() { + for { + rawConn, err := t.Listener.Accept() + if err != nil { + return // assume Close() + } + + conn := tls.Server(rawConn, cfg) + if _, err = io.WriteString(conn, t.Domain); err != nil { + t.Test.Errorf("writing to tlsconn: %s", err) + } + conn.Close() + } + }() +} + +func (t *tlsServer) Close() { + t.Listener.Close() +} + +// cert creates a well-formed, but completely insecure self-signed +// cert for domain. +func cert(t *testing.T, domain string) tls.Certificate { + private, err := rsa.GenerateKey(rand.Reader, 1024) + if err != nil { + t.Fatal(err) + } + template := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{ + Organization: []string{"Test Co"}, + CommonName: domain, + }, + NotBefore: time.Time{}, + NotAfter: time.Now().Add(60 * time.Minute), + IsCA: true, + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + } + + derBytes, err := x509.CreateCertificate(rand.Reader, template, template, &private.PublicKey, private) + if err != nil { + t.Fatal(err) + } + + var cert, key bytes.Buffer + pem.Encode(&cert, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) + pem.Encode(&key, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(private)}) + + tlscert, err := tls.X509KeyPair(cert.Bytes(), key.Bytes()) + if err != nil { + t.Fatal(err) + } + + return tlscert +} + +// newTLSServer starts a TLS server that serves a self-signed cert for +// domain. +func newTLSServer(t *testing.T, domain string) net.Listener { + cert := cert(t, domain) + + l := newLocalListener(t) + go func() { + for { + rawConn, err := l.Accept() + if err != nil { + return // assume closed + } + + cfg := &tls.Config{ + Certificates: []tls.Certificate{cert}, + } + cfg.BuildNameToCertificate() + conn := tls.Server(rawConn, cfg) + if _, err = io.WriteString(conn, domain); err != nil { + t.Errorf("writing to tlsconn: %s", err) + } + conn.Close() + } + }() + + return l +} + +func readTLS(dest, domain string) (string, error) { + conn, err := tls.Dial("tcp", dest, &tls.Config{ + ServerName: domain, + InsecureSkipVerify: true, + }) + if err != nil { + return "", err + } + defer conn.Close() + + bs, err := ioutil.ReadAll(conn) + if err != nil { + return "", err + } + return string(bs), nil +} |