From b67c0f1977029982a7a70af90b8b936c19d04397 Mon Sep 17 00:00:00 2001 From: wangfan Date: Wed, 25 Sep 2019 14:51:49 +0800 Subject: [PATCH] fix thrift conflict --- client/TalosClientFactory.go | 2 +- client/TalosHttpClient.go | 3 +- example/admin/TalosAdminDemo.go | 3 +- .../TalosSimpleProducerDemo.go | 2 +- go.mod | 1 - go.sum | 2 - test/producer/partitionSender_test.go | 2 +- test/producer/simpleProducer_test.go | 2 +- test/producer/talosProducer_test.go | 2 +- test/producer/userMessage_test.go | 2 +- thrift/auth/constants.go | 3 +- thrift/auth/ttypes.go | 3 +- thrift/authorization/constants.go | 3 +- thrift/authorization/ttypes.go | 3 +- thrift/common/constants.go | 3 +- .../talos_base_service-remote.go | 3 +- thrift/common/talosbaseservice.go | 3 +- thrift/common/ttypes.go | 3 +- thrift/consumer/constants.go | 3 +- .../consumer_service-remote.go | 3 +- thrift/consumer/consumerservice.go | 3 +- thrift/consumer/ttypes.go | 3 +- thrift/message/constants.go | 3 +- .../message_service-remote.go | 3 +- thrift/message/messageservice.go | 3 +- thrift/message/ttypes.go | 3 +- thrift/quota/constants.go | 3 +- .../quota_service-remote.go | 3 +- thrift/quota/quotaservice.go | 3 +- thrift/quota/ttypes.go | 3 +- thrift/thrift/application_exception.go | 142 ++ thrift/thrift/application_exception_test.go | 41 + thrift/thrift/binary_protocol.go | 484 +++++++ thrift/thrift/binary_protocol_test.go | 28 + thrift/thrift/buffered_transport.go | 70 + thrift/thrift/buffered_transport_test.go | 29 + thrift/thrift/compact_protocol.go | 797 ++++++++++ thrift/thrift/compact_protocol_test.go | 53 + thrift/thrift/debug_protocol.go | 269 ++++ thrift/thrift/deserializer.go | 58 + thrift/thrift/exception.go | 25 + thrift/thrift/field.go | 79 + thrift/thrift/framed_transport.go | 151 ++ thrift/thrift/framed_transport_test.go | 29 + thrift/thrift/http_client.go | 195 +++ thrift/thrift/http_client_test.go | 50 + thrift/thrift/iostream_transport.go | 205 +++ thrift/thrift/iostream_transport_test.go | 30 + thrift/thrift/json_protocol.go | 566 ++++++++ thrift/thrift/json_protocol_test.go | 646 +++++++++ thrift/thrift/lowlevel_benchmarks_test.go | 396 +++++ thrift/thrift/memory_buffer.go | 79 + thrift/thrift/memory_buffer_test.go | 29 + thrift/thrift/messagetype.go | 31 + thrift/thrift/multiplexed_protocol.go | 169 +++ thrift/thrift/numeric.go | 164 +++ thrift/thrift/pointerize.go | 50 + thrift/thrift/processor.go | 30 + thrift/thrift/processor_factory.go | 58 + thrift/thrift/protocol.go | 154 ++ thrift/thrift/protocol_exception.go | 77 + thrift/thrift/protocol_factory.go | 25 + thrift/thrift/protocol_test.go | 479 +++++++ thrift/thrift/rich_transport.go | 64 + thrift/thrift/rich_transport_test.go | 85 ++ thrift/thrift/serializer.go | 75 + thrift/thrift/serializer_test.go | 169 +++ thrift/thrift/serializer_types.go | 595 ++++++++ thrift/thrift/server.go | 35 + thrift/thrift/server_socket.go | 127 ++ thrift/thrift/server_test.go | 28 + thrift/thrift/server_transport.go | 34 + thrift/thrift/simple_json_protocol.go | 1277 +++++++++++++++++ thrift/thrift/simple_json_protocol_test.go | 632 ++++++++ thrift/thrift/simple_server.go | 188 +++ thrift/thrift/socket.go | 159 ++ thrift/thrift/ssl_server_socket.go | 109 ++ thrift/thrift/ssl_socket.go | 161 +++ thrift/thrift/transport.go | 59 + thrift/thrift/transport_exception.go | 90 ++ thrift/thrift/transport_exception_test.go | 60 + thrift/thrift/transport_factory.go | 39 + thrift/thrift/transport_test.go | 176 +++ thrift/thrift/type.go | 68 + thrift/topic/constants.go | 3 +- .../topic_service-remote.go | 3 +- thrift/topic/topicservice.go | 3 +- thrift/topic/ttypes.go | 2 +- utils/Utils.go | 2 +- 89 files changed, 9974 insertions(+), 38 deletions(-) create mode 100644 thrift/thrift/application_exception.go create mode 100644 thrift/thrift/application_exception_test.go create mode 100644 thrift/thrift/binary_protocol.go create mode 100644 thrift/thrift/binary_protocol_test.go create mode 100644 thrift/thrift/buffered_transport.go create mode 100644 thrift/thrift/buffered_transport_test.go create mode 100644 thrift/thrift/compact_protocol.go create mode 100644 thrift/thrift/compact_protocol_test.go create mode 100644 thrift/thrift/debug_protocol.go create mode 100644 thrift/thrift/deserializer.go create mode 100644 thrift/thrift/exception.go create mode 100644 thrift/thrift/field.go create mode 100644 thrift/thrift/framed_transport.go create mode 100644 thrift/thrift/framed_transport_test.go create mode 100644 thrift/thrift/http_client.go create mode 100644 thrift/thrift/http_client_test.go create mode 100644 thrift/thrift/iostream_transport.go create mode 100644 thrift/thrift/iostream_transport_test.go create mode 100644 thrift/thrift/json_protocol.go create mode 100644 thrift/thrift/json_protocol_test.go create mode 100644 thrift/thrift/lowlevel_benchmarks_test.go create mode 100644 thrift/thrift/memory_buffer.go create mode 100644 thrift/thrift/memory_buffer_test.go create mode 100644 thrift/thrift/messagetype.go create mode 100644 thrift/thrift/multiplexed_protocol.go create mode 100644 thrift/thrift/numeric.go create mode 100644 thrift/thrift/pointerize.go create mode 100644 thrift/thrift/processor.go create mode 100644 thrift/thrift/processor_factory.go create mode 100644 thrift/thrift/protocol.go create mode 100644 thrift/thrift/protocol_exception.go create mode 100644 thrift/thrift/protocol_factory.go create mode 100644 thrift/thrift/protocol_test.go create mode 100644 thrift/thrift/rich_transport.go create mode 100644 thrift/thrift/rich_transport_test.go create mode 100644 thrift/thrift/serializer.go create mode 100644 thrift/thrift/serializer_test.go create mode 100644 thrift/thrift/serializer_types.go create mode 100644 thrift/thrift/server.go create mode 100644 thrift/thrift/server_socket.go create mode 100644 thrift/thrift/server_test.go create mode 100644 thrift/thrift/server_transport.go create mode 100644 thrift/thrift/simple_json_protocol.go create mode 100644 thrift/thrift/simple_json_protocol_test.go create mode 100644 thrift/thrift/simple_server.go create mode 100644 thrift/thrift/socket.go create mode 100644 thrift/thrift/ssl_server_socket.go create mode 100644 thrift/thrift/ssl_socket.go create mode 100644 thrift/thrift/transport.go create mode 100644 thrift/thrift/transport_exception.go create mode 100644 thrift/thrift/transport_exception_test.go create mode 100644 thrift/thrift/transport_factory.go create mode 100644 thrift/thrift/transport_test.go create mode 100644 thrift/thrift/type.go diff --git a/client/TalosClientFactory.go b/client/TalosClientFactory.go index f1604ce..bac5c51 100644 --- a/client/TalosClientFactory.go +++ b/client/TalosClientFactory.go @@ -21,7 +21,7 @@ import ( "github.com/XiaoMi/talos-sdk-golang/thrift/topic" "github.com/XiaoMi/talos-sdk-golang/utils" - "git.apache.org/thrift.git/lib/go/thrift" + "github.com/XiaoMi/talos-sdk-golang/thrift/thrift" ) //var log = utils.Instance diff --git a/client/TalosHttpClient.go b/client/TalosHttpClient.go index 8c53ebe..fe6f86b 100644 --- a/client/TalosHttpClient.go +++ b/client/TalosHttpClient.go @@ -36,8 +36,7 @@ import ( "github.com/XiaoMi/talos-sdk-golang/thrift/auth" "github.com/XiaoMi/talos-sdk-golang/thrift/common" - - "git.apache.org/thrift.git/lib/go/thrift" + "github.com/XiaoMi/talos-sdk-golang/thrift/thrift" "github.com/gofrs/uuid" log "github.com/sirupsen/logrus" ) diff --git a/example/admin/TalosAdminDemo.go b/example/admin/TalosAdminDemo.go index 55144ce..a03c350 100644 --- a/example/admin/TalosAdminDemo.go +++ b/example/admin/TalosAdminDemo.go @@ -12,10 +12,9 @@ import ( "github.com/XiaoMi/talos-sdk-golang/thrift/auth" "github.com/XiaoMi/talos-sdk-golang/thrift/authorization" "github.com/XiaoMi/talos-sdk-golang/thrift/message" + "github.com/XiaoMi/talos-sdk-golang/thrift/thrift" "github.com/XiaoMi/talos-sdk-golang/thrift/topic" "github.com/XiaoMi/talos-sdk-golang/utils" - - "git.apache.org/thrift.git/lib/go/thrift" log "github.com/sirupsen/logrus" ) diff --git a/example/simple_producer/TalosSimpleProducerDemo.go b/example/simple_producer/TalosSimpleProducerDemo.go index 9b330ee..5b89a12 100644 --- a/example/simple_producer/TalosSimpleProducerDemo.go +++ b/example/simple_producer/TalosSimpleProducerDemo.go @@ -15,7 +15,7 @@ import ( "github.com/XiaoMi/talos-sdk-golang/thrift/message" "github.com/XiaoMi/talos-sdk-golang/utils" - "git.apache.org/thrift.git/lib/go/thrift" + "github.com/XiaoMi/talos-sdk-golang/thrift/thrift" log "github.com/sirupsen/logrus" ) diff --git a/go.mod b/go.mod index 68c7735..68057ff 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,6 @@ module github.com/XiaoMi/talos-sdk-golang go 1.12 require ( - git.apache.org/thrift.git v0.0.0-20141105021220-591e20f9636c github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21 github.com/gofrs/uuid v3.2.0+incompatible github.com/golang/mock v1.3.1 diff --git a/go.sum b/go.sum index f7e024b..5f7f99a 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,3 @@ -git.apache.org/thrift.git v0.0.0-20141105021220-591e20f9636c h1:8Z1BqdpEgzGJNfAhMtW8SP/zdw6uoLLJZSSALQcCuak= -git.apache.org/thrift.git v0.0.0-20141105021220-591e20f9636c/go.mod h1:fPE2ZNJGynbRyZ4dJvy6G277gSllfV2HJqblrnkyeyg= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= diff --git a/test/producer/partitionSender_test.go b/test/producer/partitionSender_test.go index 2618225..a0a17a2 100644 --- a/test/producer/partitionSender_test.go +++ b/test/producer/partitionSender_test.go @@ -21,7 +21,7 @@ import ( "talos-sdk-golang/testos-sdk-golang/test/mock_message" "talos-sdk-golang/testos-sdk-golang/test/mock_producer" - "git.apache.org/thrift.git/lib/go/thrift" + "github.com/XiaoMi/talos-sdk-golang/thrift/thrift" "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" ) diff --git a/test/producer/simpleProducer_test.go b/test/producer/simpleProducer_test.go index 19be0f9..d0440f1 100644 --- a/test/producer/simpleProducer_test.go +++ b/test/producer/simpleProducer_test.go @@ -19,7 +19,7 @@ import ( "talos-sdk-golang/producer" "talos-sdk-golang/testos-sdk-golang/test/mock_message" - "git.apache.org/thrift.git/lib/go/thrift" + "github.com/XiaoMi/talos-sdk-golang/thrift/thrift" "github.com/golang/mock/gomock" log4go "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" diff --git a/test/producer/talosProducer_test.go b/test/producer/talosProducer_test.go index a962648..e49ded4 100644 --- a/test/producer/talosProducer_test.go +++ b/test/producer/talosProducer_test.go @@ -30,7 +30,7 @@ import ( "talos-sdk-golang/thrift/auth" - "git.apache.org/thrift.git/lib/go/thrift" + "github.com/XiaoMi/talos-sdk-golang/thrift/thrift" "github.com/golang/mock/gomock" log4go "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" diff --git a/test/producer/userMessage_test.go b/test/producer/userMessage_test.go index ca0671f..c86b98e 100644 --- a/test/producer/userMessage_test.go +++ b/test/producer/userMessage_test.go @@ -14,7 +14,7 @@ import ( "talos-sdk-golang/producer" - "git.apache.org/thrift.git/lib/go/thrift" + "github.com/XiaoMi/talos-sdk-golang/thrift/thrift" "github.com/stretchr/testify/assert" ) diff --git a/thrift/auth/constants.go b/thrift/auth/constants.go index b1d7ea3..ace7d55 100644 --- a/thrift/auth/constants.go +++ b/thrift/auth/constants.go @@ -6,8 +6,9 @@ package auth import ( "bytes" "fmt" - "git.apache.org/thrift.git/lib/go/thrift" + "github.com/XiaoMi/talos-sdk-golang/thrift/common" + "github.com/XiaoMi/talos-sdk-golang/thrift/thrift" ) // (needed to ensure safety because of naive import list construction.) diff --git a/thrift/auth/ttypes.go b/thrift/auth/ttypes.go index 2681c7f..60a25b6 100644 --- a/thrift/auth/ttypes.go +++ b/thrift/auth/ttypes.go @@ -6,8 +6,9 @@ package auth import ( "bytes" "fmt" - "git.apache.org/thrift.git/lib/go/thrift" + "github.com/XiaoMi/talos-sdk-golang/thrift/common" + "github.com/XiaoMi/talos-sdk-golang/thrift/thrift" ) // (needed to ensure safety because of naive import list construction.) diff --git a/thrift/authorization/constants.go b/thrift/authorization/constants.go index 4587499..4535206 100644 --- a/thrift/authorization/constants.go +++ b/thrift/authorization/constants.go @@ -6,7 +6,8 @@ package authorization import ( "bytes" "fmt" - "git.apache.org/thrift.git/lib/go/thrift" + + "github.com/XiaoMi/talos-sdk-golang/thrift/thrift" ) // (needed to ensure safety because of naive import list construction.) diff --git a/thrift/authorization/ttypes.go b/thrift/authorization/ttypes.go index 3764f48..0e37af8 100644 --- a/thrift/authorization/ttypes.go +++ b/thrift/authorization/ttypes.go @@ -6,7 +6,8 @@ package authorization import ( "bytes" "fmt" - "git.apache.org/thrift.git/lib/go/thrift" + + "github.com/XiaoMi/talos-sdk-golang/thrift/thrift" ) // (needed to ensure safety because of naive import list construction.) diff --git a/thrift/common/constants.go b/thrift/common/constants.go index 2252213..f174bd7 100644 --- a/thrift/common/constants.go +++ b/thrift/common/constants.go @@ -6,7 +6,8 @@ package common import ( "bytes" "fmt" - "git.apache.org/thrift.git/lib/go/thrift" + + "github.com/XiaoMi/talos-sdk-golang/thrift/thrift" ) // (needed to ensure safety because of naive import list construction.) diff --git a/thrift/common/talos_base_service-remote/talos_base_service-remote.go b/thrift/common/talos_base_service-remote/talos_base_service-remote.go index 0a7d9b6..6fe503c 100755 --- a/thrift/common/talos_base_service-remote/talos_base_service-remote.go +++ b/thrift/common/talos_base_service-remote/talos_base_service-remote.go @@ -6,7 +6,6 @@ package main import ( "flag" "fmt" - "git.apache.org/thrift.git/lib/go/thrift" "math" "net" "net/url" @@ -14,6 +13,8 @@ import ( "strconv" "strings" "thrift/common" + + "github.com/XiaoMi/talos-sdk-golang/thrift/thrift" ) func Usage() { diff --git a/thrift/common/talosbaseservice.go b/thrift/common/talosbaseservice.go index dd4e8ef..532ca13 100644 --- a/thrift/common/talosbaseservice.go +++ b/thrift/common/talosbaseservice.go @@ -6,7 +6,8 @@ package common import ( "bytes" "fmt" - "git.apache.org/thrift.git/lib/go/thrift" + + "github.com/XiaoMi/talos-sdk-golang/thrift/thrift" ) // (needed to ensure safety because of naive import list construction.) diff --git a/thrift/common/ttypes.go b/thrift/common/ttypes.go index 1dad25a..8485add 100644 --- a/thrift/common/ttypes.go +++ b/thrift/common/ttypes.go @@ -6,7 +6,8 @@ package common import ( "bytes" "fmt" - "git.apache.org/thrift.git/lib/go/thrift" + + "github.com/XiaoMi/talos-sdk-golang/thrift/thrift" ) // (needed to ensure safety because of naive import list construction.) diff --git a/thrift/consumer/constants.go b/thrift/consumer/constants.go index 2827f6c..80027e8 100644 --- a/thrift/consumer/constants.go +++ b/thrift/consumer/constants.go @@ -6,8 +6,9 @@ package consumer import ( "bytes" "fmt" - "git.apache.org/thrift.git/lib/go/thrift" + "github.com/XiaoMi/talos-sdk-golang/thrift/common" + "github.com/XiaoMi/talos-sdk-golang/thrift/thrift" "github.com/XiaoMi/talos-sdk-golang/thrift/topic" ) diff --git a/thrift/consumer/consumer_service-remote/consumer_service-remote.go b/thrift/consumer/consumer_service-remote/consumer_service-remote.go index 1b5e3b0..066c9e0 100755 --- a/thrift/consumer/consumer_service-remote/consumer_service-remote.go +++ b/thrift/consumer/consumer_service-remote/consumer_service-remote.go @@ -6,7 +6,6 @@ package main import ( "flag" "fmt" - "git.apache.org/thrift.git/lib/go/thrift" "math" "net" "net/url" @@ -14,6 +13,8 @@ import ( "strconv" "strings" "thrift/consumer" + + "github.com/XiaoMi/talos-sdk-golang/thrift/thrift" ) func Usage() { diff --git a/thrift/consumer/consumerservice.go b/thrift/consumer/consumerservice.go index d6e799d..d01d89d 100644 --- a/thrift/consumer/consumerservice.go +++ b/thrift/consumer/consumerservice.go @@ -6,8 +6,9 @@ package consumer import ( "bytes" "fmt" - "git.apache.org/thrift.git/lib/go/thrift" + "github.com/XiaoMi/talos-sdk-golang/thrift/common" + "github.com/XiaoMi/talos-sdk-golang/thrift/thrift" "github.com/XiaoMi/talos-sdk-golang/thrift/topic" ) diff --git a/thrift/consumer/ttypes.go b/thrift/consumer/ttypes.go index 97a2a70..e4e265a 100644 --- a/thrift/consumer/ttypes.go +++ b/thrift/consumer/ttypes.go @@ -6,8 +6,9 @@ package consumer import ( "bytes" "fmt" - "git.apache.org/thrift.git/lib/go/thrift" + "github.com/XiaoMi/talos-sdk-golang/thrift/common" + "github.com/XiaoMi/talos-sdk-golang/thrift/thrift" "github.com/XiaoMi/talos-sdk-golang/thrift/topic" ) diff --git a/thrift/message/constants.go b/thrift/message/constants.go index 423446b..7aed2a8 100644 --- a/thrift/message/constants.go +++ b/thrift/message/constants.go @@ -6,8 +6,9 @@ package message import ( "bytes" "fmt" - "git.apache.org/thrift.git/lib/go/thrift" + "github.com/XiaoMi/talos-sdk-golang/thrift/common" + "github.com/XiaoMi/talos-sdk-golang/thrift/thrift" "github.com/XiaoMi/talos-sdk-golang/thrift/topic" ) diff --git a/thrift/message/message_service-remote/message_service-remote.go b/thrift/message/message_service-remote/message_service-remote.go index c70a061..f0a1181 100755 --- a/thrift/message/message_service-remote/message_service-remote.go +++ b/thrift/message/message_service-remote/message_service-remote.go @@ -6,7 +6,6 @@ package main import ( "flag" "fmt" - "git.apache.org/thrift.git/lib/go/thrift" "math" "net" "net/url" @@ -14,6 +13,8 @@ import ( "strconv" "strings" "thrift/message" + + "github.com/XiaoMi/talos-sdk-golang/thrift/thrift" ) func Usage() { diff --git a/thrift/message/messageservice.go b/thrift/message/messageservice.go index 02e0838..04ce354 100644 --- a/thrift/message/messageservice.go +++ b/thrift/message/messageservice.go @@ -6,8 +6,9 @@ package message import ( "bytes" "fmt" - "git.apache.org/thrift.git/lib/go/thrift" + "github.com/XiaoMi/talos-sdk-golang/thrift/common" + "github.com/XiaoMi/talos-sdk-golang/thrift/thrift" "github.com/XiaoMi/talos-sdk-golang/thrift/topic" ) diff --git a/thrift/message/ttypes.go b/thrift/message/ttypes.go index 314181b..f10a4f9 100644 --- a/thrift/message/ttypes.go +++ b/thrift/message/ttypes.go @@ -6,8 +6,9 @@ package message import ( "bytes" "fmt" - "git.apache.org/thrift.git/lib/go/thrift" + "github.com/XiaoMi/talos-sdk-golang/thrift/common" + "github.com/XiaoMi/talos-sdk-golang/thrift/thrift" "github.com/XiaoMi/talos-sdk-golang/thrift/topic" ) diff --git a/thrift/quota/constants.go b/thrift/quota/constants.go index 2183e26..38fd8ec 100644 --- a/thrift/quota/constants.go +++ b/thrift/quota/constants.go @@ -6,8 +6,9 @@ package quota import ( "bytes" "fmt" - "git.apache.org/thrift.git/lib/go/thrift" + "github.com/XiaoMi/talos-sdk-golang/thrift/common" + "github.com/XiaoMi/talos-sdk-golang/thrift/thrift" ) // (needed to ensure safety because of naive import list construction.) diff --git a/thrift/quota/quota_service-remote/quota_service-remote.go b/thrift/quota/quota_service-remote/quota_service-remote.go index 2ba5028..77920a5 100755 --- a/thrift/quota/quota_service-remote/quota_service-remote.go +++ b/thrift/quota/quota_service-remote/quota_service-remote.go @@ -6,7 +6,6 @@ package main import ( "flag" "fmt" - "git.apache.org/thrift.git/lib/go/thrift" "math" "net" "net/url" @@ -14,6 +13,8 @@ import ( "strconv" "strings" "thrift/quota" + + "github.com/XiaoMi/talos-sdk-golang/thrift/thrift" ) func Usage() { diff --git a/thrift/quota/quotaservice.go b/thrift/quota/quotaservice.go index cf1aea4..da44f2c 100644 --- a/thrift/quota/quotaservice.go +++ b/thrift/quota/quotaservice.go @@ -6,8 +6,9 @@ package quota import ( "bytes" "fmt" - "git.apache.org/thrift.git/lib/go/thrift" + "github.com/XiaoMi/talos-sdk-golang/thrift/common" + "github.com/XiaoMi/talos-sdk-golang/thrift/thrift" ) // (needed to ensure safety because of naive import list construction.) diff --git a/thrift/quota/ttypes.go b/thrift/quota/ttypes.go index 92d0735..3857f11 100644 --- a/thrift/quota/ttypes.go +++ b/thrift/quota/ttypes.go @@ -6,8 +6,9 @@ package quota import ( "bytes" "fmt" - "git.apache.org/thrift.git/lib/go/thrift" + "github.com/XiaoMi/talos-sdk-golang/thrift/common" + "github.com/XiaoMi/talos-sdk-golang/thrift/thrift" ) // (needed to ensure safety because of naive import list construction.) diff --git a/thrift/thrift/application_exception.go b/thrift/thrift/application_exception.go new file mode 100644 index 0000000..6655cc5 --- /dev/null +++ b/thrift/thrift/application_exception.go @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +const ( + UNKNOWN_APPLICATION_EXCEPTION = 0 + UNKNOWN_METHOD = 1 + INVALID_MESSAGE_TYPE_EXCEPTION = 2 + WRONG_METHOD_NAME = 3 + BAD_SEQUENCE_ID = 4 + MISSING_RESULT = 5 + INTERNAL_ERROR = 6 + PROTOCOL_ERROR = 7 +) + +// Application level Thrift exception +type TApplicationException interface { + TException + TypeId() int32 + Read(iprot TProtocol) (TApplicationException, error) + Write(oprot TProtocol) error +} + +type tApplicationException struct { + message string + type_ int32 +} + +func (e tApplicationException) Error() string { + return e.message +} + +func NewTApplicationException(type_ int32, message string) TApplicationException { + return &tApplicationException{message, type_} +} + +func (p *tApplicationException) TypeId() int32 { + return p.type_ +} + +func (p *tApplicationException) Read(iprot TProtocol) (TApplicationException, error) { + _, err := iprot.ReadStructBegin() + if err != nil { + return nil, err + } + + message := "" + type_ := int32(UNKNOWN_APPLICATION_EXCEPTION) + + for { + _, ttype, id, err := iprot.ReadFieldBegin() + if err != nil { + return nil, err + } + if ttype == STOP { + break + } + switch id { + case 1: + if ttype == STRING { + if message, err = iprot.ReadString(); err != nil { + return nil, err + } + } else { + if err = SkipDefaultDepth(iprot, ttype); err != nil { + return nil, err + } + } + case 2: + if ttype == I32 { + if type_, err = iprot.ReadI32(); err != nil { + return nil, err + } + } else { + if err = SkipDefaultDepth(iprot, ttype); err != nil { + return nil, err + } + } + default: + if err = SkipDefaultDepth(iprot, ttype); err != nil { + return nil, err + } + } + if err = iprot.ReadFieldEnd(); err != nil { + return nil, err + } + } + return NewTApplicationException(type_, message), iprot.ReadStructEnd() +} + +func (p *tApplicationException) Write(oprot TProtocol) (err error) { + err = oprot.WriteStructBegin("TApplicationException") + if len(p.Error()) > 0 { + err = oprot.WriteFieldBegin("message", STRING, 1) + if err != nil { + return + } + err = oprot.WriteString(p.Error()) + if err != nil { + return + } + err = oprot.WriteFieldEnd() + if err != nil { + return + } + } + err = oprot.WriteFieldBegin("type", I32, 2) + if err != nil { + return + } + err = oprot.WriteI32(p.type_) + if err != nil { + return + } + err = oprot.WriteFieldEnd() + if err != nil { + return + } + err = oprot.WriteFieldStop() + if err != nil { + return + } + err = oprot.WriteStructEnd() + return +} diff --git a/thrift/thrift/application_exception_test.go b/thrift/thrift/application_exception_test.go new file mode 100644 index 0000000..7010f86 --- /dev/null +++ b/thrift/thrift/application_exception_test.go @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import ( + "testing" +) + +func TestTApplicationException(t *testing.T) { + exc := NewTApplicationException(UNKNOWN_APPLICATION_EXCEPTION, "") + if exc.Error() != "" { + t.Fatalf("Expected empty string for exception but found '%s'", exc.Error()) + } + if exc.TypeId() != UNKNOWN_APPLICATION_EXCEPTION { + t.Fatalf("Expected type UNKNOWN for exception but found '%s'", exc.TypeId()) + } + exc = NewTApplicationException(WRONG_METHOD_NAME, "junk_method") + if exc.Error() != "junk_method" { + t.Fatalf("Expected 'junk_method' for exception but found '%s'", exc.Error()) + } + if exc.TypeId() != WRONG_METHOD_NAME { + t.Fatalf("Expected type WRONG_METHOD_NAME for exception but found '%s'", exc.TypeId()) + } +} diff --git a/thrift/thrift/binary_protocol.go b/thrift/thrift/binary_protocol.go new file mode 100644 index 0000000..09f94d4 --- /dev/null +++ b/thrift/thrift/binary_protocol.go @@ -0,0 +1,484 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import ( + "encoding/binary" + "errors" + "fmt" + "io" + "math" +) + +type TBinaryProtocol struct { + trans TRichTransport + origTransport TTransport + reader io.Reader + writer io.Writer + strictRead bool + strictWrite bool + buffer [64]byte +} + +type TBinaryProtocolFactory struct { + strictRead bool + strictWrite bool +} + +func NewTBinaryProtocolTransport(t TTransport) *TBinaryProtocol { + return NewTBinaryProtocol(t, false, true) +} + +func NewTBinaryProtocol(t TTransport, strictRead, strictWrite bool) *TBinaryProtocol { + p := &TBinaryProtocol{origTransport: t, strictRead: strictRead, strictWrite: strictWrite} + if et, ok := t.(TRichTransport); ok { + p.trans = et + } else { + p.trans = NewTRichTransport(t) + } + p.reader = p.trans + p.writer = p.trans + return p +} + +func NewTBinaryProtocolFactoryDefault() *TBinaryProtocolFactory { + return NewTBinaryProtocolFactory(false, true) +} + +func NewTBinaryProtocolFactory(strictRead, strictWrite bool) *TBinaryProtocolFactory { + return &TBinaryProtocolFactory{strictRead: strictRead, strictWrite: strictWrite} +} + +func (p *TBinaryProtocolFactory) GetProtocol(t TTransport) TProtocol { + return NewTBinaryProtocol(t, p.strictRead, p.strictWrite) +} + +/** + * Writing Methods + */ + +func (p *TBinaryProtocol) WriteMessageBegin(name string, typeId TMessageType, seqId int32) error { + if p.strictWrite { + version := uint32(VERSION_1) | uint32(typeId) + e := p.WriteI32(int32(version)) + if e != nil { + return e + } + e = p.WriteString(name) + if e != nil { + return e + } + e = p.WriteI32(seqId) + return e + } else { + e := p.WriteString(name) + if e != nil { + return e + } + e = p.WriteByte(byte(typeId)) + if e != nil { + return e + } + e = p.WriteI32(seqId) + return e + } + return nil +} + +func (p *TBinaryProtocol) WriteMessageEnd() error { + return nil +} + +func (p *TBinaryProtocol) WriteStructBegin(name string) error { + return nil +} + +func (p *TBinaryProtocol) WriteStructEnd() error { + return nil +} + +func (p *TBinaryProtocol) WriteFieldBegin(name string, typeId TType, id int16) error { + e := p.WriteByte(byte(typeId)) + if e != nil { + return e + } + e = p.WriteI16(id) + return e +} + +func (p *TBinaryProtocol) WriteFieldEnd() error { + return nil +} + +func (p *TBinaryProtocol) WriteFieldStop() error { + e := p.WriteByte(STOP) + return e +} + +func (p *TBinaryProtocol) WriteMapBegin(keyType TType, valueType TType, size int) error { + e := p.WriteByte(byte(keyType)) + if e != nil { + return e + } + e = p.WriteByte(byte(valueType)) + if e != nil { + return e + } + e = p.WriteI32(int32(size)) + return e +} + +func (p *TBinaryProtocol) WriteMapEnd() error { + return nil +} + +func (p *TBinaryProtocol) WriteListBegin(elemType TType, size int) error { + e := p.WriteByte(byte(elemType)) + if e != nil { + return e + } + e = p.WriteI32(int32(size)) + return e +} + +func (p *TBinaryProtocol) WriteListEnd() error { + return nil +} + +func (p *TBinaryProtocol) WriteSetBegin(elemType TType, size int) error { + e := p.WriteByte(byte(elemType)) + if e != nil { + return e + } + e = p.WriteI32(int32(size)) + return e +} + +func (p *TBinaryProtocol) WriteSetEnd() error { + return nil +} + +func (p *TBinaryProtocol) WriteBool(value bool) error { + if value { + return p.WriteByte(1) + } + return p.WriteByte(0) +} + +func (p *TBinaryProtocol) WriteByte(value byte) error { + e := p.trans.WriteByte(value) + return NewTProtocolException(e) +} + +func (p *TBinaryProtocol) WriteI16(value int16) error { + v := p.buffer[0:2] + binary.BigEndian.PutUint16(v, uint16(value)) + _, e := p.writer.Write(v) + return NewTProtocolException(e) +} + +func (p *TBinaryProtocol) WriteI32(value int32) error { + v := p.buffer[0:4] + binary.BigEndian.PutUint32(v, uint32(value)) + _, e := p.writer.Write(v) + return NewTProtocolException(e) +} + +func (p *TBinaryProtocol) WriteI64(value int64) error { + v := p.buffer[0:8] + binary.BigEndian.PutUint64(v, uint64(value)) + _, err := p.writer.Write(v) + return NewTProtocolException(err) +} + +func (p *TBinaryProtocol) WriteDouble(value float64) error { + return p.WriteI64(int64(math.Float64bits(value))) +} + +func (p *TBinaryProtocol) WriteString(value string) error { + e := p.WriteI32(int32(len(value))) + if e != nil { + return e + } + _, err := p.trans.WriteString(value) + return NewTProtocolException(err) +} + +func (p *TBinaryProtocol) WriteBinary(value []byte) error { + e := p.WriteI32(int32(len(value))) + if e != nil { + return e + } + _, err := p.writer.Write(value) + return NewTProtocolException(err) +} + +/** + * Reading methods + */ + +func (p *TBinaryProtocol) ReadMessageBegin() (name string, typeId TMessageType, seqId int32, err error) { + size, e := p.ReadI32() + if e != nil { + return "", typeId, 0, NewTProtocolException(e) + } + if size < 0 { + typeId = TMessageType(size & 0x0ff) + version := int64(int64(size) & VERSION_MASK) + if version != VERSION_1 { + return name, typeId, seqId, NewTProtocolExceptionWithType(BAD_VERSION, fmt.Errorf("Bad version in ReadMessageBegin")) + } + name, e = p.ReadString() + if e != nil { + return name, typeId, seqId, NewTProtocolException(e) + } + seqId, e = p.ReadI32() + if e != nil { + return name, typeId, seqId, NewTProtocolException(e) + } + return name, typeId, seqId, nil + } + if p.strictRead { + return name, typeId, seqId, NewTProtocolExceptionWithType(BAD_VERSION, fmt.Errorf("Missing version in ReadMessageBegin")) + } + name, e2 := p.readStringBody(int(size)) + if e2 != nil { + return name, typeId, seqId, e2 + } + b, e3 := p.ReadByte() + if e3 != nil { + return name, typeId, seqId, e3 + } + typeId = TMessageType(b) + seqId, e4 := p.ReadI32() + if e4 != nil { + return name, typeId, seqId, e4 + } + return name, typeId, seqId, nil +} + +func (p *TBinaryProtocol) ReadMessageEnd() error { + return nil +} + +func (p *TBinaryProtocol) ReadStructBegin() (name string, err error) { + return +} + +func (p *TBinaryProtocol) ReadStructEnd() error { + return nil +} + +func (p *TBinaryProtocol) ReadFieldBegin() (name string, typeId TType, seqId int16, err error) { + t, err := p.ReadByte() + typeId = TType(t) + if err != nil { + return name, typeId, seqId, err + } + if t != STOP { + seqId, err = p.ReadI16() + } + return name, typeId, seqId, err +} + +func (p *TBinaryProtocol) ReadFieldEnd() error { + return nil +} + +var invalidDataLength = NewTProtocolExceptionWithType(INVALID_DATA, errors.New("Invalid data length")) + +func (p *TBinaryProtocol) ReadMapBegin() (kType, vType TType, size int, err error) { + k, e := p.ReadByte() + if e != nil { + err = NewTProtocolException(e) + return + } + kType = TType(k) + v, e := p.ReadByte() + if e != nil { + err = NewTProtocolException(e) + return + } + vType = TType(v) + size32, e := p.ReadI32() + if e != nil { + err = NewTProtocolException(e) + return + } + if size32 < 0 { + err = invalidDataLength + return + } + size = int(size32) + return kType, vType, size, nil +} + +func (p *TBinaryProtocol) ReadMapEnd() error { + return nil +} + +func (p *TBinaryProtocol) ReadListBegin() (elemType TType, size int, err error) { + b, e := p.ReadByte() + if e != nil { + err = NewTProtocolException(e) + return + } + elemType = TType(b) + size32, e := p.ReadI32() + if e != nil { + err = NewTProtocolException(e) + return + } + if size32 < 0 { + err = invalidDataLength + return + } + size = int(size32) + + return +} + +func (p *TBinaryProtocol) ReadListEnd() error { + return nil +} + +func (p *TBinaryProtocol) ReadSetBegin() (elemType TType, size int, err error) { + b, e := p.ReadByte() + if e != nil { + err = NewTProtocolException(e) + return + } + elemType = TType(b) + size32, e := p.ReadI32() + if e != nil { + err = NewTProtocolException(e) + return + } + if size32 < 0 { + err = invalidDataLength + return + } + size = int(size32) + return elemType, size, nil +} + +func (p *TBinaryProtocol) ReadSetEnd() error { + return nil +} + +func (p *TBinaryProtocol) ReadBool() (bool, error) { + b, e := p.ReadByte() + v := true + if b != 1 { + v = false + } + return v, e +} + +func (p *TBinaryProtocol) ReadByte() (value byte, err error) { + return p.trans.ReadByte() +} + +func (p *TBinaryProtocol) ReadI16() (value int16, err error) { + buf := p.buffer[0:2] + err = p.readAll(buf) + value = int16(binary.BigEndian.Uint16(buf)) + return value, err +} + +func (p *TBinaryProtocol) ReadI32() (value int32, err error) { + buf := p.buffer[0:4] + err = p.readAll(buf) + value = int32(binary.BigEndian.Uint32(buf)) + return value, err +} + +func (p *TBinaryProtocol) ReadI64() (value int64, err error) { + buf := p.buffer[0:8] + err = p.readAll(buf) + value = int64(binary.BigEndian.Uint64(buf)) + return value, err +} + +func (p *TBinaryProtocol) ReadDouble() (value float64, err error) { + buf := p.buffer[0:8] + err = p.readAll(buf) + value = math.Float64frombits(binary.BigEndian.Uint64(buf)) + return value, err +} + +func (p *TBinaryProtocol) ReadString() (value string, err error) { + size, e := p.ReadI32() + if e != nil { + return "", e + } + if size < 0 { + err = invalidDataLength + return + } + + return p.readStringBody(int(size)) +} + +func (p *TBinaryProtocol) ReadBinary() ([]byte, error) { + size, e := p.ReadI32() + if e != nil { + return nil, e + } + if size < 0 { + return nil, invalidDataLength + } + + isize := int(size) + buf := make([]byte, isize) + _, err := io.ReadFull(p.trans, buf) + return buf, NewTProtocolException(err) +} + +func (p *TBinaryProtocol) Flush() (err error) { + return NewTProtocolException(p.trans.Flush()) +} + +func (p *TBinaryProtocol) Skip(fieldType TType) (err error) { + return SkipDefaultDepth(p, fieldType) +} + +func (p *TBinaryProtocol) Transport() TTransport { + return p.origTransport +} + +func (p *TBinaryProtocol) readAll(buf []byte) error { + _, err := io.ReadFull(p.reader, buf) + return NewTProtocolException(err) +} + +func (p *TBinaryProtocol) readStringBody(size int) (value string, err error) { + if size < 0 { + return "", nil + } + var buf []byte + if size <= len(p.buffer) { + buf = p.buffer[0:size] + } else { + buf = make([]byte, size) + } + _, e := io.ReadFull(p.trans, buf) + return string(buf), NewTProtocolException(e) +} diff --git a/thrift/thrift/binary_protocol_test.go b/thrift/thrift/binary_protocol_test.go new file mode 100644 index 0000000..0462cc7 --- /dev/null +++ b/thrift/thrift/binary_protocol_test.go @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import ( + "testing" +) + +func TestReadWriteBinaryProtocol(t *testing.T) { + ReadWriteProtocolTest(t, NewTBinaryProtocolFactoryDefault()) +} diff --git a/thrift/thrift/buffered_transport.go b/thrift/thrift/buffered_transport.go new file mode 100644 index 0000000..d258b70 --- /dev/null +++ b/thrift/thrift/buffered_transport.go @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import ( + "bufio" +) + +type TBufferedTransportFactory struct { + size int +} + +type TBufferedTransport struct { + bufio.ReadWriter + tp TTransport +} + +func (p *TBufferedTransportFactory) GetTransport(trans TTransport) TTransport { + return NewTBufferedTransport(trans, p.size) +} + +func NewTBufferedTransportFactory(bufferSize int) *TBufferedTransportFactory { + return &TBufferedTransportFactory{size: bufferSize} +} + +func NewTBufferedTransport(trans TTransport, bufferSize int) *TBufferedTransport { + return &TBufferedTransport{ + ReadWriter: bufio.ReadWriter{ + Reader: bufio.NewReaderSize(trans, bufferSize), + Writer: bufio.NewWriterSize(trans, bufferSize), + }, + tp: trans, + } +} + +func (p *TBufferedTransport) IsOpen() bool { + return p.tp.IsOpen() +} + +func (p *TBufferedTransport) Open() (err error) { + return p.tp.Open() +} + +func (p *TBufferedTransport) Close() (err error) { + return p.tp.Close() +} + +func (p *TBufferedTransport) Flush() error { + if err := p.ReadWriter.Flush(); err != nil { + return err + } + return p.tp.Flush() +} diff --git a/thrift/thrift/buffered_transport_test.go b/thrift/thrift/buffered_transport_test.go new file mode 100644 index 0000000..95ec0cb --- /dev/null +++ b/thrift/thrift/buffered_transport_test.go @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import ( + "testing" +) + +func TestBufferedTransport(t *testing.T) { + trans := NewTBufferedTransport(NewTMemoryBuffer(), 10240) + TransportTest(t, trans, trans) +} diff --git a/thrift/thrift/compact_protocol.go b/thrift/thrift/compact_protocol.go new file mode 100644 index 0000000..0857a7a --- /dev/null +++ b/thrift/thrift/compact_protocol.go @@ -0,0 +1,797 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import ( + "encoding/binary" + "fmt" + "io" + "math" +) + +const ( + COMPACT_PROTOCOL_ID = 0x082 + COMPACT_VERSION = 1 + COMPACT_VERSION_MASK = 0x1f + COMPACT_TYPE_MASK = 0x0E0 + COMPACT_TYPE_BITS = 0x07 + COMPACT_TYPE_SHIFT_AMOUNT = 5 +) + +type tCompactType byte + +const ( + COMPACT_BOOLEAN_TRUE = 0x01 + COMPACT_BOOLEAN_FALSE = 0x02 + COMPACT_BYTE = 0x03 + COMPACT_I16 = 0x04 + COMPACT_I32 = 0x05 + COMPACT_I64 = 0x06 + COMPACT_DOUBLE = 0x07 + COMPACT_BINARY = 0x08 + COMPACT_LIST = 0x09 + COMPACT_SET = 0x0A + COMPACT_MAP = 0x0B + COMPACT_STRUCT = 0x0C +) + +var ( + ttypeToCompactType map[TType]tCompactType +) + +func init() { + ttypeToCompactType = map[TType]tCompactType{ + STOP: STOP, + BOOL: COMPACT_BOOLEAN_TRUE, + BYTE: COMPACT_BYTE, + I16: COMPACT_I16, + I32: COMPACT_I32, + I64: COMPACT_I64, + DOUBLE: COMPACT_DOUBLE, + STRING: COMPACT_BINARY, + LIST: COMPACT_LIST, + SET: COMPACT_SET, + MAP: COMPACT_MAP, + STRUCT: COMPACT_STRUCT, + } +} + +type TCompactProtocolFactory struct{} + +func NewTCompactProtocolFactory() *TCompactProtocolFactory { + return &TCompactProtocolFactory{} +} + +func (p *TCompactProtocolFactory) GetProtocol(trans TTransport) TProtocol { + return NewTCompactProtocol(trans) +} + +type TCompactProtocol struct { + trans TRichTransport + origTransport TTransport + + // Used to keep track of the last field for the current and previous structs, + // so we can do the delta stuff. + lastField []int + lastFieldId int + + // If we encounter a boolean field begin, save the TField here so it can + // have the value incorporated. + booleanFieldName string + booleanFieldId int16 + booleanFieldPending bool + + // If we read a field header, and it's a boolean field, save the boolean + // value here so that readBool can use it. + boolValue bool + boolValueIsNotNull bool + buffer [64]byte +} + +// Create a TCompactProtocol given a TTransport +func NewTCompactProtocol(trans TTransport) *TCompactProtocol { + p := &TCompactProtocol{origTransport: trans, lastField: []int{}} + if et, ok := trans.(TRichTransport); ok { + p.trans = et + } else { + p.trans = NewTRichTransport(trans) + } + + return p + +} + +// +// Public Writing methods. +// + +// Write a message header to the wire. Compact Protocol messages contain the +// protocol version so we can migrate forwards in the future if need be. +func (p *TCompactProtocol) WriteMessageBegin(name string, typeId TMessageType, seqid int32) error { + err := p.writeByteDirect(COMPACT_PROTOCOL_ID) + if err != nil { + return NewTProtocolException(err) + } + err = p.writeByteDirect((COMPACT_VERSION & COMPACT_VERSION_MASK) | ((byte(typeId) << COMPACT_TYPE_SHIFT_AMOUNT) & COMPACT_TYPE_MASK)) + if err != nil { + return NewTProtocolException(err) + } + _, err = p.writeVarint32(seqid) + if err != nil { + return NewTProtocolException(err) + } + e := p.WriteString(name) + return e + +} + +func (p *TCompactProtocol) WriteMessageEnd() error { return nil } + +// Write a struct begin. This doesn't actually put anything on the wire. We +// use it as an opportunity to put special placeholder markers on the field +// stack so we can get the field id deltas correct. +func (p *TCompactProtocol) WriteStructBegin(name string) error { + p.lastField = append(p.lastField, p.lastFieldId) + p.lastFieldId = 0 + return nil +} + +// Write a struct end. This doesn't actually put anything on the wire. We use +// this as an opportunity to pop the last field from the current struct off +// of the field stack. +func (p *TCompactProtocol) WriteStructEnd() error { + p.lastFieldId = p.lastField[len(p.lastField)-1] + p.lastField = p.lastField[:len(p.lastField)-1] + return nil +} + +func (p *TCompactProtocol) WriteFieldBegin(name string, typeId TType, id int16) error { + if typeId == BOOL { + // we want to possibly include the value, so we'll wait. + p.booleanFieldName, p.booleanFieldId, p.booleanFieldPending = name, id, true + return nil + } + _, err := p.writeFieldBeginInternal(name, typeId, id, 0xFF) + return NewTProtocolException(err) +} + +// The workhorse of writeFieldBegin. It has the option of doing a +// 'type override' of the type header. This is used specifically in the +// boolean field case. +func (p *TCompactProtocol) writeFieldBeginInternal(name string, typeId TType, id int16, typeOverride byte) (int, error) { + // short lastField = lastField_.pop(); + + // if there's a type override, use that. + var typeToWrite byte + if typeOverride == 0xFF { + typeToWrite = byte(p.getCompactType(typeId)) + } else { + typeToWrite = typeOverride + } + // check if we can use delta encoding for the field id + fieldId := int(id) + written := 0 + if fieldId > p.lastFieldId && fieldId-p.lastFieldId <= 15 { + // write them together + err := p.writeByteDirect(byte((fieldId-p.lastFieldId)<<4) | typeToWrite) + if err != nil { + return 0, err + } + } else { + // write them separate + err := p.writeByteDirect(typeToWrite) + if err != nil { + return 0, err + } + err = p.WriteI16(id) + written = 1 + 2 + if err != nil { + return 0, err + } + } + + p.lastFieldId = fieldId + // p.lastField.Push(field.id); + return written, nil +} + +func (p *TCompactProtocol) WriteFieldEnd() error { return nil } + +func (p *TCompactProtocol) WriteFieldStop() error { + err := p.writeByteDirect(STOP) + return NewTProtocolException(err) +} + +func (p *TCompactProtocol) WriteMapBegin(keyType TType, valueType TType, size int) error { + if size == 0 { + err := p.writeByteDirect(0) + return NewTProtocolException(err) + } + _, err := p.writeVarint32(int32(size)) + if err != nil { + return NewTProtocolException(err) + } + err = p.writeByteDirect(byte(p.getCompactType(keyType))<<4 | byte(p.getCompactType(valueType))) + return NewTProtocolException(err) +} + +func (p *TCompactProtocol) WriteMapEnd() error { return nil } + +// Write a list header. +func (p *TCompactProtocol) WriteListBegin(elemType TType, size int) error { + _, err := p.writeCollectionBegin(elemType, size) + return NewTProtocolException(err) +} + +func (p *TCompactProtocol) WriteListEnd() error { return nil } + +// Write a set header. +func (p *TCompactProtocol) WriteSetBegin(elemType TType, size int) error { + _, err := p.writeCollectionBegin(elemType, size) + return NewTProtocolException(err) +} + +func (p *TCompactProtocol) WriteSetEnd() error { return nil } + +func (p *TCompactProtocol) WriteBool(value bool) error { + v := byte(COMPACT_BOOLEAN_FALSE) + if value { + v = byte(COMPACT_BOOLEAN_TRUE) + } + if p.booleanFieldPending { + // we haven't written the field header yet + _, err := p.writeFieldBeginInternal(p.booleanFieldName, BOOL, p.booleanFieldId, v) + p.booleanFieldPending = false + return NewTProtocolException(err) + } + // we're not part of a field, so just write the value. + err := p.writeByteDirect(v) + return NewTProtocolException(err) +} + +// Write a byte. Nothing to see here! +func (p *TCompactProtocol) WriteByte(value byte) error { + err := p.writeByteDirect(value) + return NewTProtocolException(err) +} + +// Write an I16 as a zigzag varint. +func (p *TCompactProtocol) WriteI16(value int16) error { + _, err := p.writeVarint32(p.int32ToZigzag(int32(value))) + return NewTProtocolException(err) +} + +// Write an i32 as a zigzag varint. +func (p *TCompactProtocol) WriteI32(value int32) error { + _, err := p.writeVarint32(p.int32ToZigzag(value)) + return NewTProtocolException(err) +} + +// Write an i64 as a zigzag varint. +func (p *TCompactProtocol) WriteI64(value int64) error { + _, err := p.writeVarint64(p.int64ToZigzag(value)) + return NewTProtocolException(err) +} + +// Write a double to the wire as 8 bytes. +func (p *TCompactProtocol) WriteDouble(value float64) error { + buf := p.buffer[0:8] + binary.LittleEndian.PutUint64(buf, math.Float64bits(value)) + _, err := p.trans.Write(buf) + return NewTProtocolException(err) +} + +// Write a string to the wire with a varint size preceeding. +func (p *TCompactProtocol) WriteString(value string) error { + _, e := p.writeVarint32(int32(len(value))) + if e != nil { + return NewTProtocolException(e) + } + if len(value) > 0 { + } + _, e = p.trans.WriteString(value) + return e +} + +// Write a byte array, using a varint for the size. +func (p *TCompactProtocol) WriteBinary(bin []byte) error { + _, e := p.writeVarint32(int32(len(bin))) + if e != nil { + return NewTProtocolException(e) + } + if len(bin) > 0 { + _, e = p.trans.Write(bin) + return NewTProtocolException(e) + } + return nil +} + +// +// Reading methods. +// + +// Read a message header. +func (p *TCompactProtocol) ReadMessageBegin() (name string, typeId TMessageType, seqId int32, err error) { + protocolId, err := p.ReadByte() + if protocolId != COMPACT_PROTOCOL_ID { + e := fmt.Errorf("Expected protocol id %02x but got %02x", COMPACT_PROTOCOL_ID, protocolId) + return "", typeId, seqId, NewTProtocolExceptionWithType(BAD_VERSION, e) + } + versionAndType, err := p.ReadByte() + version := versionAndType & COMPACT_VERSION_MASK + typeId = TMessageType((versionAndType >> COMPACT_TYPE_SHIFT_AMOUNT) & COMPACT_TYPE_BITS) + if err != nil { + return + } + if version != COMPACT_VERSION { + e := fmt.Errorf("Expected version %02x but got %02x", COMPACT_VERSION, version) + err = NewTProtocolExceptionWithType(BAD_VERSION, e) + return + } + seqId, e := p.readVarint32() + if e != nil { + err = NewTProtocolException(e) + return + } + name, err = p.ReadString() + return +} + +func (p *TCompactProtocol) ReadMessageEnd() error { return nil } + +// Read a struct begin. There's nothing on the wire for this, but it is our +// opportunity to push a new struct begin marker onto the field stack. +func (p *TCompactProtocol) ReadStructBegin() (name string, err error) { + p.lastField = append(p.lastField, p.lastFieldId) + p.lastFieldId = 0 + return +} + +// Doesn't actually consume any wire data, just removes the last field for +// this struct from the field stack. +func (p *TCompactProtocol) ReadStructEnd() error { + // consume the last field we read off the wire. + p.lastFieldId = p.lastField[len(p.lastField)-1] + p.lastField = p.lastField[:len(p.lastField)-1] + return nil +} + +// Read a field header off the wire. +func (p *TCompactProtocol) ReadFieldBegin() (name string, typeId TType, id int16, err error) { + t, err := p.ReadByte() + if err != nil { + return + } + + // if it's a stop, then we can return immediately, as the struct is over. + if (t & 0x0f) == STOP { + return "", STOP, 0, nil + } + + // mask off the 4 MSB of the type header. it could contain a field id delta. + modifier := int16((t & 0xf0) >> 4) + if modifier == 0 { + // not a delta. look ahead for the zigzag varint field id. + id, err = p.ReadI16() + if err != nil { + return + } + } else { + // has a delta. add the delta to the last read field id. + id = int16(p.lastFieldId) + modifier + } + typeId, e := p.getTType(tCompactType(t & 0x0f)) + if e != nil { + err = NewTProtocolException(e) + return + } + + // if this happens to be a boolean field, the value is encoded in the type + if p.isBoolType(t) { + // save the boolean value in a special instance variable. + p.boolValue = (byte(t)&0x0f == COMPACT_BOOLEAN_TRUE) + p.boolValueIsNotNull = true + } + + // push the new field onto the field stack so we can keep the deltas going. + p.lastFieldId = int(id) + return +} + +func (p *TCompactProtocol) ReadFieldEnd() error { return nil } + +// Read a map header off the wire. If the size is zero, skip reading the key +// and value type. This means that 0-length maps will yield TMaps without the +// "correct" types. +func (p *TCompactProtocol) ReadMapBegin() (keyType TType, valueType TType, size int, err error) { + size32, e := p.readVarint32() + if e != nil { + err = NewTProtocolException(e) + return + } + if size32 < 0 { + err = invalidDataLength + return + } + size = int(size32) + + keyAndValueType := byte(STOP) + if size != 0 { + keyAndValueType, err = p.ReadByte() + if err != nil { + return + } + } + keyType, _ = p.getTType(tCompactType(keyAndValueType >> 4)) + valueType, _ = p.getTType(tCompactType(keyAndValueType & 0xf)) + return +} + +func (p *TCompactProtocol) ReadMapEnd() error { return nil } + +// Read a list header off the wire. If the list size is 0-14, the size will +// be packed into the element type header. If it's a longer list, the 4 MSB +// of the element type header will be 0xF, and a varint will follow with the +// true size. +func (p *TCompactProtocol) ReadListBegin() (elemType TType, size int, err error) { + size_and_type, err := p.ReadByte() + if err != nil { + return + } + size = int((size_and_type >> 4) & 0x0f) + if size == 15 { + size2, e := p.readVarint32() + if e != nil { + err = NewTProtocolException(e) + return + } + if size2 < 0 { + err = invalidDataLength + return + } + size = int(size2) + } + elemType, e := p.getTType(tCompactType(size_and_type)) + if e != nil { + err = NewTProtocolException(e) + return + } + return +} + +func (p *TCompactProtocol) ReadListEnd() error { return nil } + +// Read a set header off the wire. If the set size is 0-14, the size will +// be packed into the element type header. If it's a longer set, the 4 MSB +// of the element type header will be 0xF, and a varint will follow with the +// true size. +func (p *TCompactProtocol) ReadSetBegin() (elemType TType, size int, err error) { + return p.ReadListBegin() +} + +func (p *TCompactProtocol) ReadSetEnd() error { return nil } + +// Read a boolean off the wire. If this is a boolean field, the value should +// already have been read during readFieldBegin, so we'll just consume the +// pre-stored value. Otherwise, read a byte. +func (p *TCompactProtocol) ReadBool() (value bool, err error) { + if p.boolValueIsNotNull { + p.boolValueIsNotNull = false + return p.boolValue, nil + } + v, err := p.ReadByte() + return v == COMPACT_BOOLEAN_TRUE, err +} + +// Read a single byte off the wire. Nothing interesting here. +func (p *TCompactProtocol) ReadByte() (value byte, err error) { + value, err = p.trans.ReadByte() + if err != nil { + return 0, NewTProtocolException(err) + } + return +} + +// Read an i16 from the wire as a zigzag varint. +func (p *TCompactProtocol) ReadI16() (value int16, err error) { + v, err := p.ReadI32() + return int16(v), err +} + +// Read an i32 from the wire as a zigzag varint. +func (p *TCompactProtocol) ReadI32() (value int32, err error) { + v, e := p.readVarint32() + if e != nil { + return 0, NewTProtocolException(e) + } + value = p.zigzagToInt32(v) + return value, nil +} + +// Read an i64 from the wire as a zigzag varint. +func (p *TCompactProtocol) ReadI64() (value int64, err error) { + v, e := p.readVarint64() + if e != nil { + return 0, NewTProtocolException(e) + } + value = p.zigzagToInt64(v) + return value, nil +} + +// No magic here - just read a double off the wire. +func (p *TCompactProtocol) ReadDouble() (value float64, err error) { + longBits := p.buffer[0:8] + _, e := io.ReadFull(p.trans, longBits) + if e != nil { + return 0.0, NewTProtocolException(e) + } + return math.Float64frombits(p.bytesToUint64(longBits)), nil +} + +// Reads a []byte (via readBinary), and then UTF-8 decodes it. +func (p *TCompactProtocol) ReadString() (value string, err error) { + length, e := p.readVarint32() + if e != nil { + return "", NewTProtocolException(e) + } + if length < 0 { + return "", invalidDataLength + } + + if length == 0 { + return "", nil + } + var buf []byte + if length <= int32(len(p.buffer)) { + buf = p.buffer[0:length] + } else { + buf = make([]byte, length) + } + _, e = io.ReadFull(p.trans, buf) + return string(buf), NewTProtocolException(e) +} + +// Read a []byte from the wire. +func (p *TCompactProtocol) ReadBinary() (value []byte, err error) { + length, e := p.readVarint32() + if e != nil { + return nil, NewTProtocolException(e) + } + if length == 0 { + return []byte{}, nil + } + if length < 0 { + return nil, invalidDataLength + } + + buf := make([]byte, length) + _, e = io.ReadFull(p.trans, buf) + return buf, NewTProtocolException(e) +} + +func (p *TCompactProtocol) Flush() (err error) { + return NewTProtocolException(p.trans.Flush()) +} + +func (p *TCompactProtocol) Skip(fieldType TType) (err error) { + return SkipDefaultDepth(p, fieldType) +} + +func (p *TCompactProtocol) Transport() TTransport { + return p.origTransport +} + +// +// Internal writing methods +// + +// Abstract method for writing the start of lists and sets. List and sets on +// the wire differ only by the type indicator. +func (p *TCompactProtocol) writeCollectionBegin(elemType TType, size int) (int, error) { + if size <= 14 { + return 1, p.writeByteDirect(byte(int32(size<<4) | int32(p.getCompactType(elemType)))) + } + err := p.writeByteDirect(0xf0 | byte(p.getCompactType(elemType))) + if err != nil { + return 0, err + } + m, err := p.writeVarint32(int32(size)) + return 1 + m, err +} + +// Write an i32 as a varint. Results in 1-5 bytes on the wire. +// TODO(pomack): make a permanent buffer like writeVarint64? +func (p *TCompactProtocol) writeVarint32(n int32) (int, error) { + i32buf := p.buffer[0:5] + idx := 0 + for { + if (n & ^0x7F) == 0 { + i32buf[idx] = byte(n) + idx++ + // p.writeByteDirect(byte(n)); + break + // return; + } else { + i32buf[idx] = byte((n & 0x7F) | 0x80) + idx++ + // p.writeByteDirect(byte(((n & 0x7F) | 0x80))); + u := uint32(n) + n = int32(u >> 7) + } + } + return p.trans.Write(i32buf[0:idx]) +} + +// Write an i64 as a varint. Results in 1-10 bytes on the wire. +func (p *TCompactProtocol) writeVarint64(n int64) (int, error) { + varint64out := p.buffer[0:10] + idx := 0 + for { + if (n & ^0x7F) == 0 { + varint64out[idx] = byte(n) + idx++ + break + } else { + varint64out[idx] = byte((n & 0x7F) | 0x80) + idx++ + u := uint64(n) + n = int64(u >> 7) + } + } + return p.trans.Write(varint64out[0:idx]) +} + +// Convert l into a zigzag long. This allows negative numbers to be +// represented compactly as a varint. +func (p *TCompactProtocol) int64ToZigzag(l int64) int64 { + return (l << 1) ^ (l >> 63) +} + +// Convert l into a zigzag long. This allows negative numbers to be +// represented compactly as a varint. +func (p *TCompactProtocol) int32ToZigzag(n int32) int32 { + return (n << 1) ^ (n >> 31) +} + +func (p *TCompactProtocol) fixedUint64ToBytes(n uint64, buf []byte) { + binary.LittleEndian.PutUint64(buf, n) +} + +func (p *TCompactProtocol) fixedInt64ToBytes(n int64, buf []byte) { + binary.LittleEndian.PutUint64(buf, uint64(n)) +} + +// Writes a byte without any possiblity of all that field header nonsense. +// Used internally by other writing methods that know they need to write a byte. +func (p *TCompactProtocol) writeByteDirect(b byte) error { + return p.trans.WriteByte(b) +} + +// Writes a byte without any possiblity of all that field header nonsense. +func (p *TCompactProtocol) writeIntAsByteDirect(n int) (int, error) { + return 1, p.writeByteDirect(byte(n)) +} + +// +// Internal reading methods +// + +// Read an i32 from the wire as a varint. The MSB of each byte is set +// if there is another byte to follow. This can read up to 5 bytes. +func (p *TCompactProtocol) readVarint32() (int32, error) { + // if the wire contains the right stuff, this will just truncate the i64 we + // read and get us the right sign. + v, err := p.readVarint64() + return int32(v), err +} + +// Read an i64 from the wire as a proper varint. The MSB of each byte is set +// if there is another byte to follow. This can read up to 10 bytes. +func (p *TCompactProtocol) readVarint64() (int64, error) { + shift := uint(0) + result := int64(0) + for { + b, err := p.ReadByte() + if err != nil { + return 0, err + } + result |= int64(b&0x7f) << shift + if (b & 0x80) != 0x80 { + break + } + shift += 7 + } + return result, nil +} + +// +// encoding helpers +// + +// Convert from zigzag int to int. +func (p *TCompactProtocol) zigzagToInt32(n int32) int32 { + u := uint32(n) + return int32(u>>1) ^ -(n & 1) +} + +// Convert from zigzag long to long. +func (p *TCompactProtocol) zigzagToInt64(n int64) int64 { + u := uint64(n) + return int64(u>>1) ^ -(n & 1) +} + +// Note that it's important that the mask bytes are long literals, +// otherwise they'll default to ints, and when you shift an int left 56 bits, +// you just get a messed up int. +func (p *TCompactProtocol) bytesToInt64(b []byte) int64 { + return int64(binary.LittleEndian.Uint64(b)) +} + +// Note that it's important that the mask bytes are long literals, +// otherwise they'll default to ints, and when you shift an int left 56 bits, +// you just get a messed up int. +func (p *TCompactProtocol) bytesToUint64(b []byte) uint64 { + return binary.LittleEndian.Uint64(b) +} + +// +// type testing and converting +// + +func (p *TCompactProtocol) isBoolType(b byte) bool { + return (b&0x0f) == COMPACT_BOOLEAN_TRUE || (b&0x0f) == COMPACT_BOOLEAN_FALSE +} + +// Given a tCompactType constant, convert it to its corresponding +// TType value. +func (p *TCompactProtocol) getTType(t tCompactType) (TType, error) { + switch byte(t) & 0x0f { + case STOP: + return STOP, nil + case COMPACT_BOOLEAN_FALSE, COMPACT_BOOLEAN_TRUE: + return BOOL, nil + case COMPACT_BYTE: + return BYTE, nil + case COMPACT_I16: + return I16, nil + case COMPACT_I32: + return I32, nil + case COMPACT_I64: + return I64, nil + case COMPACT_DOUBLE: + return DOUBLE, nil + case COMPACT_BINARY: + return STRING, nil + case COMPACT_LIST: + return LIST, nil + case COMPACT_SET: + return SET, nil + case COMPACT_MAP: + return MAP, nil + case COMPACT_STRUCT: + return STRUCT, nil + } + return STOP, TException(fmt.Errorf("don't know what type: %s", t&0x0f)) +} + +// Given a TType value, find the appropriate TCompactProtocol.Types constant. +func (p *TCompactProtocol) getCompactType(t TType) tCompactType { + return ttypeToCompactType[t] +} diff --git a/thrift/thrift/compact_protocol_test.go b/thrift/thrift/compact_protocol_test.go new file mode 100644 index 0000000..f940b4e --- /dev/null +++ b/thrift/thrift/compact_protocol_test.go @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import ( + "bytes" + "testing" +) + +func TestReadWriteCompactProtocol(t *testing.T) { + ReadWriteProtocolTest(t, NewTCompactProtocolFactory()) + transports := []TTransport{ + NewTMemoryBuffer(), + NewStreamTransportRW(bytes.NewBuffer(make([]byte, 0, 16384))), + NewTFramedTransport(NewTMemoryBuffer()), + } + for _, trans := range transports { + p := NewTCompactProtocol(trans) + ReadWriteBool(t, p, trans) + p = NewTCompactProtocol(trans) + ReadWriteByte(t, p, trans) + p = NewTCompactProtocol(trans) + ReadWriteI16(t, p, trans) + p = NewTCompactProtocol(trans) + ReadWriteI32(t, p, trans) + p = NewTCompactProtocol(trans) + ReadWriteI64(t, p, trans) + p = NewTCompactProtocol(trans) + ReadWriteDouble(t, p, trans) + p = NewTCompactProtocol(trans) + ReadWriteString(t, p, trans) + p = NewTCompactProtocol(trans) + ReadWriteBinary(t, p, trans) + trans.Close() + } +} diff --git a/thrift/thrift/debug_protocol.go b/thrift/thrift/debug_protocol.go new file mode 100644 index 0000000..ee341b2 --- /dev/null +++ b/thrift/thrift/debug_protocol.go @@ -0,0 +1,269 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import ( + "log" +) + +type TDebugProtocol struct { + Delegate TProtocol + LogPrefix string +} + +type TDebugProtocolFactory struct { + Underlying TProtocolFactory + LogPrefix string +} + +func NewTDebugProtocolFactory(underlying TProtocolFactory, logPrefix string) *TDebugProtocolFactory { + return &TDebugProtocolFactory{ + Underlying: underlying, + LogPrefix: logPrefix, + } +} + +func (t *TDebugProtocolFactory) GetProtocol(trans TTransport) TProtocol { + return &TDebugProtocol{ + Delegate: t.Underlying.GetProtocol(trans), + LogPrefix: t.LogPrefix, + } +} + +func (tdp *TDebugProtocol) WriteMessageBegin(name string, typeId TMessageType, seqid int32) error { + err := tdp.Delegate.WriteMessageBegin(name, typeId, seqid) + log.Printf("%sWriteMessageBegin(name=%#v, typeId=%#v, seqid=%#v) => %#v", tdp.LogPrefix, name, typeId, seqid, err) + return err +} +func (tdp *TDebugProtocol) WriteMessageEnd() error { + err := tdp.Delegate.WriteMessageEnd() + log.Printf("%sWriteMessageEnd() => %#v", tdp.LogPrefix, err) + return err +} +func (tdp *TDebugProtocol) WriteStructBegin(name string) error { + err := tdp.Delegate.WriteStructBegin(name) + log.Printf("%sWriteStructBegin(name=%#v) => %#v", tdp.LogPrefix, name, err) + return err +} +func (tdp *TDebugProtocol) WriteStructEnd() error { + err := tdp.Delegate.WriteStructEnd() + log.Printf("%sWriteStructEnd() => %#v", tdp.LogPrefix, err) + return err +} +func (tdp *TDebugProtocol) WriteFieldBegin(name string, typeId TType, id int16) error { + err := tdp.Delegate.WriteFieldBegin(name, typeId, id) + log.Printf("%sWriteFieldBegin(name=%#v, typeId=%#v, id%#v) => %#v", tdp.LogPrefix, name, typeId, id, err) + return err +} +func (tdp *TDebugProtocol) WriteFieldEnd() error { + err := tdp.Delegate.WriteFieldEnd() + log.Printf("%sWriteFieldEnd() => %#v", tdp.LogPrefix, err) + return err +} +func (tdp *TDebugProtocol) WriteFieldStop() error { + err := tdp.Delegate.WriteFieldStop() + log.Printf("%sWriteFieldStop() => %#v", tdp.LogPrefix, err) + return err +} +func (tdp *TDebugProtocol) WriteMapBegin(keyType TType, valueType TType, size int) error { + err := tdp.Delegate.WriteMapBegin(keyType, valueType, size) + log.Printf("%sWriteMapBegin(keyType=%#v, valueType=%#v, size=%#v) => %#v", tdp.LogPrefix, keyType, valueType, size, err) + return err +} +func (tdp *TDebugProtocol) WriteMapEnd() error { + err := tdp.Delegate.WriteMapEnd() + log.Printf("%sWriteMapEnd() => %#v", tdp.LogPrefix, err) + return err +} +func (tdp *TDebugProtocol) WriteListBegin(elemType TType, size int) error { + err := tdp.Delegate.WriteListBegin(elemType, size) + log.Printf("%sWriteListBegin(elemType=%#v, size=%#v) => %#v", tdp.LogPrefix, elemType, size, err) + return err +} +func (tdp *TDebugProtocol) WriteListEnd() error { + err := tdp.Delegate.WriteListEnd() + log.Printf("%sWriteListEnd() => %#v", tdp.LogPrefix, err) + return err +} +func (tdp *TDebugProtocol) WriteSetBegin(elemType TType, size int) error { + err := tdp.Delegate.WriteSetBegin(elemType, size) + log.Printf("%sWriteSetBegin(elemType=%#v, size=%#v) => %#v", tdp.LogPrefix, elemType, size, err) + return err +} +func (tdp *TDebugProtocol) WriteSetEnd() error { + err := tdp.Delegate.WriteSetEnd() + log.Printf("%sWriteSetEnd() => %#v", tdp.LogPrefix, err) + return err +} +func (tdp *TDebugProtocol) WriteBool(value bool) error { + err := tdp.Delegate.WriteBool(value) + log.Printf("%sWriteBool(value=%#v) => %#v", tdp.LogPrefix, value, err) + return err +} +func (tdp *TDebugProtocol) WriteByte(value byte) error { + err := tdp.Delegate.WriteByte(value) + log.Printf("%sWriteByte(value=%#v) => %#v", tdp.LogPrefix, value, err) + return err +} +func (tdp *TDebugProtocol) WriteI16(value int16) error { + err := tdp.Delegate.WriteI16(value) + log.Printf("%sWriteI16(value=%#v) => %#v", tdp.LogPrefix, value, err) + return err +} +func (tdp *TDebugProtocol) WriteI32(value int32) error { + err := tdp.Delegate.WriteI32(value) + log.Printf("%sWriteI32(value=%#v) => %#v", tdp.LogPrefix, value, err) + return err +} +func (tdp *TDebugProtocol) WriteI64(value int64) error { + err := tdp.Delegate.WriteI64(value) + log.Printf("%sWriteI64(value=%#v) => %#v", tdp.LogPrefix, value, err) + return err +} +func (tdp *TDebugProtocol) WriteDouble(value float64) error { + err := tdp.Delegate.WriteDouble(value) + log.Printf("%sWriteDouble(value=%#v) => %#v", tdp.LogPrefix, value, err) + return err +} +func (tdp *TDebugProtocol) WriteString(value string) error { + err := tdp.Delegate.WriteString(value) + log.Printf("%sWriteString(value=%#v) => %#v", tdp.LogPrefix, value, err) + return err +} +func (tdp *TDebugProtocol) WriteBinary(value []byte) error { + err := tdp.Delegate.WriteBinary(value) + log.Printf("%sWriteBinary(value=%#v) => %#v", tdp.LogPrefix, value, err) + return err +} + +func (tdp *TDebugProtocol) ReadMessageBegin() (name string, typeId TMessageType, seqid int32, err error) { + name, typeId, seqid, err = tdp.Delegate.ReadMessageBegin() + log.Printf("%sReadMessageBegin() (name=%#v, typeId=%#v, seqid=%#v, err=%#v)", tdp.LogPrefix, name, typeId, seqid, err) + return +} +func (tdp *TDebugProtocol) ReadMessageEnd() (err error) { + err = tdp.Delegate.ReadMessageEnd() + log.Printf("%sReadMessageEnd() err=%#v", tdp.LogPrefix, err) + return +} +func (tdp *TDebugProtocol) ReadStructBegin() (name string, err error) { + name, err = tdp.Delegate.ReadStructBegin() + log.Printf("%sReadStructBegin() (name%#v, err=%#v)", tdp.LogPrefix, name, err) + return +} +func (tdp *TDebugProtocol) ReadStructEnd() (err error) { + err = tdp.Delegate.ReadStructEnd() + log.Printf("%sReadStructEnd() err=%#v", tdp.LogPrefix, err) + return +} +func (tdp *TDebugProtocol) ReadFieldBegin() (name string, typeId TType, id int16, err error) { + name, typeId, id, err = tdp.Delegate.ReadFieldBegin() + log.Printf("%sReadFieldBegin() (name=%#v, typeId=%#v, id=%#v, err=%#v)", tdp.LogPrefix, name, typeId, id, err) + return +} +func (tdp *TDebugProtocol) ReadFieldEnd() (err error) { + err = tdp.Delegate.ReadFieldEnd() + log.Printf("%sReadFieldEnd() err=%#v", tdp.LogPrefix, err) + return +} +func (tdp *TDebugProtocol) ReadMapBegin() (keyType TType, valueType TType, size int, err error) { + keyType, valueType, size, err = tdp.Delegate.ReadMapBegin() + log.Printf("%sReadMapBegin() (keyType=%#v, valueType=%#v, size=%#v, err=%#v)", tdp.LogPrefix, keyType, valueType, size, err) + return +} +func (tdp *TDebugProtocol) ReadMapEnd() (err error) { + err = tdp.Delegate.ReadMapEnd() + log.Printf("%sReadMapEnd() err=%#v", tdp.LogPrefix, err) + return +} +func (tdp *TDebugProtocol) ReadListBegin() (elemType TType, size int, err error) { + elemType, size, err = tdp.Delegate.ReadListBegin() + log.Printf("%sReadListBegin() (elemType=%#v, size=%#v, err=%#v)", tdp.LogPrefix, elemType, size, err) + return +} +func (tdp *TDebugProtocol) ReadListEnd() (err error) { + err = tdp.Delegate.ReadListEnd() + log.Printf("%sReadListEnd() err=%#v", tdp.LogPrefix, err) + return +} +func (tdp *TDebugProtocol) ReadSetBegin() (elemType TType, size int, err error) { + elemType, size, err = tdp.Delegate.ReadSetBegin() + log.Printf("%sReadSetBegin() (elemType=%#v, size=%#v, err=%#v)", tdp.LogPrefix, elemType, size, err) + return +} +func (tdp *TDebugProtocol) ReadSetEnd() (err error) { + err = tdp.Delegate.ReadSetEnd() + log.Printf("%sReadSetEnd() err=%#v", tdp.LogPrefix, err) + return +} +func (tdp *TDebugProtocol) ReadBool() (value bool, err error) { + value, err = tdp.Delegate.ReadBool() + log.Printf("%sReadBool() (value=%#v, err=%#v)", tdp.LogPrefix, value, err) + return +} +func (tdp *TDebugProtocol) ReadByte() (value byte, err error) { + value, err = tdp.Delegate.ReadByte() + log.Printf("%sReadByte() (value=%#v, err=%#v)", tdp.LogPrefix, value, err) + return +} +func (tdp *TDebugProtocol) ReadI16() (value int16, err error) { + value, err = tdp.Delegate.ReadI16() + log.Printf("%sReadI16() (value=%#v, err=%#v)", tdp.LogPrefix, value, err) + return +} +func (tdp *TDebugProtocol) ReadI32() (value int32, err error) { + value, err = tdp.Delegate.ReadI32() + log.Printf("%sReadI32() (value=%#v, err=%#v)", tdp.LogPrefix, value, err) + return +} +func (tdp *TDebugProtocol) ReadI64() (value int64, err error) { + value, err = tdp.Delegate.ReadI64() + log.Printf("%sReadI64() (value=%#v, err=%#v)", tdp.LogPrefix, value, err) + return +} +func (tdp *TDebugProtocol) ReadDouble() (value float64, err error) { + value, err = tdp.Delegate.ReadDouble() + log.Printf("%sReadDouble() (value=%#v, err=%#v)", tdp.LogPrefix, value, err) + return +} +func (tdp *TDebugProtocol) ReadString() (value string, err error) { + value, err = tdp.Delegate.ReadString() + log.Printf("%sReadString() (value=%#v, err=%#v)", tdp.LogPrefix, value, err) + return +} +func (tdp *TDebugProtocol) ReadBinary() (value []byte, err error) { + value, err = tdp.Delegate.ReadBinary() + log.Printf("%sReadBinary() (value=%#v, err=%#v)", tdp.LogPrefix, value, err) + return +} +func (tdp *TDebugProtocol) Skip(fieldType TType) (err error) { + err = tdp.Delegate.Skip(fieldType) + log.Printf("%sSkip(fieldType=%#v) (err=%#v)", tdp.LogPrefix, fieldType, err) + return +} +func (tdp *TDebugProtocol) Flush() (err error) { + err = tdp.Delegate.Flush() + log.Printf("%sFlush() (err=%#v)", tdp.LogPrefix, err) + return +} + +func (tdp *TDebugProtocol) Transport() TTransport { + return tdp.Delegate.Transport() +} diff --git a/thrift/thrift/deserializer.go b/thrift/thrift/deserializer.go new file mode 100644 index 0000000..91a0983 --- /dev/null +++ b/thrift/thrift/deserializer.go @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +type TDeserializer struct { + Transport TTransport + Protocol TProtocol +} + +func NewTDeserializer() *TDeserializer { + var transport TTransport + transport = NewTMemoryBufferLen(1024) + + protocol := NewTBinaryProtocolFactoryDefault().GetProtocol(transport) + + return &TDeserializer{ + transport, + protocol} +} + +func (t *TDeserializer) ReadString(msg TStruct, s string) (err error) { + err = nil + if _, err = t.Transport.Write([]byte(s)); err != nil { + return + } + if err = msg.Read(t.Protocol); err != nil { + return + } + return +} + +func (t *TDeserializer) Read(msg TStruct, b []byte) (err error) { + err = nil + if _, err = t.Transport.Write(b); err != nil { + return + } + if err = msg.Read(t.Protocol); err != nil { + return + } + return +} diff --git a/thrift/thrift/exception.go b/thrift/thrift/exception.go new file mode 100644 index 0000000..e08ffc0 --- /dev/null +++ b/thrift/thrift/exception.go @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +// Generic Thrift exception +type TException interface { + error +} diff --git a/thrift/thrift/field.go b/thrift/thrift/field.go new file mode 100644 index 0000000..9d66525 --- /dev/null +++ b/thrift/thrift/field.go @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +// Helper class that encapsulates field metadata. +type field struct { + name string + typeId TType + id int +} + +func newField(n string, t TType, i int) *field { + return &field{name: n, typeId: t, id: i} +} + +func (p *field) Name() string { + if p == nil { + return "" + } + return p.name +} + +func (p *field) TypeId() TType { + if p == nil { + return TType(VOID) + } + return p.typeId +} + +func (p *field) Id() int { + if p == nil { + return -1 + } + return p.id +} + +func (p *field) String() string { + if p == nil { + return "" + } + return "" +} + +var ANONYMOUS_FIELD *field + +type fieldSlice []field + +func (p fieldSlice) Len() int { + return len(p) +} + +func (p fieldSlice) Less(i, j int) bool { + return p[i].Id() < p[j].Id() +} + +func (p fieldSlice) Swap(i, j int) { + p[i], p[j] = p[j], p[i] +} + +func init() { + ANONYMOUS_FIELD = newField("", STOP, 0) +} diff --git a/thrift/thrift/framed_transport.go b/thrift/thrift/framed_transport.go new file mode 100644 index 0000000..bfecbe8 --- /dev/null +++ b/thrift/thrift/framed_transport.go @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import ( + "bufio" + "bytes" + "encoding/binary" + "fmt" + "io" +) + +const DEFAULT_MAX_LENGTH = 16384000 + +type TFramedTransport struct { + transport TTransport + buf bytes.Buffer + reader *bufio.Reader + frameSize int //Current remaining size of the frame. if ==0 read next frame header + buffer [4]byte + maxLength int +} + +type tFramedTransportFactory struct { + factory TTransportFactory + maxLength int +} + +func NewTFramedTransportFactory(factory TTransportFactory) TTransportFactory { + return &tFramedTransportFactory{factory: factory, maxLength: DEFAULT_MAX_LENGTH} +} + +func (p *tFramedTransportFactory) GetTransport(base TTransport) TTransport { + return NewTFramedTransportMaxLength(p.factory.GetTransport(base), p.maxLength) +} + +func NewTFramedTransport(transport TTransport) *TFramedTransport { + return &TFramedTransport{transport: transport, reader: bufio.NewReader(transport), maxLength: DEFAULT_MAX_LENGTH} +} + +func NewTFramedTransportMaxLength(transport TTransport, maxLength int) *TFramedTransport { + return &TFramedTransport{transport: transport, reader: bufio.NewReader(transport), maxLength: maxLength} +} + +func (p *TFramedTransport) Open() error { + return p.transport.Open() +} + +func (p *TFramedTransport) IsOpen() bool { + return p.transport.IsOpen() +} + +func (p *TFramedTransport) Close() error { + return p.transport.Close() +} + +func (p *TFramedTransport) Read(buf []byte) (l int, err error) { + if p.frameSize == 0 { + p.frameSize, err = p.readFrameHeader() + if err != nil { + return + } + } + if p.frameSize < len(buf) { + return 0, NewTTransportExceptionFromError(fmt.Errorf("Not enought frame size %d to read %d bytes", p.frameSize, len(buf))) + } + got, err := p.reader.Read(buf) + p.frameSize = p.frameSize - got + //sanity check + if p.frameSize < 0 { + return 0, NewTTransportException(UNKNOWN_TRANSPORT_EXCEPTION, "Negative frame size") + } + return got, NewTTransportExceptionFromError(err) +} + +func (p *TFramedTransport) ReadByte() (c byte, err error) { + if p.frameSize == 0 { + p.frameSize, err = p.readFrameHeader() + if err != nil { + return + } + } + if p.frameSize < 1 { + return 0, NewTTransportExceptionFromError(fmt.Errorf("Not enought frame size %d to read %d bytes", p.frameSize, 1)) + } + c, err = p.reader.ReadByte() + if err == nil { + p.frameSize-- + } + return +} + +func (p *TFramedTransport) Write(buf []byte) (int, error) { + n, err := p.buf.Write(buf) + return n, NewTTransportExceptionFromError(err) +} + +func (p *TFramedTransport) WriteByte(c byte) error { + return p.buf.WriteByte(c) +} + +func (p *TFramedTransport) WriteString(s string) (n int, err error) { + return p.buf.WriteString(s) +} + +func (p *TFramedTransport) Flush() error { + size := p.buf.Len() + buf := p.buffer[:4] + binary.BigEndian.PutUint32(buf, uint32(size)) + _, err := p.transport.Write(buf) + if err != nil { + return NewTTransportExceptionFromError(err) + } + if size > 0 { + if n, err := p.buf.WriteTo(p.transport); err != nil { + print("Error while flushing write buffer of size ", size, " to transport, only wrote ", n, " bytes: ", err.Error(), "\n") + return NewTTransportExceptionFromError(err) + } + } + err = p.transport.Flush() + return NewTTransportExceptionFromError(err) +} + +func (p *TFramedTransport) readFrameHeader() (int, error) { + buf := p.buffer[:4] + if _, err := io.ReadFull(p.reader, buf); err != nil { + return 0, err + } + size := int(binary.BigEndian.Uint32(buf)) + if size < 0 || size > p.maxLength { + return 0, NewTTransportException(UNKNOWN_TRANSPORT_EXCEPTION, fmt.Sprintf("Incorrect frame size (%d)", size)) + } + return size, nil +} diff --git a/thrift/thrift/framed_transport_test.go b/thrift/thrift/framed_transport_test.go new file mode 100644 index 0000000..8f683ef --- /dev/null +++ b/thrift/thrift/framed_transport_test.go @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import ( + "testing" +) + +func TestFramedTransport(t *testing.T) { + trans := NewTFramedTransport(NewTMemoryBuffer()) + TransportTest(t, trans, trans) +} diff --git a/thrift/thrift/http_client.go b/thrift/thrift/http_client.go new file mode 100644 index 0000000..df66897 --- /dev/null +++ b/thrift/thrift/http_client.go @@ -0,0 +1,195 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import ( + "bytes" + "io" + "net/http" + "net/url" + "strconv" +) + +type THttpClient struct { + response *http.Response + url *url.URL + requestBuffer *bytes.Buffer + header http.Header + nsecConnectTimeout int64 + nsecReadTimeout int64 +} + +type THttpClientTransportFactory struct { + url string + isPost bool +} + +func (p *THttpClientTransportFactory) GetTransport(trans TTransport) TTransport { + if trans != nil { + t, ok := trans.(*THttpClient) + if ok && t.url != nil { + if t.requestBuffer != nil { + t2, _ := NewTHttpPostClient(t.url.String()) + return t2 + } + t2, _ := NewTHttpClient(t.url.String()) + return t2 + } + } + if p.isPost { + s, _ := NewTHttpPostClient(p.url) + return s + } + s, _ := NewTHttpClient(p.url) + return s +} + +func NewTHttpClientTransportFactory(url string) *THttpClientTransportFactory { + return &THttpClientTransportFactory{url: url, isPost: false} +} + +func NewTHttpPostClientTransportFactory(url string) *THttpClientTransportFactory { + return &THttpClientTransportFactory{url: url, isPost: true} +} + +func NewTHttpClient(urlstr string) (TTransport, error) { + parsedURL, err := url.Parse(urlstr) + if err != nil { + return nil, err + } + response, err := http.Get(urlstr) + if err != nil { + return nil, err + } + return &THttpClient{response: response, url: parsedURL}, nil +} + +func NewTHttpPostClient(urlstr string) (TTransport, error) { + parsedURL, err := url.Parse(urlstr) + if err != nil { + return nil, err + } + buf := make([]byte, 0, 1024) + return &THttpClient{url: parsedURL, requestBuffer: bytes.NewBuffer(buf), header: http.Header{}}, nil +} + +// Set the HTTP Header for this specific Thrift Transport +// It is important that you first assert the TTransport as a THttpClient type +// like so: +// +// httpTrans := trans.(THttpClient) +// httpTrans.SetHeader("User-Agent","Thrift Client 1.0") +func (p *THttpClient) SetHeader(key string, value string) { + p.header.Add(key, value) +} + +// Get the HTTP Header represented by the supplied Header Key for this specific Thrift Transport +// It is important that you first assert the TTransport as a THttpClient type +// like so: +// +// httpTrans := trans.(THttpClient) +// hdrValue := httpTrans.GetHeader("User-Agent") +func (p *THttpClient) GetHeader(key string) string { + return p.header.Get(key) +} + +// Deletes the HTTP Header given a Header Key for this specific Thrift Transport +// It is important that you first assert the TTransport as a THttpClient type +// like so: +// +// httpTrans := trans.(THttpClient) +// httpTrans.DelHeader("User-Agent") +func (p *THttpClient) DelHeader(key string) { + p.header.Del(key) +} + +func (p *THttpClient) Open() error { + // do nothing + return nil +} + +func (p *THttpClient) IsOpen() bool { + return p.response != nil || p.requestBuffer != nil +} + +func (p *THttpClient) Peek() bool { + return p.IsOpen() +} + +func (p *THttpClient) Close() error { + if p.response != nil && p.response.Body != nil { + err := p.response.Body.Close() + p.response = nil + return err + } + if p.requestBuffer != nil { + p.requestBuffer.Reset() + p.requestBuffer = nil + } + return nil +} + +func (p *THttpClient) Read(buf []byte) (int, error) { + if p.response == nil { + return 0, NewTTransportException(NOT_OPEN, "Response buffer is empty, no request.") + } + n, err := p.response.Body.Read(buf) + if n > 0 && (err == nil || err == io.EOF) { + return n, nil + } + return n, NewTTransportExceptionFromError(err) +} + +func (p *THttpClient) ReadByte() (c byte, err error) { + return readByte(p.response.Body) +} + +func (p *THttpClient) Write(buf []byte) (int, error) { + n, err := p.requestBuffer.Write(buf) + return n, err +} + +func (p *THttpClient) WriteByte(c byte) error { + return p.requestBuffer.WriteByte(c) +} + +func (p *THttpClient) WriteString(s string) (n int, err error) { + return p.requestBuffer.WriteString(s) +} + +func (p *THttpClient) Flush() error { + client := &http.Client{} + req, err := http.NewRequest("POST", p.url.String(), p.requestBuffer) + if err != nil { + return NewTTransportExceptionFromError(err) + } + p.header.Add("Content-Type", "application/x-thrift") + req.Header = p.header + response, err := client.Do(req) + if err != nil { + return NewTTransportExceptionFromError(err) + } + if response.StatusCode != http.StatusOK { + // TODO(pomack) log bad response + return NewTTransportException(UNKNOWN_TRANSPORT_EXCEPTION, "HTTP Response code: "+strconv.Itoa(response.StatusCode)) + } + p.response = response + return nil +} diff --git a/thrift/thrift/http_client_test.go b/thrift/thrift/http_client_test.go new file mode 100644 index 0000000..0c2cb28 --- /dev/null +++ b/thrift/thrift/http_client_test.go @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import ( + "testing" +) + +func TestHttpClient(t *testing.T) { + l, addr := HttpClientSetupForTest(t) + if l != nil { + defer l.Close() + } + trans, err := NewTHttpPostClient("http://" + addr.String()) + if err != nil { + l.Close() + t.Fatalf("Unable to connect to %s: %s", addr.String(), err) + } + TransportTest(t, trans, trans) +} + +func TestHttpClientHeaders(t *testing.T) { + l, addr := HttpClientSetupForTest(t) + if l != nil { + defer l.Close() + } + trans, err := NewTHttpPostClient("http://" + addr.String()) + if err != nil { + l.Close() + t.Fatalf("Unable to connect to %s: %s", addr.String(), err) + } + TransportHeaderTest(t, trans, trans) +} diff --git a/thrift/thrift/iostream_transport.go b/thrift/thrift/iostream_transport.go new file mode 100644 index 0000000..314eaa6 --- /dev/null +++ b/thrift/thrift/iostream_transport.go @@ -0,0 +1,205 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import ( + "bufio" + "io" +) + +// StreamTransport is a Transport made of an io.Reader and/or an io.Writer +type StreamTransport struct { + io.Reader + io.Writer + isReadWriter bool +} + +type StreamTransportFactory struct { + Reader io.Reader + Writer io.Writer + isReadWriter bool +} + +func (p *StreamTransportFactory) GetTransport(trans TTransport) TTransport { + if trans != nil { + t, ok := trans.(*StreamTransport) + if ok { + if t.isReadWriter { + return NewStreamTransportRW(t.Reader.(io.ReadWriter)) + } + if t.Reader != nil && t.Writer != nil { + return NewStreamTransport(t.Reader, t.Writer) + } + if t.Reader != nil && t.Writer == nil { + return NewStreamTransportR(t.Reader) + } + if t.Reader == nil && t.Writer != nil { + return NewStreamTransportW(t.Writer) + } + return &StreamTransport{} + } + } + if p.isReadWriter { + return NewStreamTransportRW(p.Reader.(io.ReadWriter)) + } + if p.Reader != nil && p.Writer != nil { + return NewStreamTransport(p.Reader, p.Writer) + } + if p.Reader != nil && p.Writer == nil { + return NewStreamTransportR(p.Reader) + } + if p.Reader == nil && p.Writer != nil { + return NewStreamTransportW(p.Writer) + } + return &StreamTransport{} +} + +func NewStreamTransportFactory(reader io.Reader, writer io.Writer, isReadWriter bool) *StreamTransportFactory { + return &StreamTransportFactory{Reader: reader, Writer: writer, isReadWriter: isReadWriter} +} + +func NewStreamTransport(r io.Reader, w io.Writer) *StreamTransport { + return &StreamTransport{Reader: bufio.NewReader(r), Writer: bufio.NewWriter(w)} +} + +func NewStreamTransportR(r io.Reader) *StreamTransport { + return &StreamTransport{Reader: bufio.NewReader(r)} +} + +func NewStreamTransportW(w io.Writer) *StreamTransport { + return &StreamTransport{Writer: bufio.NewWriter(w)} +} + +func NewStreamTransportRW(rw io.ReadWriter) *StreamTransport { + bufrw := bufio.NewReadWriter(bufio.NewReader(rw), bufio.NewWriter(rw)) + return &StreamTransport{Reader: bufrw, Writer: bufrw, isReadWriter: true} +} + +// (The streams must already be open at construction time, so this should +// always return true.) +func (p *StreamTransport) IsOpen() bool { + return true +} + +// (The streams must already be open. This method does nothing.) +func (p *StreamTransport) Open() error { + return nil +} + +// func (p *StreamTransport) Peek() bool { +// return p.IsOpen() +// } + +// Closes both the input and output streams. +func (p *StreamTransport) Close() error { + closedReader := false + if p.Reader != nil { + c, ok := p.Reader.(io.Closer) + if ok { + e := c.Close() + closedReader = true + if e != nil { + return e + } + } + p.Reader = nil + } + if p.Writer != nil && (!closedReader || !p.isReadWriter) { + c, ok := p.Writer.(io.Closer) + if ok { + e := c.Close() + if e != nil { + return e + } + } + p.Writer = nil + } + return nil +} + +// Flushes the underlying output stream if not null. +func (p *StreamTransport) Flush() error { + if p.Writer == nil { + return NewTTransportException(NOT_OPEN, "Cannot flush null outputStream") + } + f, ok := p.Writer.(Flusher) + if ok { + err := f.Flush() + if err != nil { + return NewTTransportExceptionFromError(err) + } + } + return nil +} + +func (p *StreamTransport) Read(c []byte) (n int, err error) { + n, err = p.Reader.Read(c) + if err != nil { + err = NewTTransportExceptionFromError(err) + } + return +} + +func (p *StreamTransport) ReadByte() (c byte, err error) { + f, ok := p.Reader.(io.ByteReader) + if ok { + c, err = f.ReadByte() + } else { + c, err = readByte(p.Reader) + } + if err != nil { + err = NewTTransportExceptionFromError(err) + } + return +} + +func (p *StreamTransport) Write(c []byte) (n int, err error) { + n, err = p.Writer.Write(c) + if err != nil { + err = NewTTransportExceptionFromError(err) + } + return +} + +func (p *StreamTransport) WriteByte(c byte) (err error) { + f, ok := p.Writer.(io.ByteWriter) + if ok { + err = f.WriteByte(c) + } else { + err = writeByte(p.Writer, c) + } + if err != nil { + err = NewTTransportExceptionFromError(err) + } + return +} + +func (p *StreamTransport) WriteString(s string) (n int, err error) { + f, ok := p.Writer.(stringWriter) + if ok { + n, err = f.WriteString(s) + } else { + n, err = p.Writer.Write([]byte(s)) + } + if err != nil { + err = NewTTransportExceptionFromError(err) + } + return +} diff --git a/thrift/thrift/iostream_transport_test.go b/thrift/thrift/iostream_transport_test.go new file mode 100644 index 0000000..15ea2d4 --- /dev/null +++ b/thrift/thrift/iostream_transport_test.go @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import ( + "bytes" + "testing" +) + +func TestStreamTransport(t *testing.T) { + trans := NewStreamTransportRW(bytes.NewBuffer(make([]byte, 0, 1024))) + TransportTest(t, trans, trans) +} diff --git a/thrift/thrift/json_protocol.go b/thrift/thrift/json_protocol.go new file mode 100644 index 0000000..10d8ca7 --- /dev/null +++ b/thrift/thrift/json_protocol.go @@ -0,0 +1,566 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import ( + "encoding/base64" + "fmt" +) + +const ( + THRIFT_JSON_PROTOCOL_VERSION = 1 +) + +// for references to _ParseContext see tsimplejson_protocol.go + +// JSON protocol implementation for thrift. +// +// This protocol produces/consumes a simple output format +// suitable for parsing by scripting languages. It should not be +// confused with the full-featured TJSONProtocol. +// +type TJSONProtocol struct { + *TSimpleJSONProtocol +} + +// Constructor +func NewTJSONProtocol(t TTransport) *TJSONProtocol { + v := &TJSONProtocol{TSimpleJSONProtocol: NewTSimpleJSONProtocol(t)} + v.parseContextStack = append(v.parseContextStack, int(_CONTEXT_IN_TOPLEVEL)) + v.dumpContext = append(v.dumpContext, int(_CONTEXT_IN_TOPLEVEL)) + return v +} + +// Factory +type TJSONProtocolFactory struct{} + +func (p *TJSONProtocolFactory) GetProtocol(trans TTransport) TProtocol { + return NewTJSONProtocol(trans) +} + +func NewTJSONProtocolFactory() *TJSONProtocolFactory { + return &TJSONProtocolFactory{} +} + +func (p *TJSONProtocol) WriteMessageBegin(name string, typeId TMessageType, seqId int32) error { + if e := p.OutputListBegin(); e != nil { + return e + } + if e := p.WriteI32(THRIFT_JSON_PROTOCOL_VERSION); e != nil { + return e + } + if e := p.WriteString(name); e != nil { + return e + } + if e := p.WriteByte(byte(typeId)); e != nil { + return e + } + if e := p.WriteI32(seqId); e != nil { + return e + } + return nil +} + +func (p *TJSONProtocol) WriteMessageEnd() error { + return p.OutputListEnd() +} + +func (p *TJSONProtocol) WriteStructBegin(name string) error { + if e := p.OutputObjectBegin(); e != nil { + return e + } + return nil +} + +func (p *TJSONProtocol) WriteStructEnd() error { + return p.OutputObjectEnd() +} + +func (p *TJSONProtocol) WriteFieldBegin(name string, typeId TType, id int16) error { + if e := p.WriteI16(id); e != nil { + return e + } + if e := p.OutputObjectBegin(); e != nil { + return e + } + s, e1 := p.TypeIdToString(typeId) + if e1 != nil { + return e1 + } + if e := p.WriteString(s); e != nil { + return e + } + return nil +} + +func (p *TJSONProtocol) WriteFieldEnd() error { + return p.OutputObjectEnd() +} + +func (p *TJSONProtocol) WriteFieldStop() error { return nil } + +func (p *TJSONProtocol) WriteMapBegin(keyType TType, valueType TType, size int) error { + if e := p.OutputListBegin(); e != nil { + return e + } + s, e1 := p.TypeIdToString(keyType) + if e1 != nil { + return e1 + } + if e := p.WriteString(s); e != nil { + return e + } + s, e1 = p.TypeIdToString(valueType) + if e1 != nil { + return e1 + } + if e := p.WriteString(s); e != nil { + return e + } + if e := p.WriteI64(int64(size)); e != nil { + return e + } + return p.OutputObjectBegin() +} + +func (p *TJSONProtocol) WriteMapEnd() error { + if e := p.OutputObjectEnd(); e != nil { + return e + } + return p.OutputListEnd() +} + +func (p *TJSONProtocol) WriteListBegin(elemType TType, size int) error { + return p.OutputElemListBegin(elemType, size) +} + +func (p *TJSONProtocol) WriteListEnd() error { + return p.OutputListEnd() +} + +func (p *TJSONProtocol) WriteSetBegin(elemType TType, size int) error { + return p.OutputElemListBegin(elemType, size) +} + +func (p *TJSONProtocol) WriteSetEnd() error { + return p.OutputListEnd() +} + +func (p *TJSONProtocol) WriteBool(b bool) error { + if b { + return p.WriteI32(1) + } + return p.WriteI32(0) +} + +func (p *TJSONProtocol) WriteByte(b byte) error { + return p.WriteI32(int32(b)) +} + +func (p *TJSONProtocol) WriteI16(v int16) error { + return p.WriteI32(int32(v)) +} + +func (p *TJSONProtocol) WriteI32(v int32) error { + return p.OutputI64(int64(v)) +} + +func (p *TJSONProtocol) WriteI64(v int64) error { + return p.OutputI64(int64(v)) +} + +func (p *TJSONProtocol) WriteDouble(v float64) error { + return p.OutputF64(v) +} + +func (p *TJSONProtocol) WriteString(v string) error { + return p.OutputString(v) +} + +func (p *TJSONProtocol) WriteBinary(v []byte) error { + // JSON library only takes in a string, + // not an arbitrary byte array, to ensure bytes are transmitted + // efficiently we must convert this into a valid JSON string + // therefore we use base64 encoding to avoid excessive escaping/quoting + if e := p.OutputPreValue(); e != nil { + return e + } + p.writer.Write(JSON_QUOTE_BYTES) + writer := base64.NewEncoder(base64.StdEncoding, p.writer) + if _, e := writer.Write(v); e != nil { + return NewTProtocolException(e) + } + writer.Close() + p.writer.Write(JSON_QUOTE_BYTES) + return p.OutputPostValue() +} + +// Reading methods. + +func (p *TJSONProtocol) ReadMessageBegin() (name string, typeId TMessageType, seqId int32, err error) { + if isNull, err := p.ParseListBegin(); isNull || err != nil { + return name, typeId, seqId, err + } + version, err := p.ReadI32() + if err != nil { + return name, typeId, seqId, err + } + if version != THRIFT_JSON_PROTOCOL_VERSION { + e := fmt.Errorf("Unknown Protocol version %d, expected version %d", version, THRIFT_JSON_PROTOCOL_VERSION) + return name, typeId, seqId, NewTProtocolExceptionWithType(INVALID_DATA, e) + + } + if name, err = p.ReadString(); err != nil { + return name, typeId, seqId, err + } + bTypeId, err := p.ReadByte() + typeId = TMessageType(bTypeId) + if err != nil { + return name, typeId, seqId, err + } + if seqId, err = p.ReadI32(); err != nil { + return name, typeId, seqId, err + } + return name, typeId, seqId, nil +} + +func (p *TJSONProtocol) ReadMessageEnd() error { + err := p.ParseListEnd() + return err +} + +func (p *TJSONProtocol) ReadStructBegin() (name string, err error) { + _, err = p.ParseObjectStart() + return "", err +} + +func (p *TJSONProtocol) ReadStructEnd() error { + return p.ParseObjectEnd() +} + +func (p *TJSONProtocol) ReadFieldBegin() (string, TType, int16, error) { + if p.reader.Buffered() < 1 { + return "", STOP, -1, nil + } + b, _ := p.reader.Peek(1) + if len(b) < 1 || b[0] == JSON_RBRACE[0] || b[0] == JSON_RBRACKET[0] { + return "", STOP, -1, nil + } + fieldId, err := p.ReadI16() + if err != nil { + return "", STOP, fieldId, err + } + if _, err = p.ParseObjectStart(); err != nil { + return "", STOP, fieldId, err + } + sType, err := p.ReadString() + if err != nil { + return "", STOP, fieldId, err + } + fType, err := p.StringToTypeId(sType) + return "", fType, fieldId, err +} + +func (p *TJSONProtocol) ReadFieldEnd() error { + return p.ParseObjectEnd() +} + +func (p *TJSONProtocol) ReadMapBegin() (keyType TType, valueType TType, size int, e error) { + if isNull, e := p.ParseListBegin(); isNull || e != nil { + return VOID, VOID, 0, e + } + + // read keyType + sKeyType, e := p.ReadString() + if e != nil { + return keyType, valueType, size, e + } + keyType, e = p.StringToTypeId(sKeyType) + if e != nil { + return keyType, valueType, size, e + } + + // read valueType + sValueType, e := p.ReadString() + if e != nil { + return keyType, valueType, size, e + } + valueType, e = p.StringToTypeId(sValueType) + if e != nil { + return keyType, valueType, size, e + } + + // read size + iSize, err := p.ReadI64() + if err != nil { + return keyType, valueType, size, err + } + size = int(iSize) + _, err = p.ParseObjectStart() + return keyType, valueType, size, err +} + +func (p *TJSONProtocol) ReadMapEnd() error { + if err := p.ParseObjectEnd(); err != nil { + return err + } + return p.ParseListEnd() +} + +func (p *TJSONProtocol) ReadListBegin() (elemType TType, size int, e error) { + return p.ParseElemListBegin() +} + +func (p *TJSONProtocol) ReadListEnd() error { + return p.ParseListEnd() +} + +func (p *TJSONProtocol) ReadSetBegin() (elemType TType, size int, e error) { + return p.ParseElemListBegin() +} + +func (p *TJSONProtocol) ReadSetEnd() error { + return p.ParseListEnd() +} + +func (p *TJSONProtocol) ReadBool() (bool, error) { + value, err := p.ReadI32() + return (value != 0), err +} + +func (p *TJSONProtocol) ReadByte() (byte, error) { + v, err := p.ReadI64() + return byte(v), err +} + +func (p *TJSONProtocol) ReadI16() (int16, error) { + v, err := p.ReadI64() + return int16(v), err +} + +func (p *TJSONProtocol) ReadI32() (int32, error) { + v, err := p.ReadI64() + return int32(v), err +} + +func (p *TJSONProtocol) ReadI64() (int64, error) { + v, _, err := p.ParseI64() + return v, err +} + +func (p *TJSONProtocol) ReadDouble() (float64, error) { + v, _, err := p.ParseF64() + return v, err +} + +func (p *TJSONProtocol) ReadString() (string, error) { + var v string + if err := p.ParsePreValue(); err != nil { + return v, err + } + b, _ := p.reader.Peek(len(JSON_NULL)) + if len(b) > 0 && b[0] == JSON_QUOTE { + p.reader.ReadByte() + value, err := p.ParseStringBody() + v = value + if err != nil { + return v, err + } + } else if len(b) >= len(JSON_NULL) && string(b[0:len(JSON_NULL)]) == string(JSON_NULL) { + _, err := p.reader.Read(b[0:len(JSON_NULL)]) + if err != nil { + return v, NewTProtocolException(err) + } + } else { + e := fmt.Errorf("Expected a JSON string, found %s", string(b)) + return v, NewTProtocolExceptionWithType(INVALID_DATA, e) + } + return v, p.ParsePostValue() +} + +func (p *TJSONProtocol) ReadBinary() ([]byte, error) { + var v []byte + if err := p.ParsePreValue(); err != nil { + return nil, err + } + b, _ := p.reader.Peek(len(JSON_NULL)) + if len(b) > 0 && b[0] == JSON_QUOTE { + p.reader.ReadByte() + value, err := p.ParseBase64EncodedBody() + v = value + if err != nil { + return v, err + } + } else if len(b) >= len(JSON_NULL) && string(b[0:len(JSON_NULL)]) == string(JSON_NULL) { + _, err := p.reader.Read(b[0:len(JSON_NULL)]) + if err != nil { + return v, NewTProtocolException(err) + } + } else { + e := fmt.Errorf("Expected a JSON string, found %s", string(b)) + return v, NewTProtocolExceptionWithType(INVALID_DATA, e) + } + return v, p.ParsePostValue() +} + +func (p *TJSONProtocol) Flush() (err error) { + err = p.writer.Flush() + if err == nil { + err = p.trans.Flush() + } + // we nee pass back the sds specific transport error + return err +} + +func (p *TJSONProtocol) Skip(fieldType TType) (err error) { + return SkipDefaultDepth(p, fieldType) +} + +func (p *TJSONProtocol) Transport() TTransport { + return p.trans +} + +func (p *TJSONProtocol) OutputElemListBegin(elemType TType, size int) error { + if e := p.OutputListBegin(); e != nil { + return e + } + s, e1 := p.TypeIdToString(elemType) + if e1 != nil { + return e1 + } + if e := p.WriteString(s); e != nil { + return e + } + if e := p.WriteI64(int64(size)); e != nil { + return e + } + return nil +} + +func (p *TJSONProtocol) ParseElemListBegin() (elemType TType, size int, e error) { + if isNull, e := p.ParseListBegin(); isNull || e != nil { + return VOID, 0, e + } + sElemType, err := p.ReadString() + if err != nil { + return VOID, size, err + } + elemType, err = p.StringToTypeId(sElemType) + if err != nil { + return elemType, size, err + } + nSize, err2 := p.ReadI64() + size = int(nSize) + return elemType, size, err2 +} + +func (p *TJSONProtocol) readElemListBegin() (elemType TType, size int, e error) { + if isNull, e := p.ParseListBegin(); isNull || e != nil { + return VOID, 0, e + } + sElemType, err := p.ReadString() + if err != nil { + return VOID, size, err + } + elemType, err = p.StringToTypeId(sElemType) + if err != nil { + return elemType, size, err + } + nSize, err2 := p.ReadI64() + size = int(nSize) + return elemType, size, err2 +} + +func (p *TJSONProtocol) writeElemListBegin(elemType TType, size int) error { + if e := p.OutputListBegin(); e != nil { + return e + } + s, e1 := p.TypeIdToString(elemType) + if e1 != nil { + return e1 + } + if e := p.OutputString(s); e != nil { + return e + } + if e := p.OutputI64(int64(size)); e != nil { + return e + } + return nil +} + +func (p *TJSONProtocol) TypeIdToString(fieldType TType) (string, error) { + switch byte(fieldType) { + case BOOL: + return "tf", nil + case BYTE: + return "i8", nil + case I16: + return "i16", nil + case I32: + return "i32", nil + case I64: + return "i64", nil + case DOUBLE: + return "dbl", nil + case STRING: + return "str", nil + case STRUCT: + return "rec", nil + case MAP: + return "map", nil + case SET: + return "set", nil + case LIST: + return "lst", nil + } + + e := fmt.Errorf("Unknown fieldType: %d", int(fieldType)) + return "", NewTProtocolExceptionWithType(INVALID_DATA, e) +} + +func (p *TJSONProtocol) StringToTypeId(fieldType string) (TType, error) { + switch fieldType { + case "tf": + return TType(BOOL), nil + case "i8": + return TType(BYTE), nil + case "i16": + return TType(I16), nil + case "i32": + return TType(I32), nil + case "i64": + return TType(I64), nil + case "dbl": + return TType(DOUBLE), nil + case "str": + return TType(STRING), nil + case "rec": + return TType(STRUCT), nil + case "map": + return TType(MAP), nil + case "set": + return TType(SET), nil + case "lst": + return TType(LIST), nil + } + + e := fmt.Errorf("Unknown type identifier: %s", fieldType) + return TType(STOP), NewTProtocolExceptionWithType(INVALID_DATA, e) +} diff --git a/thrift/thrift/json_protocol_test.go b/thrift/thrift/json_protocol_test.go new file mode 100644 index 0000000..8542a96 --- /dev/null +++ b/thrift/thrift/json_protocol_test.go @@ -0,0 +1,646 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import ( + "encoding/base64" + "encoding/json" + "fmt" + "math" + "strconv" + "testing" +) + +func TestWriteJSONProtocolBool(t *testing.T) { + thetype := "boolean" + trans := NewTMemoryBuffer() + p := NewTJSONProtocol(trans) + for _, value := range BOOL_VALUES { + if e := p.WriteBool(value); e != nil { + t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error()) + } + if e := p.Flush(); e != nil { + t.Fatalf("Unable to write %s value %v due to error flushing: %s", thetype, value, e.Error()) + } + s := trans.String() + expected := "" + if value { + expected = "1" + } else { + expected = "0" + } + if s != expected { + t.Fatalf("Bad value for %s %v: %s expected", thetype, value, s) + } + v := -1 + if err := json.Unmarshal([]byte(s), &v); err != nil || (v != 0) != value { + t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, v) + } + trans.Reset() + } + trans.Close() +} + +func TestReadJSONProtocolBool(t *testing.T) { + thetype := "boolean" + for _, value := range BOOL_VALUES { + trans := NewTMemoryBuffer() + p := NewTJSONProtocol(trans) + if value { + trans.Write([]byte{'1'}) // not JSON_TRUE + } else { + trans.Write([]byte{'0'}) // not JSON_FALSE + } + trans.Flush() + s := trans.String() + v, e := p.ReadBool() + if e != nil { + t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error()) + } + if v != value { + t.Fatalf("Bad value for %s value %v, wrote: %v, received: %v", thetype, value, s, v) + } + vv := -1 + if err := json.Unmarshal([]byte(s), &vv); err != nil || (vv != 0) != value { + t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, vv) + } + trans.Reset() + trans.Close() + } +} + +func TestWriteJSONProtocolByte(t *testing.T) { + thetype := "byte" + trans := NewTMemoryBuffer() + p := NewTJSONProtocol(trans) + for _, value := range BYTE_VALUES { + if e := p.WriteByte(value); e != nil { + t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error()) + } + if e := p.Flush(); e != nil { + t.Fatalf("Unable to write %s value %v due to error flushing: %s", thetype, value, e.Error()) + } + s := trans.String() + if s != fmt.Sprint(value) { + t.Fatalf("Bad value for %s %v: %s", thetype, value, s) + } + v := byte(0) + if err := json.Unmarshal([]byte(s), &v); err != nil || v != value { + t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, v) + } + trans.Reset() + } + trans.Close() +} + +func TestReadJSONProtocolByte(t *testing.T) { + thetype := "byte" + for _, value := range BYTE_VALUES { + trans := NewTMemoryBuffer() + p := NewTJSONProtocol(trans) + trans.WriteString(strconv.Itoa(int(value))) + trans.Flush() + s := trans.String() + v, e := p.ReadByte() + if e != nil { + t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error()) + } + if v != value { + t.Fatalf("Bad value for %s value %v, wrote: %v, received: %v", thetype, value, s, v) + } + if err := json.Unmarshal([]byte(s), &v); err != nil || v != value { + t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, v) + } + trans.Reset() + trans.Close() + } +} + +func TestWriteJSONProtocolI16(t *testing.T) { + thetype := "int16" + trans := NewTMemoryBuffer() + p := NewTJSONProtocol(trans) + for _, value := range INT16_VALUES { + if e := p.WriteI16(value); e != nil { + t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error()) + } + if e := p.Flush(); e != nil { + t.Fatalf("Unable to write %s value %v due to error flushing: %s", thetype, value, e.Error()) + } + s := trans.String() + if s != fmt.Sprint(value) { + t.Fatalf("Bad value for %s %v: %s", thetype, value, s) + } + v := int16(0) + if err := json.Unmarshal([]byte(s), &v); err != nil || v != value { + t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, v) + } + trans.Reset() + } + trans.Close() +} + +func TestReadJSONProtocolI16(t *testing.T) { + thetype := "int16" + for _, value := range INT16_VALUES { + trans := NewTMemoryBuffer() + p := NewTJSONProtocol(trans) + trans.WriteString(strconv.Itoa(int(value))) + trans.Flush() + s := trans.String() + v, e := p.ReadI16() + if e != nil { + t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error()) + } + if v != value { + t.Fatalf("Bad value for %s value %v, wrote: %v, received: %v", thetype, value, s, v) + } + if err := json.Unmarshal([]byte(s), &v); err != nil || v != value { + t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, v) + } + trans.Reset() + trans.Close() + } +} + +func TestWriteJSONProtocolI32(t *testing.T) { + thetype := "int32" + trans := NewTMemoryBuffer() + p := NewTJSONProtocol(trans) + for _, value := range INT32_VALUES { + if e := p.WriteI32(value); e != nil { + t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error()) + } + if e := p.Flush(); e != nil { + t.Fatalf("Unable to write %s value %v due to error flushing: %s", thetype, value, e.Error()) + } + s := trans.String() + if s != fmt.Sprint(value) { + t.Fatalf("Bad value for %s %v: %s", thetype, value, s) + } + v := int32(0) + if err := json.Unmarshal([]byte(s), &v); err != nil || v != value { + t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, v) + } + trans.Reset() + } + trans.Close() +} + +func TestReadJSONProtocolI32(t *testing.T) { + thetype := "int32" + for _, value := range INT32_VALUES { + trans := NewTMemoryBuffer() + p := NewTJSONProtocol(trans) + trans.WriteString(strconv.Itoa(int(value))) + trans.Flush() + s := trans.String() + v, e := p.ReadI32() + if e != nil { + t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error()) + } + if v != value { + t.Fatalf("Bad value for %s value %v, wrote: %v, received: %v", thetype, value, s, v) + } + if err := json.Unmarshal([]byte(s), &v); err != nil || v != value { + t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, v) + } + trans.Reset() + trans.Close() + } +} + +func TestWriteJSONProtocolI64(t *testing.T) { + thetype := "int64" + trans := NewTMemoryBuffer() + p := NewTJSONProtocol(trans) + for _, value := range INT64_VALUES { + if e := p.WriteI64(value); e != nil { + t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error()) + } + if e := p.Flush(); e != nil { + t.Fatalf("Unable to write %s value %v due to error flushing: %s", thetype, value, e.Error()) + } + s := trans.String() + if s != fmt.Sprint(value) { + t.Fatalf("Bad value for %s %v: %s", thetype, value, s) + } + v := int64(0) + if err := json.Unmarshal([]byte(s), &v); err != nil || v != value { + t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, v) + } + trans.Reset() + } + trans.Close() +} + +func TestReadJSONProtocolI64(t *testing.T) { + thetype := "int64" + for _, value := range INT64_VALUES { + trans := NewTMemoryBuffer() + p := NewTJSONProtocol(trans) + trans.WriteString(strconv.FormatInt(value, 10)) + trans.Flush() + s := trans.String() + v, e := p.ReadI64() + if e != nil { + t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error()) + } + if v != value { + t.Fatalf("Bad value for %s value %v, wrote: %v, received: %v", thetype, value, s, v) + } + if err := json.Unmarshal([]byte(s), &v); err != nil || v != value { + t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, v) + } + trans.Reset() + trans.Close() + } +} + +func TestWriteJSONProtocolDouble(t *testing.T) { + thetype := "double" + trans := NewTMemoryBuffer() + p := NewTJSONProtocol(trans) + for _, value := range DOUBLE_VALUES { + if e := p.WriteDouble(value); e != nil { + t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error()) + } + if e := p.Flush(); e != nil { + t.Fatalf("Unable to write %s value %v due to error flushing: %s", thetype, value, e.Error()) + } + s := trans.String() + if math.IsInf(value, 1) { + if s != jsonQuote(JSON_INFINITY) { + t.Fatalf("Bad value for %s %v, wrote: %v, expected: %v", thetype, value, s, jsonQuote(JSON_INFINITY)) + } + } else if math.IsInf(value, -1) { + if s != jsonQuote(JSON_NEGATIVE_INFINITY) { + t.Fatalf("Bad value for %s %v, wrote: %v, expected: %v", thetype, value, s, jsonQuote(JSON_NEGATIVE_INFINITY)) + } + } else if math.IsNaN(value) { + if s != jsonQuote(JSON_NAN) { + t.Fatalf("Bad value for %s %v, wrote: %v, expected: %v", thetype, value, s, jsonQuote(JSON_NAN)) + } + } else { + if s != fmt.Sprint(value) { + t.Fatalf("Bad value for %s %v: %s", thetype, value, s) + } + v := float64(0) + if err := json.Unmarshal([]byte(s), &v); err != nil || v != value { + t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, v) + } + } + trans.Reset() + } + trans.Close() +} + +func TestReadJSONProtocolDouble(t *testing.T) { + thetype := "double" + for _, value := range DOUBLE_VALUES { + trans := NewTMemoryBuffer() + p := NewTJSONProtocol(trans) + n := NewNumericFromDouble(value) + trans.WriteString(n.String()) + trans.Flush() + s := trans.String() + v, e := p.ReadDouble() + if e != nil { + t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error()) + } + if math.IsInf(value, 1) { + if !math.IsInf(v, 1) { + t.Fatalf("Bad value for %s %v, wrote: %v, received: %v", thetype, value, s, v) + } + } else if math.IsInf(value, -1) { + if !math.IsInf(v, -1) { + t.Fatalf("Bad value for %s %v, wrote: %v, received: %v", thetype, value, s, v) + } + } else if math.IsNaN(value) { + if !math.IsNaN(v) { + t.Fatalf("Bad value for %s %v, wrote: %v, received: %v", thetype, value, s, v) + } + } else { + if v != value { + t.Fatalf("Bad value for %s value %v, wrote: %v, received: %v", thetype, value, s, v) + } + if err := json.Unmarshal([]byte(s), &v); err != nil || v != value { + t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, v) + } + } + trans.Reset() + trans.Close() + } +} + +func TestWriteJSONProtocolString(t *testing.T) { + thetype := "string" + trans := NewTMemoryBuffer() + p := NewTJSONProtocol(trans) + for _, value := range STRING_VALUES { + if e := p.WriteString(value); e != nil { + t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error()) + } + if e := p.Flush(); e != nil { + t.Fatalf("Unable to write %s value %v due to error flushing: %s", thetype, value, e.Error()) + } + s := trans.String() + if s[0] != '"' || s[len(s)-1] != '"' { + t.Fatalf("Bad value for %s '%v', wrote '%v', expected: %v", thetype, value, s, fmt.Sprint("\"", value, "\"")) + } + v := new(string) + if err := json.Unmarshal([]byte(s), v); err != nil || *v != value { + t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, *v) + } + trans.Reset() + } + trans.Close() +} + +func TestReadJSONProtocolString(t *testing.T) { + thetype := "string" + for _, value := range STRING_VALUES { + trans := NewTMemoryBuffer() + p := NewTJSONProtocol(trans) + trans.WriteString(jsonQuote(value)) + trans.Flush() + s := trans.String() + v, e := p.ReadString() + if e != nil { + t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error()) + } + if v != value { + t.Fatalf("Bad value for %s value %v, wrote: %v, received: %v", thetype, value, s, v) + } + v1 := new(string) + if err := json.Unmarshal([]byte(s), v1); err != nil || *v1 != value { + t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, *v1) + } + trans.Reset() + trans.Close() + } +} + +func TestWriteJSONProtocolBinary(t *testing.T) { + thetype := "binary" + value := protocol_bdata + b64value := make([]byte, base64.StdEncoding.EncodedLen(len(protocol_bdata))) + base64.StdEncoding.Encode(b64value, value) + b64String := string(b64value) + trans := NewTMemoryBuffer() + p := NewTJSONProtocol(trans) + if e := p.WriteBinary(value); e != nil { + t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error()) + } + if e := p.Flush(); e != nil { + t.Fatalf("Unable to write %s value %v due to error flushing: %s", thetype, value, e.Error()) + } + s := trans.String() + expectedString := fmt.Sprint("\"", b64String, "\"") + if s != expectedString { + t.Fatalf("Bad value for %s %v\n wrote: \"%v\"\nexpected: \"%v\"", thetype, value, s, expectedString) + } + v1, err := p.ReadBinary() + if err != nil { + t.Fatalf("Unable to read binary: %s", err.Error()) + } + if len(v1) != len(value) { + t.Fatalf("Invalid value for binary\nexpected: \"%v\"\n read: \"%v\"", value, v1) + } + for k, v := range value { + if v1[k] != v { + t.Fatalf("Invalid value for binary at %v\nexpected: \"%v\"\n read: \"%v\"", k, v, v1[k]) + } + } + trans.Close() +} + +func TestReadJSONProtocolBinary(t *testing.T) { + thetype := "binary" + value := protocol_bdata + b64value := make([]byte, base64.StdEncoding.EncodedLen(len(protocol_bdata))) + base64.StdEncoding.Encode(b64value, value) + b64String := string(b64value) + trans := NewTMemoryBuffer() + p := NewTJSONProtocol(trans) + trans.WriteString(jsonQuote(b64String)) + trans.Flush() + s := trans.String() + v, e := p.ReadBinary() + if e != nil { + t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error()) + } + if len(v) != len(value) { + t.Fatalf("Bad value for %s value length %v, wrote: %v, received length: %v", thetype, len(value), s, len(v)) + } + for i := 0; i < len(v); i++ { + if v[i] != value[i] { + t.Fatalf("Bad value for %s at index %d value %v, wrote: %v, received: %v", thetype, i, value[i], s, v[i]) + } + } + v1 := new(string) + if err := json.Unmarshal([]byte(s), v1); err != nil || *v1 != b64String { + t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, *v1) + } + trans.Reset() + trans.Close() +} + +func TestWriteJSONProtocolList(t *testing.T) { + thetype := "list" + trans := NewTMemoryBuffer() + p := NewTJSONProtocol(trans) + p.WriteListBegin(TType(DOUBLE), len(DOUBLE_VALUES)) + for _, value := range DOUBLE_VALUES { + if e := p.WriteDouble(value); e != nil { + t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error()) + } + } + p.WriteListEnd() + if e := p.Flush(); e != nil { + t.Fatalf("Unable to write %s due to error flushing: %s", thetype, e.Error()) + } + str := trans.String() + str1 := new([]interface{}) + err := json.Unmarshal([]byte(str), str1) + if err != nil { + t.Fatalf("Unable to decode %s, wrote: %s", thetype, str) + } + l := *str1 + if len(l) < 2 { + t.Fatalf("List must be at least of length two to include metadata") + } + if l[0] != "dbl" { + t.Fatal("Invalid type for list, expected: ", STRING, ", but was: ", l[0]) + } + if int(l[1].(float64)) != len(DOUBLE_VALUES) { + t.Fatal("Invalid length for list, expected: ", len(DOUBLE_VALUES), ", but was: ", l[1]) + } + for k, value := range DOUBLE_VALUES { + s := l[k+2] + if math.IsInf(value, 1) { + if s.(string) != JSON_INFINITY { + t.Fatalf("Bad value for %s at index %v %v, wrote: %q, expected: %q, originally wrote: %q", thetype, k, value, s, jsonQuote(JSON_INFINITY), str) + } + } else if math.IsInf(value, 0) { + if s.(string) != JSON_NEGATIVE_INFINITY { + t.Fatalf("Bad value for %s at index %v %v, wrote: %q, expected: %q, originally wrote: %q", thetype, k, value, s, jsonQuote(JSON_NEGATIVE_INFINITY), str) + } + } else if math.IsNaN(value) { + if s.(string) != JSON_NAN { + t.Fatalf("Bad value for %s at index %v %v, wrote: %q, expected: %q, originally wrote: %q", thetype, k, value, s, jsonQuote(JSON_NAN), str) + } + } else { + if s.(float64) != value { + t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s'", thetype, value, s) + } + } + trans.Reset() + } + trans.Close() +} + +func TestWriteJSONProtocolSet(t *testing.T) { + thetype := "set" + trans := NewTMemoryBuffer() + p := NewTJSONProtocol(trans) + p.WriteSetBegin(TType(DOUBLE), len(DOUBLE_VALUES)) + for _, value := range DOUBLE_VALUES { + if e := p.WriteDouble(value); e != nil { + t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error()) + } + } + p.WriteSetEnd() + if e := p.Flush(); e != nil { + t.Fatalf("Unable to write %s due to error flushing: %s", thetype, e.Error()) + } + str := trans.String() + str1 := new([]interface{}) + err := json.Unmarshal([]byte(str), str1) + if err != nil { + t.Fatalf("Unable to decode %s, wrote: %s", thetype, str) + } + l := *str1 + if len(l) < 2 { + t.Fatalf("Set must be at least of length two to include metadata") + } + if l[0] != "dbl" { + t.Fatal("Invalid type for set, expected: ", DOUBLE, ", but was: ", l[0]) + } + if int(l[1].(float64)) != len(DOUBLE_VALUES) { + t.Fatal("Invalid length for set, expected: ", len(DOUBLE_VALUES), ", but was: ", l[1]) + } + for k, value := range DOUBLE_VALUES { + s := l[k+2] + if math.IsInf(value, 1) { + if s.(string) != JSON_INFINITY { + t.Fatalf("Bad value for %s at index %v %v, wrote: %q, expected: %q, originally wrote: %q", thetype, k, value, s, jsonQuote(JSON_INFINITY), str) + } + } else if math.IsInf(value, 0) { + if s.(string) != JSON_NEGATIVE_INFINITY { + t.Fatalf("Bad value for %s at index %v %v, wrote: %q, expected: %q, originally wrote: %q", thetype, k, value, s, jsonQuote(JSON_NEGATIVE_INFINITY), str) + } + } else if math.IsNaN(value) { + if s.(string) != JSON_NAN { + t.Fatalf("Bad value for %s at index %v %v, wrote: %q, expected: %q, originally wrote: %q", thetype, k, value, s, jsonQuote(JSON_NAN), str) + } + } else { + if s.(float64) != value { + t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s'", thetype, value, s) + } + } + trans.Reset() + } + trans.Close() +} + +func TestWriteJSONProtocolMap(t *testing.T) { + thetype := "map" + trans := NewTMemoryBuffer() + p := NewTJSONProtocol(trans) + p.WriteMapBegin(TType(I32), TType(DOUBLE), len(DOUBLE_VALUES)) + for k, value := range DOUBLE_VALUES { + if e := p.WriteI32(int32(k)); e != nil { + t.Fatalf("Unable to write %s key int32 value %v due to error: %s", thetype, k, e.Error()) + } + if e := p.WriteDouble(value); e != nil { + t.Fatalf("Unable to write %s value float64 value %v due to error: %s", thetype, value, e.Error()) + } + } + p.WriteMapEnd() + if e := p.Flush(); e != nil { + t.Fatalf("Unable to write %s due to error flushing: %s", thetype, e.Error()) + } + str := trans.String() + if str[0] != '[' || str[len(str)-1] != ']' { + t.Fatalf("Bad value for %s, wrote: %q, in go: %q", thetype, str, DOUBLE_VALUES) + } + expectedKeyType, expectedValueType, expectedSize, err := p.ReadMapBegin() + if err != nil { + t.Fatalf("Error while reading map begin: %s", err.Error()) + } + if expectedKeyType != I32 { + t.Fatal("Expected map key type ", I32, ", but was ", expectedKeyType) + } + if expectedValueType != DOUBLE { + t.Fatal("Expected map value type ", DOUBLE, ", but was ", expectedValueType) + } + if expectedSize != len(DOUBLE_VALUES) { + t.Fatal("Expected map size of ", len(DOUBLE_VALUES), ", but was ", expectedSize) + } + for k, value := range DOUBLE_VALUES { + ik, err := p.ReadI32() + if err != nil { + t.Fatalf("Bad key for %s index %v, wrote: %v, expected: %v, error: %s", thetype, k, ik, string(k), err.Error()) + } + if int(ik) != k { + t.Fatalf("Bad key for %s index %v, wrote: %v, expected: %v", thetype, k, ik, k) + } + dv, err := p.ReadDouble() + if err != nil { + t.Fatalf("Bad value for %s index %v, wrote: %v, expected: %v, error: %s", thetype, k, dv, value, err.Error()) + } + s := strconv.FormatFloat(dv, 'g', 10, 64) + if math.IsInf(value, 1) { + if !math.IsInf(dv, 1) { + t.Fatalf("Bad value for %s at index %v %v, wrote: %v, expected: %v", thetype, k, value, s, jsonQuote(JSON_INFINITY)) + } + } else if math.IsInf(value, 0) { + if !math.IsInf(dv, 0) { + t.Fatalf("Bad value for %s at index %v %v, wrote: %v, expected: %v", thetype, k, value, s, jsonQuote(JSON_NEGATIVE_INFINITY)) + } + } else if math.IsNaN(value) { + if !math.IsNaN(dv) { + t.Fatalf("Bad value for %s at index %v %v, wrote: %v, expected: %v", thetype, k, value, s, jsonQuote(JSON_NAN)) + } + } else { + expected := strconv.FormatFloat(value, 'g', 10, 64) + if s != expected { + t.Fatalf("Bad value for %s at index %v %v, wrote: %v, expected %v", thetype, k, value, s, expected) + } + v := float64(0) + if err := json.Unmarshal([]byte(s), &v); err != nil || v != value { + t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, v) + } + } + trans.Reset() + } + trans.Close() +} diff --git a/thrift/thrift/lowlevel_benchmarks_test.go b/thrift/thrift/lowlevel_benchmarks_test.go new file mode 100644 index 0000000..a5094ae --- /dev/null +++ b/thrift/thrift/lowlevel_benchmarks_test.go @@ -0,0 +1,396 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import ( + "bytes" + "testing" +) + +var binaryProtoF = NewTBinaryProtocolFactoryDefault() +var compactProtoF = NewTCompactProtocolFactory() + +var buf = bytes.NewBuffer(make([]byte, 0, 1024)) + +var tfv = []TTransportFactory{ + NewTMemoryBufferTransportFactory(1024), + NewStreamTransportFactory(buf, buf, true), + NewTFramedTransportFactory(NewTMemoryBufferTransportFactory(1024)), +} + +func BenchmarkBinaryBool_0(b *testing.B) { + trans := tfv[0].GetTransport(nil) + p := binaryProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteBool(b, p, trans) + } +} + +func BenchmarkBinaryByte_0(b *testing.B) { + trans := tfv[0].GetTransport(nil) + p := binaryProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteByte(b, p, trans) + } +} + +func BenchmarkBinaryI16_0(b *testing.B) { + trans := tfv[0].GetTransport(nil) + p := binaryProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteI16(b, p, trans) + } +} + +func BenchmarkBinaryI32_0(b *testing.B) { + trans := tfv[0].GetTransport(nil) + p := binaryProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteI32(b, p, trans) + } +} +func BenchmarkBinaryI64_0(b *testing.B) { + trans := tfv[0].GetTransport(nil) + p := binaryProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteI64(b, p, trans) + } +} +func BenchmarkBinaryDouble_0(b *testing.B) { + trans := tfv[0].GetTransport(nil) + p := binaryProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteDouble(b, p, trans) + } +} +func BenchmarkBinaryString_0(b *testing.B) { + trans := tfv[0].GetTransport(nil) + p := binaryProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteString(b, p, trans) + } +} +func BenchmarkBinaryBinary_0(b *testing.B) { + trans := tfv[0].GetTransport(nil) + p := binaryProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteBinary(b, p, trans) + } +} + +func BenchmarkBinaryBool_1(b *testing.B) { + trans := tfv[1].GetTransport(nil) + p := binaryProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteBool(b, p, trans) + } +} + +func BenchmarkBinaryByte_1(b *testing.B) { + trans := tfv[1].GetTransport(nil) + p := binaryProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteByte(b, p, trans) + } +} + +func BenchmarkBinaryI16_1(b *testing.B) { + trans := tfv[1].GetTransport(nil) + p := binaryProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteI16(b, p, trans) + } +} + +func BenchmarkBinaryI32_1(b *testing.B) { + trans := tfv[1].GetTransport(nil) + p := binaryProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteI32(b, p, trans) + } +} +func BenchmarkBinaryI64_1(b *testing.B) { + trans := tfv[1].GetTransport(nil) + p := binaryProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteI64(b, p, trans) + } +} +func BenchmarkBinaryDouble_1(b *testing.B) { + trans := tfv[1].GetTransport(nil) + p := binaryProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteDouble(b, p, trans) + } +} +func BenchmarkBinaryString_1(b *testing.B) { + trans := tfv[1].GetTransport(nil) + p := binaryProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteString(b, p, trans) + } +} +func BenchmarkBinaryBinary_1(b *testing.B) { + trans := tfv[1].GetTransport(nil) + p := binaryProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteBinary(b, p, trans) + } +} + +func BenchmarkBinaryBool_2(b *testing.B) { + trans := tfv[2].GetTransport(nil) + p := binaryProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteBool(b, p, trans) + } +} + +func BenchmarkBinaryByte_2(b *testing.B) { + trans := tfv[2].GetTransport(nil) + p := binaryProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteByte(b, p, trans) + } +} + +func BenchmarkBinaryI16_2(b *testing.B) { + trans := tfv[2].GetTransport(nil) + p := binaryProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteI16(b, p, trans) + } +} + +func BenchmarkBinaryI32_2(b *testing.B) { + trans := tfv[2].GetTransport(nil) + p := binaryProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteI32(b, p, trans) + } +} +func BenchmarkBinaryI64_2(b *testing.B) { + trans := tfv[2].GetTransport(nil) + p := binaryProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteI64(b, p, trans) + } +} +func BenchmarkBinaryDouble_2(b *testing.B) { + trans := tfv[2].GetTransport(nil) + p := binaryProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteDouble(b, p, trans) + } +} +func BenchmarkBinaryString_2(b *testing.B) { + trans := tfv[2].GetTransport(nil) + p := binaryProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteString(b, p, trans) + } +} +func BenchmarkBinaryBinary_2(b *testing.B) { + trans := tfv[2].GetTransport(nil) + p := binaryProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteBinary(b, p, trans) + } +} + +func BenchmarkCompactBool_0(b *testing.B) { + trans := tfv[0].GetTransport(nil) + p := compactProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteBool(b, p, trans) + } +} + +func BenchmarkCompactByte_0(b *testing.B) { + trans := tfv[0].GetTransport(nil) + p := compactProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteByte(b, p, trans) + } +} + +func BenchmarkCompactI16_0(b *testing.B) { + trans := tfv[0].GetTransport(nil) + p := compactProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteI16(b, p, trans) + } +} + +func BenchmarkCompactI32_0(b *testing.B) { + trans := tfv[0].GetTransport(nil) + p := compactProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteI32(b, p, trans) + } +} +func BenchmarkCompactI64_0(b *testing.B) { + trans := tfv[0].GetTransport(nil) + p := compactProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteI64(b, p, trans) + } +} +func BenchmarkCompactDouble0(b *testing.B) { + trans := tfv[0].GetTransport(nil) + p := compactProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteDouble(b, p, trans) + } +} +func BenchmarkCompactString0(b *testing.B) { + trans := tfv[0].GetTransport(nil) + p := compactProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteString(b, p, trans) + } +} +func BenchmarkCompactBinary0(b *testing.B) { + trans := tfv[0].GetTransport(nil) + p := compactProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteBinary(b, p, trans) + } +} + +func BenchmarkCompactBool_1(b *testing.B) { + trans := tfv[1].GetTransport(nil) + p := compactProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteBool(b, p, trans) + } +} + +func BenchmarkCompactByte_1(b *testing.B) { + trans := tfv[1].GetTransport(nil) + p := compactProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteByte(b, p, trans) + } +} + +func BenchmarkCompactI16_1(b *testing.B) { + trans := tfv[1].GetTransport(nil) + p := compactProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteI16(b, p, trans) + } +} + +func BenchmarkCompactI32_1(b *testing.B) { + trans := tfv[1].GetTransport(nil) + p := compactProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteI32(b, p, trans) + } +} +func BenchmarkCompactI64_1(b *testing.B) { + trans := tfv[1].GetTransport(nil) + p := compactProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteI64(b, p, trans) + } +} +func BenchmarkCompactDouble1(b *testing.B) { + trans := tfv[1].GetTransport(nil) + p := compactProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteDouble(b, p, trans) + } +} +func BenchmarkCompactString1(b *testing.B) { + trans := tfv[1].GetTransport(nil) + p := compactProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteString(b, p, trans) + } +} +func BenchmarkCompactBinary1(b *testing.B) { + trans := tfv[1].GetTransport(nil) + p := compactProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteBinary(b, p, trans) + } +} + +func BenchmarkCompactBool_2(b *testing.B) { + trans := tfv[2].GetTransport(nil) + p := compactProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteBool(b, p, trans) + } +} + +func BenchmarkCompactByte_2(b *testing.B) { + trans := tfv[2].GetTransport(nil) + p := compactProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteByte(b, p, trans) + } +} + +func BenchmarkCompactI16_2(b *testing.B) { + trans := tfv[2].GetTransport(nil) + p := compactProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteI16(b, p, trans) + } +} + +func BenchmarkCompactI32_2(b *testing.B) { + trans := tfv[2].GetTransport(nil) + p := compactProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteI32(b, p, trans) + } +} +func BenchmarkCompactI64_2(b *testing.B) { + trans := tfv[2].GetTransport(nil) + p := compactProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteI64(b, p, trans) + } +} +func BenchmarkCompactDouble2(b *testing.B) { + trans := tfv[2].GetTransport(nil) + p := compactProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteDouble(b, p, trans) + } +} +func BenchmarkCompactString2(b *testing.B) { + trans := tfv[2].GetTransport(nil) + p := compactProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteString(b, p, trans) + } +} +func BenchmarkCompactBinary2(b *testing.B) { + trans := tfv[2].GetTransport(nil) + p := compactProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteBinary(b, p, trans) + } +} diff --git a/thrift/thrift/memory_buffer.go b/thrift/thrift/memory_buffer.go new file mode 100644 index 0000000..c48e089 --- /dev/null +++ b/thrift/thrift/memory_buffer.go @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import ( + "bytes" +) + +// Memory buffer-based implementation of the TTransport interface. +type TMemoryBuffer struct { + *bytes.Buffer + size int +} + +type TMemoryBufferTransportFactory struct { + size int +} + +func (p *TMemoryBufferTransportFactory) GetTransport(trans TTransport) TTransport { + if trans != nil { + t, ok := trans.(*TMemoryBuffer) + if ok && t.size > 0 { + return NewTMemoryBufferLen(t.size) + } + } + return NewTMemoryBufferLen(p.size) +} + +func NewTMemoryBufferTransportFactory(size int) *TMemoryBufferTransportFactory { + return &TMemoryBufferTransportFactory{size: size} +} + +func NewTMemoryBuffer() *TMemoryBuffer { + return &TMemoryBuffer{Buffer: &bytes.Buffer{}, size: 0} +} + +func NewTMemoryBufferLen(size int) *TMemoryBuffer { + buf := make([]byte, 0, size) + return &TMemoryBuffer{Buffer: bytes.NewBuffer(buf), size: size} +} + +func (p *TMemoryBuffer) IsOpen() bool { + return true +} + +func (p *TMemoryBuffer) Open() error { + return nil +} + +func (p *TMemoryBuffer) Peek() bool { + return p.IsOpen() +} + +func (p *TMemoryBuffer) Close() error { + p.Buffer.Reset() + return nil +} + +// Flushing a memory buffer is a no-op +func (p *TMemoryBuffer) Flush() error { + return nil +} diff --git a/thrift/thrift/memory_buffer_test.go b/thrift/thrift/memory_buffer_test.go new file mode 100644 index 0000000..af2e8bf --- /dev/null +++ b/thrift/thrift/memory_buffer_test.go @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import ( + "testing" +) + +func TestMemoryBuffer(t *testing.T) { + trans := NewTMemoryBufferLen(1024) + TransportTest(t, trans, trans) +} diff --git a/thrift/thrift/messagetype.go b/thrift/thrift/messagetype.go new file mode 100644 index 0000000..25ab2e9 --- /dev/null +++ b/thrift/thrift/messagetype.go @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +// Message type constants in the Thrift protocol. +type TMessageType int32 + +const ( + INVALID_TMESSAGE_TYPE TMessageType = 0 + CALL TMessageType = 1 + REPLY TMessageType = 2 + EXCEPTION TMessageType = 3 + ONEWAY TMessageType = 4 +) diff --git a/thrift/thrift/multiplexed_protocol.go b/thrift/thrift/multiplexed_protocol.go new file mode 100644 index 0000000..3157e0d --- /dev/null +++ b/thrift/thrift/multiplexed_protocol.go @@ -0,0 +1,169 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import ( + "fmt" + "strings" +) + +/* +TMultiplexedProtocol is a protocol-independent concrete decorator +that allows a Thrift client to communicate with a multiplexing Thrift server, +by prepending the service name to the function name during function calls. + +NOTE: THIS IS NOT USED BY SERVERS. On the server, use TMultiplexedProcessor to handle request +from a multiplexing client. + +This example uses a single socket transport to invoke two services: + +socket := thrift.NewTSocketFromAddrTimeout(addr, TIMEOUT) +transport := thrift.NewTFramedTransport(socket) +protocol := thrift.NewTBinaryProtocolTransport(transport) + +mp := thrift.NewTMultiplexedProtocol(protocol, "Calculator") +service := Calculator.NewCalculatorClient(mp) + +mp2 := thrift.NewTMultiplexedProtocol(protocol, "WeatherReport") +service2 := WeatherReport.NewWeatherReportClient(mp2) + +err := transport.Open() +if err != nil { + t.Fatal("Unable to open client socket", err) +} + +fmt.Println(service.Add(2,2)) +fmt.Println(service2.GetTemperature()) +*/ + +type TMultiplexedProtocol struct { + TProtocol + serviceName string +} + +const MULTIPLEXED_SEPARATOR = ":" + +func NewTMultiplexedProtocol(protocol TProtocol, serviceName string) *TMultiplexedProtocol { + return &TMultiplexedProtocol{ + TProtocol: protocol, + serviceName: serviceName, + } +} + +func (t *TMultiplexedProtocol) WriteMessageBegin(name string, typeId TMessageType, seqid int32) error { + if typeId == CALL || typeId == ONEWAY { + return t.TProtocol.WriteMessageBegin(t.serviceName+MULTIPLEXED_SEPARATOR+name, typeId, seqid) + } else { + return t.TProtocol.WriteMessageBegin(name, typeId, seqid) + } +} + +/* +TMultiplexedProcessor is a TProcessor allowing +a single TServer to provide multiple services. + +To do so, you instantiate the processor and then register additional +processors with it, as shown in the following example: + +var processor = thrift.NewTMultiplexedProcessor() + +firstProcessor := +processor.RegisterProcessor("FirstService", firstProcessor) + +processor.registerProcessor( + "Calculator", + Calculator.NewCalculatorProcessor(&CalculatorHandler{}), +) + +processor.registerProcessor( + "WeatherReport", + WeatherReport.NewWeatherReportProcessor(&WeatherReportHandler{}), +) + +serverTransport, err := thrift.NewTServerSocketTimeout(addr, TIMEOUT) +if err != nil { + t.Fatal("Unable to create server socket", err) +} +server := thrift.NewTSimpleServer2(processor, serverTransport) +server.Serve(); +*/ + +type TMultiplexedProcessor struct { + serviceProcessorMap map[string]TProcessor + DefaultProcessor TProcessor +} + +func NewTMultiplexedProcessor() *TMultiplexedProcessor { + return &TMultiplexedProcessor{ + serviceProcessorMap: make(map[string]TProcessor), + } +} + +func (t *TMultiplexedProcessor) RegisterDefault(processor TProcessor) { + t.DefaultProcessor = processor +} + +func (t *TMultiplexedProcessor) RegisterProcessor(name string, processor TProcessor) { + if t.serviceProcessorMap == nil { + t.serviceProcessorMap = make(map[string]TProcessor) + } + t.serviceProcessorMap[name] = processor +} + +func (t *TMultiplexedProcessor) Process(in, out TProtocol) (bool, TException) { + name, typeId, seqid, err := in.ReadMessageBegin() + if err != nil { + return false, err + } + if typeId != CALL && typeId != ONEWAY { + return false, fmt.Errorf("Unexpected message type %v", typeId) + } + //extract the service name + v := strings.SplitN(name, MULTIPLEXED_SEPARATOR, 2) + if len(v) != 2 { + if t.DefaultProcessor != nil { + smb := NewStoredMessageProtocol(in, name, typeId, seqid) + return t.DefaultProcessor.Process(smb, out) + } + return false, fmt.Errorf("Service name not found in message name: %s. Did you forget to use a TMultiplexProtocol in your client?", name) + } + actualProcessor, ok := t.serviceProcessorMap[v[0]] + if !ok { + return false, fmt.Errorf("Service name not found: %s. Did you forget to call registerProcessor()?", v[0]) + } + smb := NewStoredMessageProtocol(in, v[1], typeId, seqid) + return actualProcessor.Process(smb, out) +} + +//Protocol that use stored message for ReadMessageBegin +type storedMessageProtocol struct { + TProtocol + name string + typeId TMessageType + seqid int32 +} + +func NewStoredMessageProtocol(protocol TProtocol, name string, typeId TMessageType, seqid int32) *storedMessageProtocol { + return &storedMessageProtocol{protocol, name, typeId, seqid} +} + +func (s *storedMessageProtocol) ReadMessageBegin() (name string, typeId TMessageType, seqid int32, err error) { + return s.name, s.typeId, s.seqid, nil +} diff --git a/thrift/thrift/numeric.go b/thrift/thrift/numeric.go new file mode 100644 index 0000000..aa8daa9 --- /dev/null +++ b/thrift/thrift/numeric.go @@ -0,0 +1,164 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import ( + "math" + "strconv" +) + +type Numeric interface { + Int64() int64 + Int32() int32 + Int16() int16 + Byte() byte + Int() int + Float64() float64 + Float32() float32 + String() string + isNull() bool +} + +type numeric struct { + iValue int64 + dValue float64 + sValue string + isNil bool +} + +var ( + INFINITY Numeric + NEGATIVE_INFINITY Numeric + NAN Numeric + ZERO Numeric + NUMERIC_NULL Numeric +) + +func NewNumericFromDouble(dValue float64) Numeric { + if math.IsInf(dValue, 1) { + return INFINITY + } + if math.IsInf(dValue, -1) { + return NEGATIVE_INFINITY + } + if math.IsNaN(dValue) { + return NAN + } + iValue := int64(dValue) + sValue := strconv.FormatFloat(dValue, 'g', 10, 64) + isNil := false + return &numeric{iValue: iValue, dValue: dValue, sValue: sValue, isNil: isNil} +} + +func NewNumericFromI64(iValue int64) Numeric { + dValue := float64(iValue) + sValue := string(iValue) + isNil := false + return &numeric{iValue: iValue, dValue: dValue, sValue: sValue, isNil: isNil} +} + +func NewNumericFromI32(iValue int32) Numeric { + dValue := float64(iValue) + sValue := string(iValue) + isNil := false + return &numeric{iValue: int64(iValue), dValue: dValue, sValue: sValue, isNil: isNil} +} + +func NewNumericFromString(sValue string) Numeric { + if sValue == INFINITY.String() { + return INFINITY + } + if sValue == NEGATIVE_INFINITY.String() { + return NEGATIVE_INFINITY + } + if sValue == NAN.String() { + return NAN + } + iValue, _ := strconv.ParseInt(sValue, 10, 64) + dValue, _ := strconv.ParseFloat(sValue, 64) + isNil := len(sValue) == 0 + return &numeric{iValue: iValue, dValue: dValue, sValue: sValue, isNil: isNil} +} + +func NewNumericFromJSONString(sValue string, isNull bool) Numeric { + if isNull { + return NewNullNumeric() + } + if sValue == JSON_INFINITY { + return INFINITY + } + if sValue == JSON_NEGATIVE_INFINITY { + return NEGATIVE_INFINITY + } + if sValue == JSON_NAN { + return NAN + } + iValue, _ := strconv.ParseInt(sValue, 10, 64) + dValue, _ := strconv.ParseFloat(sValue, 64) + return &numeric{iValue: iValue, dValue: dValue, sValue: sValue, isNil: isNull} +} + +func NewNullNumeric() Numeric { + return &numeric{iValue: 0, dValue: 0.0, sValue: "", isNil: true} +} + +func (p *numeric) Int64() int64 { + return p.iValue +} + +func (p *numeric) Int32() int32 { + return int32(p.iValue) +} + +func (p *numeric) Int16() int16 { + return int16(p.iValue) +} + +func (p *numeric) Byte() byte { + return byte(p.iValue) +} + +func (p *numeric) Int() int { + return int(p.iValue) +} + +func (p *numeric) Float64() float64 { + return p.dValue +} + +func (p *numeric) Float32() float32 { + return float32(p.dValue) +} + +func (p *numeric) String() string { + return p.sValue +} + +func (p *numeric) isNull() bool { + return p.isNil +} + +func init() { + INFINITY = &numeric{iValue: 0, dValue: math.Inf(1), sValue: "Infinity", isNil: false} + NEGATIVE_INFINITY = &numeric{iValue: 0, dValue: math.Inf(-1), sValue: "-Infinity", isNil: false} + NAN = &numeric{iValue: 0, dValue: math.NaN(), sValue: "NaN", isNil: false} + ZERO = &numeric{iValue: 0, dValue: 0, sValue: "0", isNil: false} + NUMERIC_NULL = &numeric{iValue: 0, dValue: 0, sValue: "0", isNil: true} +} diff --git a/thrift/thrift/pointerize.go b/thrift/thrift/pointerize.go new file mode 100644 index 0000000..8d6b2c2 --- /dev/null +++ b/thrift/thrift/pointerize.go @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +/////////////////////////////////////////////////////////////////////////////// +// This file is home to helpers that convert from various base types to +// respective pointer types. This is necessary because Go does not permit +// references to constants, nor can a pointer type to base type be allocated +// and initialized in a single expression. +// +// E.g., this is not allowed: +// +// var ip *int = &5 +// +// But this *is* allowed: +// +// func IntPtr(i int) *int { return &i } +// var ip *int = IntPtr(5) +// +// Since pointers to base types are commonplace as [optional] fields in +// exported thrift structs, we factor such helpers here. +/////////////////////////////////////////////////////////////////////////////// + +func Float32Ptr(v float32) *float32 { return &v } +func Float64Ptr(v float64) *float64 { return &v } +func IntPtr(v int) *int { return &v } +func Int32Ptr(v int32) *int32 { return &v } +func Int64Ptr(v int64) *int64 { return &v } +func StringPtr(v string) *string { return &v } +func Uint32Ptr(v uint32) *uint32 { return &v } +func Uint64Ptr(v uint64) *uint64 { return &v } +func BoolPtr(v bool) *bool { return &v } +func ByteSlicePtr(v []byte) *[]byte { return &v } diff --git a/thrift/thrift/processor.go b/thrift/thrift/processor.go new file mode 100644 index 0000000..ca0d3fa --- /dev/null +++ b/thrift/thrift/processor.go @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +// A processor is a generic object which operates upon an input stream and +// writes to some output stream. +type TProcessor interface { + Process(in, out TProtocol) (bool, TException) +} + +type TProcessorFunction interface { + Process(seqId int32, in, out TProtocol) (bool, TException) +} diff --git a/thrift/thrift/processor_factory.go b/thrift/thrift/processor_factory.go new file mode 100644 index 0000000..9d645df --- /dev/null +++ b/thrift/thrift/processor_factory.go @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +// The default processor factory just returns a singleton +// instance. +type TProcessorFactory interface { + GetProcessor(trans TTransport) TProcessor +} + +type tProcessorFactory struct { + processor TProcessor +} + +func NewTProcessorFactory(p TProcessor) TProcessorFactory { + return &tProcessorFactory{processor: p} +} + +func (p *tProcessorFactory) GetProcessor(trans TTransport) TProcessor { + return p.processor +} + +/** + * The default processor factory just returns a singleton + * instance. + */ +type TProcessorFunctionFactory interface { + GetProcessorFunction(trans TTransport) TProcessorFunction +} + +type tProcessorFunctionFactory struct { + processor TProcessorFunction +} + +func NewTProcessorFunctionFactory(p TProcessorFunction) TProcessorFunctionFactory { + return &tProcessorFunctionFactory{processor: p} +} + +func (p *tProcessorFunctionFactory) GetProcessorFunction(trans TTransport) TProcessorFunction { + return p.processor +} diff --git a/thrift/thrift/protocol.go b/thrift/thrift/protocol.go new file mode 100644 index 0000000..87ceaad --- /dev/null +++ b/thrift/thrift/protocol.go @@ -0,0 +1,154 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +const ( + VERSION_MASK = 0xffff0000 + VERSION_1 = 0x80010000 +) + +type TProtocol interface { + WriteMessageBegin(name string, typeId TMessageType, seqid int32) error + WriteMessageEnd() error + WriteStructBegin(name string) error + WriteStructEnd() error + WriteFieldBegin(name string, typeId TType, id int16) error + WriteFieldEnd() error + WriteFieldStop() error + WriteMapBegin(keyType TType, valueType TType, size int) error + WriteMapEnd() error + WriteListBegin(elemType TType, size int) error + WriteListEnd() error + WriteSetBegin(elemType TType, size int) error + WriteSetEnd() error + WriteBool(value bool) error + WriteByte(value byte) error + WriteI16(value int16) error + WriteI32(value int32) error + WriteI64(value int64) error + WriteDouble(value float64) error + WriteString(value string) error + WriteBinary(value []byte) error + + ReadMessageBegin() (name string, typeId TMessageType, seqid int32, err error) + ReadMessageEnd() error + ReadStructBegin() (name string, err error) + ReadStructEnd() error + ReadFieldBegin() (name string, typeId TType, id int16, err error) + ReadFieldEnd() error + ReadMapBegin() (keyType TType, valueType TType, size int, err error) + ReadMapEnd() error + ReadListBegin() (elemType TType, size int, err error) + ReadListEnd() error + ReadSetBegin() (elemType TType, size int, err error) + ReadSetEnd() error + ReadBool() (value bool, err error) + ReadByte() (value byte, err error) + ReadI16() (value int16, err error) + ReadI32() (value int32, err error) + ReadI64() (value int64, err error) + ReadDouble() (value float64, err error) + ReadString() (value string, err error) + ReadBinary() (value []byte, err error) + + Skip(fieldType TType) (err error) + Flush() (err error) + + Transport() TTransport +} + +// The maximum recursive depth the skip() function will traverse +var MaxSkipDepth = 1<<31 - 1 + +// Skips over the next data element from the provided input TProtocol object. +func SkipDefaultDepth(prot TProtocol, typeId TType) (err error) { + return Skip(prot, typeId, MaxSkipDepth) +} + +// Skips over the next data element from the provided input TProtocol object. +func Skip(self TProtocol, fieldType TType, maxDepth int) (err error) { + switch fieldType { + case STOP: + return + case BOOL: + _, err = self.ReadBool() + return + case BYTE: + _, err = self.ReadByte() + return + case I16: + _, err = self.ReadI16() + return + case I32: + _, err = self.ReadI32() + return + case I64: + _, err = self.ReadI64() + return + case DOUBLE: + _, err = self.ReadDouble() + return + case STRING: + _, err = self.ReadString() + return + case STRUCT: + if _, err = self.ReadStructBegin(); err != nil { + return err + } + for { + _, typeId, _, _ := self.ReadFieldBegin() + if typeId == STOP { + break + } + Skip(self, typeId, maxDepth-1) + self.ReadFieldEnd() + } + return self.ReadStructEnd() + case MAP: + keyType, valueType, size, err := self.ReadMapBegin() + if err != nil { + return err + } + for i := 0; i < size; i++ { + Skip(self, keyType, maxDepth-1) + self.Skip(valueType) + } + return self.ReadMapEnd() + case SET: + elemType, size, err := self.ReadSetBegin() + if err != nil { + return err + } + for i := 0; i < size; i++ { + Skip(self, elemType, maxDepth-1) + } + return self.ReadSetEnd() + case LIST: + elemType, size, err := self.ReadListBegin() + if err != nil { + return err + } + for i := 0; i < size; i++ { + Skip(self, elemType, maxDepth-1) + } + return self.ReadListEnd() + } + return nil +} diff --git a/thrift/thrift/protocol_exception.go b/thrift/thrift/protocol_exception.go new file mode 100644 index 0000000..29ab75d --- /dev/null +++ b/thrift/thrift/protocol_exception.go @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import ( + "encoding/base64" +) + +// Thrift Protocol exception +type TProtocolException interface { + TException + TypeId() int +} + +const ( + UNKNOWN_PROTOCOL_EXCEPTION = 0 + INVALID_DATA = 1 + NEGATIVE_SIZE = 2 + SIZE_LIMIT = 3 + BAD_VERSION = 4 + NOT_IMPLEMENTED = 5 + DEPTH_LIMIT = 6 +) + +type tProtocolException struct { + typeId int + message string +} + +func (p *tProtocolException) TypeId() int { + return p.typeId +} + +func (p *tProtocolException) String() string { + return p.message +} + +func (p *tProtocolException) Error() string { + return p.message +} + +func NewTProtocolException(err error) TProtocolException { + if err == nil { + return nil + } + if e, ok := err.(TProtocolException); ok { + return e + } + if _, ok := err.(base64.CorruptInputError); ok { + return &tProtocolException{INVALID_DATA, err.Error()} + } + return &tProtocolException{UNKNOWN_PROTOCOL_EXCEPTION, err.Error()} +} + +func NewTProtocolExceptionWithType(errType int, err error) TProtocolException { + if err == nil { + return nil + } + return &tProtocolException{errType, err.Error()} +} diff --git a/thrift/thrift/protocol_factory.go b/thrift/thrift/protocol_factory.go new file mode 100644 index 0000000..c40f796 --- /dev/null +++ b/thrift/thrift/protocol_factory.go @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +// Factory interface for constructing protocol instances. +type TProtocolFactory interface { + GetProtocol(trans TTransport) TProtocol +} diff --git a/thrift/thrift/protocol_test.go b/thrift/thrift/protocol_test.go new file mode 100644 index 0000000..7e7950d --- /dev/null +++ b/thrift/thrift/protocol_test.go @@ -0,0 +1,479 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import ( + "bytes" + "io/ioutil" + "math" + "net" + "net/http" + "testing" +) + +const PROTOCOL_BINARY_DATA_SIZE = 155 + +var ( + data string // test data for writing + protocol_bdata []byte // test data for writing; same as data + BOOL_VALUES []bool + BYTE_VALUES []byte + INT16_VALUES []int16 + INT32_VALUES []int32 + INT64_VALUES []int64 + DOUBLE_VALUES []float64 + STRING_VALUES []string +) + +func init() { + protocol_bdata = make([]byte, PROTOCOL_BINARY_DATA_SIZE) + for i := 0; i < PROTOCOL_BINARY_DATA_SIZE; i++ { + protocol_bdata[i] = byte((i + 'a') % 255) + } + data = string(protocol_bdata) + BOOL_VALUES = []bool{false, true, false, false, true} + BYTE_VALUES = []byte{117, 0, 1, 32, 127, 128, 255} + INT16_VALUES = []int16{459, 0, 1, -1, -128, 127, 32767, -32768} + INT32_VALUES = []int32{459, 0, 1, -1, -128, 127, 32767, 2147483647, -2147483535} + INT64_VALUES = []int64{459, 0, 1, -1, -128, 127, 32767, 2147483647, -2147483535, 34359738481, -35184372088719, -9223372036854775808, 9223372036854775807} + DOUBLE_VALUES = []float64{459.3, 0.0, -1.0, 1.0, 0.5, 0.3333, 3.14159, 1.537e-38, 1.673e25, 6.02214179e23, -6.02214179e23, INFINITY.Float64(), NEGATIVE_INFINITY.Float64(), NAN.Float64()} + STRING_VALUES = []string{"", "a", "st[uf]f", "st,u:ff with spaces", "stuff\twith\nescape\\characters'...\"lots{of}fun"} +} + +type HTTPEchoServer struct{} +type HTTPHeaderEchoServer struct{} + +func (p *HTTPEchoServer) ServeHTTP(w http.ResponseWriter, req *http.Request) { + buf, err := ioutil.ReadAll(req.Body) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + w.Write(buf) + } else { + w.WriteHeader(http.StatusOK) + w.Write(buf) + } +} + +func (p *HTTPHeaderEchoServer) ServeHTTP(w http.ResponseWriter, req *http.Request) { + buf, err := ioutil.ReadAll(req.Body) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + w.Write(buf) + } else { + w.WriteHeader(http.StatusOK) + w.Write(buf) + } +} + +func HttpClientSetupForTest(t *testing.T) (net.Listener, net.Addr) { + addr, err := FindAvailableTCPServerPort(40000) + if err != nil { + t.Fatalf("Unable to find available tcp port addr: %s", err) + return nil, addr + } + l, err := net.Listen(addr.Network(), addr.String()) + if err != nil { + t.Fatalf("Unable to setup tcp listener on %s: %s", addr.String(), err) + return l, addr + } + go http.Serve(l, &HTTPEchoServer{}) + return l, addr +} + +func HttpClientSetupForHeaderTest(t *testing.T) (net.Listener, net.Addr) { + addr, err := FindAvailableTCPServerPort(40000) + if err != nil { + t.Fatalf("Unable to find available tcp port addr: %s", err) + return nil, addr + } + l, err := net.Listen(addr.Network(), addr.String()) + if err != nil { + t.Fatalf("Unable to setup tcp listener on %s: %s", addr.String(), err) + return l, addr + } + go http.Serve(l, &HTTPHeaderEchoServer{}) + return l, addr +} + +func ReadWriteProtocolTest(t *testing.T, protocolFactory TProtocolFactory) { + buf := bytes.NewBuffer(make([]byte, 0, 1024)) + l, addr := HttpClientSetupForTest(t) + defer l.Close() + transports := []TTransportFactory{ + NewTMemoryBufferTransportFactory(1024), + NewStreamTransportFactory(buf, buf, true), + NewTFramedTransportFactory(NewTMemoryBufferTransportFactory(1024)), + NewTHttpPostClientTransportFactory("http://" + addr.String()), + } + for _, tf := range transports { + trans := tf.GetTransport(nil) + p := protocolFactory.GetProtocol(trans) + ReadWriteBool(t, p, trans) + trans.Close() + } + for _, tf := range transports { + trans := tf.GetTransport(nil) + p := protocolFactory.GetProtocol(trans) + ReadWriteByte(t, p, trans) + trans.Close() + } + for _, tf := range transports { + trans := tf.GetTransport(nil) + p := protocolFactory.GetProtocol(trans) + ReadWriteI16(t, p, trans) + trans.Close() + } + for _, tf := range transports { + trans := tf.GetTransport(nil) + p := protocolFactory.GetProtocol(trans) + ReadWriteI32(t, p, trans) + trans.Close() + } + for _, tf := range transports { + trans := tf.GetTransport(nil) + p := protocolFactory.GetProtocol(trans) + ReadWriteI64(t, p, trans) + trans.Close() + } + for _, tf := range transports { + trans := tf.GetTransport(nil) + p := protocolFactory.GetProtocol(trans) + ReadWriteDouble(t, p, trans) + trans.Close() + } + for _, tf := range transports { + trans := tf.GetTransport(nil) + p := protocolFactory.GetProtocol(trans) + ReadWriteString(t, p, trans) + trans.Close() + } + for _, tf := range transports { + trans := tf.GetTransport(nil) + p := protocolFactory.GetProtocol(trans) + ReadWriteBinary(t, p, trans) + trans.Close() + } + for _, tf := range transports { + trans := tf.GetTransport(nil) + p := protocolFactory.GetProtocol(trans) + ReadWriteI64(t, p, trans) + ReadWriteDouble(t, p, trans) + ReadWriteBinary(t, p, trans) + ReadWriteByte(t, p, trans) + trans.Close() + } +} + +func ReadWriteBool(t testing.TB, p TProtocol, trans TTransport) { + thetype := TType(BOOL) + thelen := len(BOOL_VALUES) + err := p.WriteListBegin(thetype, thelen) + if err != nil { + t.Errorf("%s: %T %T %q Error writing list begin: %q", "ReadWriteBool", p, trans, err, thetype) + } + for k, v := range BOOL_VALUES { + err = p.WriteBool(v) + if err != nil { + t.Errorf("%s: %T %T %q Error writing bool in list at index %d: %q", "ReadWriteBool", p, trans, err, k, v) + } + } + p.WriteListEnd() + if err != nil { + t.Errorf("%s: %T %T %q Error writing list end: %q", "ReadWriteBool", p, trans, err, BOOL_VALUES) + } + p.Flush() + thetype2, thelen2, err := p.ReadListBegin() + if err != nil { + t.Errorf("%s: %T %T %q Error reading list: %q", "ReadWriteBool", p, trans, err, BOOL_VALUES) + } + _, ok := p.(*TSimpleJSONProtocol) + if !ok { + if thetype != thetype2 { + t.Errorf("%s: %T %T type %s != type %s", "ReadWriteBool", p, trans, thetype, thetype2) + } + if thelen != thelen2 { + t.Errorf("%s: %T %T len %s != len %s", "ReadWriteBool", p, trans, thelen, thelen2) + } + } + for k, v := range BOOL_VALUES { + value, err := p.ReadBool() + if err != nil { + t.Errorf("%s: %T %T %q Error reading bool at index %d: %q", "ReadWriteBool", p, trans, err, k, v) + } + if v != value { + t.Errorf("%s: index %d %q %q %q != %q", "ReadWriteBool", k, p, trans, v, value) + } + } + err = p.ReadListEnd() + if err != nil { + t.Errorf("%s: %T %T Unable to read list end: %q", "ReadWriteBool", p, trans, err) + } +} + +func ReadWriteByte(t testing.TB, p TProtocol, trans TTransport) { + thetype := TType(BYTE) + thelen := len(BYTE_VALUES) + err := p.WriteListBegin(thetype, thelen) + if err != nil { + t.Errorf("%s: %T %T %q Error writing list begin: %q", "ReadWriteByte", p, trans, err, thetype) + } + for k, v := range BYTE_VALUES { + err = p.WriteByte(v) + if err != nil { + t.Errorf("%s: %T %T %q Error writing byte in list at index %d: %q", "ReadWriteByte", p, trans, err, k, v) + } + } + err = p.WriteListEnd() + if err != nil { + t.Errorf("%s: %T %T %q Error writing list end: %q", "ReadWriteByte", p, trans, err, BYTE_VALUES) + } + err = p.Flush() + if err != nil { + t.Errorf("%s: %T %T %q Error flushing list of bytes: %q", "ReadWriteByte", p, trans, err, BYTE_VALUES) + } + thetype2, thelen2, err := p.ReadListBegin() + if err != nil { + t.Errorf("%s: %T %T %q Error reading list: %q", "ReadWriteByte", p, trans, err, BYTE_VALUES) + } + _, ok := p.(*TSimpleJSONProtocol) + if !ok { + if thetype != thetype2 { + t.Errorf("%s: %T %T type %s != type %s", "ReadWriteByte", p, trans, thetype, thetype2) + } + if thelen != thelen2 { + t.Errorf("%s: %T %T len %s != len %s", "ReadWriteByte", p, trans, thelen, thelen2) + } + } + for k, v := range BYTE_VALUES { + value, err := p.ReadByte() + if err != nil { + t.Errorf("%s: %T %T %q Error reading byte at index %d: %q", "ReadWriteByte", p, trans, err, k, v) + } + if v != value { + t.Errorf("%s: %T %T %d != %d", "ReadWriteByte", p, trans, v, value) + } + } + err = p.ReadListEnd() + if err != nil { + t.Errorf("%s: %T %T Unable to read list end: %q", "ReadWriteByte", p, trans, err) + } +} + +func ReadWriteI16(t testing.TB, p TProtocol, trans TTransport) { + thetype := TType(I16) + thelen := len(INT16_VALUES) + p.WriteListBegin(thetype, thelen) + for _, v := range INT16_VALUES { + p.WriteI16(v) + } + p.WriteListEnd() + p.Flush() + thetype2, thelen2, err := p.ReadListBegin() + if err != nil { + t.Errorf("%s: %T %T %q Error reading list: %q", "ReadWriteI16", p, trans, err, INT16_VALUES) + } + _, ok := p.(*TSimpleJSONProtocol) + if !ok { + if thetype != thetype2 { + t.Errorf("%s: %T %T type %s != type %s", "ReadWriteI16", p, trans, thetype, thetype2) + } + if thelen != thelen2 { + t.Errorf("%s: %T %T len %s != len %s", "ReadWriteI16", p, trans, thelen, thelen2) + } + } + for k, v := range INT16_VALUES { + value, err := p.ReadI16() + if err != nil { + t.Errorf("%s: %T %T %q Error reading int16 at index %d: %q", "ReadWriteI16", p, trans, err, k, v) + } + if v != value { + t.Errorf("%s: %T %T %d != %d", "ReadWriteI16", p, trans, v, value) + } + } + err = p.ReadListEnd() + if err != nil { + t.Errorf("%s: %T %T Unable to read list end: %q", "ReadWriteI16", p, trans, err) + } +} + +func ReadWriteI32(t testing.TB, p TProtocol, trans TTransport) { + thetype := TType(I32) + thelen := len(INT32_VALUES) + p.WriteListBegin(thetype, thelen) + for _, v := range INT32_VALUES { + p.WriteI32(v) + } + p.WriteListEnd() + p.Flush() + thetype2, thelen2, err := p.ReadListBegin() + if err != nil { + t.Errorf("%s: %T %T %q Error reading list: %q", "ReadWriteI32", p, trans, err, INT32_VALUES) + } + _, ok := p.(*TSimpleJSONProtocol) + if !ok { + if thetype != thetype2 { + t.Errorf("%s: %T %T type %s != type %s", "ReadWriteI32", p, trans, thetype, thetype2) + } + if thelen != thelen2 { + t.Errorf("%s: %T %T len %s != len %s", "ReadWriteI32", p, trans, thelen, thelen2) + } + } + for k, v := range INT32_VALUES { + value, err := p.ReadI32() + if err != nil { + t.Errorf("%s: %T %T %q Error reading int32 at index %d: %q", "ReadWriteI32", p, trans, err, k, v) + } + if v != value { + t.Errorf("%s: %T %T %d != %d", "ReadWriteI32", p, trans, v, value) + } + } + if err != nil { + t.Errorf("%s: %T %T Unable to read list end: %q", "ReadWriteI32", p, trans, err) + } +} + +func ReadWriteI64(t testing.TB, p TProtocol, trans TTransport) { + thetype := TType(I64) + thelen := len(INT64_VALUES) + p.WriteListBegin(thetype, thelen) + for _, v := range INT64_VALUES { + p.WriteI64(v) + } + p.WriteListEnd() + p.Flush() + thetype2, thelen2, err := p.ReadListBegin() + if err != nil { + t.Errorf("%s: %T %T %q Error reading list: %q", "ReadWriteI64", p, trans, err, INT64_VALUES) + } + _, ok := p.(*TSimpleJSONProtocol) + if !ok { + if thetype != thetype2 { + t.Errorf("%s: %T %T type %s != type %s", "ReadWriteI64", p, trans, thetype, thetype2) + } + if thelen != thelen2 { + t.Errorf("%s: %T %T len %s != len %s", "ReadWriteI64", p, trans, thelen, thelen2) + } + } + for k, v := range INT64_VALUES { + value, err := p.ReadI64() + if err != nil { + t.Errorf("%s: %T %T %q Error reading int64 at index %d: %q", "ReadWriteI64", p, trans, err, k, v) + } + if v != value { + t.Errorf("%s: %T %T %q != %q", "ReadWriteI64", p, trans, v, value) + } + } + if err != nil { + t.Errorf("%s: %T %T Unable to read list end: %q", "ReadWriteI64", p, trans, err) + } +} + +func ReadWriteDouble(t testing.TB, p TProtocol, trans TTransport) { + thetype := TType(DOUBLE) + thelen := len(DOUBLE_VALUES) + p.WriteListBegin(thetype, thelen) + for _, v := range DOUBLE_VALUES { + p.WriteDouble(v) + } + p.WriteListEnd() + p.Flush() + thetype2, thelen2, err := p.ReadListBegin() + if err != nil { + t.Errorf("%s: %T %T %q Error reading list: %q", "ReadWriteDouble", p, trans, err, DOUBLE_VALUES) + } + if thetype != thetype2 { + t.Errorf("%s: %T %T type %s != type %s", "ReadWriteDouble", p, trans, thetype, thetype2) + } + if thelen != thelen2 { + t.Errorf("%s: %T %T len %s != len %s", "ReadWriteDouble", p, trans, thelen, thelen2) + } + for k, v := range DOUBLE_VALUES { + value, err := p.ReadDouble() + if err != nil { + t.Errorf("%s: %T %T %q Error reading double at index %d: %q", "ReadWriteDouble", p, trans, err, k, v) + } + if math.IsNaN(v) { + if !math.IsNaN(value) { + t.Errorf("%s: %T %T math.IsNaN(%q) != math.IsNaN(%q)", "ReadWriteDouble", p, trans, v, value) + } + } else if v != value { + t.Errorf("%s: %T %T %v != %q", "ReadWriteDouble", p, trans, v, value) + } + } + err = p.ReadListEnd() + if err != nil { + t.Errorf("%s: %T %T Unable to read list end: %q", "ReadWriteDouble", p, trans, err) + } +} + +func ReadWriteString(t testing.TB, p TProtocol, trans TTransport) { + thetype := TType(STRING) + thelen := len(STRING_VALUES) + p.WriteListBegin(thetype, thelen) + for _, v := range STRING_VALUES { + p.WriteString(v) + } + p.WriteListEnd() + p.Flush() + thetype2, thelen2, err := p.ReadListBegin() + if err != nil { + t.Errorf("%s: %T %T %q Error reading list: %q", "ReadWriteString", p, trans, err, STRING_VALUES) + } + _, ok := p.(*TSimpleJSONProtocol) + if !ok { + if thetype != thetype2 { + t.Errorf("%s: %T %T type %s != type %s", "ReadWriteString", p, trans, thetype, thetype2) + } + if thelen != thelen2 { + t.Errorf("%s: %T %T len %s != len %s", "ReadWriteString", p, trans, thelen, thelen2) + } + } + for k, v := range STRING_VALUES { + value, err := p.ReadString() + if err != nil { + t.Errorf("%s: %T %T %q Error reading string at index %d: %q", "ReadWriteString", p, trans, err, k, v) + } + if v != value { + t.Errorf("%s: %T %T %d != %d", "ReadWriteString", p, trans, v, value) + } + } + if err != nil { + t.Errorf("%s: %T %T Unable to read list end: %q", "ReadWriteString", p, trans, err) + } +} + +func ReadWriteBinary(t testing.TB, p TProtocol, trans TTransport) { + v := protocol_bdata + p.WriteBinary(v) + p.Flush() + value, err := p.ReadBinary() + if err != nil { + t.Errorf("%s: %T %T Unable to read binary: %s", "ReadWriteBinary", p, trans, err.Error()) + } + if len(v) != len(value) { + t.Errorf("%s: %T %T len(v) != len(value)... %d != %d", "ReadWriteBinary", p, trans, len(v), len(value)) + } else { + for i := 0; i < len(v); i++ { + if v[i] != value[i] { + t.Errorf("%s: %T %T %s != %s", "ReadWriteBinary", p, trans, v, value) + } + } + } +} diff --git a/thrift/thrift/rich_transport.go b/thrift/thrift/rich_transport.go new file mode 100644 index 0000000..d7268d9 --- /dev/null +++ b/thrift/thrift/rich_transport.go @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import "io" + +type RichTransport struct { + TTransport +} + +// Wraps Transport to provide TRichTransport interface +func NewTRichTransport(trans TTransport) *RichTransport { + return &RichTransport{trans} +} + +func (r *RichTransport) ReadByte() (c byte, err error) { + return readByte(r.TTransport) +} + +func (r *RichTransport) WriteByte(c byte) error { + return writeByte(r.TTransport, c) +} + +func (r *RichTransport) WriteString(s string) (n int, err error) { + return r.Write([]byte(s)) +} + +func readByte(r io.Reader) (c byte, err error) { + v := [1]byte{0} + n, err := r.Read(v[0:1]) + if n > 0 && (err == nil || err == io.EOF) { + return v[0], nil + } + if n > 0 && err != nil { + return v[0], err + } + if err != nil { + return 0, err + } + return v[0], nil +} + +func writeByte(w io.Writer, c byte) error { + v := [1]byte{c} + _, err := w.Write(v[0:1]) + return err +} diff --git a/thrift/thrift/rich_transport_test.go b/thrift/thrift/rich_transport_test.go new file mode 100644 index 0000000..41513f8 --- /dev/null +++ b/thrift/thrift/rich_transport_test.go @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import ( + "bytes" + "errors" + "io" + "reflect" + "testing" +) + +func TestEnsureTransportsAreRich(t *testing.T) { + buf := bytes.NewBuffer(make([]byte, 0, 1024)) + + transports := []TTransportFactory{ + NewTMemoryBufferTransportFactory(1024), + NewStreamTransportFactory(buf, buf, true), + NewTFramedTransportFactory(NewTMemoryBufferTransportFactory(1024)), + NewTHttpPostClientTransportFactory("http://127.0.0.1"), + } + for _, tf := range transports { + trans := tf.GetTransport(nil) + _, ok := trans.(TRichTransport) + if !ok { + t.Errorf("Transport %s does not implement TRichTransport interface", reflect.ValueOf(trans)) + } + } +} + +// TestReadByte tests whether readByte handles error cases correctly. +func TestReadByte(t *testing.T) { + for i, test := range readByteTests { + v, err := readByte(test.r) + if v != test.v { + t.Fatalf("TestReadByte %d: value differs. Expected %d, got %d", i, test.v, test.r.v) + } + if err != test.err { + t.Fatalf("TestReadByte %d: error differs. Expected %s, got %s", i, test.err, test.r.err) + } + } +} + +var someError = errors.New("Some error") +var readByteTests = []struct { + r *mockReader + v byte + err error +}{ + {&mockReader{0, 55, io.EOF}, 0, io.EOF}, // reader sends EOF w/o data + {&mockReader{0, 55, someError}, 0, someError}, // reader sends some other error + {&mockReader{1, 55, nil}, 55, nil}, // reader sends data w/o error + {&mockReader{1, 55, io.EOF}, 55, nil}, // reader sends data with EOF + {&mockReader{1, 55, someError}, 55, someError}, // reader sends data withsome error +} + +type mockReader struct { + n int + v byte + err error +} + +func (r *mockReader) Read(p []byte) (n int, err error) { + if r.n > 0 { + p[0] = r.v + } + return r.n, r.err +} diff --git a/thrift/thrift/serializer.go b/thrift/thrift/serializer.go new file mode 100644 index 0000000..7712229 --- /dev/null +++ b/thrift/thrift/serializer.go @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +type TSerializer struct { + Transport *TMemoryBuffer + Protocol TProtocol +} + +type TStruct interface { + Write(p TProtocol) error + Read(p TProtocol) error +} + +func NewTSerializer() *TSerializer { + transport := NewTMemoryBufferLen(1024) + protocol := NewTBinaryProtocolFactoryDefault().GetProtocol(transport) + + return &TSerializer{ + transport, + protocol} +} + +func (t *TSerializer) WriteString(msg TStruct) (s string, err error) { + t.Transport.Reset() + + if err = msg.Write(t.Protocol); err != nil { + return + } + + if err = t.Protocol.Flush(); err != nil { + return + } + if err = t.Transport.Flush(); err != nil { + return + } + + return t.Transport.String(), nil +} + +func (t *TSerializer) Write(msg TStruct) (b []byte, err error) { + t.Transport.Reset() + + if err = msg.Write(t.Protocol); err != nil { + return + } + + if err = t.Protocol.Flush(); err != nil { + return + } + + if err = t.Transport.Flush(); err != nil { + return + } + + b = append(b, t.Transport.Bytes()...) + return +} diff --git a/thrift/thrift/serializer_test.go b/thrift/thrift/serializer_test.go new file mode 100644 index 0000000..0f3f7d7 --- /dev/null +++ b/thrift/thrift/serializer_test.go @@ -0,0 +1,169 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import ( + "errors" + "fmt" + "testing" +) + +type ProtocolFactory interface { + GetProtocol(t TTransport) TProtocol +} + +func compareStructs(m, m1 TestStruct) (bool, error) { + switch { + case m.On != m1.On: + return false, errors.New("Boolean not equal") + case m.B != m1.B: + return false, errors.New("Byte not equal") + case m.Int16 != m1.Int16: + return false, errors.New("Int16 not equal") + case m.Int32 != m1.Int32: + return false, errors.New("Int32 not equal") + case m.Int64 != m1.Int64: + return false, errors.New("Int64 not equal") + case m.D != m1.D: + return false, errors.New("Double not equal") + case m.St != m1.St: + return false, errors.New("String not equal") + + case len(m.Bin) != len(m1.Bin): + return false, errors.New("Binary size not equal") + case len(m.Bin) == len(m1.Bin): + for i := range m.Bin { + if m.Bin[i] != m1.Bin[i] { + return false, errors.New("Binary not equal") + } + } + case len(m.StringMap) != len(m1.StringMap): + return false, errors.New("StringMap size not equal") + case len(m.StringList) != len(m1.StringList): + return false, errors.New("StringList size not equal") + case len(m.StringSet) != len(m1.StringSet): + return false, errors.New("StringSet size not equal") + + case m.E != m1.E: + return false, errors.New("TestEnum not equal") + + default: + return true, nil + + } + return true, nil +} + +func ProtocolTest1(test *testing.T, pf ProtocolFactory) (bool, error) { + t := NewTSerializer() + t.Protocol = pf.GetProtocol(t.Transport) + var m = TestStruct{} + m.On = true + m.B = int8(0) + m.Int16 = 1 + m.Int32 = 2 + m.Int64 = 3 + m.D = 4.1 + m.St = "Test" + m.Bin = make([]byte, 10) + m.StringMap = make(map[string]string, 5) + m.StringList = make([]string, 5) + m.StringSet = make(map[string]bool, 5) + m.E = 2 + + s, err := t.WriteString(&m) + if err != nil { + return false, errors.New(fmt.Sprintf("Unable to Serialize struct\n\t %s", err)) + } + + t1 := NewTDeserializer() + t1.Protocol = pf.GetProtocol(t1.Transport) + var m1 = TestStruct{} + if err = t1.ReadString(&m1, s); err != nil { + return false, errors.New(fmt.Sprintf("Unable to Deserialize struct\n\t %s", err)) + + } + + return compareStructs(m, m1) + +} + +func ProtocolTest2(test *testing.T, pf ProtocolFactory) (bool, error) { + t := NewTSerializer() + t.Protocol = pf.GetProtocol(t.Transport) + var m = TestStruct{} + m.On = false + m.B = int8(0) + m.Int16 = 1 + m.Int32 = 2 + m.Int64 = 3 + m.D = 4.1 + m.St = "Test" + m.Bin = make([]byte, 10) + m.StringMap = make(map[string]string, 5) + m.StringList = make([]string, 5) + m.StringSet = make(map[string]bool, 5) + m.E = 2 + + s, err := t.WriteString(&m) + if err != nil { + return false, errors.New(fmt.Sprintf("Unable to Serialize struct\n\t %s", err)) + + } + + t1 := NewTDeserializer() + t1.Protocol = pf.GetProtocol(t1.Transport) + var m1 = TestStruct{} + if err = t1.ReadString(&m1, s); err != nil { + return false, errors.New(fmt.Sprintf("Unable to Deserialize struct\n\t %s", err)) + + } + + return compareStructs(m, m1) + +} + +func TestSerializer(t *testing.T) { + + var protocol_factories map[string]ProtocolFactory + protocol_factories = make(map[string]ProtocolFactory) + protocol_factories["Binary"] = NewTBinaryProtocolFactoryDefault() + protocol_factories["Compact"] = NewTCompactProtocolFactory() + //protocol_factories["SimpleJSON"] = NewTSimpleJSONProtocolFactory() - write only, can't be read back by design + protocol_factories["JSON"] = NewTJSONProtocolFactory() + + var tests map[string]func(*testing.T, ProtocolFactory) (bool, error) + tests = make(map[string]func(*testing.T, ProtocolFactory) (bool, error)) + tests["Test 1"] = ProtocolTest1 + tests["Test 2"] = ProtocolTest2 + //tests["Test 3"] = ProtocolTest3 // Example of how to add additional tests + + for name, pf := range protocol_factories { + + for test, f := range tests { + + if s, err := f(t, pf); !s || err != nil { + t.Errorf("%s Failed for %s protocol\n\t %s", test, name, err) + } + + } + } + +} diff --git a/thrift/thrift/serializer_types.go b/thrift/thrift/serializer_types.go new file mode 100644 index 0000000..efbcde8 --- /dev/null +++ b/thrift/thrift/serializer_types.go @@ -0,0 +1,595 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +// Autogenerated by Thrift Compiler (1.0.0-dev) +// DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING + +/* THE FOLLOWING THRIFT FILE WAS USED TO CREATE THIS + +enum TestEnum { + FIRST = 1, + SECOND = 2, + THIRD = 3, + FOURTH = 4, +} + +struct TestStruct { + 1: bool on, + 2: byte b, + 3: i16 int16, + 4: i32 int32, + 5: i64 int64, + 6: double d, + 7: string st, + 8: binary bin, + 9: map stringMap, + 10: list stringList, + 11: set stringSet, + 12: TestEnum e, +} +*/ + +import ( + "fmt" +) + +// (needed to ensure safety because of naive import list construction.) +var _ = ZERO +var _ = fmt.Printf + +var GoUnusedProtection__ int + +type TestEnum int64 + +const ( + TestEnum_FIRST TestEnum = 1 + TestEnum_SECOND TestEnum = 2 + TestEnum_THIRD TestEnum = 3 + TestEnum_FOURTH TestEnum = 4 +) + +func (p TestEnum) String() string { + switch p { + case TestEnum_FIRST: + return "TestEnum_FIRST" + case TestEnum_SECOND: + return "TestEnum_SECOND" + case TestEnum_THIRD: + return "TestEnum_THIRD" + case TestEnum_FOURTH: + return "TestEnum_FOURTH" + } + return "" +} + +func TestEnumFromString(s string) (TestEnum, error) { + switch s { + case "TestEnum_FIRST": + return TestEnum_FIRST, nil + case "TestEnum_SECOND": + return TestEnum_SECOND, nil + case "TestEnum_THIRD": + return TestEnum_THIRD, nil + case "TestEnum_FOURTH": + return TestEnum_FOURTH, nil + } + return TestEnum(0), fmt.Errorf("not a valid TestEnum string") +} + +func TestEnumPtr(v TestEnum) *TestEnum { return &v } + +type TestStruct struct { + On bool `thrift:"on,1"` + B int8 `thrift:"b,2"` + Int16 int16 `thrift:"int16,3"` + Int32 int32 `thrift:"int32,4"` + Int64 int64 `thrift:"int64,5"` + D float64 `thrift:"d,6"` + St string `thrift:"st,7"` + Bin []byte `thrift:"bin,8"` + StringMap map[string]string `thrift:"stringMap,9"` + StringList []string `thrift:"stringList,10"` + StringSet map[string]bool `thrift:"stringSet,11"` + E TestEnum `thrift:"e,12"` +} + +func NewTestStruct() *TestStruct { + rval := &TestStruct{} + return rval +} + +func (p *TestStruct) Read(iprot TProtocol) error { + if _, err := iprot.ReadStructBegin(); err != nil { + return fmt.Errorf("%T read error: %s", p, err) + } + for { + _, fieldTypeId, fieldId, err := iprot.ReadFieldBegin() + if err != nil { + return fmt.Errorf("%T field %d read error: %s", p, fieldId, err) + } + if fieldTypeId == STOP { + break + } + switch fieldId { + case 1: + if err := p.readField1(iprot); err != nil { + return err + } + case 2: + if err := p.readField2(iprot); err != nil { + return err + } + case 3: + if err := p.readField3(iprot); err != nil { + return err + } + case 4: + if err := p.readField4(iprot); err != nil { + return err + } + case 5: + if err := p.readField5(iprot); err != nil { + return err + } + case 6: + if err := p.readField6(iprot); err != nil { + return err + } + case 7: + if err := p.readField7(iprot); err != nil { + return err + } + case 8: + if err := p.readField8(iprot); err != nil { + return err + } + case 9: + if err := p.readField9(iprot); err != nil { + return err + } + case 10: + if err := p.readField10(iprot); err != nil { + return err + } + case 11: + if err := p.readField11(iprot); err != nil { + return err + } + case 12: + if err := p.readField12(iprot); err != nil { + return err + } + default: + if err := iprot.Skip(fieldTypeId); err != nil { + return err + } + } + if err := iprot.ReadFieldEnd(); err != nil { + return err + } + } + if err := iprot.ReadStructEnd(); err != nil { + return fmt.Errorf("%T read struct end error: %s", p, err) + } + return nil +} + +func (p *TestStruct) readField1(iprot TProtocol) error { + if v, err := iprot.ReadBool(); err != nil { + return fmt.Errorf("error reading field 1: %s", err) + } else { + p.On = v + } + return nil +} + +func (p *TestStruct) readField2(iprot TProtocol) error { + if v, err := iprot.ReadByte(); err != nil { + return fmt.Errorf("error reading field 2: %s", err) + } else { + temp := int8(v) + p.B = temp + } + return nil +} + +func (p *TestStruct) readField3(iprot TProtocol) error { + if v, err := iprot.ReadI16(); err != nil { + return fmt.Errorf("error reading field 3: %s", err) + } else { + p.Int16 = v + } + return nil +} + +func (p *TestStruct) readField4(iprot TProtocol) error { + if v, err := iprot.ReadI32(); err != nil { + return fmt.Errorf("error reading field 4: %s", err) + } else { + p.Int32 = v + } + return nil +} + +func (p *TestStruct) readField5(iprot TProtocol) error { + if v, err := iprot.ReadI64(); err != nil { + return fmt.Errorf("error reading field 5: %s", err) + } else { + p.Int64 = v + } + return nil +} + +func (p *TestStruct) readField6(iprot TProtocol) error { + if v, err := iprot.ReadDouble(); err != nil { + return fmt.Errorf("error reading field 6: %s", err) + } else { + p.D = v + } + return nil +} + +func (p *TestStruct) readField7(iprot TProtocol) error { + if v, err := iprot.ReadString(); err != nil { + return fmt.Errorf("error reading field 7: %s", err) + } else { + p.St = v + } + return nil +} + +func (p *TestStruct) readField8(iprot TProtocol) error { + if v, err := iprot.ReadBinary(); err != nil { + return fmt.Errorf("error reading field 8: %s", err) + } else { + p.Bin = v + } + return nil +} + +func (p *TestStruct) readField9(iprot TProtocol) error { + _, _, size, err := iprot.ReadMapBegin() + if err != nil { + return fmt.Errorf("error reading map begin: %s") + } + tMap := make(map[string]string, size) + p.StringMap = tMap + for i := 0; i < size; i++ { + var _key0 string + if v, err := iprot.ReadString(); err != nil { + return fmt.Errorf("error reading field 0: %s", err) + } else { + _key0 = v + } + var _val1 string + if v, err := iprot.ReadString(); err != nil { + return fmt.Errorf("error reading field 0: %s", err) + } else { + _val1 = v + } + p.StringMap[_key0] = _val1 + } + if err := iprot.ReadMapEnd(); err != nil { + return fmt.Errorf("error reading map end: %s") + } + return nil +} + +func (p *TestStruct) readField10(iprot TProtocol) error { + _, size, err := iprot.ReadListBegin() + if err != nil { + return fmt.Errorf("error reading list begin: %s") + } + tSlice := make([]string, 0, size) + p.StringList = tSlice + for i := 0; i < size; i++ { + var _elem2 string + if v, err := iprot.ReadString(); err != nil { + return fmt.Errorf("error reading field 0: %s", err) + } else { + _elem2 = v + } + p.StringList = append(p.StringList, _elem2) + } + if err := iprot.ReadListEnd(); err != nil { + return fmt.Errorf("error reading list end: %s") + } + return nil +} + +func (p *TestStruct) readField11(iprot TProtocol) error { + _, size, err := iprot.ReadSetBegin() + if err != nil { + return fmt.Errorf("error reading set begin: %s") + } + tSet := make(map[string]bool, size) + p.StringSet = tSet + for i := 0; i < size; i++ { + var _elem3 string + if v, err := iprot.ReadString(); err != nil { + return fmt.Errorf("error reading field 0: %s", err) + } else { + _elem3 = v + } + p.StringSet[_elem3] = true + } + if err := iprot.ReadSetEnd(); err != nil { + return fmt.Errorf("error reading set end: %s") + } + return nil +} + +func (p *TestStruct) readField12(iprot TProtocol) error { + if v, err := iprot.ReadI32(); err != nil { + return fmt.Errorf("error reading field 12: %s", err) + } else { + temp := TestEnum(v) + p.E = temp + } + return nil +} + +func (p *TestStruct) Write(oprot TProtocol) error { + if err := oprot.WriteStructBegin("TestStruct"); err != nil { + return fmt.Errorf("%T write struct begin error: %s", p, err) + } + if err := p.writeField1(oprot); err != nil { + return err + } + if err := p.writeField2(oprot); err != nil { + return err + } + if err := p.writeField3(oprot); err != nil { + return err + } + if err := p.writeField4(oprot); err != nil { + return err + } + if err := p.writeField5(oprot); err != nil { + return err + } + if err := p.writeField6(oprot); err != nil { + return err + } + if err := p.writeField7(oprot); err != nil { + return err + } + if err := p.writeField8(oprot); err != nil { + return err + } + if err := p.writeField9(oprot); err != nil { + return err + } + if err := p.writeField10(oprot); err != nil { + return err + } + if err := p.writeField11(oprot); err != nil { + return err + } + if err := p.writeField12(oprot); err != nil { + return err + } + if err := oprot.WriteFieldStop(); err != nil { + return fmt.Errorf("write field stop error: %s", err) + } + if err := oprot.WriteStructEnd(); err != nil { + return fmt.Errorf("write struct stop error: %s", err) + } + return nil +} + +func (p *TestStruct) writeField1(oprot TProtocol) (err error) { + if err := oprot.WriteFieldBegin("on", BOOL, 1); err != nil { + return fmt.Errorf("%T write field begin error 1:on: %s", p, err) + } + if err := oprot.WriteBool(bool(p.On)); err != nil { + return fmt.Errorf("%T.on (1) field write error: %s", p, err) + } + if err := oprot.WriteFieldEnd(); err != nil { + return fmt.Errorf("%T write field end error 1:on: %s", p, err) + } + return err +} + +func (p *TestStruct) writeField2(oprot TProtocol) (err error) { + if err := oprot.WriteFieldBegin("b", BYTE, 2); err != nil { + return fmt.Errorf("%T write field begin error 2:b: %s", p, err) + } + if err := oprot.WriteByte(byte(p.B)); err != nil { + return fmt.Errorf("%T.b (2) field write error: %s", p, err) + } + if err := oprot.WriteFieldEnd(); err != nil { + return fmt.Errorf("%T write field end error 2:b: %s", p, err) + } + return err +} + +func (p *TestStruct) writeField3(oprot TProtocol) (err error) { + if err := oprot.WriteFieldBegin("int16", I16, 3); err != nil { + return fmt.Errorf("%T write field begin error 3:int16: %s", p, err) + } + if err := oprot.WriteI16(int16(p.Int16)); err != nil { + return fmt.Errorf("%T.int16 (3) field write error: %s", p, err) + } + if err := oprot.WriteFieldEnd(); err != nil { + return fmt.Errorf("%T write field end error 3:int16: %s", p, err) + } + return err +} + +func (p *TestStruct) writeField4(oprot TProtocol) (err error) { + if err := oprot.WriteFieldBegin("int32", I32, 4); err != nil { + return fmt.Errorf("%T write field begin error 4:int32: %s", p, err) + } + if err := oprot.WriteI32(int32(p.Int32)); err != nil { + return fmt.Errorf("%T.int32 (4) field write error: %s", p, err) + } + if err := oprot.WriteFieldEnd(); err != nil { + return fmt.Errorf("%T write field end error 4:int32: %s", p, err) + } + return err +} + +func (p *TestStruct) writeField5(oprot TProtocol) (err error) { + if err := oprot.WriteFieldBegin("int64", I64, 5); err != nil { + return fmt.Errorf("%T write field begin error 5:int64: %s", p, err) + } + if err := oprot.WriteI64(int64(p.Int64)); err != nil { + return fmt.Errorf("%T.int64 (5) field write error: %s", p, err) + } + if err := oprot.WriteFieldEnd(); err != nil { + return fmt.Errorf("%T write field end error 5:int64: %s", p, err) + } + return err +} + +func (p *TestStruct) writeField6(oprot TProtocol) (err error) { + if err := oprot.WriteFieldBegin("d", DOUBLE, 6); err != nil { + return fmt.Errorf("%T write field begin error 6:d: %s", p, err) + } + if err := oprot.WriteDouble(float64(p.D)); err != nil { + return fmt.Errorf("%T.d (6) field write error: %s", p, err) + } + if err := oprot.WriteFieldEnd(); err != nil { + return fmt.Errorf("%T write field end error 6:d: %s", p, err) + } + return err +} + +func (p *TestStruct) writeField7(oprot TProtocol) (err error) { + if err := oprot.WriteFieldBegin("st", STRING, 7); err != nil { + return fmt.Errorf("%T write field begin error 7:st: %s", p, err) + } + if err := oprot.WriteString(string(p.St)); err != nil { + return fmt.Errorf("%T.st (7) field write error: %s", p, err) + } + if err := oprot.WriteFieldEnd(); err != nil { + return fmt.Errorf("%T write field end error 7:st: %s", p, err) + } + return err +} + +func (p *TestStruct) writeField8(oprot TProtocol) (err error) { + if p.Bin != nil { + if err := oprot.WriteFieldBegin("bin", STRING, 8); err != nil { + return fmt.Errorf("%T write field begin error 8:bin: %s", p, err) + } + if err := oprot.WriteBinary(p.Bin); err != nil { + return fmt.Errorf("%T.bin (8) field write error: %s", p, err) + } + if err := oprot.WriteFieldEnd(); err != nil { + return fmt.Errorf("%T write field end error 8:bin: %s", p, err) + } + } + return err +} + +func (p *TestStruct) writeField9(oprot TProtocol) (err error) { + if p.StringMap != nil { + if err := oprot.WriteFieldBegin("stringMap", MAP, 9); err != nil { + return fmt.Errorf("%T write field begin error 9:stringMap: %s", p, err) + } + if err := oprot.WriteMapBegin(STRING, STRING, len(p.StringMap)); err != nil { + return fmt.Errorf("error writing map begin: %s") + } + for k, v := range p.StringMap { + if err := oprot.WriteString(string(k)); err != nil { + return fmt.Errorf("%T. (0) field write error: %s", p, err) + } + if err := oprot.WriteString(string(v)); err != nil { + return fmt.Errorf("%T. (0) field write error: %s", p, err) + } + } + if err := oprot.WriteMapEnd(); err != nil { + return fmt.Errorf("error writing map end: %s") + } + if err := oprot.WriteFieldEnd(); err != nil { + return fmt.Errorf("%T write field end error 9:stringMap: %s", p, err) + } + } + return err +} + +func (p *TestStruct) writeField10(oprot TProtocol) (err error) { + if p.StringList != nil { + if err := oprot.WriteFieldBegin("stringList", LIST, 10); err != nil { + return fmt.Errorf("%T write field begin error 10:stringList: %s", p, err) + } + if err := oprot.WriteListBegin(STRING, len(p.StringList)); err != nil { + return fmt.Errorf("error writing list begin: %s") + } + for _, v := range p.StringList { + if err := oprot.WriteString(string(v)); err != nil { + return fmt.Errorf("%T. (0) field write error: %s", p, err) + } + } + if err := oprot.WriteListEnd(); err != nil { + return fmt.Errorf("error writing list end: %s") + } + if err := oprot.WriteFieldEnd(); err != nil { + return fmt.Errorf("%T write field end error 10:stringList: %s", p, err) + } + } + return err +} + +func (p *TestStruct) writeField11(oprot TProtocol) (err error) { + if p.StringSet != nil { + if err := oprot.WriteFieldBegin("stringSet", SET, 11); err != nil { + return fmt.Errorf("%T write field begin error 11:stringSet: %s", p, err) + } + if err := oprot.WriteSetBegin(STRING, len(p.StringSet)); err != nil { + return fmt.Errorf("error writing set begin: %s") + } + for v, _ := range p.StringSet { + if err := oprot.WriteString(string(v)); err != nil { + return fmt.Errorf("%T. (0) field write error: %s", p, err) + } + } + if err := oprot.WriteSetEnd(); err != nil { + return fmt.Errorf("error writing set end: %s") + } + if err := oprot.WriteFieldEnd(); err != nil { + return fmt.Errorf("%T write field end error 11:stringSet: %s", p, err) + } + } + return err +} + +func (p *TestStruct) writeField12(oprot TProtocol) (err error) { + if err := oprot.WriteFieldBegin("e", I32, 12); err != nil { + return fmt.Errorf("%T write field begin error 12:e: %s", p, err) + } + if err := oprot.WriteI32(int32(p.E)); err != nil { + return fmt.Errorf("%T.e (12) field write error: %s", p, err) + } + if err := oprot.WriteFieldEnd(); err != nil { + return fmt.Errorf("%T write field end error 12:e: %s", p, err) + } + return err +} + +func (p *TestStruct) String() string { + if p == nil { + return "" + } + return fmt.Sprintf("TestStruct(%+v)", *p) +} diff --git a/thrift/thrift/server.go b/thrift/thrift/server.go new file mode 100644 index 0000000..f813fa3 --- /dev/null +++ b/thrift/thrift/server.go @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +type TServer interface { + ProcessorFactory() TProcessorFactory + ServerTransport() TServerTransport + InputTransportFactory() TTransportFactory + OutputTransportFactory() TTransportFactory + InputProtocolFactory() TProtocolFactory + OutputProtocolFactory() TProtocolFactory + + // Starts the server + Serve() error + // Stops the server. This is optional on a per-implementation basis. Not + // all servers are required to be cleanly stoppable. + Stop() error +} diff --git a/thrift/thrift/server_socket.go b/thrift/thrift/server_socket.go new file mode 100644 index 0000000..936eb2e --- /dev/null +++ b/thrift/thrift/server_socket.go @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import ( + "net" + "sync" + "time" +) + +type TServerSocket struct { + listener net.Listener + addr net.Addr + clientTimeout time.Duration + + // Protects the interrupted value to make it thread safe. + mu sync.RWMutex + interrupted bool + + //Size of buffer to use for socket. Defaults to 1024. + //Set to 0 to disable bufferring server transport altogether. + BufferSize int +} + +func NewTServerSocket(listenAddr string) (*TServerSocket, error) { + return NewTServerSocketTimeout(listenAddr, 0) +} + +func NewTServerSocketTimeout(listenAddr string, clientTimeout time.Duration) (*TServerSocket, error) { + addr, err := net.ResolveTCPAddr("tcp", listenAddr) + if err != nil { + return nil, err + } + return &TServerSocket{addr: addr, clientTimeout: clientTimeout, BufferSize: 1024}, nil +} + +func (p *TServerSocket) Listen() error { + if p.IsListening() { + return nil + } + l, err := net.Listen(p.addr.Network(), p.addr.String()) + if err != nil { + return err + } + p.listener = l + return nil +} + +func (p *TServerSocket) Accept() (TTransport, error) { + p.mu.RLock() + interrupted := p.interrupted + p.mu.RUnlock() + + if interrupted { + return nil, errTransportInterrupted + } + if p.listener == nil { + return nil, NewTTransportException(NOT_OPEN, "No underlying server socket") + } + conn, err := p.listener.Accept() + if err != nil { + return nil, NewTTransportExceptionFromError(err) + } + var trans TTransport + trans = NewTSocketFromConnTimeout(conn, p.clientTimeout) + if p.BufferSize != 0 { + trans = NewTBufferedTransport(trans, p.BufferSize) + } + return trans, nil +} + +// Checks whether the socket is listening. +func (p *TServerSocket) IsListening() bool { + return p.listener != nil +} + +// Connects the socket, creating a new socket object if necessary. +func (p *TServerSocket) Open() error { + if p.IsListening() { + return NewTTransportException(ALREADY_OPEN, "Server socket already open") + } + if l, err := net.Listen(p.addr.Network(), p.addr.String()); err != nil { + return err + } else { + p.listener = l + } + return nil +} + +func (p *TServerSocket) Addr() net.Addr { + return p.addr +} + +func (p *TServerSocket) Close() error { + defer func() { + p.listener = nil + }() + if p.IsListening() { + return p.listener.Close() + } + return nil +} + +func (p *TServerSocket) Interrupt() error { + p.mu.Lock() + p.interrupted = true + p.mu.Unlock() + + return nil +} diff --git a/thrift/thrift/server_test.go b/thrift/thrift/server_test.go new file mode 100644 index 0000000..ffaf457 --- /dev/null +++ b/thrift/thrift/server_test.go @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import ( + "testing" +) + +func TestNothing(t *testing.T) { + +} diff --git a/thrift/thrift/server_transport.go b/thrift/thrift/server_transport.go new file mode 100644 index 0000000..51c40b6 --- /dev/null +++ b/thrift/thrift/server_transport.go @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +// Server transport. Object which provides client transports. +type TServerTransport interface { + Listen() error + Accept() (TTransport, error) + Close() error + + // Optional method implementation. This signals to the server transport + // that it should break out of any accept() or listen() that it is currently + // blocked on. This method, if implemented, MUST be thread safe, as it may + // be called from a different thread context than the other TServerTransport + // methods. + Interrupt() error +} diff --git a/thrift/thrift/simple_json_protocol.go b/thrift/thrift/simple_json_protocol.go new file mode 100644 index 0000000..bb4a5c2 --- /dev/null +++ b/thrift/thrift/simple_json_protocol.go @@ -0,0 +1,1277 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import ( + "bufio" + "bytes" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "math" + "strconv" +) + +type _ParseContext int + +const ( + _CONTEXT_IN_TOPLEVEL _ParseContext = 1 + _CONTEXT_IN_LIST_FIRST _ParseContext = 2 + _CONTEXT_IN_LIST _ParseContext = 3 + _CONTEXT_IN_OBJECT_FIRST _ParseContext = 4 + _CONTEXT_IN_OBJECT_NEXT_KEY _ParseContext = 5 + _CONTEXT_IN_OBJECT_NEXT_VALUE _ParseContext = 6 +) + +func (p _ParseContext) String() string { + switch p { + case _CONTEXT_IN_TOPLEVEL: + return "TOPLEVEL" + case _CONTEXT_IN_LIST_FIRST: + return "LIST-FIRST" + case _CONTEXT_IN_LIST: + return "LIST" + case _CONTEXT_IN_OBJECT_FIRST: + return "OBJECT-FIRST" + case _CONTEXT_IN_OBJECT_NEXT_KEY: + return "OBJECT-NEXT-KEY" + case _CONTEXT_IN_OBJECT_NEXT_VALUE: + return "OBJECT-NEXT-VALUE" + } + return "UNKNOWN-PARSE-CONTEXT" +} + +// JSON protocol implementation for thrift. +// +// This protocol produces/consumes a simple output format +// suitable for parsing by scripting languages. It should not be +// confused with the full-featured TJSONProtocol. +// +type TSimpleJSONProtocol struct { + trans TTransport + + parseContextStack []int + dumpContext []int + + writer *bufio.Writer + reader *bufio.Reader +} + +// Constructor +func NewTSimpleJSONProtocol(t TTransport) *TSimpleJSONProtocol { + v := &TSimpleJSONProtocol{trans: t, + writer: bufio.NewWriter(t), + reader: bufio.NewReader(t), + } + v.parseContextStack = append(v.parseContextStack, int(_CONTEXT_IN_TOPLEVEL)) + v.dumpContext = append(v.dumpContext, int(_CONTEXT_IN_TOPLEVEL)) + return v +} + +// Factory +type TSimpleJSONProtocolFactory struct{} + +func (p *TSimpleJSONProtocolFactory) GetProtocol(trans TTransport) TProtocol { + return NewTSimpleJSONProtocol(trans) +} + +func NewTSimpleJSONProtocolFactory() *TSimpleJSONProtocolFactory { + return &TSimpleJSONProtocolFactory{} +} + +var ( + JSON_COMMA []byte + JSON_COLON []byte + JSON_LBRACE []byte + JSON_RBRACE []byte + JSON_LBRACKET []byte + JSON_RBRACKET []byte + JSON_QUOTE byte + JSON_QUOTE_BYTES []byte + JSON_NULL []byte + JSON_TRUE []byte + JSON_FALSE []byte + JSON_INFINITY string + JSON_NEGATIVE_INFINITY string + JSON_NAN string + JSON_INFINITY_BYTES []byte + JSON_NEGATIVE_INFINITY_BYTES []byte + JSON_NAN_BYTES []byte + json_nonbase_map_elem_bytes []byte +) + +func init() { + JSON_COMMA = []byte{','} + JSON_COLON = []byte{':'} + JSON_LBRACE = []byte{'{'} + JSON_RBRACE = []byte{'}'} + JSON_LBRACKET = []byte{'['} + JSON_RBRACKET = []byte{']'} + JSON_QUOTE = '"' + JSON_QUOTE_BYTES = []byte{'"'} + JSON_NULL = []byte{'n', 'u', 'l', 'l'} + JSON_TRUE = []byte{'t', 'r', 'u', 'e'} + JSON_FALSE = []byte{'f', 'a', 'l', 's', 'e'} + JSON_INFINITY = "Infinity" + JSON_NEGATIVE_INFINITY = "-Infinity" + JSON_NAN = "NaN" + JSON_INFINITY_BYTES = []byte{'I', 'n', 'f', 'i', 'n', 'i', 't', 'y'} + JSON_NEGATIVE_INFINITY_BYTES = []byte{'-', 'I', 'n', 'f', 'i', 'n', 'i', 't', 'y'} + JSON_NAN_BYTES = []byte{'N', 'a', 'N'} + json_nonbase_map_elem_bytes = []byte{']', ',', '['} +} + +func jsonQuote(s string) string { + b, _ := json.Marshal(s) + s1 := string(b) + return s1 +} + +func jsonUnquote(s string) (string, bool) { + s1 := new(string) + err := json.Unmarshal([]byte(s), s1) + return *s1, err == nil +} + +func mismatch(expected, actual string) error { + return fmt.Errorf("Expected '%s' but found '%s' while parsing JSON.", expected, actual) +} + +func (p *TSimpleJSONProtocol) WriteMessageBegin(name string, typeId TMessageType, seqId int32) error { + if e := p.OutputListBegin(); e != nil { + return e + } + if e := p.WriteString(name); e != nil { + return e + } + if e := p.WriteByte(byte(typeId)); e != nil { + return e + } + if e := p.WriteI32(seqId); e != nil { + return e + } + return nil +} + +func (p *TSimpleJSONProtocol) WriteMessageEnd() error { + return p.OutputListEnd() +} + +func (p *TSimpleJSONProtocol) WriteStructBegin(name string) error { + if e := p.OutputObjectBegin(); e != nil { + return e + } + return nil +} + +func (p *TSimpleJSONProtocol) WriteStructEnd() error { + return p.OutputObjectEnd() +} + +func (p *TSimpleJSONProtocol) WriteFieldBegin(name string, typeId TType, id int16) error { + if e := p.WriteString(name); e != nil { + return e + } + return nil +} + +func (p *TSimpleJSONProtocol) WriteFieldEnd() error { + //return p.OutputListEnd() + return nil +} + +func (p *TSimpleJSONProtocol) WriteFieldStop() error { return nil } + +func (p *TSimpleJSONProtocol) WriteMapBegin(keyType TType, valueType TType, size int) error { + if e := p.OutputListBegin(); e != nil { + return e + } + if e := p.WriteByte(byte(keyType)); e != nil { + return e + } + if e := p.WriteByte(byte(valueType)); e != nil { + return e + } + return p.WriteI32(int32(size)) +} + +func (p *TSimpleJSONProtocol) WriteMapEnd() error { + return p.OutputListEnd() +} + +func (p *TSimpleJSONProtocol) WriteListBegin(elemType TType, size int) error { + return p.OutputElemListBegin(elemType, size) +} + +func (p *TSimpleJSONProtocol) WriteListEnd() error { + return p.OutputListEnd() +} + +func (p *TSimpleJSONProtocol) WriteSetBegin(elemType TType, size int) error { + return p.OutputElemListBegin(elemType, size) +} + +func (p *TSimpleJSONProtocol) WriteSetEnd() error { + return p.OutputListEnd() +} + +func (p *TSimpleJSONProtocol) WriteBool(b bool) error { + return p.OutputBool(b) +} + +func (p *TSimpleJSONProtocol) WriteByte(b byte) error { + return p.WriteI32(int32(b)) +} + +func (p *TSimpleJSONProtocol) WriteI16(v int16) error { + return p.WriteI32(int32(v)) +} + +func (p *TSimpleJSONProtocol) WriteI32(v int32) error { + return p.OutputI64(int64(v)) +} + +func (p *TSimpleJSONProtocol) WriteI64(v int64) error { + return p.OutputI64(int64(v)) +} + +func (p *TSimpleJSONProtocol) WriteDouble(v float64) error { + return p.OutputF64(v) +} + +func (p *TSimpleJSONProtocol) WriteString(v string) error { + return p.OutputString(v) +} + +func (p *TSimpleJSONProtocol) WriteBinary(v []byte) error { + // JSON library only takes in a string, + // not an arbitrary byte array, to ensure bytes are transmitted + // efficiently we must convert this into a valid JSON string + // therefore we use base64 encoding to avoid excessive escaping/quoting + if e := p.OutputPreValue(); e != nil { + return e + } + if _, e := p.writer.Write(JSON_QUOTE_BYTES); e != nil { + return NewTProtocolException(e) + } + writer := base64.NewEncoder(base64.StdEncoding, p.writer) + if _, e := writer.Write(v); e != nil { + return NewTProtocolException(e) + } + if e := writer.Close(); e != nil { + return NewTProtocolException(e) + } + if _, e := p.writer.Write(JSON_QUOTE_BYTES); e != nil { + return NewTProtocolException(e) + } + return p.OutputPostValue() +} + +// Reading methods. + +func (p *TSimpleJSONProtocol) ReadMessageBegin() (name string, typeId TMessageType, seqId int32, err error) { + if isNull, err := p.ParseListBegin(); isNull || err != nil { + return name, typeId, seqId, err + } + if name, err = p.ReadString(); err != nil { + return name, typeId, seqId, err + } + bTypeId, err := p.ReadByte() + typeId = TMessageType(bTypeId) + if err != nil { + return name, typeId, seqId, err + } + if seqId, err = p.ReadI32(); err != nil { + return name, typeId, seqId, err + } + return name, typeId, seqId, nil +} + +func (p *TSimpleJSONProtocol) ReadMessageEnd() error { + return p.ParseListEnd() +} + +func (p *TSimpleJSONProtocol) ReadStructBegin() (name string, err error) { + _, err = p.ParseObjectStart() + return "", err +} + +func (p *TSimpleJSONProtocol) ReadStructEnd() error { + return p.ParseObjectEnd() +} + +func (p *TSimpleJSONProtocol) ReadFieldBegin() (string, TType, int16, error) { + if err := p.ParsePreValue(); err != nil { + return "", STOP, 0, err + } + b, _ := p.reader.Peek(1) + if len(b) > 0 { + switch b[0] { + case JSON_RBRACE[0]: + return "", STOP, 0, nil + case JSON_QUOTE: + p.reader.ReadByte() + name, err := p.ParseStringBody() + // simplejson is not meant to be read back into thrift + // - see http://wiki.apache.org/thrift/ThriftUsageJava + // - use JSON instead + if err != nil { + return name, STOP, 0, err + } + return name, STOP, -1, p.ParsePostValue() + /* + if err = p.ParsePostValue(); err != nil { + return name, STOP, 0, err + } + if isNull, err := p.ParseListBegin(); isNull || err != nil { + return name, STOP, 0, err + } + bType, err := p.ReadByte() + thetype := TType(bType) + if err != nil { + return name, thetype, 0, err + } + id, err := p.ReadI16() + return name, thetype, id, err + */ + } + e := fmt.Errorf("Expected \"}\" or '\"', but found: '%s'", string(b)) + return "", STOP, 0, NewTProtocolExceptionWithType(INVALID_DATA, e) + } + return "", STOP, 0, NewTProtocolException(io.EOF) +} + +func (p *TSimpleJSONProtocol) ReadFieldEnd() error { + return nil + //return p.ParseListEnd() +} + +func (p *TSimpleJSONProtocol) ReadMapBegin() (keyType TType, valueType TType, size int, e error) { + if isNull, e := p.ParseListBegin(); isNull || e != nil { + return VOID, VOID, 0, e + } + + // read keyType + bKeyType, e := p.ReadByte() + keyType = TType(bKeyType) + if e != nil { + return keyType, valueType, size, e + } + + // read valueType + bValueType, e := p.ReadByte() + valueType = TType(bValueType) + if e != nil { + return keyType, valueType, size, e + } + + // read size + iSize, err := p.ReadI64() + size = int(iSize) + return keyType, valueType, size, err +} + +func (p *TSimpleJSONProtocol) ReadMapEnd() error { + return p.ParseListEnd() +} + +func (p *TSimpleJSONProtocol) ReadListBegin() (elemType TType, size int, e error) { + return p.ParseElemListBegin() +} + +func (p *TSimpleJSONProtocol) ReadListEnd() error { + return p.ParseListEnd() +} + +func (p *TSimpleJSONProtocol) ReadSetBegin() (elemType TType, size int, e error) { + return p.ParseElemListBegin() +} + +func (p *TSimpleJSONProtocol) ReadSetEnd() error { + return p.ParseListEnd() +} + +func (p *TSimpleJSONProtocol) ReadBool() (bool, error) { + var value bool + if err := p.ParsePreValue(); err != nil { + return value, err + } + b, _ := p.reader.Peek(len(JSON_TRUE)) + if len(b) > 0 { + switch b[0] { + case JSON_TRUE[0]: + if string(b) == string(JSON_TRUE) { + p.reader.Read(b[0:len(JSON_TRUE)]) + value = true + } else { + e := fmt.Errorf("Expected \"true\" but found: %s", string(b)) + return value, NewTProtocolExceptionWithType(INVALID_DATA, e) + } + break + case JSON_FALSE[0]: + if string(b) == string(JSON_FALSE[:len(b)]) { + p.reader.Read(b[0:len(JSON_FALSE)]) + value = false + } else { + e := fmt.Errorf("Expected \"false\" but found: %s", string(b)) + return value, NewTProtocolExceptionWithType(INVALID_DATA, e) + } + break + case JSON_NULL[0]: + if string(b) == string(JSON_NULL) { + p.reader.Read(b[0:len(JSON_NULL)]) + value = false + } else { + e := fmt.Errorf("Expected \"null\" but found: %s", string(b)) + return value, NewTProtocolExceptionWithType(INVALID_DATA, e) + } + default: + e := fmt.Errorf("Expected \"true\", \"false\", or \"null\" but found: %s", string(b)) + return value, NewTProtocolExceptionWithType(INVALID_DATA, e) + } + } + return value, p.ParsePostValue() +} + +func (p *TSimpleJSONProtocol) ReadByte() (byte, error) { + v, err := p.ReadI64() + return byte(v), err +} + +func (p *TSimpleJSONProtocol) ReadI16() (int16, error) { + v, err := p.ReadI64() + return int16(v), err +} + +func (p *TSimpleJSONProtocol) ReadI32() (int32, error) { + v, err := p.ReadI64() + return int32(v), err +} + +func (p *TSimpleJSONProtocol) ReadI64() (int64, error) { + v, _, err := p.ParseI64() + return v, err +} + +func (p *TSimpleJSONProtocol) ReadDouble() (float64, error) { + v, _, err := p.ParseF64() + return v, err +} + +func (p *TSimpleJSONProtocol) ReadString() (string, error) { + var v string + if err := p.ParsePreValue(); err != nil { + return v, err + } + var b []byte + b, _ = p.reader.Peek(len(JSON_NULL)) + if len(b) > 0 && b[0] == JSON_QUOTE { + p.reader.ReadByte() + value, err := p.ParseStringBody() + v = value + if err != nil { + return v, err + } + } else if len(b) >= len(JSON_NULL) && string(b[0:len(JSON_NULL)]) == string(JSON_NULL) { + _, err := p.reader.Read(b[0:len(JSON_NULL)]) + if err != nil { + return v, NewTProtocolException(err) + } + } else { + e := fmt.Errorf("Expected a JSON string, found %s", string(b)) + return v, NewTProtocolExceptionWithType(INVALID_DATA, e) + } + return v, p.ParsePostValue() +} + +func (p *TSimpleJSONProtocol) ReadBinary() ([]byte, error) { + var v []byte + if err := p.ParsePreValue(); err != nil { + return nil, err + } + b, _ := p.reader.Peek(len(JSON_NULL)) + if len(b) > 0 && b[0] == JSON_QUOTE { + p.reader.ReadByte() + value, err := p.ParseBase64EncodedBody() + v = value + if err != nil { + return v, err + } + } else if len(b) >= len(JSON_NULL) && string(b[0:len(JSON_NULL)]) == string(JSON_NULL) { + _, err := p.reader.Read(b[0:len(JSON_NULL)]) + if err != nil { + return v, NewTProtocolException(err) + } + } else { + e := fmt.Errorf("Expected a JSON string, found %s", string(b)) + return v, NewTProtocolExceptionWithType(INVALID_DATA, e) + } + return v, p.ParsePostValue() +} + +func (p *TSimpleJSONProtocol) Flush() (err error) { + return NewTProtocolException(p.writer.Flush()) +} + +func (p *TSimpleJSONProtocol) Skip(fieldType TType) (err error) { + return SkipDefaultDepth(p, fieldType) +} + +func (p *TSimpleJSONProtocol) Transport() TTransport { + return p.trans +} + +func (p *TSimpleJSONProtocol) OutputPreValue() error { + cxt := _ParseContext(p.dumpContext[len(p.dumpContext)-1]) + switch cxt { + case _CONTEXT_IN_LIST, _CONTEXT_IN_OBJECT_NEXT_KEY: + if _, e := p.writer.Write(JSON_COMMA); e != nil { + return NewTProtocolException(e) + } + break + case _CONTEXT_IN_OBJECT_NEXT_VALUE: + if _, e := p.writer.Write(JSON_COLON); e != nil { + return NewTProtocolException(e) + } + break + } + return nil +} + +func (p *TSimpleJSONProtocol) OutputPostValue() error { + cxt := _ParseContext(p.dumpContext[len(p.dumpContext)-1]) + switch cxt { + case _CONTEXT_IN_LIST_FIRST: + p.dumpContext = p.dumpContext[:len(p.dumpContext)-1] + p.dumpContext = append(p.dumpContext, int(_CONTEXT_IN_LIST)) + break + case _CONTEXT_IN_OBJECT_FIRST: + p.dumpContext = p.dumpContext[:len(p.dumpContext)-1] + p.dumpContext = append(p.dumpContext, int(_CONTEXT_IN_OBJECT_NEXT_VALUE)) + break + case _CONTEXT_IN_OBJECT_NEXT_KEY: + p.dumpContext = p.dumpContext[:len(p.dumpContext)-1] + p.dumpContext = append(p.dumpContext, int(_CONTEXT_IN_OBJECT_NEXT_VALUE)) + break + case _CONTEXT_IN_OBJECT_NEXT_VALUE: + p.dumpContext = p.dumpContext[:len(p.dumpContext)-1] + p.dumpContext = append(p.dumpContext, int(_CONTEXT_IN_OBJECT_NEXT_KEY)) + break + } + return nil +} + +func (p *TSimpleJSONProtocol) OutputBool(value bool) error { + if e := p.OutputPreValue(); e != nil { + return e + } + var v string + if value { + v = string(JSON_TRUE) + } else { + v = string(JSON_FALSE) + } + switch _ParseContext(p.dumpContext[len(p.dumpContext)-1]) { + case _CONTEXT_IN_OBJECT_FIRST, _CONTEXT_IN_OBJECT_NEXT_KEY: + v = jsonQuote(v) + default: + } + if e := p.OutputStringData(v); e != nil { + return e + } + return p.OutputPostValue() +} + +func (p *TSimpleJSONProtocol) OutputNull() error { + if e := p.OutputPreValue(); e != nil { + return e + } + if _, e := p.writer.Write(JSON_NULL); e != nil { + return NewTProtocolException(e) + } + return p.OutputPostValue() +} + +func (p *TSimpleJSONProtocol) OutputF64(value float64) error { + if e := p.OutputPreValue(); e != nil { + return e + } + var v string + if math.IsNaN(value) { + v = string(JSON_QUOTE) + JSON_NAN + string(JSON_QUOTE) + } else if math.IsInf(value, 1) { + v = string(JSON_QUOTE) + JSON_INFINITY + string(JSON_QUOTE) + } else if math.IsInf(value, -1) { + v = string(JSON_QUOTE) + JSON_NEGATIVE_INFINITY + string(JSON_QUOTE) + } else { + v = strconv.FormatFloat(value, 'g', -1, 64) + switch _ParseContext(p.dumpContext[len(p.dumpContext)-1]) { + case _CONTEXT_IN_OBJECT_FIRST, _CONTEXT_IN_OBJECT_NEXT_KEY: + v = string(JSON_QUOTE) + v + string(JSON_QUOTE) + default: + } + } + if e := p.OutputStringData(v); e != nil { + return e + } + return p.OutputPostValue() +} + +func (p *TSimpleJSONProtocol) OutputI64(value int64) error { + if e := p.OutputPreValue(); e != nil { + return e + } + v := strconv.FormatInt(value, 10) + switch _ParseContext(p.dumpContext[len(p.dumpContext)-1]) { + case _CONTEXT_IN_OBJECT_FIRST, _CONTEXT_IN_OBJECT_NEXT_KEY: + v = jsonQuote(v) + default: + } + if e := p.OutputStringData(v); e != nil { + return e + } + return p.OutputPostValue() +} + +func (p *TSimpleJSONProtocol) OutputString(s string) error { + if e := p.OutputPreValue(); e != nil { + return e + } + if e := p.OutputStringData(jsonQuote(s)); e != nil { + return e + } + return p.OutputPostValue() +} + +func (p *TSimpleJSONProtocol) OutputStringData(s string) error { + _, e := p.writer.Write([]byte(s)) + return NewTProtocolException(e) +} + +func (p *TSimpleJSONProtocol) OutputObjectBegin() error { + if e := p.OutputPreValue(); e != nil { + return e + } + if _, e := p.writer.Write(JSON_LBRACE); e != nil { + return NewTProtocolException(e) + } + p.dumpContext = append(p.dumpContext, int(_CONTEXT_IN_OBJECT_FIRST)) + return nil +} + +func (p *TSimpleJSONProtocol) OutputObjectEnd() error { + if _, e := p.writer.Write(JSON_RBRACE); e != nil { + return NewTProtocolException(e) + } + p.dumpContext = p.dumpContext[:len(p.dumpContext)-1] + if e := p.OutputPostValue(); e != nil { + return e + } + return nil +} + +func (p *TSimpleJSONProtocol) OutputListBegin() error { + if e := p.OutputPreValue(); e != nil { + return e + } + if _, e := p.writer.Write(JSON_LBRACKET); e != nil { + return NewTProtocolException(e) + } + p.dumpContext = append(p.dumpContext, int(_CONTEXT_IN_LIST_FIRST)) + return nil +} + +func (p *TSimpleJSONProtocol) OutputListEnd() error { + if _, e := p.writer.Write(JSON_RBRACKET); e != nil { + return NewTProtocolException(e) + } + p.dumpContext = p.dumpContext[:len(p.dumpContext)-1] + if e := p.OutputPostValue(); e != nil { + return e + } + return nil +} + +func (p *TSimpleJSONProtocol) OutputElemListBegin(elemType TType, size int) error { + if e := p.OutputListBegin(); e != nil { + return e + } + if e := p.WriteByte(byte(elemType)); e != nil { + return e + } + if e := p.WriteI64(int64(size)); e != nil { + return e + } + return nil +} + +func (p *TSimpleJSONProtocol) ParsePreValue() error { + if e := p.readNonSignificantWhitespace(); e != nil { + return NewTProtocolException(e) + } + cxt := _ParseContext(p.parseContextStack[len(p.parseContextStack)-1]) + b, _ := p.reader.Peek(1) + switch cxt { + case _CONTEXT_IN_LIST: + if len(b) > 0 { + switch b[0] { + case JSON_RBRACKET[0]: + return nil + case JSON_COMMA[0]: + p.reader.ReadByte() + if e := p.readNonSignificantWhitespace(); e != nil { + return NewTProtocolException(e) + } + return nil + default: + e := fmt.Errorf("Expected \"]\" or \",\" in list context, but found \"%s\"", string(b)) + return NewTProtocolExceptionWithType(INVALID_DATA, e) + } + } + break + case _CONTEXT_IN_OBJECT_NEXT_KEY: + if len(b) > 0 { + switch b[0] { + case JSON_RBRACE[0]: + return nil + case JSON_COMMA[0]: + p.reader.ReadByte() + if e := p.readNonSignificantWhitespace(); e != nil { + return NewTProtocolException(e) + } + return nil + default: + e := fmt.Errorf("Expected \"}\" or \",\" in object context, but found \"%s\"", string(b)) + return NewTProtocolExceptionWithType(INVALID_DATA, e) + } + } + break + case _CONTEXT_IN_OBJECT_NEXT_VALUE: + if len(b) > 0 { + switch b[0] { + case JSON_COLON[0]: + p.reader.ReadByte() + if e := p.readNonSignificantWhitespace(); e != nil { + return NewTProtocolException(e) + } + return nil + default: + e := fmt.Errorf("Expected \":\" in object context, but found \"%s\"", string(b)) + return NewTProtocolExceptionWithType(INVALID_DATA, e) + } + } + break + } + return nil +} + +func (p *TSimpleJSONProtocol) ParsePostValue() error { + if e := p.readNonSignificantWhitespace(); e != nil { + return NewTProtocolException(e) + } + cxt := _ParseContext(p.parseContextStack[len(p.parseContextStack)-1]) + switch cxt { + case _CONTEXT_IN_LIST_FIRST: + p.parseContextStack = p.parseContextStack[:len(p.parseContextStack)-1] + p.parseContextStack = append(p.parseContextStack, int(_CONTEXT_IN_LIST)) + break + case _CONTEXT_IN_OBJECT_FIRST, _CONTEXT_IN_OBJECT_NEXT_KEY: + p.parseContextStack = p.parseContextStack[:len(p.parseContextStack)-1] + p.parseContextStack = append(p.parseContextStack, int(_CONTEXT_IN_OBJECT_NEXT_VALUE)) + break + case _CONTEXT_IN_OBJECT_NEXT_VALUE: + p.parseContextStack = p.parseContextStack[:len(p.parseContextStack)-1] + p.parseContextStack = append(p.parseContextStack, int(_CONTEXT_IN_OBJECT_NEXT_KEY)) + break + } + return nil +} + +func (p *TSimpleJSONProtocol) readNonSignificantWhitespace() error { + for { + b, _ := p.reader.Peek(1) + if len(b) < 1 { + return nil + } + switch b[0] { + case ' ', '\r', '\n', '\t': + p.reader.ReadByte() + continue + default: + break + } + break + } + return nil +} + +func (p *TSimpleJSONProtocol) ParseStringBody() (string, error) { + line, err := p.reader.ReadString(JSON_QUOTE) + if err != nil { + return "", NewTProtocolException(err) + } + l := len(line) + // count number of escapes to see if we need to keep going + i := 1 + for ; i < l; i++ { + if line[l-i-1] != '\\' { + break + } + } + if i&0x01 == 1 { + v, ok := jsonUnquote(string(JSON_QUOTE) + line) + if !ok { + return "", NewTProtocolException(err) + } + return v, nil + } + s, err := p.ParseQuotedStringBody() + if err != nil { + return "", NewTProtocolException(err) + } + str := string(JSON_QUOTE) + line + s + v, ok := jsonUnquote(str) + if !ok { + e := fmt.Errorf("Unable to parse as JSON string %s", str) + return "", NewTProtocolExceptionWithType(INVALID_DATA, e) + } + return v, nil +} + +func (p *TSimpleJSONProtocol) ParseQuotedStringBody() (string, error) { + line, err := p.reader.ReadString(JSON_QUOTE) + if err != nil { + return "", NewTProtocolException(err) + } + l := len(line) + // count number of escapes to see if we need to keep going + i := 1 + for ; i < l; i++ { + if line[l-i-1] != '\\' { + break + } + } + if i&0x01 == 1 { + return line, nil + } + s, err := p.ParseQuotedStringBody() + if err != nil { + return "", NewTProtocolException(err) + } + v := line + s + return v, nil +} + +func (p *TSimpleJSONProtocol) ParseBase64EncodedBody() ([]byte, error) { + line, err := p.reader.ReadBytes(JSON_QUOTE) + if err != nil { + return line, NewTProtocolException(err) + } + line2 := line[0 : len(line)-1] + l := len(line2) + output := make([]byte, base64.StdEncoding.DecodedLen(l)) + n, err := base64.StdEncoding.Decode(output, line2) + return output[0:n], NewTProtocolException(err) +} + +func (p *TSimpleJSONProtocol) ParseI64() (int64, bool, error) { + if err := p.ParsePreValue(); err != nil { + return 0, false, err + } + var value int64 + var isnull bool + b, _ := p.reader.Peek(len(JSON_NULL)) + if len(b) >= len(JSON_NULL) && string(b) == string(JSON_NULL) { + p.reader.Read(b[0:len(JSON_NULL)]) + isnull = true + } else { + num, err := p.readNumeric() + isnull = (num == nil) + if !isnull { + value = num.Int64() + } + if err != nil { + return value, isnull, err + } + } + return value, isnull, p.ParsePostValue() +} + +func (p *TSimpleJSONProtocol) ParseF64() (float64, bool, error) { + if err := p.ParsePreValue(); err != nil { + return 0, false, err + } + var value float64 + var isnull bool + b, _ := p.reader.Peek(len(JSON_NULL)) + if len(b) >= len(JSON_NULL) && string(b) == string(JSON_NULL) { + p.reader.Read(b[0:len(JSON_NULL)]) + isnull = true + } else { + num, err := p.readNumeric() + isnull = (num == nil) + if !isnull { + value = num.Float64() + } + if err != nil { + return value, isnull, err + } + } + return value, isnull, p.ParsePostValue() +} + +func (p *TSimpleJSONProtocol) ParseObjectStart() (bool, error) { + if err := p.ParsePreValue(); err != nil { + return false, err + } + var b []byte + b, _ = p.reader.Peek(len(JSON_NULL)) + if len(b) > 0 && b[0] == JSON_LBRACE[0] { + p.reader.ReadByte() + p.parseContextStack = append(p.parseContextStack, int(_CONTEXT_IN_OBJECT_FIRST)) + return false, nil + } else if len(b) >= len(JSON_NULL) && string(b[0:len(JSON_NULL)]) == string(JSON_NULL) { + return true, nil + } + e := fmt.Errorf("Expected '{' or null, but found '%s'", string(b)) + return false, NewTProtocolExceptionWithType(INVALID_DATA, e) +} + +func (p *TSimpleJSONProtocol) ParseObjectEnd() error { + if isNull, err := p.readIfNull(); isNull || err != nil { + return err + } + cxt := _ParseContext(p.parseContextStack[len(p.parseContextStack)-1]) + if cxt != _CONTEXT_IN_OBJECT_FIRST && cxt != _CONTEXT_IN_OBJECT_NEXT_KEY { + e := fmt.Errorf("Expected to be in the Object Context, but not in Object Context") + return NewTProtocolExceptionWithType(INVALID_DATA, e) + } + line, err := p.reader.ReadString(JSON_RBRACE[0]) + if err != nil { + return NewTProtocolException(err) + } + for _, char := range line { + switch char { + default: + e := fmt.Errorf("Expecting end of object \"}\", but found: \"%s\"", line) + return NewTProtocolExceptionWithType(INVALID_DATA, e) + case ' ', '\n', '\r', '\t', '}': + break + } + } + p.parseContextStack = p.parseContextStack[:len(p.parseContextStack)-1] + return p.ParsePostValue() +} + +func (p *TSimpleJSONProtocol) ParseListBegin() (isNull bool, err error) { + if e := p.ParsePreValue(); e != nil { + return false, e + } + var b []byte + b, err = p.reader.Peek(len(JSON_NULL)) + if err != nil { + return false, err + } + if len(b) >= 1 && b[0] == JSON_LBRACKET[0] { + p.parseContextStack = append(p.parseContextStack, int(_CONTEXT_IN_LIST_FIRST)) + p.reader.ReadByte() + isNull = false + } else if len(b) >= len(JSON_NULL) && string(b) == string(JSON_NULL) { + isNull = true + } else { + err = fmt.Errorf("Expected \"null\" or \"[\", received %q", b) + } + return isNull, NewTProtocolExceptionWithType(INVALID_DATA, err) +} + +func (p *TSimpleJSONProtocol) ParseElemListBegin() (elemType TType, size int, e error) { + if isNull, e := p.ParseListBegin(); isNull || e != nil { + return VOID, 0, e + } + bElemType, err := p.ReadByte() + elemType = TType(bElemType) + if err != nil { + return elemType, size, err + } + nSize, err2 := p.ReadI64() + size = int(nSize) + return elemType, size, err2 +} + +func (p *TSimpleJSONProtocol) ParseListEnd() error { + if isNull, err := p.readIfNull(); isNull || err != nil { + return err + } + if _ParseContext(p.parseContextStack[len(p.parseContextStack)-1]) != _CONTEXT_IN_LIST { + e := fmt.Errorf("Expected to be in the List Context, but not in List Context") + return NewTProtocolExceptionWithType(INVALID_DATA, e) + } + line, err := p.reader.ReadString(JSON_RBRACKET[0]) + if err != nil { + return NewTProtocolException(err) + } + for _, char := range line { + switch char { + default: + e := fmt.Errorf("Expecting end of list \"]\", but found: \"", line, "\"") + return NewTProtocolExceptionWithType(INVALID_DATA, e) + case ' ', '\n', '\r', '\t', rune(JSON_RBRACKET[0]): + break + } + } + p.parseContextStack = p.parseContextStack[:len(p.parseContextStack)-1] + return p.ParsePostValue() +} + +func (p *TSimpleJSONProtocol) readSingleValue() (interface{}, TType, error) { + e := p.readNonSignificantWhitespace() + if e != nil { + return nil, VOID, NewTProtocolException(e) + } + b, e := p.reader.Peek(10) + if len(b) > 0 { + c := b[0] + switch c { + case JSON_NULL[0]: + buf := make([]byte, len(JSON_NULL)) + _, e := p.reader.Read(buf) + if e != nil { + return nil, VOID, NewTProtocolException(e) + } + if string(JSON_NULL) != string(buf) { + e = mismatch(string(JSON_NULL), string(buf)) + return nil, VOID, NewTProtocolExceptionWithType(INVALID_DATA, e) + } + return nil, VOID, nil + case JSON_QUOTE: + p.reader.ReadByte() + v, e := p.ParseStringBody() + if e != nil { + return v, UTF8, NewTProtocolException(e) + } + if v == JSON_INFINITY { + return INFINITY, DOUBLE, nil + } else if v == JSON_NEGATIVE_INFINITY { + return NEGATIVE_INFINITY, DOUBLE, nil + } else if v == JSON_NAN { + return NAN, DOUBLE, nil + } + return v, UTF8, nil + case JSON_TRUE[0]: + buf := make([]byte, len(JSON_TRUE)) + _, e := p.reader.Read(buf) + if e != nil { + return true, BOOL, NewTProtocolException(e) + } + if string(JSON_TRUE) != string(buf) { + e := mismatch(string(JSON_TRUE), string(buf)) + return true, BOOL, NewTProtocolExceptionWithType(INVALID_DATA, e) + } + return true, BOOL, nil + case JSON_FALSE[0]: + buf := make([]byte, len(JSON_FALSE)) + _, e := p.reader.Read(buf) + if e != nil { + return false, BOOL, NewTProtocolException(e) + } + if string(JSON_FALSE) != string(buf) { + e := mismatch(string(JSON_FALSE), string(buf)) + return false, BOOL, NewTProtocolExceptionWithType(INVALID_DATA, e) + } + return false, BOOL, nil + case JSON_LBRACKET[0]: + _, e := p.reader.ReadByte() + return make([]interface{}, 0), LIST, NewTProtocolException(e) + case JSON_LBRACE[0]: + _, e := p.reader.ReadByte() + return make(map[string]interface{}), STRUCT, NewTProtocolException(e) + case '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'e', 'E', '.', '+', '-', JSON_INFINITY[0], JSON_NAN[0]: + // assume numeric + v, e := p.readNumeric() + return v, DOUBLE, e + default: + e := fmt.Errorf("Expected element in list but found '%s' while parsing JSON.", string(c)) + return nil, VOID, NewTProtocolExceptionWithType(INVALID_DATA, e) + } + } + e = fmt.Errorf("Cannot read a single element while parsing JSON.") + return nil, VOID, NewTProtocolExceptionWithType(INVALID_DATA, e) + +} + +func (p *TSimpleJSONProtocol) readIfNull() (bool, error) { + cont := true + for cont { + b, _ := p.reader.Peek(1) + if len(b) < 1 { + return false, nil + } + switch b[0] { + default: + return false, nil + case JSON_NULL[0]: + cont = false + break + case ' ', '\n', '\r', '\t': + p.reader.ReadByte() + break + } + } + b, _ := p.reader.Peek(len(JSON_NULL)) + if string(b) == string(JSON_NULL) { + p.reader.Read(b[0:len(JSON_NULL)]) + return true, nil + } + return false, nil +} + +func (p *TSimpleJSONProtocol) readQuoteIfNext() { + b, _ := p.reader.Peek(1) + if len(b) > 0 && b[0] == JSON_QUOTE { + p.reader.ReadByte() + } +} + +func (p *TSimpleJSONProtocol) readNumeric() (Numeric, error) { + isNull, err := p.readIfNull() + if isNull || err != nil { + return NUMERIC_NULL, err + } + hasDecimalPoint := false + nextCanBeSign := true + hasE := false + MAX_LEN := 40 + buf := bytes.NewBuffer(make([]byte, 0, MAX_LEN)) + continueFor := true + inQuotes := false + for continueFor { + c, err := p.reader.ReadByte() + if err != nil { + if err == io.EOF { + break + } + return NUMERIC_NULL, NewTProtocolException(err) + } + switch c { + case '0', '1', '2', '3', '4', '5', '6', '7', '8', '9': + buf.WriteByte(c) + nextCanBeSign = false + case '.': + if hasDecimalPoint { + e := fmt.Errorf("Unable to parse number with multiple decimal points '%s.'", buf.String()) + return NUMERIC_NULL, NewTProtocolExceptionWithType(INVALID_DATA, e) + } + if hasE { + e := fmt.Errorf("Unable to parse number with decimal points in the exponent '%s.'", buf.String()) + return NUMERIC_NULL, NewTProtocolExceptionWithType(INVALID_DATA, e) + } + buf.WriteByte(c) + hasDecimalPoint, nextCanBeSign = true, false + case 'e', 'E': + if hasE { + e := fmt.Errorf("Unable to parse number with multiple exponents '%s%c'", buf.String(), c) + return NUMERIC_NULL, NewTProtocolExceptionWithType(INVALID_DATA, e) + } + buf.WriteByte(c) + hasE, nextCanBeSign = true, true + case '-', '+': + if !nextCanBeSign { + e := fmt.Errorf("Negative sign within number") + return NUMERIC_NULL, NewTProtocolExceptionWithType(INVALID_DATA, e) + } + buf.WriteByte(c) + nextCanBeSign = false + case ' ', 0, '\t', '\n', '\r', JSON_RBRACE[0], JSON_RBRACKET[0], JSON_COMMA[0], JSON_COLON[0]: + p.reader.UnreadByte() + continueFor = false + case JSON_NAN[0]: + if buf.Len() == 0 { + buffer := make([]byte, len(JSON_NAN)) + buffer[0] = c + _, e := p.reader.Read(buffer[1:]) + if e != nil { + return NUMERIC_NULL, NewTProtocolException(e) + } + if JSON_NAN != string(buffer) { + e := mismatch(JSON_NAN, string(buffer)) + return NUMERIC_NULL, NewTProtocolExceptionWithType(INVALID_DATA, e) + } + if inQuotes { + p.readQuoteIfNext() + } + return NAN, nil + } else { + e := fmt.Errorf("Unable to parse number starting with character '%c'", c) + return NUMERIC_NULL, NewTProtocolExceptionWithType(INVALID_DATA, e) + } + case JSON_INFINITY[0]: + if buf.Len() == 0 || (buf.Len() == 1 && buf.Bytes()[0] == '+') { + buffer := make([]byte, len(JSON_INFINITY)) + buffer[0] = c + _, e := p.reader.Read(buffer[1:]) + if e != nil { + return NUMERIC_NULL, NewTProtocolException(e) + } + if JSON_INFINITY != string(buffer) { + e := mismatch(JSON_INFINITY, string(buffer)) + return NUMERIC_NULL, NewTProtocolExceptionWithType(INVALID_DATA, e) + } + if inQuotes { + p.readQuoteIfNext() + } + return INFINITY, nil + } else if buf.Len() == 1 && buf.Bytes()[0] == JSON_NEGATIVE_INFINITY[0] { + buffer := make([]byte, len(JSON_NEGATIVE_INFINITY)) + buffer[0] = JSON_NEGATIVE_INFINITY[0] + buffer[1] = c + _, e := p.reader.Read(buffer[2:]) + if e != nil { + return NUMERIC_NULL, NewTProtocolException(e) + } + if JSON_NEGATIVE_INFINITY != string(buffer) { + e := mismatch(JSON_NEGATIVE_INFINITY, string(buffer)) + return NUMERIC_NULL, NewTProtocolExceptionWithType(INVALID_DATA, e) + } + if inQuotes { + p.readQuoteIfNext() + } + return NEGATIVE_INFINITY, nil + } else { + e := fmt.Errorf("Unable to parse number starting with character '%c' due to existing buffer %s", c, buf.String()) + return NUMERIC_NULL, NewTProtocolExceptionWithType(INVALID_DATA, e) + } + case JSON_QUOTE: + if !inQuotes { + inQuotes = true + } else { + break + } + default: + e := fmt.Errorf("Unable to parse number starting with character '%c'", c) + return NUMERIC_NULL, NewTProtocolExceptionWithType(INVALID_DATA, e) + } + } + if buf.Len() == 0 { + e := fmt.Errorf("Unable to parse number from empty string ''") + return NUMERIC_NULL, NewTProtocolExceptionWithType(INVALID_DATA, e) + } + return NewNumericFromJSONString(buf.String(), false), nil +} diff --git a/thrift/thrift/simple_json_protocol_test.go b/thrift/thrift/simple_json_protocol_test.go new file mode 100644 index 0000000..87a5c64 --- /dev/null +++ b/thrift/thrift/simple_json_protocol_test.go @@ -0,0 +1,632 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import ( + "encoding/base64" + "encoding/json" + "fmt" + "math" + "strconv" + "strings" + "testing" +) + +func TestWriteSimpleJSONProtocolBool(t *testing.T) { + thetype := "boolean" + trans := NewTMemoryBuffer() + p := NewTSimpleJSONProtocol(trans) + for _, value := range BOOL_VALUES { + if e := p.WriteBool(value); e != nil { + t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error()) + } + if e := p.Flush(); e != nil { + t.Fatalf("Unable to write %s value %v due to error flushing: %s", thetype, value, e.Error()) + } + s := trans.String() + if s != fmt.Sprint(value) { + t.Fatalf("Bad value for %s %v: %s", thetype, value, s) + } + v := false + if err := json.Unmarshal([]byte(s), &v); err != nil || v != value { + t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, v) + } + trans.Reset() + } + trans.Close() +} + +func TestReadSimpleJSONProtocolBool(t *testing.T) { + thetype := "boolean" + for _, value := range BOOL_VALUES { + trans := NewTMemoryBuffer() + p := NewTSimpleJSONProtocol(trans) + if value { + trans.Write(JSON_TRUE) + } else { + trans.Write(JSON_FALSE) + } + trans.Flush() + s := trans.String() + v, e := p.ReadBool() + if e != nil { + t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error()) + } + if v != value { + t.Fatalf("Bad value for %s value %v, wrote: %v, received: %v", thetype, value, s, v) + } + if err := json.Unmarshal([]byte(s), &v); err != nil || v != value { + t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, v) + } + trans.Reset() + trans.Close() + } +} + +func TestWriteSimpleJSONProtocolByte(t *testing.T) { + thetype := "byte" + trans := NewTMemoryBuffer() + p := NewTSimpleJSONProtocol(trans) + for _, value := range BYTE_VALUES { + if e := p.WriteByte(value); e != nil { + t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error()) + } + if e := p.Flush(); e != nil { + t.Fatalf("Unable to write %s value %v due to error flushing: %s", thetype, value, e.Error()) + } + s := trans.String() + if s != fmt.Sprint(value) { + t.Fatalf("Bad value for %s %v: %s", thetype, value, s) + } + v := byte(0) + if err := json.Unmarshal([]byte(s), &v); err != nil || v != value { + t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, v) + } + trans.Reset() + } + trans.Close() +} + +func TestReadSimpleJSONProtocolByte(t *testing.T) { + thetype := "byte" + for _, value := range BYTE_VALUES { + trans := NewTMemoryBuffer() + p := NewTSimpleJSONProtocol(trans) + trans.WriteString(strconv.Itoa(int(value))) + trans.Flush() + s := trans.String() + v, e := p.ReadByte() + if e != nil { + t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error()) + } + if v != value { + t.Fatalf("Bad value for %s value %v, wrote: %v, received: %v", thetype, value, s, v) + } + if err := json.Unmarshal([]byte(s), &v); err != nil || v != value { + t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, v) + } + trans.Reset() + trans.Close() + } +} + +func TestWriteSimpleJSONProtocolI16(t *testing.T) { + thetype := "int16" + trans := NewTMemoryBuffer() + p := NewTSimpleJSONProtocol(trans) + for _, value := range INT16_VALUES { + if e := p.WriteI16(value); e != nil { + t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error()) + } + if e := p.Flush(); e != nil { + t.Fatalf("Unable to write %s value %v due to error flushing: %s", thetype, value, e.Error()) + } + s := trans.String() + if s != fmt.Sprint(value) { + t.Fatalf("Bad value for %s %v: %s", thetype, value, s) + } + v := int16(0) + if err := json.Unmarshal([]byte(s), &v); err != nil || v != value { + t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, v) + } + trans.Reset() + } + trans.Close() +} + +func TestReadSimpleJSONProtocolI16(t *testing.T) { + thetype := "int16" + for _, value := range INT16_VALUES { + trans := NewTMemoryBuffer() + p := NewTSimpleJSONProtocol(trans) + trans.WriteString(strconv.Itoa(int(value))) + trans.Flush() + s := trans.String() + v, e := p.ReadI16() + if e != nil { + t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error()) + } + if v != value { + t.Fatalf("Bad value for %s value %v, wrote: %v, received: %v", thetype, value, s, v) + } + if err := json.Unmarshal([]byte(s), &v); err != nil || v != value { + t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, v) + } + trans.Reset() + trans.Close() + } +} + +func TestWriteSimpleJSONProtocolI32(t *testing.T) { + thetype := "int32" + trans := NewTMemoryBuffer() + p := NewTSimpleJSONProtocol(trans) + for _, value := range INT32_VALUES { + if e := p.WriteI32(value); e != nil { + t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error()) + } + if e := p.Flush(); e != nil { + t.Fatalf("Unable to write %s value %v due to error flushing: %s", thetype, value, e.Error()) + } + s := trans.String() + if s != fmt.Sprint(value) { + t.Fatalf("Bad value for %s %v: %s", thetype, value, s) + } + v := int32(0) + if err := json.Unmarshal([]byte(s), &v); err != nil || v != value { + t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, v) + } + trans.Reset() + } + trans.Close() +} + +func TestReadSimpleJSONProtocolI32(t *testing.T) { + thetype := "int32" + for _, value := range INT32_VALUES { + trans := NewTMemoryBuffer() + p := NewTSimpleJSONProtocol(trans) + trans.WriteString(strconv.Itoa(int(value))) + trans.Flush() + s := trans.String() + v, e := p.ReadI32() + if e != nil { + t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error()) + } + if v != value { + t.Fatalf("Bad value for %s value %v, wrote: %v, received: %v", thetype, value, s, v) + } + if err := json.Unmarshal([]byte(s), &v); err != nil || v != value { + t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, v) + } + trans.Reset() + trans.Close() + } +} + +func TestWriteSimpleJSONProtocolI64(t *testing.T) { + thetype := "int64" + trans := NewTMemoryBuffer() + p := NewTSimpleJSONProtocol(trans) + for _, value := range INT64_VALUES { + if e := p.WriteI64(value); e != nil { + t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error()) + } + if e := p.Flush(); e != nil { + t.Fatalf("Unable to write %s value %v due to error flushing: %s", thetype, value, e.Error()) + } + s := trans.String() + if s != fmt.Sprint(value) { + t.Fatalf("Bad value for %s %v: %s", thetype, value, s) + } + v := int64(0) + if err := json.Unmarshal([]byte(s), &v); err != nil || v != value { + t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, v) + } + trans.Reset() + } + trans.Close() +} + +func TestReadSimpleJSONProtocolI64(t *testing.T) { + thetype := "int64" + for _, value := range INT64_VALUES { + trans := NewTMemoryBuffer() + p := NewTSimpleJSONProtocol(trans) + trans.WriteString(strconv.FormatInt(value, 10)) + trans.Flush() + s := trans.String() + v, e := p.ReadI64() + if e != nil { + t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error()) + } + if v != value { + t.Fatalf("Bad value for %s value %v, wrote: %v, received: %v", thetype, value, s, v) + } + if err := json.Unmarshal([]byte(s), &v); err != nil || v != value { + t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, v) + } + trans.Reset() + trans.Close() + } +} + +func TestWriteSimpleJSONProtocolDouble(t *testing.T) { + thetype := "double" + trans := NewTMemoryBuffer() + p := NewTSimpleJSONProtocol(trans) + for _, value := range DOUBLE_VALUES { + if e := p.WriteDouble(value); e != nil { + t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error()) + } + if e := p.Flush(); e != nil { + t.Fatalf("Unable to write %s value %v due to error flushing: %s", thetype, value, e.Error()) + } + s := trans.String() + if math.IsInf(value, 1) { + if s != jsonQuote(JSON_INFINITY) { + t.Fatalf("Bad value for %s %v, wrote: %v, expected: %v", thetype, value, s, jsonQuote(JSON_INFINITY)) + } + } else if math.IsInf(value, -1) { + if s != jsonQuote(JSON_NEGATIVE_INFINITY) { + t.Fatalf("Bad value for %s %v, wrote: %v, expected: %v", thetype, value, s, jsonQuote(JSON_NEGATIVE_INFINITY)) + } + } else if math.IsNaN(value) { + if s != jsonQuote(JSON_NAN) { + t.Fatalf("Bad value for %s %v, wrote: %v, expected: %v", thetype, value, s, jsonQuote(JSON_NAN)) + } + } else { + if s != fmt.Sprint(value) { + t.Fatalf("Bad value for %s %v: %s", thetype, value, s) + } + v := float64(0) + if err := json.Unmarshal([]byte(s), &v); err != nil || v != value { + t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, v) + } + } + trans.Reset() + } + trans.Close() +} + +func TestReadSimpleJSONProtocolDouble(t *testing.T) { + thetype := "double" + for _, value := range DOUBLE_VALUES { + trans := NewTMemoryBuffer() + p := NewTSimpleJSONProtocol(trans) + n := NewNumericFromDouble(value) + trans.WriteString(n.String()) + trans.Flush() + s := trans.String() + v, e := p.ReadDouble() + if e != nil { + t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error()) + } + if math.IsInf(value, 1) { + if !math.IsInf(v, 1) { + t.Fatalf("Bad value for %s %v, wrote: %v, received: %v", thetype, value, s, v) + } + } else if math.IsInf(value, -1) { + if !math.IsInf(v, -1) { + t.Fatalf("Bad value for %s %v, wrote: %v, received: %v", thetype, value, s, v) + } + } else if math.IsNaN(value) { + if !math.IsNaN(v) { + t.Fatalf("Bad value for %s %v, wrote: %v, received: %v", thetype, value, s, v) + } + } else { + if v != value { + t.Fatalf("Bad value for %s value %v, wrote: %v, received: %v", thetype, value, s, v) + } + if err := json.Unmarshal([]byte(s), &v); err != nil || v != value { + t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, v) + } + } + trans.Reset() + trans.Close() + } +} + +func TestWriteSimpleJSONProtocolString(t *testing.T) { + thetype := "string" + trans := NewTMemoryBuffer() + p := NewTSimpleJSONProtocol(trans) + for _, value := range STRING_VALUES { + if e := p.WriteString(value); e != nil { + t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error()) + } + if e := p.Flush(); e != nil { + t.Fatalf("Unable to write %s value %v due to error flushing: %s", thetype, value, e.Error()) + } + s := trans.String() + if s[0] != '"' || s[len(s)-1] != '"' { + t.Fatalf("Bad value for %s '%v', wrote '%v', expected: %v", thetype, value, s, fmt.Sprint("\"", value, "\"")) + } + v := new(string) + if err := json.Unmarshal([]byte(s), v); err != nil || *v != value { + t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, *v) + } + trans.Reset() + } + trans.Close() +} + +func TestReadSimpleJSONProtocolString(t *testing.T) { + thetype := "string" + for _, value := range STRING_VALUES { + trans := NewTMemoryBuffer() + p := NewTSimpleJSONProtocol(trans) + trans.WriteString(jsonQuote(value)) + trans.Flush() + s := trans.String() + v, e := p.ReadString() + if e != nil { + t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error()) + } + if v != value { + t.Fatalf("Bad value for %s value %v, wrote: %v, received: %v", thetype, value, s, v) + } + v1 := new(string) + if err := json.Unmarshal([]byte(s), v1); err != nil || *v1 != value { + t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, *v1) + } + trans.Reset() + trans.Close() + } +} + +func TestWriteSimpleJSONProtocolBinary(t *testing.T) { + thetype := "binary" + value := protocol_bdata + b64value := make([]byte, base64.StdEncoding.EncodedLen(len(protocol_bdata))) + base64.StdEncoding.Encode(b64value, value) + b64String := string(b64value) + trans := NewTMemoryBuffer() + p := NewTSimpleJSONProtocol(trans) + if e := p.WriteBinary(value); e != nil { + t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error()) + } + if e := p.Flush(); e != nil { + t.Fatalf("Unable to write %s value %v due to error flushing: %s", thetype, value, e.Error()) + } + s := trans.String() + if s != fmt.Sprint("\"", b64String, "\"") { + t.Fatalf("Bad value for %s %v\n wrote: %v\nexpected: %v", thetype, value, s, "\""+b64String+"\"") + } + v1 := new(string) + if err := json.Unmarshal([]byte(s), v1); err != nil || *v1 != b64String { + t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, *v1) + } + trans.Close() +} + +func TestReadSimpleJSONProtocolBinary(t *testing.T) { + thetype := "binary" + value := protocol_bdata + b64value := make([]byte, base64.StdEncoding.EncodedLen(len(protocol_bdata))) + base64.StdEncoding.Encode(b64value, value) + b64String := string(b64value) + trans := NewTMemoryBuffer() + p := NewTSimpleJSONProtocol(trans) + trans.WriteString(jsonQuote(b64String)) + trans.Flush() + s := trans.String() + v, e := p.ReadBinary() + if e != nil { + t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error()) + } + if len(v) != len(value) { + t.Fatalf("Bad value for %s value length %v, wrote: %v, received length: %v", thetype, len(value), s, len(v)) + } + for i := 0; i < len(v); i++ { + if v[i] != value[i] { + t.Fatalf("Bad value for %s at index %d value %v, wrote: %v, received: %v", thetype, i, value[i], s, v[i]) + } + } + v1 := new(string) + if err := json.Unmarshal([]byte(s), v1); err != nil || *v1 != b64String { + t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, *v1) + } + trans.Reset() + trans.Close() +} + +func TestWriteSimpleJSONProtocolList(t *testing.T) { + thetype := "list" + trans := NewTMemoryBuffer() + p := NewTSimpleJSONProtocol(trans) + p.WriteListBegin(TType(DOUBLE), len(DOUBLE_VALUES)) + for _, value := range DOUBLE_VALUES { + if e := p.WriteDouble(value); e != nil { + t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error()) + } + } + p.WriteListEnd() + if e := p.Flush(); e != nil { + t.Fatalf("Unable to write %s due to error flushing: %s", thetype, e.Error()) + } + str := trans.String() + str1 := new([]interface{}) + err := json.Unmarshal([]byte(str), str1) + if err != nil { + t.Fatalf("Unable to decode %s, wrote: %s", thetype, str) + } + l := *str1 + if len(l) < 2 { + t.Fatalf("List must be at least of length two to include metadata") + } + if int(l[0].(float64)) != DOUBLE { + t.Fatal("Invalid type for list, expected: ", DOUBLE, ", but was: ", l[0]) + } + if int(l[1].(float64)) != len(DOUBLE_VALUES) { + t.Fatal("Invalid length for list, expected: ", len(DOUBLE_VALUES), ", but was: ", l[1]) + } + for k, value := range DOUBLE_VALUES { + s := l[k+2] + if math.IsInf(value, 1) { + if s.(string) != JSON_INFINITY { + t.Fatalf("Bad value for %s at index %v %v, wrote: %q, expected: %q, originally wrote: %q", thetype, k, value, s, jsonQuote(JSON_INFINITY), str) + } + } else if math.IsInf(value, 0) { + if s.(string) != JSON_NEGATIVE_INFINITY { + t.Fatalf("Bad value for %s at index %v %v, wrote: %q, expected: %q, originally wrote: %q", thetype, k, value, s, jsonQuote(JSON_NEGATIVE_INFINITY), str) + } + } else if math.IsNaN(value) { + if s.(string) != JSON_NAN { + t.Fatalf("Bad value for %s at index %v %v, wrote: %q, expected: %q, originally wrote: %q", thetype, k, value, s, jsonQuote(JSON_NAN), str) + } + } else { + if s.(float64) != value { + t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s'", thetype, value, s) + } + } + trans.Reset() + } + trans.Close() +} + +func TestWriteSimpleJSONProtocolSet(t *testing.T) { + thetype := "set" + trans := NewTMemoryBuffer() + p := NewTSimpleJSONProtocol(trans) + p.WriteSetBegin(TType(DOUBLE), len(DOUBLE_VALUES)) + for _, value := range DOUBLE_VALUES { + if e := p.WriteDouble(value); e != nil { + t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error()) + } + } + p.WriteSetEnd() + if e := p.Flush(); e != nil { + t.Fatalf("Unable to write %s due to error flushing: %s", thetype, e.Error()) + } + str := trans.String() + str1 := new([]interface{}) + err := json.Unmarshal([]byte(str), str1) + if err != nil { + t.Fatalf("Unable to decode %s, wrote: %s", thetype, str) + } + l := *str1 + if len(l) < 2 { + t.Fatalf("Set must be at least of length two to include metadata") + } + if int(l[0].(float64)) != DOUBLE { + t.Fatal("Invalid type for set, expected: ", DOUBLE, ", but was: ", l[0]) + } + if int(l[1].(float64)) != len(DOUBLE_VALUES) { + t.Fatal("Invalid length for set, expected: ", len(DOUBLE_VALUES), ", but was: ", l[1]) + } + for k, value := range DOUBLE_VALUES { + s := l[k+2] + if math.IsInf(value, 1) { + if s.(string) != JSON_INFINITY { + t.Fatalf("Bad value for %s at index %v %v, wrote: %q, expected: %q, originally wrote: %q", thetype, k, value, s, jsonQuote(JSON_INFINITY), str) + } + } else if math.IsInf(value, 0) { + if s.(string) != JSON_NEGATIVE_INFINITY { + t.Fatalf("Bad value for %s at index %v %v, wrote: %q, expected: %q, originally wrote: %q", thetype, k, value, s, jsonQuote(JSON_NEGATIVE_INFINITY), str) + } + } else if math.IsNaN(value) { + if s.(string) != JSON_NAN { + t.Fatalf("Bad value for %s at index %v %v, wrote: %q, expected: %q, originally wrote: %q", thetype, k, value, s, jsonQuote(JSON_NAN), str) + } + } else { + if s.(float64) != value { + t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s'", thetype, value, s) + } + } + trans.Reset() + } + trans.Close() +} + +func TestWriteSimpleJSONProtocolMap(t *testing.T) { + thetype := "map" + trans := NewTMemoryBuffer() + p := NewTSimpleJSONProtocol(trans) + p.WriteMapBegin(TType(I32), TType(DOUBLE), len(DOUBLE_VALUES)) + for k, value := range DOUBLE_VALUES { + if e := p.WriteI32(int32(k)); e != nil { + t.Fatalf("Unable to write %s key int32 value %v due to error: %s", thetype, k, e.Error()) + } + if e := p.WriteDouble(value); e != nil { + t.Fatalf("Unable to write %s value float64 value %v due to error: %s", thetype, value, e.Error()) + } + } + p.WriteMapEnd() + if e := p.Flush(); e != nil { + t.Fatalf("Unable to write %s due to error flushing: %s", thetype, e.Error()) + } + str := trans.String() + if str[0] != '[' || str[len(str)-1] != ']' { + t.Fatalf("Bad value for %s, wrote: %q, in go: %q", thetype, str, DOUBLE_VALUES) + } + l := strings.Split(str[1:len(str)-1], ",") + if len(l) < 3 { + t.Fatal("Expected list of at least length 3 for map for metadata, but was of length ", len(l)) + } + expectedKeyType, _ := strconv.Atoi(l[0]) + expectedValueType, _ := strconv.Atoi(l[1]) + expectedSize, _ := strconv.Atoi(l[2]) + if expectedKeyType != I32 { + t.Fatal("Expected map key type ", I32, ", but was ", l[0]) + } + if expectedValueType != DOUBLE { + t.Fatal("Expected map value type ", DOUBLE, ", but was ", l[1]) + } + if expectedSize != len(DOUBLE_VALUES) { + t.Fatal("Expected map size of ", len(DOUBLE_VALUES), ", but was ", l[2]) + } + for k, value := range DOUBLE_VALUES { + strk := l[k*2+3] + strv := l[k*2+4] + ik, err := strconv.Atoi(strk) + if err != nil { + t.Fatalf("Bad value for %s index %v, wrote: %v, expected: %v, error: %s", thetype, k, strk, string(k), err.Error()) + } + if ik != k { + t.Fatalf("Bad value for %s index %v, wrote: %v, expected: %v", thetype, k, strk, k) + } + s := strv + if math.IsInf(value, 1) { + if s != jsonQuote(JSON_INFINITY) { + t.Fatalf("Bad value for %s at index %v %v, wrote: %v, expected: %v", thetype, k, value, s, jsonQuote(JSON_INFINITY)) + } + } else if math.IsInf(value, 0) { + if s != jsonQuote(JSON_NEGATIVE_INFINITY) { + t.Fatalf("Bad value for %s at index %v %v, wrote: %v, expected: %v", thetype, k, value, s, jsonQuote(JSON_NEGATIVE_INFINITY)) + } + } else if math.IsNaN(value) { + if s != jsonQuote(JSON_NAN) { + t.Fatalf("Bad value for %s at index %v %v, wrote: %v, expected: %v", thetype, k, value, s, jsonQuote(JSON_NAN)) + } + } else { + expected := strconv.FormatFloat(value, 'g', 10, 64) + if s != expected { + t.Fatalf("Bad value for %s at index %v %v, wrote: %v, expected %v", thetype, k, value, s, expected) + } + v := float64(0) + if err := json.Unmarshal([]byte(s), &v); err != nil || v != value { + t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, v) + } + } + trans.Reset() + } + trans.Close() +} diff --git a/thrift/thrift/simple_server.go b/thrift/thrift/simple_server.go new file mode 100644 index 0000000..9a27215 --- /dev/null +++ b/thrift/thrift/simple_server.go @@ -0,0 +1,188 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import ( + "log" + "runtime/debug" +) + +// Simple, non-concurrent server for testing. +type TSimpleServer struct { + quit chan struct{} + + processorFactory TProcessorFactory + serverTransport TServerTransport + inputTransportFactory TTransportFactory + outputTransportFactory TTransportFactory + inputProtocolFactory TProtocolFactory + outputProtocolFactory TProtocolFactory +} + +func NewTSimpleServer2(processor TProcessor, serverTransport TServerTransport) *TSimpleServer { + return NewTSimpleServerFactory2(NewTProcessorFactory(processor), serverTransport) +} + +func NewTSimpleServer4(processor TProcessor, serverTransport TServerTransport, transportFactory TTransportFactory, protocolFactory TProtocolFactory) *TSimpleServer { + return NewTSimpleServerFactory4(NewTProcessorFactory(processor), + serverTransport, + transportFactory, + protocolFactory, + ) +} + +func NewTSimpleServer6(processor TProcessor, serverTransport TServerTransport, inputTransportFactory TTransportFactory, outputTransportFactory TTransportFactory, inputProtocolFactory TProtocolFactory, outputProtocolFactory TProtocolFactory) *TSimpleServer { + return NewTSimpleServerFactory6(NewTProcessorFactory(processor), + serverTransport, + inputTransportFactory, + outputTransportFactory, + inputProtocolFactory, + outputProtocolFactory, + ) +} + +func NewTSimpleServerFactory2(processorFactory TProcessorFactory, serverTransport TServerTransport) *TSimpleServer { + return NewTSimpleServerFactory6(processorFactory, + serverTransport, + NewTTransportFactory(), + NewTTransportFactory(), + NewTBinaryProtocolFactoryDefault(), + NewTBinaryProtocolFactoryDefault(), + ) +} + +func NewTSimpleServerFactory4(processorFactory TProcessorFactory, serverTransport TServerTransport, transportFactory TTransportFactory, protocolFactory TProtocolFactory) *TSimpleServer { + return NewTSimpleServerFactory6(processorFactory, + serverTransport, + transportFactory, + transportFactory, + protocolFactory, + protocolFactory, + ) +} + +func NewTSimpleServerFactory6(processorFactory TProcessorFactory, serverTransport TServerTransport, inputTransportFactory TTransportFactory, outputTransportFactory TTransportFactory, inputProtocolFactory TProtocolFactory, outputProtocolFactory TProtocolFactory) *TSimpleServer { + return &TSimpleServer{ + processorFactory: processorFactory, + serverTransport: serverTransport, + inputTransportFactory: inputTransportFactory, + outputTransportFactory: outputTransportFactory, + inputProtocolFactory: inputProtocolFactory, + outputProtocolFactory: outputProtocolFactory, + quit: make(chan struct{}, 1), + } +} + +func (p *TSimpleServer) ProcessorFactory() TProcessorFactory { + return p.processorFactory +} + +func (p *TSimpleServer) ServerTransport() TServerTransport { + return p.serverTransport +} + +func (p *TSimpleServer) InputTransportFactory() TTransportFactory { + return p.inputTransportFactory +} + +func (p *TSimpleServer) OutputTransportFactory() TTransportFactory { + return p.outputTransportFactory +} + +func (p *TSimpleServer) InputProtocolFactory() TProtocolFactory { + return p.inputProtocolFactory +} + +func (p *TSimpleServer) OutputProtocolFactory() TProtocolFactory { + return p.outputProtocolFactory +} + +func (p *TSimpleServer) Listen() error { + return p.serverTransport.Listen() +} + +func (p *TSimpleServer) AcceptLoop() error { + for { + select { + case <-p.quit: + return nil + default: + } + + client, err := p.serverTransport.Accept() + if err != nil { + log.Println("Accept err: ", err) + } + if client != nil { + go func() { + if err := p.processRequests(client); err != nil { + log.Println("error processing request:", err) + } + }() + } + } +} + +func (p *TSimpleServer) Serve() error { + err := p.Listen() + if err != nil { + return err + } + p.AcceptLoop() + return nil +} + +func (p *TSimpleServer) Stop() error { + p.quit <- struct{}{} + p.serverTransport.Interrupt() + return nil +} + +func (p *TSimpleServer) processRequests(client TTransport) error { + processor := p.processorFactory.GetProcessor(client) + inputTransport := p.inputTransportFactory.GetTransport(client) + outputTransport := p.outputTransportFactory.GetTransport(client) + inputProtocol := p.inputProtocolFactory.GetProtocol(inputTransport) + outputProtocol := p.outputProtocolFactory.GetProtocol(outputTransport) + defer func() { + if e := recover(); e != nil { + log.Printf("panic in processor: %s: %s", e, debug.Stack()) + } + }() + if inputTransport != nil { + defer inputTransport.Close() + } + if outputTransport != nil { + defer outputTransport.Close() + } + for { + ok, err := processor.Process(inputProtocol, outputProtocol) + if err, ok := err.(TTransportException); ok && err.TypeId() == END_OF_FILE { + return nil + } else if err != nil { + log.Printf("error processing request: %s", err) + return err + } + if !ok { + break + } + } + return nil +} diff --git a/thrift/thrift/socket.go b/thrift/thrift/socket.go new file mode 100644 index 0000000..a381ea2 --- /dev/null +++ b/thrift/thrift/socket.go @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import ( + "net" + "time" +) + +type TSocket struct { + conn net.Conn + addr net.Addr + timeout time.Duration +} + +// NewTSocket creates a net.Conn-backed TTransport, given a host and port +// +// Example: +// trans, err := thrift.NewTSocket("localhost:9090") +func NewTSocket(hostPort string) (*TSocket, error) { + return NewTSocketTimeout(hostPort, 0) +} + +// NewTSocketTimeout creates a net.Conn-backed TTransport, given a host and port +// it also accepts a timeout as a time.Duration +func NewTSocketTimeout(hostPort string, timeout time.Duration) (*TSocket, error) { + //conn, err := net.DialTimeout(network, address, timeout) + addr, err := net.ResolveTCPAddr("tcp", hostPort) + if err != nil { + return nil, err + } + return NewTSocketFromAddrTimeout(addr, timeout), nil +} + +// Creates a TSocket from a net.Addr +func NewTSocketFromAddrTimeout(addr net.Addr, timeout time.Duration) *TSocket { + return &TSocket{addr: addr, timeout: timeout} +} + +// Creates a TSocket from an existing net.Conn +func NewTSocketFromConnTimeout(conn net.Conn, timeout time.Duration) *TSocket { + return &TSocket{conn: conn, addr: conn.RemoteAddr(), timeout: timeout} +} + +// Sets the socket timeout +func (p *TSocket) SetTimeout(timeout time.Duration) error { + p.timeout = timeout + return nil +} + +func (p *TSocket) pushDeadline(read, write bool) { + var t time.Time + if p.timeout > 0 { + t = time.Now().Add(time.Duration(p.timeout)) + } + if read && write { + p.conn.SetDeadline(t) + } else if read { + p.conn.SetReadDeadline(t) + } else if write { + p.conn.SetWriteDeadline(t) + } +} + +// Connects the socket, creating a new socket object if necessary. +func (p *TSocket) Open() error { + if p.IsOpen() { + return NewTTransportException(ALREADY_OPEN, "Socket already connected.") + } + if p.addr == nil { + return NewTTransportException(NOT_OPEN, "Cannot open nil address.") + } + if len(p.addr.Network()) == 0 { + return NewTTransportException(NOT_OPEN, "Cannot open bad network name.") + } + if len(p.addr.String()) == 0 { + return NewTTransportException(NOT_OPEN, "Cannot open bad address.") + } + var err error + if p.conn, err = net.DialTimeout(p.addr.Network(), p.addr.String(), p.timeout); err != nil { + return NewTTransportException(NOT_OPEN, err.Error()) + } + return nil +} + +// Retreive the underlying net.Conn +func (p *TSocket) Conn() net.Conn { + return p.conn +} + +// Returns true if the connection is open +func (p *TSocket) IsOpen() bool { + if p.conn == nil { + return false + } + return true +} + +// Closes the socket. +func (p *TSocket) Close() error { + // Close the socket + if p.conn != nil { + err := p.conn.Close() + if err != nil { + return err + } + p.conn = nil + } + return nil +} + +func (p *TSocket) Read(buf []byte) (int, error) { + if !p.IsOpen() { + return 0, NewTTransportException(NOT_OPEN, "Connection not open") + } + p.pushDeadline(true, false) + n, err := p.conn.Read(buf) + return n, NewTTransportExceptionFromError(err) +} + +func (p *TSocket) Write(buf []byte) (int, error) { + if !p.IsOpen() { + return 0, NewTTransportException(NOT_OPEN, "Connection not open") + } + p.pushDeadline(false, true) + return p.conn.Write(buf) +} + +func (p *TSocket) Peek() bool { + return p.IsOpen() +} + +func (p *TSocket) Flush() error { + return nil +} + +func (p *TSocket) Interrupt() error { + if !p.IsOpen() { + return nil + } + return p.conn.Close() +} diff --git a/thrift/thrift/ssl_server_socket.go b/thrift/thrift/ssl_server_socket.go new file mode 100644 index 0000000..0615528 --- /dev/null +++ b/thrift/thrift/ssl_server_socket.go @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import ( + "crypto/tls" + "net" + "time" +) + +type TSSLServerSocket struct { + listener net.Listener + addr net.Addr + clientTimeout time.Duration + interrupted bool + cfg *tls.Config +} + +func NewTSSLServerSocket(listenAddr string, cfg *tls.Config) (*TSSLServerSocket, error) { + return NewTSSLServerSocketTimeout(listenAddr, cfg, 0) +} + +func NewTSSLServerSocketTimeout(listenAddr string, cfg *tls.Config, clientTimeout time.Duration) (*TSSLServerSocket, error) { + addr, err := net.ResolveTCPAddr("tcp", listenAddr) + if err != nil { + return nil, err + } + return &TSSLServerSocket{addr: addr, clientTimeout: clientTimeout, cfg: cfg}, nil +} + +func (p *TSSLServerSocket) Listen() error { + if p.IsListening() { + return nil + } + l, err := tls.Listen(p.addr.Network(), p.addr.String(), p.cfg) + if err != nil { + return err + } + p.listener = l + return nil +} + +func (p *TSSLServerSocket) Accept() (TTransport, error) { + if p.interrupted { + return nil, errTransportInterrupted + } + if p.listener == nil { + return nil, NewTTransportException(NOT_OPEN, "No underlying server socket") + } + conn, err := p.listener.Accept() + if err != nil { + return nil, NewTTransportExceptionFromError(err) + } + return NewTSSLSocketFromConnTimeout(conn, p.cfg, p.clientTimeout), nil +} + +// Checks whether the socket is listening. +func (p *TSSLServerSocket) IsListening() bool { + return p.listener != nil +} + +// Connects the socket, creating a new socket object if necessary. +func (p *TSSLServerSocket) Open() error { + if p.IsListening() { + return NewTTransportException(ALREADY_OPEN, "Server socket already open") + } + if l, err := tls.Listen(p.addr.Network(), p.addr.String(), p.cfg); err != nil { + return err + } else { + p.listener = l + } + return nil +} + +func (p *TSSLServerSocket) Addr() net.Addr { + return p.addr +} + +func (p *TSSLServerSocket) Close() error { + defer func() { + p.listener = nil + }() + if p.IsListening() { + return p.listener.Close() + } + return nil +} + +func (p *TSSLServerSocket) Interrupt() error { + p.interrupted = true + return nil +} diff --git a/thrift/thrift/ssl_socket.go b/thrift/thrift/ssl_socket.go new file mode 100644 index 0000000..f831e18 --- /dev/null +++ b/thrift/thrift/ssl_socket.go @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import ( + "crypto/tls" + "net" + "time" +) + +type TSSLSocket struct { + conn net.Conn + addr net.Addr + timeout time.Duration + cfg *tls.Config +} + +// NewTSSLSocket creates a net.Conn-backed TTransport, given a host and port and tls Configuration +// +// Example: +// trans, err := thrift.NewTSocket("localhost:9090") +func NewTSSLSocket(hostPort string, cfg *tls.Config) (*TSSLSocket, error) { + return NewTSSLSocketTimeout(hostPort, cfg, 0) +} + +// NewTSSLSocketTimeout creates a net.Conn-backed TTransport, given a host and port +// it also accepts a tls Configuration and a timeout as a time.Duration +func NewTSSLSocketTimeout(hostPort string, cfg *tls.Config, timeout time.Duration) (*TSSLSocket, error) { + //conn, err := net.DialTimeout(network, address, timeout) + addr, err := net.ResolveTCPAddr("tcp", hostPort) + if err != nil { + return nil, err + } + return NewTSSLSocketFromAddrTimeout(addr, cfg, timeout), nil +} + +// Creates a TSSLSocket from a net.Addr +func NewTSSLSocketFromAddrTimeout(addr net.Addr, cfg *tls.Config, timeout time.Duration) *TSSLSocket { + return &TSSLSocket{addr: addr, timeout: timeout, cfg: cfg} +} + +// Creates a TSSLSocket from an existing net.Conn +func NewTSSLSocketFromConnTimeout(conn net.Conn, cfg *tls.Config, timeout time.Duration) *TSSLSocket { + return &TSSLSocket{conn: conn, addr: conn.RemoteAddr(), timeout: timeout, cfg: cfg} +} + +// Sets the socket timeout +func (p *TSSLSocket) SetTimeout(timeout time.Duration) error { + p.timeout = timeout + return nil +} + +func (p *TSSLSocket) pushDeadline(read, write bool) { + var t time.Time + if p.timeout > 0 { + t = time.Now().Add(time.Duration(p.timeout)) + } + if read && write { + p.conn.SetDeadline(t) + } else if read { + p.conn.SetReadDeadline(t) + } else if write { + p.conn.SetWriteDeadline(t) + } +} + +// Connects the socket, creating a new socket object if necessary. +func (p *TSSLSocket) Open() error { + if p.IsOpen() { + return NewTTransportException(ALREADY_OPEN, "Socket already connected.") + } + if p.addr == nil { + return NewTTransportException(NOT_OPEN, "Cannot open nil address.") + } + if len(p.addr.Network()) == 0 { + return NewTTransportException(NOT_OPEN, "Cannot open bad network name.") + } + if len(p.addr.String()) == 0 { + return NewTTransportException(NOT_OPEN, "Cannot open bad address.") + } + var err error + if p.conn, err = tls.Dial(p.addr.Network(), p.addr.String(), p.cfg); err != nil { + return NewTTransportException(NOT_OPEN, err.Error()) + } + return nil +} + +// Retreive the underlying net.Conn +func (p *TSSLSocket) Conn() net.Conn { + return p.conn +} + +// Returns true if the connection is open +func (p *TSSLSocket) IsOpen() bool { + if p.conn == nil { + return false + } + return true +} + +// Closes the socket. +func (p *TSSLSocket) Close() error { + // Close the socket + if p.conn != nil { + err := p.conn.Close() + if err != nil { + return err + } + p.conn = nil + } + return nil +} + +func (p *TSSLSocket) Read(buf []byte) (int, error) { + if !p.IsOpen() { + return 0, NewTTransportException(NOT_OPEN, "Connection not open") + } + p.pushDeadline(true, false) + n, err := p.conn.Read(buf) + return n, NewTTransportExceptionFromError(err) +} + +func (p *TSSLSocket) Write(buf []byte) (int, error) { + if !p.IsOpen() { + return 0, NewTTransportException(NOT_OPEN, "Connection not open") + } + p.pushDeadline(false, true) + return p.conn.Write(buf) +} + +func (p *TSSLSocket) Peek() bool { + return p.IsOpen() +} + +func (p *TSSLSocket) Flush() error { + return nil +} + +func (p *TSSLSocket) Interrupt() error { + if !p.IsOpen() { + return nil + } + return p.conn.Close() +} diff --git a/thrift/thrift/transport.go b/thrift/thrift/transport.go new file mode 100644 index 0000000..8c0622d --- /dev/null +++ b/thrift/thrift/transport.go @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import ( + "errors" + "io" +) + +var errTransportInterrupted = errors.New("Transport Interrupted") + +type Flusher interface { + Flush() (err error) +} + +// Encapsulates the I/O layer +type TTransport interface { + io.ReadWriteCloser + Flusher + + // Opens the transport for communication + Open() error + + // Returns true if the transport is open + IsOpen() bool +} + +type stringWriter interface { + WriteString(s string) (n int, err error) +} + +// This is "enchanced" transport with extra capabilities. You need to use one of these +// to construct protocol. +// Notably, TSocket does not implement this interface, and it is always a mistake to use +// TSocket directly in protocol. +type TRichTransport interface { + io.ReadWriter + io.ByteReader + io.ByteWriter + stringWriter + Flusher +} diff --git a/thrift/thrift/transport_exception.go b/thrift/thrift/transport_exception.go new file mode 100644 index 0000000..9505b44 --- /dev/null +++ b/thrift/thrift/transport_exception.go @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import ( + "errors" + "io" +) + +type timeoutable interface { + Timeout() bool +} + +// Thrift Transport exception +type TTransportException interface { + TException + TypeId() int + Err() error +} + +const ( + UNKNOWN_TRANSPORT_EXCEPTION = 0 + NOT_OPEN = 1 + ALREADY_OPEN = 2 + TIMED_OUT = 3 + END_OF_FILE = 4 +) + +type tTransportException struct { + typeId int + err error +} + +func (p *tTransportException) TypeId() int { + return p.typeId +} + +func (p *tTransportException) Error() string { + return p.err.Error() +} + +func (p *tTransportException) Err() error { + return p.err +} + +func NewTTransportException(t int, e string) TTransportException { + return &tTransportException{typeId: t, err: errors.New(e)} +} + +func NewTTransportExceptionFromError(e error) TTransportException { + if e == nil { + return nil + } + + if t, ok := e.(TTransportException); ok { + return t + } + + switch v := e.(type) { + case TTransportException: + return v + case timeoutable: + if v.Timeout() { + return &tTransportException{typeId: TIMED_OUT, err: e} + } + } + + if e == io.EOF { + return &tTransportException{typeId: END_OF_FILE, err: e} + } + + return &tTransportException{typeId: UNKNOWN_TRANSPORT_EXCEPTION, err: e} +} diff --git a/thrift/thrift/transport_exception_test.go b/thrift/thrift/transport_exception_test.go new file mode 100644 index 0000000..b44314f --- /dev/null +++ b/thrift/thrift/transport_exception_test.go @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import ( + "fmt" + "io" + + "testing" +) + +type timeout struct{ timedout bool } + +func (t *timeout) Timeout() bool { + return t.timedout +} + +func (t *timeout) Error() string { + return fmt.Sprintf("Timeout: %v", t.timedout) +} + +func TestTExceptionTimeout(t *testing.T) { + timeout := &timeout{true} + exception := NewTTransportExceptionFromError(timeout) + if timeout.Error() != exception.Error() { + t.Fatalf("Error did not match: expected %q, got %q", timeout.Error(), exception.Error()) + } + + if exception.TypeId() != TIMED_OUT { + t.Fatalf("TypeId was not TIMED_OUT: expected %v, got %v", TIMED_OUT, exception.TypeId()) + } +} + +func TestTExceptionEOF(t *testing.T) { + exception := NewTTransportExceptionFromError(io.EOF) + if io.EOF.Error() != exception.Error() { + t.Fatalf("Error did not match: expected %q, got %q", io.EOF.Error(), exception.Error()) + } + + if exception.TypeId() != END_OF_FILE { + t.Fatalf("TypeId was not END_OF_FILE: expected %v, got %v", END_OF_FILE, exception.TypeId()) + } +} diff --git a/thrift/thrift/transport_factory.go b/thrift/thrift/transport_factory.go new file mode 100644 index 0000000..533d1b4 --- /dev/null +++ b/thrift/thrift/transport_factory.go @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +// Factory class used to create wrapped instance of Transports. +// This is used primarily in servers, which get Transports from +// a ServerTransport and then may want to mutate them (i.e. create +// a BufferedTransport from the underlying base transport) +type TTransportFactory interface { + GetTransport(trans TTransport) TTransport +} + +type tTransportFactory struct{} + +// Return a wrapped instance of the base Transport. +func (p *tTransportFactory) GetTransport(trans TTransport) TTransport { + return trans +} + +func NewTTransportFactory() TTransportFactory { + return &tTransportFactory{} +} diff --git a/thrift/thrift/transport_test.go b/thrift/thrift/transport_test.go new file mode 100644 index 0000000..864958a --- /dev/null +++ b/thrift/thrift/transport_test.go @@ -0,0 +1,176 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import ( + "io" + "net" + "strconv" + "testing" +) + +const TRANSPORT_BINARY_DATA_SIZE = 4096 + +var ( + transport_bdata []byte // test data for writing; same as data + transport_header map[string]string +) + +func init() { + transport_bdata = make([]byte, TRANSPORT_BINARY_DATA_SIZE) + for i := 0; i < TRANSPORT_BINARY_DATA_SIZE; i++ { + transport_bdata[i] = byte((i + 'a') % 255) + } + transport_header = map[string]string{"key": "User-Agent", + "value": "Mozilla/5.0 (Windows NT 6.2; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/32.0.1667.0 Safari/537.36"} +} + +func TransportTest(t *testing.T, writeTrans TTransport, readTrans TTransport) { + buf := make([]byte, TRANSPORT_BINARY_DATA_SIZE) + if !writeTrans.IsOpen() { + t.Fatalf("Transport %T not open: %s", writeTrans, writeTrans) + } + if !readTrans.IsOpen() { + t.Fatalf("Transport %T not open: %s", readTrans, readTrans) + } + _, err := writeTrans.Write(transport_bdata) + if err != nil { + t.Fatalf("Transport %T cannot write binary data of length %d: %s", writeTrans, len(transport_bdata), err) + } + err = writeTrans.Flush() + if err != nil { + t.Fatalf("Transport %T cannot flush write of binary data: %s", writeTrans, err) + } + n, err := io.ReadFull(readTrans, buf) + if err != nil { + t.Errorf("Transport %T cannot read binary data of length %d: %s", readTrans, TRANSPORT_BINARY_DATA_SIZE, err) + } + if n != TRANSPORT_BINARY_DATA_SIZE { + t.Errorf("Transport %T read only %d instead of %d bytes of binary data", readTrans, n, TRANSPORT_BINARY_DATA_SIZE) + } + for k, v := range buf { + if v != transport_bdata[k] { + t.Fatalf("Transport %T read %d instead of %d for index %d of binary data 2", readTrans, v, transport_bdata[k], k) + } + } + _, err = writeTrans.Write(transport_bdata) + if err != nil { + t.Fatalf("Transport %T cannot write binary data 2 of length %d: %s", writeTrans, len(transport_bdata), err) + } + err = writeTrans.Flush() + if err != nil { + t.Fatalf("Transport %T cannot flush write binary data 2: %s", writeTrans, err) + } + buf = make([]byte, TRANSPORT_BINARY_DATA_SIZE) + read := 1 + for n = 0; n < TRANSPORT_BINARY_DATA_SIZE && read != 0; { + read, err = readTrans.Read(buf[n:]) + if err != nil { + t.Errorf("Transport %T cannot read binary data 2 of total length %d from offset %d: %s", readTrans, TRANSPORT_BINARY_DATA_SIZE, n, err) + } + n += read + } + if n != TRANSPORT_BINARY_DATA_SIZE { + t.Errorf("Transport %T read only %d instead of %d bytes of binary data 2", readTrans, n, TRANSPORT_BINARY_DATA_SIZE) + } + for k, v := range buf { + if v != transport_bdata[k] { + t.Fatalf("Transport %T read %d instead of %d for index %d of binary data 2", readTrans, v, transport_bdata[k], k) + } + } +} + +func TransportHeaderTest(t *testing.T, writeTrans TTransport, readTrans TTransport) { + buf := make([]byte, TRANSPORT_BINARY_DATA_SIZE) + if !writeTrans.IsOpen() { + t.Fatalf("Transport %T not open: %s", writeTrans, writeTrans) + } + if !readTrans.IsOpen() { + t.Fatalf("Transport %T not open: %s", readTrans, readTrans) + } + // Need to assert type of TTransport to THttpClient to expose the Setter + httpWPostTrans := writeTrans.(*THttpClient) + httpWPostTrans.SetHeader(transport_header["key"], transport_header["value"]) + + _, err := writeTrans.Write(transport_bdata) + if err != nil { + t.Fatalf("Transport %T cannot write binary data of length %d: %s", writeTrans, len(transport_bdata), err) + } + err = writeTrans.Flush() + if err != nil { + t.Fatalf("Transport %T cannot flush write of binary data: %s", writeTrans, err) + } + // Need to assert type of TTransport to THttpClient to expose the Getter + httpRPostTrans := readTrans.(*THttpClient) + readHeader := httpRPostTrans.GetHeader(transport_header["key"]) + if err != nil { + t.Errorf("Transport %T cannot read HTTP Header Value", httpRPostTrans) + } + + if transport_header["value"] != readHeader { + t.Errorf("Expected HTTP Header Value %s, got %s", transport_header["value"], readHeader) + } + n, err := io.ReadFull(readTrans, buf) + if err != nil { + t.Errorf("Transport %T cannot read binary data of length %d: %s", readTrans, TRANSPORT_BINARY_DATA_SIZE, err) + } + if n != TRANSPORT_BINARY_DATA_SIZE { + t.Errorf("Transport %T read only %d instead of %d bytes of binary data", readTrans, n, TRANSPORT_BINARY_DATA_SIZE) + } + for k, v := range buf { + if v != transport_bdata[k] { + t.Fatalf("Transport %T read %d instead of %d for index %d of binary data 2", readTrans, v, transport_bdata[k], k) + } + } +} + +func CloseTransports(t *testing.T, readTrans TTransport, writeTrans TTransport) { + err := readTrans.Close() + if err != nil { + t.Errorf("Transport %T cannot close read transport: %s", readTrans, err) + } + if writeTrans != readTrans { + err = writeTrans.Close() + if err != nil { + t.Errorf("Transport %T cannot close write transport: %s", writeTrans, err) + } + } +} + +func FindAvailableTCPServerPort(startPort int) (net.Addr, error) { + for i := startPort; i < 65535; i++ { + s := "127.0.0.1:" + strconv.Itoa(i) + l, err := net.Listen("tcp", s) + if err == nil { + l.Close() + return net.ResolveTCPAddr("tcp", s) + } + } + return nil, NewTTransportException(UNKNOWN_TRANSPORT_EXCEPTION, "Could not find available server port") +} + +func valueInSlice(value string, slice []string) bool { + for _, v := range slice { + if value == v { + return true + } + } + return false +} diff --git a/thrift/thrift/type.go b/thrift/thrift/type.go new file mode 100644 index 0000000..7c68c2b --- /dev/null +++ b/thrift/thrift/type.go @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +// Type constants in the Thrift protocol +type TType byte + +const ( + STOP = 0 + VOID = 1 + BOOL = 2 + BYTE = 3 + I08 = 3 + DOUBLE = 4 + I16 = 6 + I32 = 8 + I64 = 10 + STRING = 11 + UTF7 = 11 + STRUCT = 12 + MAP = 13 + SET = 14 + LIST = 15 + UTF8 = 16 + UTF16 = 17 + BINARY = 18 +) + +var typeNames = map[int]string{ + STOP: "STOP", + VOID: "VOID", + BOOL: "BOOL", + BYTE: "BYTE", + I16: "I16", + I32: "I32", + I64: "I64", + STRING: "STRING", + STRUCT: "STRUCT", + MAP: "MAP", + SET: "SET", + LIST: "LIST", + UTF8: "UTF8", + UTF16: "UTF16", +} + +func (p TType) String() string { + if s, ok := typeNames[int(p)]; ok { + return s + } + return "Unknown" +} diff --git a/thrift/topic/constants.go b/thrift/topic/constants.go index 72da329..034ab10 100644 --- a/thrift/topic/constants.go +++ b/thrift/topic/constants.go @@ -6,10 +6,11 @@ package topic import ( "bytes" "fmt" - "git.apache.org/thrift.git/lib/go/thrift" + "github.com/XiaoMi/talos-sdk-golang/thrift/authorization" "github.com/XiaoMi/talos-sdk-golang/thrift/common" "github.com/XiaoMi/talos-sdk-golang/thrift/quota" + "github.com/XiaoMi/talos-sdk-golang/thrift/thrift" ) // (needed to ensure safety because of naive import list construction.) diff --git a/thrift/topic/topic_service-remote/topic_service-remote.go b/thrift/topic/topic_service-remote/topic_service-remote.go index 7fb0cd4..5ab71df 100755 --- a/thrift/topic/topic_service-remote/topic_service-remote.go +++ b/thrift/topic/topic_service-remote/topic_service-remote.go @@ -6,7 +6,6 @@ package main import ( "flag" "fmt" - "git.apache.org/thrift.git/lib/go/thrift" "math" "net" "net/url" @@ -14,6 +13,8 @@ import ( "strconv" "strings" "thrift/topic" + + "github.com/XiaoMi/talos-sdk-golang/thrift/thrift" ) func Usage() { diff --git a/thrift/topic/topicservice.go b/thrift/topic/topicservice.go index 8a14301..74b5b92 100644 --- a/thrift/topic/topicservice.go +++ b/thrift/topic/topicservice.go @@ -6,10 +6,11 @@ package topic import ( "bytes" "fmt" - "git.apache.org/thrift.git/lib/go/thrift" + "github.com/XiaoMi/talos-sdk-golang/thrift/authorization" "github.com/XiaoMi/talos-sdk-golang/thrift/common" "github.com/XiaoMi/talos-sdk-golang/thrift/quota" + "github.com/XiaoMi/talos-sdk-golang/thrift/thrift" ) // (needed to ensure safety because of naive import list construction.) diff --git a/thrift/topic/ttypes.go b/thrift/topic/ttypes.go index 40e5660..0cc3212 100644 --- a/thrift/topic/ttypes.go +++ b/thrift/topic/ttypes.go @@ -6,10 +6,10 @@ package topic import ( "bytes" "fmt" - "git.apache.org/thrift.git/lib/go/thrift" "github.com/XiaoMi/talos-sdk-golang/thrift/authorization" "github.com/XiaoMi/talos-sdk-golang/thrift/common" "github.com/XiaoMi/talos-sdk-golang/thrift/quota" + "github.com/XiaoMi/talos-sdk-golang/thrift/thrift" ) // (needed to ensure safety because of naive import list construction.) diff --git a/utils/Utils.go b/utils/Utils.go index b9c2395..f93aa31 100644 --- a/utils/Utils.go +++ b/utils/Utils.go @@ -16,10 +16,10 @@ import ( "sync/atomic" "time" - "git.apache.org/thrift.git/lib/go/thrift" "github.com/XiaoMi/talos-sdk-golang/thrift/auth" "github.com/XiaoMi/talos-sdk-golang/thrift/common" "github.com/XiaoMi/talos-sdk-golang/thrift/message" + "github.com/XiaoMi/talos-sdk-golang/thrift/thrift" "github.com/XiaoMi/talos-sdk-golang/thrift/topic" "github.com/gofrs/uuid" log "github.com/sirupsen/logrus"