diff --git a/decode.go b/decode.go index d6e9dae..45eb695 100644 --- a/decode.go +++ b/decode.go @@ -21,7 +21,8 @@ func Unmarshal(data []byte, v interface{}) error { // A Decoder reads and decodes fixed width data from an input stream. type Decoder struct { - data *bufio.Reader + scanner *bufio.Scanner + separator string done bool useCodepointIndices bool @@ -31,9 +32,12 @@ type Decoder struct { // NewDecoder returns a new decoder that reads from r. func NewDecoder(r io.Reader) *Decoder { - return &Decoder{ - data: bufio.NewReader(r), + dec := &Decoder{ + scanner: bufio.NewScanner(r), + separator: "", } + dec.scanner.Split(dec.Scan) + return dec } // An InvalidUnmarshalError describes an invalid argument passed to Unmarshal. @@ -178,19 +182,44 @@ func findFirstMultiByteChar(data string) int { return len(data) } -func (d *Decoder) readLine(v reflect.Value) (err error, ok bool) { - line, err := d.data.ReadString('\n') - if err != nil && err != io.EOF { - return err, false +func (d *Decoder) SetSeparator(separator string) { + d.separator = separator +} + +func (d Decoder) Separator() string { + if d.separator != "" { + return d.separator } - if err == io.EOF { - d.done = true - if len(line) <= 0 || line[0] == '\n' { - // skip last empty lines - return nil, false - } + return "\n" +} + +func (d *Decoder) Scan(data []byte, atEOF bool) (advance int, token []byte, err error) { + sep := []byte(d.Separator()) + if atEOF && len(data) == 0 { + return 0, nil, nil + } + if i := bytes.Index(data, sep); i >= 0 { + // We have a full newline-terminated line. + return i + len(sep), data[0:i], nil } + // If we're at EOF, we have a final, non-terminated line. Return it. + if atEOF { + return len(data), data, nil + } + // Request more data. + return 0, nil, nil +} + +func (d *Decoder) readLine(v reflect.Value) (err error, ok bool) { + ok = d.scanner.Scan() + if !ok { + d.done = true + return nil, false + } + + line := string(d.scanner.Bytes()) + rawValue, err := newRawValue(line, d.useCodepointIndices) if err != nil { return diff --git a/decode_test.go b/decode_test.go index 5ce24ce..dad5812 100644 --- a/decode_test.go +++ b/decode_test.go @@ -332,3 +332,79 @@ func TestNewRawValue(t *testing.T) { }) } } + +func TestLineSeparator(t *testing.T) { + // allTypes contains a field with all current supported types. + type allTypes struct { + String string `fixed:"1,5"` + Int int `fixed:"6,10"` + Float float64 `fixed:"11,15"` + TextUnmarshaler EncodableString `fixed:"16,20"` + } + for _, tt := range []struct { + name string + rawValue []byte + target interface{} + expected interface{} + shouldErr bool + separator string + }{ + { + name: "CR line endings", + rawValue: []byte("foo 123 1.2 bar" + "\n" + "bar 321 2.1 foo"), + target: &[]allTypes{}, + expected: &[]allTypes{ + {"foo", 123, 1.2, EncodableString{"bar", nil}}, + {"bar", 321, 2.1, EncodableString{"foo", nil}}, + }, + shouldErr: false, + separator: "", + }, + { + name: "CR line endings", + rawValue: []byte("f\ro 123 1.2 bar" + "\n" + "bar 321 2.1 foo"), + target: &[]allTypes{}, + expected: &[]allTypes{ + {"f\ro", 123, 1.2, EncodableString{"bar", nil}}, + {"bar", 321, 2.1, EncodableString{"foo", nil}}, + }, + shouldErr: false, + separator: "\n", + }, + { + name: "CRLF line endings", + rawValue: []byte("f\no 123 1.2 bar" + "\r\n" + "bar 321 2.1 foo"), + target: &[]allTypes{}, + expected: &[]allTypes{ + {"f\no", 123, 1.2, EncodableString{"bar", nil}}, + {"bar", 321, 2.1, EncodableString{"foo", nil}}, + }, + shouldErr: false, + separator: "\r\n", + }, + { + name: "LF line endings", + rawValue: []byte("f\no 123 1.2 bar" + "\r" + "bar 321 2.1 foo"), + target: &[]allTypes{}, + expected: &[]allTypes{ + {"f\no", 123, 1.2, EncodableString{"bar", nil}}, + {"bar", 321, 2.1, EncodableString{"foo", nil}}, + }, + shouldErr: false, + separator: "\r", + }, + } { + t.Run(tt.name, func(t *testing.T) { + dec := NewDecoder(bytes.NewReader(tt.rawValue)) + dec.SetSeparator(tt.separator) + err := dec.Decode(tt.target) + if tt.shouldErr != (err != nil) { + t.Errorf("Unmarshal() err want %v, have %v (%v)", tt.shouldErr, err != nil, err) + } + if !tt.shouldErr && !reflect.DeepEqual(tt.target, tt.expected) { + t.Errorf("Unmarshal() want %+v, have %+v", tt.expected, tt.target) + } + + }) + } +}