diff --git a/all_test.go b/all_test.go index dd8ae18..6457a97 100644 --- a/all_test.go +++ b/all_test.go @@ -6,7 +6,7 @@ import ( "testing" . "gopkg.in/check.v1" - "gopkg.in/xmlpath.v2" + "gopkg.in/abemedia/xmlpath.v2" "strings" ) diff --git a/path.go b/path.go index db38ed5..320cc7a 100644 --- a/path.go +++ b/path.go @@ -134,7 +134,35 @@ func (s *pathStepState) next() bool { func (s *pathStepState) test(pred predicate) bool { switch pred := pred.(type) { case positionPredicate: - if pred.pos == s.pos { + switch pred.op { + case EQ: + if pred.pos == s.pos { + return true + } + case NEQ: + if pred.pos != s.pos { + return true + } + case LT: + if pred.pos > s.pos { + return true + } + case LTEQ: + if pred.pos >= s.pos { + return true + } + case GT: + if pred.pos < s.pos { + return true + } + case GTEQ: + if pred.pos <= s.pos { + return true + } + } + case lastPredicate: + ps := *s + if !ps._next() { return true } case existsPredicate: @@ -363,10 +391,24 @@ func (s *pathStepState) _next() bool { return false } +type Operator int + +const ( + EQ Operator = iota + NEQ + LT + GT + LTEQ + GTEQ +) + type positionPredicate struct { pos int + op Operator } +type lastPredicate struct{} + type existsPredicate struct { path *Path } @@ -399,6 +441,7 @@ type predicate interface { } func (positionPredicate) predicate() {} +func (lastPredicate) predicate() {} func (existsPredicate) predicate() {} func (equalsPredicate) predicate() {} func (containsPredicate) predicate() {} @@ -580,7 +623,19 @@ func (c *pathCompiler) parsePath() (path *Path, err error) { if pos == 0 { return nil, c.errorf("positions start at 1") } - next = positionPredicate{pos} + next = positionPredicate{pos, EQ} + } else if c.skipString("position()") { + c.skipSpaces() + op, ok := c.parseOp() + c.skipSpaces() + pos, ok2 := c.parseInt() + + if (!ok || !ok2) { + return nil, c.errorf("invalid position() predicate") + } + next = positionPredicate{pos, op} + } else if c.skipString("last()") { + next = lastPredicate{} } else if c.skipString("contains(") { path, err := c.parsePath() if err != nil { @@ -729,6 +784,25 @@ func (c *pathCompiler) parseInt() (v int, ok bool) { return v, true } +func (c *pathCompiler) parseOp() (op Operator, ok bool) { + if c.skipByte('=') { + return EQ, true + } else if c.skipByte('!') && c.skipByte('=') { + return NEQ, true + } else if c.skipByte('<') { + if c.skipByte('=') { + return LTEQ, true + } + return LT, true + } else if c.skipByte('>') { + if c.skipByte('=') { + return GTEQ, true + } + return GT, true + } + return -1, false +} + func (c *pathCompiler) skipByte(b byte) bool { if c.i < len(c.path) && c.path[c.i] == b { c.i++