Skip to content

Commit

Permalink
perf: early validate all messages before encoding (#454)
Browse files Browse the repository at this point in the history
* perf: early validate all messages before encoding

* protocol validator now can validate message

* protocol validator early validate message as well

* chore: add verbose error message

* select protocol version to use before validating

* chore: update encoder test code docs

* test: test protocol version must be selected on encode

* chore: clean up code
  • Loading branch information
muktihari authored Sep 18, 2024
1 parent e3d3e16 commit c0b1138
Show file tree
Hide file tree
Showing 7 changed files with 241 additions and 84 deletions.
48 changes: 33 additions & 15 deletions encoder/encoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,10 @@ func (e *Encoder) reset() {
//
// Encode chooses which strategy to use for encoding the data based on given writer.
func (e *Encoder) Encode(fit *proto.FIT) (err error) {
e.selectProtocolVersion(&fit.FileHeader)
if err = e.validateMessages(fit.Messages); err != nil {
return err
}
switch e.w.(type) {
case io.WriterAt, io.WriteSeeker:
err = e.encodeWithDirectUpdateStrategy(fit)
Expand All @@ -267,6 +271,31 @@ func (e *Encoder) Encode(fit *proto.FIT) (err error) {
return
}

func (e *Encoder) selectProtocolVersion(fileHeader *proto.FileHeader) {
if e.options.protocolVersion != 0 { // Override regardless the value in FileHeader.
fileHeader.ProtocolVersion = e.options.protocolVersion
} else if fileHeader.ProtocolVersion == 0 { // Default when not specified in FileHeader.
fileHeader.ProtocolVersion = proto.V1
}
e.protocolValidator.ProtocolVersion = fileHeader.ProtocolVersion
}

func (e *Encoder) validateMessages(messages []proto.Message) (err error) {
defer e.options.messageValidator.Reset()
for i := range messages {
mesg := &messages[i] // Must use pointer reference since validator may update the message.
if err = e.protocolValidator.ValidateMessage(mesg); err != nil {
return fmt.Errorf("protocol validation failed: message index: %d, num: %d (%s): %w",
i, mesg.Num, mesg.Num.String(), err)
}
if err = e.options.messageValidator.Validate(mesg); err != nil {
return fmt.Errorf("message validation failed: message index: %d, num: %d (%s): %w",
i, mesg.Num, mesg.Num.String(), err)
}
}
return nil
}

// encodeWithDirectUpdateStrategy encodes all data to file, after completing,
// it updates the actual size of the messages that being written to the proto.
func (e *Encoder) encodeWithDirectUpdateStrategy(fit *proto.FIT) error {
Expand Down Expand Up @@ -316,13 +345,6 @@ func (e *Encoder) encodeFileHeader(header *proto.FileHeader) error {
header.ProfileVersion = profile.Version
}

if e.options.protocolVersion != 0 { // Override regardless the value in FileHeader.
header.ProtocolVersion = e.options.protocolVersion
} else if header.ProtocolVersion == 0 { // Default when not specified in FileHeader.
header.ProtocolVersion = proto.V1
}
e.protocolValidator.ProtocolVersion = header.ProtocolVersion

header.DataType = proto.DataTypeFIT
header.CRC = 0 // recalculated

Expand Down Expand Up @@ -436,10 +458,6 @@ func (e *Encoder) encodeMessages(messages []proto.Message) error {
func (e *Encoder) encodeMessage(mesg *proto.Message) (err error) {
mesg.Header = proto.MesgNormalHeaderMask

if err = e.options.messageValidator.Validate(mesg); err != nil {
return fmt.Errorf("message validation failed: %w", err)
}

var compressed bool
if e.options.headerOption == HeaderOptionCompressedTimestamp {
if e.w == io.Discard {
Expand Down Expand Up @@ -467,10 +485,6 @@ func (e *Encoder) encodeMessage(mesg *proto.Message) (err error) {
}

mesgDef := e.newMessageDefinition(mesg)
if err := e.protocolValidator.ValidateMessageDefinition(mesgDef); err != nil {
return err
}

b, _ := mesgDef.MarshalAppend(e.buf[:0])
localMesgNum, isNewMesgDef := e.localMesgNumLRU.Put(b) // This might alloc memory since we need to copy the item.

Expand Down Expand Up @@ -578,6 +592,10 @@ func (e *Encoder) encodeCRC() error {

// EncodeWithContext is similar to Encode but with respect to context propagation.
func (e *Encoder) EncodeWithContext(ctx context.Context, fit *proto.FIT) (err error) {
e.selectProtocolVersion(&fit.FileHeader)
if err = e.validateMessages(fit.Messages); err != nil {
return err
}
switch e.w.(type) {
case io.WriterAt, io.WriteSeeker:
err = e.encodeWithDirectUpdateStrategyWithContext(ctx, fit)
Expand Down
162 changes: 93 additions & 69 deletions encoder/encoder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -362,40 +362,115 @@ var (
)

func TestEncode(t *testing.T) {
fitOK := proto.FIT{
Messages: []proto.Message{
{Num: mesgnum.FileId, Fields: []proto.Field{
factory.CreateField(mesgnum.FileId, fieldnum.FileIdType).WithValue(typedef.FileActivity.Byte()),
}},
},
}

tt := []struct {
name string
w io.Writer
fit *proto.FIT
err error
}{
{name: "encode with nil", w: nil, err: ErrNilWriter},
{name: "encode with writer", w: fnWriteOK},
{name: "encode with writerAt", w: mockWriterAt{fnWriteOK, fnWriteAtOK}},
{name: "encode with writeSeeker", w: mockWriteSeeker{fnWriteOK, fnSeekOK}},
}

fit := proto.FIT{
Messages: []proto.Message{
{Num: mesgnum.FileId, Fields: []proto.Field{
factory.CreateField(mesgnum.FileId, fieldnum.FileIdType).WithValue(typedef.FileActivity.Byte()),
}},
{name: "encode with nil", w: nil, fit: &fitOK, err: ErrNilWriter},
{name: "encode with writer", w: fnWriteOK, fit: &fitOK},
{name: "encode with writerAt", w: mockWriterAt{fnWriteOK, fnWriteAtOK}, fit: &fitOK},
{name: "encode with writeSeeker", w: mockWriteSeeker{fnWriteOK, fnSeekOK}, fit: &fitOK},
{
name: "encode return error from validation",
fit: &proto.FIT{
Messages: []proto.Message{
{Num: mesgnum.Record, Fields: []proto.Field{
factory.CreateField(mesgnum.Record, fieldnum.RecordSpeed1S).WithValue(make([]uint8, 256)), // Exceed max allowed
}},
},
},
w: fnWriteOK,
err: ErrExceedMaxAllowed,
},
{
name: "encode return error protocol violation since proto.V1 does not allow Int64",
fit: &proto.FIT{
FileHeader: proto.FileHeader{ProtocolVersion: proto.V1},
Messages: []proto.Message{
{Num: mesgnum.Record, Fields: []proto.Field{
{FieldBase: &proto.FieldBase{BaseType: basetype.Sint64}},
}},
},
},
w: fnWriteOK,
err: proto.ErrProtocolViolation,
},
}

for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
for i, tc := range tt {
if i != len(tt)-1 {
continue
}
t.Run(fmt.Sprintf("[%d] %s", i, tc.name), func(t *testing.T) {
enc := New(tc.w)
err := enc.Encode(&fit)
err := enc.Encode(tc.fit)
if !errors.Is(err, tc.err) {
t.Fatalf("expected error: %v, got: %v", tc.err, err)
}
})
}

// Test same logic for EncodeWithContext
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
for i, tc := range tt {
t.Run(fmt.Sprintf("[%d] %s", i, tc.name), func(t *testing.T) {
enc := New(tc.w)
err := enc.EncodeWithContext(context.Background(), &fit)
err := enc.EncodeWithContext(context.Background(), tc.fit)
if !errors.Is(err, tc.err) {
t.Fatalf("expected error: %v, got: %v", tc.err, err)
}
})
}
}

func TestValidateMessages(t *testing.T) {
tt := []struct {
name string
protocolVersion proto.Version
messages []proto.Message
err error
}{
{
name: "happy flow",
protocolVersion: proto.V1,
messages: []proto.Message{{Num: mesgnum.FileId, Fields: []proto.Field{
factory.CreateField(mesgnum.FileId, fieldnum.FileIdManufacturer).WithValue(typedef.ManufacturerDevelopment),
}}},
},
{
name: "protocol validation failed",
protocolVersion: proto.V1,
messages: []proto.Message{{Num: mesgnum.Record, Fields: []proto.Field{
factory.CreateField(mesgnum.Record, fieldnum.RecordSpeed).WithValue(uint16(1000)),
}, DeveloperFields: []proto.DeveloperField{{}}}},
err: proto.ErrProtocolViolation,
},
{
name: "message validation failed",
protocolVersion: proto.V1,
messages: []proto.Message{{Num: mesgnum.Record, Fields: []proto.Field{
factory.CreateField(mesgnum.Record, fieldnum.RecordSpeed1S).WithValue(make([]uint8, 256)),
}}},
err: ErrExceedMaxAllowed,
},
}

for i, tc := range tt {
t.Run(fmt.Sprintf("[%d] %s", i, tc.name), func(t *testing.T) {
enc := New(nil)
// Protocol Version is now selected by selectProtocolVersion method as we allow dynamic protocol version
// based on FileHeader. This by pass it since we don't encode file header.
enc.protocolValidator.ProtocolVersion = tc.protocolVersion
err := enc.validateMessages(tc.messages)
if !errors.Is(err, tc.err) {
t.Fatalf("expected error: %v, got: %v", tc.err, err)
}
Expand Down Expand Up @@ -868,6 +943,7 @@ func TestEncodeHeader(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
bytebuf := new(bytes.Buffer)
enc := New(bytebuf, append(tc.opts, WithWriteBufferSize(0))...)
enc.selectProtocolVersion(&tc.header)
_ = enc.encodeFileHeader(&tc.header)

if diff := cmp.Diff(bytebuf.Bytes(), tc.b); diff != "" {
Expand Down Expand Up @@ -927,49 +1003,6 @@ func TestEncodeMessage(t *testing.T) {
}},
w: fnWriteOK,
},
{
name: "message validator's validate return error",
mesg: proto.Message{},
w: nil,
err: ErrNoFields,
},
{
name: "normal header: protocol validator's validate message definition return error",
opts: []Option{
WithProtocolVersion(proto.V1),
},
mesg: proto.Message{Fields: []proto.Field{
{
FieldBase: &proto.FieldBase{
Name: factory.NameUnknown,
Type: profile.Sint64, // int64 type is ilegal for protocol v1.0
BaseType: profile.Sint64.BaseType(),
},
Value: proto.Int64(1234),
},
}},
w: nil,
err: proto.ErrProtocolViolation,
},
{
name: "compressed timestamp header: protocol validator's validate message definition return error",
opts: []Option{
WithProtocolVersion(proto.V1),
WithHeaderOption(HeaderOptionCompressedTimestamp, 0),
},
mesg: proto.Message{Fields: []proto.Field{
{
FieldBase: &proto.FieldBase{
Name: factory.NameUnknown,
Type: profile.Sint64, // int64 type is ilegal for protocol v1.0
BaseType: profile.Sint64.BaseType(),
},
Value: proto.Int64(1234),
},
}},
w: nil,
err: proto.ErrProtocolViolation,
},
{
name: "write message definition return error",
opts: []Option{
Expand Down Expand Up @@ -1014,9 +1047,6 @@ func TestEncodeMessage(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
tc.opts = append(tc.opts, WithWriteBufferSize(0))
enc := New(tc.w, tc.opts...)
// Protocol Version now is set on encodeFileHeader as we allow dynamic protocol version
// based on FileHeader. This by pass it since we don't encode file header.
enc.protocolValidator.ProtocolVersion = enc.options.protocolVersion
err := enc.encodeMessage(&tc.mesg)
if !errors.Is(err, tc.err) {
t.Fatalf("expected: %v, got: %v", tc.err, err)
Expand Down Expand Up @@ -1168,12 +1198,6 @@ func makeEncodeMessagesTableTest() []encodeMessagesTestCase {
mesgs: []proto.Message{},
err: ErrEmptyMessages,
},
{
name: "encode messages return error",
mesgValidator: fnValidateErr,
mesgs: []proto.Message{{}},
err: ErrNoFields, // Validator error since the first mesg is invalid.
},
{
name: "missing file_id mesg",
mesgs: []proto.Message{{Num: mesgnum.Record}},
Expand Down
7 changes: 7 additions & 0 deletions encoder/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,18 @@ type StreamEncoder struct {
// - This method is called right after SequenceCompleted method has been called.
func (e *StreamEncoder) WriteMessage(mesg *proto.Message) error {
if !e.fileHeaderWritten {
e.enc.selectProtocolVersion(&e.fileHeader)
if err := e.enc.encodeFileHeader(&e.fileHeader); err != nil {
return fmt.Errorf("could not encode file header: %w", err)
}
e.fileHeaderWritten = true
}
if err := e.enc.protocolValidator.ValidateMessage(mesg); err != nil {
return fmt.Errorf("protocol validate message failed: %d (%s): %w", mesg.Num, mesg.Num, err)
}
if err := e.enc.options.messageValidator.Validate(mesg); err != nil {
return fmt.Errorf("message validation failed: mesgNum: %d (%s): %w", mesg.Num, mesg.Num, err)
}
if err := e.enc.encodeMessage(mesg); err != nil {
return fmt.Errorf("could not encode mesg: mesgNum: %d (%q): %w", mesg.Num, mesg.Num, err)
}
Expand Down
19 changes: 19 additions & 0 deletions encoder/stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,25 @@ func TestStreamEncoderUnhappyFlow(t *testing.T) {
if !errors.Is(err, io.EOF) {
t.Fatalf("expected err: %v, got: %v", io.EOF, err)
}

// Protocol validation error
streamEnc, _ = New(mockWriterAt{}).StreamEncoder()
streamEnc.enc.protocolValidator.ProtocolVersion = proto.V1
err = streamEnc.WriteMessage(&proto.Message{Fields: []proto.Field{
factory.CreateField(mesgnum.Record, fieldnum.RecordSpeed1S).WithValue(make([]uint8, 256)),
}, DeveloperFields: []proto.DeveloperField{{}}})
if !errors.Is(err, proto.ErrProtocolViolation) {
t.Fatalf("expected err: %v, got: %v", proto.ErrProtocolViolation, err)
}

// Message validation error
streamEnc, _ = New(mockWriterAt{}).StreamEncoder()
err = streamEnc.WriteMessage(&proto.Message{Fields: []proto.Field{
factory.CreateField(mesgnum.Record, fieldnum.RecordSpeed1S).WithValue(make([]uint8, 256))}},
)
if !errors.Is(err, ErrExceedMaxAllowed) {
t.Fatalf("expected err: %v, got: %v", ErrExceedMaxAllowed, err)
}
}

func TestStreamEncoderWithoutWriteBuffer(t *testing.T) {
Expand Down
5 changes: 5 additions & 0 deletions encoder/validator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,11 @@ func TestMessageValidatorValidate(t *testing.T) {
},
},
},
{
name: "error no fields",
mesgs: []proto.Message{{Fields: []proto.Field{}}},
errs: []error{ErrNoFields},
},
}

for i, tc := range tt {
Expand Down
16 changes: 16 additions & 0 deletions proto/validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,19 @@ func (p *Validator) ValidateMessageDefinition(mesgDef *MessageDefinition) error
}
return nil
}

// ValidateMessage validates whether the message contains unsupported data for the targeted version.
func (p *Validator) ValidateMessage(mesg *Message) error {
if p.ProtocolVersion == V1 {
if len(mesg.DeveloperFields) > 0 {
return fmt.Errorf("protocol version 1.0 do not support developer fields: %w", ErrProtocolViolation)
}
for i := range mesg.Fields {
field := &mesg.Fields[i]
if field.BaseType&basetype.BaseTypeNumMask > basetype.Byte&basetype.BaseTypeNumMask { // byte was the last type added in 1.0
return fmt.Errorf("protocol version 1.0 do not support type '%s': %w", field.BaseType, ErrProtocolViolation)
}
}
}
return nil
}
Loading

0 comments on commit c0b1138

Please sign in to comment.