diff options
Diffstat (limited to 'src/jaegertracing/thrift/lib/go')
116 files changed, 17048 insertions, 0 deletions
diff --git a/src/jaegertracing/thrift/lib/go/Makefile.am b/src/jaegertracing/thrift/lib/go/Makefile.am new file mode 100644 index 000000000..0dfa5fadc --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/Makefile.am @@ -0,0 +1,45 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +SUBDIRS = . + +if WITH_TESTS +SUBDIRS += test +endif + +install: + @echo '##############################################################' + @echo '##############################################################' + @echo 'The Go client library should be installed via "go get", please see /lib/go/README.md' + @echo '##############################################################' + @echo '##############################################################' + +check-local: + GOPATH=`pwd` $(GO) test -race ./thrift + +clean-local: + $(RM) -rf pkg + +all-local: + GOPATH=`pwd` $(GO) build ./thrift + +EXTRA_DIST = \ + thrift \ + coding_standards.md \ + README.md diff --git a/src/jaegertracing/thrift/lib/go/README.md b/src/jaegertracing/thrift/lib/go/README.md new file mode 100644 index 000000000..ce6d5edc1 --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/README.md @@ -0,0 +1,83 @@ +Thrift Go Software Library + +License +======= + +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, +software distributed under the License is distributed on an +"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, either express or implied. See the License for the +specific language governing permissions and limitations +under the License. + + +Using Thrift with Go +==================== + +Thrift supports Go 1.7+ + +In following Go conventions, we recommend you use the 'go' tool to install +Thrift for go. + + $ go get github.com/apache/thrift/lib/go/thrift/... + +Will retrieve and install the most recent version of the package. + + +A note about optional fields +============================ + +The thrift-to-Go compiler tries to represent thrift IDL structs as Go structs. +We must be able to distinguish between optional fields that are set to their +default value and optional values which are actually unset, so the generated +code represents optional fields via pointers. + +This is generally intuitive and works well much of the time, but Go does not +have a syntax for creating a pointer to a constant in a single expression. That +is, given a struct like + + struct SomeIDLType { + OptionalField *int32 + } + +, the following will not compile: + + x := &SomeIDLType{ + OptionalField: &(3), + } + +(Nor is there any other syntax that's built in to the language) + +As such, we provide some helpers that do just this under lib/go/thrift/. E.g., + + x := &SomeIDLType{ + OptionalField: thrift.Int32Ptr(3), + } + +And so on. The code generator also creates analogous helpers for user-defined +typedefs and enums. + +Adding custom tags to generated Thrift structs +============================================== + +You can add tags to the auto-generated thrift structs using the following format: + + struct foo { + 1: required string Bar (go.tag = "some_tag:\"some_tag_value\"") + } + +which will generate: + + type Foo struct { + Bar string `thrift:"bar,1,required" some_tag:"some_tag_value"` + } diff --git a/src/jaegertracing/thrift/lib/go/coding_standards.md b/src/jaegertracing/thrift/lib/go/coding_standards.md new file mode 100644 index 000000000..fa0390bb5 --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/coding_standards.md @@ -0,0 +1 @@ +Please follow [General Coding Standards](/doc/coding_standards.md) diff --git a/src/jaegertracing/thrift/lib/go/test/BinaryKeyTest.thrift b/src/jaegertracing/thrift/lib/go/test/BinaryKeyTest.thrift new file mode 100644 index 000000000..71cb61472 --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/test/BinaryKeyTest.thrift @@ -0,0 +1,24 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +# Make sure that thrift produce compilable code for binary key +struct testStruct { + 1: required map<binary,string> bin_to_string +} + diff --git a/src/jaegertracing/thrift/lib/go/test/ConflictNamespaceServiceTest.thrift b/src/jaegertracing/thrift/lib/go/test/ConflictNamespaceServiceTest.thrift new file mode 100644 index 000000000..aade3d7d6 --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/test/ConflictNamespaceServiceTest.thrift @@ -0,0 +1,25 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +namespace go conflict.context + +include "ConflictNamespaceTestD.thrift" + +service ConflictService { + ConflictNamespaceTestD.ThingD thingFunc() +} diff --git a/src/jaegertracing/thrift/lib/go/test/ConflictNamespaceTestA.thrift b/src/jaegertracing/thrift/lib/go/test/ConflictNamespaceTestA.thrift new file mode 100644 index 000000000..2749da90b --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/test/ConflictNamespaceTestA.thrift @@ -0,0 +1,23 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +namespace go conflicta.common + +struct ThingA { + 1: bool value +} diff --git a/src/jaegertracing/thrift/lib/go/test/ConflictNamespaceTestB.thrift b/src/jaegertracing/thrift/lib/go/test/ConflictNamespaceTestB.thrift new file mode 100644 index 000000000..b1940ff9f --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/test/ConflictNamespaceTestB.thrift @@ -0,0 +1,23 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +namespace go conflictb.common + +struct ThingB { + 1: bool value +} diff --git a/src/jaegertracing/thrift/lib/go/test/ConflictNamespaceTestC.thrift b/src/jaegertracing/thrift/lib/go/test/ConflictNamespaceTestC.thrift new file mode 100644 index 000000000..7d5ee2582 --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/test/ConflictNamespaceTestC.thrift @@ -0,0 +1,23 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +namespace go common + +struct ThingC { + 1: bool value +} diff --git a/src/jaegertracing/thrift/lib/go/test/ConflictNamespaceTestD.thrift b/src/jaegertracing/thrift/lib/go/test/ConflictNamespaceTestD.thrift new file mode 100644 index 000000000..8fe7f5e01 --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/test/ConflictNamespaceTestD.thrift @@ -0,0 +1,23 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +namespace go conflictd.context + +struct ThingD { + 1: bool value +} diff --git a/src/jaegertracing/thrift/lib/go/test/ConflictNamespaceTestSuperThing.thrift b/src/jaegertracing/thrift/lib/go/test/ConflictNamespaceTestSuperThing.thrift new file mode 100644 index 000000000..2348a621f --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/test/ConflictNamespaceTestSuperThing.thrift @@ -0,0 +1,29 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +include "ConflictNamespaceTestA.thrift" +include "ConflictNamespaceTestB.thrift" +include "ConflictNamespaceTestC.thrift" +include "ConflictNamespaceTestD.thrift" + +struct SuperThing { + 1: ConflictNamespaceTestA.ThingA thing_a + 2: ConflictNamespaceTestB.ThingB thing_b + 3: ConflictNamespaceTestC.ThingC thing_c + 4: ConflictNamespaceTestD.ThingD thing_d +} diff --git a/src/jaegertracing/thrift/lib/go/test/DontExportRWTest.thrift b/src/jaegertracing/thrift/lib/go/test/DontExportRWTest.thrift new file mode 100644 index 000000000..39e5a6fc8 --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/test/DontExportRWTest.thrift @@ -0,0 +1,27 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +struct InnerStruct { + 1: required string id +} + +struct TestStruct { + 1: required string id + 2: required InnerStruct inner +} diff --git a/src/jaegertracing/thrift/lib/go/test/ErrorTest.thrift b/src/jaegertracing/thrift/lib/go/test/ErrorTest.thrift new file mode 100644 index 000000000..33b664434 --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/test/ErrorTest.thrift @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + * Contains some contributions under the Thrift Software License. + * Please see doc/old-thrift-license.txt in the Thrift distribution for + * details. + */ +struct TestStruct +{ + 1: map<string, string> m, + 2: list<string> l, + 3: set<string> s, + 4: i32 i +} + +service ErrorTest +{ + TestStruct testStruct(1: TestStruct thing) + string testString(1: string s) +} diff --git a/src/jaegertracing/thrift/lib/go/test/GoTagTest.thrift b/src/jaegertracing/thrift/lib/go/test/GoTagTest.thrift new file mode 100644 index 000000000..5667c6e54 --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/test/GoTagTest.thrift @@ -0,0 +1,25 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +struct tagged { + 1: string string_thing, + 2: i64 int_thing (go.tag = "json:\"int_thing,string\""), + 3: optional i64 optional_int_thing + 4: optional bool optional_bool_thing = false +} diff --git a/src/jaegertracing/thrift/lib/go/test/IgnoreInitialismsTest.thrift b/src/jaegertracing/thrift/lib/go/test/IgnoreInitialismsTest.thrift new file mode 100644 index 000000000..0b30e9ed5 --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/test/IgnoreInitialismsTest.thrift @@ -0,0 +1,26 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +struct IgnoreInitialismsTest { + 1: i64 id, + 2: i64 my_id, + 3: i64 num_cpu, + 4: i64 num_gpu, + 5: i64 my_ID, +} diff --git a/src/jaegertracing/thrift/lib/go/test/IncludesTest.thrift b/src/jaegertracing/thrift/lib/go/test/IncludesTest.thrift new file mode 100644 index 000000000..3f60321f7 --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/test/IncludesTest.thrift @@ -0,0 +1,67 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +include "ThriftTest.thrift" +include "NamespacedTest.thrift" + +const ThriftTest.UserId USERID = 42 +const NamespacedTest.UserId USERID1 = 41 +const ThriftTest.MapType MAPCONSTANT = {'hello':{}, 'goodnight':{}} + +const i32 TWO = NamespacedTest.Stuff.TWO +const i32 THREE = NamespacedTest.THREE + +struct testStruct { + 1: list<ThriftTest.Numberz> listNumbers +} + +struct TestStruct2 { + 1: testStruct blah, + 2: ThriftTest.UserId id, + 3: NamespacedTest.Stuff stuff, +} + +service testService extends ThriftTest.SecondService { + ThriftTest.CrazyNesting getCrazyNesting( + 1: ThriftTest.StructA a, + 2: ThriftTest.Numberz numbers + ) throws(1: ThriftTest.Xception err1), + + void getSomeValue_DO_NOT_CALL(), +} + +service ExtendedService extends testService { + void extendedMethod(), + NamespacedTest.StuffStruct extendedMethod2(), +} + +service Extended2Service extends NamespacedTest.NamespacedService { + void extendedMethod3(), +} + +typedef map<ThriftTest.UserId, map<NamespacedTest.UserId, list<TestStruct2> > > ComplexMapType + +struct ComplexMapStruct { + 1: ComplexMapType complex, +} + +service ComplexMapService { + ComplexMapStruct transformMap(1: ComplexMapStruct input), +} + diff --git a/src/jaegertracing/thrift/lib/go/test/InitialismsTest.thrift b/src/jaegertracing/thrift/lib/go/test/InitialismsTest.thrift new file mode 100644 index 000000000..ad86ac19d --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/test/InitialismsTest.thrift @@ -0,0 +1,24 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +struct InitialismsTest { + 1: string user_id, + 2: string server_url, + 3: string id, +} diff --git a/src/jaegertracing/thrift/lib/go/test/Makefile.am b/src/jaegertracing/thrift/lib/go/test/Makefile.am new file mode 100644 index 000000000..244ddff1f --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/test/Makefile.am @@ -0,0 +1,131 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +THRIFTARGS = -out gopath/src/ --gen go:thrift_import=thrift$(COMPILER_EXTRAFLAG) +THRIFTTEST = $(top_srcdir)/test/ThriftTest.thrift + +# Thrift for GO has problems with complex map keys: THRIFT-2063 +gopath: $(THRIFT) $(THRIFTTEST) \ + IncludesTest.thrift \ + NamespacedTest.thrift \ + MultiplexedProtocolTest.thrift \ + OnewayTest.thrift \ + OptionalFieldsTest.thrift \ + RequiredFieldTest.thrift \ + ServicesTest.thrift \ + GoTagTest.thrift \ + TypedefFieldTest.thrift \ + RefAnnotationFieldsTest.thrift \ + UnionDefaultValueTest.thrift \ + UnionBinaryTest.thrift \ + ErrorTest.thrift \ + NamesTest.thrift \ + InitialismsTest.thrift \ + DontExportRWTest.thrift \ + dontexportrwtest/compile_test.go \ + IgnoreInitialismsTest.thrift \ + ConflictNamespaceTestA.thrift \ + ConflictNamespaceTestB.thrift \ + ConflictNamespaceTestC.thrift \ + ConflictNamespaceTestD.thrift \ + ConflictNamespaceTestSuperThing.thrift \ + ConflictNamespaceServiceTest.thrift + mkdir -p gopath/src + grep -v list.*map.*list.*map $(THRIFTTEST) | grep -v 'set<Insanity>' > ThriftTest.thrift + $(THRIFT) $(THRIFTARGS) -r IncludesTest.thrift + $(THRIFT) $(THRIFTARGS) BinaryKeyTest.thrift + $(THRIFT) $(THRIFTARGS) MultiplexedProtocolTest.thrift + $(THRIFT) $(THRIFTARGS) OnewayTest.thrift + $(THRIFT) $(THRIFTARGS) OptionalFieldsTest.thrift + $(THRIFT) $(THRIFTARGS) RequiredFieldTest.thrift + $(THRIFT) $(THRIFTARGS) ServicesTest.thrift + $(THRIFT) $(THRIFTARGS) GoTagTest.thrift + $(THRIFT) $(THRIFTARGS) TypedefFieldTest.thrift + $(THRIFT) $(THRIFTARGS) RefAnnotationFieldsTest.thrift + $(THRIFT) $(THRIFTARGS) UnionDefaultValueTest.thrift + $(THRIFT) $(THRIFTARGS) UnionBinaryTest.thrift + $(THRIFT) $(THRIFTARGS) ErrorTest.thrift + $(THRIFT) $(THRIFTARGS) NamesTest.thrift + $(THRIFT) $(THRIFTARGS) InitialismsTest.thrift + $(THRIFT) $(THRIFTARGS),read_write_private DontExportRWTest.thrift + $(THRIFT) $(THRIFTARGS),ignore_initialisms IgnoreInitialismsTest.thrift + $(THRIFT) $(THRIFTARGS) ConflictNamespaceTestA.thrift + $(THRIFT) $(THRIFTARGS) ConflictNamespaceTestB.thrift + $(THRIFT) $(THRIFTARGS) ConflictNamespaceTestC.thrift + $(THRIFT) $(THRIFTARGS) ConflictNamespaceTestD.thrift + $(THRIFT) $(THRIFTARGS) ConflictNamespaceTestSuperThing.thrift + $(THRIFT) $(THRIFTARGS) ConflictNamespaceServiceTest.thrift + GOPATH=`pwd`/gopath $(GO) get github.com/golang/mock/gomock || true + sed -i 's/\"context\"/\"golang.org\/x\/net\/context\"/g' gopath/src/github.com/golang/mock/gomock/controller.go || true + GOPATH=`pwd`/gopath $(GO) get github.com/golang/mock/gomock + ln -nfs ../../../thrift gopath/src/thrift + ln -nfs ../../tests gopath/src/tests + cp -r ./dontexportrwtest gopath/src + touch gopath + +check: gopath + GOPATH=`pwd`/gopath $(GO) build \ + includestest \ + binarykeytest \ + servicestest \ + typedeffieldtest \ + refannotationfieldstest \ + errortest \ + namestest \ + initialismstest \ + dontexportrwtest \ + ignoreinitialismstest \ + unionbinarytest \ + conflictnamespacetestsuperthing \ + conflict/context/conflict_service-remote + GOPATH=`pwd`/gopath $(GO) test thrift tests dontexportrwtest + +clean-local: + $(RM) -r gopath ThriftTest.thrift gen-go + +client: stubs + $(GO) run TestClient.go + +EXTRA_DIST = \ + dontexportrwtest \ + tests \ + BinaryKeyTest.thrift \ + GoTagTest.thrift \ + IncludesTest.thrift \ + MultiplexedProtocolTest.thrift \ + NamespacedTest.thrift \ + OnewayTest.thrift \ + OptionalFieldsTest.thrift \ + RequiredFieldTest.thrift \ + RefAnnotationFieldsTest.thrift \ + UnionDefaultValueTest.thrift \ + UnionBinaryTest.thrift \ + ServicesTest.thrift \ + TypedefFieldTest.thrift \ + ErrorTest.thrift \ + NamesTest.thrift \ + InitialismsTest.thrift \ + DontExportRWTest.thrift \ + IgnoreInitialismsTest.thrift \ + ConflictNamespaceTestA.thrift \ + ConflictNamespaceTestB.thrift \ + ConflictNamespaceTestC.thrift \ + ConflictNamespaceTestD.thrift \ + ConflictNamespaceTestSuperThing.thrift + ConflictNamespaceServiceTest.thrift diff --git a/src/jaegertracing/thrift/lib/go/test/MultiplexedProtocolTest.thrift b/src/jaegertracing/thrift/lib/go/test/MultiplexedProtocolTest.thrift new file mode 100644 index 000000000..b263f5925 --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/test/MultiplexedProtocolTest.thrift @@ -0,0 +1,27 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +service First { + i64 returnOne(); +} + +service Second { + i64 returnTwo(); +} + diff --git a/src/jaegertracing/thrift/lib/go/test/NamesTest.thrift b/src/jaegertracing/thrift/lib/go/test/NamesTest.thrift new file mode 100644 index 000000000..e7e9563e8 --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/test/NamesTest.thrift @@ -0,0 +1,32 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +struct NamesTest { + 1: required string type +} + +service NameCollisionOne +{ + void blahBlah() +} + +service NameCollisionTwo +{ + void blahBlah() +} diff --git a/src/jaegertracing/thrift/lib/go/test/NamespacedTest.thrift b/src/jaegertracing/thrift/lib/go/test/NamespacedTest.thrift new file mode 100644 index 000000000..a91035062 --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/test/NamespacedTest.thrift @@ -0,0 +1,40 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +include "ThriftTest.thrift" + +namespace go lib.go.test.namespacedtest + +enum Stuff { + ONE = 1, + TWO = 2, +} + +const i32 THREE = 3; + +typedef i64 UserId + +struct StuffStruct { + 2: Stuff stuff, +} + +service NamespacedService { + ThriftTest.UserId getUserID(), +} + diff --git a/src/jaegertracing/thrift/lib/go/test/OnewayTest.thrift b/src/jaegertracing/thrift/lib/go/test/OnewayTest.thrift new file mode 100644 index 000000000..3242f80fd --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/test/OnewayTest.thrift @@ -0,0 +1,24 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +service OneWay { + oneway void hi(1: i64 i, 2: string s) + void emptyfunc() + i64 echo_int(1: i64 param) +} diff --git a/src/jaegertracing/thrift/lib/go/test/OptionalFieldsTest.thrift b/src/jaegertracing/thrift/lib/go/test/OptionalFieldsTest.thrift new file mode 100644 index 000000000..2afc157c7 --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/test/OptionalFieldsTest.thrift @@ -0,0 +1,50 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +struct structA { + 1: required i64 sa_i +} + +struct all_optional { + 1: optional string s = "DEFAULT", + 2: optional i64 i = 42, + 3: optional bool b = false, + 4: optional string s2, + 5: optional i64 i2, + 6: optional bool b2, + 7: optional structA aa, + 9: optional list<i64> l, + 10: optional list<i64> l2 = [1, 2], + 11: optional map<i64, i64> m, + 12: optional map<i64, i64> m2 = {1:2, 3:4}, + 13: optional binary bin, + 14: optional binary bin2 = "asdf", +} + +struct structB { + 1: required structA required_struct_thing + 2: optional structA optional_struct_thing +} + +struct structC { + 1: string s, + 2: required i32 i, + 3: optional bool b, + 4: required string s2, +} diff --git a/src/jaegertracing/thrift/lib/go/test/RefAnnotationFieldsTest.thrift b/src/jaegertracing/thrift/lib/go/test/RefAnnotationFieldsTest.thrift new file mode 100644 index 000000000..b7f28c28f --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/test/RefAnnotationFieldsTest.thrift @@ -0,0 +1,58 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +struct structA { + 1: required i64 sa_i +} + +struct all_referenced { + 1: optional string s = "DEFAULT" (cpp.ref = ""), + 2: optional i64 i = 42 (cpp.ref = ""), + 3: optional bool b = false (cpp.ref = ""), + 4: optional string s2 (cpp.ref = ""), + 5: optional i64 i2 (cpp.ref = ""), + 6: optional bool b2 (cpp.ref = ""), + 7: optional structA aa (cpp.ref = ""), + 9: optional list<i64> l (cpp.ref = ""), + 10: optional list<i64> l2 = [1, 2] (cpp.ref = ""), + 11: optional map<i64, i64> m (cpp.ref = ""), + 12: optional map<i64, i64> m2 = {1:2, 3:4} (cpp.ref = ""), + 13: optional binary bin (cpp.ref = ""), + 14: optional binary bin2 = "asdf" (cpp.ref = ""), + + 15: required string ref_s = "DEFAULT" (cpp.ref = ""), + 16: required i64 ref_i = 42 (cpp.ref = ""), + 17: required bool ref_b = false (cpp.ref = ""), + 18: required string ref_s2 (cpp.ref = ""), + 19: required i64 ref_i2 (cpp.ref = ""), + 20: required bool ref_b2 (cpp.ref = ""), + 21: required structA ref_aa (cpp.ref = ""), + 22: required list<i64> ref_l (cpp.ref = ""), + 23: required list<i64> ref_l2 = [1, 2] (cpp.ref = ""), + 24: required map<i64, i64> ref_m (cpp.ref = ""), + 25: required map<i64, i64> ref_m2 = {1:2, 3:4} (cpp.ref = ""), + 26: required binary ref_bin (cpp.ref = ""), + 27: required binary ref_bin2 = "asdf" (cpp.ref = ""), + +} + +struct structB { + 1: required structA required_struct_thing + 2: optional structA optional_struct_thing +} diff --git a/src/jaegertracing/thrift/lib/go/test/RequiredFieldTest.thrift b/src/jaegertracing/thrift/lib/go/test/RequiredFieldTest.thrift new file mode 100644 index 000000000..4a2dcaeba --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/test/RequiredFieldTest.thrift @@ -0,0 +1,7 @@ +struct RequiredField { + 1: required string name +} + +struct OtherThing { + 1: required i16 value +} diff --git a/src/jaegertracing/thrift/lib/go/test/ServicesTest.thrift b/src/jaegertracing/thrift/lib/go/test/ServicesTest.thrift new file mode 100644 index 000000000..882b03acb --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/test/ServicesTest.thrift @@ -0,0 +1,111 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +# We are only testing that generated code compiles, no correctness checking is done + +exception moderate_disaster { + 1: i32 errorCode, + 2: string message +} + +exception total_disaster { + 1: string message + 2: optional bool president_was_woken_up = false +} + +struct struct_a { + 1: required i64 whatever +} + +service a_serv { + void voidfunc(), + void void_with_1ex() throws(1: moderate_disaster err1) + void void_with_2ex() throws(1: moderate_disaster err1, 2:total_disaster err2) + + string stringfunc() + string stringfunc_1ex() throws(1: moderate_disaster err1) + string stringfunc_2ex() throws(1: moderate_disaster err1, 2:total_disaster err2) + + i64 i64func() + i64 i64func_1ex() throws(1: moderate_disaster err1) + i64 i64func_2ex() throws(1: moderate_disaster err1, 2:total_disaster err2) + + list<string> list_of_strings_func() + list<string> list_of_strings_func_1ex() throws(1: moderate_disaster err1) + list<string> list_of_strings_func_2ex() throws(1: moderate_disaster err1, 2:total_disaster err2) + + map<i64,string> map_func() + map<i64,string> map_func_1ex() throws(1: moderate_disaster err1) + map<i64,string> map_func_2ex() throws(1: moderate_disaster err1, 2:total_disaster err2) + + struct_a struct_a_func() + struct_a struct_a_func_1ex() throws(1: moderate_disaster err1) + struct_a struct_a_func_2ex() throws(1: moderate_disaster err1, 2:total_disaster err2) + + void voidfunc_1int(1: i64 i), + void void_with_1ex_1int(1: i64 i) throws(1: moderate_disaster err1) + void void_with_2ex_1int(1: i64 i) throws(1: moderate_disaster err1, 2:total_disaster err2) + + string stringfunc_1int(1: i64 i) + string stringfunc_1ex_1int(1: i64 i) throws(1: moderate_disaster err1) + string stringfunc_2ex_1int(1: i64 i) throws(1: moderate_disaster err1, 2:total_disaster err2) + + i64 i64func_1int(1: i64 i) + i64 i64func_1ex_1int(1: i64 i) throws(1: moderate_disaster err1) + i64 i64func_2ex_1int(1: i64 i) throws(1: moderate_disaster err1, 2:total_disaster err2) + + list<string> list_of_strings_func_1int(1: i64 i) + list<string> list_of_strings_func_1ex_1int(1: i64 i) throws(1: moderate_disaster err1) + list<string> list_of_strings_func_2ex_1int(1: i64 i) throws(1: moderate_disaster err1, 2:total_disaster err2) + + map<i64,string> map_func_1int(1: i64 i) + map<i64,string> map_func_1ex_1int(1: i64 i) throws(1: moderate_disaster err1) + map<i64,string> map_func_2ex_1int(1: i64 i) throws(1: moderate_disaster err1, 2:total_disaster err2) + + struct_a struct_a_func_1int(1: i64 i) + struct_a struct_a_func_1ex_1int(1: i64 i) throws(1: moderate_disaster err1) + struct_a struct_a_func_2ex_1int(1: i64 i) throws(1: moderate_disaster err1, 2:total_disaster err2) + + void voidfunc_1int_1s(1: i64 i, 2: string s), + void void_with_1ex_1int_1s(1: i64 i, 2: string s) throws(1: moderate_disaster err1) + void void_with_2ex_1int_1s(1: i64 i, 2: string s) throws(1: moderate_disaster err1, 2:total_disaster err2) + + string stringfunc_1int_1s(1: i64 i, 2: string s) + string stringfunc_1ex_1int_1s(1: i64 i, 2: string s) throws(1: moderate_disaster err1) + string stringfunc_2ex_1int_1s(1: i64 i, 2: string s) throws(1: moderate_disaster err1, 2:total_disaster err2) + + i64 i64func_1int_1s(1: i64 i, 2: string s) + i64 i64func_1ex_1int_1s(1: i64 i, 2: string s) throws(1: moderate_disaster err1) + i64 i64func_2ex_1int_1s(1: i64 i, 2: string s) throws(1: moderate_disaster err1, 2:total_disaster err2) + + list<string> list_of_strings_func_1int_1s(1: i64 i, 2: string s) + list<string> list_of_strings_func_1ex_1int_1s(1: i64 i, 2: string s) throws(1: moderate_disaster err1) + list<string> list_of_strings_func_2ex_1int_1s(1: i64 i, 2: string s) throws(1: moderate_disaster err1, 2:total_disaster err2) + + map<i64,string> map_func_1int_1s(1: i64 i, 2: string s) + map<i64,string> map_func_1ex_1int_1s(1: i64 i, 2: string s) throws(1: moderate_disaster err1) + map<i64,string> map_func_2ex_1int_1s(1: i64 i, 2: string s) throws(1: moderate_disaster err1, 2:total_disaster err2) + + struct_a struct_a_func_1int_1s(1: i64 i, 2: string s) + struct_a struct_a_func_1ex_1int_1s(1: i64 i, 2: string s) throws(1: moderate_disaster err1) + struct_a struct_a_func_2ex_1int_1s(1: i64 i, 2: string s) throws(1: moderate_disaster err1, 2:total_disaster err2) + + struct_a struct_a_func_1struct_a(1: struct_a st) + +} diff --git a/src/jaegertracing/thrift/lib/go/test/TypedefFieldTest.thrift b/src/jaegertracing/thrift/lib/go/test/TypedefFieldTest.thrift new file mode 100644 index 000000000..390e8c88e --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/test/TypedefFieldTest.thrift @@ -0,0 +1,39 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +# We are only testing that generated code compiles, no correctness checking is done + +enum Details { + Everything = 0 + StateOnly = 1 + StateAndOptions = 2 + SomethingElse = 3 +} + +typedef list< Details> DetailsWanted + +struct BaseRequest { + 1 : optional string RequestID +} + +struct GetMyDetails { + 1 : required BaseRequest base_ + 2 : required string ObjectID + 3 : optional DetailsWanted DetailsWanted +} diff --git a/src/jaegertracing/thrift/lib/go/test/UnionBinaryTest.thrift b/src/jaegertracing/thrift/lib/go/test/UnionBinaryTest.thrift new file mode 100644 index 000000000..f77112bda --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/test/UnionBinaryTest.thrift @@ -0,0 +1,25 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +# See https://issues.apache.org/jira/browse/THRIFT-4573 +union Sample { + 1: map<string, string> u1, + 2: binary u2, + 3: list<string> u3 +} diff --git a/src/jaegertracing/thrift/lib/go/test/UnionDefaultValueTest.thrift b/src/jaegertracing/thrift/lib/go/test/UnionDefaultValueTest.thrift new file mode 100644 index 000000000..4c9348003 --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/test/UnionDefaultValueTest.thrift @@ -0,0 +1,34 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +struct Option1 { +} + +struct Option2 { + 1: optional string name +} + +union Descendant { + 1: Option1 option1 + 2: Option2 option2 +} + +struct TestStruct { + 1: optional Descendant descendant = { "option1": {}} +} diff --git a/src/jaegertracing/thrift/lib/go/test/dontexportrwtest/compile_test.go b/src/jaegertracing/thrift/lib/go/test/dontexportrwtest/compile_test.go new file mode 100644 index 000000000..cf6763e29 --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/test/dontexportrwtest/compile_test.go @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package dontexportrwtest + +import ( + "testing" +) + +// Make sure that thrift generates non-exported read/write methods if +// read_write_private option is specified +func TestReadWriteMethodsArePrivate(t *testing.T) { + // This will only compile if read/write methods exist + s := NewTestStruct() + _ = s.read + _ = s.write + + is := NewInnerStruct() + _ = is.read + _ = is.write +} diff --git a/src/jaegertracing/thrift/lib/go/test/tests/binary_key_test.go b/src/jaegertracing/thrift/lib/go/test/tests/binary_key_test.go new file mode 100644 index 000000000..aa9619391 --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/test/tests/binary_key_test.go @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package tests + +import ( + "binarykeytest" + "testing" +) + +func TestBinaryMapKeyGeneratesString(t *testing.T) { + s := binarykeytest.NewTestStruct() + //This will only compile if BinToString has type of map[string]string + s.BinToString = make(map[string]string) +} diff --git a/src/jaegertracing/thrift/lib/go/test/tests/client_error_test.go b/src/jaegertracing/thrift/lib/go/test/tests/client_error_test.go new file mode 100644 index 000000000..fdec4ea57 --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/test/tests/client_error_test.go @@ -0,0 +1,873 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package tests + +import ( + "context" + "errors" + "errortest" + "testing" + "thrift" + + "github.com/golang/mock/gomock" +) + +// TestCase: Comprehensive call and reply workflow in the client. +// Setup mock to fail at a certain position. Return true if position exists otherwise false. +func prepareClientCallReply(protocol *MockTProtocol, failAt int, failWith error) bool { + var err error = nil + + if failAt == 0 { + err = failWith + } + last := protocol.EXPECT().WriteMessageBegin("testStruct", thrift.CALL, int32(1)).Return(err) + if failAt == 0 { + return true + } + if failAt == 1 { + err = failWith + } + last = protocol.EXPECT().WriteStructBegin("testStruct_args").Return(err).After(last) + if failAt == 1 { + return true + } + if failAt == 2 { + err = failWith + } + last = protocol.EXPECT().WriteFieldBegin("thing", thrift.TType(thrift.STRUCT), int16(1)).Return(err).After(last) + if failAt == 2 { + return true + } + if failAt == 3 { + err = failWith + } + last = protocol.EXPECT().WriteStructBegin("TestStruct").Return(err).After(last) + if failAt == 3 { + return true + } + if failAt == 4 { + err = failWith + } + last = protocol.EXPECT().WriteFieldBegin("m", thrift.TType(thrift.MAP), int16(1)).Return(err).After(last) + if failAt == 4 { + return true + } + if failAt == 5 { + err = failWith + } + last = protocol.EXPECT().WriteMapBegin(thrift.TType(thrift.STRING), thrift.TType(thrift.STRING), 0).Return(err).After(last) + if failAt == 5 { + return true + } + if failAt == 6 { + err = failWith + } + last = protocol.EXPECT().WriteMapEnd().Return(err).After(last) + if failAt == 6 { + return true + } + if failAt == 7 { + err = failWith + } + last = protocol.EXPECT().WriteFieldEnd().Return(err).After(last) + if failAt == 7 { + return true + } + if failAt == 8 { + err = failWith + } + last = protocol.EXPECT().WriteFieldBegin("l", thrift.TType(thrift.LIST), int16(2)).Return(err).After(last) + if failAt == 8 { + return true + } + if failAt == 9 { + err = failWith + } + last = protocol.EXPECT().WriteListBegin(thrift.TType(thrift.STRING), 0).Return(err).After(last) + if failAt == 9 { + return true + } + if failAt == 10 { + err = failWith + } + last = protocol.EXPECT().WriteListEnd().Return(err).After(last) + if failAt == 10 { + return true + } + if failAt == 11 { + err = failWith + } + last = protocol.EXPECT().WriteFieldEnd().Return(err).After(last) + if failAt == 11 { + return true + } + if failAt == 12 { + err = failWith + } + + last = protocol.EXPECT().WriteFieldBegin("s", thrift.TType(thrift.SET), int16(3)).Return(err).After(last) + if failAt == 12 { + return true + } + if failAt == 13 { + err = failWith + } + last = protocol.EXPECT().WriteSetBegin(thrift.TType(thrift.STRING), 0).Return(err).After(last) + if failAt == 13 { + return true + } + if failAt == 14 { + err = failWith + } + last = protocol.EXPECT().WriteSetEnd().Return(err).After(last) + if failAt == 14 { + return true + } + if failAt == 15 { + err = failWith + } + last = protocol.EXPECT().WriteFieldEnd().Return(err).After(last) + if failAt == 15 { + return true + } + if failAt == 16 { + err = failWith + } + last = protocol.EXPECT().WriteFieldBegin("i", thrift.TType(thrift.I32), int16(4)).Return(err).After(last) + if failAt == 16 { + return true + } + if failAt == 17 { + err = failWith + } + last = protocol.EXPECT().WriteI32(int32(3)).Return(err).After(last) + if failAt == 17 { + return true + } + if failAt == 18 { + err = failWith + } + last = protocol.EXPECT().WriteFieldEnd().Return(err).After(last) + if failAt == 18 { + return true + } + if failAt == 19 { + err = failWith + } + last = protocol.EXPECT().WriteFieldStop().Return(err).After(last) + if failAt == 19 { + return true + } + if failAt == 20 { + err = failWith + } + last = protocol.EXPECT().WriteStructEnd().Return(err).After(last) + if failAt == 20 { + return true + } + if failAt == 21 { + err = failWith + } + last = protocol.EXPECT().WriteFieldEnd().Return(err).After(last) + if failAt == 21 { + return true + } + if failAt == 22 { + err = failWith + } + last = protocol.EXPECT().WriteFieldStop().Return(err).After(last) + if failAt == 22 { + return true + } + if failAt == 23 { + err = failWith + } + last = protocol.EXPECT().WriteStructEnd().Return(err).After(last) + if failAt == 23 { + return true + } + if failAt == 24 { + err = failWith + } + last = protocol.EXPECT().WriteMessageEnd().Return(err).After(last) + if failAt == 24 { + return true + } + if failAt == 25 { + err = failWith + } + last = protocol.EXPECT().Flush(context.Background()).Return(err).After(last) + if failAt == 25 { + return true + } + if failAt == 26 { + err = failWith + } + last = protocol.EXPECT().ReadMessageBegin().Return("testStruct", thrift.REPLY, int32(1), err).After(last) + if failAt == 26 { + return true + } + if failAt == 27 { + err = failWith + } + last = protocol.EXPECT().ReadStructBegin().Return("testStruct_args", err).After(last) + if failAt == 27 { + return true + } + if failAt == 28 { + err = failWith + } + last = protocol.EXPECT().ReadFieldBegin().Return("_", thrift.TType(thrift.STRUCT), int16(0), err).After(last) + if failAt == 28 { + return true + } + if failAt == 29 { + err = failWith + } + last = protocol.EXPECT().ReadStructBegin().Return("TestStruct", err).After(last) + if failAt == 29 { + return true + } + if failAt == 30 { + err = failWith + } + last = protocol.EXPECT().ReadFieldBegin().Return("m", thrift.TType(thrift.MAP), int16(1), err).After(last) + if failAt == 30 { + return true + } + if failAt == 31 { + err = failWith + } + last = protocol.EXPECT().ReadMapBegin().Return(thrift.TType(thrift.STRING), thrift.TType(thrift.STRING), 0, err).After(last) + if failAt == 31 { + return true + } + if failAt == 32 { + err = failWith + } + last = protocol.EXPECT().ReadMapEnd().Return(err).After(last) + if failAt == 32 { + return true + } + if failAt == 33 { + err = failWith + } + last = protocol.EXPECT().ReadFieldEnd().Return(err).After(last) + if failAt == 33 { + return true + } + if failAt == 34 { + err = failWith + } + last = protocol.EXPECT().ReadFieldBegin().Return("l", thrift.TType(thrift.LIST), int16(2), err).After(last) + if failAt == 34 { + return true + } + if failAt == 35 { + err = failWith + } + last = protocol.EXPECT().ReadListBegin().Return(thrift.TType(thrift.STRING), 0, err).After(last) + if failAt == 35 { + return true + } + if failAt == 36 { + err = failWith + } + last = protocol.EXPECT().ReadListEnd().Return(err).After(last) + if failAt == 36 { + return true + } + if failAt == 37 { + err = failWith + } + last = protocol.EXPECT().ReadFieldEnd().Return(err).After(last) + if failAt == 37 { + return true + } + if failAt == 38 { + err = failWith + } + last = protocol.EXPECT().ReadFieldBegin().Return("s", thrift.TType(thrift.SET), int16(3), err).After(last) + if failAt == 38 { + return true + } + if failAt == 39 { + err = failWith + } + last = protocol.EXPECT().ReadSetBegin().Return(thrift.TType(thrift.STRING), 0, err).After(last) + if failAt == 39 { + return true + } + if failAt == 40 { + err = failWith + } + last = protocol.EXPECT().ReadSetEnd().Return(err).After(last) + if failAt == 40 { + return true + } + if failAt == 41 { + err = failWith + } + last = protocol.EXPECT().ReadFieldEnd().Return(err).After(last) + if failAt == 41 { + return true + } + if failAt == 42 { + err = failWith + } + last = protocol.EXPECT().ReadFieldBegin().Return("i", thrift.TType(thrift.I32), int16(4), err).After(last) + if failAt == 42 { + return true + } + if failAt == 43 { + err = failWith + } + last = protocol.EXPECT().ReadI32().Return(int32(3), err).After(last) + if failAt == 43 { + return true + } + if failAt == 44 { + err = failWith + } + last = protocol.EXPECT().ReadFieldEnd().Return(err).After(last) + if failAt == 44 { + return true + } + if failAt == 45 { + err = failWith + } + last = protocol.EXPECT().ReadFieldBegin().Return("_", thrift.TType(thrift.STOP), int16(5), err).After(last) + if failAt == 45 { + return true + } + if failAt == 46 { + err = failWith + } + last = protocol.EXPECT().ReadStructEnd().Return(err).After(last) + if failAt == 46 { + return true + } + if failAt == 47 { + err = failWith + } + last = protocol.EXPECT().ReadFieldEnd().Return(err).After(last) + if failAt == 47 { + return true + } + if failAt == 48 { + err = failWith + } + last = protocol.EXPECT().ReadFieldBegin().Return("_", thrift.TType(thrift.STOP), int16(1), err).After(last) + if failAt == 48 { + return true + } + if failAt == 49 { + err = failWith + } + last = protocol.EXPECT().ReadStructEnd().Return(err).After(last) + if failAt == 49 { + return true + } + if failAt == 50 { + err = failWith + } + last = protocol.EXPECT().ReadMessageEnd().Return(err).After(last) + if failAt == 50 { + return true + } + return false +} + +// TestCase: Comprehensive call and reply workflow in the client. +// Expecting TTransportError on fail. +func TestClientReportTTransportErrors(t *testing.T) { + mockCtrl := gomock.NewController(t) + + thing := errortest.NewTestStruct() + thing.M = make(map[string]string) + thing.L = make([]string, 0) + thing.S = make([]string, 0) + thing.I = 3 + + err := thrift.NewTTransportException(thrift.TIMED_OUT, "test") + for i := 0; ; i++ { + protocol := NewMockTProtocol(mockCtrl) + if !prepareClientCallReply(protocol, i, err) { + return + } + client := errortest.NewErrorTestClient(thrift.NewTStandardClient(protocol, protocol)) + _, retErr := client.TestStruct(defaultCtx, thing) + mockCtrl.Finish() + mockCtrl = gomock.NewController(t) + err2, ok := retErr.(thrift.TTransportException) + if !ok { + t.Fatal("Expected a TTrasportException") + } + + if err2.TypeId() != thrift.TIMED_OUT { + t.Fatal("Expected TIMED_OUT error") + } + } +} + +// TestCase: Comprehensive call and reply workflow in the client. +// Expecting TTransportError on fail. +// Similar to TestClientReportTTransportErrors, but using legacy client constructor. +func TestClientReportTTransportErrorsLegacy(t *testing.T) { + mockCtrl := gomock.NewController(t) + transport := thrift.NewTMemoryBuffer() + thing := errortest.NewTestStruct() + thing.M = make(map[string]string) + thing.L = make([]string, 0) + thing.S = make([]string, 0) + thing.I = 3 + + err := thrift.NewTTransportException(thrift.TIMED_OUT, "test") + for i := 0; ; i++ { + protocol := NewMockTProtocol(mockCtrl) + if !prepareClientCallReply(protocol, i, err) { + return + } + client := errortest.NewErrorTestClientProtocol(transport, protocol, protocol) + _, retErr := client.TestStruct(defaultCtx, thing) + mockCtrl.Finish() + mockCtrl = gomock.NewController(t) + err2, ok := retErr.(thrift.TTransportException) + if !ok { + t.Fatal("Expected a TTrasportException") + } + + if err2.TypeId() != thrift.TIMED_OUT { + t.Fatal("Expected TIMED_OUT error") + } + } +} + +// TestCase: Comprehensive call and reply workflow in the client. +// Expecting TTProtocolErrors on fail. +func TestClientReportTProtocolErrors(t *testing.T) { + mockCtrl := gomock.NewController(t) + + thing := errortest.NewTestStruct() + thing.M = make(map[string]string) + thing.L = make([]string, 0) + thing.S = make([]string, 0) + thing.I = 3 + + err := thrift.NewTProtocolExceptionWithType(thrift.INVALID_DATA, errors.New("test")) + for i := 0; ; i++ { + protocol := NewMockTProtocol(mockCtrl) + if !prepareClientCallReply(protocol, i, err) { + return + } + client := errortest.NewErrorTestClient(thrift.NewTStandardClient(protocol, protocol)) + _, retErr := client.TestStruct(defaultCtx, thing) + mockCtrl.Finish() + mockCtrl = gomock.NewController(t) + err2, ok := retErr.(thrift.TProtocolException) + if !ok { + t.Fatal("Expected a TProtocolException") + } + if err2.TypeId() != thrift.INVALID_DATA { + t.Fatal("Expected INVALID_DATA error") + } + } +} + +// TestCase: Comprehensive call and reply workflow in the client. +// Expecting TTProtocolErrors on fail. +// Similar to TestClientReportTProtocolErrors, but using legacy client constructor. +func TestClientReportTProtocolErrorsLegacy(t *testing.T) { + mockCtrl := gomock.NewController(t) + transport := thrift.NewTMemoryBuffer() + thing := errortest.NewTestStruct() + thing.M = make(map[string]string) + thing.L = make([]string, 0) + thing.S = make([]string, 0) + thing.I = 3 + + err := thrift.NewTProtocolExceptionWithType(thrift.INVALID_DATA, errors.New("test")) + for i := 0; ; i++ { + protocol := NewMockTProtocol(mockCtrl) + if !prepareClientCallReply(protocol, i, err) { + return + } + client := errortest.NewErrorTestClientProtocol(transport, protocol, protocol) + _, retErr := client.TestStruct(defaultCtx, thing) + mockCtrl.Finish() + mockCtrl = gomock.NewController(t) + err2, ok := retErr.(thrift.TProtocolException) + if !ok { + t.Fatal("Expected a TProtocolException") + } + if err2.TypeId() != thrift.INVALID_DATA { + t.Fatal("Expected INVALID_DATA error") + } + } +} + +// TestCase: call and reply with exception workflow in the client. +// Setup mock to fail at a certain position. Return true if position exists otherwise false. +func prepareClientCallException(protocol *MockTProtocol, failAt int, failWith error) bool { + var err error = nil + + // No need to test failure in this block, because it is covered in other test cases + last := protocol.EXPECT().WriteMessageBegin("testString", thrift.CALL, int32(1)) + last = protocol.EXPECT().WriteStructBegin("testString_args").After(last) + last = protocol.EXPECT().WriteFieldBegin("s", thrift.TType(thrift.STRING), int16(1)).After(last) + last = protocol.EXPECT().WriteString("test").After(last) + last = protocol.EXPECT().WriteFieldEnd().After(last) + last = protocol.EXPECT().WriteFieldStop().After(last) + last = protocol.EXPECT().WriteStructEnd().After(last) + last = protocol.EXPECT().WriteMessageEnd().After(last) + last = protocol.EXPECT().Flush(context.Background()).After(last) + + // Reading the exception, might fail. + if failAt == 0 { + err = failWith + } + last = protocol.EXPECT().ReadMessageBegin().Return("testString", thrift.EXCEPTION, int32(1), err).After(last) + if failAt == 0 { + return true + } + if failAt == 1 { + err = failWith + } + last = protocol.EXPECT().ReadStructBegin().Return("TApplicationException", err).After(last) + if failAt == 1 { + return true + } + if failAt == 2 { + err = failWith + } + last = protocol.EXPECT().ReadFieldBegin().Return("message", thrift.TType(thrift.STRING), int16(1), err).After(last) + if failAt == 2 { + return true + } + if failAt == 3 { + err = failWith + } + last = protocol.EXPECT().ReadString().Return("test", err).After(last) + if failAt == 3 { + return true + } + if failAt == 4 { + err = failWith + } + last = protocol.EXPECT().ReadFieldEnd().Return(err).After(last) + if failAt == 4 { + return true + } + if failAt == 5 { + err = failWith + } + last = protocol.EXPECT().ReadFieldBegin().Return("type", thrift.TType(thrift.I32), int16(2), err).After(last) + if failAt == 5 { + return true + } + if failAt == 6 { + err = failWith + } + last = protocol.EXPECT().ReadI32().Return(int32(thrift.PROTOCOL_ERROR), err).After(last) + if failAt == 6 { + return true + } + if failAt == 7 { + err = failWith + } + last = protocol.EXPECT().ReadFieldEnd().Return(err).After(last) + if failAt == 7 { + return true + } + if failAt == 8 { + err = failWith + } + last = protocol.EXPECT().ReadFieldBegin().Return("_", thrift.TType(thrift.STOP), int16(2), err).After(last) + if failAt == 8 { + return true + } + if failAt == 9 { + err = failWith + } + last = protocol.EXPECT().ReadStructEnd().Return(err).After(last) + if failAt == 9 { + return true + } + if failAt == 10 { + err = failWith + } + last = protocol.EXPECT().ReadMessageEnd().Return(err).After(last) + if failAt == 10 { + return true + } + + return false +} + +// TestCase: call and reply with exception workflow in the client. +func TestClientCallException(t *testing.T) { + mockCtrl := gomock.NewController(t) + + err := thrift.NewTTransportException(thrift.TIMED_OUT, "test") + for i := 0; ; i++ { + protocol := NewMockTProtocol(mockCtrl) + willComplete := !prepareClientCallException(protocol, i, err) + + client := errortest.NewErrorTestClient(thrift.NewTStandardClient(protocol, protocol)) + _, retErr := client.TestString(defaultCtx, "test") + mockCtrl.Finish() + mockCtrl = gomock.NewController(t) + + if !willComplete { + err2, ok := retErr.(thrift.TTransportException) + if !ok { + t.Fatal("Expected a TTransportException") + } + if err2.TypeId() != thrift.TIMED_OUT { + t.Fatal("Expected TIMED_OUT error") + } + } else { + err2, ok := retErr.(thrift.TApplicationException) + if !ok { + t.Fatal("Expected a TApplicationException") + } + if err2.TypeId() != thrift.PROTOCOL_ERROR { + t.Fatal("Expected PROTOCOL_ERROR error") + } + break + } + } +} + +// TestCase: call and reply with exception workflow in the client. +// Similar to TestClientCallException, but using legacy client constructor. +func TestClientCallExceptionLegacy(t *testing.T) { + mockCtrl := gomock.NewController(t) + transport := thrift.NewTMemoryBuffer() + err := thrift.NewTTransportException(thrift.TIMED_OUT, "test") + for i := 0; ; i++ { + protocol := NewMockTProtocol(mockCtrl) + willComplete := !prepareClientCallException(protocol, i, err) + + client := errortest.NewErrorTestClientProtocol(transport, protocol, protocol) + _, retErr := client.TestString(defaultCtx, "test") + mockCtrl.Finish() + mockCtrl = gomock.NewController(t) + + if !willComplete { + err2, ok := retErr.(thrift.TTransportException) + if !ok { + t.Fatal("Expected a TTransportException") + } + if err2.TypeId() != thrift.TIMED_OUT { + t.Fatal("Expected TIMED_OUT error") + } + } else { + err2, ok := retErr.(thrift.TApplicationException) + if !ok { + t.Fatal("Expected a TApplicationException") + } + if err2.TypeId() != thrift.PROTOCOL_ERROR { + t.Fatal("Expected PROTOCOL_ERROR error") + } + break + } + } +} + +// TestCase: Mismatching sequence id has been received in the client. +func TestClientSeqIdMismatch(t *testing.T) { + mockCtrl := gomock.NewController(t) + protocol := NewMockTProtocol(mockCtrl) + gomock.InOrder( + protocol.EXPECT().WriteMessageBegin("testString", thrift.CALL, int32(1)), + protocol.EXPECT().WriteStructBegin("testString_args"), + protocol.EXPECT().WriteFieldBegin("s", thrift.TType(thrift.STRING), int16(1)), + protocol.EXPECT().WriteString("test"), + protocol.EXPECT().WriteFieldEnd(), + protocol.EXPECT().WriteFieldStop(), + protocol.EXPECT().WriteStructEnd(), + protocol.EXPECT().WriteMessageEnd(), + protocol.EXPECT().Flush(context.Background()), + protocol.EXPECT().ReadMessageBegin().Return("testString", thrift.REPLY, int32(2), nil), + ) + + client := errortest.NewErrorTestClient(thrift.NewTStandardClient(protocol, protocol)) + _, err := client.TestString(defaultCtx, "test") + mockCtrl.Finish() + appErr, ok := err.(thrift.TApplicationException) + if !ok { + t.Fatal("Expected TApplicationException") + } + if appErr.TypeId() != thrift.BAD_SEQUENCE_ID { + t.Fatal("Expected BAD_SEQUENCE_ID error") + } +} + +// TestCase: Mismatching sequence id has been received in the client. +// Similar to TestClientSeqIdMismatch, but using legacy client constructor. +func TestClientSeqIdMismatchLegeacy(t *testing.T) { + mockCtrl := gomock.NewController(t) + transport := thrift.NewTMemoryBuffer() + protocol := NewMockTProtocol(mockCtrl) + gomock.InOrder( + protocol.EXPECT().WriteMessageBegin("testString", thrift.CALL, int32(1)), + protocol.EXPECT().WriteStructBegin("testString_args"), + protocol.EXPECT().WriteFieldBegin("s", thrift.TType(thrift.STRING), int16(1)), + protocol.EXPECT().WriteString("test"), + protocol.EXPECT().WriteFieldEnd(), + protocol.EXPECT().WriteFieldStop(), + protocol.EXPECT().WriteStructEnd(), + protocol.EXPECT().WriteMessageEnd(), + protocol.EXPECT().Flush(context.Background()), + protocol.EXPECT().ReadMessageBegin().Return("testString", thrift.REPLY, int32(2), nil), + ) + + client := errortest.NewErrorTestClientProtocol(transport, protocol, protocol) + _, err := client.TestString(defaultCtx, "test") + mockCtrl.Finish() + appErr, ok := err.(thrift.TApplicationException) + if !ok { + t.Fatal("Expected TApplicationException") + } + if appErr.TypeId() != thrift.BAD_SEQUENCE_ID { + t.Fatal("Expected BAD_SEQUENCE_ID error") + } +} + +// TestCase: Wrong method name has been received in the client. +func TestClientWrongMethodName(t *testing.T) { + mockCtrl := gomock.NewController(t) + protocol := NewMockTProtocol(mockCtrl) + gomock.InOrder( + protocol.EXPECT().WriteMessageBegin("testString", thrift.CALL, int32(1)), + protocol.EXPECT().WriteStructBegin("testString_args"), + protocol.EXPECT().WriteFieldBegin("s", thrift.TType(thrift.STRING), int16(1)), + protocol.EXPECT().WriteString("test"), + protocol.EXPECT().WriteFieldEnd(), + protocol.EXPECT().WriteFieldStop(), + protocol.EXPECT().WriteStructEnd(), + protocol.EXPECT().WriteMessageEnd(), + protocol.EXPECT().Flush(context.Background()), + protocol.EXPECT().ReadMessageBegin().Return("unknown", thrift.REPLY, int32(1), nil), + ) + + client := errortest.NewErrorTestClient(thrift.NewTStandardClient(protocol, protocol)) + _, err := client.TestString(defaultCtx, "test") + mockCtrl.Finish() + appErr, ok := err.(thrift.TApplicationException) + if !ok { + t.Fatal("Expected TApplicationException") + } + if appErr.TypeId() != thrift.WRONG_METHOD_NAME { + t.Fatal("Expected WRONG_METHOD_NAME error") + } +} + +// TestCase: Wrong method name has been received in the client. +// Similar to TestClientWrongMethodName, but using legacy client constructor. +func TestClientWrongMethodNameLegacy(t *testing.T) { + mockCtrl := gomock.NewController(t) + transport := thrift.NewTMemoryBuffer() + protocol := NewMockTProtocol(mockCtrl) + gomock.InOrder( + protocol.EXPECT().WriteMessageBegin("testString", thrift.CALL, int32(1)), + protocol.EXPECT().WriteStructBegin("testString_args"), + protocol.EXPECT().WriteFieldBegin("s", thrift.TType(thrift.STRING), int16(1)), + protocol.EXPECT().WriteString("test"), + protocol.EXPECT().WriteFieldEnd(), + protocol.EXPECT().WriteFieldStop(), + protocol.EXPECT().WriteStructEnd(), + protocol.EXPECT().WriteMessageEnd(), + protocol.EXPECT().Flush(context.Background()), + protocol.EXPECT().ReadMessageBegin().Return("unknown", thrift.REPLY, int32(1), nil), + ) + + client := errortest.NewErrorTestClientProtocol(transport, protocol, protocol) + _, err := client.TestString(defaultCtx, "test") + mockCtrl.Finish() + appErr, ok := err.(thrift.TApplicationException) + if !ok { + t.Fatal("Expected TApplicationException") + } + if appErr.TypeId() != thrift.WRONG_METHOD_NAME { + t.Fatal("Expected WRONG_METHOD_NAME error") + } +} + +// TestCase: Wrong message type has been received in the client. +func TestClientWrongMessageType(t *testing.T) { + mockCtrl := gomock.NewController(t) + protocol := NewMockTProtocol(mockCtrl) + gomock.InOrder( + protocol.EXPECT().WriteMessageBegin("testString", thrift.CALL, int32(1)), + protocol.EXPECT().WriteStructBegin("testString_args"), + protocol.EXPECT().WriteFieldBegin("s", thrift.TType(thrift.STRING), int16(1)), + protocol.EXPECT().WriteString("test"), + protocol.EXPECT().WriteFieldEnd(), + protocol.EXPECT().WriteFieldStop(), + protocol.EXPECT().WriteStructEnd(), + protocol.EXPECT().WriteMessageEnd(), + protocol.EXPECT().Flush(context.Background()), + protocol.EXPECT().ReadMessageBegin().Return("testString", thrift.INVALID_TMESSAGE_TYPE, int32(1), nil), + ) + + client := errortest.NewErrorTestClient(thrift.NewTStandardClient(protocol, protocol)) + _, err := client.TestString(defaultCtx, "test") + mockCtrl.Finish() + appErr, ok := err.(thrift.TApplicationException) + if !ok { + t.Fatal("Expected TApplicationException") + } + if appErr.TypeId() != thrift.INVALID_MESSAGE_TYPE_EXCEPTION { + t.Fatal("Expected INVALID_MESSAGE_TYPE_EXCEPTION error") + } +} + +// TestCase: Wrong message type has been received in the client. +// Similar to TestClientWrongMessageType, but using legacy client constructor. +func TestClientWrongMessageTypeLegacy(t *testing.T) { + mockCtrl := gomock.NewController(t) + transport := thrift.NewTMemoryBuffer() + protocol := NewMockTProtocol(mockCtrl) + gomock.InOrder( + protocol.EXPECT().WriteMessageBegin("testString", thrift.CALL, int32(1)), + protocol.EXPECT().WriteStructBegin("testString_args"), + protocol.EXPECT().WriteFieldBegin("s", thrift.TType(thrift.STRING), int16(1)), + protocol.EXPECT().WriteString("test"), + protocol.EXPECT().WriteFieldEnd(), + protocol.EXPECT().WriteFieldStop(), + protocol.EXPECT().WriteStructEnd(), + protocol.EXPECT().WriteMessageEnd(), + protocol.EXPECT().Flush(context.Background()), + protocol.EXPECT().ReadMessageBegin().Return("testString", thrift.INVALID_TMESSAGE_TYPE, int32(1), nil), + ) + + client := errortest.NewErrorTestClientProtocol(transport, protocol, protocol) + _, err := client.TestString(defaultCtx, "test") + mockCtrl.Finish() + appErr, ok := err.(thrift.TApplicationException) + if !ok { + t.Fatal("Expected TApplicationException") + } + if appErr.TypeId() != thrift.INVALID_MESSAGE_TYPE_EXCEPTION { + t.Fatal("Expected INVALID_MESSAGE_TYPE_EXCEPTION error") + } +} diff --git a/src/jaegertracing/thrift/lib/go/test/tests/context.go b/src/jaegertracing/thrift/lib/go/test/tests/context.go new file mode 100644 index 000000000..a93a82b8f --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/test/tests/context.go @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package tests + +import ( + "context" +) + +var defaultCtx = context.Background() diff --git a/src/jaegertracing/thrift/lib/go/test/tests/encoding_json_test.go b/src/jaegertracing/thrift/lib/go/test/tests/encoding_json_test.go new file mode 100644 index 000000000..12d4566fb --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/test/tests/encoding_json_test.go @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package tests + +import ( + "encoding" + "encoding/json" + "testing" + "thrifttest" +) + +func TestEnumIsTextMarshaller(t *testing.T) { + one := thrifttest.Numberz_ONE + var tm encoding.TextMarshaler = one + b, err := tm.MarshalText() + if err != nil { + t.Fatalf("Unexpected error from MarshalText: %s", err) + } + if string(b) != one.String() { + t.Errorf("MarshalText(%s) = %s, expected = %s", one, b, one) + } +} + +func TestEnumIsTextUnmarshaller(t *testing.T) { + var tm encoding.TextUnmarshaler = thrifttest.NumberzPtr(thrifttest.Numberz_TWO) + err := tm.UnmarshalText([]byte("TWO")) + if err != nil { + t.Fatalf("Unexpected error from UnmarshalText(TWO): %s", err) + } + if *(tm.(*thrifttest.Numberz)) != thrifttest.Numberz_TWO { + t.Errorf("UnmarshalText(TWO) = %s", tm) + } + + err = tm.UnmarshalText([]byte("NAN")) + if err == nil { + t.Errorf("Error from UnmarshalText(NAN)") + } +} + +func TestJSONMarshalUnmarshal(t *testing.T) { + s1 := thrifttest.StructB{ + Aa: &thrifttest.StructA{S: "Aa"}, + Ab: &thrifttest.StructA{S: "Ab"}, + } + + b, err := json.Marshal(s1) + if err != nil { + t.Fatalf("Unexpected error from json.Marshal: %s", err) + } + + s2 := thrifttest.StructB{} + err = json.Unmarshal(b, &s2) + if err != nil { + t.Fatalf("Unexpected error from json.Unmarshal: %s", err) + } + + if *s1.Aa != *s2.Aa || *s1.Ab != *s2.Ab { + t.Logf("s1 = %+v", s1) + t.Logf("s2 = %+v", s2) + t.Errorf("json: Unmarshal(Marshal(s)) != s") + } +} diff --git a/src/jaegertracing/thrift/lib/go/test/tests/gotag_test.go b/src/jaegertracing/thrift/lib/go/test/tests/gotag_test.go new file mode 100644 index 000000000..ff2f14ecf --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/test/tests/gotag_test.go @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package tests + +import ( + "gotagtest" + "reflect" + "testing" +) + +func TestDefaultTag(t *testing.T) { + s := gotagtest.Tagged{} + st := reflect.TypeOf(s) + field, ok := st.FieldByName("StringThing") + if !ok || field.Tag.Get("json") != "string_thing" { + t.Error("Unexpected default tag value") + } +} + +func TestCustomTag(t *testing.T) { + s := gotagtest.Tagged{} + st := reflect.TypeOf(s) + field, ok := st.FieldByName("IntThing") + if !ok || field.Tag.Get("json") != "int_thing,string" { + t.Error("Unexpected custom tag value") + } +} + +func TestOptionalTag(t *testing.T) { + s := gotagtest.Tagged{} + st := reflect.TypeOf(s) + field, ok := st.FieldByName("OptionalIntThing") + if !ok || field.Tag.Get("json") != "optional_int_thing,omitempty" { + t.Error("Unexpected default tag value for optional field") + } +} + +func TestOptionalTagWithDefaultValue(t *testing.T) { + s := gotagtest.Tagged{} + st := reflect.TypeOf(s) + field, ok := st.FieldByName("OptionalBoolThing") + if !ok || field.Tag.Get("json") != "optional_bool_thing" { + t.Error("Unexpected default tag value for optional field that has a default value") + } +} diff --git a/src/jaegertracing/thrift/lib/go/test/tests/ignoreinitialisms_test.go b/src/jaegertracing/thrift/lib/go/test/tests/ignoreinitialisms_test.go new file mode 100644 index 000000000..3cd5f6509 --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/test/tests/ignoreinitialisms_test.go @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package tests + +import ( + "ignoreinitialismstest" + "reflect" + "testing" +) + +func TestIgnoreInitialismsFlagIsHonoured(t *testing.T) { + s := ignoreinitialismstest.IgnoreInitialismsTest{} + st := reflect.TypeOf(s) + _, ok := st.FieldByName("Id") + if !ok { + t.Error("Id attribute is missing!") + } + _, ok = st.FieldByName("MyId") + if !ok { + t.Error("MyId attribute is missing!") + } + _, ok = st.FieldByName("NumCpu") + if !ok { + t.Error("NumCpu attribute is missing!") + } + _, ok = st.FieldByName("NumGpu") + if !ok { + t.Error("NumGpu attribute is missing!") + } + _, ok = st.FieldByName("My_ID") + if !ok { + t.Error("My_ID attribute is missing!") + } +} diff --git a/src/jaegertracing/thrift/lib/go/test/tests/initialisms_test.go b/src/jaegertracing/thrift/lib/go/test/tests/initialisms_test.go new file mode 100644 index 000000000..40923d24c --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/test/tests/initialisms_test.go @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package tests + +import ( + "initialismstest" + "reflect" + "testing" +) + +func TestThatCommonInitialismsAreFixed(t *testing.T) { + s := initialismstest.InitialismsTest{} + st := reflect.TypeOf(s) + _, ok := st.FieldByName("UserID") + if !ok { + t.Error("UserID attribute is missing!") + } + _, ok = st.FieldByName("ServerURL") + if !ok { + t.Error("ServerURL attribute is missing!") + } + _, ok = st.FieldByName("ID") + if !ok { + t.Error("ID attribute is missing!") + } +} diff --git a/src/jaegertracing/thrift/lib/go/test/tests/multiplexed_protocol_test.go b/src/jaegertracing/thrift/lib/go/test/tests/multiplexed_protocol_test.go new file mode 100644 index 000000000..61ac62828 --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/test/tests/multiplexed_protocol_test.go @@ -0,0 +1,197 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package tests + +import ( + "context" + "multiplexedprotocoltest" + "net" + "testing" + "thrift" + "time" +) + +func FindAvailableTCPServerPort() net.Addr { + if l, err := net.Listen("tcp", "127.0.0.1:0"); err != nil { + panic("Could not find available server port") + } else { + defer l.Close() + return l.Addr() + } +} + +type FirstImpl struct{} + +func (f *FirstImpl) ReturnOne(ctx context.Context) (r int64, err error) { + return 1, nil +} + +type SecondImpl struct{} + +func (s *SecondImpl) ReturnTwo(ctx context.Context) (r int64, err error) { + return 2, nil +} + +func createTransport(addr net.Addr) (thrift.TTransport, error) { + socket := thrift.NewTSocketFromAddrTimeout(addr, TIMEOUT) + transport := thrift.NewTFramedTransport(socket) + err := transport.Open() + if err != nil { + return nil, err + } + return transport, nil +} + +func TestMultiplexedProtocolFirst(t *testing.T) { + processor := thrift.NewTMultiplexedProcessor() + protocolFactory := thrift.NewTBinaryProtocolFactoryDefault() + transportFactory := thrift.NewTTransportFactory() + transportFactory = thrift.NewTFramedTransportFactory(transportFactory) + addr := FindAvailableTCPServerPort() + serverTransport, err := thrift.NewTServerSocketTimeout(addr.String(), TIMEOUT) + if err != nil { + t.Fatal("Unable to create server socket", err) + } + server = thrift.NewTSimpleServer4(processor, serverTransport, transportFactory, protocolFactory) + + firstProcessor := multiplexedprotocoltest.NewFirstProcessor(&FirstImpl{}) + processor.RegisterProcessor("FirstService", firstProcessor) + + secondProcessor := multiplexedprotocoltest.NewSecondProcessor(&SecondImpl{}) + processor.RegisterProcessor("SecondService", secondProcessor) + + defer server.Stop() + go server.Serve() + time.Sleep(10 * time.Millisecond) + + transport, err := createTransport(addr) + if err != nil { + t.Fatal(err) + } + defer transport.Close() + protocol := thrift.NewTMultiplexedProtocol(thrift.NewTBinaryProtocolTransport(transport), "FirstService") + + client := multiplexedprotocoltest.NewFirstClient(thrift.NewTStandardClient(protocol, protocol)) + + ret, err := client.ReturnOne(defaultCtx) + if err != nil { + t.Fatal("Unable to call first server:", err) + } else if ret != 1 { + t.Fatal("Unexpected result from server: ", ret) + } +} + +func TestMultiplexedProtocolSecond(t *testing.T) { + processor := thrift.NewTMultiplexedProcessor() + protocolFactory := thrift.NewTBinaryProtocolFactoryDefault() + transportFactory := thrift.NewTTransportFactory() + transportFactory = thrift.NewTFramedTransportFactory(transportFactory) + addr := FindAvailableTCPServerPort() + serverTransport, err := thrift.NewTServerSocketTimeout(addr.String(), TIMEOUT) + if err != nil { + t.Fatal("Unable to create server socket", err) + } + server = thrift.NewTSimpleServer4(processor, serverTransport, transportFactory, protocolFactory) + + firstProcessor := multiplexedprotocoltest.NewFirstProcessor(&FirstImpl{}) + processor.RegisterProcessor("FirstService", firstProcessor) + + secondProcessor := multiplexedprotocoltest.NewSecondProcessor(&SecondImpl{}) + processor.RegisterProcessor("SecondService", secondProcessor) + + defer server.Stop() + go server.Serve() + time.Sleep(10 * time.Millisecond) + + transport, err := createTransport(addr) + if err != nil { + t.Fatal(err) + } + defer transport.Close() + protocol := thrift.NewTMultiplexedProtocol(thrift.NewTBinaryProtocolTransport(transport), "SecondService") + + client := multiplexedprotocoltest.NewSecondClient(thrift.NewTStandardClient(protocol, protocol)) + + ret, err := client.ReturnTwo(defaultCtx) + if err != nil { + t.Fatal("Unable to call second server:", err) + } else if ret != 2 { + t.Fatal("Unexpected result from server: ", ret) + } +} + +func TestMultiplexedProtocolLegacy(t *testing.T) { + processor := thrift.NewTMultiplexedProcessor() + protocolFactory := thrift.NewTBinaryProtocolFactoryDefault() + transportFactory := thrift.NewTTransportFactory() + transportFactory = thrift.NewTFramedTransportFactory(transportFactory) + addr := FindAvailableTCPServerPort() + serverTransport, err := thrift.NewTServerSocketTimeout(addr.String(), TIMEOUT) + if err != nil { + t.Fatal("Unable to create server socket", err) + } + server = thrift.NewTSimpleServer4(processor, serverTransport, transportFactory, protocolFactory) + + firstProcessor := multiplexedprotocoltest.NewFirstProcessor(&FirstImpl{}) + processor.RegisterProcessor("FirstService", firstProcessor) + + secondProcessor := multiplexedprotocoltest.NewSecondProcessor(&SecondImpl{}) + processor.RegisterProcessor("SecondService", secondProcessor) + + defer server.Stop() + go server.Serve() + time.Sleep(10 * time.Millisecond) + + transport, err := createTransport(addr) + if err != nil { + t.Error(err) + return + } + defer transport.Close() + + protocol := thrift.NewTBinaryProtocolTransport(transport) + client := multiplexedprotocoltest.NewSecondClient(thrift.NewTStandardClient(protocol, protocol)) + + ret, err := client.ReturnTwo(defaultCtx) + //expect error since default processor is not registered + if err == nil { + t.Fatal("Expecting error") + } + + //register default processor and call again + processor.RegisterDefault(multiplexedprotocoltest.NewSecondProcessor(&SecondImpl{})) + transport, err = createTransport(addr) + if err != nil { + t.Error(err) + return + } + defer transport.Close() + + protocol = thrift.NewTBinaryProtocolTransport(transport) + client = multiplexedprotocoltest.NewSecondClient(thrift.NewTStandardClient(protocol, protocol)) + + ret, err = client.ReturnTwo(defaultCtx) + if err != nil { + t.Fatal("Unable to call legacy server:", err) + } + if ret != 2 { + t.Fatal("Unexpected result from server: ", ret) + } +} diff --git a/src/jaegertracing/thrift/lib/go/test/tests/names_test.go b/src/jaegertracing/thrift/lib/go/test/tests/names_test.go new file mode 100644 index 000000000..90b63a3b4 --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/test/tests/names_test.go @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package tests + +import ( + "namestest" + "reflect" + "testing" +) + +func TestThatAttributeNameSubstituionDoesNotOccur(t *testing.T) { + s := namestest.NamesTest{} + st := reflect.TypeOf(s) + _, ok := st.FieldByName("Type") + if !ok { + t.Error("Type attribute is missing!") + } +} diff --git a/src/jaegertracing/thrift/lib/go/test/tests/one_way_test.go b/src/jaegertracing/thrift/lib/go/test/tests/one_way_test.go new file mode 100644 index 000000000..48d0bbe38 --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/test/tests/one_way_test.go @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package tests + +import ( + "context" + "fmt" + "net" + "onewaytest" + "testing" + "thrift" + "time" +) + +func findPort() net.Addr { + if l, err := net.Listen("tcp", "127.0.0.1:0"); err != nil { + panic("Could not find available server port") + } else { + defer l.Close() + return l.Addr() + } +} + +type impl struct{} + +func (i *impl) Hi(ctx context.Context, in int64, s string) (err error) { fmt.Println("Hi!"); return } +func (i *impl) Emptyfunc(ctx context.Context) (err error) { return } +func (i *impl) EchoInt(ctx context.Context, param int64) (r int64, err error) { return param, nil } + +const TIMEOUT = time.Second + +var addr net.Addr +var server *thrift.TSimpleServer +var client *onewaytest.OneWayClient + +func TestInitOneway(t *testing.T) { + var err error + addr = findPort() + serverTransport, err := thrift.NewTServerSocketTimeout(addr.String(), TIMEOUT) + if err != nil { + t.Fatal("Unable to create server socket", err) + } + processor := onewaytest.NewOneWayProcessor(&impl{}) + server = thrift.NewTSimpleServer2(processor, serverTransport) + + go server.Serve() + time.Sleep(10 * time.Millisecond) +} + +func TestInitOnewayClient(t *testing.T) { + transport := thrift.NewTSocketFromAddrTimeout(addr, TIMEOUT) + protocol := thrift.NewTBinaryProtocolTransport(transport) + client = onewaytest.NewOneWayClient(thrift.NewTStandardClient(protocol, protocol)) + err := transport.Open() + if err != nil { + t.Fatal("Unable to open client socket", err) + } +} + +func TestCallOnewayServer(t *testing.T) { + //call oneway function + err := client.Hi(defaultCtx, 1, "") + if err != nil { + t.Fatal("Unexpected error: ", err) + } + //There is no way to detect protocol problems with single oneway call so we call it second time + i, err := client.EchoInt(defaultCtx, 42) + if err != nil { + t.Fatal("Unexpected error: ", err) + } + if i != 42 { + t.Fatal("Unexpected returned value: ", i) + } +} diff --git a/src/jaegertracing/thrift/lib/go/test/tests/optional_fields_test.go b/src/jaegertracing/thrift/lib/go/test/tests/optional_fields_test.go new file mode 100644 index 000000000..34ad6605a --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/test/tests/optional_fields_test.go @@ -0,0 +1,280 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package tests + +import ( + "bytes" + gomock "github.com/golang/mock/gomock" + "optionalfieldstest" + "testing" + "thrift" +) + +func TestIsSetReturnFalseOnCreation(t *testing.T) { + ao := optionalfieldstest.NewAllOptional() + if ao.IsSetS() { + t.Errorf("Optional field S is set on initialization") + } + if ao.IsSetI() { + t.Errorf("Optional field I is set on initialization") + } + if ao.IsSetB() { + t.Errorf("Optional field B is set on initialization") + } + if ao.IsSetS2() { + t.Errorf("Optional field S2 is set on initialization") + } + if ao.IsSetI2() { + t.Errorf("Optional field I2 is set on initialization") + } + if ao.IsSetB2() { + t.Errorf("Optional field B2 is set on initialization") + } + if ao.IsSetAa() { + t.Errorf("Optional field Aa is set on initialization") + } + if ao.IsSetL() { + t.Errorf("Optional field L is set on initialization") + } + if ao.IsSetL2() { + t.Errorf("Optional field L2 is set on initialization") + } + if ao.IsSetM() { + t.Errorf("Optional field M is set on initialization") + } + if ao.IsSetM2() { + t.Errorf("Optional field M2 is set on initialization") + } + if ao.IsSetBin() { + t.Errorf("Optional field Bin is set on initialization") + } + if ao.IsSetBin2() { + t.Errorf("Optional field Bin2 is set on initialization") + } +} + +func TestDefaultValuesOnCreation(t *testing.T) { + ao := optionalfieldstest.NewAllOptional() + if ao.GetS() != "DEFAULT" { + t.Errorf("Unexpected default value %#v for field S", ao.GetS()) + } + if ao.GetI() != 42 { + t.Errorf("Unexpected default value %#v for field I", ao.GetI()) + } + if ao.GetB() != false { + t.Errorf("Unexpected default value %#v for field B", ao.GetB()) + } + if ao.GetS2() != "" { + t.Errorf("Unexpected default value %#v for field S2", ao.GetS2()) + } + if ao.GetI2() != 0 { + t.Errorf("Unexpected default value %#v for field I2", ao.GetI2()) + } + if ao.GetB2() != false { + t.Errorf("Unexpected default value %#v for field B2", ao.GetB2()) + } + if l := ao.GetL(); len(l) != 0 { + t.Errorf("Unexpected default value %#v for field L", l) + } + if l := ao.GetL2(); len(l) != 2 || l[0] != 1 || l[1] != 2 { + t.Errorf("Unexpected default value %#v for field L2", l) + } + //FIXME: should we return empty map here? + if m := ao.GetM(); m != nil { + t.Errorf("Unexpected default value %#v for field M", m) + } + if m := ao.GetM2(); len(m) != 2 || m[1] != 2 || m[3] != 4 { + t.Errorf("Unexpected default value %#v for field M2", m) + } + if bv := ao.GetBin(); bv != nil { + t.Errorf("Unexpected default value %#v for field Bin", bv) + } + if bv := ao.GetBin2(); !bytes.Equal(bv, []byte("asdf")) { + t.Errorf("Unexpected default value %#v for field Bin2", bv) + } +} + +func TestInitialValuesOnCreation(t *testing.T) { + ao := optionalfieldstest.NewAllOptional() + if ao.S != "DEFAULT" { + t.Errorf("Unexpected initial value %#v for field S", ao.S) + } + if ao.I != 42 { + t.Errorf("Unexpected initial value %#v for field I", ao.I) + } + if ao.B != false { + t.Errorf("Unexpected initial value %#v for field B", ao.B) + } + if ao.S2 != nil { + t.Errorf("Unexpected initial value %#v for field S2", ao.S2) + } + if ao.I2 != nil { + t.Errorf("Unexpected initial value %#v for field I2", ao.I2) + } + if ao.B2 != nil { + t.Errorf("Unexpected initial value %#v for field B2", ao.B2) + } + if ao.L != nil || len(ao.L) != 0 { + t.Errorf("Unexpected initial value %#v for field L", ao.L) + } + if ao.L2 != nil { + t.Errorf("Unexpected initial value %#v for field L2", ao.L2) + } + if ao.M != nil { + t.Errorf("Unexpected initial value %#v for field M", ao.M) + } + if ao.M2 != nil { + t.Errorf("Unexpected initial value %#v for field M2", ao.M2) + } + if ao.Bin != nil || len(ao.Bin) != 0 { + t.Errorf("Unexpected initial value %#v for field Bin", ao.Bin) + } + if !bytes.Equal(ao.Bin2, []byte("asdf")) { + t.Errorf("Unexpected initial value %#v for field Bin2", ao.Bin2) + } +} + +func TestIsSetReturnTrueAfterUpdate(t *testing.T) { + ao := optionalfieldstest.NewAllOptional() + ao.S = "somevalue" + ao.I = 123 + ao.B = true + ao.Aa = optionalfieldstest.NewStructA() + if !ao.IsSetS() { + t.Errorf("Field S should be set") + } + if !ao.IsSetI() { + t.Errorf("Field I should be set") + } + if !ao.IsSetB() { + t.Errorf("Field B should be set") + } + if !ao.IsSetAa() { + t.Errorf("Field aa should be set") + } +} + +func TestListNotEmpty(t *testing.T) { + ao := optionalfieldstest.NewAllOptional() + ao.L = []int64{1, 2, 3} + if !ao.IsSetL() { + t.Errorf("Field L should be set") + } +} + +//Make sure that optional fields are not being serialized +func TestNoOptionalUnsetFieldsOnWire(t *testing.T) { + mockCtrl := gomock.NewController(t) + defer mockCtrl.Finish() + proto := NewMockTProtocol(mockCtrl) + gomock.InOrder( + proto.EXPECT().WriteStructBegin("all_optional").Return(nil), + proto.EXPECT().WriteFieldStop().Return(nil), + proto.EXPECT().WriteStructEnd().Return(nil), + ) + ao := optionalfieldstest.NewAllOptional() + ao.Write(proto) +} + +func TestNoSetToDefaultFieldsOnWire(t *testing.T) { + mockCtrl := gomock.NewController(t) + defer mockCtrl.Finish() + proto := NewMockTProtocol(mockCtrl) + gomock.InOrder( + proto.EXPECT().WriteStructBegin("all_optional").Return(nil), + proto.EXPECT().WriteFieldStop().Return(nil), + proto.EXPECT().WriteStructEnd().Return(nil), + ) + ao := optionalfieldstest.NewAllOptional() + ao.I = 42 + ao.Write(proto) +} + +//Make sure that only one field is being serialized when set to non-default +func TestOneISetFieldOnWire(t *testing.T) { + mockCtrl := gomock.NewController(t) + defer mockCtrl.Finish() + proto := NewMockTProtocol(mockCtrl) + gomock.InOrder( + proto.EXPECT().WriteStructBegin("all_optional").Return(nil), + proto.EXPECT().WriteFieldBegin("i", thrift.TType(thrift.I64), int16(2)).Return(nil), + proto.EXPECT().WriteI64(int64(123)).Return(nil), + proto.EXPECT().WriteFieldEnd().Return(nil), + proto.EXPECT().WriteFieldStop().Return(nil), + proto.EXPECT().WriteStructEnd().Return(nil), + ) + ao := optionalfieldstest.NewAllOptional() + ao.I = 123 + ao.Write(proto) +} + +func TestOneLSetFieldOnWire(t *testing.T) { + mockCtrl := gomock.NewController(t) + defer mockCtrl.Finish() + proto := NewMockTProtocol(mockCtrl) + gomock.InOrder( + proto.EXPECT().WriteStructBegin("all_optional").Return(nil), + proto.EXPECT().WriteFieldBegin("l", thrift.TType(thrift.LIST), int16(9)).Return(nil), + proto.EXPECT().WriteListBegin(thrift.TType(thrift.I64), 2).Return(nil), + proto.EXPECT().WriteI64(int64(1)).Return(nil), + proto.EXPECT().WriteI64(int64(2)).Return(nil), + proto.EXPECT().WriteListEnd().Return(nil), + proto.EXPECT().WriteFieldEnd().Return(nil), + proto.EXPECT().WriteFieldStop().Return(nil), + proto.EXPECT().WriteStructEnd().Return(nil), + ) + ao := optionalfieldstest.NewAllOptional() + ao.L = []int64{1, 2} + ao.Write(proto) +} + +func TestOneBinSetFieldOnWire(t *testing.T) { + mockCtrl := gomock.NewController(t) + defer mockCtrl.Finish() + proto := NewMockTProtocol(mockCtrl) + gomock.InOrder( + proto.EXPECT().WriteStructBegin("all_optional").Return(nil), + proto.EXPECT().WriteFieldBegin("bin", thrift.TType(thrift.STRING), int16(13)).Return(nil), + proto.EXPECT().WriteBinary([]byte("somebytestring")).Return(nil), + proto.EXPECT().WriteFieldEnd().Return(nil), + proto.EXPECT().WriteFieldStop().Return(nil), + proto.EXPECT().WriteStructEnd().Return(nil), + ) + ao := optionalfieldstest.NewAllOptional() + ao.Bin = []byte("somebytestring") + ao.Write(proto) +} + +func TestOneEmptyBinSetFieldOnWire(t *testing.T) { + mockCtrl := gomock.NewController(t) + defer mockCtrl.Finish() + proto := NewMockTProtocol(mockCtrl) + gomock.InOrder( + proto.EXPECT().WriteStructBegin("all_optional").Return(nil), + proto.EXPECT().WriteFieldBegin("bin", thrift.TType(thrift.STRING), int16(13)).Return(nil), + proto.EXPECT().WriteBinary([]byte{}).Return(nil), + proto.EXPECT().WriteFieldEnd().Return(nil), + proto.EXPECT().WriteFieldStop().Return(nil), + proto.EXPECT().WriteStructEnd().Return(nil), + ) + ao := optionalfieldstest.NewAllOptional() + ao.Bin = []byte{} + ao.Write(proto) +} diff --git a/src/jaegertracing/thrift/lib/go/test/tests/protocol_mock.go b/src/jaegertracing/thrift/lib/go/test/tests/protocol_mock.go new file mode 100644 index 000000000..51d7a02ff --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/test/tests/protocol_mock.go @@ -0,0 +1,513 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +// Automatically generated by MockGen. DO NOT EDIT! +// Source: thrift (interfaces: TProtocol) + +package tests + +import ( + "context" + thrift "thrift" + + gomock "github.com/golang/mock/gomock" +) + +// Mock of TProtocol interface +type MockTProtocol struct { + ctrl *gomock.Controller + recorder *_MockTProtocolRecorder +} + +// Recorder for MockTProtocol (not exported) +type _MockTProtocolRecorder struct { + mock *MockTProtocol +} + +func NewMockTProtocol(ctrl *gomock.Controller) *MockTProtocol { + mock := &MockTProtocol{ctrl: ctrl} + mock.recorder = &_MockTProtocolRecorder{mock} + return mock +} + +func (_m *MockTProtocol) EXPECT() *_MockTProtocolRecorder { + return _m.recorder +} + +func (_m *MockTProtocol) Flush(ctx context.Context) error { + ret := _m.ctrl.Call(_m, "Flush") + ret0, _ := ret[0].(error) + return ret0 +} + +func (_mr *_MockTProtocolRecorder) Flush(ctx context.Context) *gomock.Call { + return _mr.mock.ctrl.RecordCall(_mr.mock, "Flush") +} + +func (_m *MockTProtocol) ReadBinary() ([]byte, error) { + ret := _m.ctrl.Call(_m, "ReadBinary") + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +func (_mr *_MockTProtocolRecorder) ReadBinary() *gomock.Call { + return _mr.mock.ctrl.RecordCall(_mr.mock, "ReadBinary") +} + +func (_m *MockTProtocol) ReadBool() (bool, error) { + ret := _m.ctrl.Call(_m, "ReadBool") + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +func (_mr *_MockTProtocolRecorder) ReadBool() *gomock.Call { + return _mr.mock.ctrl.RecordCall(_mr.mock, "ReadBool") +} + +func (_m *MockTProtocol) ReadByte() (int8, error) { + ret := _m.ctrl.Call(_m, "ReadByte") + ret0, _ := ret[0].(int8) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +func (_mr *_MockTProtocolRecorder) ReadByte() *gomock.Call { + return _mr.mock.ctrl.RecordCall(_mr.mock, "ReadByte") +} + +func (_m *MockTProtocol) ReadDouble() (float64, error) { + ret := _m.ctrl.Call(_m, "ReadDouble") + ret0, _ := ret[0].(float64) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +func (_mr *_MockTProtocolRecorder) ReadDouble() *gomock.Call { + return _mr.mock.ctrl.RecordCall(_mr.mock, "ReadDouble") +} + +func (_m *MockTProtocol) ReadFieldBegin() (string, thrift.TType, int16, error) { + ret := _m.ctrl.Call(_m, "ReadFieldBegin") + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(thrift.TType) + ret2, _ := ret[2].(int16) + ret3, _ := ret[3].(error) + return ret0, ret1, ret2, ret3 +} + +func (_mr *_MockTProtocolRecorder) ReadFieldBegin() *gomock.Call { + return _mr.mock.ctrl.RecordCall(_mr.mock, "ReadFieldBegin") +} + +func (_m *MockTProtocol) ReadFieldEnd() error { + ret := _m.ctrl.Call(_m, "ReadFieldEnd") + ret0, _ := ret[0].(error) + return ret0 +} + +func (_mr *_MockTProtocolRecorder) ReadFieldEnd() *gomock.Call { + return _mr.mock.ctrl.RecordCall(_mr.mock, "ReadFieldEnd") +} + +func (_m *MockTProtocol) ReadI16() (int16, error) { + ret := _m.ctrl.Call(_m, "ReadI16") + ret0, _ := ret[0].(int16) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +func (_mr *_MockTProtocolRecorder) ReadI16() *gomock.Call { + return _mr.mock.ctrl.RecordCall(_mr.mock, "ReadI16") +} + +func (_m *MockTProtocol) ReadI32() (int32, error) { + ret := _m.ctrl.Call(_m, "ReadI32") + ret0, _ := ret[0].(int32) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +func (_mr *_MockTProtocolRecorder) ReadI32() *gomock.Call { + return _mr.mock.ctrl.RecordCall(_mr.mock, "ReadI32") +} + +func (_m *MockTProtocol) ReadI64() (int64, error) { + ret := _m.ctrl.Call(_m, "ReadI64") + ret0, _ := ret[0].(int64) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +func (_mr *_MockTProtocolRecorder) ReadI64() *gomock.Call { + return _mr.mock.ctrl.RecordCall(_mr.mock, "ReadI64") +} + +func (_m *MockTProtocol) ReadListBegin() (thrift.TType, int, error) { + ret := _m.ctrl.Call(_m, "ReadListBegin") + ret0, _ := ret[0].(thrift.TType) + ret1, _ := ret[1].(int) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +func (_mr *_MockTProtocolRecorder) ReadListBegin() *gomock.Call { + return _mr.mock.ctrl.RecordCall(_mr.mock, "ReadListBegin") +} + +func (_m *MockTProtocol) ReadListEnd() error { + ret := _m.ctrl.Call(_m, "ReadListEnd") + ret0, _ := ret[0].(error) + return ret0 +} + +func (_mr *_MockTProtocolRecorder) ReadListEnd() *gomock.Call { + return _mr.mock.ctrl.RecordCall(_mr.mock, "ReadListEnd") +} + +func (_m *MockTProtocol) ReadMapBegin() (thrift.TType, thrift.TType, int, error) { + ret := _m.ctrl.Call(_m, "ReadMapBegin") + ret0, _ := ret[0].(thrift.TType) + ret1, _ := ret[1].(thrift.TType) + ret2, _ := ret[2].(int) + ret3, _ := ret[3].(error) + return ret0, ret1, ret2, ret3 +} + +func (_mr *_MockTProtocolRecorder) ReadMapBegin() *gomock.Call { + return _mr.mock.ctrl.RecordCall(_mr.mock, "ReadMapBegin") +} + +func (_m *MockTProtocol) ReadMapEnd() error { + ret := _m.ctrl.Call(_m, "ReadMapEnd") + ret0, _ := ret[0].(error) + return ret0 +} + +func (_mr *_MockTProtocolRecorder) ReadMapEnd() *gomock.Call { + return _mr.mock.ctrl.RecordCall(_mr.mock, "ReadMapEnd") +} + +func (_m *MockTProtocol) ReadMessageBegin() (string, thrift.TMessageType, int32, error) { + ret := _m.ctrl.Call(_m, "ReadMessageBegin") + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(thrift.TMessageType) + ret2, _ := ret[2].(int32) + ret3, _ := ret[3].(error) + return ret0, ret1, ret2, ret3 +} + +func (_mr *_MockTProtocolRecorder) ReadMessageBegin() *gomock.Call { + return _mr.mock.ctrl.RecordCall(_mr.mock, "ReadMessageBegin") +} + +func (_m *MockTProtocol) ReadMessageEnd() error { + ret := _m.ctrl.Call(_m, "ReadMessageEnd") + ret0, _ := ret[0].(error) + return ret0 +} + +func (_mr *_MockTProtocolRecorder) ReadMessageEnd() *gomock.Call { + return _mr.mock.ctrl.RecordCall(_mr.mock, "ReadMessageEnd") +} + +func (_m *MockTProtocol) ReadSetBegin() (thrift.TType, int, error) { + ret := _m.ctrl.Call(_m, "ReadSetBegin") + ret0, _ := ret[0].(thrift.TType) + ret1, _ := ret[1].(int) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +func (_mr *_MockTProtocolRecorder) ReadSetBegin() *gomock.Call { + return _mr.mock.ctrl.RecordCall(_mr.mock, "ReadSetBegin") +} + +func (_m *MockTProtocol) ReadSetEnd() error { + ret := _m.ctrl.Call(_m, "ReadSetEnd") + ret0, _ := ret[0].(error) + return ret0 +} + +func (_mr *_MockTProtocolRecorder) ReadSetEnd() *gomock.Call { + return _mr.mock.ctrl.RecordCall(_mr.mock, "ReadSetEnd") +} + +func (_m *MockTProtocol) ReadString() (string, error) { + ret := _m.ctrl.Call(_m, "ReadString") + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +func (_mr *_MockTProtocolRecorder) ReadString() *gomock.Call { + return _mr.mock.ctrl.RecordCall(_mr.mock, "ReadString") +} + +func (_m *MockTProtocol) ReadStructBegin() (string, error) { + ret := _m.ctrl.Call(_m, "ReadStructBegin") + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +func (_mr *_MockTProtocolRecorder) ReadStructBegin() *gomock.Call { + return _mr.mock.ctrl.RecordCall(_mr.mock, "ReadStructBegin") +} + +func (_m *MockTProtocol) ReadStructEnd() error { + ret := _m.ctrl.Call(_m, "ReadStructEnd") + ret0, _ := ret[0].(error) + return ret0 +} + +func (_mr *_MockTProtocolRecorder) ReadStructEnd() *gomock.Call { + return _mr.mock.ctrl.RecordCall(_mr.mock, "ReadStructEnd") +} + +func (_m *MockTProtocol) Skip(_param0 thrift.TType) error { + ret := _m.ctrl.Call(_m, "Skip", _param0) + ret0, _ := ret[0].(error) + return ret0 +} + +func (_mr *_MockTProtocolRecorder) Skip(arg0 interface{}) *gomock.Call { + return _mr.mock.ctrl.RecordCall(_mr.mock, "Skip", arg0) +} + +func (_m *MockTProtocol) Transport() thrift.TTransport { + ret := _m.ctrl.Call(_m, "Transport") + ret0, _ := ret[0].(thrift.TTransport) + return ret0 +} + +func (_mr *_MockTProtocolRecorder) Transport() *gomock.Call { + return _mr.mock.ctrl.RecordCall(_mr.mock, "Transport") +} + +func (_m *MockTProtocol) WriteBinary(_param0 []byte) error { + ret := _m.ctrl.Call(_m, "WriteBinary", _param0) + ret0, _ := ret[0].(error) + return ret0 +} + +func (_mr *_MockTProtocolRecorder) WriteBinary(arg0 interface{}) *gomock.Call { + return _mr.mock.ctrl.RecordCall(_mr.mock, "WriteBinary", arg0) +} + +func (_m *MockTProtocol) WriteBool(_param0 bool) error { + ret := _m.ctrl.Call(_m, "WriteBool", _param0) + ret0, _ := ret[0].(error) + return ret0 +} + +func (_mr *_MockTProtocolRecorder) WriteBool(arg0 interface{}) *gomock.Call { + return _mr.mock.ctrl.RecordCall(_mr.mock, "WriteBool", arg0) +} + +func (_m *MockTProtocol) WriteByte(_param0 int8) error { + ret := _m.ctrl.Call(_m, "WriteByte", _param0) + ret0, _ := ret[0].(error) + return ret0 +} + +func (_mr *_MockTProtocolRecorder) WriteByte(arg0 interface{}) *gomock.Call { + return _mr.mock.ctrl.RecordCall(_mr.mock, "WriteByte", arg0) +} + +func (_m *MockTProtocol) WriteDouble(_param0 float64) error { + ret := _m.ctrl.Call(_m, "WriteDouble", _param0) + ret0, _ := ret[0].(error) + return ret0 +} + +func (_mr *_MockTProtocolRecorder) WriteDouble(arg0 interface{}) *gomock.Call { + return _mr.mock.ctrl.RecordCall(_mr.mock, "WriteDouble", arg0) +} + +func (_m *MockTProtocol) WriteFieldBegin(_param0 string, _param1 thrift.TType, _param2 int16) error { + ret := _m.ctrl.Call(_m, "WriteFieldBegin", _param0, _param1, _param2) + ret0, _ := ret[0].(error) + return ret0 +} + +func (_mr *_MockTProtocolRecorder) WriteFieldBegin(arg0, arg1, arg2 interface{}) *gomock.Call { + return _mr.mock.ctrl.RecordCall(_mr.mock, "WriteFieldBegin", arg0, arg1, arg2) +} + +func (_m *MockTProtocol) WriteFieldEnd() error { + ret := _m.ctrl.Call(_m, "WriteFieldEnd") + ret0, _ := ret[0].(error) + return ret0 +} + +func (_mr *_MockTProtocolRecorder) WriteFieldEnd() *gomock.Call { + return _mr.mock.ctrl.RecordCall(_mr.mock, "WriteFieldEnd") +} + +func (_m *MockTProtocol) WriteFieldStop() error { + ret := _m.ctrl.Call(_m, "WriteFieldStop") + ret0, _ := ret[0].(error) + return ret0 +} + +func (_mr *_MockTProtocolRecorder) WriteFieldStop() *gomock.Call { + return _mr.mock.ctrl.RecordCall(_mr.mock, "WriteFieldStop") +} + +func (_m *MockTProtocol) WriteI16(_param0 int16) error { + ret := _m.ctrl.Call(_m, "WriteI16", _param0) + ret0, _ := ret[0].(error) + return ret0 +} + +func (_mr *_MockTProtocolRecorder) WriteI16(arg0 interface{}) *gomock.Call { + return _mr.mock.ctrl.RecordCall(_mr.mock, "WriteI16", arg0) +} + +func (_m *MockTProtocol) WriteI32(_param0 int32) error { + ret := _m.ctrl.Call(_m, "WriteI32", _param0) + ret0, _ := ret[0].(error) + return ret0 +} + +func (_mr *_MockTProtocolRecorder) WriteI32(arg0 interface{}) *gomock.Call { + return _mr.mock.ctrl.RecordCall(_mr.mock, "WriteI32", arg0) +} + +func (_m *MockTProtocol) WriteI64(_param0 int64) error { + ret := _m.ctrl.Call(_m, "WriteI64", _param0) + ret0, _ := ret[0].(error) + return ret0 +} + +func (_mr *_MockTProtocolRecorder) WriteI64(arg0 interface{}) *gomock.Call { + return _mr.mock.ctrl.RecordCall(_mr.mock, "WriteI64", arg0) +} + +func (_m *MockTProtocol) WriteListBegin(_param0 thrift.TType, _param1 int) error { + ret := _m.ctrl.Call(_m, "WriteListBegin", _param0, _param1) + ret0, _ := ret[0].(error) + return ret0 +} + +func (_mr *_MockTProtocolRecorder) WriteListBegin(arg0, arg1 interface{}) *gomock.Call { + return _mr.mock.ctrl.RecordCall(_mr.mock, "WriteListBegin", arg0, arg1) +} + +func (_m *MockTProtocol) WriteListEnd() error { + ret := _m.ctrl.Call(_m, "WriteListEnd") + ret0, _ := ret[0].(error) + return ret0 +} + +func (_mr *_MockTProtocolRecorder) WriteListEnd() *gomock.Call { + return _mr.mock.ctrl.RecordCall(_mr.mock, "WriteListEnd") +} + +func (_m *MockTProtocol) WriteMapBegin(_param0 thrift.TType, _param1 thrift.TType, _param2 int) error { + ret := _m.ctrl.Call(_m, "WriteMapBegin", _param0, _param1, _param2) + ret0, _ := ret[0].(error) + return ret0 +} + +func (_mr *_MockTProtocolRecorder) WriteMapBegin(arg0, arg1, arg2 interface{}) *gomock.Call { + return _mr.mock.ctrl.RecordCall(_mr.mock, "WriteMapBegin", arg0, arg1, arg2) +} + +func (_m *MockTProtocol) WriteMapEnd() error { + ret := _m.ctrl.Call(_m, "WriteMapEnd") + ret0, _ := ret[0].(error) + return ret0 +} + +func (_mr *_MockTProtocolRecorder) WriteMapEnd() *gomock.Call { + return _mr.mock.ctrl.RecordCall(_mr.mock, "WriteMapEnd") +} + +func (_m *MockTProtocol) WriteMessageBegin(_param0 string, _param1 thrift.TMessageType, _param2 int32) error { + ret := _m.ctrl.Call(_m, "WriteMessageBegin", _param0, _param1, _param2) + ret0, _ := ret[0].(error) + return ret0 +} + +func (_mr *_MockTProtocolRecorder) WriteMessageBegin(arg0, arg1, arg2 interface{}) *gomock.Call { + return _mr.mock.ctrl.RecordCall(_mr.mock, "WriteMessageBegin", arg0, arg1, arg2) +} + +func (_m *MockTProtocol) WriteMessageEnd() error { + ret := _m.ctrl.Call(_m, "WriteMessageEnd") + ret0, _ := ret[0].(error) + return ret0 +} + +func (_mr *_MockTProtocolRecorder) WriteMessageEnd() *gomock.Call { + return _mr.mock.ctrl.RecordCall(_mr.mock, "WriteMessageEnd") +} + +func (_m *MockTProtocol) WriteSetBegin(_param0 thrift.TType, _param1 int) error { + ret := _m.ctrl.Call(_m, "WriteSetBegin", _param0, _param1) + ret0, _ := ret[0].(error) + return ret0 +} + +func (_mr *_MockTProtocolRecorder) WriteSetBegin(arg0, arg1 interface{}) *gomock.Call { + return _mr.mock.ctrl.RecordCall(_mr.mock, "WriteSetBegin", arg0, arg1) +} + +func (_m *MockTProtocol) WriteSetEnd() error { + ret := _m.ctrl.Call(_m, "WriteSetEnd") + ret0, _ := ret[0].(error) + return ret0 +} + +func (_mr *_MockTProtocolRecorder) WriteSetEnd() *gomock.Call { + return _mr.mock.ctrl.RecordCall(_mr.mock, "WriteSetEnd") +} + +func (_m *MockTProtocol) WriteString(_param0 string) error { + ret := _m.ctrl.Call(_m, "WriteString", _param0) + ret0, _ := ret[0].(error) + return ret0 +} + +func (_mr *_MockTProtocolRecorder) WriteString(arg0 interface{}) *gomock.Call { + return _mr.mock.ctrl.RecordCall(_mr.mock, "WriteString", arg0) +} + +func (_m *MockTProtocol) WriteStructBegin(_param0 string) error { + ret := _m.ctrl.Call(_m, "WriteStructBegin", _param0) + ret0, _ := ret[0].(error) + return ret0 +} + +func (_mr *_MockTProtocolRecorder) WriteStructBegin(arg0 interface{}) *gomock.Call { + return _mr.mock.ctrl.RecordCall(_mr.mock, "WriteStructBegin", arg0) +} + +func (_m *MockTProtocol) WriteStructEnd() error { + ret := _m.ctrl.Call(_m, "WriteStructEnd") + ret0, _ := ret[0].(error) + return ret0 +} + +func (_mr *_MockTProtocolRecorder) WriteStructEnd() *gomock.Call { + return _mr.mock.ctrl.RecordCall(_mr.mock, "WriteStructEnd") +} diff --git a/src/jaegertracing/thrift/lib/go/test/tests/protocols_test.go b/src/jaegertracing/thrift/lib/go/test/tests/protocols_test.go new file mode 100644 index 000000000..cffd9c3f7 --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/test/tests/protocols_test.go @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package tests + +import ( + "testing" + "thrift" + "thrifttest" +) + +func RunSocketTestSuite(t *testing.T, protocolFactory thrift.TProtocolFactory, + transportFactory thrift.TTransportFactory) { + // server + var err error + addr = FindAvailableTCPServerPort() + serverTransport, err := thrift.NewTServerSocketTimeout(addr.String(), TIMEOUT) + if err != nil { + t.Fatal("Unable to create server socket", err) + } + processor := thrifttest.NewThriftTestProcessor(NewThriftTestHandler()) + server = thrift.NewTSimpleServer4(processor, serverTransport, transportFactory, protocolFactory) + server.Listen() + + go server.Serve() + + // client + var transport thrift.TTransport = thrift.NewTSocketFromAddrTimeout(addr, TIMEOUT) + transport, err = transportFactory.GetTransport(transport) + if err != nil { + t.Fatal(err) + } + var protocol thrift.TProtocol = protocolFactory.GetProtocol(transport) + thriftTestClient := thrifttest.NewThriftTestClient(thrift.NewTStandardClient(protocol, protocol)) + err = transport.Open() + if err != nil { + t.Fatal("Unable to open client socket", err) + } + + driver := NewThriftTestDriver(t, thriftTestClient) + driver.Start() +} + +// Run test suite using TJSONProtocol +func TestTJSONProtocol(t *testing.T) { + RunSocketTestSuite(t, + thrift.NewTJSONProtocolFactory(), + thrift.NewTTransportFactory()) + RunSocketTestSuite(t, + thrift.NewTJSONProtocolFactory(), + thrift.NewTBufferedTransportFactory(8912)) + RunSocketTestSuite(t, + thrift.NewTJSONProtocolFactory(), + thrift.NewTFramedTransportFactory(thrift.NewTTransportFactory())) +} + +// Run test suite using TBinaryProtocol +func TestTBinaryProtocol(t *testing.T) { + RunSocketTestSuite(t, + thrift.NewTBinaryProtocolFactoryDefault(), + thrift.NewTTransportFactory()) + RunSocketTestSuite(t, + thrift.NewTBinaryProtocolFactoryDefault(), + thrift.NewTBufferedTransportFactory(8912)) + RunSocketTestSuite(t, + thrift.NewTBinaryProtocolFactoryDefault(), + thrift.NewTFramedTransportFactory(thrift.NewTTransportFactory())) +} + +// Run test suite using TCompactBinaryProtocol +func TestTCompactProtocol(t *testing.T) { + RunSocketTestSuite(t, + thrift.NewTCompactProtocolFactory(), + thrift.NewTTransportFactory()) + RunSocketTestSuite(t, + thrift.NewTCompactProtocolFactory(), + thrift.NewTBufferedTransportFactory(8912)) + RunSocketTestSuite(t, + thrift.NewTCompactProtocolFactory(), + thrift.NewTFramedTransportFactory(thrift.NewTTransportFactory())) +} diff --git a/src/jaegertracing/thrift/lib/go/test/tests/required_fields_test.go b/src/jaegertracing/thrift/lib/go/test/tests/required_fields_test.go new file mode 100644 index 000000000..3fa414ad8 --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/test/tests/required_fields_test.go @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package tests + +import ( + "context" + "github.com/golang/mock/gomock" + "optionalfieldstest" + "requiredfieldtest" + "testing" + "thrift" +) + +func TestRequiredField_SucecssWhenSet(t *testing.T) { + // create a new RequiredField instance with the required field set + source := &requiredfieldtest.RequiredField{Name: "this is a test"} + sourceData, err := thrift.NewTSerializer().Write(context.Background(), source) + if err != nil { + t.Fatalf("failed to serialize %T: %v", source, err) + } + + d := thrift.NewTDeserializer() + err = d.Read(&requiredfieldtest.RequiredField{}, sourceData) + if err != nil { + t.Fatalf("Did not expect an error when trying to deserialize the requiredfieldtest.RequiredField: %v", err) + } +} + +func TestRequiredField_ErrorWhenMissing(t *testing.T) { + // create a new OtherThing instance, without setting the required field + source := &requiredfieldtest.OtherThing{} + sourceData, err := thrift.NewTSerializer().Write(context.Background(), source) + if err != nil { + t.Fatalf("failed to serialize %T: %v", source, err) + } + + // attempt to deserialize into a different type (which should fail) + d := thrift.NewTDeserializer() + err = d.Read(&requiredfieldtest.RequiredField{}, sourceData) + if err == nil { + t.Fatal("Expected an error when trying to deserialize an object which is missing a required field") + } +} + +func TestStructReadRequiredFields(t *testing.T) { + mockCtrl := gomock.NewController(t) + protocol := NewMockTProtocol(mockCtrl) + testStruct := optionalfieldstest.NewStructC() + + // None of required fields are set + gomock.InOrder( + protocol.EXPECT().ReadStructBegin().Return("StructC", nil), + protocol.EXPECT().ReadFieldBegin().Return("_", thrift.TType(thrift.STOP), int16(1), nil), + protocol.EXPECT().ReadStructEnd().Return(nil), + ) + + err := testStruct.Read(protocol) + mockCtrl.Finish() + mockCtrl = gomock.NewController(t) + if err == nil { + t.Fatal("Expected read to fail") + } + err2, ok := err.(thrift.TProtocolException) + if !ok { + t.Fatal("Expected a TProtocolException") + } + if err2.TypeId() != thrift.INVALID_DATA { + t.Fatal("Expected INVALID_DATA TProtocolException") + } + + // One of the required fields is set + gomock.InOrder( + protocol.EXPECT().ReadStructBegin().Return("StructC", nil), + protocol.EXPECT().ReadFieldBegin().Return("I", thrift.TType(thrift.I32), int16(2), nil), + protocol.EXPECT().ReadI32().Return(int32(1), nil), + protocol.EXPECT().ReadFieldEnd().Return(nil), + protocol.EXPECT().ReadFieldBegin().Return("_", thrift.TType(thrift.STOP), int16(1), nil), + protocol.EXPECT().ReadStructEnd().Return(nil), + ) + + err = testStruct.Read(protocol) + mockCtrl.Finish() + mockCtrl = gomock.NewController(t) + if err == nil { + t.Fatal("Expected read to fail") + } + err2, ok = err.(thrift.TProtocolException) + if !ok { + t.Fatal("Expected a TProtocolException") + } + if err2.TypeId() != thrift.INVALID_DATA { + t.Fatal("Expected INVALID_DATA TProtocolException") + } + + // Both of the required fields are set + gomock.InOrder( + protocol.EXPECT().ReadStructBegin().Return("StructC", nil), + protocol.EXPECT().ReadFieldBegin().Return("i", thrift.TType(thrift.I32), int16(2), nil), + protocol.EXPECT().ReadI32().Return(int32(1), nil), + protocol.EXPECT().ReadFieldEnd().Return(nil), + protocol.EXPECT().ReadFieldBegin().Return("s2", thrift.TType(thrift.STRING), int16(4), nil), + protocol.EXPECT().ReadString().Return("test", nil), + protocol.EXPECT().ReadFieldEnd().Return(nil), + protocol.EXPECT().ReadFieldBegin().Return("_", thrift.TType(thrift.STOP), int16(1), nil), + protocol.EXPECT().ReadStructEnd().Return(nil), + ) + + err = testStruct.Read(protocol) + mockCtrl.Finish() + if err != nil { + t.Fatal("Expected read to succeed") + } +} diff --git a/src/jaegertracing/thrift/lib/go/test/tests/struct_args_rets_test.go b/src/jaegertracing/thrift/lib/go/test/tests/struct_args_rets_test.go new file mode 100644 index 000000000..81e9b2658 --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/test/tests/struct_args_rets_test.go @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package tests + +import ( + st "servicestest" +) + +//this function is never called, it will fail to compile if check is failed +func staticCheckStructArgsResults() { + //Check that struct args and results are passed by reference + var sa *st.StructA = &st.StructA{} + var iface st.AServ + var err error + + sa, err = iface.StructAFunc_1structA(defaultCtx, sa) + _ = err + _ = sa +} diff --git a/src/jaegertracing/thrift/lib/go/test/tests/thrifttest_driver.go b/src/jaegertracing/thrift/lib/go/test/tests/thrifttest_driver.go new file mode 100644 index 000000000..de54cbcaa --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/test/tests/thrifttest_driver.go @@ -0,0 +1,236 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * 'License'); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package tests + +import ( + "reflect" + "testing" + "thrifttest" +) + +type ThriftTestDriver struct { + client thrifttest.ThriftTest + t *testing.T +} + +func NewThriftTestDriver(t *testing.T, client thrifttest.ThriftTest) *ThriftTestDriver { + return &ThriftTestDriver{client, t} +} + +func (p *ThriftTestDriver) Start() { + client := p.client + t := p.t + + if client.TestVoid(defaultCtx) != nil { + t.Fatal("TestVoid failed") + } + + if r, err := client.TestString(defaultCtx, "Test"); r != "Test" || err != nil { + t.Fatal("TestString with simple text failed") + } + + if r, err := client.TestString(defaultCtx, ""); r != "" || err != nil { + t.Fatal("TestString with empty text failed") + } + + stringTest := "Afrikaans, Alemannisch, Aragonés, العربية, مصرى, " + + "Asturianu, Aymar aru, Azərbaycan, Башҡорт, Boarisch, Žemaitėška, " + + "Беларуская, Беларуская (тарашкевіца), Български, Bamanankan, " + + "বাংলা, Brezhoneg, Bosanski, Català, Mìng-dĕ̤ng-ngṳ̄, Нохчийн, " + + "Cebuano, ᏣᎳᎩ, Česky, Словѣ́ньскъ / ⰔⰎⰑⰂⰡⰐⰠⰔⰍⰟ, Чӑвашла, Cymraeg, " + + "Dansk, Zazaki, ދިވެހިބަސް, Ελληνικά, Emiliàn e rumagnòl, English, " + + "Esperanto, Español, Eesti, Euskara, فارسی, Suomi, Võro, Føroyskt, " + + "Français, Arpetan, Furlan, Frysk, Gaeilge, 贛語, Gàidhlig, Galego, " + + "Avañe'ẽ, ગુજરાતી, Gaelg, עברית, हिन्दी, Fiji Hindi, Hrvatski, " + + "Kreyòl ayisyen, Magyar, Հայերեն, Interlingua, Bahasa Indonesia, " + + "Ilokano, Ido, Íslenska, Italiano, 日本語, Lojban, Basa Jawa, " + + "ქართული, Kongo, Kalaallisut, ಕನ್ನಡ, 한국어, Къарачай-Малкъар, " + + "Ripoarisch, Kurdî, Коми, Kernewek, Кыргызча, Latina, Ladino, " + + "Lëtzebuergesch, Limburgs, Lingála, ລາວ, Lietuvių, Latviešu, Basa " + + "Banyumasan, Malagasy, Македонски, മലയാളം, मराठी, مازِرونی, Bahasa " + + "Melayu, Nnapulitano, Nedersaksisch, नेपाल भाषा, Nederlands, " + + "Norsk (nynorsk), Norsk (bokmål), Nouormand, Diné bizaad, " + + "Occitan, Иронау, Papiamentu, Deitsch, Polski, پنجابی, پښتو, " + + "Norfuk / Pitkern, Português, Runa Simi, Rumantsch, Romani, Română, " + + "Русский, Саха тыла, Sardu, Sicilianu, Scots, Sámegiella, Simple " + + "English, Slovenčina, Slovenščina, Српски / Srpski, Seeltersk, " + + "Svenska, Kiswahili, தமிழ், తెలుగు, Тоҷикӣ, ไทย, Türkmençe, Tagalog, " + + "Türkçe, Татарча/Tatarça, Українська, اردو, Tiếng Việt, Volapük, " + + "Walon, Winaray, 吴语, isiXhosa, ייִדיש, Yorùbá, Zeêuws, 中文, " + + "Bân-lâm-gú, 粵語" + + if r, err := client.TestString(defaultCtx, stringTest); r != stringTest || err != nil { + t.Fatal("TestString with all languages failed") + } + + specialCharacters := "quote: \" backslash:" + + " backspace: \b formfeed: \f newline: \n return: \r tab: " + + " now-all-of-them-together: '\\\b\n\r\t'" + + " now-a-bunch-of-junk: !@#$%&()(&%$#{}{}<><><" + + " char-to-test-json-parsing: ]] \"]] \\\" }}}{ [[[ " + + if r, err := client.TestString(defaultCtx, specialCharacters); r != specialCharacters || err != nil { + t.Fatal("TestString with specialCharacters failed") + } + + if r, err := client.TestByte(defaultCtx, 1); r != 1 || err != nil { + t.Fatal("TestByte(1) failed") + } + if r, err := client.TestByte(defaultCtx, 0); r != 0 || err != nil { + t.Fatal("TestByte(0) failed") + } + if r, err := client.TestByte(defaultCtx, -1); r != -1 || err != nil { + t.Fatal("TestByte(-1) failed") + } + if r, err := client.TestByte(defaultCtx, -127); r != -127 || err != nil { + t.Fatal("TestByte(-127) failed") + } + + if r, err := client.TestI32(defaultCtx, -1); r != -1 || err != nil { + t.Fatal("TestI32(-1) failed") + } + if r, err := client.TestI32(defaultCtx, 1); r != 1 || err != nil { + t.Fatal("TestI32(1) failed") + } + + if r, err := client.TestI64(defaultCtx, -5); r != -5 || err != nil { + t.Fatal("TestI64(-5) failed") + } + if r, err := client.TestI64(defaultCtx, 5); r != 5 || err != nil { + t.Fatal("TestI64(5) failed") + } + if r, err := client.TestI64(defaultCtx, -34359738368); r != -34359738368 || err != nil { + t.Fatal("TestI64(-34359738368) failed") + } + + if r, err := client.TestDouble(defaultCtx, -5.2098523); r != -5.2098523 || err != nil { + t.Fatal("TestDouble(-5.2098523) failed") + } + if r, err := client.TestDouble(defaultCtx, -7.012052175215044); r != -7.012052175215044 || err != nil { + t.Fatal("TestDouble(-7.012052175215044) failed") + } + + // TODO: add testBinary() call + + out := thrifttest.NewXtruct() + out.StringThing = "Zero" + out.ByteThing = 1 + out.I32Thing = -3 + out.I64Thing = 1000000 + if r, err := client.TestStruct(defaultCtx, out); !reflect.DeepEqual(r, out) || err != nil { + t.Fatal("TestStruct failed") + } + + out2 := thrifttest.NewXtruct2() + out2.ByteThing = 1 + out2.StructThing = out + out2.I32Thing = 5 + if r, err := client.TestNest(defaultCtx, out2); !reflect.DeepEqual(r, out2) || err != nil { + t.Fatal("TestNest failed") + } + + mapout := make(map[int32]int32) + for i := int32(0); i < 5; i++ { + mapout[i] = i - 10 + } + if r, err := client.TestMap(defaultCtx, mapout); !reflect.DeepEqual(r, mapout) || err != nil { + t.Fatal("TestMap failed") + } + + mapTestInput := map[string]string{ + "a": "123", "a b": "with spaces ", "same": "same", "0": "numeric key", + "longValue": stringTest, stringTest: "long key", + } + if r, err := client.TestStringMap(defaultCtx, mapTestInput); !reflect.DeepEqual(r, mapTestInput) || err != nil { + t.Fatal("TestStringMap failed") + } + + setTestInput := []int32{1, 2, 3} + if r, err := client.TestSet(defaultCtx, setTestInput); !reflect.DeepEqual(r, setTestInput) || err != nil { + t.Fatal("TestSet failed") + } + + listTest := []int32{1, 2, 3} + if r, err := client.TestList(defaultCtx, listTest); !reflect.DeepEqual(r, listTest) || err != nil { + t.Fatal("TestList failed") + } + + if r, err := client.TestEnum(defaultCtx, thrifttest.Numberz_ONE); r != thrifttest.Numberz_ONE || err != nil { + t.Fatal("TestEnum failed") + } + + if r, err := client.TestTypedef(defaultCtx, 69); r != 69 || err != nil { + t.Fatal("TestTypedef failed") + } + + mapMapTest := map[int32]map[int32]int32{ + 4: {1: 1, 2: 2, 3: 3, 4: 4}, + -4: {-4: -4, -3: -3, -2: -2, -1: -1}, + } + if r, err := client.TestMapMap(defaultCtx, 1); !reflect.DeepEqual(r, mapMapTest) || err != nil { + t.Fatal("TestMapMap failed") + } + + crazyX1 := thrifttest.NewXtruct() + crazyX1.StringThing = "Goodbye4" + crazyX1.ByteThing = 4 + crazyX1.I32Thing = 4 + crazyX1.I64Thing = 4 + + crazyX2 := thrifttest.NewXtruct() + crazyX2.StringThing = "Hello2" + crazyX2.ByteThing = 2 + crazyX2.I32Thing = 2 + crazyX2.I64Thing = 2 + + crazy := thrifttest.NewInsanity() + crazy.UserMap = map[thrifttest.Numberz]thrifttest.UserId{5: 5, 8: 8} + crazy.Xtructs = []*thrifttest.Xtruct{crazyX1, crazyX2} + + crazyEmpty := thrifttest.NewInsanity() + crazyEmpty.UserMap = map[thrifttest.Numberz]thrifttest.UserId{} + crazyEmpty.Xtructs = []*thrifttest.Xtruct{} + + insanity := map[thrifttest.UserId]map[thrifttest.Numberz]*thrifttest.Insanity{ + 1: {thrifttest.Numberz_TWO: crazy, thrifttest.Numberz_THREE: crazy}, + 2: {thrifttest.Numberz_SIX: crazyEmpty}, + } + if r, err := client.TestInsanity(defaultCtx, crazy); !reflect.DeepEqual(r, insanity) || err != nil { + t.Fatal("TestInsanity failed") + } + + if err := client.TestException(defaultCtx, "TException"); err == nil { + t.Fatal("TestException TException failed") + } + + if err, ok := client.TestException(defaultCtx, "Xception").(*thrifttest.Xception); ok == false || err == nil { + t.Fatal("TestException Xception failed") + } else if err.ErrorCode != 1001 || err.Message != "Xception" { + t.Fatal("TestException Xception failed") + } + + if err := client.TestException(defaultCtx, "no Exception"); err != nil { + t.Fatal("TestException no Exception failed") + } + + if err := client.TestOneway(defaultCtx, 0); err != nil { + t.Fatal("TestOneway failed") + } +} diff --git a/src/jaegertracing/thrift/lib/go/test/tests/thrifttest_handler.go b/src/jaegertracing/thrift/lib/go/test/tests/thrifttest_handler.go new file mode 100644 index 000000000..31b9ee23e --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/test/tests/thrifttest_handler.go @@ -0,0 +1,210 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * 'License'); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package tests + +import ( + "context" + "errors" + "thrift" + "thrifttest" + "time" +) + +type SecondServiceHandler struct { +} + +func NewSecondServiceHandler() *SecondServiceHandler { + return &SecondServiceHandler{} +} + +func (p *SecondServiceHandler) BlahBlah(ctx context.Context) (err error) { + return nil +} + +func (p *SecondServiceHandler) SecondtestString(ctx context.Context, thing string) (r string, err error) { + return thing, nil +} + +type ThriftTestHandler struct { +} + +func NewThriftTestHandler() *ThriftTestHandler { + return &ThriftTestHandler{} +} + +func (p *ThriftTestHandler) TestVoid(ctx context.Context) (err error) { + return nil +} + +func (p *ThriftTestHandler) TestString(ctx context.Context, thing string) (r string, err error) { + return thing, nil +} + +func (p *ThriftTestHandler) TestBool(ctx context.Context, thing bool) (r bool, err error) { + return thing, nil +} + +func (p *ThriftTestHandler) TestByte(ctx context.Context, thing int8) (r int8, err error) { + return thing, nil +} + +func (p *ThriftTestHandler) TestI32(ctx context.Context, thing int32) (r int32, err error) { + return thing, nil +} + +func (p *ThriftTestHandler) TestI64(ctx context.Context, thing int64) (r int64, err error) { + return thing, nil +} + +func (p *ThriftTestHandler) TestDouble(ctx context.Context, thing float64) (r float64, err error) { + return thing, nil +} + +func (p *ThriftTestHandler) TestBinary(ctx context.Context, thing []byte) (r []byte, err error) { + return thing, nil +} + +func (p *ThriftTestHandler) TestStruct(ctx context.Context, thing *thrifttest.Xtruct) (r *thrifttest.Xtruct, err error) { + return thing, nil +} + +func (p *ThriftTestHandler) TestNest(ctx context.Context, thing *thrifttest.Xtruct2) (r *thrifttest.Xtruct2, err error) { + return thing, nil +} + +func (p *ThriftTestHandler) TestMap(ctx context.Context, thing map[int32]int32) (r map[int32]int32, err error) { + return thing, nil +} + +func (p *ThriftTestHandler) TestStringMap(ctx context.Context, thing map[string]string) (r map[string]string, err error) { + return thing, nil +} + +func (p *ThriftTestHandler) TestSet(ctx context.Context, thing []int32) (r []int32, err error) { + return thing, nil +} + +func (p *ThriftTestHandler) TestList(ctx context.Context, thing []int32) (r []int32, err error) { + return thing, nil +} + +func (p *ThriftTestHandler) TestEnum(ctx context.Context, thing thrifttest.Numberz) (r thrifttest.Numberz, err error) { + return thing, nil +} + +func (p *ThriftTestHandler) TestTypedef(ctx context.Context, thing thrifttest.UserId) (r thrifttest.UserId, err error) { + return thing, nil +} + +func (p *ThriftTestHandler) TestMapMap(ctx context.Context, hello int32) (r map[int32]map[int32]int32, err error) { + r = make(map[int32]map[int32]int32) + pos := make(map[int32]int32) + neg := make(map[int32]int32) + + for i := int32(1); i < 5; i++ { + pos[i] = i + neg[-i] = -i + } + r[4] = pos + r[-4] = neg + + return r, nil +} + +func (p *ThriftTestHandler) TestInsanity(ctx context.Context, argument *thrifttest.Insanity) (r map[thrifttest.UserId]map[thrifttest.Numberz]*thrifttest.Insanity, err error) { + hello := thrifttest.NewXtruct() + hello.StringThing = "Hello2" + hello.ByteThing = 2 + hello.I32Thing = 2 + hello.I64Thing = 2 + + goodbye := thrifttest.NewXtruct() + goodbye.StringThing = "Goodbye4" + goodbye.ByteThing = 4 + goodbye.I32Thing = 4 + goodbye.I64Thing = 4 + + crazy := thrifttest.NewInsanity() + crazy.UserMap = make(map[thrifttest.Numberz]thrifttest.UserId) + crazy.UserMap[thrifttest.Numberz_EIGHT] = 8 + crazy.UserMap[thrifttest.Numberz_FIVE] = 5 + crazy.Xtructs = []*thrifttest.Xtruct{goodbye, hello} + + first_map := make(map[thrifttest.Numberz]*thrifttest.Insanity) + second_map := make(map[thrifttest.Numberz]*thrifttest.Insanity) + + first_map[thrifttest.Numberz_TWO] = crazy + first_map[thrifttest.Numberz_THREE] = crazy + + looney := thrifttest.NewInsanity() + second_map[thrifttest.Numberz_SIX] = looney + + var insane = make(map[thrifttest.UserId]map[thrifttest.Numberz]*thrifttest.Insanity) + insane[1] = first_map + insane[2] = second_map + + return insane, nil +} + +func (p *ThriftTestHandler) TestMulti(ctx context.Context, arg0 int8, arg1 int32, arg2 int64, arg3 map[int16]string, arg4 thrifttest.Numberz, arg5 thrifttest.UserId) (r *thrifttest.Xtruct, err error) { + r = thrifttest.NewXtruct() + r.StringThing = "Hello2" + r.ByteThing = arg0 + r.I32Thing = arg1 + r.I64Thing = arg2 + return r, nil +} + +func (p *ThriftTestHandler) TestException(ctx context.Context, arg string) (err error) { + if arg == "Xception" { + x := thrifttest.NewXception() + x.ErrorCode = 1001 + x.Message = arg + return x + } else if arg == "TException" { + return thrift.TException(errors.New(arg)) + } else { + return nil + } +} + +func (p *ThriftTestHandler) TestMultiException(ctx context.Context, arg0 string, arg1 string) (r *thrifttest.Xtruct, err error) { + if arg0 == "Xception" { + x := thrifttest.NewXception() + x.ErrorCode = 1001 + x.Message = "This is an Xception" + return nil, x + } else if arg0 == "Xception2" { + x2 := thrifttest.NewXception2() + x2.ErrorCode = 2002 + x2.StructThing = thrifttest.NewXtruct() + x2.StructThing.StringThing = "This is an Xception2" + return nil, x2 + } + + res := thrifttest.NewXtruct() + res.StringThing = arg1 + return res, nil +} + +func (p *ThriftTestHandler) TestOneway(ctx context.Context, secondsToSleep int32) (err error) { + time.Sleep(time.Second * time.Duration(secondsToSleep)) + return nil +} diff --git a/src/jaegertracing/thrift/lib/go/test/tests/union_binary_test.go b/src/jaegertracing/thrift/lib/go/test/tests/union_binary_test.go new file mode 100644 index 000000000..bdae2cb92 --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/test/tests/union_binary_test.go @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package tests + +import ( + "testing" + "unionbinarytest" +) + + +// See https://issues.apache.org/jira/browse/THRIFT-4573 +func TestUnionBinary(t *testing.T) { + s := unionbinarytest.NewSample() + s.U1 = map[string]string{} + s.U2 = []byte{} + if n := s.CountSetFieldsSample(); n != 2 { + t.Errorf("Expected 2 set fields, got %d!", n) + } +} diff --git a/src/jaegertracing/thrift/lib/go/test/tests/union_default_value_test.go b/src/jaegertracing/thrift/lib/go/test/tests/union_default_value_test.go new file mode 100644 index 000000000..2dcbf4e44 --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/test/tests/union_default_value_test.go @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package tests + +import ( + "testing" + "uniondefaultvaluetest" +) + +func TestUnionDefaultValue(t *testing.T) { + s := uniondefaultvaluetest.NewTestStruct() + d := s.GetDescendant() + if d == nil { + t.Error("Default Union value not set!") + } +} diff --git a/src/jaegertracing/thrift/lib/go/thrift/application_exception.go b/src/jaegertracing/thrift/lib/go/thrift/application_exception.go new file mode 100644 index 000000000..0023c57cf --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/thrift/application_exception.go @@ -0,0 +1,170 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +const ( + UNKNOWN_APPLICATION_EXCEPTION = 0 + UNKNOWN_METHOD = 1 + INVALID_MESSAGE_TYPE_EXCEPTION = 2 + WRONG_METHOD_NAME = 3 + BAD_SEQUENCE_ID = 4 + MISSING_RESULT = 5 + INTERNAL_ERROR = 6 + PROTOCOL_ERROR = 7 + INVALID_TRANSFORM = 8 + INVALID_PROTOCOL = 9 + UNSUPPORTED_CLIENT_TYPE = 10 +) + +var defaultApplicationExceptionMessage = map[int32]string{ + UNKNOWN_APPLICATION_EXCEPTION: "unknown application exception", + UNKNOWN_METHOD: "unknown method", + INVALID_MESSAGE_TYPE_EXCEPTION: "invalid message type", + WRONG_METHOD_NAME: "wrong method name", + BAD_SEQUENCE_ID: "bad sequence ID", + MISSING_RESULT: "missing result", + INTERNAL_ERROR: "unknown internal error", + PROTOCOL_ERROR: "unknown protocol error", + INVALID_TRANSFORM: "Invalid transform", + INVALID_PROTOCOL: "Invalid protocol", + UNSUPPORTED_CLIENT_TYPE: "Unsupported client type", +} + +// Application level Thrift exception +type TApplicationException interface { + TException + TypeId() int32 + Read(iprot TProtocol) error + Write(oprot TProtocol) error +} + +type tApplicationException struct { + message string + type_ int32 +} + +func (e tApplicationException) Error() string { + if e.message != "" { + return e.message + } + return defaultApplicationExceptionMessage[e.type_] +} + +func NewTApplicationException(type_ int32, message string) TApplicationException { + return &tApplicationException{message, type_} +} + +func (p *tApplicationException) TypeId() int32 { + return p.type_ +} + +func (p *tApplicationException) Read(iprot TProtocol) error { + // TODO: this should really be generated by the compiler + _, err := iprot.ReadStructBegin() + if err != nil { + return err + } + + message := "" + type_ := int32(UNKNOWN_APPLICATION_EXCEPTION) + + for { + _, ttype, id, err := iprot.ReadFieldBegin() + if err != nil { + return err + } + if ttype == STOP { + break + } + switch id { + case 1: + if ttype == STRING { + if message, err = iprot.ReadString(); err != nil { + return err + } + } else { + if err = SkipDefaultDepth(iprot, ttype); err != nil { + return err + } + } + case 2: + if ttype == I32 { + if type_, err = iprot.ReadI32(); err != nil { + return err + } + } else { + if err = SkipDefaultDepth(iprot, ttype); err != nil { + return err + } + } + default: + if err = SkipDefaultDepth(iprot, ttype); err != nil { + return err + } + } + if err = iprot.ReadFieldEnd(); err != nil { + return err + } + } + if err := iprot.ReadStructEnd(); err != nil { + return err + } + + p.message = message + p.type_ = type_ + + return nil +} + +func (p *tApplicationException) Write(oprot TProtocol) (err error) { + err = oprot.WriteStructBegin("TApplicationException") + if len(p.Error()) > 0 { + err = oprot.WriteFieldBegin("message", STRING, 1) + if err != nil { + return + } + err = oprot.WriteString(p.Error()) + if err != nil { + return + } + err = oprot.WriteFieldEnd() + if err != nil { + return + } + } + err = oprot.WriteFieldBegin("type", I32, 2) + if err != nil { + return + } + err = oprot.WriteI32(p.type_) + if err != nil { + return + } + err = oprot.WriteFieldEnd() + if err != nil { + return + } + err = oprot.WriteFieldStop() + if err != nil { + return + } + err = oprot.WriteStructEnd() + return +} diff --git a/src/jaegertracing/thrift/lib/go/thrift/application_exception_test.go b/src/jaegertracing/thrift/lib/go/thrift/application_exception_test.go new file mode 100644 index 000000000..77433575d --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/thrift/application_exception_test.go @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import ( + "testing" +) + +func TestTApplicationException(t *testing.T) { + exc := NewTApplicationException(UNKNOWN_APPLICATION_EXCEPTION, "") + if exc.Error() != defaultApplicationExceptionMessage[UNKNOWN_APPLICATION_EXCEPTION] { + t.Fatalf("Expected empty string for exception but found '%s'", exc.Error()) + } + if exc.TypeId() != UNKNOWN_APPLICATION_EXCEPTION { + t.Fatalf("Expected type UNKNOWN for exception but found '%v'", exc.TypeId()) + } + exc = NewTApplicationException(WRONG_METHOD_NAME, "junk_method") + if exc.Error() != "junk_method" { + t.Fatalf("Expected 'junk_method' for exception but found '%s'", exc.Error()) + } + if exc.TypeId() != WRONG_METHOD_NAME { + t.Fatalf("Expected type WRONG_METHOD_NAME for exception but found '%v'", exc.TypeId()) + } +} diff --git a/src/jaegertracing/thrift/lib/go/thrift/binary_protocol.go b/src/jaegertracing/thrift/lib/go/thrift/binary_protocol.go new file mode 100644 index 000000000..93ae898cf --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/thrift/binary_protocol.go @@ -0,0 +1,505 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import ( + "bytes" + "context" + "encoding/binary" + "errors" + "fmt" + "io" + "math" +) + +type TBinaryProtocol struct { + trans TRichTransport + origTransport TTransport + strictRead bool + strictWrite bool + buffer [64]byte +} + +type TBinaryProtocolFactory struct { + strictRead bool + strictWrite bool +} + +func NewTBinaryProtocolTransport(t TTransport) *TBinaryProtocol { + return NewTBinaryProtocol(t, false, true) +} + +func NewTBinaryProtocol(t TTransport, strictRead, strictWrite bool) *TBinaryProtocol { + p := &TBinaryProtocol{origTransport: t, strictRead: strictRead, strictWrite: strictWrite} + if et, ok := t.(TRichTransport); ok { + p.trans = et + } else { + p.trans = NewTRichTransport(t) + } + return p +} + +func NewTBinaryProtocolFactoryDefault() *TBinaryProtocolFactory { + return NewTBinaryProtocolFactory(false, true) +} + +func NewTBinaryProtocolFactory(strictRead, strictWrite bool) *TBinaryProtocolFactory { + return &TBinaryProtocolFactory{strictRead: strictRead, strictWrite: strictWrite} +} + +func (p *TBinaryProtocolFactory) GetProtocol(t TTransport) TProtocol { + return NewTBinaryProtocol(t, p.strictRead, p.strictWrite) +} + +/** + * Writing Methods + */ + +func (p *TBinaryProtocol) WriteMessageBegin(name string, typeId TMessageType, seqId int32) error { + if p.strictWrite { + version := uint32(VERSION_1) | uint32(typeId) + e := p.WriteI32(int32(version)) + if e != nil { + return e + } + e = p.WriteString(name) + if e != nil { + return e + } + e = p.WriteI32(seqId) + return e + } else { + e := p.WriteString(name) + if e != nil { + return e + } + e = p.WriteByte(int8(typeId)) + if e != nil { + return e + } + e = p.WriteI32(seqId) + return e + } + return nil +} + +func (p *TBinaryProtocol) WriteMessageEnd() error { + return nil +} + +func (p *TBinaryProtocol) WriteStructBegin(name string) error { + return nil +} + +func (p *TBinaryProtocol) WriteStructEnd() error { + return nil +} + +func (p *TBinaryProtocol) WriteFieldBegin(name string, typeId TType, id int16) error { + e := p.WriteByte(int8(typeId)) + if e != nil { + return e + } + e = p.WriteI16(id) + return e +} + +func (p *TBinaryProtocol) WriteFieldEnd() error { + return nil +} + +func (p *TBinaryProtocol) WriteFieldStop() error { + e := p.WriteByte(STOP) + return e +} + +func (p *TBinaryProtocol) WriteMapBegin(keyType TType, valueType TType, size int) error { + e := p.WriteByte(int8(keyType)) + if e != nil { + return e + } + e = p.WriteByte(int8(valueType)) + if e != nil { + return e + } + e = p.WriteI32(int32(size)) + return e +} + +func (p *TBinaryProtocol) WriteMapEnd() error { + return nil +} + +func (p *TBinaryProtocol) WriteListBegin(elemType TType, size int) error { + e := p.WriteByte(int8(elemType)) + if e != nil { + return e + } + e = p.WriteI32(int32(size)) + return e +} + +func (p *TBinaryProtocol) WriteListEnd() error { + return nil +} + +func (p *TBinaryProtocol) WriteSetBegin(elemType TType, size int) error { + e := p.WriteByte(int8(elemType)) + if e != nil { + return e + } + e = p.WriteI32(int32(size)) + return e +} + +func (p *TBinaryProtocol) WriteSetEnd() error { + return nil +} + +func (p *TBinaryProtocol) WriteBool(value bool) error { + if value { + return p.WriteByte(1) + } + return p.WriteByte(0) +} + +func (p *TBinaryProtocol) WriteByte(value int8) error { + e := p.trans.WriteByte(byte(value)) + return NewTProtocolException(e) +} + +func (p *TBinaryProtocol) WriteI16(value int16) error { + v := p.buffer[0:2] + binary.BigEndian.PutUint16(v, uint16(value)) + _, e := p.trans.Write(v) + return NewTProtocolException(e) +} + +func (p *TBinaryProtocol) WriteI32(value int32) error { + v := p.buffer[0:4] + binary.BigEndian.PutUint32(v, uint32(value)) + _, e := p.trans.Write(v) + return NewTProtocolException(e) +} + +func (p *TBinaryProtocol) WriteI64(value int64) error { + v := p.buffer[0:8] + binary.BigEndian.PutUint64(v, uint64(value)) + _, err := p.trans.Write(v) + return NewTProtocolException(err) +} + +func (p *TBinaryProtocol) WriteDouble(value float64) error { + return p.WriteI64(int64(math.Float64bits(value))) +} + +func (p *TBinaryProtocol) WriteString(value string) error { + e := p.WriteI32(int32(len(value))) + if e != nil { + return e + } + _, err := p.trans.WriteString(value) + return NewTProtocolException(err) +} + +func (p *TBinaryProtocol) WriteBinary(value []byte) error { + e := p.WriteI32(int32(len(value))) + if e != nil { + return e + } + _, err := p.trans.Write(value) + return NewTProtocolException(err) +} + +/** + * Reading methods + */ + +func (p *TBinaryProtocol) ReadMessageBegin() (name string, typeId TMessageType, seqId int32, err error) { + size, e := p.ReadI32() + if e != nil { + return "", typeId, 0, NewTProtocolException(e) + } + if size < 0 { + typeId = TMessageType(size & 0x0ff) + version := int64(int64(size) & VERSION_MASK) + if version != VERSION_1 { + return name, typeId, seqId, NewTProtocolExceptionWithType(BAD_VERSION, fmt.Errorf("Bad version in ReadMessageBegin")) + } + name, e = p.ReadString() + if e != nil { + return name, typeId, seqId, NewTProtocolException(e) + } + seqId, e = p.ReadI32() + if e != nil { + return name, typeId, seqId, NewTProtocolException(e) + } + return name, typeId, seqId, nil + } + if p.strictRead { + return name, typeId, seqId, NewTProtocolExceptionWithType(BAD_VERSION, fmt.Errorf("Missing version in ReadMessageBegin")) + } + name, e2 := p.readStringBody(size) + if e2 != nil { + return name, typeId, seqId, e2 + } + b, e3 := p.ReadByte() + if e3 != nil { + return name, typeId, seqId, e3 + } + typeId = TMessageType(b) + seqId, e4 := p.ReadI32() + if e4 != nil { + return name, typeId, seqId, e4 + } + return name, typeId, seqId, nil +} + +func (p *TBinaryProtocol) ReadMessageEnd() error { + return nil +} + +func (p *TBinaryProtocol) ReadStructBegin() (name string, err error) { + return +} + +func (p *TBinaryProtocol) ReadStructEnd() error { + return nil +} + +func (p *TBinaryProtocol) ReadFieldBegin() (name string, typeId TType, seqId int16, err error) { + t, err := p.ReadByte() + typeId = TType(t) + if err != nil { + return name, typeId, seqId, err + } + if t != STOP { + seqId, err = p.ReadI16() + } + return name, typeId, seqId, err +} + +func (p *TBinaryProtocol) ReadFieldEnd() error { + return nil +} + +var invalidDataLength = NewTProtocolExceptionWithType(INVALID_DATA, errors.New("Invalid data length")) + +func (p *TBinaryProtocol) ReadMapBegin() (kType, vType TType, size int, err error) { + k, e := p.ReadByte() + if e != nil { + err = NewTProtocolException(e) + return + } + kType = TType(k) + v, e := p.ReadByte() + if e != nil { + err = NewTProtocolException(e) + return + } + vType = TType(v) + size32, e := p.ReadI32() + if e != nil { + err = NewTProtocolException(e) + return + } + if size32 < 0 { + err = invalidDataLength + return + } + size = int(size32) + return kType, vType, size, nil +} + +func (p *TBinaryProtocol) ReadMapEnd() error { + return nil +} + +func (p *TBinaryProtocol) ReadListBegin() (elemType TType, size int, err error) { + b, e := p.ReadByte() + if e != nil { + err = NewTProtocolException(e) + return + } + elemType = TType(b) + size32, e := p.ReadI32() + if e != nil { + err = NewTProtocolException(e) + return + } + if size32 < 0 { + err = invalidDataLength + return + } + size = int(size32) + + return +} + +func (p *TBinaryProtocol) ReadListEnd() error { + return nil +} + +func (p *TBinaryProtocol) ReadSetBegin() (elemType TType, size int, err error) { + b, e := p.ReadByte() + if e != nil { + err = NewTProtocolException(e) + return + } + elemType = TType(b) + size32, e := p.ReadI32() + if e != nil { + err = NewTProtocolException(e) + return + } + if size32 < 0 { + err = invalidDataLength + return + } + size = int(size32) + return elemType, size, nil +} + +func (p *TBinaryProtocol) ReadSetEnd() error { + return nil +} + +func (p *TBinaryProtocol) ReadBool() (bool, error) { + b, e := p.ReadByte() + v := true + if b != 1 { + v = false + } + return v, e +} + +func (p *TBinaryProtocol) ReadByte() (int8, error) { + v, err := p.trans.ReadByte() + return int8(v), err +} + +func (p *TBinaryProtocol) ReadI16() (value int16, err error) { + buf := p.buffer[0:2] + err = p.readAll(buf) + value = int16(binary.BigEndian.Uint16(buf)) + return value, err +} + +func (p *TBinaryProtocol) ReadI32() (value int32, err error) { + buf := p.buffer[0:4] + err = p.readAll(buf) + value = int32(binary.BigEndian.Uint32(buf)) + return value, err +} + +func (p *TBinaryProtocol) ReadI64() (value int64, err error) { + buf := p.buffer[0:8] + err = p.readAll(buf) + value = int64(binary.BigEndian.Uint64(buf)) + return value, err +} + +func (p *TBinaryProtocol) ReadDouble() (value float64, err error) { + buf := p.buffer[0:8] + err = p.readAll(buf) + value = math.Float64frombits(binary.BigEndian.Uint64(buf)) + return value, err +} + +func (p *TBinaryProtocol) ReadString() (value string, err error) { + size, e := p.ReadI32() + if e != nil { + return "", e + } + if size < 0 { + err = invalidDataLength + return + } + + return p.readStringBody(size) +} + +func (p *TBinaryProtocol) ReadBinary() ([]byte, error) { + size, e := p.ReadI32() + if e != nil { + return nil, e + } + if size < 0 { + return nil, invalidDataLength + } + + isize := int(size) + buf := make([]byte, isize) + _, err := io.ReadFull(p.trans, buf) + return buf, NewTProtocolException(err) +} + +func (p *TBinaryProtocol) Flush(ctx context.Context) (err error) { + return NewTProtocolException(p.trans.Flush(ctx)) +} + +func (p *TBinaryProtocol) Skip(fieldType TType) (err error) { + return SkipDefaultDepth(p, fieldType) +} + +func (p *TBinaryProtocol) Transport() TTransport { + return p.origTransport +} + +func (p *TBinaryProtocol) readAll(buf []byte) error { + _, err := io.ReadFull(p.trans, buf) + return NewTProtocolException(err) +} + +const readLimit = 32768 + +func (p *TBinaryProtocol) readStringBody(size int32) (value string, err error) { + if size < 0 { + return "", nil + } + + var ( + buf bytes.Buffer + e error + b []byte + ) + + switch { + case int(size) <= len(p.buffer): + b = p.buffer[:size] // avoids allocation for small reads + case int(size) < readLimit: + b = make([]byte, size) + default: + b = make([]byte, readLimit) + } + + for size > 0 { + _, e = io.ReadFull(p.trans, b) + buf.Write(b) + if e != nil { + break + } + size -= readLimit + if size < readLimit && size > 0 { + b = b[:size] + } + } + return buf.String(), NewTProtocolException(e) +} diff --git a/src/jaegertracing/thrift/lib/go/thrift/binary_protocol_test.go b/src/jaegertracing/thrift/lib/go/thrift/binary_protocol_test.go new file mode 100644 index 000000000..0462cc79d --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/thrift/binary_protocol_test.go @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import ( + "testing" +) + +func TestReadWriteBinaryProtocol(t *testing.T) { + ReadWriteProtocolTest(t, NewTBinaryProtocolFactoryDefault()) +} diff --git a/src/jaegertracing/thrift/lib/go/thrift/buffered_transport.go b/src/jaegertracing/thrift/lib/go/thrift/buffered_transport.go new file mode 100644 index 000000000..96702061b --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/thrift/buffered_transport.go @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import ( + "bufio" + "context" +) + +type TBufferedTransportFactory struct { + size int +} + +type TBufferedTransport struct { + bufio.ReadWriter + tp TTransport +} + +func (p *TBufferedTransportFactory) GetTransport(trans TTransport) (TTransport, error) { + return NewTBufferedTransport(trans, p.size), nil +} + +func NewTBufferedTransportFactory(bufferSize int) *TBufferedTransportFactory { + return &TBufferedTransportFactory{size: bufferSize} +} + +func NewTBufferedTransport(trans TTransport, bufferSize int) *TBufferedTransport { + return &TBufferedTransport{ + ReadWriter: bufio.ReadWriter{ + Reader: bufio.NewReaderSize(trans, bufferSize), + Writer: bufio.NewWriterSize(trans, bufferSize), + }, + tp: trans, + } +} + +func (p *TBufferedTransport) IsOpen() bool { + return p.tp.IsOpen() +} + +func (p *TBufferedTransport) Open() (err error) { + return p.tp.Open() +} + +func (p *TBufferedTransport) Close() (err error) { + return p.tp.Close() +} + +func (p *TBufferedTransport) Read(b []byte) (int, error) { + n, err := p.ReadWriter.Read(b) + if err != nil { + p.ReadWriter.Reader.Reset(p.tp) + } + return n, err +} + +func (p *TBufferedTransport) Write(b []byte) (int, error) { + n, err := p.ReadWriter.Write(b) + if err != nil { + p.ReadWriter.Writer.Reset(p.tp) + } + return n, err +} + +func (p *TBufferedTransport) Flush(ctx context.Context) error { + if err := p.ReadWriter.Flush(); err != nil { + p.ReadWriter.Writer.Reset(p.tp) + return err + } + return p.tp.Flush(ctx) +} + +func (p *TBufferedTransport) RemainingBytes() (num_bytes uint64) { + return p.tp.RemainingBytes() +} diff --git a/src/jaegertracing/thrift/lib/go/thrift/buffered_transport_test.go b/src/jaegertracing/thrift/lib/go/thrift/buffered_transport_test.go new file mode 100644 index 000000000..95ec0cbd2 --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/thrift/buffered_transport_test.go @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import ( + "testing" +) + +func TestBufferedTransport(t *testing.T) { + trans := NewTBufferedTransport(NewTMemoryBuffer(), 10240) + TransportTest(t, trans, trans) +} diff --git a/src/jaegertracing/thrift/lib/go/thrift/client.go b/src/jaegertracing/thrift/lib/go/thrift/client.go new file mode 100644 index 000000000..b073a952d --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/thrift/client.go @@ -0,0 +1,95 @@ +package thrift + +import ( + "context" + "fmt" +) + +type TClient interface { + Call(ctx context.Context, method string, args, result TStruct) error +} + +type TStandardClient struct { + seqId int32 + iprot, oprot TProtocol +} + +// TStandardClient implements TClient, and uses the standard message format for Thrift. +// It is not safe for concurrent use. +func NewTStandardClient(inputProtocol, outputProtocol TProtocol) *TStandardClient { + return &TStandardClient{ + iprot: inputProtocol, + oprot: outputProtocol, + } +} + +func (p *TStandardClient) Send(ctx context.Context, oprot TProtocol, seqId int32, method string, args TStruct) error { + // Set headers from context object on THeaderProtocol + if headerProt, ok := oprot.(*THeaderProtocol); ok { + headerProt.ClearWriteHeaders() + for _, key := range GetWriteHeaderList(ctx) { + if value, ok := GetHeader(ctx, key); ok { + headerProt.SetWriteHeader(key, value) + } + } + } + + if err := oprot.WriteMessageBegin(method, CALL, seqId); err != nil { + return err + } + if err := args.Write(oprot); err != nil { + return err + } + if err := oprot.WriteMessageEnd(); err != nil { + return err + } + return oprot.Flush(ctx) +} + +func (p *TStandardClient) Recv(iprot TProtocol, seqId int32, method string, result TStruct) error { + rMethod, rTypeId, rSeqId, err := iprot.ReadMessageBegin() + if err != nil { + return err + } + + if method != rMethod { + return NewTApplicationException(WRONG_METHOD_NAME, fmt.Sprintf("%s: wrong method name", method)) + } else if seqId != rSeqId { + return NewTApplicationException(BAD_SEQUENCE_ID, fmt.Sprintf("%s: out of order sequence response", method)) + } else if rTypeId == EXCEPTION { + var exception tApplicationException + if err := exception.Read(iprot); err != nil { + return err + } + + if err := iprot.ReadMessageEnd(); err != nil { + return err + } + + return &exception + } else if rTypeId != REPLY { + return NewTApplicationException(INVALID_MESSAGE_TYPE_EXCEPTION, fmt.Sprintf("%s: invalid message type", method)) + } + + if err := result.Read(iprot); err != nil { + return err + } + + return iprot.ReadMessageEnd() +} + +func (p *TStandardClient) Call(ctx context.Context, method string, args, result TStruct) error { + p.seqId++ + seqId := p.seqId + + if err := p.Send(ctx, p.oprot, seqId, method, args); err != nil { + return err + } + + // method is oneway + if result == nil { + return nil + } + + return p.Recv(p.iprot, seqId, method, result) +} diff --git a/src/jaegertracing/thrift/lib/go/thrift/common_test.go b/src/jaegertracing/thrift/lib/go/thrift/common_test.go new file mode 100644 index 000000000..93597ff8a --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/thrift/common_test.go @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import "context" + +type mockProcessor struct { + ProcessFunc func(in, out TProtocol) (bool, TException) +} + +func (m *mockProcessor) Process(ctx context.Context, in, out TProtocol) (bool, TException) { + return m.ProcessFunc(in, out) +} diff --git a/src/jaegertracing/thrift/lib/go/thrift/compact_protocol.go b/src/jaegertracing/thrift/lib/go/thrift/compact_protocol.go new file mode 100644 index 000000000..1900d50c3 --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/thrift/compact_protocol.go @@ -0,0 +1,810 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import ( + "context" + "encoding/binary" + "fmt" + "io" + "math" +) + +const ( + COMPACT_PROTOCOL_ID = 0x082 + COMPACT_VERSION = 1 + COMPACT_VERSION_MASK = 0x1f + COMPACT_TYPE_MASK = 0x0E0 + COMPACT_TYPE_BITS = 0x07 + COMPACT_TYPE_SHIFT_AMOUNT = 5 +) + +type tCompactType byte + +const ( + COMPACT_BOOLEAN_TRUE = 0x01 + COMPACT_BOOLEAN_FALSE = 0x02 + COMPACT_BYTE = 0x03 + COMPACT_I16 = 0x04 + COMPACT_I32 = 0x05 + COMPACT_I64 = 0x06 + COMPACT_DOUBLE = 0x07 + COMPACT_BINARY = 0x08 + COMPACT_LIST = 0x09 + COMPACT_SET = 0x0A + COMPACT_MAP = 0x0B + COMPACT_STRUCT = 0x0C +) + +var ( + ttypeToCompactType map[TType]tCompactType +) + +func init() { + ttypeToCompactType = map[TType]tCompactType{ + STOP: STOP, + BOOL: COMPACT_BOOLEAN_TRUE, + BYTE: COMPACT_BYTE, + I16: COMPACT_I16, + I32: COMPACT_I32, + I64: COMPACT_I64, + DOUBLE: COMPACT_DOUBLE, + STRING: COMPACT_BINARY, + LIST: COMPACT_LIST, + SET: COMPACT_SET, + MAP: COMPACT_MAP, + STRUCT: COMPACT_STRUCT, + } +} + +type TCompactProtocolFactory struct{} + +func NewTCompactProtocolFactory() *TCompactProtocolFactory { + return &TCompactProtocolFactory{} +} + +func (p *TCompactProtocolFactory) GetProtocol(trans TTransport) TProtocol { + return NewTCompactProtocol(trans) +} + +type TCompactProtocol struct { + trans TRichTransport + origTransport TTransport + + // Used to keep track of the last field for the current and previous structs, + // so we can do the delta stuff. + lastField []int + lastFieldId int + + // If we encounter a boolean field begin, save the TField here so it can + // have the value incorporated. + booleanFieldName string + booleanFieldId int16 + booleanFieldPending bool + + // If we read a field header, and it's a boolean field, save the boolean + // value here so that readBool can use it. + boolValue bool + boolValueIsNotNull bool + buffer [64]byte +} + +// Create a TCompactProtocol given a TTransport +func NewTCompactProtocol(trans TTransport) *TCompactProtocol { + p := &TCompactProtocol{origTransport: trans, lastField: []int{}} + if et, ok := trans.(TRichTransport); ok { + p.trans = et + } else { + p.trans = NewTRichTransport(trans) + } + + return p + +} + +// +// Public Writing methods. +// + +// Write a message header to the wire. Compact Protocol messages contain the +// protocol version so we can migrate forwards in the future if need be. +func (p *TCompactProtocol) WriteMessageBegin(name string, typeId TMessageType, seqid int32) error { + err := p.writeByteDirect(COMPACT_PROTOCOL_ID) + if err != nil { + return NewTProtocolException(err) + } + err = p.writeByteDirect((COMPACT_VERSION & COMPACT_VERSION_MASK) | ((byte(typeId) << COMPACT_TYPE_SHIFT_AMOUNT) & COMPACT_TYPE_MASK)) + if err != nil { + return NewTProtocolException(err) + } + _, err = p.writeVarint32(seqid) + if err != nil { + return NewTProtocolException(err) + } + e := p.WriteString(name) + return e + +} + +func (p *TCompactProtocol) WriteMessageEnd() error { return nil } + +// Write a struct begin. This doesn't actually put anything on the wire. We +// use it as an opportunity to put special placeholder markers on the field +// stack so we can get the field id deltas correct. +func (p *TCompactProtocol) WriteStructBegin(name string) error { + p.lastField = append(p.lastField, p.lastFieldId) + p.lastFieldId = 0 + return nil +} + +// Write a struct end. This doesn't actually put anything on the wire. We use +// this as an opportunity to pop the last field from the current struct off +// of the field stack. +func (p *TCompactProtocol) WriteStructEnd() error { + p.lastFieldId = p.lastField[len(p.lastField)-1] + p.lastField = p.lastField[:len(p.lastField)-1] + return nil +} + +func (p *TCompactProtocol) WriteFieldBegin(name string, typeId TType, id int16) error { + if typeId == BOOL { + // we want to possibly include the value, so we'll wait. + p.booleanFieldName, p.booleanFieldId, p.booleanFieldPending = name, id, true + return nil + } + _, err := p.writeFieldBeginInternal(name, typeId, id, 0xFF) + return NewTProtocolException(err) +} + +// The workhorse of writeFieldBegin. It has the option of doing a +// 'type override' of the type header. This is used specifically in the +// boolean field case. +func (p *TCompactProtocol) writeFieldBeginInternal(name string, typeId TType, id int16, typeOverride byte) (int, error) { + // short lastField = lastField_.pop(); + + // if there's a type override, use that. + var typeToWrite byte + if typeOverride == 0xFF { + typeToWrite = byte(p.getCompactType(typeId)) + } else { + typeToWrite = typeOverride + } + // check if we can use delta encoding for the field id + fieldId := int(id) + written := 0 + if fieldId > p.lastFieldId && fieldId-p.lastFieldId <= 15 { + // write them together + err := p.writeByteDirect(byte((fieldId-p.lastFieldId)<<4) | typeToWrite) + if err != nil { + return 0, err + } + } else { + // write them separate + err := p.writeByteDirect(typeToWrite) + if err != nil { + return 0, err + } + err = p.WriteI16(id) + written = 1 + 2 + if err != nil { + return 0, err + } + } + + p.lastFieldId = fieldId + // p.lastField.Push(field.id); + return written, nil +} + +func (p *TCompactProtocol) WriteFieldEnd() error { return nil } + +func (p *TCompactProtocol) WriteFieldStop() error { + err := p.writeByteDirect(STOP) + return NewTProtocolException(err) +} + +func (p *TCompactProtocol) WriteMapBegin(keyType TType, valueType TType, size int) error { + if size == 0 { + err := p.writeByteDirect(0) + return NewTProtocolException(err) + } + _, err := p.writeVarint32(int32(size)) + if err != nil { + return NewTProtocolException(err) + } + err = p.writeByteDirect(byte(p.getCompactType(keyType))<<4 | byte(p.getCompactType(valueType))) + return NewTProtocolException(err) +} + +func (p *TCompactProtocol) WriteMapEnd() error { return nil } + +// Write a list header. +func (p *TCompactProtocol) WriteListBegin(elemType TType, size int) error { + _, err := p.writeCollectionBegin(elemType, size) + return NewTProtocolException(err) +} + +func (p *TCompactProtocol) WriteListEnd() error { return nil } + +// Write a set header. +func (p *TCompactProtocol) WriteSetBegin(elemType TType, size int) error { + _, err := p.writeCollectionBegin(elemType, size) + return NewTProtocolException(err) +} + +func (p *TCompactProtocol) WriteSetEnd() error { return nil } + +func (p *TCompactProtocol) WriteBool(value bool) error { + v := byte(COMPACT_BOOLEAN_FALSE) + if value { + v = byte(COMPACT_BOOLEAN_TRUE) + } + if p.booleanFieldPending { + // we haven't written the field header yet + _, err := p.writeFieldBeginInternal(p.booleanFieldName, BOOL, p.booleanFieldId, v) + p.booleanFieldPending = false + return NewTProtocolException(err) + } + // we're not part of a field, so just write the value. + err := p.writeByteDirect(v) + return NewTProtocolException(err) +} + +// Write a byte. Nothing to see here! +func (p *TCompactProtocol) WriteByte(value int8) error { + err := p.writeByteDirect(byte(value)) + return NewTProtocolException(err) +} + +// Write an I16 as a zigzag varint. +func (p *TCompactProtocol) WriteI16(value int16) error { + _, err := p.writeVarint32(p.int32ToZigzag(int32(value))) + return NewTProtocolException(err) +} + +// Write an i32 as a zigzag varint. +func (p *TCompactProtocol) WriteI32(value int32) error { + _, err := p.writeVarint32(p.int32ToZigzag(value)) + return NewTProtocolException(err) +} + +// Write an i64 as a zigzag varint. +func (p *TCompactProtocol) WriteI64(value int64) error { + _, err := p.writeVarint64(p.int64ToZigzag(value)) + return NewTProtocolException(err) +} + +// Write a double to the wire as 8 bytes. +func (p *TCompactProtocol) WriteDouble(value float64) error { + buf := p.buffer[0:8] + binary.LittleEndian.PutUint64(buf, math.Float64bits(value)) + _, err := p.trans.Write(buf) + return NewTProtocolException(err) +} + +// Write a string to the wire with a varint size preceding. +func (p *TCompactProtocol) WriteString(value string) error { + _, e := p.writeVarint32(int32(len(value))) + if e != nil { + return NewTProtocolException(e) + } + if len(value) > 0 { + } + _, e = p.trans.WriteString(value) + return e +} + +// Write a byte array, using a varint for the size. +func (p *TCompactProtocol) WriteBinary(bin []byte) error { + _, e := p.writeVarint32(int32(len(bin))) + if e != nil { + return NewTProtocolException(e) + } + if len(bin) > 0 { + _, e = p.trans.Write(bin) + return NewTProtocolException(e) + } + return nil +} + +// +// Reading methods. +// + +// Read a message header. +func (p *TCompactProtocol) ReadMessageBegin() (name string, typeId TMessageType, seqId int32, err error) { + + protocolId, err := p.readByteDirect() + if err != nil { + return + } + + if protocolId != COMPACT_PROTOCOL_ID { + e := fmt.Errorf("Expected protocol id %02x but got %02x", COMPACT_PROTOCOL_ID, protocolId) + return "", typeId, seqId, NewTProtocolExceptionWithType(BAD_VERSION, e) + } + + versionAndType, err := p.readByteDirect() + if err != nil { + return + } + + version := versionAndType & COMPACT_VERSION_MASK + typeId = TMessageType((versionAndType >> COMPACT_TYPE_SHIFT_AMOUNT) & COMPACT_TYPE_BITS) + if version != COMPACT_VERSION { + e := fmt.Errorf("Expected version %02x but got %02x", COMPACT_VERSION, version) + err = NewTProtocolExceptionWithType(BAD_VERSION, e) + return + } + seqId, e := p.readVarint32() + if e != nil { + err = NewTProtocolException(e) + return + } + name, err = p.ReadString() + return +} + +func (p *TCompactProtocol) ReadMessageEnd() error { return nil } + +// Read a struct begin. There's nothing on the wire for this, but it is our +// opportunity to push a new struct begin marker onto the field stack. +func (p *TCompactProtocol) ReadStructBegin() (name string, err error) { + p.lastField = append(p.lastField, p.lastFieldId) + p.lastFieldId = 0 + return +} + +// Doesn't actually consume any wire data, just removes the last field for +// this struct from the field stack. +func (p *TCompactProtocol) ReadStructEnd() error { + // consume the last field we read off the wire. + p.lastFieldId = p.lastField[len(p.lastField)-1] + p.lastField = p.lastField[:len(p.lastField)-1] + return nil +} + +// Read a field header off the wire. +func (p *TCompactProtocol) ReadFieldBegin() (name string, typeId TType, id int16, err error) { + t, err := p.readByteDirect() + if err != nil { + return + } + + // if it's a stop, then we can return immediately, as the struct is over. + if (t & 0x0f) == STOP { + return "", STOP, 0, nil + } + + // mask off the 4 MSB of the type header. it could contain a field id delta. + modifier := int16((t & 0xf0) >> 4) + if modifier == 0 { + // not a delta. look ahead for the zigzag varint field id. + id, err = p.ReadI16() + if err != nil { + return + } + } else { + // has a delta. add the delta to the last read field id. + id = int16(p.lastFieldId) + modifier + } + typeId, e := p.getTType(tCompactType(t & 0x0f)) + if e != nil { + err = NewTProtocolException(e) + return + } + + // if this happens to be a boolean field, the value is encoded in the type + if p.isBoolType(t) { + // save the boolean value in a special instance variable. + p.boolValue = (byte(t)&0x0f == COMPACT_BOOLEAN_TRUE) + p.boolValueIsNotNull = true + } + + // push the new field onto the field stack so we can keep the deltas going. + p.lastFieldId = int(id) + return +} + +func (p *TCompactProtocol) ReadFieldEnd() error { return nil } + +// Read a map header off the wire. If the size is zero, skip reading the key +// and value type. This means that 0-length maps will yield TMaps without the +// "correct" types. +func (p *TCompactProtocol) ReadMapBegin() (keyType TType, valueType TType, size int, err error) { + size32, e := p.readVarint32() + if e != nil { + err = NewTProtocolException(e) + return + } + if size32 < 0 { + err = invalidDataLength + return + } + size = int(size32) + + keyAndValueType := byte(STOP) + if size != 0 { + keyAndValueType, err = p.readByteDirect() + if err != nil { + return + } + } + keyType, _ = p.getTType(tCompactType(keyAndValueType >> 4)) + valueType, _ = p.getTType(tCompactType(keyAndValueType & 0xf)) + return +} + +func (p *TCompactProtocol) ReadMapEnd() error { return nil } + +// Read a list header off the wire. If the list size is 0-14, the size will +// be packed into the element type header. If it's a longer list, the 4 MSB +// of the element type header will be 0xF, and a varint will follow with the +// true size. +func (p *TCompactProtocol) ReadListBegin() (elemType TType, size int, err error) { + size_and_type, err := p.readByteDirect() + if err != nil { + return + } + size = int((size_and_type >> 4) & 0x0f) + if size == 15 { + size2, e := p.readVarint32() + if e != nil { + err = NewTProtocolException(e) + return + } + if size2 < 0 { + err = invalidDataLength + return + } + size = int(size2) + } + elemType, e := p.getTType(tCompactType(size_and_type)) + if e != nil { + err = NewTProtocolException(e) + return + } + return +} + +func (p *TCompactProtocol) ReadListEnd() error { return nil } + +// Read a set header off the wire. If the set size is 0-14, the size will +// be packed into the element type header. If it's a longer set, the 4 MSB +// of the element type header will be 0xF, and a varint will follow with the +// true size. +func (p *TCompactProtocol) ReadSetBegin() (elemType TType, size int, err error) { + return p.ReadListBegin() +} + +func (p *TCompactProtocol) ReadSetEnd() error { return nil } + +// Read a boolean off the wire. If this is a boolean field, the value should +// already have been read during readFieldBegin, so we'll just consume the +// pre-stored value. Otherwise, read a byte. +func (p *TCompactProtocol) ReadBool() (value bool, err error) { + if p.boolValueIsNotNull { + p.boolValueIsNotNull = false + return p.boolValue, nil + } + v, err := p.readByteDirect() + return v == COMPACT_BOOLEAN_TRUE, err +} + +// Read a single byte off the wire. Nothing interesting here. +func (p *TCompactProtocol) ReadByte() (int8, error) { + v, err := p.readByteDirect() + if err != nil { + return 0, NewTProtocolException(err) + } + return int8(v), err +} + +// Read an i16 from the wire as a zigzag varint. +func (p *TCompactProtocol) ReadI16() (value int16, err error) { + v, err := p.ReadI32() + return int16(v), err +} + +// Read an i32 from the wire as a zigzag varint. +func (p *TCompactProtocol) ReadI32() (value int32, err error) { + v, e := p.readVarint32() + if e != nil { + return 0, NewTProtocolException(e) + } + value = p.zigzagToInt32(v) + return value, nil +} + +// Read an i64 from the wire as a zigzag varint. +func (p *TCompactProtocol) ReadI64() (value int64, err error) { + v, e := p.readVarint64() + if e != nil { + return 0, NewTProtocolException(e) + } + value = p.zigzagToInt64(v) + return value, nil +} + +// No magic here - just read a double off the wire. +func (p *TCompactProtocol) ReadDouble() (value float64, err error) { + longBits := p.buffer[0:8] + _, e := io.ReadFull(p.trans, longBits) + if e != nil { + return 0.0, NewTProtocolException(e) + } + return math.Float64frombits(p.bytesToUint64(longBits)), nil +} + +// Reads a []byte (via readBinary), and then UTF-8 decodes it. +func (p *TCompactProtocol) ReadString() (value string, err error) { + length, e := p.readVarint32() + if e != nil { + return "", NewTProtocolException(e) + } + if length < 0 { + return "", invalidDataLength + } + + if length == 0 { + return "", nil + } + var buf []byte + if length <= int32(len(p.buffer)) { + buf = p.buffer[0:length] + } else { + buf = make([]byte, length) + } + _, e = io.ReadFull(p.trans, buf) + return string(buf), NewTProtocolException(e) +} + +// Read a []byte from the wire. +func (p *TCompactProtocol) ReadBinary() (value []byte, err error) { + length, e := p.readVarint32() + if e != nil { + return nil, NewTProtocolException(e) + } + if length == 0 { + return []byte{}, nil + } + if length < 0 { + return nil, invalidDataLength + } + + buf := make([]byte, length) + _, e = io.ReadFull(p.trans, buf) + return buf, NewTProtocolException(e) +} + +func (p *TCompactProtocol) Flush(ctx context.Context) (err error) { + return NewTProtocolException(p.trans.Flush(ctx)) +} + +func (p *TCompactProtocol) Skip(fieldType TType) (err error) { + return SkipDefaultDepth(p, fieldType) +} + +func (p *TCompactProtocol) Transport() TTransport { + return p.origTransport +} + +// +// Internal writing methods +// + +// Abstract method for writing the start of lists and sets. List and sets on +// the wire differ only by the type indicator. +func (p *TCompactProtocol) writeCollectionBegin(elemType TType, size int) (int, error) { + if size <= 14 { + return 1, p.writeByteDirect(byte(int32(size<<4) | int32(p.getCompactType(elemType)))) + } + err := p.writeByteDirect(0xf0 | byte(p.getCompactType(elemType))) + if err != nil { + return 0, err + } + m, err := p.writeVarint32(int32(size)) + return 1 + m, err +} + +// Write an i32 as a varint. Results in 1-5 bytes on the wire. +// TODO(pomack): make a permanent buffer like writeVarint64? +func (p *TCompactProtocol) writeVarint32(n int32) (int, error) { + i32buf := p.buffer[0:5] + idx := 0 + for { + if (n & ^0x7F) == 0 { + i32buf[idx] = byte(n) + idx++ + // p.writeByteDirect(byte(n)); + break + // return; + } else { + i32buf[idx] = byte((n & 0x7F) | 0x80) + idx++ + // p.writeByteDirect(byte(((n & 0x7F) | 0x80))); + u := uint32(n) + n = int32(u >> 7) + } + } + return p.trans.Write(i32buf[0:idx]) +} + +// Write an i64 as a varint. Results in 1-10 bytes on the wire. +func (p *TCompactProtocol) writeVarint64(n int64) (int, error) { + varint64out := p.buffer[0:10] + idx := 0 + for { + if (n & ^0x7F) == 0 { + varint64out[idx] = byte(n) + idx++ + break + } else { + varint64out[idx] = byte((n & 0x7F) | 0x80) + idx++ + u := uint64(n) + n = int64(u >> 7) + } + } + return p.trans.Write(varint64out[0:idx]) +} + +// Convert l into a zigzag long. This allows negative numbers to be +// represented compactly as a varint. +func (p *TCompactProtocol) int64ToZigzag(l int64) int64 { + return (l << 1) ^ (l >> 63) +} + +// Convert l into a zigzag long. This allows negative numbers to be +// represented compactly as a varint. +func (p *TCompactProtocol) int32ToZigzag(n int32) int32 { + return (n << 1) ^ (n >> 31) +} + +func (p *TCompactProtocol) fixedUint64ToBytes(n uint64, buf []byte) { + binary.LittleEndian.PutUint64(buf, n) +} + +func (p *TCompactProtocol) fixedInt64ToBytes(n int64, buf []byte) { + binary.LittleEndian.PutUint64(buf, uint64(n)) +} + +// Writes a byte without any possibility of all that field header nonsense. +// Used internally by other writing methods that know they need to write a byte. +func (p *TCompactProtocol) writeByteDirect(b byte) error { + return p.trans.WriteByte(b) +} + +// Writes a byte without any possibility of all that field header nonsense. +func (p *TCompactProtocol) writeIntAsByteDirect(n int) (int, error) { + return 1, p.writeByteDirect(byte(n)) +} + +// +// Internal reading methods +// + +// Read an i32 from the wire as a varint. The MSB of each byte is set +// if there is another byte to follow. This can read up to 5 bytes. +func (p *TCompactProtocol) readVarint32() (int32, error) { + // if the wire contains the right stuff, this will just truncate the i64 we + // read and get us the right sign. + v, err := p.readVarint64() + return int32(v), err +} + +// Read an i64 from the wire as a proper varint. The MSB of each byte is set +// if there is another byte to follow. This can read up to 10 bytes. +func (p *TCompactProtocol) readVarint64() (int64, error) { + shift := uint(0) + result := int64(0) + for { + b, err := p.readByteDirect() + if err != nil { + return 0, err + } + result |= int64(b&0x7f) << shift + if (b & 0x80) != 0x80 { + break + } + shift += 7 + } + return result, nil +} + +// Read a byte, unlike ReadByte that reads Thrift-byte that is i8. +func (p *TCompactProtocol) readByteDirect() (byte, error) { + return p.trans.ReadByte() +} + +// +// encoding helpers +// + +// Convert from zigzag int to int. +func (p *TCompactProtocol) zigzagToInt32(n int32) int32 { + u := uint32(n) + return int32(u>>1) ^ -(n & 1) +} + +// Convert from zigzag long to long. +func (p *TCompactProtocol) zigzagToInt64(n int64) int64 { + u := uint64(n) + return int64(u>>1) ^ -(n & 1) +} + +// Note that it's important that the mask bytes are long literals, +// otherwise they'll default to ints, and when you shift an int left 56 bits, +// you just get a messed up int. +func (p *TCompactProtocol) bytesToInt64(b []byte) int64 { + return int64(binary.LittleEndian.Uint64(b)) +} + +// Note that it's important that the mask bytes are long literals, +// otherwise they'll default to ints, and when you shift an int left 56 bits, +// you just get a messed up int. +func (p *TCompactProtocol) bytesToUint64(b []byte) uint64 { + return binary.LittleEndian.Uint64(b) +} + +// +// type testing and converting +// + +func (p *TCompactProtocol) isBoolType(b byte) bool { + return (b&0x0f) == COMPACT_BOOLEAN_TRUE || (b&0x0f) == COMPACT_BOOLEAN_FALSE +} + +// Given a tCompactType constant, convert it to its corresponding +// TType value. +func (p *TCompactProtocol) getTType(t tCompactType) (TType, error) { + switch byte(t) & 0x0f { + case STOP: + return STOP, nil + case COMPACT_BOOLEAN_FALSE, COMPACT_BOOLEAN_TRUE: + return BOOL, nil + case COMPACT_BYTE: + return BYTE, nil + case COMPACT_I16: + return I16, nil + case COMPACT_I32: + return I32, nil + case COMPACT_I64: + return I64, nil + case COMPACT_DOUBLE: + return DOUBLE, nil + case COMPACT_BINARY: + return STRING, nil + case COMPACT_LIST: + return LIST, nil + case COMPACT_SET: + return SET, nil + case COMPACT_MAP: + return MAP, nil + case COMPACT_STRUCT: + return STRUCT, nil + } + return STOP, TException(fmt.Errorf("don't know what type: %v", t&0x0f)) +} + +// Given a TType value, find the appropriate TCompactProtocol.Types constant. +func (p *TCompactProtocol) getCompactType(t TType) tCompactType { + return ttypeToCompactType[t] +} diff --git a/src/jaegertracing/thrift/lib/go/thrift/compact_protocol_test.go b/src/jaegertracing/thrift/lib/go/thrift/compact_protocol_test.go new file mode 100644 index 000000000..65f77f2c4 --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/thrift/compact_protocol_test.go @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import ( + "bytes" + "testing" +) + +func TestReadWriteCompactProtocol(t *testing.T) { + ReadWriteProtocolTest(t, NewTCompactProtocolFactory()) + + transports := []TTransport{ + NewTMemoryBuffer(), + NewStreamTransportRW(bytes.NewBuffer(make([]byte, 0, 16384))), + NewTFramedTransport(NewTMemoryBuffer()), + } + + zlib0, _ := NewTZlibTransport(NewTMemoryBuffer(), 0) + zlib6, _ := NewTZlibTransport(NewTMemoryBuffer(), 6) + zlib9, _ := NewTZlibTransport(NewTFramedTransport(NewTMemoryBuffer()), 9) + transports = append(transports, zlib0, zlib6, zlib9) + + for _, trans := range transports { + p := NewTCompactProtocol(trans) + ReadWriteBool(t, p, trans) + p = NewTCompactProtocol(trans) + ReadWriteByte(t, p, trans) + p = NewTCompactProtocol(trans) + ReadWriteI16(t, p, trans) + p = NewTCompactProtocol(trans) + ReadWriteI32(t, p, trans) + p = NewTCompactProtocol(trans) + ReadWriteI64(t, p, trans) + p = NewTCompactProtocol(trans) + ReadWriteDouble(t, p, trans) + p = NewTCompactProtocol(trans) + ReadWriteString(t, p, trans) + p = NewTCompactProtocol(trans) + ReadWriteBinary(t, p, trans) + trans.Close() + } +} diff --git a/src/jaegertracing/thrift/lib/go/thrift/context.go b/src/jaegertracing/thrift/lib/go/thrift/context.go new file mode 100644 index 000000000..d15c1bcf8 --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/thrift/context.go @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import "context" + +var defaultCtx = context.Background() diff --git a/src/jaegertracing/thrift/lib/go/thrift/debug_protocol.go b/src/jaegertracing/thrift/lib/go/thrift/debug_protocol.go new file mode 100644 index 000000000..57943e0f3 --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/thrift/debug_protocol.go @@ -0,0 +1,270 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import ( + "context" + "log" +) + +type TDebugProtocol struct { + Delegate TProtocol + LogPrefix string +} + +type TDebugProtocolFactory struct { + Underlying TProtocolFactory + LogPrefix string +} + +func NewTDebugProtocolFactory(underlying TProtocolFactory, logPrefix string) *TDebugProtocolFactory { + return &TDebugProtocolFactory{ + Underlying: underlying, + LogPrefix: logPrefix, + } +} + +func (t *TDebugProtocolFactory) GetProtocol(trans TTransport) TProtocol { + return &TDebugProtocol{ + Delegate: t.Underlying.GetProtocol(trans), + LogPrefix: t.LogPrefix, + } +} + +func (tdp *TDebugProtocol) WriteMessageBegin(name string, typeId TMessageType, seqid int32) error { + err := tdp.Delegate.WriteMessageBegin(name, typeId, seqid) + log.Printf("%sWriteMessageBegin(name=%#v, typeId=%#v, seqid=%#v) => %#v", tdp.LogPrefix, name, typeId, seqid, err) + return err +} +func (tdp *TDebugProtocol) WriteMessageEnd() error { + err := tdp.Delegate.WriteMessageEnd() + log.Printf("%sWriteMessageEnd() => %#v", tdp.LogPrefix, err) + return err +} +func (tdp *TDebugProtocol) WriteStructBegin(name string) error { + err := tdp.Delegate.WriteStructBegin(name) + log.Printf("%sWriteStructBegin(name=%#v) => %#v", tdp.LogPrefix, name, err) + return err +} +func (tdp *TDebugProtocol) WriteStructEnd() error { + err := tdp.Delegate.WriteStructEnd() + log.Printf("%sWriteStructEnd() => %#v", tdp.LogPrefix, err) + return err +} +func (tdp *TDebugProtocol) WriteFieldBegin(name string, typeId TType, id int16) error { + err := tdp.Delegate.WriteFieldBegin(name, typeId, id) + log.Printf("%sWriteFieldBegin(name=%#v, typeId=%#v, id%#v) => %#v", tdp.LogPrefix, name, typeId, id, err) + return err +} +func (tdp *TDebugProtocol) WriteFieldEnd() error { + err := tdp.Delegate.WriteFieldEnd() + log.Printf("%sWriteFieldEnd() => %#v", tdp.LogPrefix, err) + return err +} +func (tdp *TDebugProtocol) WriteFieldStop() error { + err := tdp.Delegate.WriteFieldStop() + log.Printf("%sWriteFieldStop() => %#v", tdp.LogPrefix, err) + return err +} +func (tdp *TDebugProtocol) WriteMapBegin(keyType TType, valueType TType, size int) error { + err := tdp.Delegate.WriteMapBegin(keyType, valueType, size) + log.Printf("%sWriteMapBegin(keyType=%#v, valueType=%#v, size=%#v) => %#v", tdp.LogPrefix, keyType, valueType, size, err) + return err +} +func (tdp *TDebugProtocol) WriteMapEnd() error { + err := tdp.Delegate.WriteMapEnd() + log.Printf("%sWriteMapEnd() => %#v", tdp.LogPrefix, err) + return err +} +func (tdp *TDebugProtocol) WriteListBegin(elemType TType, size int) error { + err := tdp.Delegate.WriteListBegin(elemType, size) + log.Printf("%sWriteListBegin(elemType=%#v, size=%#v) => %#v", tdp.LogPrefix, elemType, size, err) + return err +} +func (tdp *TDebugProtocol) WriteListEnd() error { + err := tdp.Delegate.WriteListEnd() + log.Printf("%sWriteListEnd() => %#v", tdp.LogPrefix, err) + return err +} +func (tdp *TDebugProtocol) WriteSetBegin(elemType TType, size int) error { + err := tdp.Delegate.WriteSetBegin(elemType, size) + log.Printf("%sWriteSetBegin(elemType=%#v, size=%#v) => %#v", tdp.LogPrefix, elemType, size, err) + return err +} +func (tdp *TDebugProtocol) WriteSetEnd() error { + err := tdp.Delegate.WriteSetEnd() + log.Printf("%sWriteSetEnd() => %#v", tdp.LogPrefix, err) + return err +} +func (tdp *TDebugProtocol) WriteBool(value bool) error { + err := tdp.Delegate.WriteBool(value) + log.Printf("%sWriteBool(value=%#v) => %#v", tdp.LogPrefix, value, err) + return err +} +func (tdp *TDebugProtocol) WriteByte(value int8) error { + err := tdp.Delegate.WriteByte(value) + log.Printf("%sWriteByte(value=%#v) => %#v", tdp.LogPrefix, value, err) + return err +} +func (tdp *TDebugProtocol) WriteI16(value int16) error { + err := tdp.Delegate.WriteI16(value) + log.Printf("%sWriteI16(value=%#v) => %#v", tdp.LogPrefix, value, err) + return err +} +func (tdp *TDebugProtocol) WriteI32(value int32) error { + err := tdp.Delegate.WriteI32(value) + log.Printf("%sWriteI32(value=%#v) => %#v", tdp.LogPrefix, value, err) + return err +} +func (tdp *TDebugProtocol) WriteI64(value int64) error { + err := tdp.Delegate.WriteI64(value) + log.Printf("%sWriteI64(value=%#v) => %#v", tdp.LogPrefix, value, err) + return err +} +func (tdp *TDebugProtocol) WriteDouble(value float64) error { + err := tdp.Delegate.WriteDouble(value) + log.Printf("%sWriteDouble(value=%#v) => %#v", tdp.LogPrefix, value, err) + return err +} +func (tdp *TDebugProtocol) WriteString(value string) error { + err := tdp.Delegate.WriteString(value) + log.Printf("%sWriteString(value=%#v) => %#v", tdp.LogPrefix, value, err) + return err +} +func (tdp *TDebugProtocol) WriteBinary(value []byte) error { + err := tdp.Delegate.WriteBinary(value) + log.Printf("%sWriteBinary(value=%#v) => %#v", tdp.LogPrefix, value, err) + return err +} + +func (tdp *TDebugProtocol) ReadMessageBegin() (name string, typeId TMessageType, seqid int32, err error) { + name, typeId, seqid, err = tdp.Delegate.ReadMessageBegin() + log.Printf("%sReadMessageBegin() (name=%#v, typeId=%#v, seqid=%#v, err=%#v)", tdp.LogPrefix, name, typeId, seqid, err) + return +} +func (tdp *TDebugProtocol) ReadMessageEnd() (err error) { + err = tdp.Delegate.ReadMessageEnd() + log.Printf("%sReadMessageEnd() err=%#v", tdp.LogPrefix, err) + return +} +func (tdp *TDebugProtocol) ReadStructBegin() (name string, err error) { + name, err = tdp.Delegate.ReadStructBegin() + log.Printf("%sReadStructBegin() (name%#v, err=%#v)", tdp.LogPrefix, name, err) + return +} +func (tdp *TDebugProtocol) ReadStructEnd() (err error) { + err = tdp.Delegate.ReadStructEnd() + log.Printf("%sReadStructEnd() err=%#v", tdp.LogPrefix, err) + return +} +func (tdp *TDebugProtocol) ReadFieldBegin() (name string, typeId TType, id int16, err error) { + name, typeId, id, err = tdp.Delegate.ReadFieldBegin() + log.Printf("%sReadFieldBegin() (name=%#v, typeId=%#v, id=%#v, err=%#v)", tdp.LogPrefix, name, typeId, id, err) + return +} +func (tdp *TDebugProtocol) ReadFieldEnd() (err error) { + err = tdp.Delegate.ReadFieldEnd() + log.Printf("%sReadFieldEnd() err=%#v", tdp.LogPrefix, err) + return +} +func (tdp *TDebugProtocol) ReadMapBegin() (keyType TType, valueType TType, size int, err error) { + keyType, valueType, size, err = tdp.Delegate.ReadMapBegin() + log.Printf("%sReadMapBegin() (keyType=%#v, valueType=%#v, size=%#v, err=%#v)", tdp.LogPrefix, keyType, valueType, size, err) + return +} +func (tdp *TDebugProtocol) ReadMapEnd() (err error) { + err = tdp.Delegate.ReadMapEnd() + log.Printf("%sReadMapEnd() err=%#v", tdp.LogPrefix, err) + return +} +func (tdp *TDebugProtocol) ReadListBegin() (elemType TType, size int, err error) { + elemType, size, err = tdp.Delegate.ReadListBegin() + log.Printf("%sReadListBegin() (elemType=%#v, size=%#v, err=%#v)", tdp.LogPrefix, elemType, size, err) + return +} +func (tdp *TDebugProtocol) ReadListEnd() (err error) { + err = tdp.Delegate.ReadListEnd() + log.Printf("%sReadListEnd() err=%#v", tdp.LogPrefix, err) + return +} +func (tdp *TDebugProtocol) ReadSetBegin() (elemType TType, size int, err error) { + elemType, size, err = tdp.Delegate.ReadSetBegin() + log.Printf("%sReadSetBegin() (elemType=%#v, size=%#v, err=%#v)", tdp.LogPrefix, elemType, size, err) + return +} +func (tdp *TDebugProtocol) ReadSetEnd() (err error) { + err = tdp.Delegate.ReadSetEnd() + log.Printf("%sReadSetEnd() err=%#v", tdp.LogPrefix, err) + return +} +func (tdp *TDebugProtocol) ReadBool() (value bool, err error) { + value, err = tdp.Delegate.ReadBool() + log.Printf("%sReadBool() (value=%#v, err=%#v)", tdp.LogPrefix, value, err) + return +} +func (tdp *TDebugProtocol) ReadByte() (value int8, err error) { + value, err = tdp.Delegate.ReadByte() + log.Printf("%sReadByte() (value=%#v, err=%#v)", tdp.LogPrefix, value, err) + return +} +func (tdp *TDebugProtocol) ReadI16() (value int16, err error) { + value, err = tdp.Delegate.ReadI16() + log.Printf("%sReadI16() (value=%#v, err=%#v)", tdp.LogPrefix, value, err) + return +} +func (tdp *TDebugProtocol) ReadI32() (value int32, err error) { + value, err = tdp.Delegate.ReadI32() + log.Printf("%sReadI32() (value=%#v, err=%#v)", tdp.LogPrefix, value, err) + return +} +func (tdp *TDebugProtocol) ReadI64() (value int64, err error) { + value, err = tdp.Delegate.ReadI64() + log.Printf("%sReadI64() (value=%#v, err=%#v)", tdp.LogPrefix, value, err) + return +} +func (tdp *TDebugProtocol) ReadDouble() (value float64, err error) { + value, err = tdp.Delegate.ReadDouble() + log.Printf("%sReadDouble() (value=%#v, err=%#v)", tdp.LogPrefix, value, err) + return +} +func (tdp *TDebugProtocol) ReadString() (value string, err error) { + value, err = tdp.Delegate.ReadString() + log.Printf("%sReadString() (value=%#v, err=%#v)", tdp.LogPrefix, value, err) + return +} +func (tdp *TDebugProtocol) ReadBinary() (value []byte, err error) { + value, err = tdp.Delegate.ReadBinary() + log.Printf("%sReadBinary() (value=%#v, err=%#v)", tdp.LogPrefix, value, err) + return +} +func (tdp *TDebugProtocol) Skip(fieldType TType) (err error) { + err = tdp.Delegate.Skip(fieldType) + log.Printf("%sSkip(fieldType=%#v) (err=%#v)", tdp.LogPrefix, fieldType, err) + return +} +func (tdp *TDebugProtocol) Flush(ctx context.Context) (err error) { + err = tdp.Delegate.Flush(ctx) + log.Printf("%sFlush() (err=%#v)", tdp.LogPrefix, err) + return +} + +func (tdp *TDebugProtocol) Transport() TTransport { + return tdp.Delegate.Transport() +} diff --git a/src/jaegertracing/thrift/lib/go/thrift/deserializer.go b/src/jaegertracing/thrift/lib/go/thrift/deserializer.go new file mode 100644 index 000000000..91a0983a4 --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/thrift/deserializer.go @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +type TDeserializer struct { + Transport TTransport + Protocol TProtocol +} + +func NewTDeserializer() *TDeserializer { + var transport TTransport + transport = NewTMemoryBufferLen(1024) + + protocol := NewTBinaryProtocolFactoryDefault().GetProtocol(transport) + + return &TDeserializer{ + transport, + protocol} +} + +func (t *TDeserializer) ReadString(msg TStruct, s string) (err error) { + err = nil + if _, err = t.Transport.Write([]byte(s)); err != nil { + return + } + if err = msg.Read(t.Protocol); err != nil { + return + } + return +} + +func (t *TDeserializer) Read(msg TStruct, b []byte) (err error) { + err = nil + if _, err = t.Transport.Write(b); err != nil { + return + } + if err = msg.Read(t.Protocol); err != nil { + return + } + return +} diff --git a/src/jaegertracing/thrift/lib/go/thrift/exception.go b/src/jaegertracing/thrift/lib/go/thrift/exception.go new file mode 100644 index 000000000..ea8d6f661 --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/thrift/exception.go @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import ( + "errors" +) + +// Generic Thrift exception +type TException interface { + error +} + +// Prepends additional information to an error without losing the Thrift exception interface +func PrependError(prepend string, err error) error { + if t, ok := err.(TTransportException); ok { + return NewTTransportException(t.TypeId(), prepend+t.Error()) + } + if t, ok := err.(TProtocolException); ok { + return NewTProtocolExceptionWithType(t.TypeId(), errors.New(prepend+err.Error())) + } + if t, ok := err.(TApplicationException); ok { + return NewTApplicationException(t.TypeId(), prepend+t.Error()) + } + + return errors.New(prepend + err.Error()) +} diff --git a/src/jaegertracing/thrift/lib/go/thrift/exception_test.go b/src/jaegertracing/thrift/lib/go/thrift/exception_test.go new file mode 100644 index 000000000..71f5e2c7e --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/thrift/exception_test.go @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import ( + "errors" + "testing" +) + +func TestPrependError(t *testing.T) { + err := NewTApplicationException(INTERNAL_ERROR, "original error") + err2, ok := PrependError("Prepend: ", err).(TApplicationException) + if !ok { + t.Fatal("Couldn't cast error TApplicationException") + } + if err2.Error() != "Prepend: original error" { + t.Fatal("Unexpected error string") + } + if err2.TypeId() != INTERNAL_ERROR { + t.Fatal("Unexpected type error") + } + + err3 := NewTProtocolExceptionWithType(INVALID_DATA, errors.New("original error")) + err4, ok := PrependError("Prepend: ", err3).(TProtocolException) + if !ok { + t.Fatal("Couldn't cast error TProtocolException") + } + if err4.Error() != "Prepend: original error" { + t.Fatal("Unexpected error string") + } + if err4.TypeId() != INVALID_DATA { + t.Fatal("Unexpected type error") + } + + err5 := NewTTransportException(TIMED_OUT, "original error") + err6, ok := PrependError("Prepend: ", err5).(TTransportException) + if !ok { + t.Fatal("Couldn't cast error TTransportException") + } + if err6.Error() != "Prepend: original error" { + t.Fatal("Unexpected error string") + } + if err6.TypeId() != TIMED_OUT { + t.Fatal("Unexpected type error") + } + + err7 := errors.New("original error") + err8 := PrependError("Prepend: ", err7) + if err8.Error() != "Prepend: original error" { + t.Fatal("Unexpected error string") + } +} diff --git a/src/jaegertracing/thrift/lib/go/thrift/field.go b/src/jaegertracing/thrift/lib/go/thrift/field.go new file mode 100644 index 000000000..9d6652550 --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/thrift/field.go @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +// Helper class that encapsulates field metadata. +type field struct { + name string + typeId TType + id int +} + +func newField(n string, t TType, i int) *field { + return &field{name: n, typeId: t, id: i} +} + +func (p *field) Name() string { + if p == nil { + return "" + } + return p.name +} + +func (p *field) TypeId() TType { + if p == nil { + return TType(VOID) + } + return p.typeId +} + +func (p *field) Id() int { + if p == nil { + return -1 + } + return p.id +} + +func (p *field) String() string { + if p == nil { + return "<nil>" + } + return "<TField name:'" + p.name + "' type:" + string(p.typeId) + " field-id:" + string(p.id) + ">" +} + +var ANONYMOUS_FIELD *field + +type fieldSlice []field + +func (p fieldSlice) Len() int { + return len(p) +} + +func (p fieldSlice) Less(i, j int) bool { + return p[i].Id() < p[j].Id() +} + +func (p fieldSlice) Swap(i, j int) { + p[i], p[j] = p[j], p[i] +} + +func init() { + ANONYMOUS_FIELD = newField("", STOP, 0) +} diff --git a/src/jaegertracing/thrift/lib/go/thrift/framed_transport.go b/src/jaegertracing/thrift/lib/go/thrift/framed_transport.go new file mode 100644 index 000000000..34275b5f4 --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/thrift/framed_transport.go @@ -0,0 +1,187 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import ( + "bufio" + "bytes" + "context" + "encoding/binary" + "fmt" + "io" +) + +const DEFAULT_MAX_LENGTH = 16384000 + +type TFramedTransport struct { + transport TTransport + buf bytes.Buffer + reader *bufio.Reader + frameSize uint32 //Current remaining size of the frame. if ==0 read next frame header + buffer [4]byte + maxLength uint32 +} + +type tFramedTransportFactory struct { + factory TTransportFactory + maxLength uint32 +} + +func NewTFramedTransportFactory(factory TTransportFactory) TTransportFactory { + return &tFramedTransportFactory{factory: factory, maxLength: DEFAULT_MAX_LENGTH} +} + +func NewTFramedTransportFactoryMaxLength(factory TTransportFactory, maxLength uint32) TTransportFactory { + return &tFramedTransportFactory{factory: factory, maxLength: maxLength} +} + +func (p *tFramedTransportFactory) GetTransport(base TTransport) (TTransport, error) { + tt, err := p.factory.GetTransport(base) + if err != nil { + return nil, err + } + return NewTFramedTransportMaxLength(tt, p.maxLength), nil +} + +func NewTFramedTransport(transport TTransport) *TFramedTransport { + return &TFramedTransport{transport: transport, reader: bufio.NewReader(transport), maxLength: DEFAULT_MAX_LENGTH} +} + +func NewTFramedTransportMaxLength(transport TTransport, maxLength uint32) *TFramedTransport { + return &TFramedTransport{transport: transport, reader: bufio.NewReader(transport), maxLength: maxLength} +} + +func (p *TFramedTransport) Open() error { + return p.transport.Open() +} + +func (p *TFramedTransport) IsOpen() bool { + return p.transport.IsOpen() +} + +func (p *TFramedTransport) Close() error { + return p.transport.Close() +} + +func (p *TFramedTransport) Read(buf []byte) (l int, err error) { + if p.frameSize == 0 { + p.frameSize, err = p.readFrameHeader() + if err != nil { + return + } + } + if p.frameSize < uint32(len(buf)) { + frameSize := p.frameSize + tmp := make([]byte, p.frameSize) + l, err = p.Read(tmp) + copy(buf, tmp) + if err == nil { + // Note: It's important to only return an error when l + // is zero. + // In io.Reader.Read interface, it's perfectly fine to + // return partial data and nil error, which means + // "This is all the data we have right now without + // blocking. If you need the full data, call Read again + // or use io.ReadFull instead". + // Returning partial data with an error actually means + // there's no more data after the partial data just + // returned, which is not true in this case + // (it might be that the other end just haven't written + // them yet). + if l == 0 { + err = NewTTransportExceptionFromError(fmt.Errorf("Not enough frame size %d to read %d bytes", frameSize, len(buf))) + } + return + } + } + got, err := p.reader.Read(buf) + p.frameSize = p.frameSize - uint32(got) + //sanity check + if p.frameSize < 0 { + return 0, NewTTransportException(UNKNOWN_TRANSPORT_EXCEPTION, "Negative frame size") + } + return got, NewTTransportExceptionFromError(err) +} + +func (p *TFramedTransport) ReadByte() (c byte, err error) { + if p.frameSize == 0 { + p.frameSize, err = p.readFrameHeader() + if err != nil { + return + } + } + if p.frameSize < 1 { + return 0, NewTTransportExceptionFromError(fmt.Errorf("Not enough frame size %d to read %d bytes", p.frameSize, 1)) + } + c, err = p.reader.ReadByte() + if err == nil { + p.frameSize-- + } + return +} + +func (p *TFramedTransport) Write(buf []byte) (int, error) { + n, err := p.buf.Write(buf) + return n, NewTTransportExceptionFromError(err) +} + +func (p *TFramedTransport) WriteByte(c byte) error { + return p.buf.WriteByte(c) +} + +func (p *TFramedTransport) WriteString(s string) (n int, err error) { + return p.buf.WriteString(s) +} + +func (p *TFramedTransport) Flush(ctx context.Context) error { + size := p.buf.Len() + buf := p.buffer[:4] + binary.BigEndian.PutUint32(buf, uint32(size)) + _, err := p.transport.Write(buf) + if err != nil { + p.buf.Truncate(0) + return NewTTransportExceptionFromError(err) + } + if size > 0 { + if n, err := p.buf.WriteTo(p.transport); err != nil { + print("Error while flushing write buffer of size ", size, " to transport, only wrote ", n, " bytes: ", err.Error(), "\n") + p.buf.Truncate(0) + return NewTTransportExceptionFromError(err) + } + } + err = p.transport.Flush(ctx) + return NewTTransportExceptionFromError(err) +} + +func (p *TFramedTransport) readFrameHeader() (uint32, error) { + buf := p.buffer[:4] + if _, err := io.ReadFull(p.reader, buf); err != nil { + return 0, err + } + size := binary.BigEndian.Uint32(buf) + if size < 0 || size > p.maxLength { + return 0, NewTTransportException(UNKNOWN_TRANSPORT_EXCEPTION, fmt.Sprintf("Incorrect frame size (%d)", size)) + } + return size, nil +} + +func (p *TFramedTransport) RemainingBytes() (num_bytes uint64) { + return uint64(p.frameSize) +} diff --git a/src/jaegertracing/thrift/lib/go/thrift/framed_transport_test.go b/src/jaegertracing/thrift/lib/go/thrift/framed_transport_test.go new file mode 100644 index 000000000..8f683ef30 --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/thrift/framed_transport_test.go @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import ( + "testing" +) + +func TestFramedTransport(t *testing.T) { + trans := NewTFramedTransport(NewTMemoryBuffer()) + TransportTest(t, trans, trans) +} diff --git a/src/jaegertracing/thrift/lib/go/thrift/header_context.go b/src/jaegertracing/thrift/lib/go/thrift/header_context.go new file mode 100644 index 000000000..21e880d66 --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/thrift/header_context.go @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import ( + "context" +) + +// See https://godoc.org/context#WithValue on why do we need the unexported typedefs. +type ( + headerKey string + headerKeyList int +) + +// Values for headerKeyList. +const ( + headerKeyListRead headerKeyList = iota + headerKeyListWrite +) + +// SetHeader sets a header in the context. +func SetHeader(ctx context.Context, key, value string) context.Context { + return context.WithValue( + ctx, + headerKey(key), + value, + ) +} + +// GetHeader returns a value of the given header from the context. +func GetHeader(ctx context.Context, key string) (value string, ok bool) { + if v := ctx.Value(headerKey(key)); v != nil { + value, ok = v.(string) + } + return +} + +// SetReadHeaderList sets the key list of read THeaders in the context. +func SetReadHeaderList(ctx context.Context, keys []string) context.Context { + return context.WithValue( + ctx, + headerKeyListRead, + keys, + ) +} + +// GetReadHeaderList returns the key list of read THeaders from the context. +func GetReadHeaderList(ctx context.Context) []string { + if v := ctx.Value(headerKeyListRead); v != nil { + if value, ok := v.([]string); ok { + return value + } + } + return nil +} + +// SetWriteHeaderList sets the key list of THeaders to write in the context. +func SetWriteHeaderList(ctx context.Context, keys []string) context.Context { + return context.WithValue( + ctx, + headerKeyListWrite, + keys, + ) +} + +// GetWriteHeaderList returns the key list of THeaders to write from the context. +func GetWriteHeaderList(ctx context.Context) []string { + if v := ctx.Value(headerKeyListWrite); v != nil { + if value, ok := v.([]string); ok { + return value + } + } + return nil +} + +// AddReadTHeaderToContext adds the whole THeader headers into context. +func AddReadTHeaderToContext(ctx context.Context, headers THeaderMap) context.Context { + keys := make([]string, 0, len(headers)) + for key, value := range headers { + ctx = SetHeader(ctx, key, value) + keys = append(keys, key) + } + return SetReadHeaderList(ctx, keys) +} diff --git a/src/jaegertracing/thrift/lib/go/thrift/header_context_test.go b/src/jaegertracing/thrift/lib/go/thrift/header_context_test.go new file mode 100644 index 000000000..a1ea2d093 --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/thrift/header_context_test.go @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import ( + "context" + "reflect" + "testing" +) + +func TestSetGetHeader(t *testing.T) { + const ( + key = "foo" + value = "bar" + ) + ctx := context.Background() + + ctx = SetHeader(ctx, key, value) + + checkGet := func(t *testing.T, ctx context.Context) { + t.Helper() + got, ok := GetHeader(ctx, key) + if !ok { + t.Fatalf("Cannot get header %q back after setting it.", key) + } + if got != value { + t.Fatalf("Header value expected %q, got %q instead", value, got) + } + } + + checkGet(t, ctx) + + t.Run( + "NoConflicts", + func(t *testing.T) { + type otherType string + const otherValue = "bar2" + + ctx = context.WithValue(ctx, otherType(key), otherValue) + checkGet(t, ctx) + }, + ) + + t.Run( + "GetHeaderOnNonExistKey", + func(t *testing.T) { + const otherKey = "foo2" + + if _, ok := GetHeader(ctx, otherKey); ok { + t.Errorf("GetHeader returned ok on non-existing key %q", otherKey) + } + }, + ) +} + +func TestReadKeyList(t *testing.T) { + headers := THeaderMap{ + "key1": "value1", + "key2": "value2", + } + ctx := context.Background() + + ctx = AddReadTHeaderToContext(ctx, headers) + + got := make(THeaderMap) + keys := GetReadHeaderList(ctx) + t.Logf("keys: %+v", keys) + for _, key := range keys { + value, ok := GetHeader(ctx, key) + if ok { + got[key] = value + } else { + t.Errorf("Cannot get key %q from context", key) + } + } + + if !reflect.DeepEqual(headers, got) { + t.Errorf("Expected header map %+v, got %+v", headers, got) + } + + writtenKeys := GetWriteHeaderList(ctx) + if len(writtenKeys) > 0 { + t.Errorf( + "Expected empty GetWriteHeaderList() result, got %+v", + writtenKeys, + ) + } +} + +func TestWriteKeyList(t *testing.T) { + keys := []string{ + "key1", + "key2", + } + ctx := context.Background() + + ctx = SetWriteHeaderList(ctx, keys) + got := GetWriteHeaderList(ctx) + + if !reflect.DeepEqual(keys, got) { + t.Errorf("Expected header keys %+v, got %+v", keys, got) + } + + readKeys := GetReadHeaderList(ctx) + if len(readKeys) > 0 { + t.Errorf( + "Expected empty GetReadHeaderList() result, got %+v", + readKeys, + ) + } +} diff --git a/src/jaegertracing/thrift/lib/go/thrift/header_protocol.go b/src/jaegertracing/thrift/lib/go/thrift/header_protocol.go new file mode 100644 index 000000000..46205b28b --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/thrift/header_protocol.go @@ -0,0 +1,305 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import ( + "context" +) + +// THeaderProtocol is a thrift protocol that implements THeader: +// https://github.com/apache/thrift/blob/master/doc/specs/HeaderFormat.md +// +// It supports either binary or compact protocol as the wrapped protocol. +// +// Most of the THeader handlings are happening inside THeaderTransport. +type THeaderProtocol struct { + transport *THeaderTransport + + // Will be initialized on first read/write. + protocol TProtocol +} + +// NewTHeaderProtocol creates a new THeaderProtocol from the underlying +// transport. The passed in transport will be wrapped with THeaderTransport. +// +// Note that THeaderTransport handles frame and zlib by itself, +// so the underlying transport should be a raw socket transports (TSocket or TSSLSocket), +// instead of rich transports like TZlibTransport or TFramedTransport. +func NewTHeaderProtocol(trans TTransport) *THeaderProtocol { + t := NewTHeaderTransport(trans) + p, _ := THeaderProtocolDefault.GetProtocol(t) + return &THeaderProtocol{ + transport: t, + protocol: p, + } +} + +type tHeaderProtocolFactory struct{} + +func (tHeaderProtocolFactory) GetProtocol(trans TTransport) TProtocol { + return NewTHeaderProtocol(trans) +} + +// NewTHeaderProtocolFactory creates a factory for THeader. +// +// It's a wrapper for NewTHeaderProtocol +func NewTHeaderProtocolFactory() TProtocolFactory { + return tHeaderProtocolFactory{} +} + +// Transport returns the underlying transport. +// +// It's guaranteed to be of type *THeaderTransport. +func (p *THeaderProtocol) Transport() TTransport { + return p.transport +} + +// GetReadHeaders returns the THeaderMap read from transport. +func (p *THeaderProtocol) GetReadHeaders() THeaderMap { + return p.transport.GetReadHeaders() +} + +// SetWriteHeader sets a header for write. +func (p *THeaderProtocol) SetWriteHeader(key, value string) { + p.transport.SetWriteHeader(key, value) +} + +// ClearWriteHeaders clears all write headers previously set. +func (p *THeaderProtocol) ClearWriteHeaders() { + p.transport.ClearWriteHeaders() +} + +// AddTransform add a transform for writing. +func (p *THeaderProtocol) AddTransform(transform THeaderTransformID) error { + return p.transport.AddTransform(transform) +} + +func (p *THeaderProtocol) Flush(ctx context.Context) error { + return p.transport.Flush(ctx) +} + +func (p *THeaderProtocol) WriteMessageBegin(name string, typeID TMessageType, seqID int32) error { + newProto, err := p.transport.Protocol().GetProtocol(p.transport) + if err != nil { + return err + } + p.protocol = newProto + p.transport.SequenceID = seqID + return p.protocol.WriteMessageBegin(name, typeID, seqID) +} + +func (p *THeaderProtocol) WriteMessageEnd() error { + if err := p.protocol.WriteMessageEnd(); err != nil { + return err + } + return p.transport.Flush(context.Background()) +} + +func (p *THeaderProtocol) WriteStructBegin(name string) error { + return p.protocol.WriteStructBegin(name) +} + +func (p *THeaderProtocol) WriteStructEnd() error { + return p.protocol.WriteStructEnd() +} + +func (p *THeaderProtocol) WriteFieldBegin(name string, typeID TType, id int16) error { + return p.protocol.WriteFieldBegin(name, typeID, id) +} + +func (p *THeaderProtocol) WriteFieldEnd() error { + return p.protocol.WriteFieldEnd() +} + +func (p *THeaderProtocol) WriteFieldStop() error { + return p.protocol.WriteFieldStop() +} + +func (p *THeaderProtocol) WriteMapBegin(keyType TType, valueType TType, size int) error { + return p.protocol.WriteMapBegin(keyType, valueType, size) +} + +func (p *THeaderProtocol) WriteMapEnd() error { + return p.protocol.WriteMapEnd() +} + +func (p *THeaderProtocol) WriteListBegin(elemType TType, size int) error { + return p.protocol.WriteListBegin(elemType, size) +} + +func (p *THeaderProtocol) WriteListEnd() error { + return p.protocol.WriteListEnd() +} + +func (p *THeaderProtocol) WriteSetBegin(elemType TType, size int) error { + return p.protocol.WriteSetBegin(elemType, size) +} + +func (p *THeaderProtocol) WriteSetEnd() error { + return p.protocol.WriteSetEnd() +} + +func (p *THeaderProtocol) WriteBool(value bool) error { + return p.protocol.WriteBool(value) +} + +func (p *THeaderProtocol) WriteByte(value int8) error { + return p.protocol.WriteByte(value) +} + +func (p *THeaderProtocol) WriteI16(value int16) error { + return p.protocol.WriteI16(value) +} + +func (p *THeaderProtocol) WriteI32(value int32) error { + return p.protocol.WriteI32(value) +} + +func (p *THeaderProtocol) WriteI64(value int64) error { + return p.protocol.WriteI64(value) +} + +func (p *THeaderProtocol) WriteDouble(value float64) error { + return p.protocol.WriteDouble(value) +} + +func (p *THeaderProtocol) WriteString(value string) error { + return p.protocol.WriteString(value) +} + +func (p *THeaderProtocol) WriteBinary(value []byte) error { + return p.protocol.WriteBinary(value) +} + +// ReadFrame calls underlying THeaderTransport's ReadFrame function. +func (p *THeaderProtocol) ReadFrame() error { + return p.transport.ReadFrame() +} + +func (p *THeaderProtocol) ReadMessageBegin() (name string, typeID TMessageType, seqID int32, err error) { + if err = p.transport.ReadFrame(); err != nil { + return + } + + var newProto TProtocol + newProto, err = p.transport.Protocol().GetProtocol(p.transport) + if err != nil { + tAppExc, ok := err.(TApplicationException) + if !ok { + return + } + if e := p.protocol.WriteMessageBegin("", EXCEPTION, seqID); e != nil { + return + } + if e := tAppExc.Write(p.protocol); e != nil { + return + } + if e := p.protocol.WriteMessageEnd(); e != nil { + return + } + if e := p.transport.Flush(context.Background()); e != nil { + return + } + return + } + p.protocol = newProto + + return p.protocol.ReadMessageBegin() +} + +func (p *THeaderProtocol) ReadMessageEnd() error { + return p.protocol.ReadMessageEnd() +} + +func (p *THeaderProtocol) ReadStructBegin() (name string, err error) { + return p.protocol.ReadStructBegin() +} + +func (p *THeaderProtocol) ReadStructEnd() error { + return p.protocol.ReadStructEnd() +} + +func (p *THeaderProtocol) ReadFieldBegin() (name string, typeID TType, id int16, err error) { + return p.protocol.ReadFieldBegin() +} + +func (p *THeaderProtocol) ReadFieldEnd() error { + return p.protocol.ReadFieldEnd() +} + +func (p *THeaderProtocol) ReadMapBegin() (keyType TType, valueType TType, size int, err error) { + return p.protocol.ReadMapBegin() +} + +func (p *THeaderProtocol) ReadMapEnd() error { + return p.protocol.ReadMapEnd() +} + +func (p *THeaderProtocol) ReadListBegin() (elemType TType, size int, err error) { + return p.protocol.ReadListBegin() +} + +func (p *THeaderProtocol) ReadListEnd() error { + return p.protocol.ReadListEnd() +} + +func (p *THeaderProtocol) ReadSetBegin() (elemType TType, size int, err error) { + return p.protocol.ReadSetBegin() +} + +func (p *THeaderProtocol) ReadSetEnd() error { + return p.protocol.ReadSetEnd() +} + +func (p *THeaderProtocol) ReadBool() (value bool, err error) { + return p.protocol.ReadBool() +} + +func (p *THeaderProtocol) ReadByte() (value int8, err error) { + return p.protocol.ReadByte() +} + +func (p *THeaderProtocol) ReadI16() (value int16, err error) { + return p.protocol.ReadI16() +} + +func (p *THeaderProtocol) ReadI32() (value int32, err error) { + return p.protocol.ReadI32() +} + +func (p *THeaderProtocol) ReadI64() (value int64, err error) { + return p.protocol.ReadI64() +} + +func (p *THeaderProtocol) ReadDouble() (value float64, err error) { + return p.protocol.ReadDouble() +} + +func (p *THeaderProtocol) ReadString() (value string, err error) { + return p.protocol.ReadString() +} + +func (p *THeaderProtocol) ReadBinary() (value []byte, err error) { + return p.protocol.ReadBinary() +} + +func (p *THeaderProtocol) Skip(fieldType TType) error { + return p.protocol.Skip(fieldType) +} diff --git a/src/jaegertracing/thrift/lib/go/thrift/header_protocol_test.go b/src/jaegertracing/thrift/lib/go/thrift/header_protocol_test.go new file mode 100644 index 000000000..9b6019bcf --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/thrift/header_protocol_test.go @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import ( + "testing" +) + +func TestReadWriteHeaderProtocol(t *testing.T) { + ReadWriteProtocolTest(t, NewTHeaderProtocolFactory()) +} diff --git a/src/jaegertracing/thrift/lib/go/thrift/header_transport.go b/src/jaegertracing/thrift/lib/go/thrift/header_transport.go new file mode 100644 index 000000000..5343ccb46 --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/thrift/header_transport.go @@ -0,0 +1,723 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import ( + "bufio" + "bytes" + "compress/zlib" + "context" + "encoding/binary" + "errors" + "fmt" + "io" + "io/ioutil" +) + +// Size in bytes for 32-bit ints. +const size32 = 4 + +type headerMeta struct { + MagicFlags uint32 + SequenceID int32 + HeaderLength uint16 +} + +const headerMetaSize = 10 + +type clientType int + +const ( + clientUnknown clientType = iota + clientHeaders + clientFramedBinary + clientUnframedBinary + clientFramedCompact + clientUnframedCompact +) + +// Constants defined in THeader format: +// https://github.com/apache/thrift/blob/master/doc/specs/HeaderFormat.md +const ( + THeaderHeaderMagic uint32 = 0x0fff0000 + THeaderHeaderMask uint32 = 0xffff0000 + THeaderFlagsMask uint32 = 0x0000ffff + THeaderMaxFrameSize uint32 = 0x3fffffff +) + +// THeaderMap is the type of the header map in THeader transport. +type THeaderMap map[string]string + +// THeaderProtocolID is the wrapped protocol id used in THeader. +type THeaderProtocolID int32 + +// Supported THeaderProtocolID values. +const ( + THeaderProtocolBinary THeaderProtocolID = 0x00 + THeaderProtocolCompact THeaderProtocolID = 0x02 + THeaderProtocolDefault = THeaderProtocolBinary +) + +// GetProtocol gets the corresponding TProtocol from the wrapped protocol id. +func (id THeaderProtocolID) GetProtocol(trans TTransport) (TProtocol, error) { + switch id { + default: + return nil, NewTApplicationException( + INVALID_PROTOCOL, + fmt.Sprintf("THeader protocol id %d not supported", id), + ) + case THeaderProtocolBinary: + return NewTBinaryProtocolFactoryDefault().GetProtocol(trans), nil + case THeaderProtocolCompact: + return NewTCompactProtocol(trans), nil + } +} + +// THeaderTransformID defines the numeric id of the transform used. +type THeaderTransformID int32 + +// THeaderTransformID values +const ( + TransformNone THeaderTransformID = iota // 0, no special handling + TransformZlib // 1, zlib + // Rest of the values are not currently supported, namely HMAC and Snappy. +) + +var supportedTransformIDs = map[THeaderTransformID]bool{ + TransformNone: true, + TransformZlib: true, +} + +// TransformReader is an io.ReadCloser that handles transforms reading. +type TransformReader struct { + io.Reader + + closers []io.Closer +} + +var _ io.ReadCloser = (*TransformReader)(nil) + +// NewTransformReaderWithCapacity initializes a TransformReader with expected +// closers capacity. +// +// If you don't know the closers capacity beforehand, just use +// +// &TransformReader{Reader: baseReader} +// +// instead would be sufficient. +func NewTransformReaderWithCapacity(baseReader io.Reader, capacity int) *TransformReader { + return &TransformReader{ + Reader: baseReader, + closers: make([]io.Closer, 0, capacity), + } +} + +// Close calls the underlying closers in appropriate order, +// stops at and returns the first error encountered. +func (tr *TransformReader) Close() error { + // Call closers in reversed order + for i := len(tr.closers) - 1; i >= 0; i-- { + if err := tr.closers[i].Close(); err != nil { + return err + } + } + return nil +} + +// AddTransform adds a transform. +func (tr *TransformReader) AddTransform(id THeaderTransformID) error { + switch id { + default: + return NewTApplicationException( + INVALID_TRANSFORM, + fmt.Sprintf("THeaderTransformID %d not supported", id), + ) + case TransformNone: + // no-op + case TransformZlib: + readCloser, err := zlib.NewReader(tr.Reader) + if err != nil { + return err + } + tr.Reader = readCloser + tr.closers = append(tr.closers, readCloser) + } + return nil +} + +// TransformWriter is an io.WriteCloser that handles transforms writing. +type TransformWriter struct { + io.Writer + + closers []io.Closer +} + +var _ io.WriteCloser = (*TransformWriter)(nil) + +// NewTransformWriter creates a new TransformWriter with base writer and transforms. +func NewTransformWriter(baseWriter io.Writer, transforms []THeaderTransformID) (io.WriteCloser, error) { + writer := &TransformWriter{ + Writer: baseWriter, + closers: make([]io.Closer, 0, len(transforms)), + } + for _, id := range transforms { + if err := writer.AddTransform(id); err != nil { + return nil, err + } + } + return writer, nil +} + +// Close calls the underlying closers in appropriate order, +// stops at and returns the first error encountered. +func (tw *TransformWriter) Close() error { + // Call closers in reversed order + for i := len(tw.closers) - 1; i >= 0; i-- { + if err := tw.closers[i].Close(); err != nil { + return err + } + } + return nil +} + +// AddTransform adds a transform. +func (tw *TransformWriter) AddTransform(id THeaderTransformID) error { + switch id { + default: + return NewTApplicationException( + INVALID_TRANSFORM, + fmt.Sprintf("THeaderTransformID %d not supported", id), + ) + case TransformNone: + // no-op + case TransformZlib: + writeCloser := zlib.NewWriter(tw.Writer) + tw.Writer = writeCloser + tw.closers = append(tw.closers, writeCloser) + } + return nil +} + +// THeaderInfoType is the type id of the info headers. +type THeaderInfoType int32 + +// Supported THeaderInfoType values. +const ( + _ THeaderInfoType = iota // Skip 0 + InfoKeyValue // 1 + // Rest of the info types are not supported. +) + +// THeaderTransport is a Transport mode that implements THeader. +// +// Note that THeaderTransport handles frame and zlib by itself, +// so the underlying transport should be a raw socket transports (TSocket or TSSLSocket), +// instead of rich transports like TZlibTransport or TFramedTransport. +type THeaderTransport struct { + SequenceID int32 + Flags uint32 + + transport TTransport + + // THeaderMap for read and write + readHeaders THeaderMap + writeHeaders THeaderMap + + // Reading related variables. + reader *bufio.Reader + // When frame is detected, we read the frame fully into frameBuffer. + frameBuffer bytes.Buffer + // When it's non-nil, Read should read from frameReader instead of + // reader, and EOF error indicates end of frame instead of end of all + // transport. + frameReader io.ReadCloser + + // Writing related variables + writeBuffer bytes.Buffer + writeTransforms []THeaderTransformID + + clientType clientType + protocolID THeaderProtocolID + + // buffer is used in the following scenarios to avoid repetitive + // allocations, while 4 is big enough for all those scenarios: + // + // * header padding (max size 4) + // * write the frame size (size 4) + buffer [4]byte +} + +var _ TTransport = (*THeaderTransport)(nil) + +// NewTHeaderTransport creates THeaderTransport from the underlying transport. +// +// Please note that THeaderTransport handles framing and zlib by itself, +// so the underlying transport should be the raw socket transports (TSocket or TSSLSocket), +// instead of rich transports like TZlibTransport or TFramedTransport. +// +// If trans is already a *THeaderTransport, it will be returned as is. +func NewTHeaderTransport(trans TTransport) *THeaderTransport { + if ht, ok := trans.(*THeaderTransport); ok { + return ht + } + return &THeaderTransport{ + transport: trans, + reader: bufio.NewReader(trans), + writeHeaders: make(THeaderMap), + protocolID: THeaderProtocolDefault, + } +} + +// Open calls the underlying transport's Open function. +func (t *THeaderTransport) Open() error { + return t.transport.Open() +} + +// IsOpen calls the underlying transport's IsOpen function. +func (t *THeaderTransport) IsOpen() bool { + return t.transport.IsOpen() +} + +// ReadFrame tries to read the frame header, guess the client type, and handle +// unframed clients. +func (t *THeaderTransport) ReadFrame() error { + if !t.needReadFrame() { + // No need to read frame, skipping. + return nil + } + // Peek and handle the first 32 bits. + // They could either be the length field of a framed message, + // or the first bytes of an unframed message. + buf, err := t.reader.Peek(size32) + if err != nil { + return err + } + frameSize := binary.BigEndian.Uint32(buf) + if frameSize&VERSION_MASK == VERSION_1 { + t.clientType = clientUnframedBinary + return nil + } + if buf[0] == COMPACT_PROTOCOL_ID && buf[1]&COMPACT_VERSION_MASK == COMPACT_VERSION { + t.clientType = clientUnframedCompact + return nil + } + + // At this point it should be a framed message, + // sanity check on frameSize then discard the peeked part. + if frameSize > THeaderMaxFrameSize { + return NewTProtocolExceptionWithType( + SIZE_LIMIT, + errors.New("frame too large"), + ) + } + t.reader.Discard(size32) + + // Read the frame fully into frameBuffer. + _, err = io.Copy( + &t.frameBuffer, + io.LimitReader(t.reader, int64(frameSize)), + ) + if err != nil { + return err + } + t.frameReader = ioutil.NopCloser(&t.frameBuffer) + + // Peek and handle the next 32 bits. + buf = t.frameBuffer.Bytes()[:size32] + version := binary.BigEndian.Uint32(buf) + if version&THeaderHeaderMask == THeaderHeaderMagic { + t.clientType = clientHeaders + return t.parseHeaders(frameSize) + } + if version&VERSION_MASK == VERSION_1 { + t.clientType = clientFramedBinary + return nil + } + if buf[0] == COMPACT_PROTOCOL_ID && buf[1]&COMPACT_VERSION_MASK == COMPACT_VERSION { + t.clientType = clientFramedCompact + return nil + } + if err := t.endOfFrame(); err != nil { + return err + } + return NewTProtocolExceptionWithType( + NOT_IMPLEMENTED, + errors.New("unsupported client transport type"), + ) +} + +// endOfFrame does end of frame handling. +// +// It closes frameReader, and also resets frame related states. +func (t *THeaderTransport) endOfFrame() error { + defer func() { + t.frameBuffer.Reset() + t.frameReader = nil + }() + return t.frameReader.Close() +} + +func (t *THeaderTransport) parseHeaders(frameSize uint32) error { + if t.clientType != clientHeaders { + return nil + } + + var err error + var meta headerMeta + if err = binary.Read(&t.frameBuffer, binary.BigEndian, &meta); err != nil { + return err + } + frameSize -= headerMetaSize + t.Flags = meta.MagicFlags & THeaderFlagsMask + t.SequenceID = meta.SequenceID + headerLength := int64(meta.HeaderLength) * 4 + if int64(frameSize) < headerLength { + return NewTProtocolExceptionWithType( + SIZE_LIMIT, + errors.New("header size is larger than the whole frame"), + ) + } + headerBuf := NewTMemoryBuffer() + _, err = io.Copy(headerBuf, io.LimitReader(&t.frameBuffer, headerLength)) + if err != nil { + return err + } + hp := NewTCompactProtocol(headerBuf) + + // At this point the header is already read into headerBuf, + // and t.frameBuffer starts from the actual payload. + protoID, err := hp.readVarint32() + if err != nil { + return err + } + t.protocolID = THeaderProtocolID(protoID) + var transformCount int32 + transformCount, err = hp.readVarint32() + if err != nil { + return err + } + if transformCount > 0 { + reader := NewTransformReaderWithCapacity( + &t.frameBuffer, + int(transformCount), + ) + t.frameReader = reader + transformIDs := make([]THeaderTransformID, transformCount) + for i := 0; i < int(transformCount); i++ { + id, err := hp.readVarint32() + if err != nil { + return err + } + transformIDs[i] = THeaderTransformID(id) + } + // The transform IDs on the wire was added based on the order of + // writing, so on the reading side we need to reverse the order. + for i := transformCount - 1; i >= 0; i-- { + id := transformIDs[i] + if err := reader.AddTransform(id); err != nil { + return err + } + } + } + + // The info part does not use the transforms yet, so it's + // important to continue using headerBuf. + headers := make(THeaderMap) + for { + infoType, err := hp.readVarint32() + if err == io.EOF { + break + } + if err != nil { + return err + } + if THeaderInfoType(infoType) == InfoKeyValue { + count, err := hp.readVarint32() + if err != nil { + return err + } + for i := 0; i < int(count); i++ { + key, err := hp.ReadString() + if err != nil { + return err + } + value, err := hp.ReadString() + if err != nil { + return err + } + headers[key] = value + } + } else { + // Skip reading info section on the first + // unsupported info type. + break + } + } + t.readHeaders = headers + + return nil +} + +func (t *THeaderTransport) needReadFrame() bool { + if t.clientType == clientUnknown { + // This is a new connection that's never read before. + return true + } + if t.isFramed() && t.frameReader == nil { + // We just finished the last frame. + return true + } + return false +} + +func (t *THeaderTransport) Read(p []byte) (read int, err error) { + err = t.ReadFrame() + if err != nil { + return + } + if t.frameReader != nil { + read, err = t.frameReader.Read(p) + if err == io.EOF { + err = t.endOfFrame() + if err != nil { + return + } + if read < len(p) { + var nextRead int + nextRead, err = t.Read(p[read:]) + read += nextRead + } + } + return + } + return t.reader.Read(p) +} + +// Write writes data to the write buffer. +// +// You need to call Flush to actually write them to the transport. +func (t *THeaderTransport) Write(p []byte) (int, error) { + return t.writeBuffer.Write(p) +} + +// Flush writes the appropriate header and the write buffer to the underlying transport. +func (t *THeaderTransport) Flush(ctx context.Context) error { + if t.writeBuffer.Len() == 0 { + return nil + } + + defer t.writeBuffer.Reset() + + switch t.clientType { + default: + fallthrough + case clientUnknown: + t.clientType = clientHeaders + fallthrough + case clientHeaders: + headers := NewTMemoryBuffer() + hp := NewTCompactProtocol(headers) + if _, err := hp.writeVarint32(int32(t.protocolID)); err != nil { + return NewTTransportExceptionFromError(err) + } + if _, err := hp.writeVarint32(int32(len(t.writeTransforms))); err != nil { + return NewTTransportExceptionFromError(err) + } + for _, transform := range t.writeTransforms { + if _, err := hp.writeVarint32(int32(transform)); err != nil { + return NewTTransportExceptionFromError(err) + } + } + if len(t.writeHeaders) > 0 { + if _, err := hp.writeVarint32(int32(InfoKeyValue)); err != nil { + return NewTTransportExceptionFromError(err) + } + if _, err := hp.writeVarint32(int32(len(t.writeHeaders))); err != nil { + return NewTTransportExceptionFromError(err) + } + for key, value := range t.writeHeaders { + if err := hp.WriteString(key); err != nil { + return NewTTransportExceptionFromError(err) + } + if err := hp.WriteString(value); err != nil { + return NewTTransportExceptionFromError(err) + } + } + } + padding := 4 - headers.Len()%4 + if padding < 4 { + buf := t.buffer[:padding] + for i := range buf { + buf[i] = 0 + } + if _, err := headers.Write(buf); err != nil { + return NewTTransportExceptionFromError(err) + } + } + + var payload bytes.Buffer + meta := headerMeta{ + MagicFlags: THeaderHeaderMagic + t.Flags&THeaderFlagsMask, + SequenceID: t.SequenceID, + HeaderLength: uint16(headers.Len() / 4), + } + if err := binary.Write(&payload, binary.BigEndian, meta); err != nil { + return NewTTransportExceptionFromError(err) + } + if _, err := io.Copy(&payload, headers); err != nil { + return NewTTransportExceptionFromError(err) + } + + writer, err := NewTransformWriter(&payload, t.writeTransforms) + if err != nil { + return NewTTransportExceptionFromError(err) + } + if _, err := io.Copy(writer, &t.writeBuffer); err != nil { + return NewTTransportExceptionFromError(err) + } + if err := writer.Close(); err != nil { + return NewTTransportExceptionFromError(err) + } + + // First write frame length + buf := t.buffer[:size32] + binary.BigEndian.PutUint32(buf, uint32(payload.Len())) + if _, err := t.transport.Write(buf); err != nil { + return NewTTransportExceptionFromError(err) + } + // Then write the payload + if _, err := io.Copy(t.transport, &payload); err != nil { + return NewTTransportExceptionFromError(err) + } + + case clientFramedBinary, clientFramedCompact: + buf := t.buffer[:size32] + binary.BigEndian.PutUint32(buf, uint32(t.writeBuffer.Len())) + if _, err := t.transport.Write(buf); err != nil { + return NewTTransportExceptionFromError(err) + } + fallthrough + case clientUnframedBinary, clientUnframedCompact: + if _, err := io.Copy(t.transport, &t.writeBuffer); err != nil { + return NewTTransportExceptionFromError(err) + } + } + + select { + default: + case <-ctx.Done(): + return NewTTransportExceptionFromError(ctx.Err()) + } + + return t.transport.Flush(ctx) +} + +// Close closes the transport, along with its underlying transport. +func (t *THeaderTransport) Close() error { + if err := t.Flush(context.Background()); err != nil { + return err + } + return t.transport.Close() +} + +// RemainingBytes calls underlying transport's RemainingBytes. +// +// Even in framed cases, because of all the possible compression transforms +// involved, the remaining frame size is likely to be different from the actual +// remaining readable bytes, so we don't bother to keep tracking the remaining +// frame size by ourselves and just use the underlying transport's +// RemainingBytes directly. +func (t *THeaderTransport) RemainingBytes() uint64 { + return t.transport.RemainingBytes() +} + +// GetReadHeaders returns the THeaderMap read from transport. +func (t *THeaderTransport) GetReadHeaders() THeaderMap { + return t.readHeaders +} + +// SetWriteHeader sets a header for write. +func (t *THeaderTransport) SetWriteHeader(key, value string) { + t.writeHeaders[key] = value +} + +// ClearWriteHeaders clears all write headers previously set. +func (t *THeaderTransport) ClearWriteHeaders() { + t.writeHeaders = make(THeaderMap) +} + +// AddTransform add a transform for writing. +func (t *THeaderTransport) AddTransform(transform THeaderTransformID) error { + if !supportedTransformIDs[transform] { + return NewTProtocolExceptionWithType( + NOT_IMPLEMENTED, + fmt.Errorf("THeaderTransformID %d not supported", transform), + ) + } + t.writeTransforms = append(t.writeTransforms, transform) + return nil +} + +// Protocol returns the wrapped protocol id used in this THeaderTransport. +func (t *THeaderTransport) Protocol() THeaderProtocolID { + switch t.clientType { + default: + return t.protocolID + case clientFramedBinary, clientUnframedBinary: + return THeaderProtocolBinary + case clientFramedCompact, clientUnframedCompact: + return THeaderProtocolCompact + } +} + +func (t *THeaderTransport) isFramed() bool { + switch t.clientType { + default: + return false + case clientHeaders, clientFramedBinary, clientFramedCompact: + return true + } +} + +// THeaderTransportFactory is a TTransportFactory implementation to create +// THeaderTransport. +type THeaderTransportFactory struct { + // The underlying factory, could be nil. + Factory TTransportFactory +} + +// NewTHeaderTransportFactory creates a new *THeaderTransportFactory. +func NewTHeaderTransportFactory(factory TTransportFactory) TTransportFactory { + return &THeaderTransportFactory{ + Factory: factory, + } +} + +// GetTransport implements TTransportFactory. +func (f *THeaderTransportFactory) GetTransport(trans TTransport) (TTransport, error) { + if f.Factory != nil { + t, err := f.Factory.GetTransport(trans) + if err != nil { + return nil, err + } + return NewTHeaderTransport(t), nil + } + return NewTHeaderTransport(trans), nil +} diff --git a/src/jaegertracing/thrift/lib/go/thrift/header_transport_test.go b/src/jaegertracing/thrift/lib/go/thrift/header_transport_test.go new file mode 100644 index 000000000..e30476802 --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/thrift/header_transport_test.go @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import ( + "context" + "io" + "io/ioutil" + "testing" +) + +func TestTHeaderHeadersReadWrite(t *testing.T) { + trans := NewTMemoryBuffer() + reader := NewTHeaderTransport(trans) + writer := NewTHeaderTransport(trans) + + const key1 = "key1" + const value1 = "value1" + const key2 = "key2" + const value2 = "value2" + const payload1 = "hello, world1\n" + const payload2 = "hello, world2\n" + + // Write + if err := writer.AddTransform(TransformZlib); err != nil { + t.Fatalf( + "writer.AddTransform(TransformZlib) returned error: %v", + err, + ) + } + // Use double zlib to make sure that we close them in the right order. + if err := writer.AddTransform(TransformZlib); err != nil { + t.Fatalf( + "writer.AddTransform(TransformZlib) returned error: %v", + err, + ) + } + if err := writer.AddTransform(TransformNone); err != nil { + t.Fatalf( + "writer.AddTransform(TransformNone) returned error: %v", + err, + ) + } + writer.SetWriteHeader(key1, value1) + writer.SetWriteHeader(key2, value2) + if _, err := writer.Write([]byte(payload1)); err != nil { + t.Errorf("writer.Write returned error: %v", err) + } + if err := writer.Flush(context.Background()); err != nil { + t.Errorf("writer.Flush returned error: %v", err) + } + if _, err := writer.Write([]byte(payload2)); err != nil { + t.Errorf("writer.Write returned error: %v", err) + } + if err := writer.Flush(context.Background()); err != nil { + t.Errorf("writer.Flush returned error: %v", err) + } + + // Read + + // Make sure multiple calls to ReadFrame is fine. + if err := reader.ReadFrame(); err != nil { + t.Errorf("reader.ReadFrame returned error: %v", err) + } + if err := reader.ReadFrame(); err != nil { + t.Errorf("reader.ReadFrame returned error: %v", err) + } + read, err := ioutil.ReadAll(reader) + if err != nil { + t.Errorf("Read returned error: %v", err) + } + if err := reader.ReadFrame(); err != nil && err != io.EOF { + t.Errorf("reader.ReadFrame returned error: %v", err) + } + if string(read) != payload1+payload2 { + t.Errorf( + "Read content expected %q, got %q", + payload1+payload2, + read, + ) + } + if prot := reader.Protocol(); prot != THeaderProtocolBinary { + t.Errorf( + "reader.Protocol() expected %d, got %d", + THeaderProtocolBinary, + prot, + ) + } + if reader.clientType != clientHeaders { + t.Errorf( + "reader.clientType expected %d, got %d", + clientHeaders, + reader.clientType, + ) + } + headers := reader.GetReadHeaders() + if len(headers) != 2 || headers[key1] != value1 || headers[key2] != value2 { + t.Errorf( + "reader.GetReadHeaders() expected size 2, actual content: %+v", + headers, + ) + } +} + +func TestTHeaderTransportNoDoubleWrapping(t *testing.T) { + trans := NewTMemoryBuffer() + orig := NewTHeaderTransport(trans) + wrapped := NewTHeaderTransport(orig) + + if wrapped != orig { + t.Errorf("NewTHeaderTransport double wrapped THeaderTransport") + } +} diff --git a/src/jaegertracing/thrift/lib/go/thrift/http_client.go b/src/jaegertracing/thrift/lib/go/thrift/http_client.go new file mode 100644 index 000000000..5c82bf538 --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/thrift/http_client.go @@ -0,0 +1,242 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import ( + "bytes" + "context" + "io" + "io/ioutil" + "net/http" + "net/url" + "strconv" +) + +// Default to using the shared http client. Library users are +// free to change this global client or specify one through +// THttpClientOptions. +var DefaultHttpClient *http.Client = http.DefaultClient + +type THttpClient struct { + client *http.Client + response *http.Response + url *url.URL + requestBuffer *bytes.Buffer + header http.Header + nsecConnectTimeout int64 + nsecReadTimeout int64 +} + +type THttpClientTransportFactory struct { + options THttpClientOptions + url string +} + +func (p *THttpClientTransportFactory) GetTransport(trans TTransport) (TTransport, error) { + if trans != nil { + t, ok := trans.(*THttpClient) + if ok && t.url != nil { + return NewTHttpClientWithOptions(t.url.String(), p.options) + } + } + return NewTHttpClientWithOptions(p.url, p.options) +} + +type THttpClientOptions struct { + // If nil, DefaultHttpClient is used + Client *http.Client +} + +func NewTHttpClientTransportFactory(url string) *THttpClientTransportFactory { + return NewTHttpClientTransportFactoryWithOptions(url, THttpClientOptions{}) +} + +func NewTHttpClientTransportFactoryWithOptions(url string, options THttpClientOptions) *THttpClientTransportFactory { + return &THttpClientTransportFactory{url: url, options: options} +} + +func NewTHttpClientWithOptions(urlstr string, options THttpClientOptions) (TTransport, error) { + parsedURL, err := url.Parse(urlstr) + if err != nil { + return nil, err + } + buf := make([]byte, 0, 1024) + client := options.Client + if client == nil { + client = DefaultHttpClient + } + httpHeader := map[string][]string{"Content-Type": {"application/x-thrift"}} + return &THttpClient{client: client, url: parsedURL, requestBuffer: bytes.NewBuffer(buf), header: httpHeader}, nil +} + +func NewTHttpClient(urlstr string) (TTransport, error) { + return NewTHttpClientWithOptions(urlstr, THttpClientOptions{}) +} + +// Set the HTTP Header for this specific Thrift Transport +// It is important that you first assert the TTransport as a THttpClient type +// like so: +// +// httpTrans := trans.(THttpClient) +// httpTrans.SetHeader("User-Agent","Thrift Client 1.0") +func (p *THttpClient) SetHeader(key string, value string) { + p.header.Add(key, value) +} + +// Get the HTTP Header represented by the supplied Header Key for this specific Thrift Transport +// It is important that you first assert the TTransport as a THttpClient type +// like so: +// +// httpTrans := trans.(THttpClient) +// hdrValue := httpTrans.GetHeader("User-Agent") +func (p *THttpClient) GetHeader(key string) string { + return p.header.Get(key) +} + +// Deletes the HTTP Header given a Header Key for this specific Thrift Transport +// It is important that you first assert the TTransport as a THttpClient type +// like so: +// +// httpTrans := trans.(THttpClient) +// httpTrans.DelHeader("User-Agent") +func (p *THttpClient) DelHeader(key string) { + p.header.Del(key) +} + +func (p *THttpClient) Open() error { + // do nothing + return nil +} + +func (p *THttpClient) IsOpen() bool { + return p.response != nil || p.requestBuffer != nil +} + +func (p *THttpClient) closeResponse() error { + var err error + if p.response != nil && p.response.Body != nil { + // The docs specify that if keepalive is enabled and the response body is not + // read to completion the connection will never be returned to the pool and + // reused. Errors are being ignored here because if the connection is invalid + // and this fails for some reason, the Close() method will do any remaining + // cleanup. + io.Copy(ioutil.Discard, p.response.Body) + + err = p.response.Body.Close() + } + + p.response = nil + return err +} + +func (p *THttpClient) Close() error { + if p.requestBuffer != nil { + p.requestBuffer.Reset() + p.requestBuffer = nil + } + return p.closeResponse() +} + +func (p *THttpClient) Read(buf []byte) (int, error) { + if p.response == nil { + return 0, NewTTransportException(NOT_OPEN, "Response buffer is empty, no request.") + } + n, err := p.response.Body.Read(buf) + if n > 0 && (err == nil || err == io.EOF) { + return n, nil + } + return n, NewTTransportExceptionFromError(err) +} + +func (p *THttpClient) ReadByte() (c byte, err error) { + return readByte(p.response.Body) +} + +func (p *THttpClient) Write(buf []byte) (int, error) { + n, err := p.requestBuffer.Write(buf) + return n, err +} + +func (p *THttpClient) WriteByte(c byte) error { + return p.requestBuffer.WriteByte(c) +} + +func (p *THttpClient) WriteString(s string) (n int, err error) { + return p.requestBuffer.WriteString(s) +} + +func (p *THttpClient) Flush(ctx context.Context) error { + // Close any previous response body to avoid leaking connections. + p.closeResponse() + + req, err := http.NewRequest("POST", p.url.String(), p.requestBuffer) + if err != nil { + return NewTTransportExceptionFromError(err) + } + req.Header = p.header + if ctx != nil { + req = req.WithContext(ctx) + } + response, err := p.client.Do(req) + if err != nil { + return NewTTransportExceptionFromError(err) + } + if response.StatusCode != http.StatusOK { + // Close the response to avoid leaking file descriptors. closeResponse does + // more than just call Close(), so temporarily assign it and reuse the logic. + p.response = response + p.closeResponse() + + // TODO(pomack) log bad response + return NewTTransportException(UNKNOWN_TRANSPORT_EXCEPTION, "HTTP Response code: "+strconv.Itoa(response.StatusCode)) + } + p.response = response + return nil +} + +func (p *THttpClient) RemainingBytes() (num_bytes uint64) { + len := p.response.ContentLength + if len >= 0 { + return uint64(len) + } + + const maxSize = ^uint64(0) + return maxSize // the thruth is, we just don't know unless framed is used +} + +// Deprecated: Use NewTHttpClientTransportFactory instead. +func NewTHttpPostClientTransportFactory(url string) *THttpClientTransportFactory { + return NewTHttpClientTransportFactoryWithOptions(url, THttpClientOptions{}) +} + +// Deprecated: Use NewTHttpClientTransportFactoryWithOptions instead. +func NewTHttpPostClientTransportFactoryWithOptions(url string, options THttpClientOptions) *THttpClientTransportFactory { + return NewTHttpClientTransportFactoryWithOptions(url, options) +} + +// Deprecated: Use NewTHttpClientWithOptions instead. +func NewTHttpPostClientWithOptions(urlstr string, options THttpClientOptions) (TTransport, error) { + return NewTHttpClientWithOptions(urlstr, options) +} + +// Deprecated: Use NewTHttpClient instead. +func NewTHttpPostClient(urlstr string) (TTransport, error) { + return NewTHttpClientWithOptions(urlstr, THttpClientOptions{}) +} diff --git a/src/jaegertracing/thrift/lib/go/thrift/http_client_test.go b/src/jaegertracing/thrift/lib/go/thrift/http_client_test.go new file mode 100644 index 000000000..453680ace --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/thrift/http_client_test.go @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import ( + "net/http" + "testing" +) + +func TestHttpClient(t *testing.T) { + l, addr := HttpClientSetupForTest(t) + if l != nil { + defer l.Close() + } + trans, err := NewTHttpPostClient("http://" + addr.String()) + if err != nil { + l.Close() + t.Fatalf("Unable to connect to %s: %s", addr.String(), err) + } + TransportTest(t, trans, trans) +} + +func TestHttpClientHeaders(t *testing.T) { + l, addr := HttpClientSetupForTest(t) + if l != nil { + defer l.Close() + } + trans, err := NewTHttpPostClient("http://" + addr.String()) + if err != nil { + l.Close() + t.Fatalf("Unable to connect to %s: %s", addr.String(), err) + } + TransportHeaderTest(t, trans, trans) +} + +func TestHttpCustomClient(t *testing.T) { + l, addr := HttpClientSetupForTest(t) + if l != nil { + defer l.Close() + } + + httpTransport := &customHttpTransport{} + + trans, err := NewTHttpPostClientWithOptions("http://"+addr.String(), THttpClientOptions{ + Client: &http.Client{ + Transport: httpTransport, + }, + }) + if err != nil { + l.Close() + t.Fatalf("Unable to connect to %s: %s", addr.String(), err) + } + TransportHeaderTest(t, trans, trans) + + if !httpTransport.hit { + t.Fatalf("Custom client was not used") + } +} + +func TestHttpCustomClientPackageScope(t *testing.T) { + l, addr := HttpClientSetupForTest(t) + if l != nil { + defer l.Close() + } + httpTransport := &customHttpTransport{} + DefaultHttpClient = &http.Client{ + Transport: httpTransport, + } + + trans, err := NewTHttpPostClient("http://" + addr.String()) + if err != nil { + l.Close() + t.Fatalf("Unable to connect to %s: %s", addr.String(), err) + } + TransportHeaderTest(t, trans, trans) + + if !httpTransport.hit { + t.Fatalf("Custom client was not used") + } +} + +type customHttpTransport struct { + hit bool +} + +func (c *customHttpTransport) RoundTrip(req *http.Request) (*http.Response, error) { + c.hit = true + return http.DefaultTransport.RoundTrip(req) +} diff --git a/src/jaegertracing/thrift/lib/go/thrift/http_transport.go b/src/jaegertracing/thrift/lib/go/thrift/http_transport.go new file mode 100644 index 000000000..66f0f388a --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/thrift/http_transport.go @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import ( + "compress/gzip" + "io" + "net/http" + "strings" +) + +// NewThriftHandlerFunc is a function that create a ready to use Apache Thrift Handler function +func NewThriftHandlerFunc(processor TProcessor, + inPfactory, outPfactory TProtocolFactory) func(w http.ResponseWriter, r *http.Request) { + + return gz(func(w http.ResponseWriter, r *http.Request) { + w.Header().Add("Content-Type", "application/x-thrift") + + transport := NewStreamTransport(r.Body, w) + processor.Process(r.Context(), inPfactory.GetProtocol(transport), outPfactory.GetProtocol(transport)) + }) +} + +// gz transparently compresses the HTTP response if the client supports it. +func gz(handler http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if !strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") { + handler(w, r) + return + } + w.Header().Set("Content-Encoding", "gzip") + gz := gzip.NewWriter(w) + defer gz.Close() + gzw := gzipResponseWriter{Writer: gz, ResponseWriter: w} + handler(gzw, r) + } +} + +type gzipResponseWriter struct { + io.Writer + http.ResponseWriter +} + +func (w gzipResponseWriter) Write(b []byte) (int, error) { + return w.Writer.Write(b) +} diff --git a/src/jaegertracing/thrift/lib/go/thrift/iostream_transport.go b/src/jaegertracing/thrift/lib/go/thrift/iostream_transport.go new file mode 100644 index 000000000..fea93bcef --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/thrift/iostream_transport.go @@ -0,0 +1,214 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import ( + "bufio" + "context" + "io" +) + +// StreamTransport is a Transport made of an io.Reader and/or an io.Writer +type StreamTransport struct { + io.Reader + io.Writer + isReadWriter bool + closed bool +} + +type StreamTransportFactory struct { + Reader io.Reader + Writer io.Writer + isReadWriter bool +} + +func (p *StreamTransportFactory) GetTransport(trans TTransport) (TTransport, error) { + if trans != nil { + t, ok := trans.(*StreamTransport) + if ok { + if t.isReadWriter { + return NewStreamTransportRW(t.Reader.(io.ReadWriter)), nil + } + if t.Reader != nil && t.Writer != nil { + return NewStreamTransport(t.Reader, t.Writer), nil + } + if t.Reader != nil && t.Writer == nil { + return NewStreamTransportR(t.Reader), nil + } + if t.Reader == nil && t.Writer != nil { + return NewStreamTransportW(t.Writer), nil + } + return &StreamTransport{}, nil + } + } + if p.isReadWriter { + return NewStreamTransportRW(p.Reader.(io.ReadWriter)), nil + } + if p.Reader != nil && p.Writer != nil { + return NewStreamTransport(p.Reader, p.Writer), nil + } + if p.Reader != nil && p.Writer == nil { + return NewStreamTransportR(p.Reader), nil + } + if p.Reader == nil && p.Writer != nil { + return NewStreamTransportW(p.Writer), nil + } + return &StreamTransport{}, nil +} + +func NewStreamTransportFactory(reader io.Reader, writer io.Writer, isReadWriter bool) *StreamTransportFactory { + return &StreamTransportFactory{Reader: reader, Writer: writer, isReadWriter: isReadWriter} +} + +func NewStreamTransport(r io.Reader, w io.Writer) *StreamTransport { + return &StreamTransport{Reader: bufio.NewReader(r), Writer: bufio.NewWriter(w)} +} + +func NewStreamTransportR(r io.Reader) *StreamTransport { + return &StreamTransport{Reader: bufio.NewReader(r)} +} + +func NewStreamTransportW(w io.Writer) *StreamTransport { + return &StreamTransport{Writer: bufio.NewWriter(w)} +} + +func NewStreamTransportRW(rw io.ReadWriter) *StreamTransport { + bufrw := bufio.NewReadWriter(bufio.NewReader(rw), bufio.NewWriter(rw)) + return &StreamTransport{Reader: bufrw, Writer: bufrw, isReadWriter: true} +} + +func (p *StreamTransport) IsOpen() bool { + return !p.closed +} + +// implicitly opened on creation, can't be reopened once closed +func (p *StreamTransport) Open() error { + if !p.closed { + return NewTTransportException(ALREADY_OPEN, "StreamTransport already open.") + } else { + return NewTTransportException(NOT_OPEN, "cannot reopen StreamTransport.") + } +} + +// Closes both the input and output streams. +func (p *StreamTransport) Close() error { + if p.closed { + return NewTTransportException(NOT_OPEN, "StreamTransport already closed.") + } + p.closed = true + closedReader := false + if p.Reader != nil { + c, ok := p.Reader.(io.Closer) + if ok { + e := c.Close() + closedReader = true + if e != nil { + return e + } + } + p.Reader = nil + } + if p.Writer != nil && (!closedReader || !p.isReadWriter) { + c, ok := p.Writer.(io.Closer) + if ok { + e := c.Close() + if e != nil { + return e + } + } + p.Writer = nil + } + return nil +} + +// Flushes the underlying output stream if not null. +func (p *StreamTransport) Flush(ctx context.Context) error { + if p.Writer == nil { + return NewTTransportException(NOT_OPEN, "Cannot flush null outputStream") + } + f, ok := p.Writer.(Flusher) + if ok { + err := f.Flush() + if err != nil { + return NewTTransportExceptionFromError(err) + } + } + return nil +} + +func (p *StreamTransport) Read(c []byte) (n int, err error) { + n, err = p.Reader.Read(c) + if err != nil { + err = NewTTransportExceptionFromError(err) + } + return +} + +func (p *StreamTransport) ReadByte() (c byte, err error) { + f, ok := p.Reader.(io.ByteReader) + if ok { + c, err = f.ReadByte() + } else { + c, err = readByte(p.Reader) + } + if err != nil { + err = NewTTransportExceptionFromError(err) + } + return +} + +func (p *StreamTransport) Write(c []byte) (n int, err error) { + n, err = p.Writer.Write(c) + if err != nil { + err = NewTTransportExceptionFromError(err) + } + return +} + +func (p *StreamTransport) WriteByte(c byte) (err error) { + f, ok := p.Writer.(io.ByteWriter) + if ok { + err = f.WriteByte(c) + } else { + err = writeByte(p.Writer, c) + } + if err != nil { + err = NewTTransportExceptionFromError(err) + } + return +} + +func (p *StreamTransport) WriteString(s string) (n int, err error) { + f, ok := p.Writer.(stringWriter) + if ok { + n, err = f.WriteString(s) + } else { + n, err = p.Writer.Write([]byte(s)) + } + if err != nil { + err = NewTTransportExceptionFromError(err) + } + return +} + +func (p *StreamTransport) RemainingBytes() (num_bytes uint64) { + const maxSize = ^uint64(0) + return maxSize // the thruth is, we just don't know unless framed is used +} diff --git a/src/jaegertracing/thrift/lib/go/thrift/iostream_transport_test.go b/src/jaegertracing/thrift/lib/go/thrift/iostream_transport_test.go new file mode 100644 index 000000000..15a611642 --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/thrift/iostream_transport_test.go @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import ( + "bytes" + "testing" +) + +func TestStreamTransport(t *testing.T) { + trans := NewStreamTransportRW(bytes.NewBuffer(make([]byte, 0, 1024))) + TransportTest(t, trans, trans) +} + +func TestStreamTransportOpenClose(t *testing.T) { + trans := NewStreamTransportRW(bytes.NewBuffer(make([]byte, 0, 1024))) + if !trans.IsOpen() { + t.Fatal("StreamTransport should be already open") + } + if trans.Open() == nil { + t.Fatal("StreamTransport should return error when open twice") + } + if trans.Close() != nil { + t.Fatal("StreamTransport should not return error when closing open transport") + } + if trans.IsOpen() { + t.Fatal("StreamTransport should not be open after close") + } + if trans.Close() == nil { + t.Fatal("StreamTransport should return error when closing a non open transport") + } + if trans.Open() == nil { + t.Fatal("StreamTransport should not be able to reopen") + } +} diff --git a/src/jaegertracing/thrift/lib/go/thrift/json_protocol.go b/src/jaegertracing/thrift/lib/go/thrift/json_protocol.go new file mode 100644 index 000000000..800ac22c7 --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/thrift/json_protocol.go @@ -0,0 +1,581 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import ( + "context" + "encoding/base64" + "fmt" +) + +const ( + THRIFT_JSON_PROTOCOL_VERSION = 1 +) + +// for references to _ParseContext see tsimplejson_protocol.go + +// JSON protocol implementation for thrift. +// Utilizes Simple JSON protocol +// +type TJSONProtocol struct { + *TSimpleJSONProtocol +} + +// Constructor +func NewTJSONProtocol(t TTransport) *TJSONProtocol { + v := &TJSONProtocol{TSimpleJSONProtocol: NewTSimpleJSONProtocol(t)} + v.parseContextStack = append(v.parseContextStack, int(_CONTEXT_IN_TOPLEVEL)) + v.dumpContext = append(v.dumpContext, int(_CONTEXT_IN_TOPLEVEL)) + return v +} + +// Factory +type TJSONProtocolFactory struct{} + +func (p *TJSONProtocolFactory) GetProtocol(trans TTransport) TProtocol { + return NewTJSONProtocol(trans) +} + +func NewTJSONProtocolFactory() *TJSONProtocolFactory { + return &TJSONProtocolFactory{} +} + +func (p *TJSONProtocol) WriteMessageBegin(name string, typeId TMessageType, seqId int32) error { + p.resetContextStack() // THRIFT-3735 + if e := p.OutputListBegin(); e != nil { + return e + } + if e := p.WriteI32(THRIFT_JSON_PROTOCOL_VERSION); e != nil { + return e + } + if e := p.WriteString(name); e != nil { + return e + } + if e := p.WriteByte(int8(typeId)); e != nil { + return e + } + if e := p.WriteI32(seqId); e != nil { + return e + } + return nil +} + +func (p *TJSONProtocol) WriteMessageEnd() error { + return p.OutputListEnd() +} + +func (p *TJSONProtocol) WriteStructBegin(name string) error { + if e := p.OutputObjectBegin(); e != nil { + return e + } + return nil +} + +func (p *TJSONProtocol) WriteStructEnd() error { + return p.OutputObjectEnd() +} + +func (p *TJSONProtocol) WriteFieldBegin(name string, typeId TType, id int16) error { + if e := p.WriteI16(id); e != nil { + return e + } + if e := p.OutputObjectBegin(); e != nil { + return e + } + s, e1 := p.TypeIdToString(typeId) + if e1 != nil { + return e1 + } + if e := p.WriteString(s); e != nil { + return e + } + return nil +} + +func (p *TJSONProtocol) WriteFieldEnd() error { + return p.OutputObjectEnd() +} + +func (p *TJSONProtocol) WriteFieldStop() error { return nil } + +func (p *TJSONProtocol) WriteMapBegin(keyType TType, valueType TType, size int) error { + if e := p.OutputListBegin(); e != nil { + return e + } + s, e1 := p.TypeIdToString(keyType) + if e1 != nil { + return e1 + } + if e := p.WriteString(s); e != nil { + return e + } + s, e1 = p.TypeIdToString(valueType) + if e1 != nil { + return e1 + } + if e := p.WriteString(s); e != nil { + return e + } + if e := p.WriteI64(int64(size)); e != nil { + return e + } + return p.OutputObjectBegin() +} + +func (p *TJSONProtocol) WriteMapEnd() error { + if e := p.OutputObjectEnd(); e != nil { + return e + } + return p.OutputListEnd() +} + +func (p *TJSONProtocol) WriteListBegin(elemType TType, size int) error { + return p.OutputElemListBegin(elemType, size) +} + +func (p *TJSONProtocol) WriteListEnd() error { + return p.OutputListEnd() +} + +func (p *TJSONProtocol) WriteSetBegin(elemType TType, size int) error { + return p.OutputElemListBegin(elemType, size) +} + +func (p *TJSONProtocol) WriteSetEnd() error { + return p.OutputListEnd() +} + +func (p *TJSONProtocol) WriteBool(b bool) error { + if b { + return p.WriteI32(1) + } + return p.WriteI32(0) +} + +func (p *TJSONProtocol) WriteByte(b int8) error { + return p.WriteI32(int32(b)) +} + +func (p *TJSONProtocol) WriteI16(v int16) error { + return p.WriteI32(int32(v)) +} + +func (p *TJSONProtocol) WriteI32(v int32) error { + return p.OutputI64(int64(v)) +} + +func (p *TJSONProtocol) WriteI64(v int64) error { + return p.OutputI64(int64(v)) +} + +func (p *TJSONProtocol) WriteDouble(v float64) error { + return p.OutputF64(v) +} + +func (p *TJSONProtocol) WriteString(v string) error { + return p.OutputString(v) +} + +func (p *TJSONProtocol) WriteBinary(v []byte) error { + // JSON library only takes in a string, + // not an arbitrary byte array, to ensure bytes are transmitted + // efficiently we must convert this into a valid JSON string + // therefore we use base64 encoding to avoid excessive escaping/quoting + if e := p.OutputPreValue(); e != nil { + return e + } + if _, e := p.write(JSON_QUOTE_BYTES); e != nil { + return NewTProtocolException(e) + } + writer := base64.NewEncoder(base64.StdEncoding, p.writer) + if _, e := writer.Write(v); e != nil { + p.writer.Reset(p.trans) // THRIFT-3735 + return NewTProtocolException(e) + } + if e := writer.Close(); e != nil { + return NewTProtocolException(e) + } + if _, e := p.write(JSON_QUOTE_BYTES); e != nil { + return NewTProtocolException(e) + } + return p.OutputPostValue() +} + +// Reading methods. +func (p *TJSONProtocol) ReadMessageBegin() (name string, typeId TMessageType, seqId int32, err error) { + p.resetContextStack() // THRIFT-3735 + if isNull, err := p.ParseListBegin(); isNull || err != nil { + return name, typeId, seqId, err + } + version, err := p.ReadI32() + if err != nil { + return name, typeId, seqId, err + } + if version != THRIFT_JSON_PROTOCOL_VERSION { + e := fmt.Errorf("Unknown Protocol version %d, expected version %d", version, THRIFT_JSON_PROTOCOL_VERSION) + return name, typeId, seqId, NewTProtocolExceptionWithType(INVALID_DATA, e) + + } + if name, err = p.ReadString(); err != nil { + return name, typeId, seqId, err + } + bTypeId, err := p.ReadByte() + typeId = TMessageType(bTypeId) + if err != nil { + return name, typeId, seqId, err + } + if seqId, err = p.ReadI32(); err != nil { + return name, typeId, seqId, err + } + return name, typeId, seqId, nil +} + +func (p *TJSONProtocol) ReadMessageEnd() error { + err := p.ParseListEnd() + return err +} + +func (p *TJSONProtocol) ReadStructBegin() (name string, err error) { + _, err = p.ParseObjectStart() + return "", err +} + +func (p *TJSONProtocol) ReadStructEnd() error { + return p.ParseObjectEnd() +} + +func (p *TJSONProtocol) ReadFieldBegin() (string, TType, int16, error) { + b, _ := p.reader.Peek(1) + if len(b) < 1 || b[0] == JSON_RBRACE[0] || b[0] == JSON_RBRACKET[0] { + return "", STOP, -1, nil + } + fieldId, err := p.ReadI16() + if err != nil { + return "", STOP, fieldId, err + } + if _, err = p.ParseObjectStart(); err != nil { + return "", STOP, fieldId, err + } + sType, err := p.ReadString() + if err != nil { + return "", STOP, fieldId, err + } + fType, err := p.StringToTypeId(sType) + return "", fType, fieldId, err +} + +func (p *TJSONProtocol) ReadFieldEnd() error { + return p.ParseObjectEnd() +} + +func (p *TJSONProtocol) ReadMapBegin() (keyType TType, valueType TType, size int, e error) { + if isNull, e := p.ParseListBegin(); isNull || e != nil { + return VOID, VOID, 0, e + } + + // read keyType + sKeyType, e := p.ReadString() + if e != nil { + return keyType, valueType, size, e + } + keyType, e = p.StringToTypeId(sKeyType) + if e != nil { + return keyType, valueType, size, e + } + + // read valueType + sValueType, e := p.ReadString() + if e != nil { + return keyType, valueType, size, e + } + valueType, e = p.StringToTypeId(sValueType) + if e != nil { + return keyType, valueType, size, e + } + + // read size + iSize, e := p.ReadI64() + if e != nil { + return keyType, valueType, size, e + } + size = int(iSize) + + _, e = p.ParseObjectStart() + return keyType, valueType, size, e +} + +func (p *TJSONProtocol) ReadMapEnd() error { + e := p.ParseObjectEnd() + if e != nil { + return e + } + return p.ParseListEnd() +} + +func (p *TJSONProtocol) ReadListBegin() (elemType TType, size int, e error) { + return p.ParseElemListBegin() +} + +func (p *TJSONProtocol) ReadListEnd() error { + return p.ParseListEnd() +} + +func (p *TJSONProtocol) ReadSetBegin() (elemType TType, size int, e error) { + return p.ParseElemListBegin() +} + +func (p *TJSONProtocol) ReadSetEnd() error { + return p.ParseListEnd() +} + +func (p *TJSONProtocol) ReadBool() (bool, error) { + value, err := p.ReadI32() + return (value != 0), err +} + +func (p *TJSONProtocol) ReadByte() (int8, error) { + v, err := p.ReadI64() + return int8(v), err +} + +func (p *TJSONProtocol) ReadI16() (int16, error) { + v, err := p.ReadI64() + return int16(v), err +} + +func (p *TJSONProtocol) ReadI32() (int32, error) { + v, err := p.ReadI64() + return int32(v), err +} + +func (p *TJSONProtocol) ReadI64() (int64, error) { + v, _, err := p.ParseI64() + return v, err +} + +func (p *TJSONProtocol) ReadDouble() (float64, error) { + v, _, err := p.ParseF64() + return v, err +} + +func (p *TJSONProtocol) ReadString() (string, error) { + var v string + if err := p.ParsePreValue(); err != nil { + return v, err + } + f, _ := p.reader.Peek(1) + if len(f) > 0 && f[0] == JSON_QUOTE { + p.reader.ReadByte() + value, err := p.ParseStringBody() + v = value + if err != nil { + return v, err + } + } else if len(f) > 0 && f[0] == JSON_NULL[0] { + b := make([]byte, len(JSON_NULL)) + _, err := p.reader.Read(b) + if err != nil { + return v, NewTProtocolException(err) + } + if string(b) != string(JSON_NULL) { + e := fmt.Errorf("Expected a JSON string, found unquoted data started with %s", string(b)) + return v, NewTProtocolExceptionWithType(INVALID_DATA, e) + } + } else { + e := fmt.Errorf("Expected a JSON string, found unquoted data started with %s", string(f)) + return v, NewTProtocolExceptionWithType(INVALID_DATA, e) + } + return v, p.ParsePostValue() +} + +func (p *TJSONProtocol) ReadBinary() ([]byte, error) { + var v []byte + if err := p.ParsePreValue(); err != nil { + return nil, err + } + f, _ := p.reader.Peek(1) + if len(f) > 0 && f[0] == JSON_QUOTE { + p.reader.ReadByte() + value, err := p.ParseBase64EncodedBody() + v = value + if err != nil { + return v, err + } + } else if len(f) > 0 && f[0] == JSON_NULL[0] { + b := make([]byte, len(JSON_NULL)) + _, err := p.reader.Read(b) + if err != nil { + return v, NewTProtocolException(err) + } + if string(b) != string(JSON_NULL) { + e := fmt.Errorf("Expected a JSON string, found unquoted data started with %s", string(b)) + return v, NewTProtocolExceptionWithType(INVALID_DATA, e) + } + } else { + e := fmt.Errorf("Expected a JSON string, found unquoted data started with %s", string(f)) + return v, NewTProtocolExceptionWithType(INVALID_DATA, e) + } + + return v, p.ParsePostValue() +} + +func (p *TJSONProtocol) Flush(ctx context.Context) (err error) { + err = p.writer.Flush() + if err == nil { + err = p.trans.Flush(ctx) + } + return NewTProtocolException(err) +} + +func (p *TJSONProtocol) Skip(fieldType TType) (err error) { + return SkipDefaultDepth(p, fieldType) +} + +func (p *TJSONProtocol) Transport() TTransport { + return p.trans +} + +func (p *TJSONProtocol) OutputElemListBegin(elemType TType, size int) error { + if e := p.OutputListBegin(); e != nil { + return e + } + s, e1 := p.TypeIdToString(elemType) + if e1 != nil { + return e1 + } + if e := p.WriteString(s); e != nil { + return e + } + if e := p.WriteI64(int64(size)); e != nil { + return e + } + return nil +} + +func (p *TJSONProtocol) ParseElemListBegin() (elemType TType, size int, e error) { + if isNull, e := p.ParseListBegin(); isNull || e != nil { + return VOID, 0, e + } + sElemType, err := p.ReadString() + if err != nil { + return VOID, size, err + } + elemType, err = p.StringToTypeId(sElemType) + if err != nil { + return elemType, size, err + } + nSize, err2 := p.ReadI64() + size = int(nSize) + return elemType, size, err2 +} + +func (p *TJSONProtocol) readElemListBegin() (elemType TType, size int, e error) { + if isNull, e := p.ParseListBegin(); isNull || e != nil { + return VOID, 0, e + } + sElemType, err := p.ReadString() + if err != nil { + return VOID, size, err + } + elemType, err = p.StringToTypeId(sElemType) + if err != nil { + return elemType, size, err + } + nSize, err2 := p.ReadI64() + size = int(nSize) + return elemType, size, err2 +} + +func (p *TJSONProtocol) writeElemListBegin(elemType TType, size int) error { + if e := p.OutputListBegin(); e != nil { + return e + } + s, e1 := p.TypeIdToString(elemType) + if e1 != nil { + return e1 + } + if e := p.OutputString(s); e != nil { + return e + } + if e := p.OutputI64(int64(size)); e != nil { + return e + } + return nil +} + +func (p *TJSONProtocol) TypeIdToString(fieldType TType) (string, error) { + switch byte(fieldType) { + case BOOL: + return "tf", nil + case BYTE: + return "i8", nil + case I16: + return "i16", nil + case I32: + return "i32", nil + case I64: + return "i64", nil + case DOUBLE: + return "dbl", nil + case STRING: + return "str", nil + case STRUCT: + return "rec", nil + case MAP: + return "map", nil + case SET: + return "set", nil + case LIST: + return "lst", nil + } + + e := fmt.Errorf("Unknown fieldType: %d", int(fieldType)) + return "", NewTProtocolExceptionWithType(INVALID_DATA, e) +} + +func (p *TJSONProtocol) StringToTypeId(fieldType string) (TType, error) { + switch fieldType { + case "tf": + return TType(BOOL), nil + case "i8": + return TType(BYTE), nil + case "i16": + return TType(I16), nil + case "i32": + return TType(I32), nil + case "i64": + return TType(I64), nil + case "dbl": + return TType(DOUBLE), nil + case "str": + return TType(STRING), nil + case "rec": + return TType(STRUCT), nil + case "map": + return TType(MAP), nil + case "set": + return TType(SET), nil + case "lst": + return TType(LIST), nil + } + + e := fmt.Errorf("Unknown type identifier: %s", fieldType) + return TType(STOP), NewTProtocolExceptionWithType(INVALID_DATA, e) +} diff --git a/src/jaegertracing/thrift/lib/go/thrift/json_protocol_test.go b/src/jaegertracing/thrift/lib/go/thrift/json_protocol_test.go new file mode 100644 index 000000000..59c4d64a2 --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/thrift/json_protocol_test.go @@ -0,0 +1,650 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import ( + "context" + "encoding/base64" + "encoding/json" + "fmt" + "math" + "strconv" + "testing" +) + +func TestWriteJSONProtocolBool(t *testing.T) { + thetype := "boolean" + trans := NewTMemoryBuffer() + p := NewTJSONProtocol(trans) + for _, value := range BOOL_VALUES { + if e := p.WriteBool(value); e != nil { + t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error()) + } + if e := p.Flush(context.Background()); e != nil { + t.Fatalf("Unable to write %s value %v due to error flushing: %s", thetype, value, e.Error()) + } + s := trans.String() + expected := "" + if value { + expected = "1" + } else { + expected = "0" + } + if s != expected { + t.Fatalf("Bad value for %s %v: %s expected", thetype, value, s) + } + v := -1 + if err := json.Unmarshal([]byte(s), &v); err != nil || (v != 0) != value { + t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, v) + } + trans.Reset() + } + trans.Close() +} + +func TestReadJSONProtocolBool(t *testing.T) { + thetype := "boolean" + for _, value := range BOOL_VALUES { + trans := NewTMemoryBuffer() + p := NewTJSONProtocol(trans) + if value { + trans.Write([]byte{'1'}) // not JSON_TRUE + } else { + trans.Write([]byte{'0'}) // not JSON_FALSE + } + trans.Flush(context.Background()) + s := trans.String() + v, e := p.ReadBool() + if e != nil { + t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error()) + } + if v != value { + t.Fatalf("Bad value for %s value %v, wrote: %v, received: %v", thetype, value, s, v) + } + vv := -1 + if err := json.Unmarshal([]byte(s), &vv); err != nil || (vv != 0) != value { + t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, vv) + } + trans.Reset() + trans.Close() + } +} + +func TestWriteJSONProtocolByte(t *testing.T) { + thetype := "byte" + trans := NewTMemoryBuffer() + p := NewTJSONProtocol(trans) + for _, value := range BYTE_VALUES { + if e := p.WriteByte(value); e != nil { + t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error()) + } + if e := p.Flush(context.Background()); e != nil { + t.Fatalf("Unable to write %s value %v due to error flushing: %s", thetype, value, e.Error()) + } + s := trans.String() + if s != fmt.Sprint(value) { + t.Fatalf("Bad value for %s %v: %s", thetype, value, s) + } + v := int8(0) + if err := json.Unmarshal([]byte(s), &v); err != nil || v != value { + t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, v) + } + trans.Reset() + } + trans.Close() +} + +func TestReadJSONProtocolByte(t *testing.T) { + thetype := "byte" + for _, value := range BYTE_VALUES { + trans := NewTMemoryBuffer() + p := NewTJSONProtocol(trans) + trans.WriteString(strconv.Itoa(int(value))) + trans.Flush(context.Background()) + s := trans.String() + v, e := p.ReadByte() + if e != nil { + t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error()) + } + if v != value { + t.Fatalf("Bad value for %s value %v, wrote: %v, received: %v", thetype, value, s, v) + } + if err := json.Unmarshal([]byte(s), &v); err != nil || v != value { + t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, v) + } + trans.Reset() + trans.Close() + } +} + +func TestWriteJSONProtocolI16(t *testing.T) { + thetype := "int16" + trans := NewTMemoryBuffer() + p := NewTJSONProtocol(trans) + for _, value := range INT16_VALUES { + if e := p.WriteI16(value); e != nil { + t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error()) + } + if e := p.Flush(context.Background()); e != nil { + t.Fatalf("Unable to write %s value %v due to error flushing: %s", thetype, value, e.Error()) + } + s := trans.String() + if s != fmt.Sprint(value) { + t.Fatalf("Bad value for %s %v: %s", thetype, value, s) + } + v := int16(0) + if err := json.Unmarshal([]byte(s), &v); err != nil || v != value { + t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, v) + } + trans.Reset() + } + trans.Close() +} + +func TestReadJSONProtocolI16(t *testing.T) { + thetype := "int16" + for _, value := range INT16_VALUES { + trans := NewTMemoryBuffer() + p := NewTJSONProtocol(trans) + trans.WriteString(strconv.Itoa(int(value))) + trans.Flush(context.Background()) + s := trans.String() + v, e := p.ReadI16() + if e != nil { + t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error()) + } + if v != value { + t.Fatalf("Bad value for %s value %v, wrote: %v, received: %v", thetype, value, s, v) + } + if err := json.Unmarshal([]byte(s), &v); err != nil || v != value { + t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, v) + } + trans.Reset() + trans.Close() + } +} + +func TestWriteJSONProtocolI32(t *testing.T) { + thetype := "int32" + trans := NewTMemoryBuffer() + p := NewTJSONProtocol(trans) + for _, value := range INT32_VALUES { + if e := p.WriteI32(value); e != nil { + t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error()) + } + if e := p.Flush(context.Background()); e != nil { + t.Fatalf("Unable to write %s value %v due to error flushing: %s", thetype, value, e.Error()) + } + s := trans.String() + if s != fmt.Sprint(value) { + t.Fatalf("Bad value for %s %v: %s", thetype, value, s) + } + v := int32(0) + if err := json.Unmarshal([]byte(s), &v); err != nil || v != value { + t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, v) + } + trans.Reset() + } + trans.Close() +} + +func TestReadJSONProtocolI32(t *testing.T) { + thetype := "int32" + for _, value := range INT32_VALUES { + trans := NewTMemoryBuffer() + p := NewTJSONProtocol(trans) + trans.WriteString(strconv.Itoa(int(value))) + trans.Flush(context.Background()) + s := trans.String() + v, e := p.ReadI32() + if e != nil { + t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error()) + } + if v != value { + t.Fatalf("Bad value for %s value %v, wrote: %v, received: %v", thetype, value, s, v) + } + if err := json.Unmarshal([]byte(s), &v); err != nil || v != value { + t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, v) + } + trans.Reset() + trans.Close() + } +} + +func TestWriteJSONProtocolI64(t *testing.T) { + thetype := "int64" + trans := NewTMemoryBuffer() + p := NewTJSONProtocol(trans) + for _, value := range INT64_VALUES { + if e := p.WriteI64(value); e != nil { + t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error()) + } + if e := p.Flush(context.Background()); e != nil { + t.Fatalf("Unable to write %s value %v due to error flushing: %s", thetype, value, e.Error()) + } + s := trans.String() + if s != fmt.Sprint(value) { + t.Fatalf("Bad value for %s %v: %s", thetype, value, s) + } + v := int64(0) + if err := json.Unmarshal([]byte(s), &v); err != nil || v != value { + t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, v) + } + trans.Reset() + } + trans.Close() +} + +func TestReadJSONProtocolI64(t *testing.T) { + thetype := "int64" + for _, value := range INT64_VALUES { + trans := NewTMemoryBuffer() + p := NewTJSONProtocol(trans) + trans.WriteString(strconv.FormatInt(value, 10)) + trans.Flush(context.Background()) + s := trans.String() + v, e := p.ReadI64() + if e != nil { + t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error()) + } + if v != value { + t.Fatalf("Bad value for %s value %v, wrote: %v, received: %v", thetype, value, s, v) + } + if err := json.Unmarshal([]byte(s), &v); err != nil || v != value { + t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, v) + } + trans.Reset() + trans.Close() + } +} + +func TestWriteJSONProtocolDouble(t *testing.T) { + thetype := "double" + trans := NewTMemoryBuffer() + p := NewTJSONProtocol(trans) + for _, value := range DOUBLE_VALUES { + if e := p.WriteDouble(value); e != nil { + t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error()) + } + if e := p.Flush(context.Background()); e != nil { + t.Fatalf("Unable to write %s value %v due to error flushing: %s", thetype, value, e.Error()) + } + s := trans.String() + if math.IsInf(value, 1) { + if s != jsonQuote(JSON_INFINITY) { + t.Fatalf("Bad value for %s %v, wrote: %v, expected: %v", thetype, value, s, jsonQuote(JSON_INFINITY)) + } + } else if math.IsInf(value, -1) { + if s != jsonQuote(JSON_NEGATIVE_INFINITY) { + t.Fatalf("Bad value for %s %v, wrote: %v, expected: %v", thetype, value, s, jsonQuote(JSON_NEGATIVE_INFINITY)) + } + } else if math.IsNaN(value) { + if s != jsonQuote(JSON_NAN) { + t.Fatalf("Bad value for %s %v, wrote: %v, expected: %v", thetype, value, s, jsonQuote(JSON_NAN)) + } + } else { + if s != fmt.Sprint(value) { + t.Fatalf("Bad value for %s %v: %s", thetype, value, s) + } + v := float64(0) + if err := json.Unmarshal([]byte(s), &v); err != nil || v != value { + t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, v) + } + } + trans.Reset() + } + trans.Close() +} + +func TestReadJSONProtocolDouble(t *testing.T) { + thetype := "double" + for _, value := range DOUBLE_VALUES { + trans := NewTMemoryBuffer() + p := NewTJSONProtocol(trans) + n := NewNumericFromDouble(value) + trans.WriteString(n.String()) + trans.Flush(context.Background()) + s := trans.String() + v, e := p.ReadDouble() + if e != nil { + t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error()) + } + if math.IsInf(value, 1) { + if !math.IsInf(v, 1) { + t.Fatalf("Bad value for %s %v, wrote: %v, received: %v", thetype, value, s, v) + } + } else if math.IsInf(value, -1) { + if !math.IsInf(v, -1) { + t.Fatalf("Bad value for %s %v, wrote: %v, received: %v", thetype, value, s, v) + } + } else if math.IsNaN(value) { + if !math.IsNaN(v) { + t.Fatalf("Bad value for %s %v, wrote: %v, received: %v", thetype, value, s, v) + } + } else { + if v != value { + t.Fatalf("Bad value for %s value %v, wrote: %v, received: %v", thetype, value, s, v) + } + if err := json.Unmarshal([]byte(s), &v); err != nil || v != value { + t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, v) + } + } + trans.Reset() + trans.Close() + } +} + +func TestWriteJSONProtocolString(t *testing.T) { + thetype := "string" + trans := NewTMemoryBuffer() + p := NewTJSONProtocol(trans) + for _, value := range STRING_VALUES { + if e := p.WriteString(value); e != nil { + t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error()) + } + if e := p.Flush(context.Background()); e != nil { + t.Fatalf("Unable to write %s value %v due to error flushing: %s", thetype, value, e.Error()) + } + s := trans.String() + if s[0] != '"' || s[len(s)-1] != '"' { + t.Fatalf("Bad value for %s '%v', wrote '%v', expected: %v", thetype, value, s, fmt.Sprint("\"", value, "\"")) + } + v := new(string) + if err := json.Unmarshal([]byte(s), v); err != nil || *v != value { + t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, *v) + } + trans.Reset() + } + trans.Close() +} + +func TestReadJSONProtocolString(t *testing.T) { + thetype := "string" + for _, value := range STRING_VALUES { + trans := NewTMemoryBuffer() + p := NewTJSONProtocol(trans) + trans.WriteString(jsonQuote(value)) + trans.Flush(context.Background()) + s := trans.String() + v, e := p.ReadString() + if e != nil { + t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error()) + } + if v != value { + t.Fatalf("Bad value for %s value %v, wrote: %v, received: %v", thetype, value, s, v) + } + v1 := new(string) + if err := json.Unmarshal([]byte(s), v1); err != nil || *v1 != value { + t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, *v1) + } + trans.Reset() + trans.Close() + } +} + +func TestWriteJSONProtocolBinary(t *testing.T) { + thetype := "binary" + value := protocol_bdata + b64value := make([]byte, base64.StdEncoding.EncodedLen(len(protocol_bdata))) + base64.StdEncoding.Encode(b64value, value) + b64String := string(b64value) + trans := NewTMemoryBuffer() + p := NewTJSONProtocol(trans) + if e := p.WriteBinary(value); e != nil { + t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error()) + } + if e := p.Flush(context.Background()); e != nil { + t.Fatalf("Unable to write %s value %v due to error flushing: %s", thetype, value, e.Error()) + } + s := trans.String() + expectedString := fmt.Sprint("\"", b64String, "\"") + if s != expectedString { + t.Fatalf("Bad value for %s %v\n wrote: \"%v\"\nexpected: \"%v\"", thetype, value, s, expectedString) + } + v1, err := p.ReadBinary() + if err != nil { + t.Fatalf("Unable to read binary: %s", err.Error()) + } + if len(v1) != len(value) { + t.Fatalf("Invalid value for binary\nexpected: \"%v\"\n read: \"%v\"", value, v1) + } + for k, v := range value { + if v1[k] != v { + t.Fatalf("Invalid value for binary at %v\nexpected: \"%v\"\n read: \"%v\"", k, v, v1[k]) + } + } + trans.Close() +} + +func TestReadJSONProtocolBinary(t *testing.T) { + thetype := "binary" + value := protocol_bdata + b64value := make([]byte, base64.StdEncoding.EncodedLen(len(protocol_bdata))) + base64.StdEncoding.Encode(b64value, value) + b64String := string(b64value) + trans := NewTMemoryBuffer() + p := NewTJSONProtocol(trans) + trans.WriteString(jsonQuote(b64String)) + trans.Flush(context.Background()) + s := trans.String() + v, e := p.ReadBinary() + if e != nil { + t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error()) + } + if len(v) != len(value) { + t.Fatalf("Bad value for %s value length %v, wrote: %v, received length: %v", thetype, len(value), s, len(v)) + } + for i := 0; i < len(v); i++ { + if v[i] != value[i] { + t.Fatalf("Bad value for %s at index %d value %v, wrote: %v, received: %v", thetype, i, value[i], s, v[i]) + } + } + v1 := new(string) + if err := json.Unmarshal([]byte(s), v1); err != nil || *v1 != b64String { + t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, *v1) + } + trans.Reset() + trans.Close() +} + +func TestWriteJSONProtocolList(t *testing.T) { + thetype := "list" + trans := NewTMemoryBuffer() + p := NewTJSONProtocol(trans) + p.WriteListBegin(TType(DOUBLE), len(DOUBLE_VALUES)) + for _, value := range DOUBLE_VALUES { + if e := p.WriteDouble(value); e != nil { + t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error()) + } + } + p.WriteListEnd() + if e := p.Flush(context.Background()); e != nil { + t.Fatalf("Unable to write %s due to error flushing: %s", thetype, e.Error()) + } + str := trans.String() + str1 := new([]interface{}) + err := json.Unmarshal([]byte(str), str1) + if err != nil { + t.Fatalf("Unable to decode %s, wrote: %s", thetype, str) + } + l := *str1 + if len(l) < 2 { + t.Fatalf("List must be at least of length two to include metadata") + } + if l[0] != "dbl" { + t.Fatal("Invalid type for list, expected: ", STRING, ", but was: ", l[0]) + } + if int(l[1].(float64)) != len(DOUBLE_VALUES) { + t.Fatal("Invalid length for list, expected: ", len(DOUBLE_VALUES), ", but was: ", l[1]) + } + for k, value := range DOUBLE_VALUES { + s := l[k+2] + if math.IsInf(value, 1) { + if s.(string) != JSON_INFINITY { + t.Fatalf("Bad value for %s at index %v %v, wrote: %q, expected: %q, originally wrote: %q", thetype, k, value, s, jsonQuote(JSON_INFINITY), str) + } + } else if math.IsInf(value, 0) { + if s.(string) != JSON_NEGATIVE_INFINITY { + t.Fatalf("Bad value for %s at index %v %v, wrote: %q, expected: %q, originally wrote: %q", thetype, k, value, s, jsonQuote(JSON_NEGATIVE_INFINITY), str) + } + } else if math.IsNaN(value) { + if s.(string) != JSON_NAN { + t.Fatalf("Bad value for %s at index %v %v, wrote: %q, expected: %q, originally wrote: %q", thetype, k, value, s, jsonQuote(JSON_NAN), str) + } + } else { + if s.(float64) != value { + t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s'", thetype, value, s) + } + } + trans.Reset() + } + trans.Close() +} + +func TestWriteJSONProtocolSet(t *testing.T) { + thetype := "set" + trans := NewTMemoryBuffer() + p := NewTJSONProtocol(trans) + p.WriteSetBegin(TType(DOUBLE), len(DOUBLE_VALUES)) + for _, value := range DOUBLE_VALUES { + if e := p.WriteDouble(value); e != nil { + t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error()) + } + } + p.WriteSetEnd() + if e := p.Flush(context.Background()); e != nil { + t.Fatalf("Unable to write %s due to error flushing: %s", thetype, e.Error()) + } + str := trans.String() + str1 := new([]interface{}) + err := json.Unmarshal([]byte(str), str1) + if err != nil { + t.Fatalf("Unable to decode %s, wrote: %s", thetype, str) + } + l := *str1 + if len(l) < 2 { + t.Fatalf("Set must be at least of length two to include metadata") + } + if l[0] != "dbl" { + t.Fatal("Invalid type for set, expected: ", DOUBLE, ", but was: ", l[0]) + } + if int(l[1].(float64)) != len(DOUBLE_VALUES) { + t.Fatal("Invalid length for set, expected: ", len(DOUBLE_VALUES), ", but was: ", l[1]) + } + for k, value := range DOUBLE_VALUES { + s := l[k+2] + if math.IsInf(value, 1) { + if s.(string) != JSON_INFINITY { + t.Fatalf("Bad value for %s at index %v %v, wrote: %q, expected: %q, originally wrote: %q", thetype, k, value, s, jsonQuote(JSON_INFINITY), str) + } + } else if math.IsInf(value, 0) { + if s.(string) != JSON_NEGATIVE_INFINITY { + t.Fatalf("Bad value for %s at index %v %v, wrote: %q, expected: %q, originally wrote: %q", thetype, k, value, s, jsonQuote(JSON_NEGATIVE_INFINITY), str) + } + } else if math.IsNaN(value) { + if s.(string) != JSON_NAN { + t.Fatalf("Bad value for %s at index %v %v, wrote: %q, expected: %q, originally wrote: %q", thetype, k, value, s, jsonQuote(JSON_NAN), str) + } + } else { + if s.(float64) != value { + t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s'", thetype, value, s) + } + } + trans.Reset() + } + trans.Close() +} + +func TestWriteJSONProtocolMap(t *testing.T) { + thetype := "map" + trans := NewTMemoryBuffer() + p := NewTJSONProtocol(trans) + p.WriteMapBegin(TType(I32), TType(DOUBLE), len(DOUBLE_VALUES)) + for k, value := range DOUBLE_VALUES { + if e := p.WriteI32(int32(k)); e != nil { + t.Fatalf("Unable to write %s key int32 value %v due to error: %s", thetype, k, e.Error()) + } + if e := p.WriteDouble(value); e != nil { + t.Fatalf("Unable to write %s value float64 value %v due to error: %s", thetype, value, e.Error()) + } + } + p.WriteMapEnd() + if e := p.Flush(context.Background()); e != nil { + t.Fatalf("Unable to write %s due to error flushing: %s", thetype, e.Error()) + } + str := trans.String() + if str[0] != '[' || str[len(str)-1] != ']' { + t.Fatalf("Bad value for %s, wrote: %v, in go: %v", thetype, str, DOUBLE_VALUES) + } + expectedKeyType, expectedValueType, expectedSize, err := p.ReadMapBegin() + if err != nil { + t.Fatalf("Error while reading map begin: %s", err.Error()) + } + if expectedKeyType != I32 { + t.Fatal("Expected map key type ", I32, ", but was ", expectedKeyType) + } + if expectedValueType != DOUBLE { + t.Fatal("Expected map value type ", DOUBLE, ", but was ", expectedValueType) + } + if expectedSize != len(DOUBLE_VALUES) { + t.Fatal("Expected map size of ", len(DOUBLE_VALUES), ", but was ", expectedSize) + } + for k, value := range DOUBLE_VALUES { + ik, err := p.ReadI32() + if err != nil { + t.Fatalf("Bad key for %s index %v, wrote: %v, expected: %v, error: %s", thetype, k, ik, string(k), err.Error()) + } + if int(ik) != k { + t.Fatalf("Bad key for %s index %v, wrote: %v, expected: %v", thetype, k, ik, k) + } + dv, err := p.ReadDouble() + if err != nil { + t.Fatalf("Bad value for %s index %v, wrote: %v, expected: %v, error: %s", thetype, k, dv, value, err.Error()) + } + s := strconv.FormatFloat(dv, 'g', 10, 64) + if math.IsInf(value, 1) { + if !math.IsInf(dv, 1) { + t.Fatalf("Bad value for %s at index %v %v, wrote: %v, expected: %v", thetype, k, value, s, jsonQuote(JSON_INFINITY)) + } + } else if math.IsInf(value, 0) { + if !math.IsInf(dv, 0) { + t.Fatalf("Bad value for %s at index %v %v, wrote: %v, expected: %v", thetype, k, value, s, jsonQuote(JSON_NEGATIVE_INFINITY)) + } + } else if math.IsNaN(value) { + if !math.IsNaN(dv) { + t.Fatalf("Bad value for %s at index %v %v, wrote: %v, expected: %v", thetype, k, value, s, jsonQuote(JSON_NAN)) + } + } else { + expected := strconv.FormatFloat(value, 'g', 10, 64) + if s != expected { + t.Fatalf("Bad value for %s at index %v %v, wrote: %v, expected %v", thetype, k, value, s, expected) + } + v := float64(0) + if err := json.Unmarshal([]byte(s), &v); err != nil || v != value { + t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, v) + } + } + } + err = p.ReadMapEnd() + if err != nil { + t.Fatalf("Error while reading map end: %s", err.Error()) + } + trans.Close() +} diff --git a/src/jaegertracing/thrift/lib/go/thrift/lowlevel_benchmarks_test.go b/src/jaegertracing/thrift/lib/go/thrift/lowlevel_benchmarks_test.go new file mode 100644 index 000000000..e1736557b --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/thrift/lowlevel_benchmarks_test.go @@ -0,0 +1,540 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import ( + "bytes" + "testing" +) + +var binaryProtoF = NewTBinaryProtocolFactoryDefault() +var compactProtoF = NewTCompactProtocolFactory() + +var buf = bytes.NewBuffer(make([]byte, 0, 1024)) + +var tfv = []TTransportFactory{ + NewTMemoryBufferTransportFactory(1024), + NewStreamTransportFactory(buf, buf, true), + NewTFramedTransportFactory(NewTMemoryBufferTransportFactory(1024)), +} + +func BenchmarkBinaryBool_0(b *testing.B) { + trans, err := tfv[0].GetTransport(nil) + if err != nil { + b.Fatal(err) + } + p := binaryProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteBool(b, p, trans) + } +} + +func BenchmarkBinaryByte_0(b *testing.B) { + trans, err := tfv[0].GetTransport(nil) + if err != nil { + b.Fatal(err) + } + p := binaryProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteByte(b, p, trans) + } +} + +func BenchmarkBinaryI16_0(b *testing.B) { + trans, err := tfv[0].GetTransport(nil) + if err != nil { + b.Fatal(err) + } + p := binaryProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteI16(b, p, trans) + } +} + +func BenchmarkBinaryI32_0(b *testing.B) { + trans, err := tfv[0].GetTransport(nil) + if err != nil { + b.Fatal(err) + } + p := binaryProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteI32(b, p, trans) + } +} +func BenchmarkBinaryI64_0(b *testing.B) { + trans, err := tfv[0].GetTransport(nil) + if err != nil { + b.Fatal(err) + } + p := binaryProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteI64(b, p, trans) + } +} +func BenchmarkBinaryDouble_0(b *testing.B) { + trans, err := tfv[0].GetTransport(nil) + if err != nil { + b.Fatal(err) + } + p := binaryProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteDouble(b, p, trans) + } +} +func BenchmarkBinaryString_0(b *testing.B) { + trans, err := tfv[0].GetTransport(nil) + if err != nil { + b.Fatal(err) + } + p := binaryProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteString(b, p, trans) + } +} +func BenchmarkBinaryBinary_0(b *testing.B) { + trans, err := tfv[0].GetTransport(nil) + if err != nil { + b.Fatal(err) + } + p := binaryProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteBinary(b, p, trans) + } +} + +func BenchmarkBinaryBool_1(b *testing.B) { + trans, err := tfv[1].GetTransport(nil) + if err != nil { + b.Fatal(err) + } + p := binaryProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteBool(b, p, trans) + } +} + +func BenchmarkBinaryByte_1(b *testing.B) { + trans, err := tfv[1].GetTransport(nil) + if err != nil { + b.Fatal(err) + } + p := binaryProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteByte(b, p, trans) + } +} + +func BenchmarkBinaryI16_1(b *testing.B) { + trans, err := tfv[1].GetTransport(nil) + if err != nil { + b.Fatal(err) + } + p := binaryProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteI16(b, p, trans) + } +} + +func BenchmarkBinaryI32_1(b *testing.B) { + trans, err := tfv[1].GetTransport(nil) + if err != nil { + b.Fatal(err) + } + p := binaryProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteI32(b, p, trans) + } +} +func BenchmarkBinaryI64_1(b *testing.B) { + trans, err := tfv[1].GetTransport(nil) + if err != nil { + b.Fatal(err) + } + p := binaryProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteI64(b, p, trans) + } +} +func BenchmarkBinaryDouble_1(b *testing.B) { + trans, err := tfv[1].GetTransport(nil) + if err != nil { + b.Fatal(err) + } + p := binaryProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteDouble(b, p, trans) + } +} +func BenchmarkBinaryString_1(b *testing.B) { + trans, err := tfv[1].GetTransport(nil) + if err != nil { + b.Fatal(err) + } + p := binaryProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteString(b, p, trans) + } +} +func BenchmarkBinaryBinary_1(b *testing.B) { + trans, err := tfv[1].GetTransport(nil) + if err != nil { + b.Fatal(err) + } + p := binaryProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteBinary(b, p, trans) + } +} + +func BenchmarkBinaryBool_2(b *testing.B) { + trans, err := tfv[2].GetTransport(nil) + if err != nil { + b.Fatal(err) + } + p := binaryProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteBool(b, p, trans) + } +} + +func BenchmarkBinaryByte_2(b *testing.B) { + trans, err := tfv[2].GetTransport(nil) + if err != nil { + b.Fatal(err) + } + p := binaryProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteByte(b, p, trans) + } +} + +func BenchmarkBinaryI16_2(b *testing.B) { + trans, err := tfv[2].GetTransport(nil) + if err != nil { + b.Fatal(err) + } + p := binaryProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteI16(b, p, trans) + } +} + +func BenchmarkBinaryI32_2(b *testing.B) { + trans, err := tfv[2].GetTransport(nil) + if err != nil { + b.Fatal(err) + } + p := binaryProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteI32(b, p, trans) + } +} +func BenchmarkBinaryI64_2(b *testing.B) { + trans, err := tfv[2].GetTransport(nil) + if err != nil { + b.Fatal(err) + } + p := binaryProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteI64(b, p, trans) + } +} +func BenchmarkBinaryDouble_2(b *testing.B) { + trans, err := tfv[2].GetTransport(nil) + if err != nil { + b.Fatal(err) + } + p := binaryProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteDouble(b, p, trans) + } +} +func BenchmarkBinaryString_2(b *testing.B) { + trans, err := tfv[2].GetTransport(nil) + if err != nil { + b.Fatal(err) + } + p := binaryProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteString(b, p, trans) + } +} +func BenchmarkBinaryBinary_2(b *testing.B) { + trans, err := tfv[2].GetTransport(nil) + if err != nil { + b.Fatal(err) + } + p := binaryProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteBinary(b, p, trans) + } +} + +func BenchmarkCompactBool_0(b *testing.B) { + trans, err := tfv[0].GetTransport(nil) + if err != nil { + b.Fatal(err) + } + p := compactProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteBool(b, p, trans) + } +} + +func BenchmarkCompactByte_0(b *testing.B) { + trans, err := tfv[0].GetTransport(nil) + if err != nil { + b.Fatal(err) + } + p := compactProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteByte(b, p, trans) + } +} + +func BenchmarkCompactI16_0(b *testing.B) { + trans, err := tfv[0].GetTransport(nil) + if err != nil { + b.Fatal(err) + } + p := compactProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteI16(b, p, trans) + } +} + +func BenchmarkCompactI32_0(b *testing.B) { + trans, err := tfv[0].GetTransport(nil) + if err != nil { + b.Fatal(err) + } + p := compactProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteI32(b, p, trans) + } +} +func BenchmarkCompactI64_0(b *testing.B) { + trans, err := tfv[0].GetTransport(nil) + if err != nil { + b.Fatal(err) + } + p := compactProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteI64(b, p, trans) + } +} +func BenchmarkCompactDouble0(b *testing.B) { + trans, err := tfv[0].GetTransport(nil) + if err != nil { + b.Fatal(err) + } + p := compactProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteDouble(b, p, trans) + } +} +func BenchmarkCompactString0(b *testing.B) { + trans, err := tfv[0].GetTransport(nil) + if err != nil { + b.Fatal(err) + } + p := compactProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteString(b, p, trans) + } +} +func BenchmarkCompactBinary0(b *testing.B) { + trans, err := tfv[0].GetTransport(nil) + if err != nil { + b.Fatal(err) + } + p := compactProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteBinary(b, p, trans) + } +} + +func BenchmarkCompactBool_1(b *testing.B) { + trans, err := tfv[1].GetTransport(nil) + if err != nil { + b.Fatal(err) + } + p := compactProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteBool(b, p, trans) + } +} + +func BenchmarkCompactByte_1(b *testing.B) { + trans, err := tfv[1].GetTransport(nil) + if err != nil { + b.Fatal(err) + } + p := compactProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteByte(b, p, trans) + } +} + +func BenchmarkCompactI16_1(b *testing.B) { + trans, err := tfv[1].GetTransport(nil) + if err != nil { + b.Fatal(err) + } + p := compactProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteI16(b, p, trans) + } +} + +func BenchmarkCompactI32_1(b *testing.B) { + trans, err := tfv[1].GetTransport(nil) + if err != nil { + b.Fatal(err) + } + p := compactProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteI32(b, p, trans) + } +} +func BenchmarkCompactI64_1(b *testing.B) { + trans, err := tfv[1].GetTransport(nil) + if err != nil { + b.Fatal(err) + } + p := compactProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteI64(b, p, trans) + } +} +func BenchmarkCompactDouble1(b *testing.B) { + trans, err := tfv[1].GetTransport(nil) + if err != nil { + b.Fatal(err) + } + p := compactProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteDouble(b, p, trans) + } +} +func BenchmarkCompactString1(b *testing.B) { + trans, err := tfv[1].GetTransport(nil) + if err != nil { + b.Fatal(err) + } + p := compactProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteString(b, p, trans) + } +} +func BenchmarkCompactBinary1(b *testing.B) { + trans, err := tfv[1].GetTransport(nil) + if err != nil { + b.Fatal(err) + } + p := compactProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteBinary(b, p, trans) + } +} + +func BenchmarkCompactBool_2(b *testing.B) { + trans, err := tfv[2].GetTransport(nil) + if err != nil { + b.Fatal(err) + } + p := compactProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteBool(b, p, trans) + } +} + +func BenchmarkCompactByte_2(b *testing.B) { + trans, err := tfv[2].GetTransport(nil) + if err != nil { + b.Fatal(err) + } + p := compactProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteByte(b, p, trans) + } +} + +func BenchmarkCompactI16_2(b *testing.B) { + trans, err := tfv[2].GetTransport(nil) + if err != nil { + b.Fatal(err) + } + p := compactProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteI16(b, p, trans) + } +} + +func BenchmarkCompactI32_2(b *testing.B) { + trans, err := tfv[2].GetTransport(nil) + if err != nil { + b.Fatal(err) + } + p := compactProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteI32(b, p, trans) + } +} +func BenchmarkCompactI64_2(b *testing.B) { + trans, err := tfv[2].GetTransport(nil) + if err != nil { + b.Fatal(err) + } + p := compactProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteI64(b, p, trans) + } +} +func BenchmarkCompactDouble2(b *testing.B) { + trans, err := tfv[2].GetTransport(nil) + if err != nil { + b.Fatal(err) + } + p := compactProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteDouble(b, p, trans) + } +} +func BenchmarkCompactString2(b *testing.B) { + trans, err := tfv[2].GetTransport(nil) + if err != nil { + b.Fatal(err) + } + p := compactProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteString(b, p, trans) + } +} +func BenchmarkCompactBinary2(b *testing.B) { + trans, err := tfv[2].GetTransport(nil) + if err != nil { + b.Fatal(err) + } + p := compactProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteBinary(b, p, trans) + } +} diff --git a/src/jaegertracing/thrift/lib/go/thrift/memory_buffer.go b/src/jaegertracing/thrift/lib/go/thrift/memory_buffer.go new file mode 100644 index 000000000..5936d2730 --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/thrift/memory_buffer.go @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import ( + "bytes" + "context" +) + +// Memory buffer-based implementation of the TTransport interface. +type TMemoryBuffer struct { + *bytes.Buffer + size int +} + +type TMemoryBufferTransportFactory struct { + size int +} + +func (p *TMemoryBufferTransportFactory) GetTransport(trans TTransport) (TTransport, error) { + if trans != nil { + t, ok := trans.(*TMemoryBuffer) + if ok && t.size > 0 { + return NewTMemoryBufferLen(t.size), nil + } + } + return NewTMemoryBufferLen(p.size), nil +} + +func NewTMemoryBufferTransportFactory(size int) *TMemoryBufferTransportFactory { + return &TMemoryBufferTransportFactory{size: size} +} + +func NewTMemoryBuffer() *TMemoryBuffer { + return &TMemoryBuffer{Buffer: &bytes.Buffer{}, size: 0} +} + +func NewTMemoryBufferLen(size int) *TMemoryBuffer { + buf := make([]byte, 0, size) + return &TMemoryBuffer{Buffer: bytes.NewBuffer(buf), size: size} +} + +func (p *TMemoryBuffer) IsOpen() bool { + return true +} + +func (p *TMemoryBuffer) Open() error { + return nil +} + +func (p *TMemoryBuffer) Close() error { + p.Buffer.Reset() + return nil +} + +// Flushing a memory buffer is a no-op +func (p *TMemoryBuffer) Flush(ctx context.Context) error { + return nil +} + +func (p *TMemoryBuffer) RemainingBytes() (num_bytes uint64) { + return uint64(p.Buffer.Len()) +} diff --git a/src/jaegertracing/thrift/lib/go/thrift/memory_buffer_test.go b/src/jaegertracing/thrift/lib/go/thrift/memory_buffer_test.go new file mode 100644 index 000000000..af2e8bfe5 --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/thrift/memory_buffer_test.go @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import ( + "testing" +) + +func TestMemoryBuffer(t *testing.T) { + trans := NewTMemoryBufferLen(1024) + TransportTest(t, trans, trans) +} diff --git a/src/jaegertracing/thrift/lib/go/thrift/messagetype.go b/src/jaegertracing/thrift/lib/go/thrift/messagetype.go new file mode 100644 index 000000000..25ab2e98a --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/thrift/messagetype.go @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +// Message type constants in the Thrift protocol. +type TMessageType int32 + +const ( + INVALID_TMESSAGE_TYPE TMessageType = 0 + CALL TMessageType = 1 + REPLY TMessageType = 2 + EXCEPTION TMessageType = 3 + ONEWAY TMessageType = 4 +) diff --git a/src/jaegertracing/thrift/lib/go/thrift/multiplexed_protocol.go b/src/jaegertracing/thrift/lib/go/thrift/multiplexed_protocol.go new file mode 100644 index 000000000..d028a30b3 --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/thrift/multiplexed_protocol.go @@ -0,0 +1,170 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import ( + "context" + "fmt" + "strings" +) + +/* +TMultiplexedProtocol is a protocol-independent concrete decorator +that allows a Thrift client to communicate with a multiplexing Thrift server, +by prepending the service name to the function name during function calls. + +NOTE: THIS IS NOT USED BY SERVERS. On the server, use TMultiplexedProcessor to handle request +from a multiplexing client. + +This example uses a single socket transport to invoke two services: + +socket := thrift.NewTSocketFromAddrTimeout(addr, TIMEOUT) +transport := thrift.NewTFramedTransport(socket) +protocol := thrift.NewTBinaryProtocolTransport(transport) + +mp := thrift.NewTMultiplexedProtocol(protocol, "Calculator") +service := Calculator.NewCalculatorClient(mp) + +mp2 := thrift.NewTMultiplexedProtocol(protocol, "WeatherReport") +service2 := WeatherReport.NewWeatherReportClient(mp2) + +err := transport.Open() +if err != nil { + t.Fatal("Unable to open client socket", err) +} + +fmt.Println(service.Add(2,2)) +fmt.Println(service2.GetTemperature()) +*/ + +type TMultiplexedProtocol struct { + TProtocol + serviceName string +} + +const MULTIPLEXED_SEPARATOR = ":" + +func NewTMultiplexedProtocol(protocol TProtocol, serviceName string) *TMultiplexedProtocol { + return &TMultiplexedProtocol{ + TProtocol: protocol, + serviceName: serviceName, + } +} + +func (t *TMultiplexedProtocol) WriteMessageBegin(name string, typeId TMessageType, seqid int32) error { + if typeId == CALL || typeId == ONEWAY { + return t.TProtocol.WriteMessageBegin(t.serviceName+MULTIPLEXED_SEPARATOR+name, typeId, seqid) + } else { + return t.TProtocol.WriteMessageBegin(name, typeId, seqid) + } +} + +/* +TMultiplexedProcessor is a TProcessor allowing +a single TServer to provide multiple services. + +To do so, you instantiate the processor and then register additional +processors with it, as shown in the following example: + +var processor = thrift.NewTMultiplexedProcessor() + +firstProcessor := +processor.RegisterProcessor("FirstService", firstProcessor) + +processor.registerProcessor( + "Calculator", + Calculator.NewCalculatorProcessor(&CalculatorHandler{}), +) + +processor.registerProcessor( + "WeatherReport", + WeatherReport.NewWeatherReportProcessor(&WeatherReportHandler{}), +) + +serverTransport, err := thrift.NewTServerSocketTimeout(addr, TIMEOUT) +if err != nil { + t.Fatal("Unable to create server socket", err) +} +server := thrift.NewTSimpleServer2(processor, serverTransport) +server.Serve(); +*/ + +type TMultiplexedProcessor struct { + serviceProcessorMap map[string]TProcessor + DefaultProcessor TProcessor +} + +func NewTMultiplexedProcessor() *TMultiplexedProcessor { + return &TMultiplexedProcessor{ + serviceProcessorMap: make(map[string]TProcessor), + } +} + +func (t *TMultiplexedProcessor) RegisterDefault(processor TProcessor) { + t.DefaultProcessor = processor +} + +func (t *TMultiplexedProcessor) RegisterProcessor(name string, processor TProcessor) { + if t.serviceProcessorMap == nil { + t.serviceProcessorMap = make(map[string]TProcessor) + } + t.serviceProcessorMap[name] = processor +} + +func (t *TMultiplexedProcessor) Process(ctx context.Context, in, out TProtocol) (bool, TException) { + name, typeId, seqid, err := in.ReadMessageBegin() + if err != nil { + return false, err + } + if typeId != CALL && typeId != ONEWAY { + return false, fmt.Errorf("Unexpected message type %v", typeId) + } + //extract the service name + v := strings.SplitN(name, MULTIPLEXED_SEPARATOR, 2) + if len(v) != 2 { + if t.DefaultProcessor != nil { + smb := NewStoredMessageProtocol(in, name, typeId, seqid) + return t.DefaultProcessor.Process(ctx, smb, out) + } + return false, fmt.Errorf("Service name not found in message name: %s. Did you forget to use a TMultiplexProtocol in your client?", name) + } + actualProcessor, ok := t.serviceProcessorMap[v[0]] + if !ok { + return false, fmt.Errorf("Service name not found: %s. Did you forget to call registerProcessor()?", v[0]) + } + smb := NewStoredMessageProtocol(in, v[1], typeId, seqid) + return actualProcessor.Process(ctx, smb, out) +} + +//Protocol that use stored message for ReadMessageBegin +type storedMessageProtocol struct { + TProtocol + name string + typeId TMessageType + seqid int32 +} + +func NewStoredMessageProtocol(protocol TProtocol, name string, typeId TMessageType, seqid int32) *storedMessageProtocol { + return &storedMessageProtocol{protocol, name, typeId, seqid} +} + +func (s *storedMessageProtocol) ReadMessageBegin() (name string, typeId TMessageType, seqid int32, err error) { + return s.name, s.typeId, s.seqid, nil +} diff --git a/src/jaegertracing/thrift/lib/go/thrift/numeric.go b/src/jaegertracing/thrift/lib/go/thrift/numeric.go new file mode 100644 index 000000000..aa8daa9b5 --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/thrift/numeric.go @@ -0,0 +1,164 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import ( + "math" + "strconv" +) + +type Numeric interface { + Int64() int64 + Int32() int32 + Int16() int16 + Byte() byte + Int() int + Float64() float64 + Float32() float32 + String() string + isNull() bool +} + +type numeric struct { + iValue int64 + dValue float64 + sValue string + isNil bool +} + +var ( + INFINITY Numeric + NEGATIVE_INFINITY Numeric + NAN Numeric + ZERO Numeric + NUMERIC_NULL Numeric +) + +func NewNumericFromDouble(dValue float64) Numeric { + if math.IsInf(dValue, 1) { + return INFINITY + } + if math.IsInf(dValue, -1) { + return NEGATIVE_INFINITY + } + if math.IsNaN(dValue) { + return NAN + } + iValue := int64(dValue) + sValue := strconv.FormatFloat(dValue, 'g', 10, 64) + isNil := false + return &numeric{iValue: iValue, dValue: dValue, sValue: sValue, isNil: isNil} +} + +func NewNumericFromI64(iValue int64) Numeric { + dValue := float64(iValue) + sValue := string(iValue) + isNil := false + return &numeric{iValue: iValue, dValue: dValue, sValue: sValue, isNil: isNil} +} + +func NewNumericFromI32(iValue int32) Numeric { + dValue := float64(iValue) + sValue := string(iValue) + isNil := false + return &numeric{iValue: int64(iValue), dValue: dValue, sValue: sValue, isNil: isNil} +} + +func NewNumericFromString(sValue string) Numeric { + if sValue == INFINITY.String() { + return INFINITY + } + if sValue == NEGATIVE_INFINITY.String() { + return NEGATIVE_INFINITY + } + if sValue == NAN.String() { + return NAN + } + iValue, _ := strconv.ParseInt(sValue, 10, 64) + dValue, _ := strconv.ParseFloat(sValue, 64) + isNil := len(sValue) == 0 + return &numeric{iValue: iValue, dValue: dValue, sValue: sValue, isNil: isNil} +} + +func NewNumericFromJSONString(sValue string, isNull bool) Numeric { + if isNull { + return NewNullNumeric() + } + if sValue == JSON_INFINITY { + return INFINITY + } + if sValue == JSON_NEGATIVE_INFINITY { + return NEGATIVE_INFINITY + } + if sValue == JSON_NAN { + return NAN + } + iValue, _ := strconv.ParseInt(sValue, 10, 64) + dValue, _ := strconv.ParseFloat(sValue, 64) + return &numeric{iValue: iValue, dValue: dValue, sValue: sValue, isNil: isNull} +} + +func NewNullNumeric() Numeric { + return &numeric{iValue: 0, dValue: 0.0, sValue: "", isNil: true} +} + +func (p *numeric) Int64() int64 { + return p.iValue +} + +func (p *numeric) Int32() int32 { + return int32(p.iValue) +} + +func (p *numeric) Int16() int16 { + return int16(p.iValue) +} + +func (p *numeric) Byte() byte { + return byte(p.iValue) +} + +func (p *numeric) Int() int { + return int(p.iValue) +} + +func (p *numeric) Float64() float64 { + return p.dValue +} + +func (p *numeric) Float32() float32 { + return float32(p.dValue) +} + +func (p *numeric) String() string { + return p.sValue +} + +func (p *numeric) isNull() bool { + return p.isNil +} + +func init() { + INFINITY = &numeric{iValue: 0, dValue: math.Inf(1), sValue: "Infinity", isNil: false} + NEGATIVE_INFINITY = &numeric{iValue: 0, dValue: math.Inf(-1), sValue: "-Infinity", isNil: false} + NAN = &numeric{iValue: 0, dValue: math.NaN(), sValue: "NaN", isNil: false} + ZERO = &numeric{iValue: 0, dValue: 0, sValue: "0", isNil: false} + NUMERIC_NULL = &numeric{iValue: 0, dValue: 0, sValue: "0", isNil: true} +} diff --git a/src/jaegertracing/thrift/lib/go/thrift/pointerize.go b/src/jaegertracing/thrift/lib/go/thrift/pointerize.go new file mode 100644 index 000000000..fb564ea81 --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/thrift/pointerize.go @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +/////////////////////////////////////////////////////////////////////////////// +// This file is home to helpers that convert from various base types to +// respective pointer types. This is necessary because Go does not permit +// references to constants, nor can a pointer type to base type be allocated +// and initialized in a single expression. +// +// E.g., this is not allowed: +// +// var ip *int = &5 +// +// But this *is* allowed: +// +// func IntPtr(i int) *int { return &i } +// var ip *int = IntPtr(5) +// +// Since pointers to base types are commonplace as [optional] fields in +// exported thrift structs, we factor such helpers here. +/////////////////////////////////////////////////////////////////////////////// + +func Float32Ptr(v float32) *float32 { return &v } +func Float64Ptr(v float64) *float64 { return &v } +func IntPtr(v int) *int { return &v } +func Int8Ptr(v int8) *int8 { return &v } +func Int16Ptr(v int16) *int16 { return &v } +func Int32Ptr(v int32) *int32 { return &v } +func Int64Ptr(v int64) *int64 { return &v } +func StringPtr(v string) *string { return &v } +func Uint32Ptr(v uint32) *uint32 { return &v } +func Uint64Ptr(v uint64) *uint64 { return &v } +func BoolPtr(v bool) *bool { return &v } +func ByteSlicePtr(v []byte) *[]byte { return &v } diff --git a/src/jaegertracing/thrift/lib/go/thrift/processor_factory.go b/src/jaegertracing/thrift/lib/go/thrift/processor_factory.go new file mode 100644 index 000000000..e4b132b30 --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/thrift/processor_factory.go @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import "context" + +// A processor is a generic object which operates upon an input stream and +// writes to some output stream. +type TProcessor interface { + Process(ctx context.Context, in, out TProtocol) (bool, TException) +} + +type TProcessorFunction interface { + Process(ctx context.Context, seqId int32, in, out TProtocol) (bool, TException) +} + +// The default processor factory just returns a singleton +// instance. +type TProcessorFactory interface { + GetProcessor(trans TTransport) TProcessor +} + +type tProcessorFactory struct { + processor TProcessor +} + +func NewTProcessorFactory(p TProcessor) TProcessorFactory { + return &tProcessorFactory{processor: p} +} + +func (p *tProcessorFactory) GetProcessor(trans TTransport) TProcessor { + return p.processor +} + +/** + * The default processor factory just returns a singleton + * instance. + */ +type TProcessorFunctionFactory interface { + GetProcessorFunction(trans TTransport) TProcessorFunction +} + +type tProcessorFunctionFactory struct { + processor TProcessorFunction +} + +func NewTProcessorFunctionFactory(p TProcessorFunction) TProcessorFunctionFactory { + return &tProcessorFunctionFactory{processor: p} +} + +func (p *tProcessorFunctionFactory) GetProcessorFunction(trans TTransport) TProcessorFunction { + return p.processor +} diff --git a/src/jaegertracing/thrift/lib/go/thrift/protocol.go b/src/jaegertracing/thrift/lib/go/thrift/protocol.go new file mode 100644 index 000000000..2e6bc4b16 --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/thrift/protocol.go @@ -0,0 +1,177 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import ( + "context" + "errors" + "fmt" +) + +const ( + VERSION_MASK = 0xffff0000 + VERSION_1 = 0x80010000 +) + +type TProtocol interface { + WriteMessageBegin(name string, typeId TMessageType, seqid int32) error + WriteMessageEnd() error + WriteStructBegin(name string) error + WriteStructEnd() error + WriteFieldBegin(name string, typeId TType, id int16) error + WriteFieldEnd() error + WriteFieldStop() error + WriteMapBegin(keyType TType, valueType TType, size int) error + WriteMapEnd() error + WriteListBegin(elemType TType, size int) error + WriteListEnd() error + WriteSetBegin(elemType TType, size int) error + WriteSetEnd() error + WriteBool(value bool) error + WriteByte(value int8) error + WriteI16(value int16) error + WriteI32(value int32) error + WriteI64(value int64) error + WriteDouble(value float64) error + WriteString(value string) error + WriteBinary(value []byte) error + + ReadMessageBegin() (name string, typeId TMessageType, seqid int32, err error) + ReadMessageEnd() error + ReadStructBegin() (name string, err error) + ReadStructEnd() error + ReadFieldBegin() (name string, typeId TType, id int16, err error) + ReadFieldEnd() error + ReadMapBegin() (keyType TType, valueType TType, size int, err error) + ReadMapEnd() error + ReadListBegin() (elemType TType, size int, err error) + ReadListEnd() error + ReadSetBegin() (elemType TType, size int, err error) + ReadSetEnd() error + ReadBool() (value bool, err error) + ReadByte() (value int8, err error) + ReadI16() (value int16, err error) + ReadI32() (value int32, err error) + ReadI64() (value int64, err error) + ReadDouble() (value float64, err error) + ReadString() (value string, err error) + ReadBinary() (value []byte, err error) + + Skip(fieldType TType) (err error) + Flush(ctx context.Context) (err error) + + Transport() TTransport +} + +// The maximum recursive depth the skip() function will traverse +const DEFAULT_RECURSION_DEPTH = 64 + +// Skips over the next data element from the provided input TProtocol object. +func SkipDefaultDepth(prot TProtocol, typeId TType) (err error) { + return Skip(prot, typeId, DEFAULT_RECURSION_DEPTH) +} + +// Skips over the next data element from the provided input TProtocol object. +func Skip(self TProtocol, fieldType TType, maxDepth int) (err error) { + + if maxDepth <= 0 { + return NewTProtocolExceptionWithType(DEPTH_LIMIT, errors.New("Depth limit exceeded")) + } + + switch fieldType { + case BOOL: + _, err = self.ReadBool() + return + case BYTE: + _, err = self.ReadByte() + return + case I16: + _, err = self.ReadI16() + return + case I32: + _, err = self.ReadI32() + return + case I64: + _, err = self.ReadI64() + return + case DOUBLE: + _, err = self.ReadDouble() + return + case STRING: + _, err = self.ReadString() + return + case STRUCT: + if _, err = self.ReadStructBegin(); err != nil { + return err + } + for { + _, typeId, _, _ := self.ReadFieldBegin() + if typeId == STOP { + break + } + err := Skip(self, typeId, maxDepth-1) + if err != nil { + return err + } + self.ReadFieldEnd() + } + return self.ReadStructEnd() + case MAP: + keyType, valueType, size, err := self.ReadMapBegin() + if err != nil { + return err + } + for i := 0; i < size; i++ { + err := Skip(self, keyType, maxDepth-1) + if err != nil { + return err + } + self.Skip(valueType) + } + return self.ReadMapEnd() + case SET: + elemType, size, err := self.ReadSetBegin() + if err != nil { + return err + } + for i := 0; i < size; i++ { + err := Skip(self, elemType, maxDepth-1) + if err != nil { + return err + } + } + return self.ReadSetEnd() + case LIST: + elemType, size, err := self.ReadListBegin() + if err != nil { + return err + } + for i := 0; i < size; i++ { + err := Skip(self, elemType, maxDepth-1) + if err != nil { + return err + } + } + return self.ReadListEnd() + default: + return NewTProtocolExceptionWithType(INVALID_DATA, errors.New(fmt.Sprintf("Unknown data type %d", fieldType))) + } + return nil +} diff --git a/src/jaegertracing/thrift/lib/go/thrift/protocol_exception.go b/src/jaegertracing/thrift/lib/go/thrift/protocol_exception.go new file mode 100644 index 000000000..29ab75d92 --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/thrift/protocol_exception.go @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import ( + "encoding/base64" +) + +// Thrift Protocol exception +type TProtocolException interface { + TException + TypeId() int +} + +const ( + UNKNOWN_PROTOCOL_EXCEPTION = 0 + INVALID_DATA = 1 + NEGATIVE_SIZE = 2 + SIZE_LIMIT = 3 + BAD_VERSION = 4 + NOT_IMPLEMENTED = 5 + DEPTH_LIMIT = 6 +) + +type tProtocolException struct { + typeId int + message string +} + +func (p *tProtocolException) TypeId() int { + return p.typeId +} + +func (p *tProtocolException) String() string { + return p.message +} + +func (p *tProtocolException) Error() string { + return p.message +} + +func NewTProtocolException(err error) TProtocolException { + if err == nil { + return nil + } + if e, ok := err.(TProtocolException); ok { + return e + } + if _, ok := err.(base64.CorruptInputError); ok { + return &tProtocolException{INVALID_DATA, err.Error()} + } + return &tProtocolException{UNKNOWN_PROTOCOL_EXCEPTION, err.Error()} +} + +func NewTProtocolExceptionWithType(errType int, err error) TProtocolException { + if err == nil { + return nil + } + return &tProtocolException{errType, err.Error()} +} diff --git a/src/jaegertracing/thrift/lib/go/thrift/protocol_factory.go b/src/jaegertracing/thrift/lib/go/thrift/protocol_factory.go new file mode 100644 index 000000000..c40f796d8 --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/thrift/protocol_factory.go @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +// Factory interface for constructing protocol instances. +type TProtocolFactory interface { + GetProtocol(trans TTransport) TProtocol +} diff --git a/src/jaegertracing/thrift/lib/go/thrift/protocol_test.go b/src/jaegertracing/thrift/lib/go/thrift/protocol_test.go new file mode 100644 index 000000000..944055c0b --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/thrift/protocol_test.go @@ -0,0 +1,517 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import ( + "bytes" + "context" + "io/ioutil" + "math" + "net" + "net/http" + "testing" +) + +const PROTOCOL_BINARY_DATA_SIZE = 155 + +var ( + protocol_bdata []byte // test data for writing; same as data + BOOL_VALUES []bool + BYTE_VALUES []int8 + INT16_VALUES []int16 + INT32_VALUES []int32 + INT64_VALUES []int64 + DOUBLE_VALUES []float64 + STRING_VALUES []string +) + +func init() { + protocol_bdata = make([]byte, PROTOCOL_BINARY_DATA_SIZE) + for i := 0; i < PROTOCOL_BINARY_DATA_SIZE; i++ { + protocol_bdata[i] = byte((i + 'a') % 255) + } + BOOL_VALUES = []bool{false, true, false, false, true} + BYTE_VALUES = []int8{117, 0, 1, 32, 127, -128, -1} + INT16_VALUES = []int16{459, 0, 1, -1, -128, 127, 32767, -32768} + INT32_VALUES = []int32{459, 0, 1, -1, -128, 127, 32767, 2147483647, -2147483535} + INT64_VALUES = []int64{459, 0, 1, -1, -128, 127, 32767, 2147483647, -2147483535, 34359738481, -35184372088719, -9223372036854775808, 9223372036854775807} + DOUBLE_VALUES = []float64{459.3, 0.0, -1.0, 1.0, 0.5, 0.3333, 3.14159, 1.537e-38, 1.673e25, 6.02214179e23, -6.02214179e23, INFINITY.Float64(), NEGATIVE_INFINITY.Float64(), NAN.Float64()} + STRING_VALUES = []string{"", "a", "st[uf]f", "st,u:ff with spaces", "stuff\twith\nescape\\characters'...\"lots{of}fun</xml>"} +} + +type HTTPEchoServer struct{} +type HTTPHeaderEchoServer struct{} + +func (p *HTTPEchoServer) ServeHTTP(w http.ResponseWriter, req *http.Request) { + buf, err := ioutil.ReadAll(req.Body) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + w.Write(buf) + } else { + w.WriteHeader(http.StatusOK) + w.Write(buf) + } +} + +func (p *HTTPHeaderEchoServer) ServeHTTP(w http.ResponseWriter, req *http.Request) { + buf, err := ioutil.ReadAll(req.Body) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + w.Write(buf) + } else { + w.WriteHeader(http.StatusOK) + w.Write(buf) + } +} + +func HttpClientSetupForTest(t *testing.T) (net.Listener, net.Addr) { + addr, err := FindAvailableTCPServerPort(40000) + if err != nil { + t.Fatalf("Unable to find available tcp port addr: %s", err) + return nil, addr + } + l, err := net.Listen(addr.Network(), addr.String()) + if err != nil { + t.Fatalf("Unable to setup tcp listener on %s: %s", addr.String(), err) + return l, addr + } + go http.Serve(l, &HTTPEchoServer{}) + return l, addr +} + +func HttpClientSetupForHeaderTest(t *testing.T) (net.Listener, net.Addr) { + addr, err := FindAvailableTCPServerPort(40000) + if err != nil { + t.Fatalf("Unable to find available tcp port addr: %s", err) + return nil, addr + } + l, err := net.Listen(addr.Network(), addr.String()) + if err != nil { + t.Fatalf("Unable to setup tcp listener on %s: %s", addr.String(), err) + return l, addr + } + go http.Serve(l, &HTTPHeaderEchoServer{}) + return l, addr +} + +func ReadWriteProtocolTest(t *testing.T, protocolFactory TProtocolFactory) { + buf := bytes.NewBuffer(make([]byte, 0, 1024)) + l, addr := HttpClientSetupForTest(t) + defer l.Close() + transports := []TTransportFactory{ + NewTMemoryBufferTransportFactory(1024), + NewStreamTransportFactory(buf, buf, true), + NewTFramedTransportFactory(NewTMemoryBufferTransportFactory(1024)), + NewTZlibTransportFactoryWithFactory(0, NewTMemoryBufferTransportFactory(1024)), + NewTZlibTransportFactoryWithFactory(6, NewTMemoryBufferTransportFactory(1024)), + NewTZlibTransportFactoryWithFactory(9, NewTFramedTransportFactory(NewTMemoryBufferTransportFactory(1024))), + NewTHttpPostClientTransportFactory("http://" + addr.String()), + } + for _, tf := range transports { + trans, err := tf.GetTransport(nil) + if err != nil { + t.Error(err) + continue + } + p := protocolFactory.GetProtocol(trans) + ReadWriteBool(t, p, trans) + trans.Close() + } + for _, tf := range transports { + trans, err := tf.GetTransport(nil) + if err != nil { + t.Error(err) + continue + } + p := protocolFactory.GetProtocol(trans) + ReadWriteByte(t, p, trans) + trans.Close() + } + for _, tf := range transports { + trans, err := tf.GetTransport(nil) + if err != nil { + t.Error(err) + continue + } + p := protocolFactory.GetProtocol(trans) + ReadWriteI16(t, p, trans) + trans.Close() + } + for _, tf := range transports { + trans, err := tf.GetTransport(nil) + if err != nil { + t.Error(err) + continue + } + p := protocolFactory.GetProtocol(trans) + ReadWriteI32(t, p, trans) + trans.Close() + } + for _, tf := range transports { + trans, err := tf.GetTransport(nil) + if err != nil { + t.Error(err) + continue + } + p := protocolFactory.GetProtocol(trans) + ReadWriteI64(t, p, trans) + trans.Close() + } + for _, tf := range transports { + trans, err := tf.GetTransport(nil) + if err != nil { + t.Error(err) + continue + } + p := protocolFactory.GetProtocol(trans) + ReadWriteDouble(t, p, trans) + trans.Close() + } + for _, tf := range transports { + trans, err := tf.GetTransport(nil) + if err != nil { + t.Error(err) + continue + } + p := protocolFactory.GetProtocol(trans) + ReadWriteString(t, p, trans) + trans.Close() + } + for _, tf := range transports { + trans, err := tf.GetTransport(nil) + if err != nil { + t.Error(err) + continue + } + p := protocolFactory.GetProtocol(trans) + ReadWriteBinary(t, p, trans) + trans.Close() + } + for _, tf := range transports { + trans, err := tf.GetTransport(nil) + if err != nil { + t.Error(err) + continue + } + p := protocolFactory.GetProtocol(trans) + ReadWriteI64(t, p, trans) + ReadWriteDouble(t, p, trans) + ReadWriteBinary(t, p, trans) + ReadWriteByte(t, p, trans) + trans.Close() + } +} + +func ReadWriteBool(t testing.TB, p TProtocol, trans TTransport) { + thetype := TType(BOOL) + thelen := len(BOOL_VALUES) + err := p.WriteListBegin(thetype, thelen) + if err != nil { + t.Errorf("%s: %T %T %q Error writing list begin: %q", "ReadWriteBool", p, trans, err, thetype) + } + for k, v := range BOOL_VALUES { + err = p.WriteBool(v) + if err != nil { + t.Errorf("%s: %T %T %v Error writing bool in list at index %v: %v", "ReadWriteBool", p, trans, err, k, v) + } + } + p.WriteListEnd() + if err != nil { + t.Errorf("%s: %T %T %v Error writing list end: %v", "ReadWriteBool", p, trans, err, BOOL_VALUES) + } + p.Flush(context.Background()) + thetype2, thelen2, err := p.ReadListBegin() + if err != nil { + t.Errorf("%s: %T %T %v Error reading list: %v", "ReadWriteBool", p, trans, err, BOOL_VALUES) + } + _, ok := p.(*TSimpleJSONProtocol) + if !ok { + if thetype != thetype2 { + t.Errorf("%s: %T %T type %s != type %s", "ReadWriteBool", p, trans, thetype, thetype2) + } + if thelen != thelen2 { + t.Errorf("%s: %T %T len %v != len %v", "ReadWriteBool", p, trans, thelen, thelen2) + } + } + for k, v := range BOOL_VALUES { + value, err := p.ReadBool() + if err != nil { + t.Errorf("%s: %T %T %v Error reading bool at index %v: %v", "ReadWriteBool", p, trans, err, k, v) + } + if v != value { + t.Errorf("%s: index %v %v %v %v != %v", "ReadWriteBool", k, p, trans, v, value) + } + } + err = p.ReadListEnd() + if err != nil { + t.Errorf("%s: %T %T Unable to read list end: %q", "ReadWriteBool", p, trans, err) + } +} + +func ReadWriteByte(t testing.TB, p TProtocol, trans TTransport) { + thetype := TType(BYTE) + thelen := len(BYTE_VALUES) + err := p.WriteListBegin(thetype, thelen) + if err != nil { + t.Errorf("%s: %T %T %q Error writing list begin: %q", "ReadWriteByte", p, trans, err, thetype) + } + for k, v := range BYTE_VALUES { + err = p.WriteByte(v) + if err != nil { + t.Errorf("%s: %T %T %q Error writing byte in list at index %d: %q", "ReadWriteByte", p, trans, err, k, v) + } + } + err = p.WriteListEnd() + if err != nil { + t.Errorf("%s: %T %T %q Error writing list end: %q", "ReadWriteByte", p, trans, err, BYTE_VALUES) + } + err = p.Flush(context.Background()) + if err != nil { + t.Errorf("%s: %T %T %q Error flushing list of bytes: %q", "ReadWriteByte", p, trans, err, BYTE_VALUES) + } + thetype2, thelen2, err := p.ReadListBegin() + if err != nil { + t.Errorf("%s: %T %T %q Error reading list: %q", "ReadWriteByte", p, trans, err, BYTE_VALUES) + } + _, ok := p.(*TSimpleJSONProtocol) + if !ok { + if thetype != thetype2 { + t.Errorf("%s: %T %T type %s != type %s", "ReadWriteByte", p, trans, thetype, thetype2) + } + if thelen != thelen2 { + t.Errorf("%s: %T %T len %v != len %v", "ReadWriteByte", p, trans, thelen, thelen2) + } + } + for k, v := range BYTE_VALUES { + value, err := p.ReadByte() + if err != nil { + t.Errorf("%s: %T %T %q Error reading byte at index %d: %q", "ReadWriteByte", p, trans, err, k, v) + } + if v != value { + t.Errorf("%s: %T %T %d != %d", "ReadWriteByte", p, trans, v, value) + } + } + err = p.ReadListEnd() + if err != nil { + t.Errorf("%s: %T %T Unable to read list end: %q", "ReadWriteByte", p, trans, err) + } +} + +func ReadWriteI16(t testing.TB, p TProtocol, trans TTransport) { + thetype := TType(I16) + thelen := len(INT16_VALUES) + p.WriteListBegin(thetype, thelen) + for _, v := range INT16_VALUES { + p.WriteI16(v) + } + p.WriteListEnd() + p.Flush(context.Background()) + thetype2, thelen2, err := p.ReadListBegin() + if err != nil { + t.Errorf("%s: %T %T %q Error reading list: %q", "ReadWriteI16", p, trans, err, INT16_VALUES) + } + _, ok := p.(*TSimpleJSONProtocol) + if !ok { + if thetype != thetype2 { + t.Errorf("%s: %T %T type %s != type %s", "ReadWriteI16", p, trans, thetype, thetype2) + } + if thelen != thelen2 { + t.Errorf("%s: %T %T len %v != len %v", "ReadWriteI16", p, trans, thelen, thelen2) + } + } + for k, v := range INT16_VALUES { + value, err := p.ReadI16() + if err != nil { + t.Errorf("%s: %T %T %q Error reading int16 at index %d: %q", "ReadWriteI16", p, trans, err, k, v) + } + if v != value { + t.Errorf("%s: %T %T %d != %d", "ReadWriteI16", p, trans, v, value) + } + } + err = p.ReadListEnd() + if err != nil { + t.Errorf("%s: %T %T Unable to read list end: %q", "ReadWriteI16", p, trans, err) + } +} + +func ReadWriteI32(t testing.TB, p TProtocol, trans TTransport) { + thetype := TType(I32) + thelen := len(INT32_VALUES) + p.WriteListBegin(thetype, thelen) + for _, v := range INT32_VALUES { + p.WriteI32(v) + } + p.WriteListEnd() + p.Flush(context.Background()) + thetype2, thelen2, err := p.ReadListBegin() + if err != nil { + t.Errorf("%s: %T %T %q Error reading list: %q", "ReadWriteI32", p, trans, err, INT32_VALUES) + } + _, ok := p.(*TSimpleJSONProtocol) + if !ok { + if thetype != thetype2 { + t.Errorf("%s: %T %T type %s != type %s", "ReadWriteI32", p, trans, thetype, thetype2) + } + if thelen != thelen2 { + t.Errorf("%s: %T %T len %v != len %v", "ReadWriteI32", p, trans, thelen, thelen2) + } + } + for k, v := range INT32_VALUES { + value, err := p.ReadI32() + if err != nil { + t.Errorf("%s: %T %T %q Error reading int32 at index %d: %q", "ReadWriteI32", p, trans, err, k, v) + } + if v != value { + t.Errorf("%s: %T %T %d != %d", "ReadWriteI32", p, trans, v, value) + } + } + if err != nil { + t.Errorf("%s: %T %T Unable to read list end: %q", "ReadWriteI32", p, trans, err) + } +} + +func ReadWriteI64(t testing.TB, p TProtocol, trans TTransport) { + thetype := TType(I64) + thelen := len(INT64_VALUES) + p.WriteListBegin(thetype, thelen) + for _, v := range INT64_VALUES { + p.WriteI64(v) + } + p.WriteListEnd() + p.Flush(context.Background()) + thetype2, thelen2, err := p.ReadListBegin() + if err != nil { + t.Errorf("%s: %T %T %q Error reading list: %q", "ReadWriteI64", p, trans, err, INT64_VALUES) + } + _, ok := p.(*TSimpleJSONProtocol) + if !ok { + if thetype != thetype2 { + t.Errorf("%s: %T %T type %s != type %s", "ReadWriteI64", p, trans, thetype, thetype2) + } + if thelen != thelen2 { + t.Errorf("%s: %T %T len %v != len %v", "ReadWriteI64", p, trans, thelen, thelen2) + } + } + for k, v := range INT64_VALUES { + value, err := p.ReadI64() + if err != nil { + t.Errorf("%s: %T %T %q Error reading int64 at index %d: %q", "ReadWriteI64", p, trans, err, k, v) + } + if v != value { + t.Errorf("%s: %T %T %q != %q", "ReadWriteI64", p, trans, v, value) + } + } + if err != nil { + t.Errorf("%s: %T %T Unable to read list end: %q", "ReadWriteI64", p, trans, err) + } +} + +func ReadWriteDouble(t testing.TB, p TProtocol, trans TTransport) { + thetype := TType(DOUBLE) + thelen := len(DOUBLE_VALUES) + p.WriteListBegin(thetype, thelen) + for _, v := range DOUBLE_VALUES { + p.WriteDouble(v) + } + p.WriteListEnd() + p.Flush(context.Background()) + thetype2, thelen2, err := p.ReadListBegin() + if err != nil { + t.Errorf("%s: %T %T %v Error reading list: %v", "ReadWriteDouble", p, trans, err, DOUBLE_VALUES) + } + if thetype != thetype2 { + t.Errorf("%s: %T %T type %s != type %s", "ReadWriteDouble", p, trans, thetype, thetype2) + } + if thelen != thelen2 { + t.Errorf("%s: %T %T len %v != len %v", "ReadWriteDouble", p, trans, thelen, thelen2) + } + for k, v := range DOUBLE_VALUES { + value, err := p.ReadDouble() + if err != nil { + t.Errorf("%s: %T %T %q Error reading double at index %d: %v", "ReadWriteDouble", p, trans, err, k, v) + } + if math.IsNaN(v) { + if !math.IsNaN(value) { + t.Errorf("%s: %T %T math.IsNaN(%v) != math.IsNaN(%v)", "ReadWriteDouble", p, trans, v, value) + } + } else if v != value { + t.Errorf("%s: %T %T %v != %v", "ReadWriteDouble", p, trans, v, value) + } + } + err = p.ReadListEnd() + if err != nil { + t.Errorf("%s: %T %T Unable to read list end: %q", "ReadWriteDouble", p, trans, err) + } +} + +func ReadWriteString(t testing.TB, p TProtocol, trans TTransport) { + thetype := TType(STRING) + thelen := len(STRING_VALUES) + p.WriteListBegin(thetype, thelen) + for _, v := range STRING_VALUES { + p.WriteString(v) + } + p.WriteListEnd() + p.Flush(context.Background()) + thetype2, thelen2, err := p.ReadListBegin() + if err != nil { + t.Errorf("%s: %T %T %q Error reading list: %q", "ReadWriteString", p, trans, err, STRING_VALUES) + } + _, ok := p.(*TSimpleJSONProtocol) + if !ok { + if thetype != thetype2 { + t.Errorf("%s: %T %T type %s != type %s", "ReadWriteString", p, trans, thetype, thetype2) + } + if thelen != thelen2 { + t.Errorf("%s: %T %T len %v != len %v", "ReadWriteString", p, trans, thelen, thelen2) + } + } + for k, v := range STRING_VALUES { + value, err := p.ReadString() + if err != nil { + t.Errorf("%s: %T %T %q Error reading string at index %d: %q", "ReadWriteString", p, trans, err, k, v) + } + if v != value { + t.Errorf("%s: %T %T %v != %v", "ReadWriteString", p, trans, v, value) + } + } + if err != nil { + t.Errorf("%s: %T %T Unable to read list end: %q", "ReadWriteString", p, trans, err) + } +} + +func ReadWriteBinary(t testing.TB, p TProtocol, trans TTransport) { + v := protocol_bdata + p.WriteBinary(v) + p.Flush(context.Background()) + value, err := p.ReadBinary() + if err != nil { + t.Errorf("%s: %T %T Unable to read binary: %s", "ReadWriteBinary", p, trans, err.Error()) + } + if len(v) != len(value) { + t.Errorf("%s: %T %T len(v) != len(value)... %d != %d", "ReadWriteBinary", p, trans, len(v), len(value)) + } else { + for i := 0; i < len(v); i++ { + if v[i] != value[i] { + t.Errorf("%s: %T %T %s != %s", "ReadWriteBinary", p, trans, v, value) + } + } + } +} diff --git a/src/jaegertracing/thrift/lib/go/thrift/rich_transport.go b/src/jaegertracing/thrift/lib/go/thrift/rich_transport.go new file mode 100644 index 000000000..4025bebea --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/thrift/rich_transport.go @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import "io" + +type RichTransport struct { + TTransport +} + +// Wraps Transport to provide TRichTransport interface +func NewTRichTransport(trans TTransport) *RichTransport { + return &RichTransport{trans} +} + +func (r *RichTransport) ReadByte() (c byte, err error) { + return readByte(r.TTransport) +} + +func (r *RichTransport) WriteByte(c byte) error { + return writeByte(r.TTransport, c) +} + +func (r *RichTransport) WriteString(s string) (n int, err error) { + return r.Write([]byte(s)) +} + +func (r *RichTransport) RemainingBytes() (num_bytes uint64) { + return r.TTransport.RemainingBytes() +} + +func readByte(r io.Reader) (c byte, err error) { + v := [1]byte{0} + n, err := r.Read(v[0:1]) + if n > 0 && (err == nil || err == io.EOF) { + return v[0], nil + } + if n > 0 && err != nil { + return v[0], err + } + if err != nil { + return 0, err + } + return v[0], nil +} + +func writeByte(w io.Writer, c byte) error { + v := [1]byte{c} + _, err := w.Write(v[0:1]) + return err +} diff --git a/src/jaegertracing/thrift/lib/go/thrift/rich_transport_test.go b/src/jaegertracing/thrift/lib/go/thrift/rich_transport_test.go new file mode 100644 index 000000000..25c3fd5aa --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/thrift/rich_transport_test.go @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import ( + "bytes" + "errors" + "io" + "reflect" + "testing" +) + +func TestEnsureTransportsAreRich(t *testing.T) { + buf := bytes.NewBuffer(make([]byte, 0, 1024)) + + transports := []TTransportFactory{ + NewTMemoryBufferTransportFactory(1024), + NewStreamTransportFactory(buf, buf, true), + NewTFramedTransportFactory(NewTMemoryBufferTransportFactory(1024)), + NewTHttpPostClientTransportFactory("http://127.0.0.1"), + } + for _, tf := range transports { + trans, err := tf.GetTransport(nil) + if err != nil { + t.Error(err) + continue + } + _, ok := trans.(TRichTransport) + if !ok { + t.Errorf("Transport %s does not implement TRichTransport interface", reflect.ValueOf(trans)) + } + } +} + +// TestReadByte tests whether readByte handles error cases correctly. +func TestReadByte(t *testing.T) { + for i, test := range readByteTests { + v, err := readByte(test.r) + if v != test.v { + t.Fatalf("TestReadByte %d: value differs. Expected %d, got %d", i, test.v, test.r.v) + } + if err != test.err { + t.Fatalf("TestReadByte %d: error differs. Expected %s, got %s", i, test.err, test.r.err) + } + } +} + +var someError = errors.New("Some error") +var readByteTests = []struct { + r *mockReader + v byte + err error +}{ + {&mockReader{0, 55, io.EOF}, 0, io.EOF}, // reader sends EOF w/o data + {&mockReader{0, 55, someError}, 0, someError}, // reader sends some other error + {&mockReader{1, 55, nil}, 55, nil}, // reader sends data w/o error + {&mockReader{1, 55, io.EOF}, 55, nil}, // reader sends data with EOF + {&mockReader{1, 55, someError}, 55, someError}, // reader sends data withsome error +} + +type mockReader struct { + n int + v byte + err error +} + +func (r *mockReader) Read(p []byte) (n int, err error) { + if r.n > 0 { + p[0] = r.v + } + return r.n, r.err +} diff --git a/src/jaegertracing/thrift/lib/go/thrift/serializer.go b/src/jaegertracing/thrift/lib/go/thrift/serializer.go new file mode 100644 index 000000000..1ff4d3754 --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/thrift/serializer.go @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import ( + "context" +) + +type TSerializer struct { + Transport *TMemoryBuffer + Protocol TProtocol +} + +type TStruct interface { + Write(p TProtocol) error + Read(p TProtocol) error +} + +func NewTSerializer() *TSerializer { + transport := NewTMemoryBufferLen(1024) + protocol := NewTBinaryProtocolFactoryDefault().GetProtocol(transport) + + return &TSerializer{ + transport, + protocol} +} + +func (t *TSerializer) WriteString(ctx context.Context, msg TStruct) (s string, err error) { + t.Transport.Reset() + + if err = msg.Write(t.Protocol); err != nil { + return + } + + if err = t.Protocol.Flush(ctx); err != nil { + return + } + if err = t.Transport.Flush(ctx); err != nil { + return + } + + return t.Transport.String(), nil +} + +func (t *TSerializer) Write(ctx context.Context, msg TStruct) (b []byte, err error) { + t.Transport.Reset() + + if err = msg.Write(t.Protocol); err != nil { + return + } + + if err = t.Protocol.Flush(ctx); err != nil { + return + } + + if err = t.Transport.Flush(ctx); err != nil { + return + } + + b = append(b, t.Transport.Bytes()...) + return +} diff --git a/src/jaegertracing/thrift/lib/go/thrift/serializer_test.go b/src/jaegertracing/thrift/lib/go/thrift/serializer_test.go new file mode 100644 index 000000000..32227ef49 --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/thrift/serializer_test.go @@ -0,0 +1,170 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import ( + "context" + "errors" + "fmt" + "testing" +) + +type ProtocolFactory interface { + GetProtocol(t TTransport) TProtocol +} + +func compareStructs(m, m1 MyTestStruct) (bool, error) { + switch { + case m.On != m1.On: + return false, errors.New("Boolean not equal") + case m.B != m1.B: + return false, errors.New("Byte not equal") + case m.Int16 != m1.Int16: + return false, errors.New("Int16 not equal") + case m.Int32 != m1.Int32: + return false, errors.New("Int32 not equal") + case m.Int64 != m1.Int64: + return false, errors.New("Int64 not equal") + case m.D != m1.D: + return false, errors.New("Double not equal") + case m.St != m1.St: + return false, errors.New("String not equal") + + case len(m.Bin) != len(m1.Bin): + return false, errors.New("Binary size not equal") + case len(m.Bin) == len(m1.Bin): + for i := range m.Bin { + if m.Bin[i] != m1.Bin[i] { + return false, errors.New("Binary not equal") + } + } + case len(m.StringMap) != len(m1.StringMap): + return false, errors.New("StringMap size not equal") + case len(m.StringList) != len(m1.StringList): + return false, errors.New("StringList size not equal") + case len(m.StringSet) != len(m1.StringSet): + return false, errors.New("StringSet size not equal") + + case m.E != m1.E: + return false, errors.New("MyTestEnum not equal") + + default: + return true, nil + + } + return true, nil +} + +func ProtocolTest1(test *testing.T, pf ProtocolFactory) (bool, error) { + t := NewTSerializer() + t.Protocol = pf.GetProtocol(t.Transport) + var m = MyTestStruct{} + m.On = true + m.B = int8(0) + m.Int16 = 1 + m.Int32 = 2 + m.Int64 = 3 + m.D = 4.1 + m.St = "Test" + m.Bin = make([]byte, 10) + m.StringMap = make(map[string]string, 5) + m.StringList = make([]string, 5) + m.StringSet = make(map[string]struct{}, 5) + m.E = 2 + + s, err := t.WriteString(context.Background(), &m) + if err != nil { + return false, errors.New(fmt.Sprintf("Unable to Serialize struct\n\t %s", err)) + } + + t1 := NewTDeserializer() + t1.Protocol = pf.GetProtocol(t1.Transport) + var m1 = MyTestStruct{} + if err = t1.ReadString(&m1, s); err != nil { + return false, errors.New(fmt.Sprintf("Unable to Deserialize struct\n\t %s", err)) + + } + + return compareStructs(m, m1) + +} + +func ProtocolTest2(test *testing.T, pf ProtocolFactory) (bool, error) { + t := NewTSerializer() + t.Protocol = pf.GetProtocol(t.Transport) + var m = MyTestStruct{} + m.On = false + m.B = int8(0) + m.Int16 = 1 + m.Int32 = 2 + m.Int64 = 3 + m.D = 4.1 + m.St = "Test" + m.Bin = make([]byte, 10) + m.StringMap = make(map[string]string, 5) + m.StringList = make([]string, 5) + m.StringSet = make(map[string]struct{}, 5) + m.E = 2 + + s, err := t.WriteString(context.Background(), &m) + if err != nil { + return false, errors.New(fmt.Sprintf("Unable to Serialize struct\n\t %s", err)) + + } + + t1 := NewTDeserializer() + t1.Protocol = pf.GetProtocol(t1.Transport) + var m1 = MyTestStruct{} + if err = t1.ReadString(&m1, s); err != nil { + return false, errors.New(fmt.Sprintf("Unable to Deserialize struct\n\t %s", err)) + + } + + return compareStructs(m, m1) + +} + +func TestSerializer(t *testing.T) { + + var protocol_factories map[string]ProtocolFactory + protocol_factories = make(map[string]ProtocolFactory) + protocol_factories["Binary"] = NewTBinaryProtocolFactoryDefault() + protocol_factories["Compact"] = NewTCompactProtocolFactory() + //protocol_factories["SimpleJSON"] = NewTSimpleJSONProtocolFactory() - write only, can't be read back by design + protocol_factories["JSON"] = NewTJSONProtocolFactory() + + var tests map[string]func(*testing.T, ProtocolFactory) (bool, error) + tests = make(map[string]func(*testing.T, ProtocolFactory) (bool, error)) + tests["Test 1"] = ProtocolTest1 + tests["Test 2"] = ProtocolTest2 + //tests["Test 3"] = ProtocolTest3 // Example of how to add additional tests + + for name, pf := range protocol_factories { + + for test, f := range tests { + + if s, err := f(t, pf); !s || err != nil { + t.Errorf("%s Failed for %s protocol\n\t %s", test, name, err) + } + + } + } + +} diff --git a/src/jaegertracing/thrift/lib/go/thrift/serializer_types_test.go b/src/jaegertracing/thrift/lib/go/thrift/serializer_types_test.go new file mode 100644 index 000000000..e5472bbff --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/thrift/serializer_types_test.go @@ -0,0 +1,633 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +// Autogenerated by Thrift Compiler (FIXME) +// DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING + +/* THE FOLLOWING THRIFT FILE WAS USED TO CREATE THIS + +enum MyTestEnum { + FIRST = 1, + SECOND = 2, + THIRD = 3, + FOURTH = 4, +} + +struct MyTestStruct { + 1: bool on, + 2: byte b, + 3: i16 int16, + 4: i32 int32, + 5: i64 int64, + 6: double d, + 7: string st, + 8: binary bin, + 9: map<string, string> stringMap, + 10: list<string> stringList, + 11: set<string> stringSet, + 12: MyTestEnum e, +} +*/ + +import ( + "fmt" +) + +// (needed to ensure safety because of naive import list construction.) +var _ = ZERO +var _ = fmt.Printf + +var GoUnusedProtection__ int + +type MyTestEnum int64 + +const ( + MyTestEnum_FIRST MyTestEnum = 1 + MyTestEnum_SECOND MyTestEnum = 2 + MyTestEnum_THIRD MyTestEnum = 3 + MyTestEnum_FOURTH MyTestEnum = 4 +) + +func (p MyTestEnum) String() string { + switch p { + case MyTestEnum_FIRST: + return "FIRST" + case MyTestEnum_SECOND: + return "SECOND" + case MyTestEnum_THIRD: + return "THIRD" + case MyTestEnum_FOURTH: + return "FOURTH" + } + return "<UNSET>" +} + +func MyTestEnumFromString(s string) (MyTestEnum, error) { + switch s { + case "FIRST": + return MyTestEnum_FIRST, nil + case "SECOND": + return MyTestEnum_SECOND, nil + case "THIRD": + return MyTestEnum_THIRD, nil + case "FOURTH": + return MyTestEnum_FOURTH, nil + } + return MyTestEnum(0), fmt.Errorf("not a valid MyTestEnum string") +} + +func MyTestEnumPtr(v MyTestEnum) *MyTestEnum { return &v } + +type MyTestStruct struct { + On bool `thrift:"on,1" json:"on"` + B int8 `thrift:"b,2" json:"b"` + Int16 int16 `thrift:"int16,3" json:"int16"` + Int32 int32 `thrift:"int32,4" json:"int32"` + Int64 int64 `thrift:"int64,5" json:"int64"` + D float64 `thrift:"d,6" json:"d"` + St string `thrift:"st,7" json:"st"` + Bin []byte `thrift:"bin,8" json:"bin"` + StringMap map[string]string `thrift:"stringMap,9" json:"stringMap"` + StringList []string `thrift:"stringList,10" json:"stringList"` + StringSet map[string]struct{} `thrift:"stringSet,11" json:"stringSet"` + E MyTestEnum `thrift:"e,12" json:"e"` +} + +func NewMyTestStruct() *MyTestStruct { + return &MyTestStruct{} +} + +func (p *MyTestStruct) GetOn() bool { + return p.On +} + +func (p *MyTestStruct) GetB() int8 { + return p.B +} + +func (p *MyTestStruct) GetInt16() int16 { + return p.Int16 +} + +func (p *MyTestStruct) GetInt32() int32 { + return p.Int32 +} + +func (p *MyTestStruct) GetInt64() int64 { + return p.Int64 +} + +func (p *MyTestStruct) GetD() float64 { + return p.D +} + +func (p *MyTestStruct) GetSt() string { + return p.St +} + +func (p *MyTestStruct) GetBin() []byte { + return p.Bin +} + +func (p *MyTestStruct) GetStringMap() map[string]string { + return p.StringMap +} + +func (p *MyTestStruct) GetStringList() []string { + return p.StringList +} + +func (p *MyTestStruct) GetStringSet() map[string]struct{} { + return p.StringSet +} + +func (p *MyTestStruct) GetE() MyTestEnum { + return p.E +} +func (p *MyTestStruct) Read(iprot TProtocol) error { + if _, err := iprot.ReadStructBegin(); err != nil { + return PrependError(fmt.Sprintf("%T read error: ", p), err) + } + for { + _, fieldTypeId, fieldId, err := iprot.ReadFieldBegin() + if err != nil { + return PrependError(fmt.Sprintf("%T field %d read error: ", p, fieldId), err) + } + if fieldTypeId == STOP { + break + } + switch fieldId { + case 1: + if err := p.readField1(iprot); err != nil { + return err + } + case 2: + if err := p.readField2(iprot); err != nil { + return err + } + case 3: + if err := p.readField3(iprot); err != nil { + return err + } + case 4: + if err := p.readField4(iprot); err != nil { + return err + } + case 5: + if err := p.readField5(iprot); err != nil { + return err + } + case 6: + if err := p.readField6(iprot); err != nil { + return err + } + case 7: + if err := p.readField7(iprot); err != nil { + return err + } + case 8: + if err := p.readField8(iprot); err != nil { + return err + } + case 9: + if err := p.readField9(iprot); err != nil { + return err + } + case 10: + if err := p.readField10(iprot); err != nil { + return err + } + case 11: + if err := p.readField11(iprot); err != nil { + return err + } + case 12: + if err := p.readField12(iprot); err != nil { + return err + } + default: + if err := iprot.Skip(fieldTypeId); err != nil { + return err + } + } + if err := iprot.ReadFieldEnd(); err != nil { + return err + } + } + if err := iprot.ReadStructEnd(); err != nil { + return PrependError(fmt.Sprintf("%T read struct end error: ", p), err) + } + return nil +} + +func (p *MyTestStruct) readField1(iprot TProtocol) error { + if v, err := iprot.ReadBool(); err != nil { + return PrependError("error reading field 1: ", err) + } else { + p.On = v + } + return nil +} + +func (p *MyTestStruct) readField2(iprot TProtocol) error { + if v, err := iprot.ReadByte(); err != nil { + return PrependError("error reading field 2: ", err) + } else { + temp := int8(v) + p.B = temp + } + return nil +} + +func (p *MyTestStruct) readField3(iprot TProtocol) error { + if v, err := iprot.ReadI16(); err != nil { + return PrependError("error reading field 3: ", err) + } else { + p.Int16 = v + } + return nil +} + +func (p *MyTestStruct) readField4(iprot TProtocol) error { + if v, err := iprot.ReadI32(); err != nil { + return PrependError("error reading field 4: ", err) + } else { + p.Int32 = v + } + return nil +} + +func (p *MyTestStruct) readField5(iprot TProtocol) error { + if v, err := iprot.ReadI64(); err != nil { + return PrependError("error reading field 5: ", err) + } else { + p.Int64 = v + } + return nil +} + +func (p *MyTestStruct) readField6(iprot TProtocol) error { + if v, err := iprot.ReadDouble(); err != nil { + return PrependError("error reading field 6: ", err) + } else { + p.D = v + } + return nil +} + +func (p *MyTestStruct) readField7(iprot TProtocol) error { + if v, err := iprot.ReadString(); err != nil { + return PrependError("error reading field 7: ", err) + } else { + p.St = v + } + return nil +} + +func (p *MyTestStruct) readField8(iprot TProtocol) error { + if v, err := iprot.ReadBinary(); err != nil { + return PrependError("error reading field 8: ", err) + } else { + p.Bin = v + } + return nil +} + +func (p *MyTestStruct) readField9(iprot TProtocol) error { + _, _, size, err := iprot.ReadMapBegin() + if err != nil { + return PrependError("error reading map begin: ", err) + } + tMap := make(map[string]string, size) + p.StringMap = tMap + for i := 0; i < size; i++ { + var _key0 string + if v, err := iprot.ReadString(); err != nil { + return PrependError("error reading field 0: ", err) + } else { + _key0 = v + } + var _val1 string + if v, err := iprot.ReadString(); err != nil { + return PrependError("error reading field 0: ", err) + } else { + _val1 = v + } + p.StringMap[_key0] = _val1 + } + if err := iprot.ReadMapEnd(); err != nil { + return PrependError("error reading map end: ", err) + } + return nil +} + +func (p *MyTestStruct) readField10(iprot TProtocol) error { + _, size, err := iprot.ReadListBegin() + if err != nil { + return PrependError("error reading list begin: ", err) + } + tSlice := make([]string, 0, size) + p.StringList = tSlice + for i := 0; i < size; i++ { + var _elem2 string + if v, err := iprot.ReadString(); err != nil { + return PrependError("error reading field 0: ", err) + } else { + _elem2 = v + } + p.StringList = append(p.StringList, _elem2) + } + if err := iprot.ReadListEnd(); err != nil { + return PrependError("error reading list end: ", err) + } + return nil +} + +func (p *MyTestStruct) readField11(iprot TProtocol) error { + _, size, err := iprot.ReadSetBegin() + if err != nil { + return PrependError("error reading set begin: ", err) + } + tSet := make(map[string]struct{}, size) + p.StringSet = tSet + for i := 0; i < size; i++ { + var _elem3 string + if v, err := iprot.ReadString(); err != nil { + return PrependError("error reading field 0: ", err) + } else { + _elem3 = v + } + p.StringSet[_elem3] = struct{}{} + } + if err := iprot.ReadSetEnd(); err != nil { + return PrependError("error reading set end: ", err) + } + return nil +} + +func (p *MyTestStruct) readField12(iprot TProtocol) error { + if v, err := iprot.ReadI32(); err != nil { + return PrependError("error reading field 12: ", err) + } else { + temp := MyTestEnum(v) + p.E = temp + } + return nil +} + +func (p *MyTestStruct) Write(oprot TProtocol) error { + if err := oprot.WriteStructBegin("MyTestStruct"); err != nil { + return PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) + } + if err := p.writeField1(oprot); err != nil { + return err + } + if err := p.writeField2(oprot); err != nil { + return err + } + if err := p.writeField3(oprot); err != nil { + return err + } + if err := p.writeField4(oprot); err != nil { + return err + } + if err := p.writeField5(oprot); err != nil { + return err + } + if err := p.writeField6(oprot); err != nil { + return err + } + if err := p.writeField7(oprot); err != nil { + return err + } + if err := p.writeField8(oprot); err != nil { + return err + } + if err := p.writeField9(oprot); err != nil { + return err + } + if err := p.writeField10(oprot); err != nil { + return err + } + if err := p.writeField11(oprot); err != nil { + return err + } + if err := p.writeField12(oprot); err != nil { + return err + } + if err := oprot.WriteFieldStop(); err != nil { + return PrependError("write field stop error: ", err) + } + if err := oprot.WriteStructEnd(); err != nil { + return PrependError("write struct stop error: ", err) + } + return nil +} + +func (p *MyTestStruct) writeField1(oprot TProtocol) (err error) { + if err := oprot.WriteFieldBegin("on", BOOL, 1); err != nil { + return PrependError(fmt.Sprintf("%T write field begin error 1:on: ", p), err) + } + if err := oprot.WriteBool(bool(p.On)); err != nil { + return PrependError(fmt.Sprintf("%T.on (1) field write error: ", p), err) + } + if err := oprot.WriteFieldEnd(); err != nil { + return PrependError(fmt.Sprintf("%T write field end error 1:on: ", p), err) + } + return err +} + +func (p *MyTestStruct) writeField2(oprot TProtocol) (err error) { + if err := oprot.WriteFieldBegin("b", BYTE, 2); err != nil { + return PrependError(fmt.Sprintf("%T write field begin error 2:b: ", p), err) + } + if err := oprot.WriteByte(int8(p.B)); err != nil { + return PrependError(fmt.Sprintf("%T.b (2) field write error: ", p), err) + } + if err := oprot.WriteFieldEnd(); err != nil { + return PrependError(fmt.Sprintf("%T write field end error 2:b: ", p), err) + } + return err +} + +func (p *MyTestStruct) writeField3(oprot TProtocol) (err error) { + if err := oprot.WriteFieldBegin("int16", I16, 3); err != nil { + return PrependError(fmt.Sprintf("%T write field begin error 3:int16: ", p), err) + } + if err := oprot.WriteI16(int16(p.Int16)); err != nil { + return PrependError(fmt.Sprintf("%T.int16 (3) field write error: ", p), err) + } + if err := oprot.WriteFieldEnd(); err != nil { + return PrependError(fmt.Sprintf("%T write field end error 3:int16: ", p), err) + } + return err +} + +func (p *MyTestStruct) writeField4(oprot TProtocol) (err error) { + if err := oprot.WriteFieldBegin("int32", I32, 4); err != nil { + return PrependError(fmt.Sprintf("%T write field begin error 4:int32: ", p), err) + } + if err := oprot.WriteI32(int32(p.Int32)); err != nil { + return PrependError(fmt.Sprintf("%T.int32 (4) field write error: ", p), err) + } + if err := oprot.WriteFieldEnd(); err != nil { + return PrependError(fmt.Sprintf("%T write field end error 4:int32: ", p), err) + } + return err +} + +func (p *MyTestStruct) writeField5(oprot TProtocol) (err error) { + if err := oprot.WriteFieldBegin("int64", I64, 5); err != nil { + return PrependError(fmt.Sprintf("%T write field begin error 5:int64: ", p), err) + } + if err := oprot.WriteI64(int64(p.Int64)); err != nil { + return PrependError(fmt.Sprintf("%T.int64 (5) field write error: ", p), err) + } + if err := oprot.WriteFieldEnd(); err != nil { + return PrependError(fmt.Sprintf("%T write field end error 5:int64: ", p), err) + } + return err +} + +func (p *MyTestStruct) writeField6(oprot TProtocol) (err error) { + if err := oprot.WriteFieldBegin("d", DOUBLE, 6); err != nil { + return PrependError(fmt.Sprintf("%T write field begin error 6:d: ", p), err) + } + if err := oprot.WriteDouble(float64(p.D)); err != nil { + return PrependError(fmt.Sprintf("%T.d (6) field write error: ", p), err) + } + if err := oprot.WriteFieldEnd(); err != nil { + return PrependError(fmt.Sprintf("%T write field end error 6:d: ", p), err) + } + return err +} + +func (p *MyTestStruct) writeField7(oprot TProtocol) (err error) { + if err := oprot.WriteFieldBegin("st", STRING, 7); err != nil { + return PrependError(fmt.Sprintf("%T write field begin error 7:st: ", p), err) + } + if err := oprot.WriteString(string(p.St)); err != nil { + return PrependError(fmt.Sprintf("%T.st (7) field write error: ", p), err) + } + if err := oprot.WriteFieldEnd(); err != nil { + return PrependError(fmt.Sprintf("%T write field end error 7:st: ", p), err) + } + return err +} + +func (p *MyTestStruct) writeField8(oprot TProtocol) (err error) { + if err := oprot.WriteFieldBegin("bin", STRING, 8); err != nil { + return PrependError(fmt.Sprintf("%T write field begin error 8:bin: ", p), err) + } + if err := oprot.WriteBinary(p.Bin); err != nil { + return PrependError(fmt.Sprintf("%T.bin (8) field write error: ", p), err) + } + if err := oprot.WriteFieldEnd(); err != nil { + return PrependError(fmt.Sprintf("%T write field end error 8:bin: ", p), err) + } + return err +} + +func (p *MyTestStruct) writeField9(oprot TProtocol) (err error) { + if err := oprot.WriteFieldBegin("stringMap", MAP, 9); err != nil { + return PrependError(fmt.Sprintf("%T write field begin error 9:stringMap: ", p), err) + } + if err := oprot.WriteMapBegin(STRING, STRING, len(p.StringMap)); err != nil { + return PrependError("error writing map begin: ", err) + } + for k, v := range p.StringMap { + if err := oprot.WriteString(string(k)); err != nil { + return PrependError(fmt.Sprintf("%T. (0) field write error: ", p), err) + } + if err := oprot.WriteString(string(v)); err != nil { + return PrependError(fmt.Sprintf("%T. (0) field write error: ", p), err) + } + } + if err := oprot.WriteMapEnd(); err != nil { + return PrependError("error writing map end: ", err) + } + if err := oprot.WriteFieldEnd(); err != nil { + return PrependError(fmt.Sprintf("%T write field end error 9:stringMap: ", p), err) + } + return err +} + +func (p *MyTestStruct) writeField10(oprot TProtocol) (err error) { + if err := oprot.WriteFieldBegin("stringList", LIST, 10); err != nil { + return PrependError(fmt.Sprintf("%T write field begin error 10:stringList: ", p), err) + } + if err := oprot.WriteListBegin(STRING, len(p.StringList)); err != nil { + return PrependError("error writing list begin: ", err) + } + for _, v := range p.StringList { + if err := oprot.WriteString(string(v)); err != nil { + return PrependError(fmt.Sprintf("%T. (0) field write error: ", p), err) + } + } + if err := oprot.WriteListEnd(); err != nil { + return PrependError("error writing list end: ", err) + } + if err := oprot.WriteFieldEnd(); err != nil { + return PrependError(fmt.Sprintf("%T write field end error 10:stringList: ", p), err) + } + return err +} + +func (p *MyTestStruct) writeField11(oprot TProtocol) (err error) { + if err := oprot.WriteFieldBegin("stringSet", SET, 11); err != nil { + return PrependError(fmt.Sprintf("%T write field begin error 11:stringSet: ", p), err) + } + if err := oprot.WriteSetBegin(STRING, len(p.StringSet)); err != nil { + return PrependError("error writing set begin: ", err) + } + for v := range p.StringSet { + if err := oprot.WriteString(string(v)); err != nil { + return PrependError(fmt.Sprintf("%T. (0) field write error: ", p), err) + } + } + if err := oprot.WriteSetEnd(); err != nil { + return PrependError("error writing set end: ", err) + } + if err := oprot.WriteFieldEnd(); err != nil { + return PrependError(fmt.Sprintf("%T write field end error 11:stringSet: ", p), err) + } + return err +} + +func (p *MyTestStruct) writeField12(oprot TProtocol) (err error) { + if err := oprot.WriteFieldBegin("e", I32, 12); err != nil { + return PrependError(fmt.Sprintf("%T write field begin error 12:e: ", p), err) + } + if err := oprot.WriteI32(int32(p.E)); err != nil { + return PrependError(fmt.Sprintf("%T.e (12) field write error: ", p), err) + } + if err := oprot.WriteFieldEnd(); err != nil { + return PrependError(fmt.Sprintf("%T write field end error 12:e: ", p), err) + } + return err +} + +func (p *MyTestStruct) String() string { + if p == nil { + return "<nil>" + } + return fmt.Sprintf("MyTestStruct(%+v)", *p) +} diff --git a/src/jaegertracing/thrift/lib/go/thrift/server.go b/src/jaegertracing/thrift/lib/go/thrift/server.go new file mode 100644 index 000000000..f813fa353 --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/thrift/server.go @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +type TServer interface { + ProcessorFactory() TProcessorFactory + ServerTransport() TServerTransport + InputTransportFactory() TTransportFactory + OutputTransportFactory() TTransportFactory + InputProtocolFactory() TProtocolFactory + OutputProtocolFactory() TProtocolFactory + + // Starts the server + Serve() error + // Stops the server. This is optional on a per-implementation basis. Not + // all servers are required to be cleanly stoppable. + Stop() error +} diff --git a/src/jaegertracing/thrift/lib/go/thrift/server_socket.go b/src/jaegertracing/thrift/lib/go/thrift/server_socket.go new file mode 100644 index 000000000..7dd24ae36 --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/thrift/server_socket.go @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import ( + "net" + "sync" + "time" +) + +type TServerSocket struct { + listener net.Listener + addr net.Addr + clientTimeout time.Duration + + // Protects the interrupted value to make it thread safe. + mu sync.RWMutex + interrupted bool +} + +func NewTServerSocket(listenAddr string) (*TServerSocket, error) { + return NewTServerSocketTimeout(listenAddr, 0) +} + +func NewTServerSocketTimeout(listenAddr string, clientTimeout time.Duration) (*TServerSocket, error) { + addr, err := net.ResolveTCPAddr("tcp", listenAddr) + if err != nil { + return nil, err + } + return &TServerSocket{addr: addr, clientTimeout: clientTimeout}, nil +} + +// Creates a TServerSocket from a net.Addr +func NewTServerSocketFromAddrTimeout(addr net.Addr, clientTimeout time.Duration) *TServerSocket { + return &TServerSocket{addr: addr, clientTimeout: clientTimeout} +} + +func (p *TServerSocket) Listen() error { + p.mu.Lock() + defer p.mu.Unlock() + if p.IsListening() { + return nil + } + l, err := net.Listen(p.addr.Network(), p.addr.String()) + if err != nil { + return err + } + p.listener = l + return nil +} + +func (p *TServerSocket) Accept() (TTransport, error) { + p.mu.RLock() + interrupted := p.interrupted + p.mu.RUnlock() + + if interrupted { + return nil, errTransportInterrupted + } + + p.mu.Lock() + listener := p.listener + p.mu.Unlock() + if listener == nil { + return nil, NewTTransportException(NOT_OPEN, "No underlying server socket") + } + + conn, err := listener.Accept() + if err != nil { + return nil, NewTTransportExceptionFromError(err) + } + return NewTSocketFromConnTimeout(conn, p.clientTimeout), nil +} + +// Checks whether the socket is listening. +func (p *TServerSocket) IsListening() bool { + return p.listener != nil +} + +// Connects the socket, creating a new socket object if necessary. +func (p *TServerSocket) Open() error { + p.mu.Lock() + defer p.mu.Unlock() + if p.IsListening() { + return NewTTransportException(ALREADY_OPEN, "Server socket already open") + } + if l, err := net.Listen(p.addr.Network(), p.addr.String()); err != nil { + return err + } else { + p.listener = l + } + return nil +} + +func (p *TServerSocket) Addr() net.Addr { + if p.listener != nil { + return p.listener.Addr() + } + return p.addr +} + +func (p *TServerSocket) Close() error { + var err error + p.mu.Lock() + if p.IsListening() { + err = p.listener.Close() + p.listener = nil + } + p.mu.Unlock() + return err +} + +func (p *TServerSocket) Interrupt() error { + p.mu.Lock() + p.interrupted = true + p.mu.Unlock() + p.Close() + + return nil +} diff --git a/src/jaegertracing/thrift/lib/go/thrift/server_socket_test.go b/src/jaegertracing/thrift/lib/go/thrift/server_socket_test.go new file mode 100644 index 000000000..f1e1983a9 --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/thrift/server_socket_test.go @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import ( + "fmt" + "testing" +) + +func TestSocketIsntListeningAfterInterrupt(t *testing.T) { + host := "127.0.0.1" + port := 9090 + addr := fmt.Sprintf("%s:%d", host, port) + + socket := CreateServerSocket(t, addr) + socket.Listen() + socket.Interrupt() + + newSocket := CreateServerSocket(t, addr) + err := newSocket.Listen() + defer newSocket.Interrupt() + if err != nil { + t.Fatalf("Failed to rebinds: %s", err) + } +} + +func TestSocketConcurrency(t *testing.T) { + host := "127.0.0.1" + port := 9090 + addr := fmt.Sprintf("%s:%d", host, port) + + socket := CreateServerSocket(t, addr) + go func() { socket.Listen() }() + go func() { socket.Interrupt() }() +} + +func CreateServerSocket(t *testing.T, addr string) *TServerSocket { + socket, err := NewTServerSocket(addr) + if err != nil { + t.Fatalf("Failed to create server socket: %s", err) + } + return socket +} diff --git a/src/jaegertracing/thrift/lib/go/thrift/server_test.go b/src/jaegertracing/thrift/lib/go/thrift/server_test.go new file mode 100644 index 000000000..ffaf45702 --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/thrift/server_test.go @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import ( + "testing" +) + +func TestNothing(t *testing.T) { + +} diff --git a/src/jaegertracing/thrift/lib/go/thrift/server_transport.go b/src/jaegertracing/thrift/lib/go/thrift/server_transport.go new file mode 100644 index 000000000..51c40b64a --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/thrift/server_transport.go @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +// Server transport. Object which provides client transports. +type TServerTransport interface { + Listen() error + Accept() (TTransport, error) + Close() error + + // Optional method implementation. This signals to the server transport + // that it should break out of any accept() or listen() that it is currently + // blocked on. This method, if implemented, MUST be thread safe, as it may + // be called from a different thread context than the other TServerTransport + // methods. + Interrupt() error +} diff --git a/src/jaegertracing/thrift/lib/go/thrift/simple_json_protocol.go b/src/jaegertracing/thrift/lib/go/thrift/simple_json_protocol.go new file mode 100644 index 000000000..f5e0c05d1 --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/thrift/simple_json_protocol.go @@ -0,0 +1,1338 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import ( + "bufio" + "bytes" + "context" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "math" + "strconv" +) + +type _ParseContext int + +const ( + _CONTEXT_IN_TOPLEVEL _ParseContext = 1 + _CONTEXT_IN_LIST_FIRST _ParseContext = 2 + _CONTEXT_IN_LIST _ParseContext = 3 + _CONTEXT_IN_OBJECT_FIRST _ParseContext = 4 + _CONTEXT_IN_OBJECT_NEXT_KEY _ParseContext = 5 + _CONTEXT_IN_OBJECT_NEXT_VALUE _ParseContext = 6 +) + +func (p _ParseContext) String() string { + switch p { + case _CONTEXT_IN_TOPLEVEL: + return "TOPLEVEL" + case _CONTEXT_IN_LIST_FIRST: + return "LIST-FIRST" + case _CONTEXT_IN_LIST: + return "LIST" + case _CONTEXT_IN_OBJECT_FIRST: + return "OBJECT-FIRST" + case _CONTEXT_IN_OBJECT_NEXT_KEY: + return "OBJECT-NEXT-KEY" + case _CONTEXT_IN_OBJECT_NEXT_VALUE: + return "OBJECT-NEXT-VALUE" + } + return "UNKNOWN-PARSE-CONTEXT" +} + +// Simple JSON protocol implementation for thrift. +// +// This protocol produces/consumes a simple output format +// suitable for parsing by scripting languages. It should not be +// confused with the full-featured TJSONProtocol. +// +type TSimpleJSONProtocol struct { + trans TTransport + + parseContextStack []int + dumpContext []int + + writer *bufio.Writer + reader *bufio.Reader +} + +// Constructor +func NewTSimpleJSONProtocol(t TTransport) *TSimpleJSONProtocol { + v := &TSimpleJSONProtocol{trans: t, + writer: bufio.NewWriter(t), + reader: bufio.NewReader(t), + } + v.parseContextStack = append(v.parseContextStack, int(_CONTEXT_IN_TOPLEVEL)) + v.dumpContext = append(v.dumpContext, int(_CONTEXT_IN_TOPLEVEL)) + return v +} + +// Factory +type TSimpleJSONProtocolFactory struct{} + +func (p *TSimpleJSONProtocolFactory) GetProtocol(trans TTransport) TProtocol { + return NewTSimpleJSONProtocol(trans) +} + +func NewTSimpleJSONProtocolFactory() *TSimpleJSONProtocolFactory { + return &TSimpleJSONProtocolFactory{} +} + +var ( + JSON_COMMA []byte + JSON_COLON []byte + JSON_LBRACE []byte + JSON_RBRACE []byte + JSON_LBRACKET []byte + JSON_RBRACKET []byte + JSON_QUOTE byte + JSON_QUOTE_BYTES []byte + JSON_NULL []byte + JSON_TRUE []byte + JSON_FALSE []byte + JSON_INFINITY string + JSON_NEGATIVE_INFINITY string + JSON_NAN string + JSON_INFINITY_BYTES []byte + JSON_NEGATIVE_INFINITY_BYTES []byte + JSON_NAN_BYTES []byte + json_nonbase_map_elem_bytes []byte +) + +func init() { + JSON_COMMA = []byte{','} + JSON_COLON = []byte{':'} + JSON_LBRACE = []byte{'{'} + JSON_RBRACE = []byte{'}'} + JSON_LBRACKET = []byte{'['} + JSON_RBRACKET = []byte{']'} + JSON_QUOTE = '"' + JSON_QUOTE_BYTES = []byte{'"'} + JSON_NULL = []byte{'n', 'u', 'l', 'l'} + JSON_TRUE = []byte{'t', 'r', 'u', 'e'} + JSON_FALSE = []byte{'f', 'a', 'l', 's', 'e'} + JSON_INFINITY = "Infinity" + JSON_NEGATIVE_INFINITY = "-Infinity" + JSON_NAN = "NaN" + JSON_INFINITY_BYTES = []byte{'I', 'n', 'f', 'i', 'n', 'i', 't', 'y'} + JSON_NEGATIVE_INFINITY_BYTES = []byte{'-', 'I', 'n', 'f', 'i', 'n', 'i', 't', 'y'} + JSON_NAN_BYTES = []byte{'N', 'a', 'N'} + json_nonbase_map_elem_bytes = []byte{']', ',', '['} +} + +func jsonQuote(s string) string { + b, _ := json.Marshal(s) + s1 := string(b) + return s1 +} + +func jsonUnquote(s string) (string, bool) { + s1 := new(string) + err := json.Unmarshal([]byte(s), s1) + return *s1, err == nil +} + +func mismatch(expected, actual string) error { + return fmt.Errorf("Expected '%s' but found '%s' while parsing JSON.", expected, actual) +} + +func (p *TSimpleJSONProtocol) WriteMessageBegin(name string, typeId TMessageType, seqId int32) error { + p.resetContextStack() // THRIFT-3735 + if e := p.OutputListBegin(); e != nil { + return e + } + if e := p.WriteString(name); e != nil { + return e + } + if e := p.WriteByte(int8(typeId)); e != nil { + return e + } + if e := p.WriteI32(seqId); e != nil { + return e + } + return nil +} + +func (p *TSimpleJSONProtocol) WriteMessageEnd() error { + return p.OutputListEnd() +} + +func (p *TSimpleJSONProtocol) WriteStructBegin(name string) error { + if e := p.OutputObjectBegin(); e != nil { + return e + } + return nil +} + +func (p *TSimpleJSONProtocol) WriteStructEnd() error { + return p.OutputObjectEnd() +} + +func (p *TSimpleJSONProtocol) WriteFieldBegin(name string, typeId TType, id int16) error { + if e := p.WriteString(name); e != nil { + return e + } + return nil +} + +func (p *TSimpleJSONProtocol) WriteFieldEnd() error { + //return p.OutputListEnd() + return nil +} + +func (p *TSimpleJSONProtocol) WriteFieldStop() error { return nil } + +func (p *TSimpleJSONProtocol) WriteMapBegin(keyType TType, valueType TType, size int) error { + if e := p.OutputListBegin(); e != nil { + return e + } + if e := p.WriteByte(int8(keyType)); e != nil { + return e + } + if e := p.WriteByte(int8(valueType)); e != nil { + return e + } + return p.WriteI32(int32(size)) +} + +func (p *TSimpleJSONProtocol) WriteMapEnd() error { + return p.OutputListEnd() +} + +func (p *TSimpleJSONProtocol) WriteListBegin(elemType TType, size int) error { + return p.OutputElemListBegin(elemType, size) +} + +func (p *TSimpleJSONProtocol) WriteListEnd() error { + return p.OutputListEnd() +} + +func (p *TSimpleJSONProtocol) WriteSetBegin(elemType TType, size int) error { + return p.OutputElemListBegin(elemType, size) +} + +func (p *TSimpleJSONProtocol) WriteSetEnd() error { + return p.OutputListEnd() +} + +func (p *TSimpleJSONProtocol) WriteBool(b bool) error { + return p.OutputBool(b) +} + +func (p *TSimpleJSONProtocol) WriteByte(b int8) error { + return p.WriteI32(int32(b)) +} + +func (p *TSimpleJSONProtocol) WriteI16(v int16) error { + return p.WriteI32(int32(v)) +} + +func (p *TSimpleJSONProtocol) WriteI32(v int32) error { + return p.OutputI64(int64(v)) +} + +func (p *TSimpleJSONProtocol) WriteI64(v int64) error { + return p.OutputI64(int64(v)) +} + +func (p *TSimpleJSONProtocol) WriteDouble(v float64) error { + return p.OutputF64(v) +} + +func (p *TSimpleJSONProtocol) WriteString(v string) error { + return p.OutputString(v) +} + +func (p *TSimpleJSONProtocol) WriteBinary(v []byte) error { + // JSON library only takes in a string, + // not an arbitrary byte array, to ensure bytes are transmitted + // efficiently we must convert this into a valid JSON string + // therefore we use base64 encoding to avoid excessive escaping/quoting + if e := p.OutputPreValue(); e != nil { + return e + } + if _, e := p.write(JSON_QUOTE_BYTES); e != nil { + return NewTProtocolException(e) + } + writer := base64.NewEncoder(base64.StdEncoding, p.writer) + if _, e := writer.Write(v); e != nil { + p.writer.Reset(p.trans) // THRIFT-3735 + return NewTProtocolException(e) + } + if e := writer.Close(); e != nil { + return NewTProtocolException(e) + } + if _, e := p.write(JSON_QUOTE_BYTES); e != nil { + return NewTProtocolException(e) + } + return p.OutputPostValue() +} + +// Reading methods. +func (p *TSimpleJSONProtocol) ReadMessageBegin() (name string, typeId TMessageType, seqId int32, err error) { + p.resetContextStack() // THRIFT-3735 + if isNull, err := p.ParseListBegin(); isNull || err != nil { + return name, typeId, seqId, err + } + if name, err = p.ReadString(); err != nil { + return name, typeId, seqId, err + } + bTypeId, err := p.ReadByte() + typeId = TMessageType(bTypeId) + if err != nil { + return name, typeId, seqId, err + } + if seqId, err = p.ReadI32(); err != nil { + return name, typeId, seqId, err + } + return name, typeId, seqId, nil +} + +func (p *TSimpleJSONProtocol) ReadMessageEnd() error { + return p.ParseListEnd() +} + +func (p *TSimpleJSONProtocol) ReadStructBegin() (name string, err error) { + _, err = p.ParseObjectStart() + return "", err +} + +func (p *TSimpleJSONProtocol) ReadStructEnd() error { + return p.ParseObjectEnd() +} + +func (p *TSimpleJSONProtocol) ReadFieldBegin() (string, TType, int16, error) { + if err := p.ParsePreValue(); err != nil { + return "", STOP, 0, err + } + b, _ := p.reader.Peek(1) + if len(b) > 0 { + switch b[0] { + case JSON_RBRACE[0]: + return "", STOP, 0, nil + case JSON_QUOTE: + p.reader.ReadByte() + name, err := p.ParseStringBody() + // simplejson is not meant to be read back into thrift + // - see http://wiki.apache.org/thrift/ThriftUsageJava + // - use JSON instead + if err != nil { + return name, STOP, 0, err + } + return name, STOP, -1, p.ParsePostValue() + /* + if err = p.ParsePostValue(); err != nil { + return name, STOP, 0, err + } + if isNull, err := p.ParseListBegin(); isNull || err != nil { + return name, STOP, 0, err + } + bType, err := p.ReadByte() + thetype := TType(bType) + if err != nil { + return name, thetype, 0, err + } + id, err := p.ReadI16() + return name, thetype, id, err + */ + } + e := fmt.Errorf("Expected \"}\" or '\"', but found: '%s'", string(b)) + return "", STOP, 0, NewTProtocolExceptionWithType(INVALID_DATA, e) + } + return "", STOP, 0, NewTProtocolException(io.EOF) +} + +func (p *TSimpleJSONProtocol) ReadFieldEnd() error { + return nil + //return p.ParseListEnd() +} + +func (p *TSimpleJSONProtocol) ReadMapBegin() (keyType TType, valueType TType, size int, e error) { + if isNull, e := p.ParseListBegin(); isNull || e != nil { + return VOID, VOID, 0, e + } + + // read keyType + bKeyType, e := p.ReadByte() + keyType = TType(bKeyType) + if e != nil { + return keyType, valueType, size, e + } + + // read valueType + bValueType, e := p.ReadByte() + valueType = TType(bValueType) + if e != nil { + return keyType, valueType, size, e + } + + // read size + iSize, err := p.ReadI64() + size = int(iSize) + return keyType, valueType, size, err +} + +func (p *TSimpleJSONProtocol) ReadMapEnd() error { + return p.ParseListEnd() +} + +func (p *TSimpleJSONProtocol) ReadListBegin() (elemType TType, size int, e error) { + return p.ParseElemListBegin() +} + +func (p *TSimpleJSONProtocol) ReadListEnd() error { + return p.ParseListEnd() +} + +func (p *TSimpleJSONProtocol) ReadSetBegin() (elemType TType, size int, e error) { + return p.ParseElemListBegin() +} + +func (p *TSimpleJSONProtocol) ReadSetEnd() error { + return p.ParseListEnd() +} + +func (p *TSimpleJSONProtocol) ReadBool() (bool, error) { + var value bool + + if err := p.ParsePreValue(); err != nil { + return value, err + } + f, _ := p.reader.Peek(1) + if len(f) > 0 { + switch f[0] { + case JSON_TRUE[0]: + b := make([]byte, len(JSON_TRUE)) + _, err := p.reader.Read(b) + if err != nil { + return false, NewTProtocolException(err) + } + if string(b) == string(JSON_TRUE) { + value = true + } else { + e := fmt.Errorf("Expected \"true\" but found: %s", string(b)) + return value, NewTProtocolExceptionWithType(INVALID_DATA, e) + } + break + case JSON_FALSE[0]: + b := make([]byte, len(JSON_FALSE)) + _, err := p.reader.Read(b) + if err != nil { + return false, NewTProtocolException(err) + } + if string(b) == string(JSON_FALSE) { + value = false + } else { + e := fmt.Errorf("Expected \"false\" but found: %s", string(b)) + return value, NewTProtocolExceptionWithType(INVALID_DATA, e) + } + break + case JSON_NULL[0]: + b := make([]byte, len(JSON_NULL)) + _, err := p.reader.Read(b) + if err != nil { + return false, NewTProtocolException(err) + } + if string(b) == string(JSON_NULL) { + value = false + } else { + e := fmt.Errorf("Expected \"null\" but found: %s", string(b)) + return value, NewTProtocolExceptionWithType(INVALID_DATA, e) + } + default: + e := fmt.Errorf("Expected \"true\", \"false\", or \"null\" but found: %s", string(f)) + return value, NewTProtocolExceptionWithType(INVALID_DATA, e) + } + } + return value, p.ParsePostValue() +} + +func (p *TSimpleJSONProtocol) ReadByte() (int8, error) { + v, err := p.ReadI64() + return int8(v), err +} + +func (p *TSimpleJSONProtocol) ReadI16() (int16, error) { + v, err := p.ReadI64() + return int16(v), err +} + +func (p *TSimpleJSONProtocol) ReadI32() (int32, error) { + v, err := p.ReadI64() + return int32(v), err +} + +func (p *TSimpleJSONProtocol) ReadI64() (int64, error) { + v, _, err := p.ParseI64() + return v, err +} + +func (p *TSimpleJSONProtocol) ReadDouble() (float64, error) { + v, _, err := p.ParseF64() + return v, err +} + +func (p *TSimpleJSONProtocol) ReadString() (string, error) { + var v string + if err := p.ParsePreValue(); err != nil { + return v, err + } + f, _ := p.reader.Peek(1) + if len(f) > 0 && f[0] == JSON_QUOTE { + p.reader.ReadByte() + value, err := p.ParseStringBody() + v = value + if err != nil { + return v, err + } + } else if len(f) > 0 && f[0] == JSON_NULL[0] { + b := make([]byte, len(JSON_NULL)) + _, err := p.reader.Read(b) + if err != nil { + return v, NewTProtocolException(err) + } + if string(b) != string(JSON_NULL) { + e := fmt.Errorf("Expected a JSON string, found unquoted data started with %s", string(b)) + return v, NewTProtocolExceptionWithType(INVALID_DATA, e) + } + } else { + e := fmt.Errorf("Expected a JSON string, found unquoted data started with %s", string(f)) + return v, NewTProtocolExceptionWithType(INVALID_DATA, e) + } + return v, p.ParsePostValue() +} + +func (p *TSimpleJSONProtocol) ReadBinary() ([]byte, error) { + var v []byte + if err := p.ParsePreValue(); err != nil { + return nil, err + } + f, _ := p.reader.Peek(1) + if len(f) > 0 && f[0] == JSON_QUOTE { + p.reader.ReadByte() + value, err := p.ParseBase64EncodedBody() + v = value + if err != nil { + return v, err + } + } else if len(f) > 0 && f[0] == JSON_NULL[0] { + b := make([]byte, len(JSON_NULL)) + _, err := p.reader.Read(b) + if err != nil { + return v, NewTProtocolException(err) + } + if string(b) != string(JSON_NULL) { + e := fmt.Errorf("Expected a JSON string, found unquoted data started with %s", string(b)) + return v, NewTProtocolExceptionWithType(INVALID_DATA, e) + } + } else { + e := fmt.Errorf("Expected a JSON string, found unquoted data started with %s", string(f)) + return v, NewTProtocolExceptionWithType(INVALID_DATA, e) + } + + return v, p.ParsePostValue() +} + +func (p *TSimpleJSONProtocol) Flush(ctx context.Context) (err error) { + return NewTProtocolException(p.writer.Flush()) +} + +func (p *TSimpleJSONProtocol) Skip(fieldType TType) (err error) { + return SkipDefaultDepth(p, fieldType) +} + +func (p *TSimpleJSONProtocol) Transport() TTransport { + return p.trans +} + +func (p *TSimpleJSONProtocol) OutputPreValue() error { + cxt := _ParseContext(p.dumpContext[len(p.dumpContext)-1]) + switch cxt { + case _CONTEXT_IN_LIST, _CONTEXT_IN_OBJECT_NEXT_KEY: + if _, e := p.write(JSON_COMMA); e != nil { + return NewTProtocolException(e) + } + break + case _CONTEXT_IN_OBJECT_NEXT_VALUE: + if _, e := p.write(JSON_COLON); e != nil { + return NewTProtocolException(e) + } + break + } + return nil +} + +func (p *TSimpleJSONProtocol) OutputPostValue() error { + cxt := _ParseContext(p.dumpContext[len(p.dumpContext)-1]) + switch cxt { + case _CONTEXT_IN_LIST_FIRST: + p.dumpContext = p.dumpContext[:len(p.dumpContext)-1] + p.dumpContext = append(p.dumpContext, int(_CONTEXT_IN_LIST)) + break + case _CONTEXT_IN_OBJECT_FIRST: + p.dumpContext = p.dumpContext[:len(p.dumpContext)-1] + p.dumpContext = append(p.dumpContext, int(_CONTEXT_IN_OBJECT_NEXT_VALUE)) + break + case _CONTEXT_IN_OBJECT_NEXT_KEY: + p.dumpContext = p.dumpContext[:len(p.dumpContext)-1] + p.dumpContext = append(p.dumpContext, int(_CONTEXT_IN_OBJECT_NEXT_VALUE)) + break + case _CONTEXT_IN_OBJECT_NEXT_VALUE: + p.dumpContext = p.dumpContext[:len(p.dumpContext)-1] + p.dumpContext = append(p.dumpContext, int(_CONTEXT_IN_OBJECT_NEXT_KEY)) + break + } + return nil +} + +func (p *TSimpleJSONProtocol) OutputBool(value bool) error { + if e := p.OutputPreValue(); e != nil { + return e + } + var v string + if value { + v = string(JSON_TRUE) + } else { + v = string(JSON_FALSE) + } + switch _ParseContext(p.dumpContext[len(p.dumpContext)-1]) { + case _CONTEXT_IN_OBJECT_FIRST, _CONTEXT_IN_OBJECT_NEXT_KEY: + v = jsonQuote(v) + default: + } + if e := p.OutputStringData(v); e != nil { + return e + } + return p.OutputPostValue() +} + +func (p *TSimpleJSONProtocol) OutputNull() error { + if e := p.OutputPreValue(); e != nil { + return e + } + if _, e := p.write(JSON_NULL); e != nil { + return NewTProtocolException(e) + } + return p.OutputPostValue() +} + +func (p *TSimpleJSONProtocol) OutputF64(value float64) error { + if e := p.OutputPreValue(); e != nil { + return e + } + var v string + if math.IsNaN(value) { + v = string(JSON_QUOTE) + JSON_NAN + string(JSON_QUOTE) + } else if math.IsInf(value, 1) { + v = string(JSON_QUOTE) + JSON_INFINITY + string(JSON_QUOTE) + } else if math.IsInf(value, -1) { + v = string(JSON_QUOTE) + JSON_NEGATIVE_INFINITY + string(JSON_QUOTE) + } else { + v = strconv.FormatFloat(value, 'g', -1, 64) + switch _ParseContext(p.dumpContext[len(p.dumpContext)-1]) { + case _CONTEXT_IN_OBJECT_FIRST, _CONTEXT_IN_OBJECT_NEXT_KEY: + v = string(JSON_QUOTE) + v + string(JSON_QUOTE) + default: + } + } + if e := p.OutputStringData(v); e != nil { + return e + } + return p.OutputPostValue() +} + +func (p *TSimpleJSONProtocol) OutputI64(value int64) error { + if e := p.OutputPreValue(); e != nil { + return e + } + v := strconv.FormatInt(value, 10) + switch _ParseContext(p.dumpContext[len(p.dumpContext)-1]) { + case _CONTEXT_IN_OBJECT_FIRST, _CONTEXT_IN_OBJECT_NEXT_KEY: + v = jsonQuote(v) + default: + } + if e := p.OutputStringData(v); e != nil { + return e + } + return p.OutputPostValue() +} + +func (p *TSimpleJSONProtocol) OutputString(s string) error { + if e := p.OutputPreValue(); e != nil { + return e + } + if e := p.OutputStringData(jsonQuote(s)); e != nil { + return e + } + return p.OutputPostValue() +} + +func (p *TSimpleJSONProtocol) OutputStringData(s string) error { + _, e := p.write([]byte(s)) + return NewTProtocolException(e) +} + +func (p *TSimpleJSONProtocol) OutputObjectBegin() error { + if e := p.OutputPreValue(); e != nil { + return e + } + if _, e := p.write(JSON_LBRACE); e != nil { + return NewTProtocolException(e) + } + p.dumpContext = append(p.dumpContext, int(_CONTEXT_IN_OBJECT_FIRST)) + return nil +} + +func (p *TSimpleJSONProtocol) OutputObjectEnd() error { + if _, e := p.write(JSON_RBRACE); e != nil { + return NewTProtocolException(e) + } + p.dumpContext = p.dumpContext[:len(p.dumpContext)-1] + if e := p.OutputPostValue(); e != nil { + return e + } + return nil +} + +func (p *TSimpleJSONProtocol) OutputListBegin() error { + if e := p.OutputPreValue(); e != nil { + return e + } + if _, e := p.write(JSON_LBRACKET); e != nil { + return NewTProtocolException(e) + } + p.dumpContext = append(p.dumpContext, int(_CONTEXT_IN_LIST_FIRST)) + return nil +} + +func (p *TSimpleJSONProtocol) OutputListEnd() error { + if _, e := p.write(JSON_RBRACKET); e != nil { + return NewTProtocolException(e) + } + p.dumpContext = p.dumpContext[:len(p.dumpContext)-1] + if e := p.OutputPostValue(); e != nil { + return e + } + return nil +} + +func (p *TSimpleJSONProtocol) OutputElemListBegin(elemType TType, size int) error { + if e := p.OutputListBegin(); e != nil { + return e + } + if e := p.WriteByte(int8(elemType)); e != nil { + return e + } + if e := p.WriteI64(int64(size)); e != nil { + return e + } + return nil +} + +func (p *TSimpleJSONProtocol) ParsePreValue() error { + if e := p.readNonSignificantWhitespace(); e != nil { + return NewTProtocolException(e) + } + cxt := _ParseContext(p.parseContextStack[len(p.parseContextStack)-1]) + b, _ := p.reader.Peek(1) + switch cxt { + case _CONTEXT_IN_LIST: + if len(b) > 0 { + switch b[0] { + case JSON_RBRACKET[0]: + return nil + case JSON_COMMA[0]: + p.reader.ReadByte() + if e := p.readNonSignificantWhitespace(); e != nil { + return NewTProtocolException(e) + } + return nil + default: + e := fmt.Errorf("Expected \"]\" or \",\" in list context, but found \"%s\"", string(b)) + return NewTProtocolExceptionWithType(INVALID_DATA, e) + } + } + break + case _CONTEXT_IN_OBJECT_NEXT_KEY: + if len(b) > 0 { + switch b[0] { + case JSON_RBRACE[0]: + return nil + case JSON_COMMA[0]: + p.reader.ReadByte() + if e := p.readNonSignificantWhitespace(); e != nil { + return NewTProtocolException(e) + } + return nil + default: + e := fmt.Errorf("Expected \"}\" or \",\" in object context, but found \"%s\"", string(b)) + return NewTProtocolExceptionWithType(INVALID_DATA, e) + } + } + break + case _CONTEXT_IN_OBJECT_NEXT_VALUE: + if len(b) > 0 { + switch b[0] { + case JSON_COLON[0]: + p.reader.ReadByte() + if e := p.readNonSignificantWhitespace(); e != nil { + return NewTProtocolException(e) + } + return nil + default: + e := fmt.Errorf("Expected \":\" in object context, but found \"%s\"", string(b)) + return NewTProtocolExceptionWithType(INVALID_DATA, e) + } + } + break + } + return nil +} + +func (p *TSimpleJSONProtocol) ParsePostValue() error { + if e := p.readNonSignificantWhitespace(); e != nil { + return NewTProtocolException(e) + } + cxt := _ParseContext(p.parseContextStack[len(p.parseContextStack)-1]) + switch cxt { + case _CONTEXT_IN_LIST_FIRST: + p.parseContextStack = p.parseContextStack[:len(p.parseContextStack)-1] + p.parseContextStack = append(p.parseContextStack, int(_CONTEXT_IN_LIST)) + break + case _CONTEXT_IN_OBJECT_FIRST, _CONTEXT_IN_OBJECT_NEXT_KEY: + p.parseContextStack = p.parseContextStack[:len(p.parseContextStack)-1] + p.parseContextStack = append(p.parseContextStack, int(_CONTEXT_IN_OBJECT_NEXT_VALUE)) + break + case _CONTEXT_IN_OBJECT_NEXT_VALUE: + p.parseContextStack = p.parseContextStack[:len(p.parseContextStack)-1] + p.parseContextStack = append(p.parseContextStack, int(_CONTEXT_IN_OBJECT_NEXT_KEY)) + break + } + return nil +} + +func (p *TSimpleJSONProtocol) readNonSignificantWhitespace() error { + for { + b, _ := p.reader.Peek(1) + if len(b) < 1 { + return nil + } + switch b[0] { + case ' ', '\r', '\n', '\t': + p.reader.ReadByte() + continue + default: + break + } + break + } + return nil +} + +func (p *TSimpleJSONProtocol) ParseStringBody() (string, error) { + line, err := p.reader.ReadString(JSON_QUOTE) + if err != nil { + return "", NewTProtocolException(err) + } + l := len(line) + // count number of escapes to see if we need to keep going + i := 1 + for ; i < l; i++ { + if line[l-i-1] != '\\' { + break + } + } + if i&0x01 == 1 { + v, ok := jsonUnquote(string(JSON_QUOTE) + line) + if !ok { + return "", NewTProtocolException(err) + } + return v, nil + } + s, err := p.ParseQuotedStringBody() + if err != nil { + return "", NewTProtocolException(err) + } + str := string(JSON_QUOTE) + line + s + v, ok := jsonUnquote(str) + if !ok { + e := fmt.Errorf("Unable to parse as JSON string %s", str) + return "", NewTProtocolExceptionWithType(INVALID_DATA, e) + } + return v, nil +} + +func (p *TSimpleJSONProtocol) ParseQuotedStringBody() (string, error) { + line, err := p.reader.ReadString(JSON_QUOTE) + if err != nil { + return "", NewTProtocolException(err) + } + l := len(line) + // count number of escapes to see if we need to keep going + i := 1 + for ; i < l; i++ { + if line[l-i-1] != '\\' { + break + } + } + if i&0x01 == 1 { + return line, nil + } + s, err := p.ParseQuotedStringBody() + if err != nil { + return "", NewTProtocolException(err) + } + v := line + s + return v, nil +} + +func (p *TSimpleJSONProtocol) ParseBase64EncodedBody() ([]byte, error) { + line, err := p.reader.ReadBytes(JSON_QUOTE) + if err != nil { + return line, NewTProtocolException(err) + } + line2 := line[0 : len(line)-1] + l := len(line2) + if (l % 4) != 0 { + pad := 4 - (l % 4) + fill := [...]byte{'=', '=', '='} + line2 = append(line2, fill[:pad]...) + l = len(line2) + } + output := make([]byte, base64.StdEncoding.DecodedLen(l)) + n, err := base64.StdEncoding.Decode(output, line2) + return output[0:n], NewTProtocolException(err) +} + +func (p *TSimpleJSONProtocol) ParseI64() (int64, bool, error) { + if err := p.ParsePreValue(); err != nil { + return 0, false, err + } + var value int64 + var isnull bool + if p.safePeekContains(JSON_NULL) { + p.reader.Read(make([]byte, len(JSON_NULL))) + isnull = true + } else { + num, err := p.readNumeric() + isnull = (num == nil) + if !isnull { + value = num.Int64() + } + if err != nil { + return value, isnull, err + } + } + return value, isnull, p.ParsePostValue() +} + +func (p *TSimpleJSONProtocol) ParseF64() (float64, bool, error) { + if err := p.ParsePreValue(); err != nil { + return 0, false, err + } + var value float64 + var isnull bool + if p.safePeekContains(JSON_NULL) { + p.reader.Read(make([]byte, len(JSON_NULL))) + isnull = true + } else { + num, err := p.readNumeric() + isnull = (num == nil) + if !isnull { + value = num.Float64() + } + if err != nil { + return value, isnull, err + } + } + return value, isnull, p.ParsePostValue() +} + +func (p *TSimpleJSONProtocol) ParseObjectStart() (bool, error) { + if err := p.ParsePreValue(); err != nil { + return false, err + } + var b []byte + b, err := p.reader.Peek(1) + if err != nil { + return false, err + } + if len(b) > 0 && b[0] == JSON_LBRACE[0] { + p.reader.ReadByte() + p.parseContextStack = append(p.parseContextStack, int(_CONTEXT_IN_OBJECT_FIRST)) + return false, nil + } else if p.safePeekContains(JSON_NULL) { + return true, nil + } + e := fmt.Errorf("Expected '{' or null, but found '%s'", string(b)) + return false, NewTProtocolExceptionWithType(INVALID_DATA, e) +} + +func (p *TSimpleJSONProtocol) ParseObjectEnd() error { + if isNull, err := p.readIfNull(); isNull || err != nil { + return err + } + cxt := _ParseContext(p.parseContextStack[len(p.parseContextStack)-1]) + if (cxt != _CONTEXT_IN_OBJECT_FIRST) && (cxt != _CONTEXT_IN_OBJECT_NEXT_KEY) { + e := fmt.Errorf("Expected to be in the Object Context, but not in Object Context (%d)", cxt) + return NewTProtocolExceptionWithType(INVALID_DATA, e) + } + line, err := p.reader.ReadString(JSON_RBRACE[0]) + if err != nil { + return NewTProtocolException(err) + } + for _, char := range line { + switch char { + default: + e := fmt.Errorf("Expecting end of object \"}\", but found: \"%s\"", line) + return NewTProtocolExceptionWithType(INVALID_DATA, e) + case ' ', '\n', '\r', '\t', '}': + break + } + } + p.parseContextStack = p.parseContextStack[:len(p.parseContextStack)-1] + return p.ParsePostValue() +} + +func (p *TSimpleJSONProtocol) ParseListBegin() (isNull bool, err error) { + if e := p.ParsePreValue(); e != nil { + return false, e + } + var b []byte + b, err = p.reader.Peek(1) + if err != nil { + return false, err + } + if len(b) >= 1 && b[0] == JSON_LBRACKET[0] { + p.parseContextStack = append(p.parseContextStack, int(_CONTEXT_IN_LIST_FIRST)) + p.reader.ReadByte() + isNull = false + } else if p.safePeekContains(JSON_NULL) { + isNull = true + } else { + err = fmt.Errorf("Expected \"null\" or \"[\", received %q", b) + } + return isNull, NewTProtocolExceptionWithType(INVALID_DATA, err) +} + +func (p *TSimpleJSONProtocol) ParseElemListBegin() (elemType TType, size int, e error) { + if isNull, e := p.ParseListBegin(); isNull || e != nil { + return VOID, 0, e + } + bElemType, err := p.ReadByte() + elemType = TType(bElemType) + if err != nil { + return elemType, size, err + } + nSize, err2 := p.ReadI64() + size = int(nSize) + return elemType, size, err2 +} + +func (p *TSimpleJSONProtocol) ParseListEnd() error { + if isNull, err := p.readIfNull(); isNull || err != nil { + return err + } + cxt := _ParseContext(p.parseContextStack[len(p.parseContextStack)-1]) + if cxt != _CONTEXT_IN_LIST { + e := fmt.Errorf("Expected to be in the List Context, but not in List Context (%d)", cxt) + return NewTProtocolExceptionWithType(INVALID_DATA, e) + } + line, err := p.reader.ReadString(JSON_RBRACKET[0]) + if err != nil { + return NewTProtocolException(err) + } + for _, char := range line { + switch char { + default: + e := fmt.Errorf("Expecting end of list \"]\", but found: \"%v\"", line) + return NewTProtocolExceptionWithType(INVALID_DATA, e) + case ' ', '\n', '\r', '\t', rune(JSON_RBRACKET[0]): + break + } + } + p.parseContextStack = p.parseContextStack[:len(p.parseContextStack)-1] + if _ParseContext(p.parseContextStack[len(p.parseContextStack)-1]) == _CONTEXT_IN_TOPLEVEL { + return nil + } + return p.ParsePostValue() +} + +func (p *TSimpleJSONProtocol) readSingleValue() (interface{}, TType, error) { + e := p.readNonSignificantWhitespace() + if e != nil { + return nil, VOID, NewTProtocolException(e) + } + b, e := p.reader.Peek(1) + if len(b) > 0 { + c := b[0] + switch c { + case JSON_NULL[0]: + buf := make([]byte, len(JSON_NULL)) + _, e := p.reader.Read(buf) + if e != nil { + return nil, VOID, NewTProtocolException(e) + } + if string(JSON_NULL) != string(buf) { + e = mismatch(string(JSON_NULL), string(buf)) + return nil, VOID, NewTProtocolExceptionWithType(INVALID_DATA, e) + } + return nil, VOID, nil + case JSON_QUOTE: + p.reader.ReadByte() + v, e := p.ParseStringBody() + if e != nil { + return v, UTF8, NewTProtocolException(e) + } + if v == JSON_INFINITY { + return INFINITY, DOUBLE, nil + } else if v == JSON_NEGATIVE_INFINITY { + return NEGATIVE_INFINITY, DOUBLE, nil + } else if v == JSON_NAN { + return NAN, DOUBLE, nil + } + return v, UTF8, nil + case JSON_TRUE[0]: + buf := make([]byte, len(JSON_TRUE)) + _, e := p.reader.Read(buf) + if e != nil { + return true, BOOL, NewTProtocolException(e) + } + if string(JSON_TRUE) != string(buf) { + e := mismatch(string(JSON_TRUE), string(buf)) + return true, BOOL, NewTProtocolExceptionWithType(INVALID_DATA, e) + } + return true, BOOL, nil + case JSON_FALSE[0]: + buf := make([]byte, len(JSON_FALSE)) + _, e := p.reader.Read(buf) + if e != nil { + return false, BOOL, NewTProtocolException(e) + } + if string(JSON_FALSE) != string(buf) { + e := mismatch(string(JSON_FALSE), string(buf)) + return false, BOOL, NewTProtocolExceptionWithType(INVALID_DATA, e) + } + return false, BOOL, nil + case JSON_LBRACKET[0]: + _, e := p.reader.ReadByte() + return make([]interface{}, 0), LIST, NewTProtocolException(e) + case JSON_LBRACE[0]: + _, e := p.reader.ReadByte() + return make(map[string]interface{}), STRUCT, NewTProtocolException(e) + case '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'e', 'E', '.', '+', '-', JSON_INFINITY[0], JSON_NAN[0]: + // assume numeric + v, e := p.readNumeric() + return v, DOUBLE, e + default: + e := fmt.Errorf("Expected element in list but found '%s' while parsing JSON.", string(c)) + return nil, VOID, NewTProtocolExceptionWithType(INVALID_DATA, e) + } + } + e = fmt.Errorf("Cannot read a single element while parsing JSON.") + return nil, VOID, NewTProtocolExceptionWithType(INVALID_DATA, e) + +} + +func (p *TSimpleJSONProtocol) readIfNull() (bool, error) { + cont := true + for cont { + b, _ := p.reader.Peek(1) + if len(b) < 1 { + return false, nil + } + switch b[0] { + default: + return false, nil + case JSON_NULL[0]: + cont = false + break + case ' ', '\n', '\r', '\t': + p.reader.ReadByte() + break + } + } + if p.safePeekContains(JSON_NULL) { + p.reader.Read(make([]byte, len(JSON_NULL))) + return true, nil + } + return false, nil +} + +func (p *TSimpleJSONProtocol) readQuoteIfNext() { + b, _ := p.reader.Peek(1) + if len(b) > 0 && b[0] == JSON_QUOTE { + p.reader.ReadByte() + } +} + +func (p *TSimpleJSONProtocol) readNumeric() (Numeric, error) { + isNull, err := p.readIfNull() + if isNull || err != nil { + return NUMERIC_NULL, err + } + hasDecimalPoint := false + nextCanBeSign := true + hasE := false + MAX_LEN := 40 + buf := bytes.NewBuffer(make([]byte, 0, MAX_LEN)) + continueFor := true + inQuotes := false + for continueFor { + c, err := p.reader.ReadByte() + if err != nil { + if err == io.EOF { + break + } + return NUMERIC_NULL, NewTProtocolException(err) + } + switch c { + case '0', '1', '2', '3', '4', '5', '6', '7', '8', '9': + buf.WriteByte(c) + nextCanBeSign = false + case '.': + if hasDecimalPoint { + e := fmt.Errorf("Unable to parse number with multiple decimal points '%s.'", buf.String()) + return NUMERIC_NULL, NewTProtocolExceptionWithType(INVALID_DATA, e) + } + if hasE { + e := fmt.Errorf("Unable to parse number with decimal points in the exponent '%s.'", buf.String()) + return NUMERIC_NULL, NewTProtocolExceptionWithType(INVALID_DATA, e) + } + buf.WriteByte(c) + hasDecimalPoint, nextCanBeSign = true, false + case 'e', 'E': + if hasE { + e := fmt.Errorf("Unable to parse number with multiple exponents '%s%c'", buf.String(), c) + return NUMERIC_NULL, NewTProtocolExceptionWithType(INVALID_DATA, e) + } + buf.WriteByte(c) + hasE, nextCanBeSign = true, true + case '-', '+': + if !nextCanBeSign { + e := fmt.Errorf("Negative sign within number") + return NUMERIC_NULL, NewTProtocolExceptionWithType(INVALID_DATA, e) + } + buf.WriteByte(c) + nextCanBeSign = false + case ' ', 0, '\t', '\n', '\r', JSON_RBRACE[0], JSON_RBRACKET[0], JSON_COMMA[0], JSON_COLON[0]: + p.reader.UnreadByte() + continueFor = false + case JSON_NAN[0]: + if buf.Len() == 0 { + buffer := make([]byte, len(JSON_NAN)) + buffer[0] = c + _, e := p.reader.Read(buffer[1:]) + if e != nil { + return NUMERIC_NULL, NewTProtocolException(e) + } + if JSON_NAN != string(buffer) { + e := mismatch(JSON_NAN, string(buffer)) + return NUMERIC_NULL, NewTProtocolExceptionWithType(INVALID_DATA, e) + } + if inQuotes { + p.readQuoteIfNext() + } + return NAN, nil + } else { + e := fmt.Errorf("Unable to parse number starting with character '%c'", c) + return NUMERIC_NULL, NewTProtocolExceptionWithType(INVALID_DATA, e) + } + case JSON_INFINITY[0]: + if buf.Len() == 0 || (buf.Len() == 1 && buf.Bytes()[0] == '+') { + buffer := make([]byte, len(JSON_INFINITY)) + buffer[0] = c + _, e := p.reader.Read(buffer[1:]) + if e != nil { + return NUMERIC_NULL, NewTProtocolException(e) + } + if JSON_INFINITY != string(buffer) { + e := mismatch(JSON_INFINITY, string(buffer)) + return NUMERIC_NULL, NewTProtocolExceptionWithType(INVALID_DATA, e) + } + if inQuotes { + p.readQuoteIfNext() + } + return INFINITY, nil + } else if buf.Len() == 1 && buf.Bytes()[0] == JSON_NEGATIVE_INFINITY[0] { + buffer := make([]byte, len(JSON_NEGATIVE_INFINITY)) + buffer[0] = JSON_NEGATIVE_INFINITY[0] + buffer[1] = c + _, e := p.reader.Read(buffer[2:]) + if e != nil { + return NUMERIC_NULL, NewTProtocolException(e) + } + if JSON_NEGATIVE_INFINITY != string(buffer) { + e := mismatch(JSON_NEGATIVE_INFINITY, string(buffer)) + return NUMERIC_NULL, NewTProtocolExceptionWithType(INVALID_DATA, e) + } + if inQuotes { + p.readQuoteIfNext() + } + return NEGATIVE_INFINITY, nil + } else { + e := fmt.Errorf("Unable to parse number starting with character '%c' due to existing buffer %s", c, buf.String()) + return NUMERIC_NULL, NewTProtocolExceptionWithType(INVALID_DATA, e) + } + case JSON_QUOTE: + if !inQuotes { + inQuotes = true + } else { + break + } + default: + e := fmt.Errorf("Unable to parse number starting with character '%c'", c) + return NUMERIC_NULL, NewTProtocolExceptionWithType(INVALID_DATA, e) + } + } + if buf.Len() == 0 { + e := fmt.Errorf("Unable to parse number from empty string ''") + return NUMERIC_NULL, NewTProtocolExceptionWithType(INVALID_DATA, e) + } + return NewNumericFromJSONString(buf.String(), false), nil +} + +// Safely peeks into the buffer, reading only what is necessary +func (p *TSimpleJSONProtocol) safePeekContains(b []byte) bool { + for i := 0; i < len(b); i++ { + a, _ := p.reader.Peek(i + 1) + if len(a) < (i+1) || a[i] != b[i] { + return false + } + } + return true +} + +// Reset the context stack to its initial state. +func (p *TSimpleJSONProtocol) resetContextStack() { + p.parseContextStack = []int{int(_CONTEXT_IN_TOPLEVEL)} + p.dumpContext = []int{int(_CONTEXT_IN_TOPLEVEL)} +} + +func (p *TSimpleJSONProtocol) write(b []byte) (int, error) { + n, err := p.writer.Write(b) + if err != nil { + p.writer.Reset(p.trans) // THRIFT-3735 + } + return n, err +} diff --git a/src/jaegertracing/thrift/lib/go/thrift/simple_json_protocol_test.go b/src/jaegertracing/thrift/lib/go/thrift/simple_json_protocol_test.go new file mode 100644 index 000000000..0126da0a8 --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/thrift/simple_json_protocol_test.go @@ -0,0 +1,738 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import ( + "context" + "encoding/base64" + "encoding/json" + "fmt" + "math" + "strconv" + "strings" + "testing" +) + +func TestWriteSimpleJSONProtocolBool(t *testing.T) { + thetype := "boolean" + trans := NewTMemoryBuffer() + p := NewTSimpleJSONProtocol(trans) + for _, value := range BOOL_VALUES { + if e := p.WriteBool(value); e != nil { + t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error()) + } + if e := p.Flush(context.Background()); e != nil { + t.Fatalf("Unable to write %s value %v due to error flushing: %s", thetype, value, e.Error()) + } + s := trans.String() + if s != fmt.Sprint(value) { + t.Fatalf("Bad value for %s %v: %s", thetype, value, s) + } + v := false + if err := json.Unmarshal([]byte(s), &v); err != nil || v != value { + t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, v) + } + trans.Reset() + } + trans.Close() +} + +func TestReadSimpleJSONProtocolBool(t *testing.T) { + thetype := "boolean" + for _, value := range BOOL_VALUES { + trans := NewTMemoryBuffer() + p := NewTSimpleJSONProtocol(trans) + if value { + trans.Write(JSON_TRUE) + } else { + trans.Write(JSON_FALSE) + } + trans.Flush(context.Background()) + s := trans.String() + v, e := p.ReadBool() + if e != nil { + t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error()) + } + if v != value { + t.Fatalf("Bad value for %s value %v, wrote: %v, received: %v", thetype, value, s, v) + } + if err := json.Unmarshal([]byte(s), &v); err != nil || v != value { + t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, v) + } + trans.Reset() + trans.Close() + } +} + +func TestWriteSimpleJSONProtocolByte(t *testing.T) { + thetype := "byte" + trans := NewTMemoryBuffer() + p := NewTSimpleJSONProtocol(trans) + for _, value := range BYTE_VALUES { + if e := p.WriteByte(value); e != nil { + t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error()) + } + if e := p.Flush(context.Background()); e != nil { + t.Fatalf("Unable to write %s value %v due to error flushing: %s", thetype, value, e.Error()) + } + s := trans.String() + if s != fmt.Sprint(value) { + t.Fatalf("Bad value for %s %v: %s", thetype, value, s) + } + v := int8(0) + if err := json.Unmarshal([]byte(s), &v); err != nil || v != value { + t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, v) + } + trans.Reset() + } + trans.Close() +} + +func TestReadSimpleJSONProtocolByte(t *testing.T) { + thetype := "byte" + for _, value := range BYTE_VALUES { + trans := NewTMemoryBuffer() + p := NewTSimpleJSONProtocol(trans) + trans.WriteString(strconv.Itoa(int(value))) + trans.Flush(context.Background()) + s := trans.String() + v, e := p.ReadByte() + if e != nil { + t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error()) + } + if v != value { + t.Fatalf("Bad value for %s value %v, wrote: %v, received: %v", thetype, value, s, v) + } + if err := json.Unmarshal([]byte(s), &v); err != nil || v != value { + t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, v) + } + trans.Reset() + trans.Close() + } +} + +func TestWriteSimpleJSONProtocolI16(t *testing.T) { + thetype := "int16" + trans := NewTMemoryBuffer() + p := NewTSimpleJSONProtocol(trans) + for _, value := range INT16_VALUES { + if e := p.WriteI16(value); e != nil { + t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error()) + } + if e := p.Flush(context.Background()); e != nil { + t.Fatalf("Unable to write %s value %v due to error flushing: %s", thetype, value, e.Error()) + } + s := trans.String() + if s != fmt.Sprint(value) { + t.Fatalf("Bad value for %s %v: %s", thetype, value, s) + } + v := int16(0) + if err := json.Unmarshal([]byte(s), &v); err != nil || v != value { + t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, v) + } + trans.Reset() + } + trans.Close() +} + +func TestReadSimpleJSONProtocolI16(t *testing.T) { + thetype := "int16" + for _, value := range INT16_VALUES { + trans := NewTMemoryBuffer() + p := NewTSimpleJSONProtocol(trans) + trans.WriteString(strconv.Itoa(int(value))) + trans.Flush(context.Background()) + s := trans.String() + v, e := p.ReadI16() + if e != nil { + t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error()) + } + if v != value { + t.Fatalf("Bad value for %s value %v, wrote: %v, received: %v", thetype, value, s, v) + } + if err := json.Unmarshal([]byte(s), &v); err != nil || v != value { + t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, v) + } + trans.Reset() + trans.Close() + } +} + +func TestWriteSimpleJSONProtocolI32(t *testing.T) { + thetype := "int32" + trans := NewTMemoryBuffer() + p := NewTSimpleJSONProtocol(trans) + for _, value := range INT32_VALUES { + if e := p.WriteI32(value); e != nil { + t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error()) + } + if e := p.Flush(context.Background()); e != nil { + t.Fatalf("Unable to write %s value %v due to error flushing: %s", thetype, value, e.Error()) + } + s := trans.String() + if s != fmt.Sprint(value) { + t.Fatalf("Bad value for %s %v: %s", thetype, value, s) + } + v := int32(0) + if err := json.Unmarshal([]byte(s), &v); err != nil || v != value { + t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, v) + } + trans.Reset() + } + trans.Close() +} + +func TestReadSimpleJSONProtocolI32(t *testing.T) { + thetype := "int32" + for _, value := range INT32_VALUES { + trans := NewTMemoryBuffer() + p := NewTSimpleJSONProtocol(trans) + trans.WriteString(strconv.Itoa(int(value))) + trans.Flush(context.Background()) + s := trans.String() + v, e := p.ReadI32() + if e != nil { + t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error()) + } + if v != value { + t.Fatalf("Bad value for %s value %v, wrote: %v, received: %v", thetype, value, s, v) + } + if err := json.Unmarshal([]byte(s), &v); err != nil || v != value { + t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, v) + } + trans.Reset() + trans.Close() + } +} + +func TestReadSimpleJSONProtocolI32Null(t *testing.T) { + thetype := "int32" + value := "null" + + trans := NewTMemoryBuffer() + p := NewTSimpleJSONProtocol(trans) + trans.WriteString(value) + trans.Flush(context.Background()) + s := trans.String() + v, e := p.ReadI32() + + if e != nil { + t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error()) + } + if v != 0 { + t.Fatalf("Bad value for %s value %v, wrote: %v, received: %v", thetype, value, s, v) + } + trans.Reset() + trans.Close() +} + +func TestWriteSimpleJSONProtocolI64(t *testing.T) { + thetype := "int64" + trans := NewTMemoryBuffer() + p := NewTSimpleJSONProtocol(trans) + for _, value := range INT64_VALUES { + if e := p.WriteI64(value); e != nil { + t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error()) + } + if e := p.Flush(context.Background()); e != nil { + t.Fatalf("Unable to write %s value %v due to error flushing: %s", thetype, value, e.Error()) + } + s := trans.String() + if s != fmt.Sprint(value) { + t.Fatalf("Bad value for %s %v: %s", thetype, value, s) + } + v := int64(0) + if err := json.Unmarshal([]byte(s), &v); err != nil || v != value { + t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, v) + } + trans.Reset() + } + trans.Close() +} + +func TestReadSimpleJSONProtocolI64(t *testing.T) { + thetype := "int64" + for _, value := range INT64_VALUES { + trans := NewTMemoryBuffer() + p := NewTSimpleJSONProtocol(trans) + trans.WriteString(strconv.FormatInt(value, 10)) + trans.Flush(context.Background()) + s := trans.String() + v, e := p.ReadI64() + if e != nil { + t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error()) + } + if v != value { + t.Fatalf("Bad value for %s value %v, wrote: %v, received: %v", thetype, value, s, v) + } + if err := json.Unmarshal([]byte(s), &v); err != nil || v != value { + t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, v) + } + trans.Reset() + trans.Close() + } +} + +func TestReadSimpleJSONProtocolI64Null(t *testing.T) { + thetype := "int32" + value := "null" + + trans := NewTMemoryBuffer() + p := NewTSimpleJSONProtocol(trans) + trans.WriteString(value) + trans.Flush(context.Background()) + s := trans.String() + v, e := p.ReadI64() + + if e != nil { + t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error()) + } + if v != 0 { + t.Fatalf("Bad value for %s value %v, wrote: %v, received: %v", thetype, value, s, v) + } + trans.Reset() + trans.Close() +} + +func TestWriteSimpleJSONProtocolDouble(t *testing.T) { + thetype := "double" + trans := NewTMemoryBuffer() + p := NewTSimpleJSONProtocol(trans) + for _, value := range DOUBLE_VALUES { + if e := p.WriteDouble(value); e != nil { + t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error()) + } + if e := p.Flush(context.Background()); e != nil { + t.Fatalf("Unable to write %s value %v due to error flushing: %s", thetype, value, e.Error()) + } + s := trans.String() + if math.IsInf(value, 1) { + if s != jsonQuote(JSON_INFINITY) { + t.Fatalf("Bad value for %s %v, wrote: %v, expected: %v", thetype, value, s, jsonQuote(JSON_INFINITY)) + } + } else if math.IsInf(value, -1) { + if s != jsonQuote(JSON_NEGATIVE_INFINITY) { + t.Fatalf("Bad value for %s %v, wrote: %v, expected: %v", thetype, value, s, jsonQuote(JSON_NEGATIVE_INFINITY)) + } + } else if math.IsNaN(value) { + if s != jsonQuote(JSON_NAN) { + t.Fatalf("Bad value for %s %v, wrote: %v, expected: %v", thetype, value, s, jsonQuote(JSON_NAN)) + } + } else { + if s != fmt.Sprint(value) { + t.Fatalf("Bad value for %s %v: %s", thetype, value, s) + } + v := float64(0) + if err := json.Unmarshal([]byte(s), &v); err != nil || v != value { + t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, v) + } + } + trans.Reset() + } + trans.Close() +} + +func TestReadSimpleJSONProtocolDouble(t *testing.T) { + thetype := "double" + for _, value := range DOUBLE_VALUES { + trans := NewTMemoryBuffer() + p := NewTSimpleJSONProtocol(trans) + n := NewNumericFromDouble(value) + trans.WriteString(n.String()) + trans.Flush(context.Background()) + s := trans.String() + v, e := p.ReadDouble() + if e != nil { + t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error()) + } + if math.IsInf(value, 1) { + if !math.IsInf(v, 1) { + t.Fatalf("Bad value for %s %v, wrote: %v, received: %v", thetype, value, s, v) + } + } else if math.IsInf(value, -1) { + if !math.IsInf(v, -1) { + t.Fatalf("Bad value for %s %v, wrote: %v, received: %v", thetype, value, s, v) + } + } else if math.IsNaN(value) { + if !math.IsNaN(v) { + t.Fatalf("Bad value for %s %v, wrote: %v, received: %v", thetype, value, s, v) + } + } else { + if v != value { + t.Fatalf("Bad value for %s value %v, wrote: %v, received: %v", thetype, value, s, v) + } + if err := json.Unmarshal([]byte(s), &v); err != nil || v != value { + t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, v) + } + } + trans.Reset() + trans.Close() + } +} + +func TestWriteSimpleJSONProtocolString(t *testing.T) { + thetype := "string" + trans := NewTMemoryBuffer() + p := NewTSimpleJSONProtocol(trans) + for _, value := range STRING_VALUES { + if e := p.WriteString(value); e != nil { + t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error()) + } + if e := p.Flush(context.Background()); e != nil { + t.Fatalf("Unable to write %s value %v due to error flushing: %s", thetype, value, e.Error()) + } + s := trans.String() + if s[0] != '"' || s[len(s)-1] != '"' { + t.Fatalf("Bad value for %s '%v', wrote '%v', expected: %v", thetype, value, s, fmt.Sprint("\"", value, "\"")) + } + v := new(string) + if err := json.Unmarshal([]byte(s), v); err != nil || *v != value { + t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, *v) + } + trans.Reset() + } + trans.Close() +} + +func TestReadSimpleJSONProtocolString(t *testing.T) { + thetype := "string" + for _, value := range STRING_VALUES { + trans := NewTMemoryBuffer() + p := NewTSimpleJSONProtocol(trans) + trans.WriteString(jsonQuote(value)) + trans.Flush(context.Background()) + s := trans.String() + v, e := p.ReadString() + if e != nil { + t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error()) + } + if v != value { + t.Fatalf("Bad value for %s value %v, wrote: %v, received: %v", thetype, value, s, v) + } + v1 := new(string) + if err := json.Unmarshal([]byte(s), v1); err != nil || *v1 != value { + t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, *v1) + } + trans.Reset() + trans.Close() + } +} +func TestReadSimpleJSONProtocolStringNull(t *testing.T) { + thetype := "string" + value := "null" + + trans := NewTMemoryBuffer() + p := NewTSimpleJSONProtocol(trans) + trans.WriteString(value) + trans.Flush(context.Background()) + s := trans.String() + v, e := p.ReadString() + if e != nil { + t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error()) + } + if v != "" { + t.Fatalf("Bad value for %s value %v, wrote: %v, received: %v", thetype, value, s, v) + } + trans.Reset() + trans.Close() +} + +func TestWriteSimpleJSONProtocolBinary(t *testing.T) { + thetype := "binary" + value := protocol_bdata + b64value := make([]byte, base64.StdEncoding.EncodedLen(len(protocol_bdata))) + base64.StdEncoding.Encode(b64value, value) + b64String := string(b64value) + trans := NewTMemoryBuffer() + p := NewTSimpleJSONProtocol(trans) + if e := p.WriteBinary(value); e != nil { + t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error()) + } + if e := p.Flush(context.Background()); e != nil { + t.Fatalf("Unable to write %s value %v due to error flushing: %s", thetype, value, e.Error()) + } + s := trans.String() + if s != fmt.Sprint("\"", b64String, "\"") { + t.Fatalf("Bad value for %s %v\n wrote: %v\nexpected: %v", thetype, value, s, "\""+b64String+"\"") + } + v1 := new(string) + if err := json.Unmarshal([]byte(s), v1); err != nil || *v1 != b64String { + t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, *v1) + } + trans.Close() +} + +func TestReadSimpleJSONProtocolBinary(t *testing.T) { + thetype := "binary" + value := protocol_bdata + b64value := make([]byte, base64.StdEncoding.EncodedLen(len(protocol_bdata))) + base64.StdEncoding.Encode(b64value, value) + b64String := string(b64value) + trans := NewTMemoryBuffer() + p := NewTSimpleJSONProtocol(trans) + trans.WriteString(jsonQuote(b64String)) + trans.Flush(context.Background()) + s := trans.String() + v, e := p.ReadBinary() + if e != nil { + t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error()) + } + if len(v) != len(value) { + t.Fatalf("Bad value for %s value length %v, wrote: %v, received length: %v", thetype, len(value), s, len(v)) + } + for i := 0; i < len(v); i++ { + if v[i] != value[i] { + t.Fatalf("Bad value for %s at index %d value %v, wrote: %v, received: %v", thetype, i, value[i], s, v[i]) + } + } + v1 := new(string) + if err := json.Unmarshal([]byte(s), v1); err != nil || *v1 != b64String { + t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, *v1) + } + trans.Reset() + trans.Close() +} + +func TestReadSimpleJSONProtocolBinaryNull(t *testing.T) { + thetype := "binary" + value := "null" + + trans := NewTMemoryBuffer() + p := NewTSimpleJSONProtocol(trans) + trans.WriteString(value) + trans.Flush(context.Background()) + s := trans.String() + b, e := p.ReadBinary() + v := string(b) + + if e != nil { + t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error()) + } + if v != "" { + t.Fatalf("Bad value for %s value %v, wrote: %v, received: %v", thetype, value, s, v) + } + trans.Reset() + trans.Close() +} + +func TestWriteSimpleJSONProtocolList(t *testing.T) { + thetype := "list" + trans := NewTMemoryBuffer() + p := NewTSimpleJSONProtocol(trans) + p.WriteListBegin(TType(DOUBLE), len(DOUBLE_VALUES)) + for _, value := range DOUBLE_VALUES { + if e := p.WriteDouble(value); e != nil { + t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error()) + } + } + p.WriteListEnd() + if e := p.Flush(context.Background()); e != nil { + t.Fatalf("Unable to write %s due to error flushing: %s", thetype, e.Error()) + } + str := trans.String() + str1 := new([]interface{}) + err := json.Unmarshal([]byte(str), str1) + if err != nil { + t.Fatalf("Unable to decode %s, wrote: %s", thetype, str) + } + l := *str1 + if len(l) < 2 { + t.Fatalf("List must be at least of length two to include metadata") + } + if int(l[0].(float64)) != DOUBLE { + t.Fatal("Invalid type for list, expected: ", DOUBLE, ", but was: ", l[0]) + } + if int(l[1].(float64)) != len(DOUBLE_VALUES) { + t.Fatal("Invalid length for list, expected: ", len(DOUBLE_VALUES), ", but was: ", l[1]) + } + for k, value := range DOUBLE_VALUES { + s := l[k+2] + if math.IsInf(value, 1) { + if s.(string) != JSON_INFINITY { + t.Fatalf("Bad value for %s at index %v %v, wrote: %q, expected: %q, originally wrote: %q", thetype, k, value, s, jsonQuote(JSON_INFINITY), str) + } + } else if math.IsInf(value, 0) { + if s.(string) != JSON_NEGATIVE_INFINITY { + t.Fatalf("Bad value for %s at index %v %v, wrote: %q, expected: %q, originally wrote: %q", thetype, k, value, s, jsonQuote(JSON_NEGATIVE_INFINITY), str) + } + } else if math.IsNaN(value) { + if s.(string) != JSON_NAN { + t.Fatalf("Bad value for %s at index %v %v, wrote: %q, expected: %q, originally wrote: %q", thetype, k, value, s, jsonQuote(JSON_NAN), str) + } + } else { + if s.(float64) != value { + t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s'", thetype, value, s) + } + } + trans.Reset() + } + trans.Close() +} + +func TestWriteSimpleJSONProtocolSet(t *testing.T) { + thetype := "set" + trans := NewTMemoryBuffer() + p := NewTSimpleJSONProtocol(trans) + p.WriteSetBegin(TType(DOUBLE), len(DOUBLE_VALUES)) + for _, value := range DOUBLE_VALUES { + if e := p.WriteDouble(value); e != nil { + t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error()) + } + } + p.WriteSetEnd() + if e := p.Flush(context.Background()); e != nil { + t.Fatalf("Unable to write %s due to error flushing: %s", thetype, e.Error()) + } + str := trans.String() + str1 := new([]interface{}) + err := json.Unmarshal([]byte(str), str1) + if err != nil { + t.Fatalf("Unable to decode %s, wrote: %s", thetype, str) + } + l := *str1 + if len(l) < 2 { + t.Fatalf("Set must be at least of length two to include metadata") + } + if int(l[0].(float64)) != DOUBLE { + t.Fatal("Invalid type for set, expected: ", DOUBLE, ", but was: ", l[0]) + } + if int(l[1].(float64)) != len(DOUBLE_VALUES) { + t.Fatal("Invalid length for set, expected: ", len(DOUBLE_VALUES), ", but was: ", l[1]) + } + for k, value := range DOUBLE_VALUES { + s := l[k+2] + if math.IsInf(value, 1) { + if s.(string) != JSON_INFINITY { + t.Fatalf("Bad value for %s at index %v %v, wrote: %q, expected: %q, originally wrote: %q", thetype, k, value, s, jsonQuote(JSON_INFINITY), str) + } + } else if math.IsInf(value, 0) { + if s.(string) != JSON_NEGATIVE_INFINITY { + t.Fatalf("Bad value for %s at index %v %v, wrote: %q, expected: %q, originally wrote: %q", thetype, k, value, s, jsonQuote(JSON_NEGATIVE_INFINITY), str) + } + } else if math.IsNaN(value) { + if s.(string) != JSON_NAN { + t.Fatalf("Bad value for %s at index %v %v, wrote: %q, expected: %q, originally wrote: %q", thetype, k, value, s, jsonQuote(JSON_NAN), str) + } + } else { + if s.(float64) != value { + t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s'", thetype, value, s) + } + } + trans.Reset() + } + trans.Close() +} + +func TestWriteSimpleJSONProtocolMap(t *testing.T) { + thetype := "map" + trans := NewTMemoryBuffer() + p := NewTSimpleJSONProtocol(trans) + p.WriteMapBegin(TType(I32), TType(DOUBLE), len(DOUBLE_VALUES)) + for k, value := range DOUBLE_VALUES { + if e := p.WriteI32(int32(k)); e != nil { + t.Fatalf("Unable to write %s key int32 value %v due to error: %s", thetype, k, e.Error()) + } + if e := p.WriteDouble(value); e != nil { + t.Fatalf("Unable to write %s value float64 value %v due to error: %s", thetype, value, e.Error()) + } + } + p.WriteMapEnd() + if e := p.Flush(context.Background()); e != nil { + t.Fatalf("Unable to write %s due to error flushing: %s", thetype, e.Error()) + } + str := trans.String() + if str[0] != '[' || str[len(str)-1] != ']' { + t.Fatalf("Bad value for %s, wrote: %v, in go: %v", thetype, str, DOUBLE_VALUES) + } + l := strings.Split(str[1:len(str)-1], ",") + if len(l) < 3 { + t.Fatal("Expected list of at least length 3 for map for metadata, but was of length ", len(l)) + } + expectedKeyType, _ := strconv.Atoi(l[0]) + expectedValueType, _ := strconv.Atoi(l[1]) + expectedSize, _ := strconv.Atoi(l[2]) + if expectedKeyType != I32 { + t.Fatal("Expected map key type ", I32, ", but was ", l[0]) + } + if expectedValueType != DOUBLE { + t.Fatal("Expected map value type ", DOUBLE, ", but was ", l[1]) + } + if expectedSize != len(DOUBLE_VALUES) { + t.Fatal("Expected map size of ", len(DOUBLE_VALUES), ", but was ", l[2]) + } + for k, value := range DOUBLE_VALUES { + strk := l[k*2+3] + strv := l[k*2+4] + ik, err := strconv.Atoi(strk) + if err != nil { + t.Fatalf("Bad value for %s index %v, wrote: %v, expected: %v, error: %s", thetype, k, strk, string(k), err.Error()) + } + if ik != k { + t.Fatalf("Bad value for %s index %v, wrote: %v, expected: %v", thetype, k, strk, k) + } + s := strv + if math.IsInf(value, 1) { + if s != jsonQuote(JSON_INFINITY) { + t.Fatalf("Bad value for %s at index %v %v, wrote: %v, expected: %v", thetype, k, value, s, jsonQuote(JSON_INFINITY)) + } + } else if math.IsInf(value, 0) { + if s != jsonQuote(JSON_NEGATIVE_INFINITY) { + t.Fatalf("Bad value for %s at index %v %v, wrote: %v, expected: %v", thetype, k, value, s, jsonQuote(JSON_NEGATIVE_INFINITY)) + } + } else if math.IsNaN(value) { + if s != jsonQuote(JSON_NAN) { + t.Fatalf("Bad value for %s at index %v %v, wrote: %v, expected: %v", thetype, k, value, s, jsonQuote(JSON_NAN)) + } + } else { + expected := strconv.FormatFloat(value, 'g', 10, 64) + if s != expected { + t.Fatalf("Bad value for %s at index %v %v, wrote: %v, expected %v", thetype, k, value, s, expected) + } + v := float64(0) + if err := json.Unmarshal([]byte(s), &v); err != nil || v != value { + t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, v) + } + } + trans.Reset() + } + trans.Close() +} + +func TestWriteSimpleJSONProtocolSafePeek(t *testing.T) { + trans := NewTMemoryBuffer() + p := NewTSimpleJSONProtocol(trans) + trans.Write([]byte{'a', 'b'}) + trans.Flush(context.Background()) + + test1 := p.safePeekContains([]byte{'a', 'b'}) + if !test1 { + t.Fatalf("Should match at test 1") + } + + test2 := p.safePeekContains([]byte{'a', 'b', 'c', 'd'}) + if test2 { + t.Fatalf("Should not match at test 2") + } + + test3 := p.safePeekContains([]byte{'x', 'y'}) + if test3 { + t.Fatalf("Should not match at test 3") + } +} diff --git a/src/jaegertracing/thrift/lib/go/thrift/simple_server.go b/src/jaegertracing/thrift/lib/go/thrift/simple_server.go new file mode 100644 index 000000000..f8efbed91 --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/thrift/simple_server.go @@ -0,0 +1,278 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import ( + "log" + "runtime/debug" + "sync" + "sync/atomic" +) + +/* + * This is not a typical TSimpleServer as it is not blocked after accept a socket. + * It is more like a TThreadedServer that can handle different connections in different goroutines. + * This will work if golang user implements a conn-pool like thing in client side. + */ +type TSimpleServer struct { + closed int32 + wg sync.WaitGroup + mu sync.Mutex + + processorFactory TProcessorFactory + serverTransport TServerTransport + inputTransportFactory TTransportFactory + outputTransportFactory TTransportFactory + inputProtocolFactory TProtocolFactory + outputProtocolFactory TProtocolFactory + + // Headers to auto forward in THeaderProtocol + forwardHeaders []string +} + +func NewTSimpleServer2(processor TProcessor, serverTransport TServerTransport) *TSimpleServer { + return NewTSimpleServerFactory2(NewTProcessorFactory(processor), serverTransport) +} + +func NewTSimpleServer4(processor TProcessor, serverTransport TServerTransport, transportFactory TTransportFactory, protocolFactory TProtocolFactory) *TSimpleServer { + return NewTSimpleServerFactory4(NewTProcessorFactory(processor), + serverTransport, + transportFactory, + protocolFactory, + ) +} + +func NewTSimpleServer6(processor TProcessor, serverTransport TServerTransport, inputTransportFactory TTransportFactory, outputTransportFactory TTransportFactory, inputProtocolFactory TProtocolFactory, outputProtocolFactory TProtocolFactory) *TSimpleServer { + return NewTSimpleServerFactory6(NewTProcessorFactory(processor), + serverTransport, + inputTransportFactory, + outputTransportFactory, + inputProtocolFactory, + outputProtocolFactory, + ) +} + +func NewTSimpleServerFactory2(processorFactory TProcessorFactory, serverTransport TServerTransport) *TSimpleServer { + return NewTSimpleServerFactory6(processorFactory, + serverTransport, + NewTTransportFactory(), + NewTTransportFactory(), + NewTBinaryProtocolFactoryDefault(), + NewTBinaryProtocolFactoryDefault(), + ) +} + +func NewTSimpleServerFactory4(processorFactory TProcessorFactory, serverTransport TServerTransport, transportFactory TTransportFactory, protocolFactory TProtocolFactory) *TSimpleServer { + return NewTSimpleServerFactory6(processorFactory, + serverTransport, + transportFactory, + transportFactory, + protocolFactory, + protocolFactory, + ) +} + +func NewTSimpleServerFactory6(processorFactory TProcessorFactory, serverTransport TServerTransport, inputTransportFactory TTransportFactory, outputTransportFactory TTransportFactory, inputProtocolFactory TProtocolFactory, outputProtocolFactory TProtocolFactory) *TSimpleServer { + return &TSimpleServer{ + processorFactory: processorFactory, + serverTransport: serverTransport, + inputTransportFactory: inputTransportFactory, + outputTransportFactory: outputTransportFactory, + inputProtocolFactory: inputProtocolFactory, + outputProtocolFactory: outputProtocolFactory, + } +} + +func (p *TSimpleServer) ProcessorFactory() TProcessorFactory { + return p.processorFactory +} + +func (p *TSimpleServer) ServerTransport() TServerTransport { + return p.serverTransport +} + +func (p *TSimpleServer) InputTransportFactory() TTransportFactory { + return p.inputTransportFactory +} + +func (p *TSimpleServer) OutputTransportFactory() TTransportFactory { + return p.outputTransportFactory +} + +func (p *TSimpleServer) InputProtocolFactory() TProtocolFactory { + return p.inputProtocolFactory +} + +func (p *TSimpleServer) OutputProtocolFactory() TProtocolFactory { + return p.outputProtocolFactory +} + +func (p *TSimpleServer) Listen() error { + return p.serverTransport.Listen() +} + +// SetForwardHeaders sets the list of header keys that will be auto forwarded +// while using THeaderProtocol. +// +// "forward" means that when the server is also a client to other upstream +// thrift servers, the context object user gets in the processor functions will +// have both read and write headers set, with write headers being forwarded. +// Users can always override the write headers by calling SetWriteHeaderList +// before calling thrift client functions. +func (p *TSimpleServer) SetForwardHeaders(headers []string) { + size := len(headers) + if size == 0 { + p.forwardHeaders = nil + return + } + + keys := make([]string, size) + copy(keys, headers) + p.forwardHeaders = keys +} + +func (p *TSimpleServer) innerAccept() (int32, error) { + client, err := p.serverTransport.Accept() + p.mu.Lock() + defer p.mu.Unlock() + closed := atomic.LoadInt32(&p.closed) + if closed != 0 { + return closed, nil + } + if err != nil { + return 0, err + } + if client != nil { + p.wg.Add(1) + go func() { + defer p.wg.Done() + if err := p.processRequests(client); err != nil { + log.Println("error processing request:", err) + } + }() + } + return 0, nil +} + +func (p *TSimpleServer) AcceptLoop() error { + for { + closed, err := p.innerAccept() + if err != nil { + return err + } + if closed != 0 { + return nil + } + } +} + +func (p *TSimpleServer) Serve() error { + err := p.Listen() + if err != nil { + return err + } + p.AcceptLoop() + return nil +} + +func (p *TSimpleServer) Stop() error { + p.mu.Lock() + defer p.mu.Unlock() + if atomic.LoadInt32(&p.closed) != 0 { + return nil + } + atomic.StoreInt32(&p.closed, 1) + p.serverTransport.Interrupt() + p.wg.Wait() + return nil +} + +func (p *TSimpleServer) processRequests(client TTransport) error { + processor := p.processorFactory.GetProcessor(client) + inputTransport, err := p.inputTransportFactory.GetTransport(client) + if err != nil { + return err + } + inputProtocol := p.inputProtocolFactory.GetProtocol(inputTransport) + var outputTransport TTransport + var outputProtocol TProtocol + + // for THeaderProtocol, we must use the same protocol instance for + // input and output so that the response is in the same dialect that + // the server detected the request was in. + headerProtocol, ok := inputProtocol.(*THeaderProtocol) + if ok { + outputProtocol = inputProtocol + } else { + oTrans, err := p.outputTransportFactory.GetTransport(client) + if err != nil { + return err + } + outputTransport = oTrans + outputProtocol = p.outputProtocolFactory.GetProtocol(outputTransport) + } + + defer func() { + if e := recover(); e != nil { + log.Printf("panic in processor: %s: %s", e, debug.Stack()) + } + }() + + if inputTransport != nil { + defer inputTransport.Close() + } + if outputTransport != nil { + defer outputTransport.Close() + } + for { + if atomic.LoadInt32(&p.closed) != 0 { + return nil + } + + ctx := defaultCtx + if headerProtocol != nil { + // We need to call ReadFrame here, otherwise we won't + // get any headers on the AddReadTHeaderToContext call. + // + // ReadFrame is safe to be called multiple times so it + // won't break when it's called again later when we + // actually start to read the message. + if err := headerProtocol.ReadFrame(); err != nil { + return err + } + ctx = AddReadTHeaderToContext(defaultCtx, headerProtocol.GetReadHeaders()) + ctx = SetWriteHeaderList(ctx, p.forwardHeaders) + } + + ok, err := processor.Process(ctx, inputProtocol, outputProtocol) + if err, ok := err.(TTransportException); ok && err.TypeId() == END_OF_FILE { + return nil + } else if err != nil { + return err + } + if err, ok := err.(TApplicationException); ok && err.TypeId() == UNKNOWN_METHOD { + continue + } + if !ok { + break + } + } + return nil +} diff --git a/src/jaegertracing/thrift/lib/go/thrift/simple_server_test.go b/src/jaegertracing/thrift/lib/go/thrift/simple_server_test.go new file mode 100644 index 000000000..58149a8e6 --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/thrift/simple_server_test.go @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import ( + "testing" + "errors" + "runtime" +) + +type mockServerTransport struct { + ListenFunc func() error + AcceptFunc func() (TTransport, error) + CloseFunc func() error + InterruptFunc func() error +} + +func (m *mockServerTransport) Listen() error { + return m.ListenFunc() +} + +func (m *mockServerTransport) Accept() (TTransport, error) { + return m.AcceptFunc() +} + +func (m *mockServerTransport) Close() error { + return m.CloseFunc() +} + +func (m *mockServerTransport) Interrupt() error { + return m.InterruptFunc() +} + +type mockTTransport struct { + TTransport +} + +func (m *mockTTransport) Close() error { + return nil +} + +func TestMultipleStop(t *testing.T) { + proc := &mockProcessor{ + ProcessFunc: func(in, out TProtocol) (bool, TException) { + return false, nil + }, + } + + var interruptCalled bool + c := make(chan struct{}) + trans := &mockServerTransport{ + ListenFunc: func() error { + return nil + }, + AcceptFunc: func() (TTransport, error) { + <-c + return nil, nil + }, + CloseFunc: func() error { + c <- struct{}{} + return nil + }, + InterruptFunc: func() error { + interruptCalled = true + return nil + }, + } + + serv := NewTSimpleServer2(proc, trans) + go serv.Serve() + serv.Stop() + if !interruptCalled { + t.Error("first server transport should have been interrupted") + } + + serv = NewTSimpleServer2(proc, trans) + interruptCalled = false + go serv.Serve() + serv.Stop() + if !interruptCalled { + t.Error("second server transport should have been interrupted") + } +} + +func TestWaitRace(t *testing.T) { + proc := &mockProcessor{ + ProcessFunc: func(in, out TProtocol) (bool, TException) { + return false, nil + }, + } + + trans := &mockServerTransport{ + ListenFunc: func() error { + return nil + }, + AcceptFunc: func() (TTransport, error) { + return &mockTTransport{}, nil + }, + CloseFunc: func() error { + return nil + }, + InterruptFunc: func() error { + return nil + }, + } + + serv := NewTSimpleServer2(proc, trans) + go serv.Serve() + runtime.Gosched() + serv.Stop() +} + +func TestNoHangDuringStopFromDanglingLockAcquireDuringAcceptLoop(t *testing.T) { + proc := &mockProcessor{ + ProcessFunc: func(in, out TProtocol) (bool, TException) { + return false, nil + }, + } + + trans := &mockServerTransport{ + ListenFunc: func() error { + return nil + }, + AcceptFunc: func() (TTransport, error) { + return nil, errors.New("no sir") + }, + CloseFunc: func() error { + return nil + }, + InterruptFunc: func() error { + return nil + }, + } + + serv := NewTSimpleServer2(proc, trans) + go serv.Serve() + runtime.Gosched() + serv.Stop() +} diff --git a/src/jaegertracing/thrift/lib/go/thrift/socket.go b/src/jaegertracing/thrift/lib/go/thrift/socket.go new file mode 100644 index 000000000..88b98f591 --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/thrift/socket.go @@ -0,0 +1,166 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import ( + "context" + "net" + "time" +) + +type TSocket struct { + conn net.Conn + addr net.Addr + timeout time.Duration +} + +// NewTSocket creates a net.Conn-backed TTransport, given a host and port +// +// Example: +// trans, err := thrift.NewTSocket("localhost:9090") +func NewTSocket(hostPort string) (*TSocket, error) { + return NewTSocketTimeout(hostPort, 0) +} + +// NewTSocketTimeout creates a net.Conn-backed TTransport, given a host and port +// it also accepts a timeout as a time.Duration +func NewTSocketTimeout(hostPort string, timeout time.Duration) (*TSocket, error) { + //conn, err := net.DialTimeout(network, address, timeout) + addr, err := net.ResolveTCPAddr("tcp", hostPort) + if err != nil { + return nil, err + } + return NewTSocketFromAddrTimeout(addr, timeout), nil +} + +// Creates a TSocket from a net.Addr +func NewTSocketFromAddrTimeout(addr net.Addr, timeout time.Duration) *TSocket { + return &TSocket{addr: addr, timeout: timeout} +} + +// Creates a TSocket from an existing net.Conn +func NewTSocketFromConnTimeout(conn net.Conn, timeout time.Duration) *TSocket { + return &TSocket{conn: conn, addr: conn.RemoteAddr(), timeout: timeout} +} + +// Sets the socket timeout +func (p *TSocket) SetTimeout(timeout time.Duration) error { + p.timeout = timeout + return nil +} + +func (p *TSocket) pushDeadline(read, write bool) { + var t time.Time + if p.timeout > 0 { + t = time.Now().Add(time.Duration(p.timeout)) + } + if read && write { + p.conn.SetDeadline(t) + } else if read { + p.conn.SetReadDeadline(t) + } else if write { + p.conn.SetWriteDeadline(t) + } +} + +// Connects the socket, creating a new socket object if necessary. +func (p *TSocket) Open() error { + if p.IsOpen() { + return NewTTransportException(ALREADY_OPEN, "Socket already connected.") + } + if p.addr == nil { + return NewTTransportException(NOT_OPEN, "Cannot open nil address.") + } + if len(p.addr.Network()) == 0 { + return NewTTransportException(NOT_OPEN, "Cannot open bad network name.") + } + if len(p.addr.String()) == 0 { + return NewTTransportException(NOT_OPEN, "Cannot open bad address.") + } + var err error + if p.conn, err = net.DialTimeout(p.addr.Network(), p.addr.String(), p.timeout); err != nil { + return NewTTransportException(NOT_OPEN, err.Error()) + } + return nil +} + +// Retrieve the underlying net.Conn +func (p *TSocket) Conn() net.Conn { + return p.conn +} + +// Returns true if the connection is open +func (p *TSocket) IsOpen() bool { + if p.conn == nil { + return false + } + return true +} + +// Closes the socket. +func (p *TSocket) Close() error { + // Close the socket + if p.conn != nil { + err := p.conn.Close() + if err != nil { + return err + } + p.conn = nil + } + return nil +} + +//Returns the remote address of the socket. +func (p *TSocket) Addr() net.Addr { + return p.addr +} + +func (p *TSocket) Read(buf []byte) (int, error) { + if !p.IsOpen() { + return 0, NewTTransportException(NOT_OPEN, "Connection not open") + } + p.pushDeadline(true, false) + n, err := p.conn.Read(buf) + return n, NewTTransportExceptionFromError(err) +} + +func (p *TSocket) Write(buf []byte) (int, error) { + if !p.IsOpen() { + return 0, NewTTransportException(NOT_OPEN, "Connection not open") + } + p.pushDeadline(false, true) + return p.conn.Write(buf) +} + +func (p *TSocket) Flush(ctx context.Context) error { + return nil +} + +func (p *TSocket) Interrupt() error { + if !p.IsOpen() { + return nil + } + return p.conn.Close() +} + +func (p *TSocket) RemainingBytes() (num_bytes uint64) { + const maxSize = ^uint64(0) + return maxSize // the truth is, we just don't know unless framed is used +} diff --git a/src/jaegertracing/thrift/lib/go/thrift/ssl_server_socket.go b/src/jaegertracing/thrift/lib/go/thrift/ssl_server_socket.go new file mode 100644 index 000000000..907afca32 --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/thrift/ssl_server_socket.go @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import ( + "crypto/tls" + "net" + "time" +) + +type TSSLServerSocket struct { + listener net.Listener + addr net.Addr + clientTimeout time.Duration + interrupted bool + cfg *tls.Config +} + +func NewTSSLServerSocket(listenAddr string, cfg *tls.Config) (*TSSLServerSocket, error) { + return NewTSSLServerSocketTimeout(listenAddr, cfg, 0) +} + +func NewTSSLServerSocketTimeout(listenAddr string, cfg *tls.Config, clientTimeout time.Duration) (*TSSLServerSocket, error) { + if cfg.MinVersion == 0 { + cfg.MinVersion = tls.VersionTLS10 + } + addr, err := net.ResolveTCPAddr("tcp", listenAddr) + if err != nil { + return nil, err + } + return &TSSLServerSocket{addr: addr, clientTimeout: clientTimeout, cfg: cfg}, nil +} + +func (p *TSSLServerSocket) Listen() error { + if p.IsListening() { + return nil + } + l, err := tls.Listen(p.addr.Network(), p.addr.String(), p.cfg) + if err != nil { + return err + } + p.listener = l + return nil +} + +func (p *TSSLServerSocket) Accept() (TTransport, error) { + if p.interrupted { + return nil, errTransportInterrupted + } + if p.listener == nil { + return nil, NewTTransportException(NOT_OPEN, "No underlying server socket") + } + conn, err := p.listener.Accept() + if err != nil { + return nil, NewTTransportExceptionFromError(err) + } + return NewTSSLSocketFromConnTimeout(conn, p.cfg, p.clientTimeout), nil +} + +// Checks whether the socket is listening. +func (p *TSSLServerSocket) IsListening() bool { + return p.listener != nil +} + +// Connects the socket, creating a new socket object if necessary. +func (p *TSSLServerSocket) Open() error { + if p.IsListening() { + return NewTTransportException(ALREADY_OPEN, "Server socket already open") + } + if l, err := tls.Listen(p.addr.Network(), p.addr.String(), p.cfg); err != nil { + return err + } else { + p.listener = l + } + return nil +} + +func (p *TSSLServerSocket) Addr() net.Addr { + return p.addr +} + +func (p *TSSLServerSocket) Close() error { + defer func() { + p.listener = nil + }() + if p.IsListening() { + return p.listener.Close() + } + return nil +} + +func (p *TSSLServerSocket) Interrupt() error { + p.interrupted = true + return nil +} diff --git a/src/jaegertracing/thrift/lib/go/thrift/ssl_socket.go b/src/jaegertracing/thrift/lib/go/thrift/ssl_socket.go new file mode 100644 index 000000000..ba6337726 --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/thrift/ssl_socket.go @@ -0,0 +1,176 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import ( + "context" + "crypto/tls" + "net" + "time" +) + +type TSSLSocket struct { + conn net.Conn + // hostPort contains host:port (e.g. "asdf.com:12345"). The field is + // only valid if addr is nil. + hostPort string + // addr is nil when hostPort is not "", and is only used when the + // TSSLSocket is constructed from a net.Addr. + addr net.Addr + timeout time.Duration + cfg *tls.Config +} + +// NewTSSLSocket creates a net.Conn-backed TTransport, given a host and port and tls Configuration +// +// Example: +// trans, err := thrift.NewTSSLSocket("localhost:9090", nil) +func NewTSSLSocket(hostPort string, cfg *tls.Config) (*TSSLSocket, error) { + return NewTSSLSocketTimeout(hostPort, cfg, 0) +} + +// NewTSSLSocketTimeout creates a net.Conn-backed TTransport, given a host and port +// it also accepts a tls Configuration and a timeout as a time.Duration +func NewTSSLSocketTimeout(hostPort string, cfg *tls.Config, timeout time.Duration) (*TSSLSocket, error) { + if cfg.MinVersion == 0 { + cfg.MinVersion = tls.VersionTLS10 + } + return &TSSLSocket{hostPort: hostPort, timeout: timeout, cfg: cfg}, nil +} + +// Creates a TSSLSocket from a net.Addr +func NewTSSLSocketFromAddrTimeout(addr net.Addr, cfg *tls.Config, timeout time.Duration) *TSSLSocket { + return &TSSLSocket{addr: addr, timeout: timeout, cfg: cfg} +} + +// Creates a TSSLSocket from an existing net.Conn +func NewTSSLSocketFromConnTimeout(conn net.Conn, cfg *tls.Config, timeout time.Duration) *TSSLSocket { + return &TSSLSocket{conn: conn, addr: conn.RemoteAddr(), timeout: timeout, cfg: cfg} +} + +// Sets the socket timeout +func (p *TSSLSocket) SetTimeout(timeout time.Duration) error { + p.timeout = timeout + return nil +} + +func (p *TSSLSocket) pushDeadline(read, write bool) { + var t time.Time + if p.timeout > 0 { + t = time.Now().Add(time.Duration(p.timeout)) + } + if read && write { + p.conn.SetDeadline(t) + } else if read { + p.conn.SetReadDeadline(t) + } else if write { + p.conn.SetWriteDeadline(t) + } +} + +// Connects the socket, creating a new socket object if necessary. +func (p *TSSLSocket) Open() error { + var err error + // If we have a hostname, we need to pass the hostname to tls.Dial for + // certificate hostname checks. + if p.hostPort != "" { + if p.conn, err = tls.DialWithDialer(&net.Dialer{ + Timeout: p.timeout}, "tcp", p.hostPort, p.cfg); err != nil { + return NewTTransportException(NOT_OPEN, err.Error()) + } + } else { + if p.IsOpen() { + return NewTTransportException(ALREADY_OPEN, "Socket already connected.") + } + if p.addr == nil { + return NewTTransportException(NOT_OPEN, "Cannot open nil address.") + } + if len(p.addr.Network()) == 0 { + return NewTTransportException(NOT_OPEN, "Cannot open bad network name.") + } + if len(p.addr.String()) == 0 { + return NewTTransportException(NOT_OPEN, "Cannot open bad address.") + } + if p.conn, err = tls.DialWithDialer(&net.Dialer{ + Timeout: p.timeout}, p.addr.Network(), p.addr.String(), p.cfg); err != nil { + return NewTTransportException(NOT_OPEN, err.Error()) + } + } + return nil +} + +// Retrieve the underlying net.Conn +func (p *TSSLSocket) Conn() net.Conn { + return p.conn +} + +// Returns true if the connection is open +func (p *TSSLSocket) IsOpen() bool { + if p.conn == nil { + return false + } + return true +} + +// Closes the socket. +func (p *TSSLSocket) Close() error { + // Close the socket + if p.conn != nil { + err := p.conn.Close() + if err != nil { + return err + } + p.conn = nil + } + return nil +} + +func (p *TSSLSocket) Read(buf []byte) (int, error) { + if !p.IsOpen() { + return 0, NewTTransportException(NOT_OPEN, "Connection not open") + } + p.pushDeadline(true, false) + n, err := p.conn.Read(buf) + return n, NewTTransportExceptionFromError(err) +} + +func (p *TSSLSocket) Write(buf []byte) (int, error) { + if !p.IsOpen() { + return 0, NewTTransportException(NOT_OPEN, "Connection not open") + } + p.pushDeadline(false, true) + return p.conn.Write(buf) +} + +func (p *TSSLSocket) Flush(ctx context.Context) error { + return nil +} + +func (p *TSSLSocket) Interrupt() error { + if !p.IsOpen() { + return nil + } + return p.conn.Close() +} + +func (p *TSSLSocket) RemainingBytes() (num_bytes uint64) { + const maxSize = ^uint64(0) + return maxSize // the thruth is, we just don't know unless framed is used +} diff --git a/src/jaegertracing/thrift/lib/go/thrift/transport.go b/src/jaegertracing/thrift/lib/go/thrift/transport.go new file mode 100644 index 000000000..ba2738a8d --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/thrift/transport.go @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import ( + "context" + "errors" + "io" +) + +var errTransportInterrupted = errors.New("Transport Interrupted") + +type Flusher interface { + Flush() (err error) +} + +type ContextFlusher interface { + Flush(ctx context.Context) (err error) +} + +type ReadSizeProvider interface { + RemainingBytes() (num_bytes uint64) +} + +// Encapsulates the I/O layer +type TTransport interface { + io.ReadWriteCloser + ContextFlusher + ReadSizeProvider + + // Opens the transport for communication + Open() error + + // Returns true if the transport is open + IsOpen() bool +} + +type stringWriter interface { + WriteString(s string) (n int, err error) +} + +// This is "enchanced" transport with extra capabilities. You need to use one of these +// to construct protocol. +// Notably, TSocket does not implement this interface, and it is always a mistake to use +// TSocket directly in protocol. +type TRichTransport interface { + io.ReadWriter + io.ByteReader + io.ByteWriter + stringWriter + ContextFlusher + ReadSizeProvider +} diff --git a/src/jaegertracing/thrift/lib/go/thrift/transport_exception.go b/src/jaegertracing/thrift/lib/go/thrift/transport_exception.go new file mode 100644 index 000000000..9505b4461 --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/thrift/transport_exception.go @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import ( + "errors" + "io" +) + +type timeoutable interface { + Timeout() bool +} + +// Thrift Transport exception +type TTransportException interface { + TException + TypeId() int + Err() error +} + +const ( + UNKNOWN_TRANSPORT_EXCEPTION = 0 + NOT_OPEN = 1 + ALREADY_OPEN = 2 + TIMED_OUT = 3 + END_OF_FILE = 4 +) + +type tTransportException struct { + typeId int + err error +} + +func (p *tTransportException) TypeId() int { + return p.typeId +} + +func (p *tTransportException) Error() string { + return p.err.Error() +} + +func (p *tTransportException) Err() error { + return p.err +} + +func NewTTransportException(t int, e string) TTransportException { + return &tTransportException{typeId: t, err: errors.New(e)} +} + +func NewTTransportExceptionFromError(e error) TTransportException { + if e == nil { + return nil + } + + if t, ok := e.(TTransportException); ok { + return t + } + + switch v := e.(type) { + case TTransportException: + return v + case timeoutable: + if v.Timeout() { + return &tTransportException{typeId: TIMED_OUT, err: e} + } + } + + if e == io.EOF { + return &tTransportException{typeId: END_OF_FILE, err: e} + } + + return &tTransportException{typeId: UNKNOWN_TRANSPORT_EXCEPTION, err: e} +} diff --git a/src/jaegertracing/thrift/lib/go/thrift/transport_exception_test.go b/src/jaegertracing/thrift/lib/go/thrift/transport_exception_test.go new file mode 100644 index 000000000..b44314f49 --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/thrift/transport_exception_test.go @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import ( + "fmt" + "io" + + "testing" +) + +type timeout struct{ timedout bool } + +func (t *timeout) Timeout() bool { + return t.timedout +} + +func (t *timeout) Error() string { + return fmt.Sprintf("Timeout: %v", t.timedout) +} + +func TestTExceptionTimeout(t *testing.T) { + timeout := &timeout{true} + exception := NewTTransportExceptionFromError(timeout) + if timeout.Error() != exception.Error() { + t.Fatalf("Error did not match: expected %q, got %q", timeout.Error(), exception.Error()) + } + + if exception.TypeId() != TIMED_OUT { + t.Fatalf("TypeId was not TIMED_OUT: expected %v, got %v", TIMED_OUT, exception.TypeId()) + } +} + +func TestTExceptionEOF(t *testing.T) { + exception := NewTTransportExceptionFromError(io.EOF) + if io.EOF.Error() != exception.Error() { + t.Fatalf("Error did not match: expected %q, got %q", io.EOF.Error(), exception.Error()) + } + + if exception.TypeId() != END_OF_FILE { + t.Fatalf("TypeId was not END_OF_FILE: expected %v, got %v", END_OF_FILE, exception.TypeId()) + } +} diff --git a/src/jaegertracing/thrift/lib/go/thrift/transport_factory.go b/src/jaegertracing/thrift/lib/go/thrift/transport_factory.go new file mode 100644 index 000000000..c80580794 --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/thrift/transport_factory.go @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +// Factory class used to create wrapped instance of Transports. +// This is used primarily in servers, which get Transports from +// a ServerTransport and then may want to mutate them (i.e. create +// a BufferedTransport from the underlying base transport) +type TTransportFactory interface { + GetTransport(trans TTransport) (TTransport, error) +} + +type tTransportFactory struct{} + +// Return a wrapped instance of the base Transport. +func (p *tTransportFactory) GetTransport(trans TTransport) (TTransport, error) { + return trans, nil +} + +func NewTTransportFactory() TTransportFactory { + return &tTransportFactory{} +} diff --git a/src/jaegertracing/thrift/lib/go/thrift/transport_test.go b/src/jaegertracing/thrift/lib/go/thrift/transport_test.go new file mode 100644 index 000000000..01278038e --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/thrift/transport_test.go @@ -0,0 +1,177 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import ( + "context" + "io" + "net" + "strconv" + "testing" +) + +const TRANSPORT_BINARY_DATA_SIZE = 4096 + +var ( + transport_bdata []byte // test data for writing; same as data + transport_header map[string]string +) + +func init() { + transport_bdata = make([]byte, TRANSPORT_BINARY_DATA_SIZE) + for i := 0; i < TRANSPORT_BINARY_DATA_SIZE; i++ { + transport_bdata[i] = byte((i + 'a') % 255) + } + transport_header = map[string]string{"key": "User-Agent", + "value": "Mozilla/5.0 (Windows NT 6.2; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/32.0.1667.0 Safari/537.36"} +} + +func TransportTest(t *testing.T, writeTrans TTransport, readTrans TTransport) { + buf := make([]byte, TRANSPORT_BINARY_DATA_SIZE) + if !writeTrans.IsOpen() { + t.Fatalf("Transport %T not open: %s", writeTrans, writeTrans) + } + if !readTrans.IsOpen() { + t.Fatalf("Transport %T not open: %s", readTrans, readTrans) + } + _, err := writeTrans.Write(transport_bdata) + if err != nil { + t.Fatalf("Transport %T cannot write binary data of length %d: %s", writeTrans, len(transport_bdata), err) + } + err = writeTrans.Flush(context.Background()) + if err != nil { + t.Fatalf("Transport %T cannot flush write of binary data: %s", writeTrans, err) + } + n, err := io.ReadFull(readTrans, buf) + if err != nil { + t.Errorf("Transport %T cannot read binary data of length %d: %s", readTrans, TRANSPORT_BINARY_DATA_SIZE, err) + } + if n != TRANSPORT_BINARY_DATA_SIZE { + t.Errorf("Transport %T read only %d instead of %d bytes of binary data", readTrans, n, TRANSPORT_BINARY_DATA_SIZE) + } + for k, v := range buf { + if v != transport_bdata[k] { + t.Fatalf("Transport %T read %d instead of %d for index %d of binary data 2", readTrans, v, transport_bdata[k], k) + } + } + _, err = writeTrans.Write(transport_bdata) + if err != nil { + t.Fatalf("Transport %T cannot write binary data 2 of length %d: %s", writeTrans, len(transport_bdata), err) + } + err = writeTrans.Flush(context.Background()) + if err != nil { + t.Fatalf("Transport %T cannot flush write binary data 2: %s", writeTrans, err) + } + buf = make([]byte, TRANSPORT_BINARY_DATA_SIZE) + read := 1 + for n = 0; n < TRANSPORT_BINARY_DATA_SIZE && read != 0; { + read, err = readTrans.Read(buf[n:]) + if err != nil { + t.Errorf("Transport %T cannot read binary data 2 of total length %d from offset %d: %s", readTrans, TRANSPORT_BINARY_DATA_SIZE, n, err) + } + n += read + } + if n != TRANSPORT_BINARY_DATA_SIZE { + t.Errorf("Transport %T read only %d instead of %d bytes of binary data 2", readTrans, n, TRANSPORT_BINARY_DATA_SIZE) + } + for k, v := range buf { + if v != transport_bdata[k] { + t.Fatalf("Transport %T read %d instead of %d for index %d of binary data 2", readTrans, v, transport_bdata[k], k) + } + } +} + +func TransportHeaderTest(t *testing.T, writeTrans TTransport, readTrans TTransport) { + buf := make([]byte, TRANSPORT_BINARY_DATA_SIZE) + if !writeTrans.IsOpen() { + t.Fatalf("Transport %T not open: %s", writeTrans, writeTrans) + } + if !readTrans.IsOpen() { + t.Fatalf("Transport %T not open: %s", readTrans, readTrans) + } + // Need to assert type of TTransport to THttpClient to expose the Setter + httpWPostTrans := writeTrans.(*THttpClient) + httpWPostTrans.SetHeader(transport_header["key"], transport_header["value"]) + + _, err := writeTrans.Write(transport_bdata) + if err != nil { + t.Fatalf("Transport %T cannot write binary data of length %d: %s", writeTrans, len(transport_bdata), err) + } + err = writeTrans.Flush(context.Background()) + if err != nil { + t.Fatalf("Transport %T cannot flush write of binary data: %s", writeTrans, err) + } + // Need to assert type of TTransport to THttpClient to expose the Getter + httpRPostTrans := readTrans.(*THttpClient) + readHeader := httpRPostTrans.GetHeader(transport_header["key"]) + if err != nil { + t.Errorf("Transport %T cannot read HTTP Header Value", httpRPostTrans) + } + + if transport_header["value"] != readHeader { + t.Errorf("Expected HTTP Header Value %s, got %s", transport_header["value"], readHeader) + } + n, err := io.ReadFull(readTrans, buf) + if err != nil { + t.Errorf("Transport %T cannot read binary data of length %d: %s", readTrans, TRANSPORT_BINARY_DATA_SIZE, err) + } + if n != TRANSPORT_BINARY_DATA_SIZE { + t.Errorf("Transport %T read only %d instead of %d bytes of binary data", readTrans, n, TRANSPORT_BINARY_DATA_SIZE) + } + for k, v := range buf { + if v != transport_bdata[k] { + t.Fatalf("Transport %T read %d instead of %d for index %d of binary data 2", readTrans, v, transport_bdata[k], k) + } + } +} + +func CloseTransports(t *testing.T, readTrans TTransport, writeTrans TTransport) { + err := readTrans.Close() + if err != nil { + t.Errorf("Transport %T cannot close read transport: %s", readTrans, err) + } + if writeTrans != readTrans { + err = writeTrans.Close() + if err != nil { + t.Errorf("Transport %T cannot close write transport: %s", writeTrans, err) + } + } +} + +func FindAvailableTCPServerPort(startPort int) (net.Addr, error) { + for i := startPort; i < 65535; i++ { + s := "127.0.0.1:" + strconv.Itoa(i) + l, err := net.Listen("tcp", s) + if err == nil { + l.Close() + return net.ResolveTCPAddr("tcp", s) + } + } + return nil, NewTTransportException(UNKNOWN_TRANSPORT_EXCEPTION, "Could not find available server port") +} + +func valueInSlice(value string, slice []string) bool { + for _, v := range slice { + if value == v { + return true + } + } + return false +} diff --git a/src/jaegertracing/thrift/lib/go/thrift/type.go b/src/jaegertracing/thrift/lib/go/thrift/type.go new file mode 100644 index 000000000..4292ffcad --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/thrift/type.go @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +// Type constants in the Thrift protocol +type TType byte + +const ( + STOP = 0 + VOID = 1 + BOOL = 2 + BYTE = 3 + I08 = 3 + DOUBLE = 4 + I16 = 6 + I32 = 8 + I64 = 10 + STRING = 11 + UTF7 = 11 + STRUCT = 12 + MAP = 13 + SET = 14 + LIST = 15 + UTF8 = 16 + UTF16 = 17 + //BINARY = 18 wrong and unusued +) + +var typeNames = map[int]string{ + STOP: "STOP", + VOID: "VOID", + BOOL: "BOOL", + BYTE: "BYTE", + DOUBLE: "DOUBLE", + I16: "I16", + I32: "I32", + I64: "I64", + STRING: "STRING", + STRUCT: "STRUCT", + MAP: "MAP", + SET: "SET", + LIST: "LIST", + UTF8: "UTF8", + UTF16: "UTF16", +} + +func (p TType) String() string { + if s, ok := typeNames[int(p)]; ok { + return s + } + return "Unknown" +} diff --git a/src/jaegertracing/thrift/lib/go/thrift/zlib_transport.go b/src/jaegertracing/thrift/lib/go/thrift/zlib_transport.go new file mode 100644 index 000000000..f3d42673a --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/thrift/zlib_transport.go @@ -0,0 +1,132 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one +* or more contributor license agreements. See the NOTICE file +* distributed with this work for additional information +* regarding copyright ownership. The ASF licenses this file +* to you under the Apache License, Version 2.0 (the +* "License"); you may not use this file except in compliance +* with the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an +* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +* KIND, either express or implied. See the License for the +* specific language governing permissions and limitations +* under the License. + */ + +package thrift + +import ( + "compress/zlib" + "context" + "io" + "log" +) + +// TZlibTransportFactory is a factory for TZlibTransport instances +type TZlibTransportFactory struct { + level int + factory TTransportFactory +} + +// TZlibTransport is a TTransport implementation that makes use of zlib compression. +type TZlibTransport struct { + reader io.ReadCloser + transport TTransport + writer *zlib.Writer +} + +// GetTransport constructs a new instance of NewTZlibTransport +func (p *TZlibTransportFactory) GetTransport(trans TTransport) (TTransport, error) { + if p.factory != nil { + // wrap other factory + var err error + trans, err = p.factory.GetTransport(trans) + if err != nil { + return nil, err + } + } + return NewTZlibTransport(trans, p.level) +} + +// NewTZlibTransportFactory constructs a new instance of NewTZlibTransportFactory +func NewTZlibTransportFactory(level int) *TZlibTransportFactory { + return &TZlibTransportFactory{level: level, factory: nil} +} + +// NewTZlibTransportFactory constructs a new instance of TZlibTransportFactory +// as a wrapper over existing transport factory +func NewTZlibTransportFactoryWithFactory(level int, factory TTransportFactory) *TZlibTransportFactory { + return &TZlibTransportFactory{level: level, factory: factory} +} + +// NewTZlibTransport constructs a new instance of TZlibTransport +func NewTZlibTransport(trans TTransport, level int) (*TZlibTransport, error) { + w, err := zlib.NewWriterLevel(trans, level) + if err != nil { + log.Println(err) + return nil, err + } + + return &TZlibTransport{ + writer: w, + transport: trans, + }, nil +} + +// Close closes the reader and writer (flushing any unwritten data) and closes +// the underlying transport. +func (z *TZlibTransport) Close() error { + if z.reader != nil { + if err := z.reader.Close(); err != nil { + return err + } + } + if err := z.writer.Close(); err != nil { + return err + } + return z.transport.Close() +} + +// Flush flushes the writer and its underlying transport. +func (z *TZlibTransport) Flush(ctx context.Context) error { + if err := z.writer.Flush(); err != nil { + return err + } + return z.transport.Flush(ctx) +} + +// IsOpen returns true if the transport is open +func (z *TZlibTransport) IsOpen() bool { + return z.transport.IsOpen() +} + +// Open opens the transport for communication +func (z *TZlibTransport) Open() error { + return z.transport.Open() +} + +func (z *TZlibTransport) Read(p []byte) (int, error) { + if z.reader == nil { + r, err := zlib.NewReader(z.transport) + if err != nil { + return 0, NewTTransportExceptionFromError(err) + } + z.reader = r + } + + return z.reader.Read(p) +} + +// RemainingBytes returns the size in bytes of the data that is still to be +// read. +func (z *TZlibTransport) RemainingBytes() uint64 { + return z.transport.RemainingBytes() +} + +func (z *TZlibTransport) Write(p []byte) (int, error) { + return z.writer.Write(p) +} diff --git a/src/jaegertracing/thrift/lib/go/thrift/zlib_transport_test.go b/src/jaegertracing/thrift/lib/go/thrift/zlib_transport_test.go new file mode 100644 index 000000000..3c6f11eb5 --- /dev/null +++ b/src/jaegertracing/thrift/lib/go/thrift/zlib_transport_test.go @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import ( + "compress/zlib" + "testing" +) + +func TestZlibTransport(t *testing.T) { + trans, err := NewTZlibTransport(NewTMemoryBuffer(), zlib.BestCompression) + if err != nil { + t.Fatal(err) + } + TransportTest(t, trans, trans) +} + +type DummyTransportFactory struct{} + +func (p *DummyTransportFactory) GetTransport(trans TTransport) (TTransport, error) { + return NewTMemoryBuffer(), nil +} + +func TestZlibFactoryTransportWithFactory(t *testing.T) { + factory := NewTZlibTransportFactoryWithFactory( + zlib.BestCompression, + &DummyTransportFactory{}, + ) + buffer := NewTMemoryBuffer() + trans, err := factory.GetTransport(buffer) + if err != nil { + t.Fatal(err) + } + TransportTest(t, trans, trans) +} + +func TestZlibFactoryTransportWithoutFactory(t *testing.T) { + factory := NewTZlibTransportFactoryWithFactory(zlib.BestCompression, nil) + buffer := NewTMemoryBuffer() + trans, err := factory.GetTransport(buffer) + if err != nil { + t.Fatal(err) + } + TransportTest(t, trans, trans) +} |