Skip to content

Commit

Permalink
fix: parsing DOCTYPE directive, trim bytes, and simplify logic (#14)
Browse files Browse the repository at this point in the history
  • Loading branch information
muktihari authored Jun 18, 2024
1 parent 70cbab6 commit 71c4417
Show file tree
Hide file tree
Showing 5 changed files with 381 additions and 121 deletions.
4 changes: 2 additions & 2 deletions benchmark_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,15 @@ func BenchmarkToken(b *testing.B) {
var err error
for i := 0; i < b.N; i++ {
if err = unmarshalWithStdlibXML(path); err != nil {
b.Fatal(err)
b.Skipf("could not unmarshal: %v", err)
}
}
})
b.Run(fmt.Sprintf("xmltokenizer:%q", name), func(b *testing.B) {
var err error
for i := 0; i < b.N; i++ {
if err = unmarshalWithXMLTokenizer(path); err != nil {
b.Fatal(err)
b.Skipf("could not unmarshal: %v", err)
}
}
})
Expand Down
5 changes: 5 additions & 0 deletions testdata/corrupted/cdata_truncated.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
<?xml version="1.0" encoding="UTF-8"?>
<content>
<data>
<![CDATA[
224 changes: 111 additions & 113 deletions tokenizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,13 @@ const (

// Tokenizer is a XML tokenizer.
type Tokenizer struct {
r io.Reader // reader provided by the client
options options // tokenizer's options
buf []byte // buffer that will grow as needed, large enough to hold a token (default max limit: 1MB)
cur, last int // cur and last bytes positions
err error // last encountered error
token Token // shared token
r io.Reader // reader provided by the client
n int64 // the n read bytes counter
options options // tokenizer's options
buf []byte // buffer that will grow as needed, large enough to hold a token (default max limit: 1MB)
cur int // cursor byte position
err error // last encountered error
token Token // shared token
}

type options struct {
Expand Down Expand Up @@ -85,7 +86,7 @@ func New(r io.Reader, opts ...Option) *Tokenizer {
// future tokenization to reduce memory alloc.
func (t *Tokenizer) Reset(r io.Reader, opts ...Option) {
t.r, t.err = r, nil
t.cur, t.last = 0, 0
t.n, t.cur = 0, 0

t.options = defaultOptions()
for i := range opts {
Expand Down Expand Up @@ -119,9 +120,12 @@ func (t *Tokenizer) Token() (token Token, err error) {
b, err := t.RawToken()
if err != nil {
if !errors.Is(err, io.EOF) {
return token, err
err = fmt.Errorf("byte pos %d: %w", t.n, err)
}
t.err = io.EOF
if len(b) == 0 || errors.Is(err, io.ErrUnexpectedEOF) {
return
}
t.err = err
}

t.clearToken()
Expand All @@ -145,136 +149,121 @@ func (t *Tokenizer) Token() (token Token, err error) {
}

// RawToken returns token in its raw bytes. At the end,
// it may returns last token bytes and an io.EOF error.
// it may returns last token bytes and an error.
// The returned token bytes is only valid before next
// Token or RawToken method invocation.
func (t *Tokenizer) RawToken() (b []byte, err error) {
if t.err != nil {
return nil, err
return nil, t.err
}

pos := t.cur
var off int
var pivot, pos = t.cur, t.cur
var openclose int // zero means open '<' and close '>' is matched.
for {
if pos >= t.last {
off, pos = t.memmoveRemainingBytes(off)
if pos >= len(t.buf) {
pivot, pos = t.memmoveRemainingBytes(pivot)
if err = t.manageBuffer(); err != nil {
if openclose != 0 && errors.Is(err, io.EOF) {
err = io.ErrUnexpectedEOF
}
t.err = err
return t.buf[off:pos], err
return t.buf[pivot:pos], err
}
}
switch t.buf[pos] {
case '<':
off = pos

// Check if tag represents Document Type Definition (DTD)
const prefix, _ = "<!DOCTYPE", "]>"
dtdOff := 0
var k int = 1
for i := pos + 1; ; i++ {
if i >= t.last {
prevLast := t.last
off, i = t.memmoveRemainingBytes(off)
if dtdOff != 0 {
dtdOff = dtdOff - (prevLast - t.last)
}
if err = t.manageBuffer(); err != nil {
t.err = err
break
}
}
if k < len(prefix) {
if t.buf[i] != prefix[k] {
break
}
k++
continue
}
switch t.buf[i] {
case ']':
dtdOff = i
case '>':
if t.buf[dtdOff] == ']' {
buf := trim(t.buf[off : i+1 : cap(t.buf)])
t.cur = i + 1
return buf, err
}
}
if openclose == 0 {
pivot = pos
}
openclose++
case '>':
// If next char represents CharData, include it.
for i := pos + 1; ; i++ {
if i >= t.last {
off, i = t.memmoveRemainingBytes(off)
pos = i - 1
if err = t.manageBuffer(); err != nil {
t.err = err
break
}
}
if t.buf[i] == '<' {
pos = i - 1
// Might be in the form of <![CDATA[ CharData ]]>
const prefix, suffix = "<![CDATA[", "]]>"
var k int = 1
for j := i + 1; ; j++ {
if j >= t.last {
prevLast := t.last
off, j = t.memmoveRemainingBytes(off)
pos = pos - (prevLast - t.last)
if err = t.manageBuffer(); err != nil {
t.err = err
break
}
}
if k < len(prefix) {
if t.buf[j] != prefix[k] {
break
}
k++
continue
}
if t.buf[j] == '>' && string(t.buf[j-2:j+1]) == suffix {
pos = j
break
}
}
break
}
if openclose--; openclose != 0 {
break
}

switch t.buf[pivot+1] {
case '?', '!': // Maybe a ProcInst "<?target", a Directive "<!DOCTYPE" or a Comment "<!--"
buf := trim(t.buf[pivot : pos+1 : cap(t.buf)])
t.cur = pos + 1
return buf, err
}
buf := trim(t.buf[off : pos+1 : cap(t.buf)])

// Regular tag, check if next char represents CharData, include it.
pivot, pos = t.parseCharData(pivot, pos)

buf := trim(t.buf[pivot : pos+1 : cap(t.buf)])
t.cur = pos + 1
return buf, err
}
pos++
}
}

func (t *Tokenizer) clearToken() {
t.token.Name.Space = nil
t.token.Name.Local = nil
t.token.Name.Full = nil
t.token.Attrs = t.token.Attrs[:0]
t.token.Data = nil
t.token.SelfClosing = false
// parseCharData parses the next character sequence and if it represents
// CharData or <![CDATA[ CharData ]]>, this method will include it in the previous token.
// It returns the new pivot and new position.
func (t *Tokenizer) parseCharData(pivot, pos int) (newPivot, newPos int) {
for i := pos + 1; ; i++ {
if i >= len(t.buf) {
pivot, i = t.memmoveRemainingBytes(pivot)
pos = i - 1
if t.err = t.manageBuffer(); t.err != nil {
break
}
}
if t.buf[i] != '<' {
continue
}

pos = i - 1
// Might be in the form of <![CDATA[ CharData ]]>
const prefix, suffix = "<![CDATA[", "]]>"
var k int = 1
for j := i + 1; ; j++ {
if j >= len(t.buf) {
prevLast := len(t.buf)
pivot, j = t.memmoveRemainingBytes(pivot)
pos = pos - (prevLast - len(t.buf))
if t.err = t.manageBuffer(); t.err != nil {
if errors.Is(t.err, io.EOF) {
t.err = io.ErrUnexpectedEOF
}
break
}
}
if k < len(prefix) {
if t.buf[j] != prefix[k] {
break
}
k++
continue
}
if t.buf[j] == '>' && string(t.buf[j-2:j+1]) == suffix {
pos = j
break
}
}
break
}
return pivot, pos
}

func (t *Tokenizer) memmoveRemainingBytes(off int) (cur, last int) {
if off == 0 {
return t.cur, t.last
func (t *Tokenizer) memmoveRemainingBytes(pivot int) (cur, last int) {
if pivot == 0 {
return t.cur, len(t.buf)
}
n := copy(t.buf, t.buf[off:])
n := copy(t.buf, t.buf[pivot:])
t.buf = t.buf[:n:cap(t.buf)]
t.cur, t.last = 0, n
return t.cur, t.last
t.cur = 0
return t.cur, len(t.buf)
}

func (t *Tokenizer) manageBuffer() error {
var start, end int
switch growSize := t.last + t.options.readBufferSize; {
growSize := len(t.buf) + t.options.readBufferSize
start, end := len(t.buf), growSize
switch {
case growSize <= cap(t.buf): // Grow by reslice
t.buf = t.buf[:growSize:cap(t.buf)]
start, end = t.last, growSize
default: // Grow by make new alloc
if growSize > t.options.autoGrowBufferMaxLimitSize {
return fmt.Errorf("could not grow buffer to %d, max limit is set to %d: %w",
Expand All @@ -288,11 +277,20 @@ func (t *Tokenizer) manageBuffer() error {

n, err := io.ReadAtLeast(t.r, t.buf[start:end], 1)
t.buf = t.buf[: start+n : cap(t.buf)]
t.last = len(t.buf)
t.n += int64(n)

return err
}

func (t *Tokenizer) clearToken() {
t.token.Name.Space = nil
t.token.Name.Local = nil
t.token.Name.Full = nil
t.token.Attrs = t.token.Attrs[:0]
t.token.Data = nil
t.token.SelfClosing = false
}

// consumeNonTagIdentifier consumes identifier starts with "<?" or "<!", make it raw data.
func (t *Tokenizer) consumeNonTagIdentifier(b []byte) []byte {
if len(b) < 2 || (string(b[:2]) != "<?" && string(b[:2]) != "<!") {
Expand Down Expand Up @@ -343,7 +341,7 @@ func (t *Tokenizer) consumeAttrs(b []byte) []byte {
case '"':
inquote = !inquote
if !inquote {
if full == nil {
if len(full) == 0 { // Ignore malformed attr
continue
}
t.token.Attrs = append(t.token.Attrs, Attr{
Expand Down Expand Up @@ -390,13 +388,13 @@ func trimPrefix(b []byte) []byte {
start += 2
i++
}
case '\n', ' ':
case '\n', ' ', '\t':
start++
default:
return b[start:]
}
}
return b
return b[start:]
}

func trimSuffix(b []byte) []byte {
Expand All @@ -408,11 +406,11 @@ func trimSuffix(b []byte) []byte {
if i-1 > 0 && b[i-1] == '\r' {
end--
}
case ' ':
case ' ', '\t':
end--
default:
return b[:end]
}
}
return b
return b[:end]
}
6 changes: 3 additions & 3 deletions tokenizer_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,9 @@ func TestReset(t *testing.T) {
t.Fatalf("expected cap(t.buf): %d, got: %d", expected, cap(tok.buf))
}

if tok.cur != 0 && tok.last != 0 {
t.Fatalf("expected cur: %d, last: %d, got: cur: %d, last: %d",
0, 0, tok.cur, tok.last)
if tok.cur != 0 {
t.Fatalf("expected cur: %d, got: cur: %d",
0, tok.cur)
}

newBufferSize := 2000 << 10
Expand Down
Loading

0 comments on commit 71c4417

Please sign in to comment.