diff options
Diffstat (limited to 'src/jaegertracing/thrift/lib/lua')
21 files changed, 4314 insertions, 0 deletions
diff --git a/src/jaegertracing/thrift/lib/lua/Makefile.am b/src/jaegertracing/thrift/lib/lua/Makefile.am new file mode 100644 index 000000000..3b272f56c --- /dev/null +++ b/src/jaegertracing/thrift/lib/lua/Makefile.am @@ -0,0 +1,73 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +AUTOMAKE_OPTIONS = subdir-objects nostdinc + +SUBDIRS = . + +lib_LTLIBRARIES = \ + libluasocket.la \ + liblualongnumber.la \ + libluabpack.la \ + libluabitwise.la + +libluasocket_la_SOURCES = \ + src/luasocket.c \ + src/usocket.c + +nobase_include_HEADERS = src/socket.h + +libluasocket_la_CPPFLAGS = $(AM_CPPFLAGS) $(LUA_INCLUDE) -DLUA_COMPAT_MODULE +libluasocket_la_LDFLAGS = $(AM_LDFLAGS) +libluasocket_la_LIBADD = $(LUA_LIB) -lm + +libluabpack_la_SOURCES = src/luabpack.c + +libluabpack_la_CPPFLAGS = $(AM_CPPFLAGS) $(LUA_INCLUDE) -DLUA_COMPAT_MODULE +libluabpack_la_LDFLAGS = $(AM_LDFLAGS) +libluabpack_la_LIBADD = liblualongnumber.la $(LUA_LIB) -lm + +libluabitwise_la_SOURCES = src/luabitwise.c + +libluabitwise_la_CPPFLAGS = $(AM_CPPFLAGS) $(LUA_INCLUDE) -DLUA_COMPAT_MODULE +libluabitwise_la_LDFLAGS = $(AM_LDFLAGS) +libluabitwise_la_LIBADD = $(LUA_LIB) -lm + +liblualongnumber_la_SOURCES = \ + src/lualongnumber.c \ + src/longnumberutils.c + +liblualongnumber_la_CPPFLAGS = $(AM_CPPFLAGS) $(LUA_INCLUDE) -DLUA_COMPAT_MODULE +liblualongnumber_la_LDFLAGS = $(AM_LDFLAGS) +liblualongnumber_la_LIBADD = $(LUA_LIB) -lm + +EXTRA_DIST = \ + coding_standards.md \ + TBinaryProtocol.lua \ + TBufferedTransport.lua \ + TCompactProtocol.lua \ + TFramedTransport.lua \ + Thrift.lua \ + THttpTransport.lua \ + TJsonProtocol.lua \ + TMemoryBuffer.lua \ + TProtocol.lua \ + TServer.lua \ + TSocket.lua \ + TTransport.lua diff --git a/src/jaegertracing/thrift/lib/lua/TBinaryProtocol.lua b/src/jaegertracing/thrift/lib/lua/TBinaryProtocol.lua new file mode 100644 index 000000000..4b8e98a9d --- /dev/null +++ b/src/jaegertracing/thrift/lib/lua/TBinaryProtocol.lua @@ -0,0 +1,264 @@ +-- +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. +-- + +require 'TProtocol' +require 'libluabpack' +require 'libluabitwise' + +TBinaryProtocol = __TObject.new(TProtocolBase, { + __type = 'TBinaryProtocol', + VERSION_MASK = -65536, -- 0xffff0000 + VERSION_1 = -2147418112, -- 0x80010000 + TYPE_MASK = 0x000000ff, + strictRead = false, + strictWrite = true +}) + +function TBinaryProtocol:writeMessageBegin(name, ttype, seqid) + if self.strictWrite then + self:writeI32(libluabitwise.bor(TBinaryProtocol.VERSION_1, ttype)) + self:writeString(name) + self:writeI32(seqid) + else + self:writeString(name) + self:writeByte(ttype) + self:writeI32(seqid) + end +end + +function TBinaryProtocol:writeMessageEnd() +end + +function TBinaryProtocol:writeStructBegin(name) +end + +function TBinaryProtocol:writeStructEnd() +end + +function TBinaryProtocol:writeFieldBegin(name, ttype, id) + self:writeByte(ttype) + self:writeI16(id) +end + +function TBinaryProtocol:writeFieldEnd() +end + +function TBinaryProtocol:writeFieldStop() + self:writeByte(TType.STOP); +end + +function TBinaryProtocol:writeMapBegin(ktype, vtype, size) + self:writeByte(ktype) + self:writeByte(vtype) + self:writeI32(size) +end + +function TBinaryProtocol:writeMapEnd() +end + +function TBinaryProtocol:writeListBegin(etype, size) + self:writeByte(etype) + self:writeI32(size) +end + +function TBinaryProtocol:writeListEnd() +end + +function TBinaryProtocol:writeSetBegin(etype, size) + self:writeByte(etype) + self:writeI32(size) +end + +function TBinaryProtocol:writeSetEnd() +end + +function TBinaryProtocol:writeBool(bool) + if bool then + self:writeByte(1) + else + self:writeByte(0) + end +end + +function TBinaryProtocol:writeByte(byte) + local buff = libluabpack.bpack('c', byte) + self.trans:write(buff) +end + +function TBinaryProtocol:writeI16(i16) + local buff = libluabpack.bpack('s', i16) + self.trans:write(buff) +end + +function TBinaryProtocol:writeI32(i32) + local buff = libluabpack.bpack('i', i32) + self.trans:write(buff) +end + +function TBinaryProtocol:writeI64(i64) + local buff = libluabpack.bpack('l', i64) + self.trans:write(buff) +end + +function TBinaryProtocol:writeDouble(dub) + local buff = libluabpack.bpack('d', dub) + self.trans:write(buff) +end + +function TBinaryProtocol:writeString(str) + -- Should be utf-8 + self:writeI32(string.len(str)) + self.trans:write(str) +end + +function TBinaryProtocol:readMessageBegin() + local sz, ttype, name, seqid = self:readI32() + if sz < 0 then + local version = libluabitwise.band(sz, TBinaryProtocol.VERSION_MASK) + if version ~= TBinaryProtocol.VERSION_1 then + terror(TProtocolException:new{ + message = 'Bad version in readMessageBegin: ' .. sz + }) + end + ttype = libluabitwise.band(sz, TBinaryProtocol.TYPE_MASK) + name = self:readString() + seqid = self:readI32() + else + if self.strictRead then + terror(TProtocolException:new{message = 'No protocol version header'}) + end + name = self.trans:readAll(sz) + ttype = self:readByte() + seqid = self:readI32() + end + return name, ttype, seqid +end + +function TBinaryProtocol:readMessageEnd() +end + +function TBinaryProtocol:readStructBegin() + return nil +end + +function TBinaryProtocol:readStructEnd() +end + +function TBinaryProtocol:readFieldBegin() + local ttype = self:readByte() + if ttype == TType.STOP then + return nil, ttype, 0 + end + local id = self:readI16() + return nil, ttype, id +end + +function TBinaryProtocol:readFieldEnd() +end + +function TBinaryProtocol:readMapBegin() + local ktype = self:readByte() + local vtype = self:readByte() + local size = self:readI32() + return ktype, vtype, size +end + +function TBinaryProtocol:readMapEnd() +end + +function TBinaryProtocol:readListBegin() + local etype = self:readByte() + local size = self:readI32() + return etype, size +end + +function TBinaryProtocol:readListEnd() +end + +function TBinaryProtocol:readSetBegin() + local etype = self:readByte() + local size = self:readI32() + return etype, size +end + +function TBinaryProtocol:readSetEnd() +end + +function TBinaryProtocol:readBool() + local byte = self:readByte() + if byte == 0 then + return false + end + return true +end + +function TBinaryProtocol:readByte() + local buff = self.trans:readAll(1) + local val = libluabpack.bunpack('c', buff) + return val +end + +function TBinaryProtocol:readI16() + local buff = self.trans:readAll(2) + local val = libluabpack.bunpack('s', buff) + return val +end + +function TBinaryProtocol:readI32() + local buff = self.trans:readAll(4) + local val = libluabpack.bunpack('i', buff) + return val +end + +function TBinaryProtocol:readI64() + local buff = self.trans:readAll(8) + local val = libluabpack.bunpack('l', buff) + return val +end + +function TBinaryProtocol:readDouble() + local buff = self.trans:readAll(8) + local val = libluabpack.bunpack('d', buff) + return val +end + +function TBinaryProtocol:readString() + local len = self:readI32() + local str = self.trans:readAll(len) + return str +end + +TBinaryProtocolFactory = TProtocolFactory:new{ + __type = 'TBinaryProtocolFactory', + strictRead = false +} + +function TBinaryProtocolFactory:getProtocol(trans) + -- TODO Enforce that this must be a transport class (ie not a bool) + if not trans then + terror(TProtocolException:new{ + message = 'Must supply a transport to ' .. ttype(self) + }) + end + return TBinaryProtocol:new{ + trans = trans, + strictRead = self.strictRead, + strictWrite = true + } +end diff --git a/src/jaegertracing/thrift/lib/lua/TBufferedTransport.lua b/src/jaegertracing/thrift/lib/lua/TBufferedTransport.lua new file mode 100644 index 000000000..45ef4b1c7 --- /dev/null +++ b/src/jaegertracing/thrift/lib/lua/TBufferedTransport.lua @@ -0,0 +1,91 @@ +-- +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. +-- + +require 'TTransport' + +TBufferedTransport = TTransportBase:new{ + __type = 'TBufferedTransport', + rBufSize = 2048, + wBufSize = 2048, + wBuf = '', + rBuf = '' +} + +function TBufferedTransport:new(obj) + if ttype(obj) ~= 'table' then + error(ttype(self) .. 'must be initialized with a table') + end + + -- Ensure a transport is provided + if not obj.trans then + error('You must provide ' .. ttype(self) .. ' with a trans') + end + + return TTransportBase.new(self, obj) +end + +function TBufferedTransport:isOpen() + return self.trans:isOpen() +end + +function TBufferedTransport:open() + return self.trans:open() +end + +function TBufferedTransport:close() + return self.trans:close() +end + +function TBufferedTransport:read(len) + return self.trans:read(len) +end + +function TBufferedTransport:readAll(len) + return self.trans:readAll(len) +end + +function TBufferedTransport:write(buf) + self.wBuf = self.wBuf .. buf + if string.len(self.wBuf) >= self.wBufSize then + self.trans:write(self.wBuf) + self.wBuf = '' + end +end + +function TBufferedTransport:flush() + if string.len(self.wBuf) > 0 then + self.trans:write(self.wBuf) + self.wBuf = '' + end +end + +TBufferedTransportFactory = TTransportFactoryBase:new{ + __type = 'TBufferedTransportFactory' +} + +function TBufferedTransportFactory:getTransport(trans) + if not trans then + terror(TTransportException:new{ + message = 'Must supply a transport to ' .. ttype(self) + }) + end + return TBufferedTransport:new{ + trans = trans + } +end diff --git a/src/jaegertracing/thrift/lib/lua/TCompactProtocol.lua b/src/jaegertracing/thrift/lib/lua/TCompactProtocol.lua new file mode 100644 index 000000000..877595a5d --- /dev/null +++ b/src/jaegertracing/thrift/lib/lua/TCompactProtocol.lua @@ -0,0 +1,457 @@ +-- +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. +-- + +require 'TProtocol' +require 'libluabpack' +require 'libluabitwise' +require 'liblualongnumber' + +TCompactProtocol = __TObject.new(TProtocolBase, { + __type = 'TCompactProtocol', + COMPACT_PROTOCOL_ID = 0x82, + COMPACT_VERSION = 1, + COMPACT_VERSION_MASK = 0x1f, + COMPACT_TYPE_MASK = 0xE0, + COMPACT_TYPE_BITS = 0x07, + COMPACT_TYPE_SHIFT_AMOUNT = 5, + + -- Used to keep track of the last field for the current and previous structs, + -- so we can do the delta stuff. + lastField = {}, + lastFieldId = 0, + lastFieldIndex = 1, + + -- If we encounter a boolean field begin, save the TField here so it can + -- have the value incorporated. + booleanFieldName = "", + booleanFieldId = 0, + booleanFieldPending = false, + + -- If we read a field header, and it's a boolean field, save the boolean + -- value here so that readBool can use it. + boolValue = false, + boolValueIsNotNull = false, +}) + +TCompactType = { + COMPACT_BOOLEAN_TRUE = 0x01, + COMPACT_BOOLEAN_FALSE = 0x02, + COMPACT_BYTE = 0x03, + COMPACT_I16 = 0x04, + COMPACT_I32 = 0x05, + COMPACT_I64 = 0x06, + COMPACT_DOUBLE = 0x07, + COMPACT_BINARY = 0x08, + COMPACT_LIST = 0x09, + COMPACT_SET = 0x0A, + COMPACT_MAP = 0x0B, + COMPACT_STRUCT = 0x0C +} + +TTypeToCompactType = {} +TTypeToCompactType[TType.STOP] = TType.STOP +TTypeToCompactType[TType.BOOL] = TCompactType.COMPACT_BOOLEAN_TRUE +TTypeToCompactType[TType.BYTE] = TCompactType.COMPACT_BYTE +TTypeToCompactType[TType.I16] = TCompactType.COMPACT_I16 +TTypeToCompactType[TType.I32] = TCompactType.COMPACT_I32 +TTypeToCompactType[TType.I64] = TCompactType.COMPACT_I64 +TTypeToCompactType[TType.DOUBLE] = TCompactType.COMPACT_DOUBLE +TTypeToCompactType[TType.STRING] = TCompactType.COMPACT_BINARY +TTypeToCompactType[TType.LIST] = TCompactType.COMPACT_LIST +TTypeToCompactType[TType.SET] = TCompactType.COMPACT_SET +TTypeToCompactType[TType.MAP] = TCompactType.COMPACT_MAP +TTypeToCompactType[TType.STRUCT] = TCompactType.COMPACT_STRUCT + +CompactTypeToTType = {} +CompactTypeToTType[TType.STOP] = TType.STOP +CompactTypeToTType[TCompactType.COMPACT_BOOLEAN_TRUE] = TType.BOOL +CompactTypeToTType[TCompactType.COMPACT_BOOLEAN_FALSE] = TType.BOOL +CompactTypeToTType[TCompactType.COMPACT_BYTE] = TType.BYTE +CompactTypeToTType[TCompactType.COMPACT_I16] = TType.I16 +CompactTypeToTType[TCompactType.COMPACT_I32] = TType.I32 +CompactTypeToTType[TCompactType.COMPACT_I64] = TType.I64 +CompactTypeToTType[TCompactType.COMPACT_DOUBLE] = TType.DOUBLE +CompactTypeToTType[TCompactType.COMPACT_BINARY] = TType.STRING +CompactTypeToTType[TCompactType.COMPACT_LIST] = TType.LIST +CompactTypeToTType[TCompactType.COMPACT_SET] = TType.SET +CompactTypeToTType[TCompactType.COMPACT_MAP] = TType.MAP +CompactTypeToTType[TCompactType.COMPACT_STRUCT] = TType.STRUCT + +function TCompactProtocol:resetLastField() + self.lastField = {} + self.lastFieldId = 0 + self.lastFieldIndex = 1 +end + +function TCompactProtocol:packCompactType(ktype, vtype) + return libluabitwise.bor(libluabitwise.shiftl(ktype, 4), vtype) +end + +function TCompactProtocol:writeMessageBegin(name, ttype, seqid) + self:writeByte(TCompactProtocol.COMPACT_PROTOCOL_ID) + self:writeByte(libluabpack.packMesgType(TCompactProtocol.COMPACT_VERSION, + TCompactProtocol.COMPACT_VERSION_MASK,ttype, + TCompactProtocol.COMPACT_TYPE_SHIFT_AMOUNT, + TCompactProtocol.COMPACT_TYPE_MASK)) + self:writeVarint32(seqid) + self:writeString(name) + self:resetLastField() +end + +function TCompactProtocol:writeMessageEnd() +end + +function TCompactProtocol:writeStructBegin(name) + self.lastFieldIndex = self.lastFieldIndex + 1 + self.lastField[self.lastFieldIndex] = self.lastFieldId + self.lastFieldId = 0 +end + +function TCompactProtocol:writeStructEnd() + self.lastFieldIndex = self.lastFieldIndex - 1 + self.lastFieldId = self.lastField[self.lastFieldIndex] +end + +function TCompactProtocol:writeFieldBegin(name, ttype, id) + if ttype == TType.BOOL then + self.booleanFieldName = name + self.booleanFieldId = id + self.booleanFieldPending = true + else + self:writeFieldBeginInternal(name, ttype, id, -1) + end +end + +function TCompactProtocol:writeFieldEnd() +end + +function TCompactProtocol:writeFieldStop() + self:writeByte(TType.STOP); +end + +function TCompactProtocol:writeMapBegin(ktype, vtype, size) + if size == 0 then + self:writeByte(0) + else + self:writeVarint32(size) + self:writeByte(self:packCompactType(TTypeToCompactType[ktype], TTypeToCompactType[vtype])) + end +end + +function TCompactProtocol:writeMapEnd() +end + +function TCompactProtocol:writeListBegin(etype, size) + self:writeCollectionBegin(etype, size) +end + +function TCompactProtocol:writeListEnd() +end + +function TCompactProtocol:writeSetBegin(etype, size) + self:writeCollectionBegin(etype, size) +end + +function TCompactProtocol:writeSetEnd() +end + +function TCompactProtocol:writeBool(bool) + local value = TCompactType.COMPACT_BOOLEAN_FALSE + if bool then + value = TCompactType.COMPACT_BOOLEAN_TRUE + end + print(value,self.booleanFieldPending,self.booleanFieldId) + if self.booleanFieldPending then + self:writeFieldBeginInternal(self.booleanFieldName, TType.BOOL, self.booleanFieldId, value) + self.booleanFieldPending = false + else + self:writeByte(value) + end +end + +function TCompactProtocol:writeByte(byte) + local buff = libluabpack.bpack('c', byte) + self.trans:write(buff) +end + +function TCompactProtocol:writeI16(i16) + self:writeVarint32(libluabpack.i32ToZigzag(i16)) +end + +function TCompactProtocol:writeI32(i32) + self:writeVarint32(libluabpack.i32ToZigzag(i32)) +end + +function TCompactProtocol:writeI64(i64) + self:writeVarint64(libluabpack.i64ToZigzag(i64)) +end + +function TCompactProtocol:writeDouble(dub) + local buff = libluabpack.bpack('d', dub) + self.trans:write(buff) +end + +function TCompactProtocol:writeString(str) + -- Should be utf-8 + self:writeBinary(str) +end + +function TCompactProtocol:writeBinary(str) + -- Should be utf-8 + self:writeVarint32(string.len(str)) + self.trans:write(str) +end + +function TCompactProtocol:writeFieldBeginInternal(name, ttype, id, typeOverride) + if typeOverride == -1 then + typeOverride = TTypeToCompactType[ttype] + end + local offset = id - self.lastFieldId + if id > self.lastFieldId and offset <= 15 then + self:writeByte(libluabitwise.bor(libluabitwise.shiftl(offset, 4), typeOverride)) + else + self:writeByte(typeOverride) + self:writeI16(id) + end + self.lastFieldId = id +end + +function TCompactProtocol:writeCollectionBegin(etype, size) + if size <= 14 then + self:writeByte(libluabitwise.bor(libluabitwise.shiftl(size, 4), TTypeToCompactType[etype])) + else + self:writeByte(libluabitwise.bor(0xf0, TTypeToCompactType[etype])) + self:writeVarint32(size) + end +end + +function TCompactProtocol:writeVarint32(i32) + -- Should be utf-8 + local str = libluabpack.toVarint32(i32) + self.trans:write(str) +end + +function TCompactProtocol:writeVarint64(i64) + -- Should be utf-8 + local str = libluabpack.toVarint64(i64) + self.trans:write(str) +end + +function TCompactProtocol:readMessageBegin() + local protocolId = self:readSignByte() + if protocolId ~= self.COMPACT_PROTOCOL_ID then + terror(TProtocolException:new{ + message = "Expected protocol id " .. self.COMPACT_PROTOCOL_ID .. " but got " .. protocolId}) + end + local versionAndType = self:readSignByte() + local version = libluabitwise.band(versionAndType, self.COMPACT_VERSION_MASK) + local ttype = libluabitwise.band(libluabitwise.shiftr(versionAndType, + self.COMPACT_TYPE_SHIFT_AMOUNT), self.COMPACT_TYPE_BITS) + if version ~= self.COMPACT_VERSION then + terror(TProtocolException:new{ + message = "Expected version " .. self.COMPACT_VERSION .. " but got " .. version}) + end + local seqid = self:readVarint32() + local name = self:readString() + return name, ttype, seqid +end + +function TCompactProtocol:readMessageEnd() +end + +function TCompactProtocol:readStructBegin() + self.lastField[self.lastFieldIndex] = self.lastFieldId + self.lastFieldIndex = self.lastFieldIndex + 1 + self.lastFieldId = 0 + return nil +end + +function TCompactProtocol:readStructEnd() + self.lastFieldIndex = self.lastFieldIndex - 1 + self.lastFieldId = self.lastField[self.lastFieldIndex] +end + +function TCompactProtocol:readFieldBegin() + local field_and_ttype = self:readSignByte() + local ttype = self:getTType(field_and_ttype) + if ttype == TType.STOP then + return nil, ttype, 0 + end + -- mask off the 4 MSB of the type header. it could contain a field id delta. + local modifier = libluabitwise.shiftr(libluabitwise.band(field_and_ttype, 0xf0), 4) + local id = 0 + if modifier == 0 then + id = self:readI16() + else + id = self.lastFieldId + modifier + end + if ttype == TType.BOOL then + boolValue = libluabitwise.band(field_and_ttype, 0x0f) == TCompactType.COMPACT_BOOLEAN_TRUE + boolValueIsNotNull = true + end + self.lastFieldId = id + return nil, ttype, id +end + +function TCompactProtocol:readFieldEnd() +end + +function TCompactProtocol:readMapBegin() + local size = self:readVarint32() + if size < 0 then + return nil,nil,nil + end + local kvtype = self:readSignByte() + local ktype = self:getTType(libluabitwise.shiftr(kvtype, 4)) + local vtype = self:getTType(kvtype) + return ktype, vtype, size +end + +function TCompactProtocol:readMapEnd() +end + +function TCompactProtocol:readListBegin() + local size_and_type = self:readSignByte() + local size = libluabitwise.band(libluabitwise.shiftr(size_and_type, 4), 0x0f) + if size == 15 then + size = self:readVarint32() + end + if size < 0 then + return nil,nil + end + local etype = self:getTType(libluabitwise.band(size_and_type, 0x0f)) + return etype, size +end + +function TCompactProtocol:readListEnd() +end + +function TCompactProtocol:readSetBegin() + return self:readListBegin() +end + +function TCompactProtocol:readSetEnd() +end + +function TCompactProtocol:readBool() + if boolValueIsNotNull then + boolValueIsNotNull = true + return boolValue + end + local val = self:readSignByte() + if val == TCompactType.COMPACT_BOOLEAN_TRUE then + return true + end + return false +end + +function TCompactProtocol:readByte() + local buff = self.trans:readAll(1) + local val = libluabpack.bunpack('c', buff) + return val +end + +function TCompactProtocol:readSignByte() + local buff = self.trans:readAll(1) + local val = libluabpack.bunpack('C', buff) + return val +end + +function TCompactProtocol:readI16() + return self:readI32() +end + +function TCompactProtocol:readI32() + local v = self:readVarint32() + local value = libluabpack.zigzagToI32(v) + return value +end + +function TCompactProtocol:readI64() + local value = self:readVarint64() + return value +end + +function TCompactProtocol:readDouble() + local buff = self.trans:readAll(8) + local val = libluabpack.bunpack('d', buff) + return val +end + +function TCompactProtocol:readString() + return self:readBinary() +end + +function TCompactProtocol:readBinary() + local size = self:readVarint32() + if size <= 0 then + return "" + end + return self.trans:readAll(size) +end + +function TCompactProtocol:readVarint32() + local shiftl = 0 + local result = 0 + while true do + b = self:readByte() + result = libluabitwise.bor(result, + libluabitwise.shiftl(libluabitwise.band(b, 0x7f), shiftl)) + if libluabitwise.band(b, 0x80) ~= 0x80 then + break + end + shiftl = shiftl + 7 + end + return result +end + +function TCompactProtocol:readVarint64() + local result = liblualongnumber.new + local data = result(0) + local shiftl = 0 + while true do + b = self:readByte() + endFlag, data = libluabpack.fromVarint64(b, shiftl, data) + shiftl = shiftl + 7 + if endFlag == 0 then + break + end + end + return data +end + +function TCompactProtocol:getTType(ctype) + return CompactTypeToTType[libluabitwise.band(ctype, 0x0f)] +end + +TCompactProtocolFactory = TProtocolFactory:new{ + __type = 'TCompactProtocolFactory', +} + +function TCompactProtocolFactory:getProtocol(trans) + -- TODO Enforce that this must be a transport class (ie not a bool) + if not trans then + terror(TProtocolException:new{ + message = 'Must supply a transport to ' .. ttype(self) + }) + end + return TCompactProtocol:new{ + trans = trans + } +end diff --git a/src/jaegertracing/thrift/lib/lua/TFramedTransport.lua b/src/jaegertracing/thrift/lib/lua/TFramedTransport.lua new file mode 100644 index 000000000..768e2d997 --- /dev/null +++ b/src/jaegertracing/thrift/lib/lua/TFramedTransport.lua @@ -0,0 +1,118 @@ +-- +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. +-- + +require 'TTransport' +require 'libluabpack' + +TFramedTransport = TTransportBase:new{ + __type = 'TFramedTransport', + doRead = true, + doWrite = true, + wBuf = '', + rBuf = '' +} + +function TFramedTransport:new(obj) + if ttype(obj) ~= 'table' then + error(ttype(self) .. 'must be initialized with a table') + end + + -- Ensure a transport is provided + if not obj.trans then + error('You must provide ' .. ttype(self) .. ' with a trans') + end + + return TTransportBase.new(self, obj) +end + +function TFramedTransport:isOpen() + return self.trans:isOpen() +end + +function TFramedTransport:open() + return self.trans:open() +end + +function TFramedTransport:close() + return self.trans:close() +end + +function TFramedTransport:read(len) + if string.len(self.rBuf) == 0 then + self:__readFrame() + end + + if self.doRead == false then + return self.trans:read(len) + end + + if len > string.len(self.rBuf) then + local val = self.rBuf + self.rBuf = '' + return val + end + + local val = string.sub(self.rBuf, 0, len) + self.rBuf = string.sub(self.rBuf, len+1) + return val +end + +function TFramedTransport:__readFrame() + local buf = self.trans:readAll(4) + local frame_len = libluabpack.bunpack('i', buf) + self.rBuf = self.trans:readAll(frame_len) +end + + +function TFramedTransport:write(buf, len) + if self.doWrite == false then + return self.trans:write(buf, len) + end + + if len and len < string.len(buf) then + buf = string.sub(buf, 0, len) + end + self.wBuf = self.wBuf .. buf +end + +function TFramedTransport:flush() + if self.doWrite == false then + return self.trans:flush() + end + + -- If the write fails we still want wBuf to be clear + local tmp = self.wBuf + self.wBuf = '' + local frame_len_buf = libluabpack.bpack("i", string.len(tmp)) + tmp = frame_len_buf .. tmp + self.trans:write(tmp) + self.trans:flush() +end + +TFramedTransportFactory = TTransportFactoryBase:new{ + __type = 'TFramedTransportFactory' +} +function TFramedTransportFactory:getTransport(trans) + if not trans then + terror(TProtocolException:new{ + message = 'Must supply a transport to ' .. ttype(self) + }) + end + return TFramedTransport:new{trans = trans} +end diff --git a/src/jaegertracing/thrift/lib/lua/THttpTransport.lua b/src/jaegertracing/thrift/lib/lua/THttpTransport.lua new file mode 100644 index 000000000..2951db79f --- /dev/null +++ b/src/jaegertracing/thrift/lib/lua/THttpTransport.lua @@ -0,0 +1,182 @@ +-- +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. +-- + +require 'TTransport' + +THttpTransport = TTransportBase:new{ + __type = 'THttpTransport', + path = '/', + wBuf = '', + rBuf = '', + CRLF = '\r\n', + VERSION = version, + isServer = true +} + +function THttpTransport:new(obj) + if ttype(obj) ~= 'table' then + error(ttype(self) .. 'must be initialized with a table') + end + + -- Ensure a transport is provided + if not obj.trans then + error('You must provide ' .. ttype(self) .. ' with a trans') + end + + return TTransportBase.new(self, obj) +end + +function THttpTransport:isOpen() + return self.trans:isOpen() +end + +function THttpTransport:open() + return self.trans:open() +end + +function THttpTransport:close() + return self.trans:close() +end + +function THttpTransport:readAll(len) + return self:read(len) +end + +function THttpTransport:read(len) + if string.len(self.rBuf) == 0 then + self:_readMsg() + end + if len > string.len(self.rBuf) then + local val = self.rBuf + self.rBuf = '' + return val + end + + local val = string.sub(self.rBuf, 0, len) + self.rBuf = string.sub(self.rBuf, len+1) + return val +end + +function THttpTransport:_readMsg() + while true do + self.rBuf = self.rBuf .. self.trans:read(4) + if string.find(self.rBuf, self.CRLF .. self.CRLF) then + break + end + end + if not self.rBuf then + self.rBuf = "" + return + end + self:getLine() + local headers = self:_parseHeaders() + if not headers then + self.rBuf = "" + return + end + + local length = tonumber(headers["Content-Length"]) + if length then + length = length - string.len(self.rBuf) + self.rBuf = self.rBuf .. self.trans:readAll(length) + end + if self.rBuf == nil then + self.rBuf = "" + end +end + +function THttpTransport:getLine() + local a,b = string.find(self.rBuf, self.CRLF) + local line = "" + if a and b then + line = string.sub(self.rBuf, 0, a-1) + self.rBuf = string.sub(self.rBuf, b+1) + end + return line +end + +function THttpTransport:_parseHeaders() + local headers = {} + + repeat + local line = self:getLine() + for key, val in string.gmatch(line, "([%w%-]+)%s*:%s*(.+)") do + if headers[key] then + local delimiter = ", " + if key == "Set-Cookie" then + delimiter = "; " + end + headers[key] = headers[key] .. delimiter .. tostring(val) + else + headers[key] = tostring(val) + end + end + until string.find(line, "^%s*$") + + return headers +end + +function THttpTransport:write(buf, len) + if len and len < string.len(buf) then + buf = string.sub(buf, 0, len) + end + self.wBuf = self.wBuf .. buf +end + +function THttpTransport:writeHttpHeader(content_len) + if self.isServer then + local header = "HTTP/1.1 200 OK" .. self.CRLF + .. "Server: Thrift/" .. self.VERSION .. self.CRLF + .. "Access-Control-Allow-Origin: *" .. self.CRLF + .. "Content-Type: application/x-thrift" .. self.CRLF + .. "Content-Length: " .. content_len .. self.CRLF + .. "Connection: Keep-Alive" .. self.CRLF .. self.CRLF + self.trans:write(header) + else + local header = "POST " .. self.path .. " HTTP/1.1" .. self.CRLF + .. "Host: " .. self.trans.host .. self.CRLF + .. "Content-Type: application/x-thrift" .. self.CRLF + .. "Content-Length: " .. content_len .. self.CRLF + .. "Accept: application/x-thrift " .. self.CRLF + .. "User-Agent: Thrift/" .. self.VERSION .. " (Lua/THttpClient)" + .. self.CRLF .. self.CRLF + self.trans:write(header) + end +end + +function THttpTransport:flush() + -- If the write fails we still want wBuf to be clear + local tmp = self.wBuf + self.wBuf = '' + self:writeHttpHeader(string.len(tmp)) + self.trans:write(tmp) + self.trans:flush() +end + +THttpTransportFactory = TTransportFactoryBase:new{ + __type = 'THttpTransportFactory' +} +function THttpTransportFactory:getTransport(trans) + if not trans then + terror(TProtocolException:new{ + message = 'Must supply a transport to ' .. ttype(self) + }) + end + return THttpTransport:new{trans = trans} +end diff --git a/src/jaegertracing/thrift/lib/lua/TJsonProtocol.lua b/src/jaegertracing/thrift/lib/lua/TJsonProtocol.lua new file mode 100644 index 000000000..db08eecf1 --- /dev/null +++ b/src/jaegertracing/thrift/lib/lua/TJsonProtocol.lua @@ -0,0 +1,727 @@ +-- +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"), you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. +-- + +require 'TProtocol' +require 'libluabpack' +require 'libluabitwise' + +TJSONProtocol = __TObject.new(TProtocolBase, { + __type = 'TJSONProtocol', + THRIFT_JSON_PROTOCOL_VERSION = 1, + jsonContext = {}, + jsonContextVal = {first = true, colon = true, ttype = 2, null = true}, + jsonContextIndex = 1, + hasReadByte = "" +}) + +TTypeToString = {} +TTypeToString[TType.BOOL] = "tf" +TTypeToString[TType.BYTE] = "i8" +TTypeToString[TType.I16] = "i16" +TTypeToString[TType.I32] = "i32" +TTypeToString[TType.I64] = "i64" +TTypeToString[TType.DOUBLE] = "dbl" +TTypeToString[TType.STRING] = "str" +TTypeToString[TType.STRUCT] = "rec" +TTypeToString[TType.LIST] = "lst" +TTypeToString[TType.SET] = "set" +TTypeToString[TType.MAP] = "map" + +StringToTType = { + tf = TType.BOOL, + i8 = TType.BYTE, + i16 = TType.I16, + i32 = TType.I32, + i64 = TType.I64, + dbl = TType.DOUBLE, + str = TType.STRING, + rec = TType.STRUCT, + map = TType.MAP, + set = TType.SET, + lst = TType.LIST +} + +JSONNode = { + ObjectBegin = '{', + ObjectEnd = '}', + ArrayBegin = '[', + ArrayEnd = ']', + PairSeparator = ':', + ElemSeparator = ',', + Backslash = '\\', + StringDelimiter = '"', + ZeroChar = '0', + EscapeChar = 'u', + Nan = 'NaN', + Infinity = 'Infinity', + NegativeInfinity = '-Infinity', + EscapeChars = "\"\\bfnrt", + EscapePrefix = "\\u00" +} + +EscapeCharVals = { + '"', '\\', '\b', '\f', '\n', '\r', '\t' +} + +JSONCharTable = { + --0 1 2 3 4 5 6 7 8 9 A B C D E F + 0, 0, 0, 0, 0, 0, 0, 0, 98,116,110, 0,102,114, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 1, 1,34, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, +} + +-- character table string +local b='ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/' + +-- encoding +function base64_encode(data) + return ((data:gsub('.', function(x) + local r,b='',x:byte() + for i=8,1,-1 do r=r..(b%2^i-b%2^(i-1)>0 and '1' or '0') end + return r; + end)..'0000'):gsub('%d%d%d?%d?%d?%d?', function(x) + if (#x < 6) then return '' end + local c=0 + for i=1,6 do c=c+(x:sub(i,i)=='1' and 2^(6-i) or 0) end + return b:sub(c+1,c+1) + end)..({ '', '==', '=' })[#data%3+1]) +end + +-- decoding +function base64_decode(data) + data = string.gsub(data, '[^'..b..'=]', '') + return (data:gsub('.', function(x) + if (x == '=') then return '' end + local r,f='',(b:find(x)-1) + for i=6,1,-1 do r=r..(f%2^i-f%2^(i-1)>0 and '1' or '0') end + return r; + end):gsub('%d%d%d?%d?%d?%d?%d?%d?', function(x) + if (#x ~= 8) then return '' end + local c=0 + for i=1,8 do c=c+(x:sub(i,i)=='1' and 2^(8-i) or 0) end + return string.char(c) + end)) +end + +function TJSONProtocol:resetContext() + self.jsonContext = {} + self.jsonContextVal = {first = true, colon = true, ttype = 2, null = true} + self.jsonContextIndex = 1 +end + +function TJSONProtocol:contextPush(context) + self.jsonContextIndex = self.jsonContextIndex + 1 + self.jsonContext[self.jsonContextIndex] = self.jsonContextVal + self.jsonContextVal = context +end + +function TJSONProtocol:contextPop() + self.jsonContextVal = self.jsonContext[self.jsonContextIndex] + self.jsonContextIndex = self.jsonContextIndex - 1 +end + +function TJSONProtocol:escapeNum() + if self.jsonContextVal.ttype == 1 then + return self.jsonContextVal.colon + else + return false + end +end + +function TJSONProtocol:writeElemSeparator() + if self.jsonContextVal.null then + return + end + if self.jsonContextVal.first then + self.jsonContextVal.first = false + else + if self.jsonContextVal.ttype == 1 then + if self.jsonContextVal.colon then + self.trans:write(JSONNode.PairSeparator) + self.jsonContextVal.colon = false + else + self.trans:write(JSONNode.ElemSeparator) + self.jsonContextVal.colon = true + end + else + self.trans:write(JSONNode.ElemSeparator) + end + end +end + +function TJSONProtocol:hexChar(val) + val = libluabitwise.band(val, 0x0f) + if val < 10 then + return val + 48 + else + return val + 87 + end +end + +function TJSONProtocol:writeJSONEscapeChar(ch) + self.trans:write(JSONNode.EscapePrefix) + local outCh = hexChar(libluabitwise.shiftr(ch, 4)) + local buff = libluabpack.bpack('c', outCh) + self.trans:write(buff) + outCh = hexChar(ch) + buff = libluabpack.bpack('c', outCh) + self.trans:write(buff) +end + +function TJSONProtocol:writeJSONChar(byte) + ch = string.byte(byte) + if ch >= 0x30 then + if ch == JSONNode.Backslash then + self.trans:write(JSONNode.Backslash) + self.trans:write(JSONNode.Backslash) + else + self.trans:write(byte) + end + else + local outCh = JSONCharTable[ch+1] + if outCh == 1 then + self.trans:write(byte) + elseif outCh > 1 then + self.trans:write(JSONNode.Backslash) + local buff = libluabpack.bpack('c', outCh) + self.trans:write(buff) + else + self:writeJSONEscapeChar(ch) + end + end +end + +function TJSONProtocol:writeJSONString(str) + self:writeElemSeparator() + self.trans:write(JSONNode.StringDelimiter) + -- TODO escape special characters + local length = string.len(str) + local ii = 1 + while ii <= length do + self:writeJSONChar(string.sub(str, ii, ii)) + ii = ii + 1 + end + self.trans:write(JSONNode.StringDelimiter) +end + +function TJSONProtocol:writeJSONBase64(str) + self:writeElemSeparator() + self.trans:write(JSONNode.StringDelimiter) + local length = string.len(str) + local offset = 1 + while length >= 3 do + -- Encode 3 bytes at a time + local bytes = base64_encode(string.sub(str, offset, offset+3)) + self.trans:write(bytes) + length = length - 3 + offset = offset + 3 + end + if length > 0 then + local bytes = base64_encode(string.sub(str, offset, offset+length)) + self.trans:write(bytes) + end + self.trans:write(JSONNode.StringDelimiter) +end + +function TJSONProtocol:writeJSONInteger(num) + self:writeElemSeparator() + if self:escapeNum() then + self.trans:write(JSONNode.StringDelimiter) + end + local numstr = "" .. num + numstr = string.sub(numstr, string.find(numstr, "^[+-]?%d+")) + self.trans:write(numstr) + if self:escapeNum() then + self.trans:write(JSONNode.StringDelimiter) + end +end + +function TJSONProtocol:writeJSONDouble(dub) + self:writeElemSeparator() + local val = "" .. dub + local prefix = string.sub(val, 1, 1) + local special = false + if prefix == 'N' or prefix == 'n' then + val = JSONNode.Nan + special = true + elseif prefix == 'I' or prefix == 'i' then + val = JSONNode.Infinity + special = true + elseif prefix == '-' then + local secondByte = string.sub(val, 2, 2) + if secondByte == 'I' or secondByte == 'i' then + val = JSONNode.NegativeInfinity + special = true + end + end + + if special or self:escapeNum() then + self.trans:write(JSONNode.StringDelimiter) + end + self.trans:write(val) + if special or self:escapeNum() then + self.trans:write(JSONNode.StringDelimiter) + end +end + +function TJSONProtocol:writeJSONObjectBegin() + self:writeElemSeparator() + self.trans:write(JSONNode.ObjectBegin) + self:contextPush({first = true, colon = true, ttype = 1, null = false}) +end + +function TJSONProtocol:writeJSONObjectEnd() + self:contextPop() + self.trans:write(JSONNode.ObjectEnd) +end + +function TJSONProtocol:writeJSONArrayBegin() + self:writeElemSeparator() + self.trans:write(JSONNode.ArrayBegin) + self:contextPush({first = true, colon = true, ttype = 2, null = false}) +end + +function TJSONProtocol:writeJSONArrayEnd() + self:contextPop() + self.trans:write(JSONNode.ArrayEnd) +end + +function TJSONProtocol:writeMessageBegin(name, ttype, seqid) + self:resetContext() + self:writeJSONArrayBegin() + self:writeJSONInteger(TJSONProtocol.THRIFT_JSON_PROTOCOL_VERSION) + self:writeJSONString(name) + self:writeJSONInteger(ttype) + self:writeJSONInteger(seqid) +end + +function TJSONProtocol:writeMessageEnd() + self:writeJSONArrayEnd() +end + +function TJSONProtocol:writeStructBegin(name) + self:writeJSONObjectBegin() +end + +function TJSONProtocol:writeStructEnd() + self:writeJSONObjectEnd() +end + +function TJSONProtocol:writeFieldBegin(name, ttype, id) + self:writeJSONInteger(id) + self:writeJSONObjectBegin() + self:writeJSONString(TTypeToString[ttype]) +end + +function TJSONProtocol:writeFieldEnd() + self:writeJSONObjectEnd() +end + +function TJSONProtocol:writeFieldStop() +end + +function TJSONProtocol:writeMapBegin(ktype, vtype, size) + self:writeJSONArrayBegin() + self:writeJSONString(TTypeToString[ktype]) + self:writeJSONString(TTypeToString[vtype]) + self:writeJSONInteger(size) + return self:writeJSONObjectBegin() +end + +function TJSONProtocol:writeMapEnd() + self:writeJSONObjectEnd() + self:writeJSONArrayEnd() +end + +function TJSONProtocol:writeListBegin(etype, size) + self:writeJSONArrayBegin() + self:writeJSONString(TTypeToString[etype]) + self:writeJSONInteger(size) +end + +function TJSONProtocol:writeListEnd() + self:writeJSONArrayEnd() +end + +function TJSONProtocol:writeSetBegin(etype, size) + self:writeJSONArrayBegin() + self:writeJSONString(TTypeToString[etype]) + self:writeJSONInteger(size) +end + +function TJSONProtocol:writeSetEnd() + self:writeJSONArrayEnd() +end + +function TJSONProtocol:writeBool(bool) + if bool then + self:writeJSONInteger(1) + else + self:writeJSONInteger(0) + end +end + +function TJSONProtocol:writeByte(byte) + local buff = libluabpack.bpack('c', byte) + local val = libluabpack.bunpack('c', buff) + self:writeJSONInteger(val) +end + +function TJSONProtocol:writeI16(i16) + local buff = libluabpack.bpack('s', i16) + local val = libluabpack.bunpack('s', buff) + self:writeJSONInteger(val) +end + +function TJSONProtocol:writeI32(i32) + local buff = libluabpack.bpack('i', i32) + local val = libluabpack.bunpack('i', buff) + self:writeJSONInteger(val) +end + +function TJSONProtocol:writeI64(i64) + local buff = libluabpack.bpack('l', i64) + local val = libluabpack.bunpack('l', buff) + self:writeJSONInteger(tostring(val)) +end + +function TJSONProtocol:writeDouble(dub) + self:writeJSONDouble(string.format("%.16f", dub)) +end + +function TJSONProtocol:writeString(str) + self:writeJSONString(str) +end + +function TJSONProtocol:writeBinary(str) + -- Should be utf-8 + self:writeJSONBase64(str) +end + +function TJSONProtocol:readJSONSyntaxChar(ch) + local ch2 = "" + if self.hasReadByte ~= "" then + ch2 = self.hasReadByte + self.hasReadByte = "" + else + ch2 = self.trans:readAll(1) + end + if ch2 ~= ch then + terror(TProtocolException:new{message = "Expected ".. ch .. ", got " .. ch2}) + end +end + +function TJSONProtocol:readElemSeparator() + if self.jsonContextVal.null then + return + end + if self.jsonContextVal.first then + self.jsonContextVal.first = false + else + if self.jsonContextVal.ttype == 1 then + if self.jsonContextVal.colon then + self:readJSONSyntaxChar(JSONNode.PairSeparator) + self.jsonContextVal.colon = false + else + self:readJSONSyntaxChar(JSONNode.ElemSeparator) + self.jsonContextVal.colon = true + end + else + self:readJSONSyntaxChar(JSONNode.ElemSeparator) + end + end +end + +function TJSONProtocol:hexVal(ch) + local val = string.byte(ch) + if val >= 48 and val <= 57 then + return val - 48 + elseif val >= 97 and val <= 102 then + return val - 87 + else + terror(TProtocolException:new{message = "Expected hex val ([0-9a-f]); got " .. ch}) + end +end + +function TJSONProtocol:readJSONEscapeChar(ch) + self:readJSONSyntaxChar(JSONNode.ZeroChar) + self:readJSONSyntaxChar(JSONNode.ZeroChar) + local b1 = self.trans:readAll(1) + local b2 = self.trans:readAll(1) + return libluabitwise.shiftl(self:hexVal(b1), 4) + self:hexVal(b2) +end + + +function TJSONProtocol:readJSONString() + self:readElemSeparator() + self:readJSONSyntaxChar(JSONNode.StringDelimiter) + local result = "" + while true do + local ch = self.trans:readAll(1) + if ch == JSONNode.StringDelimiter then + break + end + if ch == JSONNode.Backslash then + ch = self.trans:readAll(1) + if ch == JSONNode.EscapeChar then + self:readJSONEscapeChar(ch) + else + local pos, _ = string.find(JSONNode.EscapeChars, ch) + if pos == nil then + terror(TProtocolException:new{message = "Expected control char, got " .. ch}) + end + ch = EscapeCharVals[pos] + end + end + result = result .. ch + end + return result +end + +function TJSONProtocol:readJSONBase64() + local result = self:readJSONString() + local length = string.len(result) + local str = "" + local offset = 1 + while length >= 4 do + local bytes = string.sub(result, offset, offset+4) + str = str .. base64_decode(bytes) + offset = offset + 4 + length = length - 4 + end + if length >= 0 then + str = str .. base64_decode(string.sub(result, offset, offset + length)) + end + return str +end + +function TJSONProtocol:readJSONNumericChars() + local result = "" + while true do + local ch = self.trans:readAll(1) + if string.find(ch, '[-+0-9.Ee]') then + result = result .. ch + else + self.hasReadByte = ch + break + end + end + return result +end + +function TJSONProtocol:readJSONLongInteger() + self:readElemSeparator() + if self:escapeNum() then + self:readJSONSyntaxChar(JSONNode.StringDelimiter) + end + local result = self:readJSONNumericChars() + if self:escapeNum() then + self:readJSONSyntaxChar(JSONNode.StringDelimiter) + end + return result +end + +function TJSONProtocol:readJSONInteger() + return tonumber(self:readJSONLongInteger()) +end + +function TJSONProtocol:readJSONDouble() + self:readElemSeparator() + local delimiter = self.trans:readAll(1) + local num = 0.0 + if delimiter == JSONNode.StringDelimiter then + local str = self:readJSONString() + if str == JSONNode.Nan then + num = 1.0 + elseif str == JSONNode.Infinity then + num = math.maxinteger + elseif str == JSONNode.NegativeInfinity then + num = math.mininteger + else + num = tonumber(str) + end + else + if self:escapeNum() then + self:readJSONSyntaxChar(JSONNode.StringDelimiter) + end + local result = self:readJSONNumericChars() + num = tonumber(delimiter.. result) + end + return num +end + +function TJSONProtocol:readJSONObjectBegin() + self:readElemSeparator() + self:readJSONSyntaxChar(JSONNode.ObjectBegin) + self:contextPush({first = true, colon = true, ttype = 1, null = false}) +end + +function TJSONProtocol:readJSONObjectEnd() + self:readJSONSyntaxChar(JSONNode.ObjectEnd) + self:contextPop() +end + +function TJSONProtocol:readJSONArrayBegin() + self:readElemSeparator() + self:readJSONSyntaxChar(JSONNode.ArrayBegin) + self:contextPush({first = true, colon = true, ttype = 2, null = false}) +end + +function TJSONProtocol:readJSONArrayEnd() + self:readJSONSyntaxChar(JSONNode.ArrayEnd) + self:contextPop() +end + +function TJSONProtocol:readMessageBegin() + self:resetContext() + self:readJSONArrayBegin() + local version = self:readJSONInteger() + if version ~= self.THRIFT_JSON_PROTOCOL_VERSION then + terror(TProtocolException:new{message = "Message contained bad version."}) + end + local name = self:readJSONString() + local ttype = self:readJSONInteger() + local seqid = self:readJSONInteger() + return name, ttype, seqid +end + +function TJSONProtocol:readMessageEnd() + self:readJSONArrayEnd() +end + +function TJSONProtocol:readStructBegin() + self:readJSONObjectBegin() + return nil +end + +function TJSONProtocol:readStructEnd() + self:readJSONObjectEnd() +end + +function TJSONProtocol:readFieldBegin() + local ttype = TType.STOP + local id = 0 + local ch = self.trans:readAll(1) + self.hasReadByte = ch + if ch ~= JSONNode.ObjectEnd then + id = self:readJSONInteger() + self:readJSONObjectBegin() + local typeName = self:readJSONString() + ttype = StringToTType[typeName] + end + return nil, ttype, id +end + +function TJSONProtocol:readFieldEnd() + self:readJSONObjectEnd() +end + +function TJSONProtocol:readMapBegin() + self:readJSONArrayBegin() + local typeName = self:readJSONString() + local ktype = StringToTType[typeName] + typeName = self:readJSONString() + local vtype = StringToTType[typeName] + local size = self:readJSONInteger() + self:readJSONObjectBegin() + return ktype, vtype, size +end + +function TJSONProtocol:readMapEnd() + self:readJSONObjectEnd() + self:readJSONArrayEnd() +end + +function TJSONProtocol:readListBegin() + self:readJSONArrayBegin() + local typeName = self:readJSONString() + local etype = StringToTType[typeName] + local size = self:readJSONInteger() + return etype, size +end + +function TJSONProtocol:readListEnd() + return self:readJSONArrayEnd() +end + +function TJSONProtocol:readSetBegin() + return self:readListBegin() +end + +function TJSONProtocol:readSetEnd() + return self:readJSONArrayEnd() +end + +function TJSONProtocol:readBool() + local result = self:readJSONInteger() + if result == 1 then + return true + else + return false + end +end + +function TJSONProtocol:readByte() + local result = self:readJSONInteger() + if result >= 256 then + terror(TProtocolException:new{message = "UnExpected Byte " .. result}) + end + return result +end + +function TJSONProtocol:readI16() + return self:readJSONInteger() +end + +function TJSONProtocol:readI32() + return self:readJSONInteger() +end + +function TJSONProtocol:readI64() + local long = liblualongnumber.new + return long(self:readJSONLongInteger()) +end + +function TJSONProtocol:readDouble() + return self:readJSONDouble() +end + +function TJSONProtocol:readString() + return self:readJSONString() +end + +function TJSONProtocol:readBinary() + return self:readJSONBase64() +end + +TJSONProtocolFactory = TProtocolFactory:new{ + __type = 'TJSONProtocolFactory', +} + +function TJSONProtocolFactory:getProtocol(trans) + -- TODO Enforce that this must be a transport class (ie not a bool) + if not trans then + terror(TProtocolException:new{ + message = 'Must supply a transport to ' .. ttype(self) + }) + end + return TJSONProtocol:new{ + trans = trans + } +end diff --git a/src/jaegertracing/thrift/lib/lua/TMemoryBuffer.lua b/src/jaegertracing/thrift/lib/lua/TMemoryBuffer.lua new file mode 100644 index 000000000..78b2f5cf0 --- /dev/null +++ b/src/jaegertracing/thrift/lib/lua/TMemoryBuffer.lua @@ -0,0 +1,91 @@ +-- +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. +-- + +require 'TTransport' + +TMemoryBuffer = TTransportBase:new{ + __type = 'TMemoryBuffer', + buffer = '', + bufferSize = 1024, + wPos = 0, + rPos = 0 +} +function TMemoryBuffer:isOpen() + return 1 +end +function TMemoryBuffer:open() end +function TMemoryBuffer:close() end + +function TMemoryBuffer:peak() + return self.rPos < self.wPos +end + +function TMemoryBuffer:getBuffer() + return self.buffer +end + +function TMemoryBuffer:resetBuffer(buf) + if buf then + self.buffer = buf + self.bufferSize = string.len(buf) + else + self.buffer = '' + self.bufferSize = 1024 + end + self.wPos = string.len(buf) + self.rPos = 0 +end + +function TMemoryBuffer:available() + return self.wPos - self.rPos +end + +function TMemoryBuffer:read(len) + local avail = self:available() + if avail == 0 then + return '' + end + + if avail < len then + len = avail + end + + local val = string.sub(self.buffer, self.rPos + 1, self.rPos + len) + self.rPos = self.rPos + len + return val +end + +function TMemoryBuffer:readAll(len) + local avail = self:available() + + if avail < len then + local msg = string.format('Attempt to readAll(%d) found only %d available', + len, avail) + terror(TTransportException:new{message = msg}) + end + -- read should block so we don't need a loop here + return self:read(len) +end + +function TMemoryBuffer:write(buf) + self.buffer = self.buffer .. buf + self.wPos = self.wPos + string.len(buf) +end + +function TMemoryBuffer:flush() end diff --git a/src/jaegertracing/thrift/lib/lua/TProtocol.lua b/src/jaegertracing/thrift/lib/lua/TProtocol.lua new file mode 100644 index 000000000..1306fb3d8 --- /dev/null +++ b/src/jaegertracing/thrift/lib/lua/TProtocol.lua @@ -0,0 +1,164 @@ +-- +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. +-- + +require 'Thrift' + +TProtocolException = TException:new { + UNKNOWN = 0, + INVALID_DATA = 1, + NEGATIVE_SIZE = 2, + SIZE_LIMIT = 3, + BAD_VERSION = 4, + INVALID_PROTOCOL = 5, + DEPTH_LIMIT = 6, + errorCode = 0, + __type = 'TProtocolException' +} +function TProtocolException:__errorCodeToString() + if self.errorCode == self.INVALID_DATA then + return 'Invalid data' + elseif self.errorCode == self.NEGATIVE_SIZE then + return 'Negative size' + elseif self.errorCode == self.SIZE_LIMIT then + return 'Size limit' + elseif self.errorCode == self.BAD_VERSION then + return 'Bad version' + elseif self.errorCode == self.INVALID_PROTOCOL then + return 'Invalid protocol' + elseif self.errorCode == self.DEPTH_LIMIT then + return 'Exceeded size limit' + else + return 'Default (unknown)' + end +end + +TProtocolBase = __TObject:new{ + __type = 'TProtocolBase', + trans +} + +function TProtocolBase:new(obj) + if ttype(obj) ~= 'table' then + error(ttype(self) .. 'must be initialized with a table') + end + + -- Ensure a transport is provided + if not obj.trans then + error('You must provide ' .. ttype(self) .. ' with a trans') + end + + return __TObject.new(self, obj) +end + +function TProtocolBase:writeMessageBegin(name, ttype, seqid) end +function TProtocolBase:writeMessageEnd() end +function TProtocolBase:writeStructBegin(name) end +function TProtocolBase:writeStructEnd() end +function TProtocolBase:writeFieldBegin(name, ttype, id) end +function TProtocolBase:writeFieldEnd() end +function TProtocolBase:writeFieldStop() end +function TProtocolBase:writeMapBegin(ktype, vtype, size) end +function TProtocolBase:writeMapEnd() end +function TProtocolBase:writeListBegin(ttype, size) end +function TProtocolBase:writeListEnd() end +function TProtocolBase:writeSetBegin(ttype, size) end +function TProtocolBase:writeSetEnd() end +function TProtocolBase:writeBool(bool) end +function TProtocolBase:writeByte(byte) end +function TProtocolBase:writeI16(i16) end +function TProtocolBase:writeI32(i32) end +function TProtocolBase:writeI64(i64) end +function TProtocolBase:writeDouble(dub) end +function TProtocolBase:writeString(str) end +function TProtocolBase:readMessageBegin() end +function TProtocolBase:readMessageEnd() end +function TProtocolBase:readStructBegin() end +function TProtocolBase:readStructEnd() end +function TProtocolBase:readFieldBegin() end +function TProtocolBase:readFieldEnd() end +function TProtocolBase:readMapBegin() end +function TProtocolBase:readMapEnd() end +function TProtocolBase:readListBegin() end +function TProtocolBase:readListEnd() end +function TProtocolBase:readSetBegin() end +function TProtocolBase:readSetEnd() end +function TProtocolBase:readBool() end +function TProtocolBase:readByte() end +function TProtocolBase:readI16() end +function TProtocolBase:readI32() end +function TProtocolBase:readI64() end +function TProtocolBase:readDouble() end +function TProtocolBase:readString() end + +function TProtocolBase:skip(ttype) + if ttype == TType.BOOL then + self:readBool() + elseif ttype == TType.BYTE then + self:readByte() + elseif ttype == TType.I16 then + self:readI16() + elseif ttype == TType.I32 then + self:readI32() + elseif ttype == TType.I64 then + self:readI64() + elseif ttype == TType.DOUBLE then + self:readDouble() + elseif ttype == TType.STRING then + self:readString() + elseif ttype == TType.STRUCT then + local name = self:readStructBegin() + while true do + local name, ttype, id = self:readFieldBegin() + if ttype == TType.STOP then + break + end + self:skip(ttype) + self:readFieldEnd() + end + self:readStructEnd() + elseif ttype == TType.MAP then + local kttype, vttype, size = self:readMapBegin() + for i = 1, size, 1 do + self:skip(kttype) + self:skip(vttype) + end + self:readMapEnd() + elseif ttype == TType.SET then + local ettype, size = self:readSetBegin() + for i = 1, size, 1 do + self:skip(ettype) + end + self:readSetEnd() + elseif ttype == TType.LIST then + local ettype, size = self:readListBegin() + for i = 1, size, 1 do + self:skip(ettype) + end + self:readListEnd() + else + terror(TProtocolException:new{ + message = 'Invalid data' + }) + end +end + +TProtocolFactory = __TObject:new{ + __type = 'TProtocolFactory', +} +function TProtocolFactory:getProtocol(trans) end diff --git a/src/jaegertracing/thrift/lib/lua/TServer.lua b/src/jaegertracing/thrift/lib/lua/TServer.lua new file mode 100644 index 000000000..4e37d5871 --- /dev/null +++ b/src/jaegertracing/thrift/lib/lua/TServer.lua @@ -0,0 +1,140 @@ +-- +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. +-- + +require 'Thrift' +require 'TFramedTransport' +require 'TBinaryProtocol' + +-- TServer +TServer = __TObject:new{ + __type = 'TServer' +} + +-- 2 possible constructors +-- 1. {processor, serverTransport} +-- 2. {processor, serverTransport, transportFactory, protocolFactory} +function TServer:new(args) + if ttype(args) ~= 'table' then + error('TServer must be initialized with a table') + end + if args.processor == nil then + terror('You must provide ' .. ttype(self) .. ' with a processor') + end + if args.serverTransport == nil then + terror('You must provide ' .. ttype(self) .. ' with a serverTransport') + end + + -- Create the object + local obj = __TObject.new(self, args) + + if obj.transportFactory then + obj.inputTransportFactory = obj.transportFactory + obj.outputTransportFactory = obj.transportFactory + obj.transportFactory = nil + else + obj.inputTransportFactory = TFramedTransportFactory:new{} + obj.outputTransportFactory = obj.inputTransportFactory + end + + if obj.protocolFactory then + obj.inputProtocolFactory = obj.protocolFactory + obj.outputProtocolFactory = obj.protocolFactory + obj.protocolFactory = nil + else + obj.inputProtocolFactory = TBinaryProtocolFactory:new{} + obj.outputProtocolFactory = obj.inputProtocolFactory + end + + -- Set the __server variable in the handler so we can stop the server + obj.processor.handler.__server = self + + return obj +end + +function TServer:setServerEventHandler(handler) + self.serverEventHandler = handler +end + +function TServer:_clientBegin(content, iprot, oprot) + if self.serverEventHandler and + type(self.serverEventHandler.clientBegin) == 'function' then + self.serverEventHandler:clientBegin(iprot, oprot) + end +end + +function TServer:_preServe() + if self.serverEventHandler and + type(self.serverEventHandler.preServe) == 'function' then + self.serverEventHandler:preServe(self.serverTransport:getSocketInfo()) + end +end + +function TServer:_handleException(err) + if string.find(err, 'TTransportException') == nil then + print(err) + end +end + +function TServer:serve() end +function TServer:handle(client) + local itrans, otrans = + self.inputTransportFactory:getTransport(client), + self.outputTransportFactory:getTransport(client) + local iprot, oprot = + self.inputProtocolFactory:getProtocol(itrans), + self.outputProtocolFactory:getProtocol(otrans) + + self:_clientBegin(iprot, oprot) + while true do + local ret, err = pcall(self.processor.process, self.processor, iprot, oprot) + if ret == false and err then + if not string.find(err, "TTransportException") then + self:_handleException(err) + end + break + end + end + itrans:close() + otrans:close() +end + +function TServer:close() + self.serverTransport:close() +end + +-- TSimpleServer +-- Single threaded server that handles one transport (connection) +TSimpleServer = __TObject:new(TServer, { + __type = 'TSimpleServer', + __stop = false +}) + +function TSimpleServer:serve() + self.serverTransport:listen() + self:_preServe() + while not self.__stop do + client = self.serverTransport:accept() + self:handle(client) + end + self:close() +end + +function TSimpleServer:stop() + self.__stop = true +end diff --git a/src/jaegertracing/thrift/lib/lua/TSocket.lua b/src/jaegertracing/thrift/lib/lua/TSocket.lua new file mode 100644 index 000000000..d71fc1f98 --- /dev/null +++ b/src/jaegertracing/thrift/lib/lua/TSocket.lua @@ -0,0 +1,132 @@ +---- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. +-- + +require 'TTransport' +require 'libluasocket' + +-- TSocketBase +TSocketBase = TTransportBase:new{ + __type = 'TSocketBase', + timeout = 1000, + host = 'localhost', + port = 9090, + handle +} + +function TSocketBase:close() + if self.handle then + self.handle:destroy() + self.handle = nil + end +end + +-- Returns a table with the fields host and port +function TSocketBase:getSocketInfo() + if self.handle then + return self.handle:getsockinfo() + end + terror(TTransportException:new{errorCode = TTransportException.NOT_OPEN}) +end + +function TSocketBase:setTimeout(timeout) + if timeout and ttype(timeout) == 'number' then + if self.handle then + self.handle:settimeout(timeout) + end + self.timeout = timeout + end +end + +-- TSocket +TSocket = TSocketBase:new{ + __type = 'TSocket', + host = 'localhost', + port = 9090 +} + +function TSocket:isOpen() + if self.handle then + return true + end + return false +end + +function TSocket:open() + if self.handle then + self:close() + end + + -- Create local handle + local sock, err = luasocket.create_and_connect( + self.host, self.port, self.timeout) + if err == nil then + self.handle = sock + end + + if err then + terror(TTransportException:new{ + message = 'Could not connect to ' .. self.host .. ':' .. self.port + .. ' (' .. err .. ')' + }) + end +end + +function TSocket:read(len) + local buf = self.handle:receive(self.handle, len) + if not buf or string.len(buf) ~= len then + terror(TTransportException:new{errorCode = TTransportException.UNKNOWN}) + end + return buf +end + +function TSocket:write(buf) + self.handle:send(self.handle, buf) +end + +function TSocket:flush() +end + +-- TServerSocket +TServerSocket = TSocketBase:new{ + __type = 'TServerSocket', + host = 'localhost', + port = 9090 +} + +function TServerSocket:listen() + if self.handle then + self:close() + end + + local sock, err = luasocket.create(self.host, self.port) + if not err then + self.handle = sock + else + terror(err) + end + self.handle:settimeout(self.timeout) + self.handle:listen() +end + +function TServerSocket:accept() + local client, err = self.handle:accept() + if err then + terror(err) + end + return TSocket:new({handle = client}) +end diff --git a/src/jaegertracing/thrift/lib/lua/TTransport.lua b/src/jaegertracing/thrift/lib/lua/TTransport.lua new file mode 100644 index 000000000..01c7e5979 --- /dev/null +++ b/src/jaegertracing/thrift/lib/lua/TTransport.lua @@ -0,0 +1,93 @@ +-- +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. +-- + +require 'Thrift' + +TTransportException = TException:new { + UNKNOWN = 0, + NOT_OPEN = 1, + ALREADY_OPEN = 2, + TIMED_OUT = 3, + END_OF_FILE = 4, + INVALID_FRAME_SIZE = 5, + INVALID_TRANSFORM = 6, + INVALID_CLIENT_TYPE = 7, + errorCode = 0, + __type = 'TTransportException' +} + +function TTransportException:__errorCodeToString() + if self.errorCode == self.NOT_OPEN then + return 'Transport not open' + elseif self.errorCode == self.ALREADY_OPEN then + return 'Transport already open' + elseif self.errorCode == self.TIMED_OUT then + return 'Transport timed out' + elseif self.errorCode == self.END_OF_FILE then + return 'End of file' + elseif self.errorCode == self.INVALID_FRAME_SIZE then + return 'Invalid frame size' + elseif self.errorCode == self.INVALID_TRANSFORM then + return 'Invalid transform' + elseif self.errorCode == self.INVALID_CLIENT_TYPE then + return 'Invalid client type' + else + return 'Default (unknown)' + end +end + +TTransportBase = __TObject:new{ + __type = 'TTransportBase' +} + +function TTransportBase:isOpen() end +function TTransportBase:open() end +function TTransportBase:close() end +function TTransportBase:read(len) end +function TTransportBase:readAll(len) + local buf, have, chunk = '', 0 + while have < len do + chunk = self:read(len - have) + have = have + string.len(chunk) + buf = buf .. chunk + + if string.len(chunk) == 0 then + terror(TTransportException:new{ + errorCode = TTransportException.END_OF_FILE + }) + end + end + return buf +end +function TTransportBase:write(buf) end +function TTransportBase:flush() end + +TServerTransportBase = __TObject:new{ + __type = 'TServerTransportBase' +} +function TServerTransportBase:listen() end +function TServerTransportBase:accept() end +function TServerTransportBase:close() end + +TTransportFactoryBase = __TObject:new{ + __type = 'TTransportFactoryBase' +} +function TTransportFactoryBase:getTransport(trans) + return trans +end diff --git a/src/jaegertracing/thrift/lib/lua/Thrift.lua b/src/jaegertracing/thrift/lib/lua/Thrift.lua new file mode 100644 index 000000000..a948b3dcb --- /dev/null +++ b/src/jaegertracing/thrift/lib/lua/Thrift.lua @@ -0,0 +1,281 @@ +-- +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. +-- + +---- namespace thrift +--thrift = {} +--setmetatable(thrift, {__index = _G}) --> perf hit for accessing global methods +--setfenv(1, thrift) + +package.cpath = package.cpath .. ';bin/?.so' -- TODO FIX +function ttype(obj) + if type(obj) == 'table' and + obj.__type and + type(obj.__type) == 'string' then + return obj.__type + end + return type(obj) +end + +function terror(e) + if e and e.__tostring then + error(e:__tostring()) + return + end + error(e) +end + +function ttable_size(t) + local count = 0 + for k, v in pairs(t) do + count = count + 1 + end + return count +end + +version = '0.13.0' + +TType = { + STOP = 0, + VOID = 1, + BOOL = 2, + BYTE = 3, + I08 = 3, + DOUBLE = 4, + I16 = 6, + I32 = 8, + I64 = 10, + STRING = 11, + UTF7 = 11, + STRUCT = 12, + MAP = 13, + SET = 14, + LIST = 15, + UTF8 = 16, + UTF16 = 17 +} + +TMessageType = { + CALL = 1, + REPLY = 2, + EXCEPTION = 3, + ONEWAY = 4 +} + +-- Recursive __index function to achieve inheritance +function __tobj_index(self, key) + local v = rawget(self, key) + if v ~= nil then + return v + end + + local p = rawget(self, '__parent') + if p then + return __tobj_index(p, key) + end + + return nil +end + +-- Basic Thrift-Lua Object +__TObject = { + __type = '__TObject', + __mt = { + __index = __tobj_index + } +} +function __TObject:new(init_obj) + local obj = {} + if ttype(obj) == 'table' then + obj = init_obj + end + + -- Use the __parent key and the __index function to achieve inheritance + obj.__parent = self + setmetatable(obj, __TObject.__mt) + return obj +end + +-- Return a string representation of any lua variable +function thrift_print_r(t) + local ret = '' + local ltype = type(t) + if (ltype == 'table') then + ret = ret .. '{ ' + for key,value in pairs(t) do + ret = ret .. tostring(key) .. '=' .. thrift_print_r(value) .. ' ' + end + ret = ret .. '}' + elseif ltype == 'string' then + ret = ret .. "'" .. tostring(t) .. "'" + else + ret = ret .. tostring(t) + end + return ret +end + +-- Basic Exception +TException = __TObject:new{ + message, + errorCode, + __type = 'TException' +} +function TException:__tostring() + if self.message then + return string.format('%s: %s', self.__type, self.message) + else + local message + if self.errorCode and self.__errorCodeToString then + message = string.format('%d: %s', self.errorCode, self:__errorCodeToString()) + else + message = thrift_print_r(self) + end + return string.format('%s:%s', self.__type, message) + end +end + +TApplicationException = TException:new{ + UNKNOWN = 0, + UNKNOWN_METHOD = 1, + INVALID_MESSAGE_TYPE = 2, + WRONG_METHOD_NAME = 3, + BAD_SEQUENCE_ID = 4, + MISSING_RESULT = 5, + INTERNAL_ERROR = 6, + PROTOCOL_ERROR = 7, + INVALID_TRANSFORM = 8, + INVALID_PROTOCOL = 9, + UNSUPPORTED_CLIENT_TYPE = 10, + errorCode = 0, + __type = 'TApplicationException' +} + +function TApplicationException:__errorCodeToString() + if self.errorCode == self.UNKNOWN_METHOD then + return 'Unknown method' + elseif self.errorCode == self.INVALID_MESSAGE_TYPE then + return 'Invalid message type' + elseif self.errorCode == self.WRONG_METHOD_NAME then + return 'Wrong method name' + elseif self.errorCode == self.BAD_SEQUENCE_ID then + return 'Bad sequence ID' + elseif self.errorCode == self.MISSING_RESULT then + return 'Missing result' + elseif self.errorCode == self.INTERNAL_ERROR then + return 'Internal error' + elseif self.errorCode == self.PROTOCOL_ERROR then + return 'Protocol error' + elseif self.errorCode == self.INVALID_TRANSFORM then + return 'Invalid transform' + elseif self.errorCode == self.INVALID_PROTOCOL then + return 'Invalid protocol' + elseif self.errorCode == self.UNSUPPORTED_CLIENT_TYPE then + return 'Unsupported client type' + else + return 'Default (unknown)' + end +end + +function TException:read(iprot) + iprot:readStructBegin() + while true do + local fname, ftype, fid = iprot:readFieldBegin() + if ftype == TType.STOP then + break + elseif fid == 1 then + if ftype == TType.STRING then + self.message = iprot:readString() + else + iprot:skip(ftype) + end + elseif fid == 2 then + if ftype == TType.I32 then + self.errorCode = iprot:readI32() + else + iprot:skip(ftype) + end + else + iprot:skip(ftype) + end + iprot:readFieldEnd() + end + iprot:readStructEnd() +end + +function TException:write(oprot) + oprot:writeStructBegin('TApplicationException') + if self.message then + oprot:writeFieldBegin('message', TType.STRING, 1) + oprot:writeString(self.message) + oprot:writeFieldEnd() + end + if self.errorCode then + oprot:writeFieldBegin('type', TType.I32, 2) + oprot:writeI32(self.errorCode) + oprot:writeFieldEnd() + end + oprot:writeFieldStop() + oprot:writeStructEnd() +end + +-- Basic Client (used in generated lua code) +__TClient = __TObject:new{ + __type = '__TClient', + _seqid = 0 +} +function __TClient:new(obj) + if ttype(obj) ~= 'table' then + error('TClient must be initialized with a table') + end + + -- Set iprot & oprot + if obj.protocol then + obj.iprot = obj.protocol + obj.oprot = obj.protocol + obj.protocol = nil + elseif not obj.iprot then + error('You must provide ' .. ttype(self) .. ' with an iprot') + end + if not obj.oprot then + obj.oprot = obj.iprot + end + + return __TObject.new(self, obj) +end + +function __TClient:close() + self.iprot.trans:close() + self.oprot.trans:close() +end + +-- Basic Processor (used in generated lua code) +__TProcessor = __TObject:new{ + __type = '__TProcessor' +} +function __TProcessor:new(obj) + if ttype(obj) ~= 'table' then + error('TProcessor must be initialized with a table') + end + + -- Ensure a handler is provided + if not obj.handler then + error('You must provide ' .. ttype(self) .. ' with a handler') + end + + return __TObject.new(self, obj) +end diff --git a/src/jaegertracing/thrift/lib/lua/coding_standards.md b/src/jaegertracing/thrift/lib/lua/coding_standards.md new file mode 100644 index 000000000..fa0390bb5 --- /dev/null +++ b/src/jaegertracing/thrift/lib/lua/coding_standards.md @@ -0,0 +1 @@ +Please follow [General Coding Standards](/doc/coding_standards.md) diff --git a/src/jaegertracing/thrift/lib/lua/src/longnumberutils.c b/src/jaegertracing/thrift/lib/lua/src/longnumberutils.c new file mode 100644 index 000000000..fbc678900 --- /dev/null +++ b/src/jaegertracing/thrift/lib/lua/src/longnumberutils.c @@ -0,0 +1,47 @@ +// +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +// + +#include <lua.h> +#include <lauxlib.h> +#include <stdlib.h> +#include <inttypes.h> + +const char * LONG_NUM_TYPE = "__thrift_longnumber"; +int64_t lualongnumber_checklong(lua_State *L, int index) { + switch (lua_type(L, index)) { + case LUA_TNUMBER: + return (int64_t)lua_tonumber(L, index); + case LUA_TSTRING: + return atoll(lua_tostring(L, index)); + default: + return *((int64_t *)luaL_checkudata(L, index, LONG_NUM_TYPE)); + } +} + +// Creates a new longnumber and pushes it onto the statck +int64_t * lualongnumber_pushlong(lua_State *L, int64_t *val) { + int64_t *data = (int64_t *)lua_newuserdata(L, sizeof(int64_t)); // longnum + luaL_getmetatable(L, LONG_NUM_TYPE); // longnum, mt + lua_setmetatable(L, -2); // longnum + if (val) { + *data = *val; + } + return data; +} + diff --git a/src/jaegertracing/thrift/lib/lua/src/luabitwise.c b/src/jaegertracing/thrift/lib/lua/src/luabitwise.c new file mode 100644 index 000000000..2e07e1724 --- /dev/null +++ b/src/jaegertracing/thrift/lib/lua/src/luabitwise.c @@ -0,0 +1,83 @@ +// +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +// + +#include <lua.h> +#include <lauxlib.h> + +static int l_not(lua_State *L) { + int a = luaL_checkinteger(L, 1); + a = ~a; + lua_pushnumber(L, a); + return 1; +} + +static int l_xor(lua_State *L) { + int a = luaL_checkinteger(L, 1); + int b = luaL_checkinteger(L, 2); + a ^= b; + lua_pushnumber(L, a); + return 1; +} + +static int l_and(lua_State *L) { + int a = luaL_checkinteger(L, 1); + int b = luaL_checkinteger(L, 2); + a &= b; + lua_pushnumber(L, a); + return 1; +} + +static int l_or(lua_State *L) { + int a = luaL_checkinteger(L, 1); + int b = luaL_checkinteger(L, 2); + a |= b; + lua_pushnumber(L, a); + return 1; +} + +static int l_shiftr(lua_State *L) { + int a = luaL_checkinteger(L, 1); + int b = luaL_checkinteger(L, 2); + a = a >> b; + lua_pushnumber(L, a); + return 1; +} + +static int l_shiftl(lua_State *L) { + int a = luaL_checkinteger(L, 1); + int b = luaL_checkinteger(L, 2); + a = a << b; + lua_pushnumber(L, a); + return 1; +} + +static const struct luaL_Reg funcs[] = { + {"band", l_and}, + {"bor", l_or}, + {"bxor", l_xor}, + {"bnot", l_not}, + {"shiftl", l_shiftl}, + {"shiftr", l_shiftr}, + {NULL, NULL} +}; + +int luaopen_libluabitwise(lua_State *L) { + luaL_register(L, "libluabitwise", funcs); + return 1; +} diff --git a/src/jaegertracing/thrift/lib/lua/src/luabpack.c b/src/jaegertracing/thrift/lib/lua/src/luabpack.c new file mode 100644 index 000000000..077b6aa07 --- /dev/null +++ b/src/jaegertracing/thrift/lib/lua/src/luabpack.c @@ -0,0 +1,308 @@ +// +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +// + +#include <lua.h> +#include <lauxlib.h> +#include <string.h> +#include <inttypes.h> +#include <netinet/in.h> + +extern int64_t lualongnumber_checklong(lua_State *L, int index); +extern int64_t lualongnumber_pushlong(lua_State *L, int64_t *val); + +// host order to network order (64-bit) +static int64_t T_htonll(uint64_t data) { + uint32_t d1 = htonl((uint32_t)data); + uint32_t d2 = htonl((uint32_t)(data >> 32)); + return ((uint64_t)d1 << 32) + (uint64_t)d2; +} + +// network order to host order (64-bit) +static int64_t T_ntohll(uint64_t data) { + uint32_t d1 = ntohl((uint32_t)data); + uint32_t d2 = ntohl((uint32_t)(data >> 32)); + return ((uint64_t)d1 << 32) + (uint64_t)d2; +} + +/** + * bpack(type, data) + * c - Signed Byte + * s - Signed Short + * i - Signed Int + * l - Signed Long + * d - Double + */ +static int l_bpack(lua_State *L) { + const char *code = luaL_checkstring(L, 1); + luaL_argcheck(L, code[1] == '\0', 0, "Format code must be one character."); + luaL_Buffer buf; + luaL_buffinit(L, &buf); + + switch (code[0]) { + case 'c': { + int8_t data = luaL_checknumber(L, 2); + luaL_addlstring(&buf, (void*)&data, sizeof(data)); + break; + } + case 's': { + int16_t data = luaL_checknumber(L, 2); + data = (int16_t)htons(data); + luaL_addlstring(&buf, (void*)&data, sizeof(data)); + break; + } + case 'i': { + int32_t data = luaL_checkinteger(L, 2); + data = (int32_t)htonl(data); + luaL_addlstring(&buf, (void*)&data, sizeof(data)); + break; + } + case 'l': { + int64_t data = lualongnumber_checklong(L, 2); + data = (int64_t)T_htonll(data); + luaL_addlstring(&buf, (void*)&data, sizeof(data)); + break; + } + case 'd': { + double data = luaL_checknumber(L, 2); + luaL_addlstring(&buf, (void*)&data, sizeof(data)); + break; + } + default: + luaL_argcheck(L, 0, 0, "Invalid format code."); + } + + luaL_pushresult(&buf); + return 1; +} + +/** + * bunpack(type, data) + * c - Signed Byte + * C - Unsigned Byte + * s - Signed Short + * i - Signed Int + * l - Signed Long + * d - Double + */ +static int l_bunpack(lua_State *L) { + const char *code = luaL_checkstring(L, 1); + luaL_argcheck(L, code[1] == '\0', 0, "Format code must be one character."); + const char *data = luaL_checkstring(L, 2); +#if LUA_VERSION_NUM >= 502 + size_t len = lua_rawlen(L, 2); +#else + size_t len = lua_objlen(L, 2); +#endif + + switch (code[0]) { + case 'c': { + int8_t val; + luaL_argcheck(L, len == sizeof(val), 1, "Invalid input string size."); + memcpy(&val, data, sizeof(val)); + lua_pushnumber(L, val); + break; + } + /** + * unpack unsigned Byte. + */ + case 'C': { + uint8_t val; + luaL_argcheck(L, len == sizeof(val), 1, "Invalid input string size."); + memcpy(&val, data, sizeof(val)); + lua_pushnumber(L, val); + break; + } + case 's': { + int16_t val; + luaL_argcheck(L, len == sizeof(val), 1, "Invalid input string size."); + memcpy(&val, data, sizeof(val)); + val = (int16_t)ntohs(val); + lua_pushnumber(L, val); + break; + } + case 'i': { + int32_t val; + luaL_argcheck(L, len == sizeof(val), 1, "Invalid input string size."); + memcpy(&val, data, sizeof(val)); + val = (int32_t)ntohl(val); + lua_pushnumber(L, val); + break; + } + case 'l': { + int64_t val; + luaL_argcheck(L, len == sizeof(val), 1, "Invalid input string size."); + memcpy(&val, data, sizeof(val)); + val = (int64_t)T_ntohll(val); + lualongnumber_pushlong(L, &val); + break; + } + case 'd': { + double val; + luaL_argcheck(L, len == sizeof(val), 1, "Invalid input string size."); + memcpy(&val, data, sizeof(val)); + lua_pushnumber(L, val); + break; + } + default: + luaL_argcheck(L, 0, 0, "Invalid format code."); + } + return 1; +} + +/** + * Convert l into a zigzag long. This allows negative numbers to be + * represented compactly as a varint. + */ +static int l_i64ToZigzag(lua_State *L) { + int64_t n = lualongnumber_checklong(L, 1); + int64_t result = (n << 1) ^ (n >> 63); + lualongnumber_pushlong(L, &result); + return 1; +} +/** + * Convert n into a zigzag int. This allows negative numbers to be + * represented compactly as a varint. + */ +static int l_i32ToZigzag(lua_State *L) { + int32_t n = luaL_checkinteger(L, 1); + uint32_t result = (uint32_t)(n << 1) ^ (n >> 31); + lua_pushnumber(L, result); + return 1; +} + +/** + * Convert from zigzag int to int. + */ +static int l_zigzagToI32(lua_State *L) { + uint32_t n = luaL_checkinteger(L, 1); + int32_t result = (int32_t)(n >> 1) ^ (uint32_t)(-(int32_t)(n & 1)); + lua_pushnumber(L, result); + return 1; +} + +/** + * Convert from zigzag long to long. + */ +static int l_zigzagToI64(lua_State *L) { + int64_t n = lualongnumber_checklong(L, 1); + int64_t result = (int64_t)(n >> 1) ^ (uint64_t)(-(int64_t)(n & 1)); + lualongnumber_pushlong(L, &result); + return 1; +} + +/** + * Convert an i32 to a varint. Results in 1-5 bytes on the buffer. + */ +static int l_toVarint32(lua_State *L) { + uint8_t buf[5]; + uint32_t n = luaL_checkinteger(L, 1); + uint32_t wsize = 0; + + while (1) { + if ((n & ~0x7F) == 0) { + buf[wsize++] = (int8_t)n; + break; + } else { + buf[wsize++] = (int8_t)((n & 0x7F) | 0x80); + n >>= 7; + } + } + lua_pushlstring(L, buf, wsize); + return 1; +} + +/** + * Convert an i64 to a varint. Results in 1-10 bytes on the buffer. + */ +static int l_toVarint64(lua_State *L) { + uint8_t data[10]; + uint64_t n = lualongnumber_checklong(L, 1); + uint32_t wsize = 0; + luaL_Buffer buf; + luaL_buffinit(L, &buf); + + while (1) { + if ((n & ~0x7FL) == 0) { + data[wsize++] = (int8_t)n; + break; + } else { + data[wsize++] = (int8_t)((n & 0x7F) | 0x80); + n >>= 7; + } + } + + luaL_addlstring(&buf, (void*)&data, wsize); + luaL_pushresult(&buf); + return 1; +} + +/** + * Convert a varint to i64. + */ +static int l_fromVarint64(lua_State *L) { + int64_t result; + uint8_t byte = luaL_checknumber(L, 1); + int32_t shift = luaL_checknumber(L, 2); + uint64_t n = (uint64_t)lualongnumber_checklong(L, 3); + n |= (uint64_t)(byte & 0x7f) << shift; + + if (!(byte & 0x80)) { + result = (int64_t)(n >> 1) ^ (uint64_t)(-(int64_t)(n & 1)); + lua_pushnumber(L, 0); + } else { + result = n; + lua_pushnumber(L, 1); + } + lualongnumber_pushlong(L, &result); + return 2; +} + +/** + * To pack message type of compact protocol. + */ +static int l_packMesgType(lua_State *L) { + int32_t version_n = luaL_checkinteger(L, 1); + int32_t version_mask = luaL_checkinteger(L, 2); + int32_t messagetype = luaL_checkinteger(L, 3); + int32_t type_shift_amount = luaL_checkinteger(L, 4); + int32_t type_mask = luaL_checkinteger(L, 5); + int32_t to_mesg_type = (version_n & version_mask) | + (((int32_t)messagetype << type_shift_amount) & type_mask); + lua_pushnumber(L, to_mesg_type); + return 1; +} + +static const struct luaL_Reg lua_bpack[] = { + {"bpack", l_bpack}, + {"bunpack", l_bunpack}, + {"i32ToZigzag", l_i32ToZigzag}, + {"i64ToZigzag", l_i64ToZigzag}, + {"zigzagToI32", l_zigzagToI32}, + {"zigzagToI64", l_zigzagToI64}, + {"toVarint32", l_toVarint32}, + {"toVarint64", l_toVarint64}, + {"fromVarint64", l_fromVarint64}, + {"packMesgType", l_packMesgType}, + {NULL, NULL} +}; + +int luaopen_libluabpack(lua_State *L) { + luaL_register(L, "libluabpack", lua_bpack); + return 1; +} diff --git a/src/jaegertracing/thrift/lib/lua/src/lualongnumber.c b/src/jaegertracing/thrift/lib/lua/src/lualongnumber.c new file mode 100644 index 000000000..9001e4a90 --- /dev/null +++ b/src/jaegertracing/thrift/lib/lua/src/lualongnumber.c @@ -0,0 +1,228 @@ +// +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +// + +#include <lua.h> +#include <lauxlib.h> +#include <stdlib.h> +#include <math.h> +#include <inttypes.h> +#include <string.h> + +extern const char * LONG_NUM_TYPE; +extern int64_t lualongnumber_checklong(lua_State *L, int index); +extern int64_t lualongnumber_pushlong(lua_State *L, int64_t *val); + +//////////////////////////////////////////////////////////////////////////////// + +static void l_serialize(char *buf, int len, int64_t val) { + snprintf(buf, len, "%"PRId64, val); +} + +static int64_t l_deserialize(const char *buf) { + int64_t data; + int rv; + // Support hex prefixed with '0x' + if (strstr(buf, "0x") == buf) { + rv = sscanf(buf, "%"PRIx64, &data); + } else { + rv = sscanf(buf, "%"PRId64, &data); + } + if (rv == 1) { + return data; + } + return 0; // Failed +} + +//////////////////////////////////////////////////////////////////////////////// + +static int l_new(lua_State *L) { + int64_t val; + const char *str = NULL; + if (lua_type(L, 1) == LUA_TSTRING) { + str = lua_tostring(L, 1); + val = l_deserialize(str); + } else if (lua_type(L, 1) == LUA_TNUMBER) { + val = (int64_t)lua_tonumber(L, 1); + str = (const char *)1; + } + lualongnumber_pushlong(L, (str ? &val : NULL)); + return 1; +} + +//////////////////////////////////////////////////////////////////////////////// + +// a + b +static int l_add(lua_State *L) { + int64_t a, b, c; + a = lualongnumber_checklong(L, 1); + b = lualongnumber_checklong(L, 2); + c = a + b; + lualongnumber_pushlong(L, &c); + return 1; +} + +// a / b +static int l_div(lua_State *L) { + int64_t a, b, c; + a = lualongnumber_checklong(L, 1); + b = lualongnumber_checklong(L, 2); + c = a / b; + lualongnumber_pushlong(L, &c); + return 1; +} + +// a == b (both a and b are lualongnumber's) +static int l_eq(lua_State *L) { + int64_t a, b; + a = lualongnumber_checklong(L, 1); + b = lualongnumber_checklong(L, 2); + lua_pushboolean(L, (a == b ? 1 : 0)); + return 1; +} + +// garbage collection +static int l_gc(lua_State *L) { + lua_pushnil(L); + lua_setmetatable(L, 1); + return 0; +} + +// a < b +static int l_lt(lua_State *L) { + int64_t a, b; + a = lualongnumber_checklong(L, 1); + b = lualongnumber_checklong(L, 2); + lua_pushboolean(L, (a < b ? 1 : 0)); + return 1; +} + +// a <= b +static int l_le(lua_State *L) { + int64_t a, b; + a = lualongnumber_checklong(L, 1); + b = lualongnumber_checklong(L, 2); + lua_pushboolean(L, (a <= b ? 1 : 0)); + return 1; +} + +// a % b +static int l_mod(lua_State *L) { + int64_t a, b, c; + a = lualongnumber_checklong(L, 1); + b = lualongnumber_checklong(L, 2); + c = a % b; + lualongnumber_pushlong(L, &c); + return 1; +} + +// a * b +static int l_mul(lua_State *L) { + int64_t a, b, c; + a = lualongnumber_checklong(L, 1); + b = lualongnumber_checklong(L, 2); + c = a * b; + lualongnumber_pushlong(L, &c); + return 1; +} + +// a ^ b +static int l_pow(lua_State *L) { + long double a, b; + int64_t c; + a = (long double)lualongnumber_checklong(L, 1); + b = (long double)lualongnumber_checklong(L, 2); + c = (int64_t)pow(a, b); + lualongnumber_pushlong(L, &c); + return 1; +} + +// a - b +static int l_sub(lua_State *L) { + int64_t a, b, c; + a = lualongnumber_checklong(L, 1); + b = lualongnumber_checklong(L, 2); + c = a - b; + lualongnumber_pushlong(L, &c); + return 1; +} + +// tostring() +static int l_tostring(lua_State *L) { + int64_t a; + char str[256]; + l_serialize(str, 256, lualongnumber_checklong(L, 1)); + lua_pushstring(L, str); + return 1; +} + +// -a +static int l_unm(lua_State *L) { + int64_t a, c; + a = lualongnumber_checklong(L, 1); + c = -a; + lualongnumber_pushlong(L, &c); + return 1; +} + +//////////////////////////////////////////////////////////////////////////////// + +static const luaL_Reg methods[] = { + {"__add", l_add}, + {"__div", l_div}, + {"__eq", l_eq}, + {"__gc", l_gc}, + {"__lt", l_lt}, + {"__le", l_le}, + {"__mod", l_mod}, + {"__mul", l_mul}, + {"__pow", l_pow}, + {"__sub", l_sub}, + {"__tostring", l_tostring}, + {"__unm", l_unm}, + {NULL, NULL}, +}; + +static const luaL_Reg funcs[] = { + {"new", l_new}, + {NULL, NULL} +}; + +//////////////////////////////////////////////////////////////////////////////// + +static void set_methods(lua_State *L, + const char *metatablename, + const struct luaL_Reg *methods) { + luaL_getmetatable(L, metatablename); // mt + // No need for a __index table since everything is __* + for (; methods->name; methods++) { + lua_pushstring(L, methods->name); // mt, "name" + lua_pushcfunction(L, methods->func); // mt, "name", func + lua_rawset(L, -3); // mt + } + lua_pop(L, 1); +} + +LUALIB_API int luaopen_liblualongnumber(lua_State *L) { + luaL_newmetatable(L, LONG_NUM_TYPE); + lua_pop(L, 1); + set_methods(L, LONG_NUM_TYPE, methods); + + luaL_register(L, "liblualongnumber", funcs); + return 1; +} diff --git a/src/jaegertracing/thrift/lib/lua/src/luasocket.c b/src/jaegertracing/thrift/lib/lua/src/luasocket.c new file mode 100644 index 000000000..d48351077 --- /dev/null +++ b/src/jaegertracing/thrift/lib/lua/src/luasocket.c @@ -0,0 +1,380 @@ +// +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +// + +#include <lua.h> +#include <lauxlib.h> + +#include <unistd.h> +#include "string.h" +#include "socket.h" + +//////////////////////////////////////////////////////////////////////////////// + +static const char *SOCKET_ANY = "__thrift_socket_any"; +static const char *SOCKET_CONN = "__thrift_socket_connected"; + +static const char *SOCKET_GENERIC = "__thrift_socket_generic"; +static const char *SOCKET_CLIENT = "__thrift_socket_client"; +static const char *SOCKET_SERVER = "__thrift_socket_server"; + +static const char *DEFAULT_HOST = "localhost"; + +typedef struct __t_tcp { + t_socket sock; + int timeout; // Milliseconds +} t_tcp; +typedef t_tcp *p_tcp; + +//////////////////////////////////////////////////////////////////////////////// +// Util + +static void throw_argerror(lua_State *L, int index, const char *expected) { + char msg[256]; + sprintf(msg, "%s expected, got %s", expected, luaL_typename(L, index)); + luaL_argerror(L, index, msg); +} + +static void *checkgroup(lua_State *L, int index, const char *groupname) { + if (!lua_getmetatable(L, index)) { + throw_argerror(L, index, groupname); + } + + lua_pushstring(L, groupname); + lua_rawget(L, -2); + if (lua_isnil(L, -1)) { + lua_pop(L, 2); + throw_argerror(L, index, groupname); + } else { + lua_pop(L, 2); + return lua_touserdata(L, index); + } + return NULL; // Not reachable +} + +static void *checktype(lua_State *L, int index, const char *typename) { + if (strcmp(typename, SOCKET_ANY) == 0 || + strcmp(typename, SOCKET_CONN) == 0) { + return checkgroup(L, index, typename); + } else { + return luaL_checkudata(L, index, typename); + } +} + +static void settype(lua_State *L, int index, const char *typename) { + luaL_getmetatable(L, typename); + lua_setmetatable(L, index); +} + +#define LUA_SUCCESS_RETURN(L) \ + lua_pushnumber(L, 1); \ + return 1 + +#define LUA_CHECK_RETURN(L, err) \ + if (err) { \ + lua_pushnil(L); \ + lua_pushstring(L, err); \ + return 2; \ + } \ + LUA_SUCCESS_RETURN(L) + +//////////////////////////////////////////////////////////////////////////////// + +static int l_socket_create(lua_State *L); +static int l_socket_destroy(lua_State *L); +static int l_socket_settimeout(lua_State *L); +static int l_socket_getsockinfo(lua_State *L); + +static int l_socket_accept(lua_State *L); +static int l_socket_listen(lua_State *L); + +static int l_socket_create_and_connect(lua_State *L); +static int l_socket_connect(lua_State *L); +static int l_socket_send(lua_State *L); +static int l_socket_receive(lua_State *L); + +//////////////////////////////////////////////////////////////////////////////// + +static const struct luaL_Reg methods_generic[] = { + {"destroy", l_socket_destroy}, + {"settimeout", l_socket_settimeout}, + {"getsockinfo", l_socket_getsockinfo}, + {"listen", l_socket_listen}, + {"connect", l_socket_connect}, + {NULL, NULL} +}; + +static const struct luaL_Reg methods_server[] = { + {"destroy", l_socket_destroy}, + {"getsockinfo", l_socket_getsockinfo}, + {"accept", l_socket_accept}, + {"send", l_socket_send}, + {"receive", l_socket_receive}, + {NULL, NULL} +}; + +static const struct luaL_Reg methods_client[] = { + {"destroy", l_socket_destroy}, + {"settimeout", l_socket_settimeout}, + {"getsockinfo", l_socket_getsockinfo}, + {"send", l_socket_send}, + {"receive", l_socket_receive}, + {NULL, NULL} +}; + +static const struct luaL_Reg funcs_luasocket[] = { + {"create", l_socket_create}, + {"create_and_connect", l_socket_create_and_connect}, + {NULL, NULL} +}; + +//////////////////////////////////////////////////////////////////////////////// + +// Check/enforce inheritance +static void add_to_group(lua_State *L, + const char *metatablename, + const char *groupname) { + luaL_getmetatable(L, metatablename); // mt + lua_pushstring(L, groupname); // mt, "name" + lua_pushboolean(L, 1); // mt, "name", true + lua_rawset(L, -3); // mt + lua_pop(L, 1); +} + +static void set_methods(lua_State *L, + const char *metatablename, + const struct luaL_Reg *methods) { + luaL_getmetatable(L, metatablename); // mt + // Create the __index table + lua_pushstring(L, "__index"); // mt, "__index" + lua_newtable(L); // mt, "__index", t + for (; methods->name; methods++) { + lua_pushstring(L, methods->name); // mt, "__index", t, "name" + lua_pushcfunction(L, methods->func); // mt, "__index", t, "name", func + lua_rawset(L, -3); // mt, "__index", t + } + lua_rawset(L, -3); // mt + lua_pop(L, 1); +} + +int luaopen_libluasocket(lua_State *L) { + luaL_newmetatable(L, SOCKET_GENERIC); + luaL_newmetatable(L, SOCKET_CLIENT); + luaL_newmetatable(L, SOCKET_SERVER); + lua_pop(L, 3); + add_to_group(L, SOCKET_GENERIC, SOCKET_ANY); + add_to_group(L, SOCKET_CLIENT, SOCKET_ANY); + add_to_group(L, SOCKET_SERVER, SOCKET_ANY); + add_to_group(L, SOCKET_CLIENT, SOCKET_CONN); + add_to_group(L, SOCKET_SERVER, SOCKET_CONN); + set_methods(L, SOCKET_GENERIC, methods_generic); + set_methods(L, SOCKET_CLIENT, methods_client); + set_methods(L, SOCKET_SERVER, methods_server); + + luaL_register(L, "luasocket", funcs_luasocket); + return 1; +} + +//////////////////////////////////////////////////////////////////////////////// +// General + +// sock,err create(bind_host, bind_port) +// sock,err create(bind_host) -> any port +// sock,err create() -> any port on localhost +static int l_socket_create(lua_State *L) { + const char *err; + t_socket sock; + const char *addr = lua_tostring(L, 1); + if (!addr) { + addr = DEFAULT_HOST; + } + unsigned short port = lua_tonumber(L, 2); + err = tcp_create(&sock); + if (!err) { + err = tcp_bind(&sock, addr, port); // bind on create + if (err) { + tcp_destroy(&sock); + } else { + p_tcp tcp = (p_tcp) lua_newuserdata(L, sizeof(t_tcp)); + settype(L, -2, SOCKET_GENERIC); + socket_setnonblocking(&sock); + tcp->sock = sock; + tcp->timeout = 0; + return 1; // Return userdata + } + } + LUA_CHECK_RETURN(L, err); +} + +// destroy() +static int l_socket_destroy(lua_State *L) { + p_tcp tcp = (p_tcp) checktype(L, 1, SOCKET_ANY); + const char *err = tcp_destroy(&tcp->sock); + LUA_CHECK_RETURN(L, err); +} + +// send(socket, data) +static int l_socket_send(lua_State *L) { + p_tcp self = (p_tcp) checktype(L, 1, SOCKET_CONN); + p_tcp tcp = (p_tcp) checktype(L, 2, SOCKET_CONN); + size_t len; + const char *data = luaL_checklstring(L, 3, &len); + const char *err = + tcp_send(&tcp->sock, data, len, tcp->timeout); + LUA_CHECK_RETURN(L, err); +} + +#define LUA_READ_STEP 8192 +static int l_socket_receive(lua_State *L) { + p_tcp self = (p_tcp) checktype(L, 1, SOCKET_CONN); + p_tcp handle = (p_tcp) checktype(L, 2, SOCKET_CONN); + size_t len = luaL_checknumber(L, 3); + char buf[LUA_READ_STEP]; + const char *err = NULL; + int received; + size_t got = 0, step = 0; + luaL_Buffer b; + + luaL_buffinit(L, &b); + do { + step = (LUA_READ_STEP < len - got ? LUA_READ_STEP : len - got); + err = tcp_raw_receive(&handle->sock, buf, step, self->timeout, &received); + if (err == NULL) { + luaL_addlstring(&b, buf, received); + got += received; + } + } while (err == NULL && got < len); + + if (err) { + lua_pushnil(L); + lua_pushstring(L, err); + return 2; + } + luaL_pushresult(&b); + return 1; +} + +// settimeout(timeout) +static int l_socket_settimeout(lua_State *L) { + p_tcp self = (p_tcp) checktype(L, 1, SOCKET_ANY); + int timeout = luaL_checknumber(L, 2); + self->timeout = timeout; + LUA_SUCCESS_RETURN(L); +} + +// table getsockinfo() +static int l_socket_getsockinfo(lua_State *L) { + char buf[256]; + short port = 0; + p_tcp tcp = (p_tcp) checktype(L, 1, SOCKET_ANY); + if (socket_get_info(&tcp->sock, &port, buf, 256) == SUCCESS) { + lua_newtable(L); // t + lua_pushstring(L, "host"); // t, "host" + lua_pushstring(L, buf); // t, "host", buf + lua_rawset(L, -3); // t + lua_pushstring(L, "port"); // t, "port" + lua_pushnumber(L, port); // t, "port", port + lua_rawset(L, -3); // t + return 1; + } + return 0; +} + +//////////////////////////////////////////////////////////////////////////////// +// Server + +// accept() +static int l_socket_accept(lua_State *L) { + const char *err; + p_tcp self = (p_tcp) checktype(L, 1, SOCKET_SERVER); + t_socket sock; + err = tcp_accept(&self->sock, &sock, self->timeout); + if (!err) { // Success + // Create a reference to the client + p_tcp client = (p_tcp) lua_newuserdata(L, sizeof(t_tcp)); + settype(L, 2, SOCKET_CLIENT); + socket_setnonblocking(&sock); + client->sock = sock; + client->timeout = self->timeout; + return 1; + } + LUA_CHECK_RETURN(L, err); +} + +static int l_socket_listen(lua_State *L) { + const char* err; + p_tcp tcp = (p_tcp) checktype(L, 1, SOCKET_GENERIC); + int backlog = 10; + err = tcp_listen(&tcp->sock, backlog); + if (!err) { + // Set the current as a server + settype(L, 1, SOCKET_SERVER); // Now a server + } + LUA_CHECK_RETURN(L, err); +} + +//////////////////////////////////////////////////////////////////////////////// +// Client + +// create_and_connect(host, port, timeout) +extern double __gettime(); +static int l_socket_create_and_connect(lua_State *L) { + const char* err = NULL; + double end; + t_socket sock; + const char *host = luaL_checkstring(L, 1); + unsigned short port = luaL_checknumber(L, 2); + int timeout = luaL_checknumber(L, 3); + + // Create and connect loop for timeout milliseconds + end = __gettime() + timeout/1000; + do { + // Create the socket + err = tcp_create(&sock); + if (!err) { + // Connect + err = tcp_connect(&sock, host, port, timeout); + if (err) { + tcp_destroy(&sock); + usleep(100000); // sleep for 100ms + } else { + p_tcp tcp = (p_tcp) lua_newuserdata(L, sizeof(t_tcp)); + settype(L, -2, SOCKET_CLIENT); + socket_setnonblocking(&sock); + tcp->sock = sock; + tcp->timeout = timeout; + return 1; // Return userdata + } + } + } while (err && __gettime() < end); + + LUA_CHECK_RETURN(L, err); +} + +// connect(host, port) +static int l_socket_connect(lua_State *L) { + const char *err; + p_tcp tcp = (p_tcp) checktype(L, 1, SOCKET_GENERIC); + const char *host = luaL_checkstring(L, 2); + unsigned short port = luaL_checknumber(L, 3); + err = tcp_connect(&tcp->sock, host, port, tcp->timeout); + if (!err) { + settype(L, 1, SOCKET_CLIENT); // Now a client + } + LUA_CHECK_RETURN(L, err); +} diff --git a/src/jaegertracing/thrift/lib/lua/src/socket.h b/src/jaegertracing/thrift/lib/lua/src/socket.h new file mode 100644 index 000000000..afb827e47 --- /dev/null +++ b/src/jaegertracing/thrift/lib/lua/src/socket.h @@ -0,0 +1,78 @@ +// +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +// + +#ifndef LUA_THRIFT_SOCKET_H +#define LUA_THRIFT_SOCKET_H + +#include <sys/socket.h> + +#ifdef _WIN32 +// SOL +#else +typedef int t_socket; +typedef t_socket* p_socket; +#endif + +// Error Codes +enum { + SUCCESS = 0, + TIMEOUT = -1, + CLOSED = -2, +}; +typedef int T_ERRCODE; + +static const char * TIMEOUT_MSG = "Timeout"; +static const char * CLOSED_MSG = "Connection Closed"; + +typedef struct sockaddr t_sa; +typedef t_sa * p_sa; + +T_ERRCODE socket_create(p_socket sock, int domain, int type, int protocol); +T_ERRCODE socket_destroy(p_socket sock); +T_ERRCODE socket_bind(p_socket sock, p_sa addr, int addr_len); +T_ERRCODE socket_get_info(p_socket sock, short *port, char *buf, size_t len); +T_ERRCODE socket_send(p_socket sock, const char *data, size_t len, int timeout); +T_ERRCODE socket_recv(p_socket sock, char *data, size_t len, int timeout, + int *received); + +T_ERRCODE socket_setblocking(p_socket sock); +T_ERRCODE socket_setnonblocking(p_socket sock); + +T_ERRCODE socket_accept(p_socket sock, p_socket sibling, + p_sa addr, socklen_t *addr_len, int timeout); +T_ERRCODE socket_listen(p_socket sock, int backlog); + +T_ERRCODE socket_connect(p_socket sock, p_sa addr, int addr_len, int timeout); + +const char * tcp_create(p_socket sock); +const char * tcp_destroy(p_socket sock); +const char * tcp_bind(p_socket sock, const char *host, unsigned short port); +const char * tcp_send(p_socket sock, const char *data, size_t w_len, + int timeout); +const char * tcp_receive(p_socket sock, char *data, size_t r_len, int timeout); +const char * tcp_raw_receive(p_socket sock, char * data, size_t r_len, + int timeout, int *received); + +const char * tcp_listen(p_socket sock, int backlog); +const char * tcp_accept(p_socket sock, p_socket client, int timeout); + +const char * tcp_connect(p_socket sock, const char *host, unsigned short port, + int timeout); + +#endif diff --git a/src/jaegertracing/thrift/lib/lua/src/usocket.c b/src/jaegertracing/thrift/lib/lua/src/usocket.c new file mode 100644 index 000000000..1a1b549a0 --- /dev/null +++ b/src/jaegertracing/thrift/lib/lua/src/usocket.c @@ -0,0 +1,376 @@ +// +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +// + +#include <sys/time.h> +#include <sys/types.h> +#include <arpa/inet.h> +#include <netdb.h> +#include <string.h> +#include <unistd.h> +#include <fcntl.h> +#include <errno.h> + +#include <stdio.h> // TODO REMOVE + +#include "socket.h" + +//////////////////////////////////////////////////////////////////////////////// +// Private + +// Num seconds since Jan 1 1970 (UTC) +#ifdef _WIN32 +// SOL +#else + double __gettime() { + struct timeval v; + gettimeofday(&v, (struct timezone*) NULL); + return v.tv_sec + v.tv_usec/1.0e6; + } +#endif + +#define WAIT_MODE_R 1 +#define WAIT_MODE_W 2 +#define WAIT_MODE_C (WAIT_MODE_R|WAIT_MODE_W) +T_ERRCODE socket_wait(p_socket sock, int mode, int timeout) { + int ret = 0; + fd_set rfds, wfds; + struct timeval tv; + double end, t; + if (!timeout) { + return TIMEOUT; + } + + end = __gettime() + timeout/1000; + do { + FD_ZERO(&rfds); + FD_ZERO(&wfds); + + // Specify what I/O operations we care about + if (mode & WAIT_MODE_R) { + FD_SET(*sock, &rfds); + } + if (mode & WAIT_MODE_W) { + FD_SET(*sock, &wfds); + } + + // Check for timeout + t = end - __gettime(); + if (t < 0.0) { + break; + } + + // Wait + tv.tv_sec = (int)t; + tv.tv_usec = (int)((t - tv.tv_sec) * 1.0e6); + ret = select(*sock+1, &rfds, &wfds, NULL, &tv); + } while (ret == -1 && errno == EINTR); + if (ret == -1) { + return errno; + } + + // Check for timeout + if (ret == 0) { + return TIMEOUT; + } + + // Verify that we can actually read from the remote host + if (mode & WAIT_MODE_C && FD_ISSET(*sock, &rfds) && + recv(*sock, (char*) &rfds, 0, 0) != 0) { + return errno; + } + + return SUCCESS; +} + +//////////////////////////////////////////////////////////////////////////////// +// General + +T_ERRCODE socket_create(p_socket sock, int domain, int type, int protocol) { + *sock = socket(domain, type, protocol); + if (*sock > 0) { + return SUCCESS; + } else { + return errno; + } +} + +T_ERRCODE socket_destroy(p_socket sock) { + // TODO Figure out if I should be free-ing this + if (*sock > 0) { + (void)socket_setblocking(sock); + close(*sock); + *sock = -1; + } + return SUCCESS; +} + +T_ERRCODE socket_bind(p_socket sock, p_sa addr, int addr_len) { + int ret = socket_setblocking(sock); + if (ret != SUCCESS) { + return ret; + } + if (bind(*sock, addr, addr_len)) { + ret = errno; + } + int ret2 = socket_setnonblocking(sock); + return ret == SUCCESS ? ret2 : ret; +} + +T_ERRCODE socket_get_info(p_socket sock, short *port, char *buf, size_t len) { + struct sockaddr_storage sa; + memset(&sa, 0, sizeof(sa)); + socklen_t addrlen = sizeof(sa); + int rc = getsockname(*sock, (struct sockaddr*)&sa, &addrlen); + if (!rc) { + if (sa.ss_family == AF_INET6) { + struct sockaddr_in6* sin = (struct sockaddr_in6*)(&sa); + if (!inet_ntop(AF_INET6, &sin->sin6_addr, buf, len)) { + return errno; + } + *port = ntohs(sin->sin6_port); + } else { + struct sockaddr_in* sin = (struct sockaddr_in*)(&sa); + if (!inet_ntop(AF_INET, &sin->sin_addr, buf, len)) { + return errno; + } + *port = ntohs(sin->sin_port); + } + return SUCCESS; + } + return errno; +} + +//////////////////////////////////////////////////////////////////////////////// +// Server + +T_ERRCODE socket_accept(p_socket sock, p_socket client, + p_sa addr, socklen_t *addrlen, int timeout) { + int err; + if (*sock < 0) { + return CLOSED; + } + do { + *client = accept(*sock, addr, addrlen); + if (*client > 0) { + return SUCCESS; + } + } while ((err = errno) == EINTR); + + if (err == EAGAIN || err == ECONNABORTED) { + return socket_wait(sock, WAIT_MODE_R, timeout); + } + + return err; +} + +T_ERRCODE socket_listen(p_socket sock, int backlog) { + int ret = socket_setblocking(sock); + if (ret != SUCCESS) { + return ret; + } + if (listen(*sock, backlog)) { + ret = errno; + } + int ret2 = socket_setnonblocking(sock); + return ret == SUCCESS ? ret2 : ret; +} + +//////////////////////////////////////////////////////////////////////////////// +// Client + +T_ERRCODE socket_connect(p_socket sock, p_sa addr, int addr_len, int timeout) { + int err; + if (*sock < 0) { + return CLOSED; + } + + do { + if (connect(*sock, addr, addr_len) == 0) { + return SUCCESS; + } + } while ((err = errno) == EINTR); + if (err != EINPROGRESS && err != EAGAIN) { + return err; + } + return socket_wait(sock, WAIT_MODE_C, timeout); +} + +T_ERRCODE socket_send( + p_socket sock, const char *data, size_t len, int timeout) { + int err, put = 0; + if (*sock < 0) { + return CLOSED; + } + do { + put = send(*sock, data, len, 0); + if (put > 0) { + return SUCCESS; + } + } while ((err = errno) == EINTR); + + if (err == EAGAIN) { + return socket_wait(sock, WAIT_MODE_W, timeout); + } + + return err; +} + +T_ERRCODE socket_recv( + p_socket sock, char *data, size_t len, int timeout, int *received) { + int err, got = 0; + if (*sock < 0) { + return CLOSED; + } + *received = 0; + + do { + got = recv(*sock, data, len, 0); + if (got > 0) { + *received = got; + return SUCCESS; + } + err = errno; + + // Connection has been closed by peer + if (got == 0) { + return CLOSED; + } + } while (err == EINTR); + + if (err == EAGAIN) { + return socket_wait(sock, WAIT_MODE_R, timeout); + } + + return err; +} + +//////////////////////////////////////////////////////////////////////////////// +// Util + +T_ERRCODE socket_setnonblocking(p_socket sock) { + int flags = fcntl(*sock, F_GETFL, 0); + flags |= O_NONBLOCK; + return fcntl(*sock, F_SETFL, flags) != -1 ? SUCCESS : errno; +} + +T_ERRCODE socket_setblocking(p_socket sock) { + int flags = fcntl(*sock, F_GETFL, 0); + flags &= (~(O_NONBLOCK)); + return fcntl(*sock, F_SETFL, flags) != -1 ? SUCCESS : errno; +} + +//////////////////////////////////////////////////////////////////////////////// +// TCP + +#define ERRORSTR_RETURN(err) \ + if (err == SUCCESS) { \ + return NULL; \ + } else if (err == TIMEOUT) { \ + return TIMEOUT_MSG; \ + } else if (err == CLOSED) { \ + return CLOSED_MSG; \ + } \ + return strerror(err) + +const char * tcp_create(p_socket sock) { + int err = socket_create(sock, AF_INET, SOCK_STREAM, 0); + ERRORSTR_RETURN(err); +} + +const char * tcp_destroy(p_socket sock) { + int err = socket_destroy(sock); + ERRORSTR_RETURN(err); +} + +const char * tcp_bind(p_socket sock, const char *host, unsigned short port) { + int err; + struct hostent *h; + struct sockaddr_in local; + memset(&local, 0, sizeof(local)); + local.sin_family = AF_INET; + local.sin_addr.s_addr = htonl(INADDR_ANY); + local.sin_port = htons(port); + if (strcmp(host, "*") && !inet_aton(host, &local.sin_addr)) { + h = gethostbyname(host); + if (!h) { + return hstrerror(h_errno); + } + memcpy(&local.sin_addr, + (struct in_addr *)h->h_addr_list[0], + sizeof(struct in_addr)); + } + err = socket_bind(sock, (p_sa) &local, sizeof(local)); + ERRORSTR_RETURN(err); +} + +const char * tcp_listen(p_socket sock, int backlog) { + int err = socket_listen(sock, backlog); + ERRORSTR_RETURN(err); +} + +const char * tcp_accept(p_socket sock, p_socket client, int timeout) { + int err = socket_accept(sock, client, NULL, NULL, timeout); + ERRORSTR_RETURN(err); +} + +const char * tcp_connect(p_socket sock, + const char *host, + unsigned short port, + int timeout) { + int err; + struct hostent *h; + struct sockaddr_in remote; + memset(&remote, 0, sizeof(remote)); + remote.sin_family = AF_INET; + remote.sin_port = htons(port); + if (strcmp(host, "*") && !inet_aton(host, &remote.sin_addr)) { + h = gethostbyname(host); + if (!h) { + return hstrerror(h_errno); + } + memcpy(&remote.sin_addr, + (struct in_addr *)h->h_addr_list[0], + sizeof(struct in_addr)); + } + err = socket_connect(sock, (p_sa) &remote, sizeof(remote), timeout); + ERRORSTR_RETURN(err); +} + +#define WRITE_STEP 8192 +const char * tcp_send( + p_socket sock, const char * data, size_t w_len, int timeout) { + int err; + size_t put = 0, step; + if (!w_len) { + return NULL; + } + + do { + step = (WRITE_STEP < w_len - put ? WRITE_STEP : w_len - put); + err = socket_send(sock, data + put, step, timeout); + put += step; + } while (err == SUCCESS && put < w_len); + ERRORSTR_RETURN(err); +} + +const char * tcp_raw_receive( + p_socket sock, char * data, size_t r_len, int timeout, int *received) { + int err = socket_recv(sock, data, r_len, timeout, received); + ERRORSTR_RETURN(err); +} |