From 0e0ff03e2a0361e75d55041772c93950861c17c3 Mon Sep 17 00:00:00 2001 From: Sebastien Binet Date: Wed, 4 Jul 2018 17:13:12 +0200 Subject: [PATCH] zmq4: evolve Security interface to support PLAIN and CURVE Updates #25. Updates #26. --- conn.go | 2 +- security.go | 6 ++++-- security/null/null.go | 2 +- security/null/null_test.go | 4 ++-- security_test.go | 2 +- 5 files changed, 9 insertions(+), 7 deletions(-) diff --git a/conn.go b/conn.go index 572f4c1..a998133 100644 --- a/conn.go +++ b/conn.go @@ -80,7 +80,7 @@ func (conn *Conn) init(sec Security, md map[string]string) error { return errors.Wrapf(err, "zmq4: could not exchange greetings") } - err = conn.sec.Handshake() + err = conn.sec.Handshake(conn, conn.server) if err != nil { return errors.Wrapf(err, "zmq4: could not perform security handshake") } diff --git a/security.go b/security.go index 762fe56..ddbdc57 100644 --- a/security.go +++ b/security.go @@ -10,14 +10,16 @@ import ( // Security is an interface for ZMTP security mechanisms type Security interface { + // Type returns the security mechanism type. Type() SecurityType + // Handshake implements the ZMTP security handshake according to // this security mechanism. // see: // https://rfc.zeromq.org/spec:23/ZMTP/ // https://rfc.zeromq.org/spec:24/ZMTP-PLAIN/ // https://rfc.zeromq.org/spec:25/ZMTP-CURVE/ - Handshake() error + Handshake(conn *Conn, server bool) error // Encrypt writes the encrypted form of data to w. Encrypt(w io.Writer, data []byte) (int, error) @@ -58,7 +60,7 @@ func (nullSecurity) Type() SecurityType { // https://rfc.zeromq.org/spec:23/ZMTP/ // https://rfc.zeromq.org/spec:24/ZMTP-PLAIN/ // https://rfc.zeromq.org/spec:25/ZMTP-CURVE/ -func (nullSecurity) Handshake() error { +func (nullSecurity) Handshake(conn *Conn, server bool) error { return nil } diff --git a/security/null/null.go b/security/null/null.go index 4b75f26..8b3f6c8 100644 --- a/security/null/null.go +++ b/security/null/null.go @@ -30,7 +30,7 @@ func (security) Type() zmq4.SecurityType { // https://rfc.zeromq.org/spec:23/ZMTP/ // https://rfc.zeromq.org/spec:24/ZMTP-PLAIN/ // https://rfc.zeromq.org/spec:25/ZMTP-CURVE/ -func (security) Handshake() error { +func (security) Handshake(conn *zmq4.Conn, server bool) error { return nil } diff --git a/security/null/null_test.go b/security/null/null_test.go index 66a2956..5fec5ec 100644 --- a/security/null/null_test.go +++ b/security/null/null_test.go @@ -12,13 +12,13 @@ import ( "github.com/go-zeromq/zmq4/security/null" ) -func TestNullSecurity(t *testing.T) { +func TestSecurity(t *testing.T) { sec := null.Security() if got, want := sec.Type(), zmq4.NullSecurity; got != want { t.Fatalf("got=%v, want=%v", got, want) } - err := sec.Handshake() + err := sec.Handshake(nil, false) if err != nil { t.Fatalf("error doing handshake: %v", err) } diff --git a/security_test.go b/security_test.go index 9a391d8..262d520 100644 --- a/security_test.go +++ b/security_test.go @@ -15,7 +15,7 @@ func TestNullSecurity(t *testing.T) { t.Fatalf("got=%v, want=%v", got, want) } - err := sec.Handshake() + err := sec.Handshake(nil, false) if err != nil { t.Fatalf("error doing handshake: %v", err) }