From 330a89d8773e87b83cb984b7a69a0d893b30b4bd Mon Sep 17 00:00:00 2001 From: Ruslan Iushchenko Date: Fri, 22 May 2026 08:17:04 +0200 Subject: [PATCH 1/7] #850 Add support of boolean expressions in expressions evaluator. --- ...luator.scala => ExpressionEvaluator.scala} | 17 ++-- .../cobol/parser/expression/lexer/Lexer.scala | 1 + .../cobol/parser/expression/lexer/Token.scala | 4 + ...Impl.scala => ExpressionBuilderImpl.scala} | 78 ++++++++++++++----- .../parser/ExtractVariablesBuilder.scala | 2 + .../expression/parser/NumExprBuilder.scala | 1 + .../parser/expression/parser/Parser.scala | 3 + ...thRecordLengthExprRawRecordExtractor.scala | 2 +- .../iterator/RecordLengthExpression.scala | 4 +- .../reader/iterator/RecordLengthField.scala | 2 +- .../validator/ReaderParametersValidator.scala | 4 +- .../expression/ExpressionEvaluatorSuite.scala | 44 +++++++++-- 12 files changed, 125 insertions(+), 37 deletions(-) rename cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/{NumberExprEvaluator.scala => ExpressionEvaluator.scala} (80%) rename cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/parser/{NumExprBuilderImpl.scala => ExpressionBuilderImpl.scala} (61%) diff --git a/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/NumberExprEvaluator.scala b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/ExpressionEvaluator.scala similarity index 80% rename from cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/NumberExprEvaluator.scala rename to cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/ExpressionEvaluator.scala index 85a9494f0..dbda5489d 100644 --- a/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/NumberExprEvaluator.scala +++ b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/ExpressionEvaluator.scala @@ -17,7 +17,7 @@ package za.co.absa.cobrix.cobol.parser.expression import za.co.absa.cobrix.cobol.parser.expression.lexer.Lexer -import za.co.absa.cobrix.cobol.parser.expression.parser.{ExtractVariablesBuilder, NumExprBuilderImpl, Parser} +import za.co.absa.cobrix.cobol.parser.expression.parser.{ExtractVariablesBuilder, ExpressionBuilderImpl, Parser} import scala.collection.mutable @@ -34,7 +34,7 @@ import scala.collection.mutable * assert(evaluator.eval() == 549) * }}} */ -class NumberExprEvaluator(expr: String) { +class ExpressionEvaluator(expr: String) { private val tokens = new Lexer(expr).lex() private val vars = mutable.HashMap[String, Int]() @@ -50,10 +50,17 @@ class NumberExprEvaluator(expr: String) { exprBuilder.getResult } - def eval(): Int = { - val exprBuilder = new NumExprBuilderImpl(vars.toMap, expr) + def evalInt(): Int = { + val exprBuilder = new ExpressionBuilderImpl(vars.toMap, expr) Parser.parse(tokens, exprBuilder) - exprBuilder.getResult + exprBuilder.getIntResult + } + + def evalBool(): Boolean = { + val exprBuilder = new ExpressionBuilderImpl(vars.toMap, expr) + Parser.parse(tokens, exprBuilder) + + exprBuilder.getBoolResult } } diff --git a/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/lexer/Lexer.scala b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/lexer/Lexer.scala index 47c7e4d45..c5e70c3b8 100644 --- a/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/lexer/Lexer.scala +++ b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/lexer/Lexer.scala @@ -54,6 +54,7 @@ class Lexer(expression: String) { case '-' => Some(MINUS(pos)) case '*' => Some(MULT(pos)) case '/' => Some(DIV(pos)) + case '=' => Some(EQ(pos)) case _ => None } diff --git a/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/lexer/Token.scala b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/lexer/Token.scala index 81a0b4f7c..a8c091aba 100644 --- a/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/lexer/Token.scala +++ b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/lexer/Token.scala @@ -54,6 +54,10 @@ object Token { override def toString = "/" } + case class EQ(pos: Int) extends Token { + override def toString = "=" + } + case class NAME(pos: Int, s: String) extends Token { override def toString: String = s diff --git a/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/parser/NumExprBuilderImpl.scala b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/parser/ExpressionBuilderImpl.scala similarity index 61% rename from cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/parser/NumExprBuilderImpl.scala rename to cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/parser/ExpressionBuilderImpl.scala index ad23a2283..a116b4028 100644 --- a/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/parser/NumExprBuilderImpl.scala +++ b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/parser/ExpressionBuilderImpl.scala @@ -21,15 +21,16 @@ import za.co.absa.cobrix.cobol.parser.expression.exception.ExprSyntaxError import scala.annotation.tailrec import scala.collection.mutable.ListBuffer -class NumExprBuilderImpl(vars: Map[String, Int], expr: String) extends NumExprBuilder { +class ExpressionBuilderImpl(vars: Map[String, Int], expr: String) extends NumExprBuilder { val ops = new ListBuffer[String] - val values = new ListBuffer[Int] + val valuesInt = new ListBuffer[Int] + val valuesBool = new ListBuffer[Boolean] override def openParen(pos: Int): Unit = ops += "(" override def closeParen(pos: Int): Unit = { if (ops.isEmpty) { - if (values.size != 1) { + if (valuesInt.size != 1) { throw new ExprSyntaxError(s"Empty expression at $pos in '$expr'.") } } else { @@ -68,11 +69,18 @@ class NumExprBuilderImpl(vars: Map[String, Int], expr: String) extends NumExprBu ops += "/" } + override def addOperationEquals(pos: Int): Unit = { + while (ops.nonEmpty && "+-*/=".contains(ops.last)) { + eval() + } + ops += "=" + } + override def addVariable(name: String, pos: Int): Unit = { if (!vars.contains(name)) { throw new ExprSyntaxError(s"Unset variable '$name' used.") } else { - values += vars(name) + valuesInt += vars(name) } } @@ -81,19 +89,36 @@ class NumExprBuilderImpl(vars: Map[String, Int], expr: String) extends NumExprBu } override def addNumLiteral(num: Int, pos: Int): Unit = { - values += num + valuesInt += num + } + + def getIntResult: Int = { + while (ops.nonEmpty) { + eval() + } + if (valuesInt.isEmpty && valuesBool.isEmpty) { + throw new ExprSyntaxError(s"Empty expressions are not supported in '$expr'.") + } else if (valuesInt.isEmpty) { + throw new ExprSyntaxError(s"The expression does not return a number in '$expr'.") + } else if (valuesInt.size > 1 || (valuesInt.nonEmpty && valuesBool.nonEmpty)) { + throw new ExprSyntaxError(s"Malformed expression: '$expr'.") + } else { + valuesInt.head + } } - def getResult: Int = { + def getBoolResult: Boolean = { while (ops.nonEmpty) { eval() } - if (values.isEmpty) { + if (valuesInt.isEmpty && valuesBool.isEmpty) { throw new ExprSyntaxError(s"Empty expressions are not supported in '$expr'.") - } else if (values.size > 1) { + } else if (valuesBool.isEmpty) { + throw new ExprSyntaxError(s"The expression does not return a boolean in '$expr'.") + } else if (valuesBool.size > 1 || (valuesInt.nonEmpty && valuesBool.nonEmpty)) { throw new ExprSyntaxError(s"Malformed expression: '$expr'.") } else { - values.head + valuesBool.head } } @@ -105,37 +130,48 @@ class NumExprBuilderImpl(vars: Map[String, Int], expr: String) extends NumExprBu op match { case "(" => if (ops.nonEmpty && ops.last != "(") eval() case "+" => - expectArguments(2) + expectIntArguments(2) val b = getInt val a = getInt - values += a + b + valuesInt += a + b case "-" => - expectArguments(2) + expectIntArguments(2) val b = getInt val a = getInt - values += a - b + valuesInt += a - b case "*" => - expectArguments(2) + expectIntArguments(2) val b = getInt val a = getInt - values += a * b + valuesInt += a * b case "/" => - expectArguments(2) + expectIntArguments(2) val b = getInt val a = getInt - values += a / b + valuesInt += a / b + case "=" => + expectIntArguments(2) + val b = getInt + val a = getInt + valuesBool += a == b case f => throw new ExprSyntaxError(s"Unsupported function '$f' in '$expr'.") } } - private def expectArguments(n: Int): Unit = { - if (values.size < n) + private def expectIntArguments(n: Int): Unit = { + if (valuesInt.size < n) throw new ExprSyntaxError(s"Expected more arguments in '$expr'.") } private def getInt: Int = { - val a = values.last - values.remove(values.size - 1) + val a = valuesInt.last + valuesInt.remove(valuesInt.size - 1) + a + } + + private def getBool: Boolean = { + val a = valuesBool.last + valuesBool.remove(valuesBool.size - 1) a } } diff --git a/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/parser/ExtractVariablesBuilder.scala b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/parser/ExtractVariablesBuilder.scala index dd76a2e0b..3d15c1182 100644 --- a/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/parser/ExtractVariablesBuilder.scala +++ b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/parser/ExtractVariablesBuilder.scala @@ -33,6 +33,8 @@ class ExtractVariablesBuilder(expr: String) extends NumExprBuilder { override def addOperationDivide(pos: Int): Unit = {} + override def addOperationEquals(pos: Int): Unit = {} + override def addVariable(name: String, pos: Int): Unit = { variables += name } diff --git a/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/parser/NumExprBuilder.scala b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/parser/NumExprBuilder.scala index e95b9a4d3..79108638b 100644 --- a/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/parser/NumExprBuilder.scala +++ b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/parser/NumExprBuilder.scala @@ -23,6 +23,7 @@ trait NumExprBuilder { def addOperationMinus(pos: Int): Unit def addOperationMultiply(pos: Int): Unit def addOperationDivide(pos: Int): Unit + def addOperationEquals(pos: Int): Unit def addVariable(name: String, pos: Int): Unit def addFunction(name: String, pos: Int): Unit diff --git a/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/parser/Parser.scala b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/parser/Parser.scala index bc6252cd2..5af6b1149 100644 --- a/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/parser/Parser.scala +++ b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/parser/Parser.scala @@ -91,6 +91,9 @@ object Parser { case DIV(pos) => builder.addOperationDivide(pos) state = STATE0 + case EQ(pos) => + builder.addOperationEquals(pos) + state = STATE0 case NAME(pos, s) => builder.addFunction(s, pos) case NUM_LITERAL(pos, s) => diff --git a/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/reader/extractors/raw/FixedWithRecordLengthExprRawRecordExtractor.scala b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/reader/extractors/raw/FixedWithRecordLengthExprRawRecordExtractor.scala index dd53c91c3..d4270bba8 100644 --- a/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/reader/extractors/raw/FixedWithRecordLengthExprRawRecordExtractor.scala +++ b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/reader/extractors/raw/FixedWithRecordLengthExprRawRecordExtractor.scala @@ -180,7 +180,7 @@ class FixedWithRecordLengthExprRawRecordExtractor(ctx: RawRecordContext, } } - val recordLength = evaluator.eval() + val recordLength = evaluator.evalInt() val restOfDataLength = recordLength - lengthFieldBlock + readerProperties.endOffset diff --git a/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/reader/iterator/RecordLengthExpression.scala b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/reader/iterator/RecordLengthExpression.scala index 51b36fd34..4df9fcf68 100644 --- a/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/reader/iterator/RecordLengthExpression.scala +++ b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/reader/iterator/RecordLengthExpression.scala @@ -17,11 +17,11 @@ package za.co.absa.cobrix.cobol.reader.iterator import za.co.absa.cobrix.cobol.parser.ast.Primitive -import za.co.absa.cobrix.cobol.parser.expression.NumberExprEvaluator +import za.co.absa.cobrix.cobol.parser.expression.ExpressionEvaluator case class RecordLengthExpression( expr: String, - evaluator: NumberExprEvaluator, + evaluator: ExpressionEvaluator, fields: Map[String, Primitive], requiredBytesToread: Int ) diff --git a/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/reader/iterator/RecordLengthField.scala b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/reader/iterator/RecordLengthField.scala index 378c2c64a..f6e68a9cd 100644 --- a/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/reader/iterator/RecordLengthField.scala +++ b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/reader/iterator/RecordLengthField.scala @@ -17,7 +17,7 @@ package za.co.absa.cobrix.cobol.reader.iterator import za.co.absa.cobrix.cobol.parser.ast.Primitive -import za.co.absa.cobrix.cobol.parser.expression.NumberExprEvaluator +import za.co.absa.cobrix.cobol.parser.expression.ExpressionEvaluator case class RecordLengthField( field: Primitive, diff --git a/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/reader/validator/ReaderParametersValidator.scala b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/reader/validator/ReaderParametersValidator.scala index a51f55f30..c8a1a404a 100644 --- a/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/reader/validator/ReaderParametersValidator.scala +++ b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/reader/validator/ReaderParametersValidator.scala @@ -19,7 +19,7 @@ package za.co.absa.cobrix.cobol.reader.validator import org.slf4j.LoggerFactory import za.co.absa.cobrix.cobol.parser.Copybook import za.co.absa.cobrix.cobol.parser.ast.Primitive -import za.co.absa.cobrix.cobol.parser.expression.NumberExprEvaluator +import za.co.absa.cobrix.cobol.parser.expression.ExpressionEvaluator import za.co.absa.cobrix.cobol.reader.iterator.{RecordLengthExpression, RecordLengthField} import za.co.absa.cobrix.cobol.reader.parameters.MultisegmentParameters @@ -65,7 +65,7 @@ object ReaderParametersValidator { @throws(classOf[IllegalStateException]) def getLengthFieldExpr(recordLengthFieldExpr: String, recordLengthMap: Map[String, Int], cobolSchema: Copybook): Option[RecordLengthExpression] = { - val evaluator = new NumberExprEvaluator(recordLengthFieldExpr) + val evaluator = new ExpressionEvaluator(recordLengthFieldExpr) val vars = evaluator.getVariables val fields = vars.map { field => val primitive = getLengthField(field, recordLengthMap, cobolSchema) diff --git a/cobol-parser/src/test/scala/za/co/absa/cobrix/cobol/expression/ExpressionEvaluatorSuite.scala b/cobol-parser/src/test/scala/za/co/absa/cobrix/cobol/expression/ExpressionEvaluatorSuite.scala index 1da25e628..30272dee9 100644 --- a/cobol-parser/src/test/scala/za/co/absa/cobrix/cobol/expression/ExpressionEvaluatorSuite.scala +++ b/cobol-parser/src/test/scala/za/co/absa/cobrix/cobol/expression/ExpressionEvaluatorSuite.scala @@ -17,13 +17,14 @@ package za.co.absa.cobrix.cobol.expression import org.scalatest.wordspec.AnyWordSpec -import za.co.absa.cobrix.cobol.parser.expression.NumberExprEvaluator +import za.co.absa.cobrix.cobol.parser.expression.ExpressionEvaluator +import za.co.absa.cobrix.cobol.parser.expression.exception.ExprSyntaxError class ExpressionEvaluatorSuite extends AnyWordSpec { "getVariables()" should { "return the list of variables in expressions" in { val expr = "d + b * c - a / d" - val vars = new NumberExprEvaluator(expr).getVariables + val vars = new ExpressionEvaluator(expr).getVariables assert(vars == Seq("a", "b", "c", "d")) } } @@ -46,7 +47,7 @@ class ExpressionEvaluatorSuite extends AnyWordSpec { exprs.foreach { case (expr, expectedResult) => s"$expr" in { - val actualResult = new NumberExprEvaluator(expr).eval() + val actualResult = new ExpressionEvaluator(expr).evalInt() assert(actualResult == expectedResult) } } @@ -54,12 +55,45 @@ class ExpressionEvaluatorSuite extends AnyWordSpec { "evaluate expressions with variables" in { val expr = "10 * (a1 + 5) * bcd" - val evaluator = new NumberExprEvaluator(expr) + val evaluator = new ExpressionEvaluator(expr) evaluator.setValue("a1", 2) evaluator.setValue("bcd", 3) - val actualResult = evaluator.eval() + val actualResult = evaluator.evalInt() assert(actualResult == 210) } + + "evaluate boolean expressions" in { + val expr = "a1*2 = 4" + val evaluator = new ExpressionEvaluator(expr) + evaluator.setValue("a1", 2) + + val actualResult = evaluator.evalBool() + assert(actualResult) + } + + "fail when int expected but boolean returned" in { + val expr = "a1*2 = 4" + val evaluator = new ExpressionEvaluator(expr) + evaluator.setValue("a1", 2) + + val ex = intercept[ExprSyntaxError] { + evaluator.evalInt() + } + + assert(ex.getMessage == "The expression does not return a number in 'a1*2 = 4'.") + } + + "fail when bool expected but int returned" in { + val expr = "a1*2" + val evaluator = new ExpressionEvaluator(expr) + evaluator.setValue("a1", 2) + + val ex = intercept[ExprSyntaxError] { + evaluator.evalBool() + } + + assert(ex.getMessage == "The expression does not return a boolean in 'a1*2'.") + } } } From 80e45d7ab78ff3a3045b86f39a1859bcbf3b4da3 Mon Sep 17 00:00:00 2001 From: Ruslan Iushchenko Date: Tue, 26 May 2026 08:52:43 +0200 Subject: [PATCH 2/7] #850 Add Configuration parameters and AST boilerplate code for rule-based expressions attached to schema fields. --- .../cobol/parser/antlr/ParserVisitor.scala | 1 + .../absa/cobrix/cobol/parser/ast/Group.scala | 13 ++++- .../cobrix/cobol/parser/ast/Primitive.scala | 13 ++++- .../cobrix/cobol/parser/ast/Statement.scala | 8 +++ .../expression/ExpressionEvaluator.scala | 2 +- .../reader/parameters/CobolParameters.scala | 2 + .../parameters/CobolParametersParser.scala | 53 +++++++++++++++++++ .../reader/parameters/ReaderParameters.scala | 3 ++ .../parser/extract/BinaryExtractorSpec.scala | 2 +- .../cobol/source/ParametersParsingSpec.scala | 11 ++++ 10 files changed, 104 insertions(+), 4 deletions(-) diff --git a/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/antlr/ParserVisitor.scala b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/antlr/ParserVisitor.scala index 3bb1ed853..1ae409cbf 100644 --- a/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/antlr/ParserVisitor.scala +++ b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/antlr/ParserVisitor.scala @@ -859,6 +859,7 @@ class ParserVisitor(enc: Encoding, Map(), isDependee = false, identifier.toUpperCase() == Constants.FILLER, + None, DecoderSelector.getDecoder(pic.value, stringTrimmingPolicy, isDisplayAlwaysString, effectiveEbcdicCodePage, effectiveAsciiCharset, isUtf16BigEndian = isUtf16BigEndian, floatingPointFormat, strictSignOverpunch = strictSignOverpunch, improvedNullDetection = improvedNullDetection, strictIntegralPrecision = strictIntegralPrecision), EncoderSelector.getEncoder(pic.value, effectiveEbcdicCodePage, effectiveAsciiCharset) )(Some(parent)) diff --git a/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/ast/Group.scala b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/ast/Group.scala index ca1711b60..29ce3a3b4 100644 --- a/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/ast/Group.scala +++ b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/ast/Group.scala @@ -17,6 +17,7 @@ package za.co.absa.cobrix.cobol.parser.ast import za.co.absa.cobrix.cobol.parser.ast.datatype.Usage +import za.co.absa.cobrix.cobol.parser.expression.ExpressionEvaluator import scala.collection.mutable @@ -57,6 +58,7 @@ case class Group( isFiller: Boolean = false, groupUsage: Option[Usage] = None, nonFillerSize: Int = 0, + ruleExpression: Option[ExpressionEvaluator] = None, binaryProperties: BinaryProperties = BinaryProperties(0, 0, 0) ) (val parent: Option[Group] = None) @@ -81,7 +83,10 @@ case class Group( } /** Returns true if the field is a child segment */ - def isChildSegment: Boolean = parentSegment.nonEmpty + override def isChildSegment: Boolean = parentSegment.nonEmpty + + /** Returns true if the field is enabled for the input binary record. Uses the rule expression to determine that. */ + override def enabledForRecord(record: Array[Byte]): Boolean = true /** Returns the original Group with updated children */ def withUpdatedChildren(newChildren: mutable.ArrayBuffer[Statement]): Group = { @@ -108,10 +113,16 @@ case class Group( copy(parentSegment = newParentSegmentOpt)(parent) } + /** Returns the original field with updated `dependingOnHandlers` */ def withUpdatedDependingOnHandlers(newDependingOnHandlers: Map[String, Int]): Group = { copy(dependingOnHandlers = newDependingOnHandlers)(parent) } + /** Returns the original field with updated `ruleExpression` */ + def withUpdatedRuleExpression(newRuleExpression: Option[ExpressionEvaluator]): Group = { + copy(ruleExpression = newRuleExpression)(parent) + } + } object Group { diff --git a/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/ast/Primitive.scala b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/ast/Primitive.scala index e1627c626..4ccef1e23 100644 --- a/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/ast/Primitive.scala +++ b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/ast/Primitive.scala @@ -19,6 +19,7 @@ package za.co.absa.cobrix.cobol.parser.ast import za.co.absa.cobrix.cobol.parser.ast.datatype.{AlphaNumeric, COMP3, CobolType, Decimal, Integral} import za.co.absa.cobrix.cobol.parser.decoders.{BinaryUtils, DecoderSelector} import za.co.absa.cobrix.cobol.parser.encoding.{ASCII, EBCDIC, EncoderSelector} +import za.co.absa.cobrix.cobol.parser.expression.ExpressionEvaluator /** An abstraction of the statements describing fields of primitive data types in the COBOL copybook * @@ -53,6 +54,7 @@ case class Primitive( dependingOnHandlers: Map[String, Int] = Map(), isDependee: Boolean = false, isFiller: Boolean = false, + ruleExpression: Option[ExpressionEvaluator] = None, decode: DecoderSelector.Decoder, encode: Option[EncoderSelector.Encoder], binaryProperties: BinaryProperties = BinaryProperties(0, 0, 0) @@ -107,7 +109,10 @@ case class Primitive( } /** Returns true if the field is a child segment */ - def isChildSegment: Boolean = false + override def isChildSegment: Boolean = false + + /** Returns true if the field is enabled for the input binary record. Uses the rule expression to determine that. */ + override def enabledForRecord(record: Array[Byte]): Boolean = true /** Returns the original field with updated binary properties */ def withUpdatedBinaryProperties(newBinaryProperties: BinaryProperties): Primitive = { @@ -124,10 +129,16 @@ case class Primitive( copy(isDependee = newIsDependee)(parent) } + /** Returns the original field with updated `dependingOnHandlers` */ def withUpdatedDependingOnHandlers(newDependingOnHandlers: Map[String, Int]): Primitive = { copy(dependingOnHandlers = newDependingOnHandlers)(parent) } + /** Returns the original field with updated `ruleExpression` */ + def withUpdatedRuleExpression(newRuleExpression: Option[ExpressionEvaluator]): Primitive = { + copy(ruleExpression = newRuleExpression)(parent) + } + /** Returns the binary size in bits for the field */ def getBinarySizeBytes: Int = { dataType match { diff --git a/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/ast/Statement.scala b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/ast/Statement.scala index cc833ccec..013ee950e 100644 --- a/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/ast/Statement.scala +++ b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/ast/Statement.scala @@ -16,6 +16,8 @@ package za.co.absa.cobrix.cobol.parser.ast +import za.co.absa.cobrix.cobol.parser.expression.ExpressionEvaluator + /** Trait for Cobol copybook AST element (a statement). */ trait Statement { /** Returns the level of the AST element */ @@ -82,6 +84,12 @@ trait Statement { /** Returns true if the field is a child segment */ def isChildSegment: Boolean + /** The expression for the field enablement. Usually used for redefined fields. */ + def ruleExpression: Option[ExpressionEvaluator] + + /** Returns true if the field is enabled for the input binary record. Uses the rule expression to determine that. */ + def enabledForRecord(record: Array[Byte]): Boolean + /** A binary properties of a field */ val binaryProperties: BinaryProperties diff --git a/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/ExpressionEvaluator.scala b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/ExpressionEvaluator.scala index dbda5489d..1900223ab 100644 --- a/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/ExpressionEvaluator.scala +++ b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/ExpressionEvaluator.scala @@ -34,7 +34,7 @@ import scala.collection.mutable * assert(evaluator.eval() == 549) * }}} */ -class ExpressionEvaluator(expr: String) { +class ExpressionEvaluator(val expr: String) { private val tokens = new Lexer(expr).lex() private val vars = mutable.HashMap[String, Int]() diff --git a/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/reader/parameters/CobolParameters.scala b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/reader/parameters/CobolParameters.scala index 9924a724f..a716a00ab 100644 --- a/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/reader/parameters/CobolParameters.scala +++ b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/reader/parameters/CobolParameters.scala @@ -45,6 +45,7 @@ import za.co.absa.cobrix.cobol.reader.policies.SchemaRetentionPolicy.SchemaReten * @param minimumRecordLength Minium record length for which the record is considered valid. * @param maximumRecordLength Maximum record length for which the record is considered valid. * @param variableLengthParams VariableLengthParameters containing the specifications for the consumption of variable-length Cobol records. + * @param redefineRuleExpressions A map of REDEFINE field names to expressions that determine which redefine alternative to use when parsing records. * @param variableSizeOccurs Specifies how to handle OCCURS DEPENDING ON when the actual number of elements in arrays is less than the maximum array size * @param generateRecordBytes Generate 'record_bytes' field containing raw bytes of the original record * @param generateCorruptFields Generate '_corrupt_fields' field for fields that haven't converted successfully @@ -88,6 +89,7 @@ case class CobolParameters( minimumRecordLength: Option[Int], maximumRecordLength: Option[Int], variableLengthParams: Option[VariableLengthParameters], + redefineRuleExpressions: Map[String, String], variableSizeOccurs: VariableSizeOccursPolicy, generateRecordBytes: Boolean, generateCorruptFields: Boolean, diff --git a/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/reader/parameters/CobolParametersParser.scala b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/reader/parameters/CobolParametersParser.scala index bc6c230c5..261ce060d 100644 --- a/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/reader/parameters/CobolParametersParser.scala +++ b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/reader/parameters/CobolParametersParser.scala @@ -21,6 +21,7 @@ import za.co.absa.cobrix.cobol.parser.CopybookParser import za.co.absa.cobrix.cobol.parser.antlr.ParserJson import za.co.absa.cobrix.cobol.parser.decoders.FloatingPointFormat import za.co.absa.cobrix.cobol.parser.decoders.FloatingPointFormat.FloatingPointFormat +import za.co.absa.cobrix.cobol.parser.expression.ExpressionEvaluator import za.co.absa.cobrix.cobol.parser.policies.DebugFieldsPolicy.DebugFieldsPolicy import za.co.absa.cobrix.cobol.parser.policies.StringTrimmingPolicy.StringTrimmingPolicy import za.co.absa.cobrix.cobol.parser.policies._ @@ -122,6 +123,8 @@ object CobolParametersParser extends Logging { val PARAM_INPUT_FILE_COLUMN = "with_input_file_name_col" val PARAM_SEGMENT_REDEFINE_PREFIX = "redefine_segment_id_map" val PARAM_SEGMENT_REDEFINE_PREFIX_ALT = "redefine-segment-id-map" + val PARAM_REDEFINE_RULE_PREFIX = "redefine_rule" + val PARAM_REDEFINE_RULE_PREFIX_ALT = "redefine-rule" // Indexed multisegment file processing val PARAM_ENABLE_INDEXES = "enable_indexes" @@ -303,6 +306,7 @@ object CobolParametersParser extends Logging { params.get(PARAM_MINIMUM_RECORD_LENGTH).map(_.toInt), params.get(PARAM_MAXIMUM_RECORD_LENGTH).map(_.toInt), variableLengthParams, + getRedefineRuleExpressionMapping(params), variableSizeOccursPolicy, params.getOrElse(PARAM_GENERATE_RECORD_BYTES, "false").toBoolean, params.getOrElse(PARAM_CORRUPT_FIELDS, "false").toBoolean, @@ -443,6 +447,12 @@ object CobolParametersParser extends Logging { val recordsToExclude = (parameters.fileHeaderField.map(n => CopybookParser.transformIdentifier(n).toUpperCase).toSet ++ parameters.fileTrailerField.map(n => CopybookParser.transformIdentifier(n).toUpperCase).toSet) + + val ruleExpressionMap = parameters.redefineRuleExpressions map { + case (field, exprStr) => + val expr = new ExpressionEvaluator(exprStr) + (field, expr) + } ReaderParameters( recordFormat = parameters.recordFormat, @@ -454,6 +464,7 @@ object CobolParametersParser extends Logging { fieldCodePage = parameters.fieldCodePage, isUtf16BigEndian = parameters.isUtf16BigEndian, floatingPointFormat = parameters.floatingPointFormat, + redefineRuleExpressions = ruleExpressionMap, variableSizeOccurs = parameters.variableSizeOccurs, recordLength = parameters.recordLength, minimumRecordLength = parameters.minimumRecordLength.getOrElse(1), @@ -747,6 +758,48 @@ object CobolParametersParser extends Logging { } } + /** + * Parses the list of redefines rules their corresponding expressions. + * + * Example: + * For + * {{{ + * sprak.read + * .option("redefine-rule:1", "COMPANY => RECORD_TYPE = 1") + * .option("redefine-rule:2", "CONTACT => RECORD_TYPE = 2") + * }}} + * + * The corresponding mapping will be: + * + * {{{ + * "RECORD_TYPE = 1" -> "COMPANY" + * "RECORD_TYPE = 2" -> "COMPANY" + * }}} + * + * @param params Parameters provided by spark.read.option(...) + * @return Returns a sequence of redefine rules + */ + @throws(classOf[IllegalArgumentException]) + def getRedefineRuleExpressionMapping(params: Parameters): Map[String, String] = { + params.getMap.flatMap { + case (k, v) => + val keyNoCase = k.toLowerCase + if (keyNoCase.startsWith(PARAM_REDEFINE_RULE_PREFIX) || + keyNoCase.startsWith(PARAM_REDEFINE_RULE_PREFIX_ALT)) { + params.markUsed(k) + val splitVal = v.split("\\=\\>") + if (splitVal.lengthCompare(2) != 0) { + throw new IllegalArgumentException(s"Illegal argument for the '$PARAM_REDEFINE_RULE_PREFIX' option: '$v'.") + } + val redefine = splitVal(0).trim + val rule = splitVal(1).trim + Option((CopybookParser.transformIdentifier(redefine), rule)) + } else { + None + } + } + } + /** * Parses the list of sergent redefine fields and their children for a hierarchical data. * Produces a mapping between redefined fields and their parents. diff --git a/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/reader/parameters/ReaderParameters.scala b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/reader/parameters/ReaderParameters.scala index 5204cb11e..4f486cdd0 100644 --- a/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/reader/parameters/ReaderParameters.scala +++ b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/reader/parameters/ReaderParameters.scala @@ -18,6 +18,7 @@ package za.co.absa.cobrix.cobol.reader.parameters import za.co.absa.cobrix.cobol.parser.decoders.FloatingPointFormat import za.co.absa.cobrix.cobol.parser.decoders.FloatingPointFormat.FloatingPointFormat +import za.co.absa.cobrix.cobol.parser.expression.ExpressionEvaluator import za.co.absa.cobrix.cobol.parser.policies.DebugFieldsPolicy.DebugFieldsPolicy import za.co.absa.cobrix.cobol.parser.policies.StringTrimmingPolicy.StringTrimmingPolicy import za.co.absa.cobrix.cobol.parser.policies._ @@ -37,6 +38,7 @@ import za.co.absa.cobrix.cobol.reader.policies.SchemaRetentionPolicy.SchemaReten * @param fieldCodePage Specifies a mapping between a field name and the code page * @param isUtf16BigEndian If true UTF-16 strings are considered big-endian. * @param floatingPointFormat A format of floating-point numbers + * @param redefineRuleExpressions A map of REDEFINE field names to expressions that determine which redefine alternative to use when parsing records. * @param variableSizeOccurs Specifies how to handle OCCURS DEPENDING ON when the actual number of elements in arrays is less than the maximum array size * @param recordLength Specifies the length of the record disregarding the copybook record size. Implied the file has fixed record length. * @param minimumRecordLength Minium record length for which the record is considered valid. @@ -92,6 +94,7 @@ case class ReaderParameters( fieldCodePage: Map[String, String] = Map.empty[String, String], isUtf16BigEndian: Boolean = true, floatingPointFormat: FloatingPointFormat = FloatingPointFormat.IBM, + redefineRuleExpressions: Map[String, ExpressionEvaluator] = Map.empty, variableSizeOccurs: VariableSizeOccursPolicy = VariableSizeOccursPolicy.MaxSize, recordLength: Option[Int] = None, minimumRecordLength: Int = 1, diff --git a/cobol-parser/src/test/scala/za/co/absa/cobrix/cobol/parser/extract/BinaryExtractorSpec.scala b/cobol-parser/src/test/scala/za/co/absa/cobrix/cobol/parser/extract/BinaryExtractorSpec.scala index 43dbc850b..be94b990e 100644 --- a/cobol-parser/src/test/scala/za/co/absa/cobrix/cobol/parser/extract/BinaryExtractorSpec.scala +++ b/cobol-parser/src/test/scala/za/co/absa/cobrix/cobol/parser/extract/BinaryExtractorSpec.scala @@ -161,7 +161,7 @@ class BinaryExtractorSpec extends AnyFunSuite { val binaryProperties: BinaryProperties = BinaryProperties(2, 10, 10) val primitive: Primitive = Primitive(level, name, name, lineNumber, dataType, redefines, isRedefined, - occurs, to, dependingOn, Map(), isDependee, isFiller, DecoderSelector.getDecoder(dataType), EncoderSelector.getEncoder(dataType), binaryProperties)(None) + occurs, to, dependingOn, Map(), isDependee, isFiller, None, DecoderSelector.getDecoder(dataType), EncoderSelector.getEncoder(dataType), binaryProperties)(None) val result2: Any = Copybook.extractPrimitiveField(primitive, bytes, startOffset) assert(result2.asInstanceOf[String] === "EXAMPLE4") } diff --git a/spark-cobol/src/test/scala/za/co/absa/cobrix/spark/cobol/source/ParametersParsingSpec.scala b/spark-cobol/src/test/scala/za/co/absa/cobrix/spark/cobol/source/ParametersParsingSpec.scala index 5534f5d2f..0a9dc06cc 100644 --- a/spark-cobol/src/test/scala/za/co/absa/cobrix/spark/cobol/source/ParametersParsingSpec.scala +++ b/spark-cobol/src/test/scala/za/co/absa/cobrix/spark/cobol/source/ParametersParsingSpec.scala @@ -35,6 +35,17 @@ class ParametersParsingSpec extends AnyFunSuite { assert(segmentIdMapping.get("Q").isEmpty) } + test("Test redefine rule expression mapping") { + val config = HashMap[String,String] ( + "redefine-rule:1" -> "COMPANY => RECORD_TYPE = 1", + "redefine_rule:2" -> "CONTACT => RECORD_TYPE = 2") + + val ruleExpressions = CobolParametersParser.getRedefineRuleExpressionMapping(new Parameters(config)) + + assert(ruleExpressions("COMPANY") == "RECORD_TYPE = 1") + assert(ruleExpressions("CONTACT") == "RECORD_TYPE = 2") + } + test("Test field - parent field mapping") { val config = HashMap[String,String] ("is_record_sequence"-> "true", "segment-children:1" -> "COMPANY => DEPT,CUSTOMER", From 965a831e40b369828f4cd278523f536b32eb62c2 Mon Sep 17 00:00:00 2001 From: Ruslan Iushchenko Date: Tue, 26 May 2026 15:44:26 +0200 Subject: [PATCH 3/7] #850 Add minimum implementation of redefine rules based on numeric distinguisher records. --- .../cobrix/cobol/parser/CopybookParser.scala | 13 +- .../cobol/parser/antlr/ParserVisitor.scala | 3 +- .../cobrix/cobol/parser/ast/Primitive.scala | 7 + .../asttransform/RuleExpressionSetter.scala | 83 +++++++++++ .../expression/ExpressionEvaluator.scala | 2 +- .../extractors/record/RecordExtractors.scala | 114 +++++++++++---- .../parameters/CobolParametersParser.scala | 3 +- .../cobol/reader/schema/CobolSchema.scala | 2 + .../parser/extract/BinaryExtractorSpec.scala | 3 +- .../integration/Test42RedefineRulesSpec.scala | 136 ++++++++++++++++++ 10 files changed, 333 insertions(+), 33 deletions(-) create mode 100644 cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/asttransform/RuleExpressionSetter.scala create mode 100644 spark-cobol/src/test/scala/za/co/absa/cobrix/spark/cobol/source/integration/Test42RedefineRulesSpec.scala diff --git a/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/CopybookParser.scala b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/CopybookParser.scala index fb352a1e0..fda49abf3 100644 --- a/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/CopybookParser.scala +++ b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/CopybookParser.scala @@ -25,6 +25,7 @@ import za.co.absa.cobrix.cobol.parser.decoders.FloatingPointFormat.FloatingPoint import za.co.absa.cobrix.cobol.parser.encoding.codepage.{CodePage, CodePageCommon} import za.co.absa.cobrix.cobol.parser.encoding.{EBCDIC, Encoding} import za.co.absa.cobrix.cobol.parser.exceptions.SyntaxErrorException +import za.co.absa.cobrix.cobol.parser.expression.ExpressionEvaluator import za.co.absa.cobrix.cobol.parser.policies.DebugFieldsPolicy.DebugFieldsPolicy import za.co.absa.cobrix.cobol.parser.policies.StringTrimmingPolicy.StringTrimmingPolicy import za.co.absa.cobrix.cobol.parser.policies.{CommentPolicy, DebugFieldsPolicy, FillerNamingPolicy, StringTrimmingPolicy} @@ -124,6 +125,7 @@ object CopybookParser extends Logging { * @param isUtf16BigEndian If true UTF-16 strings are considered big-endian. * @param floatingPointFormat A format of floating-point numbers (IBM/IEEE754). * @param nonTerminals A list of non-terminals that should be extracted as strings. + * @param redefineRuleExpressions A map of REDEFINE field names to expressions that determine which redefine alternative to use when parsing records. * @param debugFieldsPolicy Specifies if debugging fields need to be added and what should they contain (false, hex, raw). * @return Seq[Group] where a group is a record inside the copybook. */ @@ -147,6 +149,7 @@ object CopybookParser extends Logging { floatingPointFormat: FloatingPointFormat = FloatingPointFormat.IBM, nonTerminals: Seq[String] = Nil, occursHandlers: Map[String, Map[String, Int]] = Map(), + redefineRuleExpressions: Map[String, ExpressionEvaluator] = Map.empty, debugFieldsPolicy: DebugFieldsPolicy = DebugFieldsPolicy.NoDebug, fieldCodePageMap: Map[String, String] = Map.empty[String, String]): Copybook = { parseTree(dataEncoding, @@ -169,6 +172,7 @@ object CopybookParser extends Logging { floatingPointFormat, nonTerminals, occursHandlers, + redefineRuleExpressions, debugFieldsPolicy, fieldCodePageMap) } @@ -192,6 +196,7 @@ object CopybookParser extends Logging { * @param isUtf16BigEndian If true UTF-16 strings are considered big-endian. * @param floatingPointFormat A format of floating-point numbers (IBM/IEEE754) * @param nonTerminals A list of non-terminals that should be extracted as strings + * @param redefineRuleExpressions A map of REDEFINE field names to expressions that determine which redefine alternative to use when parsing records. * @param debugFieldsPolicy Specifies if debugging fields need to be added and what should they contain (false, hex, raw). * @return Seq[Group] where a group is a record inside the copybook */ @@ -214,6 +219,7 @@ object CopybookParser extends Logging { floatingPointFormat: FloatingPointFormat = FloatingPointFormat.IBM, nonTerminals: Seq[String] = Nil, occursHandlers: Map[String, Map[String, Int]] = Map(), + redefineRuleExpressions: Map[String, ExpressionEvaluator] = Map.empty, debugFieldsPolicy: DebugFieldsPolicy = DebugFieldsPolicy.NoDebug, fieldCodePageMap: Map[String, String] = Map.empty[String, String]): Copybook = { parseTree(EBCDIC, @@ -236,6 +242,7 @@ object CopybookParser extends Logging { floatingPointFormat, nonTerminals, occursHandlers, + redefineRuleExpressions, debugFieldsPolicy, fieldCodePageMap) } @@ -259,6 +266,7 @@ object CopybookParser extends Logging { * @param isUtf16BigEndian If true UTF-16 strings are considered big-endian. * @param floatingPointFormat A format of floating-point numbers (IBM/IEEE754) * @param nonTerminals A list of non-terminals that should be extracted as strings + * @param redefineRuleExpressions A map of REDEFINE field names to expressions that determine which redefine alternative to use when parsing records. * @param debugFieldsPolicy Specifies if debugging fields need to be added and what should they contain (false, hex, raw). * @return Seq[Group] where a group is a record inside the copybook */ @@ -283,6 +291,7 @@ object CopybookParser extends Logging { floatingPointFormat: FloatingPointFormat, nonTerminals: Seq[String], occursHandlers: Map[String, Map[String, Int]], + redefineRuleExpressions: Map[String, ExpressionEvaluator], debugFieldsPolicy: DebugFieldsPolicy, fieldCodePageMap: Map[String, String]): Copybook = { @@ -313,7 +322,9 @@ object CopybookParser extends Logging { // Add debugging fields if debug mode is enabled. DebugFieldsAdder(debugFieldsPolicy), // For each group calculates the number of non-filler items. - NonFillerCountSetter() + NonFillerCountSetter(), + // Sets isUsedInRules and rule expressions for each field + RuleExpressionSetter(redefineRuleExpressions) ) val transformedAst = transformers.foldLeft(schemaANTLR) { diff --git a/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/antlr/ParserVisitor.scala b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/antlr/ParserVisitor.scala index 1ae409cbf..80f785d4d 100644 --- a/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/antlr/ParserVisitor.scala +++ b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/antlr/ParserVisitor.scala @@ -858,7 +858,8 @@ class ParserVisitor(enc: Encoding, if (occurs.isDefined) occurs.get.dep else None, Map(), isDependee = false, - identifier.toUpperCase() == Constants.FILLER, + isUsedInRules = false, + isFiller = identifier.toUpperCase() == Constants.FILLER, None, DecoderSelector.getDecoder(pic.value, stringTrimmingPolicy, isDisplayAlwaysString, effectiveEbcdicCodePage, effectiveAsciiCharset, isUtf16BigEndian = isUtf16BigEndian, floatingPointFormat, strictSignOverpunch = strictSignOverpunch, improvedNullDetection = improvedNullDetection, strictIntegralPrecision = strictIntegralPrecision), EncoderSelector.getEncoder(pic.value, effectiveEbcdicCodePage, effectiveAsciiCharset) diff --git a/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/ast/Primitive.scala b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/ast/Primitive.scala index 4ccef1e23..c67f8d7e1 100644 --- a/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/ast/Primitive.scala +++ b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/ast/Primitive.scala @@ -34,6 +34,7 @@ import za.co.absa.cobrix.cobol.parser.expression.ExpressionEvaluator * @param dependingOn A field which specifies size of the array in a record * @param dependingOnHandlers A map of handlers for the dependingOn field * @param isDependee A flag indicating if the field is a dependee + * @param isUsedInRules If true, the variable is used in redefine rule expressions * @param isFiller A flag indicating if the field is a filler * @param decode A decoder for the field to convert from raw data to a JVM data type * @param encode An optional encoder for the field to convert from a JVM data type to raw data @@ -53,6 +54,7 @@ case class Primitive( dependingOn: Option[String] = None, dependingOnHandlers: Map[String, Int] = Map(), isDependee: Boolean = false, + isUsedInRules: Boolean = false, isFiller: Boolean = false, ruleExpression: Option[ExpressionEvaluator] = None, decode: DecoderSelector.Decoder, @@ -129,6 +131,11 @@ case class Primitive( copy(isDependee = newIsDependee)(parent) } + /** Returns the original field with updated `isUsedInRules` flag */ + def withUpdatedIsUsedInRules(newIsUsedInRules: Boolean): Primitive = { + copy(isUsedInRules = newIsUsedInRules)(parent) + } + /** Returns the original field with updated `dependingOnHandlers` */ def withUpdatedDependingOnHandlers(newDependingOnHandlers: Map[String, Int]): Primitive = { copy(dependingOnHandlers = newDependingOnHandlers)(parent) diff --git a/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/asttransform/RuleExpressionSetter.scala b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/asttransform/RuleExpressionSetter.scala new file mode 100644 index 000000000..a477b337b --- /dev/null +++ b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/asttransform/RuleExpressionSetter.scala @@ -0,0 +1,83 @@ +/* + * Copyright 2018 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.cobrix.cobol.parser.asttransform + +import org.slf4j.LoggerFactory +import za.co.absa.cobrix.cobol.parser.CopybookParser.CopybookAST +import za.co.absa.cobrix.cobol.parser.ast.datatype.{Decimal, Integral} +import za.co.absa.cobrix.cobol.parser.ast.{Group, Primitive, Statement} +import za.co.absa.cobrix.cobol.parser.expression.ExpressionEvaluator + +import scala.collection.mutable +import scala.collection.mutable.ListBuffer + +class RuleExpressionSetter( + redefineRuleExpressions: Map[String, ExpressionEvaluator] + ) extends AstTransformer { + private val log = LoggerFactory.getLogger(this.getClass) + + /** + * Sets isDependee attribute for fields in the schema which are used by other fields in DEPENDING ON clause + * + * @param ast An AST as a set of copybook records + * @return The same AST with binary properties set for every field + */ + final override def transform(ast: CopybookAST): CopybookAST = { + val ruleDrivenRedefines = redefineRuleExpressions.keys.toSet + val dependeeFields = redefineRuleExpressions.values.flatMap { expr => + expr.getVariables + }.toSet + + def markRuleForGroup(group: Group): Group = { + val newChildren = markRuleFields(group) + val newGroup = if (redefineRuleExpressions contains group.name) { + group.copy(children = newChildren.children, ruleExpression = redefineRuleExpressions.get(group.name))(group.parent) + } else { + group.copy(children = newChildren.children)(group.parent) + } + newGroup + } + + def markRuleFields(group: CopybookAST): CopybookAST = { + val newChildren = for (field <- group.children) yield { + val newField: Statement = field match { + case grp: Group => markRuleForGroup(grp) + case primitive: Primitive => + val newPrimitive1 = if (redefineRuleExpressions contains primitive.name) { + primitive.withUpdatedRuleExpression(redefineRuleExpressions.get(primitive.name)) + } else { + primitive + } + val newPrimitive2 = if (dependeeFields contains primitive.name) { + newPrimitive1.withUpdatedIsUsedInRules(newIsUsedInRules = true) + } else { + newPrimitive1 + } + newPrimitive2 + } + newField + } + group.copy(children = newChildren)(group.parent) + } + + markRuleFields(ast) + } +} + +object RuleExpressionSetter { + def apply(redefineRuleExpressions: Map[String, ExpressionEvaluator]): RuleExpressionSetter = new RuleExpressionSetter(redefineRuleExpressions) +} diff --git a/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/ExpressionEvaluator.scala b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/ExpressionEvaluator.scala index 1900223ab..0fba8c3d5 100644 --- a/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/ExpressionEvaluator.scala +++ b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/ExpressionEvaluator.scala @@ -34,7 +34,7 @@ import scala.collection.mutable * assert(evaluator.eval() == 549) * }}} */ -class ExpressionEvaluator(val expr: String) { +class ExpressionEvaluator(val expr: String) extends Serializable { private val tokens = new Lexer(expr).lex() private val vars = mutable.HashMap[String, Int]() diff --git a/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/reader/extractors/record/RecordExtractors.scala b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/reader/extractors/record/RecordExtractors.scala index 13b1d2791..fe84f47a3 100644 --- a/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/reader/extractors/record/RecordExtractors.scala +++ b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/reader/extractors/record/RecordExtractors.scala @@ -77,6 +77,7 @@ object RecordExtractors { ): Seq[Any] = { val dependFields = scala.collection.mutable.HashMap.empty[String, Either[Int, String]] val corruptFields = new ArrayBuffer[CorruptField] + val variables = new mutable.HashMap[String, Any]() val isAstFlat = ast.children.exists(_.isInstanceOf[Primitive]) @@ -143,23 +144,36 @@ object RecordExtractors { if (grp.isSegmentRedefine && grp.name.compareToIgnoreCase(activeSegmentRedefine) != 0) { (grp.binaryProperties.actualSize, null) } else { - getGroupValues(useOffset, grp) + val extract = canExtract(grp, variables) + if (extract) { + getGroupValues(useOffset, grp) + } else { + (grp.binaryProperties.actualSize, null) + } } case st: Primitive => - val value = st.decodeTypeValue(useOffset, data) - if (value == null && generateCorruptFields && !st.isEmpty(useOffset, data)) { - corruptFields += CorruptField(field.name, st.getRawValue(useOffset,data)) - } - if (value != null && st.isDependee) { - val intStringVal: Either[Int, String] = value match { - case v: Int => Left(v) - case v: Number => Left(v.intValue()) - case v: String => Right(v) - case v => throw new IllegalStateException(s"Field ${st.name} is an a DEPENDING ON field of an OCCURS, should be integral or 'occurs_mapping' should be defined, found ${v.getClass}.") + val extract = canExtract(st, variables) + if (extract) { + val value = st.decodeTypeValue(useOffset, data) + if (value == null && generateCorruptFields && !st.isEmpty(useOffset, data)) { + corruptFields += CorruptField(field.name, st.getRawValue(useOffset, data)) + } + if (st.isUsedInRules) { + variables += st.name -> value + } + if (value != null && st.isDependee) { + val intStringVal: Either[Int, String] = value match { + case v: Int => Left(v) + case v: Number => Left(v.intValue()) + case v: String => Right(v) + case v => throw new IllegalStateException(s"Field ${st.name} is an a DEPENDING ON field of an OCCURS, should be integral or 'occurs_mapping' should be defined, found ${v.getClass}.") + } + dependFields += st.name -> intStringVal } - dependFields += st.name -> intStringVal + (st.binaryProperties.actualSize, value) + } else { + (st.binaryProperties.actualSize, null) } - (st.binaryProperties.actualSize, value) } } @@ -208,12 +222,22 @@ object RecordExtractors { val records: ListBuffer[T] = ListBuffer.empty[T] - for (record <- rootRecords if !recordsToExclude.contains(record.name.toUpperCase)) yield { - val (size, values) = getGroupValues(nextOffset, record.asInstanceOf[Group]) - if (!record.isRedefined) { - nextOffset += size + if (recordsToExclude.isEmpty) { + for (record <- rootRecords) yield { + val (size, values) = getGroupValues(nextOffset, record.asInstanceOf[Group]) + if (!record.isRedefined) { + nextOffset += size + } + records += values + } + } else { + for (record <- rootRecords if !recordsToExclude.contains(record.name.toUpperCase)) yield { + val (size, values) = getGroupValues(nextOffset, record.asInstanceOf[Group]) + if (!record.isRedefined) { + nextOffset += size + } + records += values } - records += values } val effectiveSchemaRetentionPolicy = if (isAstFlat) { @@ -269,6 +293,7 @@ object RecordExtractors { recordsToExclude: Set[String] = Set.empty ): Seq[Any] = { val isAstFlat = ast.children.exists(_.isInstanceOf[Primitive]) + val variables = new mutable.HashMap[String, Any]() val dependFields = scala.collection.mutable.HashMap.empty[String, Either[Int, String]] @@ -325,19 +350,32 @@ object RecordExtractors { def extractValue(field: Statement, useOffset: Int, data: Array[Byte], currentIndex: Int, parentSegmentIds: List[String]): (Int, Any) = { field match { case grp: Group => - getGroupValues(useOffset, grp, data, currentIndex, parentSegmentIds) + val extract = canExtract(grp, variables) + if (extract) { + getGroupValues(useOffset, grp, data, currentIndex, parentSegmentIds) + } else { + (grp.binaryProperties.actualSize, null) + } case st: Primitive => - val value = st.decodeTypeValue(useOffset, data) - if (value != null && st.isDependee) { - val intStringVal: Either[Int, String] = value match { - case v: Int => Left(v) - case v: Number => Left(v.intValue()) - case v: String => Right(v) - case v => throw new IllegalStateException(s"Field ${st.name} is an a DEPENDING ON field of an OCCURS, should be integral or 'occurs_mapping' should be defined, found ${v.getClass}.") + val extract = canExtract(st, variables) + if (extract) { + val value = st.decodeTypeValue(useOffset, data) + if (st.isUsedInRules) { + variables += st.name -> value } - dependFields += st.name -> intStringVal + if (value != null && st.isDependee) { + val intStringVal: Either[Int, String] = value match { + case v: Int => Left(v) + case v: Number => Left(v.intValue()) + case v: String => Right(v) + case v => throw new IllegalStateException(s"Field ${st.name} is an a DEPENDING ON field of an OCCURS, should be integral or 'occurs_mapping' should be defined, found ${v.getClass}.") + } + dependFields += st.name -> intStringVal + } + (st.binaryProperties.actualSize, value) + } else { + (st.binaryProperties.actualSize, null) } - (st.binaryProperties.actualSize, value) } } @@ -556,4 +594,24 @@ object RecordExtractors { Group(10, Constants.corruptFieldsField, Constants.corruptFieldsField, 0, children = corruptFieldsInGroup, occurs = Some(10))(None) } + + /** Applies redefine expression rules if defined. */ + def canExtract(field: Statement, variables: mutable.Map[String, Any]): Boolean = { + field.ruleExpression match { + case Some(expr) => + variables.foreach { + case (k, v) => + if (v == null) + return false + expr.setValue(k, v.toString.toInt) + } + if (expr.evalBool()) { + true + } else { + false + } + case None => + true + } + } } diff --git a/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/reader/parameters/CobolParametersParser.scala b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/reader/parameters/CobolParametersParser.scala index 261ce060d..74a9bb1b0 100644 --- a/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/reader/parameters/CobolParametersParser.scala +++ b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/reader/parameters/CobolParametersParser.scala @@ -451,7 +451,8 @@ object CobolParametersParser extends Logging { val ruleExpressionMap = parameters.redefineRuleExpressions map { case (field, exprStr) => val expr = new ExpressionEvaluator(exprStr) - (field, expr) + val fixedField = CopybookParser.transformIdentifier(field).toUpperCase + (fixedField, expr) } ReaderParameters( diff --git a/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/reader/schema/CobolSchema.scala b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/reader/schema/CobolSchema.scala index db0be84dd..e29381e94 100644 --- a/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/reader/schema/CobolSchema.scala +++ b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/reader/schema/CobolSchema.scala @@ -119,6 +119,7 @@ object CobolSchema { readerParameters.floatingPointFormat, readerParameters.nonTerminals, readerParameters.occursMappings, + readerParameters.redefineRuleExpressions, readerParameters.debugFieldsPolicy, readerParameters.fieldCodePage) else @@ -143,6 +144,7 @@ object CobolSchema { readerParameters.floatingPointFormat, nonTerminals = readerParameters.nonTerminals, readerParameters.occursMappings, + readerParameters.redefineRuleExpressions, readerParameters.debugFieldsPolicy, readerParameters.fieldCodePage) )) diff --git a/cobol-parser/src/test/scala/za/co/absa/cobrix/cobol/parser/extract/BinaryExtractorSpec.scala b/cobol-parser/src/test/scala/za/co/absa/cobrix/cobol/parser/extract/BinaryExtractorSpec.scala index be94b990e..8438c452c 100644 --- a/cobol-parser/src/test/scala/za/co/absa/cobrix/cobol/parser/extract/BinaryExtractorSpec.scala +++ b/cobol-parser/src/test/scala/za/co/absa/cobrix/cobol/parser/extract/BinaryExtractorSpec.scala @@ -157,11 +157,12 @@ class BinaryExtractorSpec extends AnyFunSuite { val to: Option[Int] = None val dependingOn: Option[String] = None val isDependee: Boolean = false + val isUsedInRules: Boolean = false val isFiller: Boolean = false val binaryProperties: BinaryProperties = BinaryProperties(2, 10, 10) val primitive: Primitive = Primitive(level, name, name, lineNumber, dataType, redefines, isRedefined, - occurs, to, dependingOn, Map(), isDependee, isFiller, None, DecoderSelector.getDecoder(dataType), EncoderSelector.getEncoder(dataType), binaryProperties)(None) + occurs, to, dependingOn, Map(), isDependee, isUsedInRules, isFiller, None, DecoderSelector.getDecoder(dataType), EncoderSelector.getEncoder(dataType), binaryProperties)(None) val result2: Any = Copybook.extractPrimitiveField(primitive, bytes, startOffset) assert(result2.asInstanceOf[String] === "EXAMPLE4") } diff --git a/spark-cobol/src/test/scala/za/co/absa/cobrix/spark/cobol/source/integration/Test42RedefineRulesSpec.scala b/spark-cobol/src/test/scala/za/co/absa/cobrix/spark/cobol/source/integration/Test42RedefineRulesSpec.scala new file mode 100644 index 000000000..02a347e71 --- /dev/null +++ b/spark-cobol/src/test/scala/za/co/absa/cobrix/spark/cobol/source/integration/Test42RedefineRulesSpec.scala @@ -0,0 +1,136 @@ +/* + * Copyright 2018 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.cobrix.spark.cobol.source.integration + +import org.scalatest.wordspec.AnyWordSpec +import za.co.absa.cobrix.spark.cobol.source.base.SparkTestBase +import za.co.absa.cobrix.spark.cobol.source.fixtures.{BinaryFileFixture, TextComparisonFixture} +import za.co.absa.cobrix.spark.cobol.utils.SparkUtils + +import java.nio.charset.StandardCharsets + +class Test42RedefineRulesSpec extends AnyWordSpec with SparkTestBase with BinaryFileFixture with TextComparisonFixture { + private val copybook = + """ 01 R. + 03 ID PIC 9(1). + 03 G1. + 04 F1 PIC S9(2). + 03 G2 REDEFINES G1. + 04 F2 PIC X(2). + 03 G3 REDEFINES G1. + 04 F3 PIC 9(1). + """ + + "Files with redefine rules" should { + "extract data according to the rules" when { + val dataAscii = "111\n222\n33\n" + withTempTextFile("redefine_rules1", ".dat", StandardCharsets.UTF_8, dataAscii) { tmpFileName => + val df = spark + .read + .format("cobol") + .option("copybook_contents", copybook) + .option("record_format", "D") + .option("redefine-rule:1", "G1 => ID = 1") + .option("redefine-rule:2", "G2 => ID = 2") + .option("redefine-rule:3", "G3 => ID = 3") + .option("pedantic", "true") + .load(tmpFileName) + + val actualSchema = df.schema.treeString + val actualData = SparkUtils.prettyJSON(df.toJSON.collect().mkString("[", ",", "]")) + + "schema should match" in { + val expectedSchema = + """root + | |-- ID: integer (nullable = true) + | |-- G1: struct (nullable = true) + | | |-- F1: integer (nullable = true) + | |-- G2: struct (nullable = true) + | | |-- F2: string (nullable = true) + | |-- G3: struct (nullable = true) + | | |-- F3: integer (nullable = true) + |""".stripMargin + + + compareTextVertical(actualSchema, expectedSchema) + } + + "data should match" in { + val expectedData = + """[ { + | "ID" : 1, + | "G1" : { + | "F1" : 11 + | } + |}, { + | "ID" : 2, + | "G2" : { + | "F2" : "22" + | } + |}, { + | "ID" : 3, + | "G3" : { + | "F3" : 3 + | } + |} ]""".stripMargin + + compareTextVertical(actualData, expectedData) + } + } + } + + "extract data according to the rules with null input fields" when { + val data = Array( + 0xF1, 0xF1, 0xF1, + 0xF2, 0xF2, 0xF2, + 0x00, 0xF3, 0xF3 + ).map(_.toByte) + + withTempBinFile("redefine_rules2", ".dat", data) { tmpFileName => + val df = spark + .read + .format("cobol") + .option("copybook_contents", copybook) + .option("record_format", "F") + .option("redefine-rule:1", "G1 => ID = 1") + .option("redefine-rule:2", "G2 => ID = 2") + .option("redefine-rule:3", "G3 => ID = 3") + .option("pedantic", "true") + .load(tmpFileName) + + val actualData = SparkUtils.prettyJSON(df.toJSON.collect().mkString("[", ",", "]")) + + "data should match" in { + val expectedData = + """[ { + | "ID" : 1, + | "G1" : { + | "F1" : 11 + | } + |}, { + | "ID" : 2, + | "G2" : { + | "F2" : "22" + | } + |}, { } ]""".stripMargin + + compareTextVertical(actualData, expectedData) + } + } + } + } +} From 340e33599081d07c87e36b40cf6dbb634a07edab Mon Sep 17 00:00:00 2001 From: Ruslan Iushchenko Date: Wed, 27 May 2026 08:41:35 +0200 Subject: [PATCH 4/7] #850 Add support for string literals, null values, comparison/logical operators, and in()/if() functions in expression evaluator. --- .../expression/ExpressionEvaluator.scala | 39 +- .../cobol/parser/expression/lexer/Lexer.scala | 71 +++- .../cobol/parser/expression/lexer/Token.scala | 42 ++ ...rBuilder.scala => ExpressionBuilder.scala} | 13 +- .../parser/ExpressionBuilderImpl.scala | 371 ++++++++++++++++- .../parser/ExtractVariablesBuilder.scala | 24 +- .../parser/expression/parser/Parser.scala | 67 ++- .../extractors/record/RecordExtractors.scala | 9 +- .../expression/ExpressionEvaluatorSuite.scala | 386 ++++++++++++++++++ .../integration/Test42RedefineRulesSpec.scala | 50 +++ 10 files changed, 1050 insertions(+), 22 deletions(-) rename cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/parser/{NumExprBuilder.scala => ExpressionBuilder.scala} (68%) diff --git a/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/ExpressionEvaluator.scala b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/ExpressionEvaluator.scala index 0fba8c3d5..4621b6bb1 100644 --- a/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/ExpressionEvaluator.scala +++ b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/ExpressionEvaluator.scala @@ -38,11 +38,39 @@ class ExpressionEvaluator(val expr: String) extends Serializable { private val tokens = new Lexer(expr).lex() private val vars = mutable.HashMap[String, Int]() + private val stringVars = mutable.HashMap[String, String]() + private val nullVars = mutable.HashSet[String]() def setValue(varName: String, value: Int): Unit = { + nullVars -= varName + stringVars -= varName vars += varName -> value } + def setValue(varName: String, value: java.lang.Integer): Unit = { + if (value == null) { + setNullValue(varName) + } else { + setValue(varName, value.intValue()) + } + } + + def setStringValue(varName: String, value: String): Unit = { + if (value == null) { + setNullValue(varName) + } else { + nullVars -= varName + vars -= varName + stringVars += varName -> value + } + } + + def setNullValue(varName: String): Unit = { + vars -= varName + stringVars -= varName + nullVars += varName + } + def getVariables: Seq[String] = { val exprBuilder = new ExtractVariablesBuilder(expr) Parser.parse(tokens, exprBuilder) @@ -51,16 +79,23 @@ class ExpressionEvaluator(val expr: String) extends Serializable { } def evalInt(): Int = { - val exprBuilder = new ExpressionBuilderImpl(vars.toMap, expr) + val exprBuilder = new ExpressionBuilderImpl(vars.toMap, stringVars.toMap, nullVars.toSet, expr) Parser.parse(tokens, exprBuilder) exprBuilder.getIntResult } def evalBool(): Boolean = { - val exprBuilder = new ExpressionBuilderImpl(vars.toMap, expr) + val exprBuilder = new ExpressionBuilderImpl(vars.toMap, stringVars.toMap, nullVars.toSet, expr) Parser.parse(tokens, exprBuilder) exprBuilder.getBoolResult } + + def evalString(): String = { + val exprBuilder = new ExpressionBuilderImpl(vars.toMap, stringVars.toMap, nullVars.toSet, expr) + Parser.parse(tokens, exprBuilder) + + exprBuilder.getStringResult + } } diff --git a/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/lexer/Lexer.scala b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/lexer/Lexer.scala index c5e70c3b8..56b475d97 100644 --- a/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/lexer/Lexer.scala +++ b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/lexer/Lexer.scala @@ -35,7 +35,7 @@ class Lexer(expression: String) { tokens.clear() while (pos < expression.length) { - val ok = findOneCharTokens() || findWhiteSpace() || findName() || findNumLiteral() + val ok = findStringLiteral() || findTwoCharTokens() || findOneCharTokens() || findWhiteSpace() || findName() || findNumLiteral() if (!ok) { throw new ExprSyntaxError(s"Unexpected character '${expression(pos)}' at position: $pos") } @@ -43,6 +43,63 @@ class Lexer(expression: String) { tokens.toArray } + def findStringLiteral(): Boolean = { + if (expression(pos) != '\'') { + return false + } + + val startPos = pos + pos += 1 + val sb = new StringBuilder + + while (pos < expression.length) { + if (expression(pos) == '\'') { + if (pos + 1 < expression.length && expression(pos + 1) == '\'') { + // Escaped quote: '' becomes ' + sb.append('\'') + pos += 2 + } else { + // End of string + pos += 1 + tokens += STRING_LITERAL(startPos, sb.toString()) + return true + } + } else { + sb.append(expression(pos)) + pos += 1 + } + } + + throw new ExprSyntaxError(s"Unterminated string literal starting at position: $startPos") + } + + def findTwoCharTokens(): Boolean = { + if (pos >= expression.length - 1) { + return false + } + + val c1 = expression(pos) + val c2 = expression(pos + 1) + + val found: Option[Token] = (c1, c2) match { + case ('>', '=') => Some(GTE(pos)) + case ('<', '=') => Some(LTE(pos)) + case ('!', '=') => Some(NE(pos)) + case ('&', '&') => Some(AND(pos)) + case ('|', '|') => Some(OR(pos)) + case _ => None + } + + found match { + case Some(t) => + tokens += t + pos += 2 + true + case None => + false + } + } + def findOneCharTokens(): Boolean = { val c = expression(pos) @@ -55,6 +112,11 @@ class Lexer(expression: String) { case '*' => Some(MULT(pos)) case '/' => Some(DIV(pos)) case '=' => Some(EQ(pos)) + case '>' => Some(GT(pos)) + case '<' => Some(LT(pos)) + case '!' => Some(NOT(pos)) + case '&' => throw new ExprSyntaxError(s"Unexpected character '&' at position $pos. Did you mean '&&'?") + case '|' => throw new ExprSyntaxError(s"Unexpected character '|' at position $pos. Did you mean '||'?") case _ => None } @@ -89,7 +151,12 @@ class Lexer(expression: String) { while (pos2 < expression.length && nameMidChars.contains(expression(pos2))) { pos2 += 1 } - val token = NAME(pos, expression.substring(pos, pos2)) + val name = expression.substring(pos, pos2) + val token = if (name.toLowerCase == "null") { + NULL_LITERAL(pos) + } else { + NAME(pos, name) + } tokens += token pos = pos2 true diff --git a/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/lexer/Token.scala b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/lexer/Token.scala index a8c091aba..da62eaed1 100644 --- a/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/lexer/Token.scala +++ b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/lexer/Token.scala @@ -58,6 +58,38 @@ object Token { override def toString = "=" } + case class GT(pos: Int) extends Token { + override def toString = ">" + } + + case class LT(pos: Int) extends Token { + override def toString = "<" + } + + case class GTE(pos: Int) extends Token { + override def toString = ">=" + } + + case class LTE(pos: Int) extends Token { + override def toString = "<=" + } + + case class NE(pos: Int) extends Token { + override def toString = "!=" + } + + case class AND(pos: Int) extends Token { + override def toString = "&&" + } + + case class OR(pos: Int) extends Token { + override def toString = "||" + } + + case class NOT(pos: Int) extends Token { + override def toString = "!" + } + case class NAME(pos: Int, s: String) extends Token { override def toString: String = s @@ -67,4 +99,14 @@ object Token { { override def toString: String = s } + + case class STRING_LITERAL(pos: Int, s: String) extends Token + { + override def toString: String = s"'$s'" + } + + case class NULL_LITERAL(pos: Int) extends Token + { + override def toString: String = "null" + } } diff --git a/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/parser/NumExprBuilder.scala b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/parser/ExpressionBuilder.scala similarity index 68% rename from cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/parser/NumExprBuilder.scala rename to cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/parser/ExpressionBuilder.scala index 79108638b..1a56d80c2 100644 --- a/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/parser/NumExprBuilder.scala +++ b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/parser/ExpressionBuilder.scala @@ -16,16 +16,27 @@ package za.co.absa.cobrix.cobol.parser.expression.parser -trait NumExprBuilder { +trait ExpressionBuilder { def openParen(pos: Int): Unit def closeParen(pos: Int): Unit + def addComma(pos: Int): Unit def addOperationPlus(pos: Int): Unit def addOperationMinus(pos: Int): Unit def addOperationMultiply(pos: Int): Unit def addOperationDivide(pos: Int): Unit def addOperationEquals(pos: Int): Unit + def addOperationGreaterThan(pos: Int): Unit + def addOperationLessThan(pos: Int): Unit + def addOperationGreaterThanOrEqual(pos: Int): Unit + def addOperationLessThanOrEqual(pos: Int): Unit + def addOperationNotEqual(pos: Int): Unit + def addOperationAnd(pos: Int): Unit + def addOperationOr(pos: Int): Unit + def addOperationNot(pos: Int): Unit def addVariable(name: String, pos: Int): Unit def addFunction(name: String, pos: Int): Unit def addNumLiteral(num: Int, pos: Int): Unit + def addStringLiteral(s: String, pos: Int): Unit + def addNullLiteral(pos: Int): Unit } diff --git a/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/parser/ExpressionBuilderImpl.scala b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/parser/ExpressionBuilderImpl.scala index a116b4028..cd7016e5c 100644 --- a/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/parser/ExpressionBuilderImpl.scala +++ b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/parser/ExpressionBuilderImpl.scala @@ -21,12 +21,24 @@ import za.co.absa.cobrix.cobol.parser.expression.exception.ExprSyntaxError import scala.annotation.tailrec import scala.collection.mutable.ListBuffer -class ExpressionBuilderImpl(vars: Map[String, Int], expr: String) extends NumExprBuilder { +class ExpressionBuilderImpl(vars: Map[String, Int], stringVars: Map[String, String], nullVars: Set[String], expr: String) extends ExpressionBuilder { + sealed trait ValueType + case object IntType extends ValueType + case object BoolType extends ValueType + case object StringType extends ValueType + case object NullType extends ValueType + val ops = new ListBuffer[String] val valuesInt = new ListBuffer[Int] val valuesBool = new ListBuffer[Boolean] + val valuesString = new ListBuffer[String] + val valueTypes = new ListBuffer[ValueType] + val argCounts = new ListBuffer[Int] // Track argument counts for functions - override def openParen(pos: Int): Unit = ops += "(" + override def openParen(pos: Int): Unit = { + ops += "(" + argCounts += 1 // Start counting arguments (at least 1 if not empty) + } override def closeParen(pos: Int): Unit = { if (ops.isEmpty) { @@ -37,7 +49,37 @@ class ExpressionBuilderImpl(vars: Map[String, Int], expr: String) extends NumExp while (ops.last != "(") { eval() } - ops.remove(ops.size - 1) + ops.remove(ops.size - 1) // Remove "(" + + val currentArgCount = if (argCounts.nonEmpty) { + val count = argCounts.last + argCounts.remove(argCounts.size - 1) + count + } else 0 + + // Check if there's a function to evaluate + if (ops.nonEmpty && !isOperator(ops.last)) { + evalFunction(ops.last, currentArgCount) + ops.remove(ops.size - 1) + } + } + } + + private def isOperator(s: String): Boolean = { + s match { + case "(" | "+" | "-" | "*" | "/" | "=" | "!=" | ">" | "<" | ">=" | "<=" | "&&" | "||" | "!" => true + case _ => false + } + } + + override def addComma(pos: Int): Unit = { + // Evaluate pending operators within the current parentheses + while (ops.nonEmpty && ops.last != "(") { + eval() + } + // Increment argument count for current function call + if (argCounts.nonEmpty) { + argCounts(argCounts.size - 1) += 1 } } @@ -70,17 +112,76 @@ class ExpressionBuilderImpl(vars: Map[String, Int], expr: String) extends NumExp } override def addOperationEquals(pos: Int): Unit = { - while (ops.nonEmpty && "+-*/=".contains(ops.last)) { + while (ops.nonEmpty && "+-*/=!><".contains(ops.last.head)) { eval() } ops += "=" } + override def addOperationGreaterThan(pos: Int): Unit = { + while (ops.nonEmpty && "+-*/=!><".contains(ops.last.head)) { + eval() + } + ops += ">" + } + + override def addOperationLessThan(pos: Int): Unit = { + while (ops.nonEmpty && "+-*/=!><".contains(ops.last.head)) { + eval() + } + ops += "<" + } + + override def addOperationGreaterThanOrEqual(pos: Int): Unit = { + while (ops.nonEmpty && "+-*/=!><".contains(ops.last.head)) { + eval() + } + ops += ">=" + } + + override def addOperationLessThanOrEqual(pos: Int): Unit = { + while (ops.nonEmpty && "+-*/=!><".contains(ops.last.head)) { + eval() + } + ops += "<=" + } + + override def addOperationNotEqual(pos: Int): Unit = { + while (ops.nonEmpty && "+-*/=!><".contains(ops.last.head)) { + eval() + } + ops += "!=" + } + + override def addOperationAnd(pos: Int): Unit = { + while (ops.nonEmpty && "+-*/=!><&".contains(ops.last.head)) { + eval() + } + ops += "&&" + } + + override def addOperationOr(pos: Int): Unit = { + while (ops.nonEmpty && ops.last != "||" && ops.last != "(") { + eval() + } + ops += "||" + } + + override def addOperationNot(pos: Int): Unit = { + ops += "!" + } + override def addVariable(name: String, pos: Int): Unit = { - if (!vars.contains(name)) { - throw new ExprSyntaxError(s"Unset variable '$name' used.") - } else { + if (nullVars.contains(name)) { + valueTypes += NullType + } else if (vars.contains(name)) { valuesInt += vars(name) + valueTypes += IntType + } else if (stringVars.contains(name)) { + valuesString += stringVars(name) + valueTypes += StringType + } else { + throw new ExprSyntaxError(s"Unset variable '$name' used.") } } @@ -90,6 +191,16 @@ class ExpressionBuilderImpl(vars: Map[String, Int], expr: String) extends NumExp override def addNumLiteral(num: Int, pos: Int): Unit = { valuesInt += num + valueTypes += IntType + } + + override def addStringLiteral(s: String, pos: Int): Unit = { + valuesString += s + valueTypes += StringType + } + + override def addNullLiteral(pos: Int): Unit = { + valueTypes += NullType } def getIntResult: Int = { @@ -111,17 +222,32 @@ class ExpressionBuilderImpl(vars: Map[String, Int], expr: String) extends NumExp while (ops.nonEmpty) { eval() } - if (valuesInt.isEmpty && valuesBool.isEmpty) { + if (valuesInt.isEmpty && valuesBool.isEmpty && valuesString.isEmpty) { throw new ExprSyntaxError(s"Empty expressions are not supported in '$expr'.") } else if (valuesBool.isEmpty) { throw new ExprSyntaxError(s"The expression does not return a boolean in '$expr'.") - } else if (valuesBool.size > 1 || (valuesInt.nonEmpty && valuesBool.nonEmpty)) { + } else if (valuesBool.size > 1 || (valuesInt.nonEmpty && valuesBool.nonEmpty) || (valuesString.nonEmpty && valuesBool.nonEmpty)) { throw new ExprSyntaxError(s"Malformed expression: '$expr'.") } else { valuesBool.head } } + def getStringResult: String = { + while (ops.nonEmpty) { + eval() + } + if (valuesInt.isEmpty && valuesBool.isEmpty && valuesString.isEmpty) { + throw new ExprSyntaxError(s"Empty expressions are not supported in '$expr'.") + } else if (valuesString.isEmpty) { + throw new ExprSyntaxError(s"The expression does not return a string in '$expr'.") + } else if (valuesString.size > 1 || (valuesInt.nonEmpty && valuesString.nonEmpty) || (valuesBool.nonEmpty && valuesString.nonEmpty)) { + throw new ExprSyntaxError(s"Malformed expression: '$expr'.") + } else { + valuesString.head + } + } + @tailrec private def eval(): Unit = { val op = ops.last @@ -131,33 +257,206 @@ class ExpressionBuilderImpl(vars: Map[String, Int], expr: String) extends NumExp case "(" => if (ops.nonEmpty && ops.last != "(") eval() case "+" => expectIntArguments(2) + valueTypes.remove(valueTypes.size - 1) + valueTypes.remove(valueTypes.size - 1) val b = getInt val a = getInt valuesInt += a + b + valueTypes += IntType case "-" => expectIntArguments(2) + valueTypes.remove(valueTypes.size - 1) + valueTypes.remove(valueTypes.size - 1) val b = getInt val a = getInt valuesInt += a - b + valueTypes += IntType case "*" => expectIntArguments(2) + valueTypes.remove(valueTypes.size - 1) + valueTypes.remove(valueTypes.size - 1) val b = getInt val a = getInt valuesInt += a * b + valueTypes += IntType case "/" => expectIntArguments(2) + valueTypes.remove(valueTypes.size - 1) + valueTypes.remove(valueTypes.size - 1) val b = getInt val a = getInt valuesInt += a / b + valueTypes += IntType case "=" => - expectIntArguments(2) - val b = getInt - val a = getInt + val (b, a, vType) = getTwoComparableValues() valuesBool += a == b + valueTypes += BoolType + case "!=" => + val (b, a, vType) = getTwoComparableValues() + valuesBool += a != b + valueTypes += BoolType + case ">" => + val (b, a, vType) = getTwoComparableValues() + valuesBool += (vType match { + case IntType => a.asInstanceOf[Int] > b.asInstanceOf[Int] + case StringType => a.asInstanceOf[String] > b.asInstanceOf[String] + case _ => throw new ExprSyntaxError(s"Cannot use > operator with $vType in '$expr'.") + }) + valueTypes += BoolType + case "<" => + val (b, a, vType) = getTwoComparableValues() + valuesBool += (vType match { + case IntType => a.asInstanceOf[Int] < b.asInstanceOf[Int] + case StringType => a.asInstanceOf[String] < b.asInstanceOf[String] + case _ => throw new ExprSyntaxError(s"Cannot use < operator with $vType in '$expr'.") + }) + valueTypes += BoolType + case ">=" => + val (b, a, vType) = getTwoComparableValues() + valuesBool += (vType match { + case IntType => a.asInstanceOf[Int] >= b.asInstanceOf[Int] + case StringType => a.asInstanceOf[String] >= b.asInstanceOf[String] + case _ => throw new ExprSyntaxError(s"Cannot use >= operator with $vType in '$expr'.") + }) + valueTypes += BoolType + case "<=" => + val (b, a, vType) = getTwoComparableValues() + valuesBool += (vType match { + case IntType => a.asInstanceOf[Int] <= b.asInstanceOf[Int] + case StringType => a.asInstanceOf[String] <= b.asInstanceOf[String] + case _ => throw new ExprSyntaxError(s"Cannot use <= operator with $vType in '$expr'.") + }) + valueTypes += BoolType + case "&&" => + expectBoolArguments(2) + valueTypes.remove(valueTypes.size - 1) + valueTypes.remove(valueTypes.size - 1) + val b = getBool + val a = getBool + valuesBool += a && b + valueTypes += BoolType + case "||" => + expectBoolArguments(2) + valueTypes.remove(valueTypes.size - 1) + valueTypes.remove(valueTypes.size - 1) + val b = getBool + val a = getBool + valuesBool += a || b + valueTypes += BoolType + case "!" => + expectBoolArguments(1) + valueTypes.remove(valueTypes.size - 1) + val a = getBool + valuesBool += !a + valueTypes += BoolType case f => throw new ExprSyntaxError(s"Unsupported function '$f' in '$expr'.") } } + private def evalFunction(funcName: String, argCount: Int): Unit = { + funcName.toLowerCase match { + case "in" => evalInFunction(argCount) + case "if" => evalIfFunction(argCount) + case f => throw new ExprSyntaxError(s"Unsupported function '$f' in '$expr'.") + } + } + + private def evalInFunction(argCount: Int): Unit = { + if (argCount < 2) + throw new ExprSyntaxError(s"Function 'in' requires at least 2 arguments in '$expr'.") + + // Get the options (all arguments except the first one) + val optionCount = argCount - 1 + val options = new ListBuffer[(Any, ValueType)] + + for (_ <- 0 until optionCount) { + val vType = valueTypes.last + valueTypes.remove(valueTypes.size - 1) + val value: Any = vType match { + case IntType => getInt + case StringType => getString + case NullType => null + case BoolType => getBool + } + options.prepend((value, vType)) + } + + // Get the value to check + val checkType = valueTypes.last + valueTypes.remove(valueTypes.size - 1) + val checkValue: Any = checkType match { + case IntType => getInt + case StringType => getString + case NullType => null + case BoolType => getBool + } + + // Check if checkValue is in options + val result = options.exists { case (optValue, optType) => + // Allow null comparisons + if (checkType == NullType || optType == NullType) { + checkValue == optValue + } else if (checkType == optType) { + checkValue == optValue + } else { + false + } + } + + valuesBool += result + valueTypes += BoolType + } + + private def evalIfFunction(argCount: Int): Unit = { + if (argCount != 3) + throw new ExprSyntaxError(s"Function 'if' requires exactly 3 arguments in '$expr'.") + + // Get false value (third argument) + val falseType = valueTypes.last + valueTypes.remove(valueTypes.size - 1) + val falseValue: Any = falseType match { + case IntType => getInt + case StringType => getString + case NullType => null + case BoolType => getBool + } + + // Get true value (second argument) + val trueType = valueTypes.last + valueTypes.remove(valueTypes.size - 1) + val trueValue: Any = trueType match { + case IntType => getInt + case StringType => getString + case NullType => null + case BoolType => getBool + } + + // Get condition (first argument) + val condType = valueTypes.last + valueTypes.remove(valueTypes.size - 1) + if (condType != BoolType) + throw new ExprSyntaxError(s"First argument of 'if' must be a boolean expression in '$expr'.") + val condition = getBool + + // Return the appropriate value based on condition + val resultType = if (condition) trueType else falseType + val resultValue = if (condition) trueValue else falseValue + + resultType match { + case IntType => + valuesInt += resultValue.asInstanceOf[Int] + valueTypes += IntType + case StringType => + valuesString += resultValue.asInstanceOf[String] + valueTypes += StringType + case BoolType => + valuesBool += resultValue.asInstanceOf[Boolean] + valueTypes += BoolType + case NullType => + valueTypes += NullType + } + } + private def expectIntArguments(n: Int): Unit = { if (valuesInt.size < n) throw new ExprSyntaxError(s"Expected more arguments in '$expr'.") @@ -174,4 +473,52 @@ class ExpressionBuilderImpl(vars: Map[String, Int], expr: String) extends NumExp valuesBool.remove(valuesBool.size - 1) a } + + private def getString: String = { + val a = valuesString.last + valuesString.remove(valuesString.size - 1) + a + } + + private def expectBoolArguments(n: Int): Unit = { + if (valuesBool.size < n) + throw new ExprSyntaxError(s"Expected boolean arguments in '$expr'.") + } + + private def getTwoComparableValues(): (Any, Any, ValueType) = { + if (valueTypes.size < 2) + throw new ExprSyntaxError(s"Expected more arguments in '$expr'.") + + val type2 = valueTypes.last + valueTypes.remove(valueTypes.size - 1) + val type1 = valueTypes.last + valueTypes.remove(valueTypes.size - 1) + + // Handle null comparisons + if (type1 == NullType || type2 == NullType) { + val val1: Any = type1 match { + case NullType => null + case IntType => getInt + case StringType => getString + case BoolType => getBool + } + val val2: Any = type2 match { + case NullType => null + case IntType => getInt + case StringType => getString + case BoolType => getBool + } + return (val2, val1, NullType) + } + + if (type1 != type2) + throw new ExprSyntaxError(s"Cannot compare $type1 with $type2 in '$expr'.") + + type1 match { + case IntType => (getInt, getInt, IntType) + case StringType => (getString, getString, StringType) + case BoolType => throw new ExprSyntaxError(s"Cannot compare boolean values in '$expr'.") + case NullType => (null, null, NullType) // both nulls + } + } } diff --git a/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/parser/ExtractVariablesBuilder.scala b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/parser/ExtractVariablesBuilder.scala index 3d15c1182..d8be51766 100644 --- a/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/parser/ExtractVariablesBuilder.scala +++ b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/parser/ExtractVariablesBuilder.scala @@ -18,13 +18,15 @@ package za.co.absa.cobrix.cobol.parser.expression.parser import scala.collection.mutable.ListBuffer -class ExtractVariablesBuilder(expr: String) extends NumExprBuilder { +class ExtractVariablesBuilder(expr: String) extends ExpressionBuilder { private val variables = new ListBuffer[String] override def openParen(pos: Int): Unit = { } override def closeParen(pos: Int): Unit = { } + override def addComma(pos: Int): Unit = { } + override def addOperationPlus(pos: Int): Unit = {} override def addOperationMinus(pos: Int): Unit = {} @@ -35,6 +37,22 @@ class ExtractVariablesBuilder(expr: String) extends NumExprBuilder { override def addOperationEquals(pos: Int): Unit = {} + override def addOperationGreaterThan(pos: Int): Unit = {} + + override def addOperationLessThan(pos: Int): Unit = {} + + override def addOperationGreaterThanOrEqual(pos: Int): Unit = {} + + override def addOperationLessThanOrEqual(pos: Int): Unit = {} + + override def addOperationNotEqual(pos: Int): Unit = {} + + override def addOperationAnd(pos: Int): Unit = {} + + override def addOperationOr(pos: Int): Unit = {} + + override def addOperationNot(pos: Int): Unit = {} + override def addVariable(name: String, pos: Int): Unit = { variables += name } @@ -43,6 +61,10 @@ class ExtractVariablesBuilder(expr: String) extends NumExprBuilder { override def addNumLiteral(num: Int, pos: Int): Unit = { } + override def addStringLiteral(s: String, pos: Int): Unit = { } + + override def addNullLiteral(pos: Int): Unit = { } + def getResult: Seq[String] = { variables.distinct.sorted.toSeq } diff --git a/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/parser/Parser.scala b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/parser/Parser.scala index 5af6b1149..1678e804f 100644 --- a/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/parser/Parser.scala +++ b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/parser/Parser.scala @@ -23,10 +23,11 @@ import za.co.absa.cobrix.cobol.parser.expression.lexer.Token._ import scala.collection.mutable.ListBuffer object Parser { - def parse(tokens: Array[Token], builder: NumExprBuilder): Unit = { + def parse(tokens: Array[Token], builder: ExpressionBuilder): Unit = { val STATE0 = 0 val STATE1 = 1 val MINUS_NUM = 3 + val NOT_OP = 4 var state = STATE0 @@ -52,6 +53,8 @@ object Parser { throw new ExprSyntaxError(s"Unexpected '+' at pos $pos") case MINUS(_) => state = MINUS_NUM + case NOT(_) => + state = NOT_OP case NAME(pos, s) => if (i == tokens.length - 1 || !tokens(i + 1).isInstanceOf[OPEN_PARAN]) { builder.addVariable(s, pos) @@ -62,11 +65,18 @@ object Parser { case NUM_LITERAL(pos, s) => builder.addNumLiteral(s.toInt, pos) state = STATE1 + case STRING_LITERAL(pos, s) => + builder.addStringLiteral(s, pos) + state = STATE1 + case NULL_LITERAL(pos) => + builder.addNullLiteral(pos) + state = STATE1 case _ => new ExprSyntaxError(s"Unexpected '$token' at pos ${token.pos}") } } else if (state == STATE1) { token match { - case COMMA(_) => + case COMMA(pos) => + builder.addComma(pos) state = STATE0 case OPEN_PARAN(pos) => paranPos += pos @@ -94,6 +104,27 @@ object Parser { case EQ(pos) => builder.addOperationEquals(pos) state = STATE0 + case GT(pos) => + builder.addOperationGreaterThan(pos) + state = STATE0 + case LT(pos) => + builder.addOperationLessThan(pos) + state = STATE0 + case GTE(pos) => + builder.addOperationGreaterThanOrEqual(pos) + state = STATE0 + case LTE(pos) => + builder.addOperationLessThanOrEqual(pos) + state = STATE0 + case NE(pos) => + builder.addOperationNotEqual(pos) + state = STATE0 + case AND(pos) => + builder.addOperationAnd(pos) + state = STATE0 + case OR(pos) => + builder.addOperationOr(pos) + state = STATE0 case NAME(pos, s) => builder.addFunction(s, pos) case NUM_LITERAL(pos, s) => @@ -116,6 +147,38 @@ object Parser { state = STATE1 case _ => new ExprSyntaxError(s"Unexpected '$token' at pos ${token.pos}") } + } else if (state == NOT_OP) { + token match { + case NOT(_) => + builder.addOperationNot(pos = token.pos) + case OPEN_PARAN(pos) => + paranPos += pos + builder.addOperationNot(pos) + builder.openParen(pos) + state = STATE0 + case NAME(pos, s) => + builder.addOperationNot(pos) + if (i == tokens.length - 1 || !tokens(i + 1).isInstanceOf[OPEN_PARAN]) { + builder.addVariable(s, pos) + state = STATE1 + } else { + builder.addFunction(s, pos) + state = STATE0 + } + case NUM_LITERAL(pos, s) => + builder.addOperationNot(pos) + builder.addNumLiteral(s.toInt, pos) + state = STATE1 + case STRING_LITERAL(pos, s) => + builder.addOperationNot(pos) + builder.addStringLiteral(s, pos) + state = STATE1 + case NULL_LITERAL(pos) => + builder.addOperationNot(pos) + builder.addNullLiteral(pos) + state = STATE1 + case _ => throw new ExprSyntaxError(s"Unexpected '$token' at pos ${token.pos}") + } } i += 1 } diff --git a/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/reader/extractors/record/RecordExtractors.scala b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/reader/extractors/record/RecordExtractors.scala index fe84f47a3..de18cb653 100644 --- a/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/reader/extractors/record/RecordExtractors.scala +++ b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/reader/extractors/record/RecordExtractors.scala @@ -602,8 +602,13 @@ object RecordExtractors { variables.foreach { case (k, v) => if (v == null) - return false - expr.setValue(k, v.toString.toInt) + expr.setNullValue(k) + else { + v match { + case s: String => expr.setStringValue(k, s) + case _ => expr.setValue(k, v.toString.toInt) + } + } } if (expr.evalBool()) { true diff --git a/cobol-parser/src/test/scala/za/co/absa/cobrix/cobol/expression/ExpressionEvaluatorSuite.scala b/cobol-parser/src/test/scala/za/co/absa/cobrix/cobol/expression/ExpressionEvaluatorSuite.scala index 30272dee9..cd7434b06 100644 --- a/cobol-parser/src/test/scala/za/co/absa/cobrix/cobol/expression/ExpressionEvaluatorSuite.scala +++ b/cobol-parser/src/test/scala/za/co/absa/cobrix/cobol/expression/ExpressionEvaluatorSuite.scala @@ -95,5 +95,391 @@ class ExpressionEvaluatorSuite extends AnyWordSpec { assert(ex.getMessage == "The expression does not return a boolean in 'a1*2'.") } + + "evaluate string literal expressions" in { + assert(new ExpressionEvaluator("'hello' = 'hello'").evalBool()) + assert(new ExpressionEvaluator("'abc' != 'def'").evalBool()) + assert(!new ExpressionEvaluator("'test' = 'Test'").evalBool()) + } + + "handle escaped quotes in strings" in { + assert(new ExpressionEvaluator("'It''s' = 'It''s'").evalBool()) + val evaluator = new ExpressionEvaluator("s = 'It''s'") + evaluator.setStringValue("s", "It's") + assert(evaluator.evalBool()) + } + + "handle empty strings" in { + assert(new ExpressionEvaluator("'' = ''").evalBool()) + assert(new ExpressionEvaluator("'' != 'a'").evalBool()) + } + + "evaluate comparison operators with integers" in { + val e1 = new ExpressionEvaluator("a > 5") + e1.setValue("a", 10) + assert(e1.evalBool()) + + val e2 = new ExpressionEvaluator("a > 5") + e2.setValue("a", 3) + assert(!e2.evalBool()) + + assert(new ExpressionEvaluator("3 < 5").evalBool()) + assert(!new ExpressionEvaluator("5 < 3").evalBool()) + assert(new ExpressionEvaluator("5 >= 5").evalBool()) + assert(new ExpressionEvaluator("6 >= 5").evalBool()) + assert(!new ExpressionEvaluator("4 >= 5").evalBool()) + assert(new ExpressionEvaluator("5 <= 5").evalBool()) + assert(!new ExpressionEvaluator("5 <= 4").evalBool()) + assert(new ExpressionEvaluator("5 != 6").evalBool()) + assert(!new ExpressionEvaluator("5 != 5").evalBool()) + } + + "compare strings with operators" in { + assert(new ExpressionEvaluator("'abc' < 'def'").evalBool()) + assert(new ExpressionEvaluator("'xyz' > 'abc'").evalBool()) + assert(new ExpressionEvaluator("'abc' <= 'abc'").evalBool()) + assert(new ExpressionEvaluator("'abc' >= 'abc'").evalBool()) + } + + "compare string variables" in { + val evaluator = new ExpressionEvaluator("status = 'ACTIVE'") + evaluator.setStringValue("status", "ACTIVE") + assert(evaluator.evalBool()) + + val evaluator2 = new ExpressionEvaluator("name > 'John'") + evaluator2.setStringValue("name", "Mary") + assert(evaluator2.evalBool()) + } + + "evaluate AND operations" in { + val e1 = new ExpressionEvaluator("a > 5 && b < 10") + e1.setValue("a", 7) + e1.setValue("b", 8) + assert(e1.evalBool()) + + val e2 = new ExpressionEvaluator("a > 5 && b < 10") + e2.setValue("a", 3) + e2.setValue("b", 8) + assert(!e2.evalBool()) + + val e3 = new ExpressionEvaluator("a > 5 && b < 10") + e3.setValue("a", 7) + e3.setValue("b", 15) + assert(!e3.evalBool()) + } + + "evaluate OR operations" in { + val e1 = new ExpressionEvaluator("a > 100 || b = 5") + e1.setValue("a", 10) + e1.setValue("b", 5) + assert(e1.evalBool()) + + val e2 = new ExpressionEvaluator("a > 100 || b = 5") + e2.setValue("a", 150) + e2.setValue("b", 10) + assert(e2.evalBool()) + + val e3 = new ExpressionEvaluator("a > 100 || b = 5") + e3.setValue("a", 10) + e3.setValue("b", 10) + assert(!e3.evalBool()) + } + + "evaluate NOT operations" in { + val e1 = new ExpressionEvaluator("!(a = 5)") + e1.setValue("a", 10) + assert(e1.evalBool()) + + val e2 = new ExpressionEvaluator("!(a = 5)") + e2.setValue("a", 5) + assert(!e2.evalBool()) + + val e3 = new ExpressionEvaluator("!(a > 5)") + e3.setValue("a", 3) + assert(e3.evalBool()) + } + + "evaluate double NOT operations" in { + val e1 = new ExpressionEvaluator("!!(a = 5)") + e1.setValue("a", 5) + assert(e1.evalBool()) + + val e2 = new ExpressionEvaluator("!!(a = 5)") + e2.setValue("a", 10) + assert(!e2.evalBool()) + } + + "evaluate complex boolean expressions" in { + val e1 = new ExpressionEvaluator("(a > 5 && b < 10) || c = 100") + e1.setValue("a", 3) + e1.setValue("b", 20) + e1.setValue("c", 100) + assert(e1.evalBool()) + + val e2 = new ExpressionEvaluator("(a > 5 && b < 10) || c = 100") + e2.setValue("a", 7) + e2.setValue("b", 8) + e2.setValue("c", 50) + assert(e2.evalBool()) + + val e3 = new ExpressionEvaluator("(a > 5 && b < 10) || c = 100") + e3.setValue("a", 3) + e3.setValue("b", 20) + e3.setValue("c", 50) + assert(!e3.evalBool()) + } + + "respect operator precedence" in { + val e1 = new ExpressionEvaluator("a = 1 || b = 2 && c = 3") + e1.setValue("a", 0) + e1.setValue("b", 2) + e1.setValue("c", 3) + assert(e1.evalBool()) // (a=1) || ((b=2) && (c=3)) = false || (true && true) = true + + val e2 = new ExpressionEvaluator("a = 1 || b = 2 && c = 3") + e2.setValue("a", 0) + e2.setValue("b", 2) + e2.setValue("c", 0) + assert(!e2.evalBool()) // (a=1) || ((b=2) && (c=3)) = false || (true && false) = false + } + + "combine string and numeric comparisons" in { + val e = new ExpressionEvaluator("status = 'ACTIVE' && amount > 100") + e.setStringValue("status", "ACTIVE") + e.setValue("amount", 150) + assert(e.evalBool()) + + val e2 = new ExpressionEvaluator("status = 'ACTIVE' && amount > 100") + e2.setStringValue("status", "INACTIVE") + e2.setValue("amount", 150) + assert(!e2.evalBool()) + } + + "throw error on type mismatch in comparison" in { + val e = new ExpressionEvaluator("a > 'string'") + e.setValue("a", 5) + val ex = intercept[ExprSyntaxError] { + e.evalBool() + } + assert(ex.getMessage.contains("Cannot compare")) + } + + "throw error on single & character" in { + val ex = intercept[ExprSyntaxError] { + new ExpressionEvaluator("a & b") + } + assert(ex.getMessage.contains("'&&'")) + } + + "throw error on single | character" in { + val ex = intercept[ExprSyntaxError] { + new ExpressionEvaluator("a | b") + } + assert(ex.getMessage.contains("'||'")) + } + + "throw error on unterminated string" in { + val ex = intercept[ExprSyntaxError] { + new ExpressionEvaluator("'unterminated") + } + assert(ex.getMessage.contains("Unterminated string")) + } + + // Null literal tests + "evaluate null comparisons" in { + assert(new ExpressionEvaluator("null = null").evalBool()) + assert(!new ExpressionEvaluator("null != null").evalBool()) + assert(!new ExpressionEvaluator("'test' = null").evalBool()) + assert(new ExpressionEvaluator("'test' != null").evalBool()) + assert(!new ExpressionEvaluator("5 = null").evalBool()) + assert(new ExpressionEvaluator("5 != null").evalBool()) + } + + "evaluate null with variables" in { + val e = new ExpressionEvaluator("a = 'test' || a = null") + e.setStringValue("a", "test") + assert(e.evalBool()) + + val e2 = new ExpressionEvaluator("a = 'test' || a = null") + e2.setStringValue("a", "other") + assert(!e2.evalBool()) + } + + "allow setting variable to null via setStringValue" in { + val e = new ExpressionEvaluator("a = null") + e.setStringValue("a", null) + assert(e.evalBool()) + + val e2 = new ExpressionEvaluator("a != null") + e2.setStringValue("a", null) + assert(!e2.evalBool()) + } + + "allow setting variable to null via setValue with Integer" in { + val e = new ExpressionEvaluator("a = null") + e.setValue("a", null: java.lang.Integer) + assert(e.evalBool()) + + val e2 = new ExpressionEvaluator("a != null") + e2.setValue("a", null: java.lang.Integer) + assert(!e2.evalBool()) + } + + "allow setting variable to null via setNullValue" in { + val e = new ExpressionEvaluator("a = null") + e.setNullValue("a") + assert(e.evalBool()) + + val e2 = new ExpressionEvaluator("a = 'test' || a = null") + e2.setNullValue("a") + assert(e2.evalBool()) + + val e3 = new ExpressionEvaluator("a = 'test' || a = null") + e3.setStringValue("a", "test") + assert(e3.evalBool()) + + val e4 = new ExpressionEvaluator("a = 'test' || a = null") + e4.setStringValue("a", "test1") + assert(!e4.evalBool()) + } + + "allow overwriting variable value with null" in { + val e = new ExpressionEvaluator("a = null") + e.setStringValue("a", "test") + assert(!e.evalBool()) + + val e2 = new ExpressionEvaluator("a = null") + e2.setStringValue("a", "test") + e2.setStringValue("a", null) + assert(e2.evalBool()) + } + + "use null variable in in() function" in { + val e = new ExpressionEvaluator("in(a, 'X', null, 'Y')") + e.setNullValue("a") + assert(e.evalBool()) + + val e2 = new ExpressionEvaluator("in(a, 'X', 'Y', 'Z')") + e2.setNullValue("a") + assert(!e2.evalBool()) + } + + "support NULL in any case" in { + assert(new ExpressionEvaluator("NULL = null").evalBool()) + assert(new ExpressionEvaluator("Null = null").evalBool()) + } + + // 'in' function tests + "evaluate in() function with strings" in { + val e = new ExpressionEvaluator("in(a, 'A', 'B', 'F')") + e.setStringValue("a", "A") + assert(e.evalBool()) + + val e2 = new ExpressionEvaluator("in(a, 'A', 'B', 'F')") + e2.setStringValue("a", "B") + assert(e2.evalBool()) + + val e3 = new ExpressionEvaluator("in(a, 'A', 'B', 'F')") + e3.setStringValue("a", "C") + assert(!e3.evalBool()) + } + + "evaluate in() function with integers" in { + val e = new ExpressionEvaluator("in(b, 1, 23, 55)") + e.setValue("b", 23) + assert(e.evalBool()) + + val e2 = new ExpressionEvaluator("in(b, 1, 23, 55)") + e2.setValue("b", 100) + assert(!e2.evalBool()) + } + + "evaluate in() function with literal value" in { + assert(new ExpressionEvaluator("in('X', 'A', 'B', 'X')").evalBool()) + assert(!new ExpressionEvaluator("in('Y', 'A', 'B', 'X')").evalBool()) + assert(new ExpressionEvaluator("in(5, 1, 5, 10)").evalBool()) + assert(!new ExpressionEvaluator("in(7, 1, 5, 10)").evalBool()) + } + + "evaluate in() function with null" in { + assert(new ExpressionEvaluator("in(null, 'A', null, 'B')").evalBool()) + assert(!new ExpressionEvaluator("in(null, 'A', 'B', 'C')").evalBool()) + } + + "throw error for in() with less than 2 arguments" in { + val e = new ExpressionEvaluator("in(a)") + e.setValue("a", 10) + val ex = intercept[ExprSyntaxError] { + e.evalBool() + } + assert(ex.getMessage.contains("at least 2 arguments")) + } + + // 'if' function tests + "evaluate if() function returning integer" in { + val e = new ExpressionEvaluator("if(a > 2, 5, 1)") + e.setValue("a", 10) + assert(e.evalInt() == 5) + + val e2 = new ExpressionEvaluator("if(a > 2, 5, 1)") + e2.setValue("a", 1) + assert(e2.evalInt() == 1) + } + + "evaluate if() function returning string" in { + val e = new ExpressionEvaluator("if(a = 'yes', 'approved', 'rejected')") + e.setStringValue("a", "yes") + assert(e.evalString() == "approved") + + val e2 = new ExpressionEvaluator("if(a = 'yes', 'approved', 'rejected')") + e2.setStringValue("a", "no") + assert(e2.evalString() == "rejected") + } + + "evaluate nested if() functions" in { + val e = new ExpressionEvaluator("if(a > 10, 100, if(a > 5, 50, 10))") + e.setValue("a", 15) + assert(e.evalInt() == 100) + + val e2 = new ExpressionEvaluator("if(a > 10, 100, if(a > 5, 50, 10))") + e2.setValue("a", 7) + assert(e2.evalInt() == 50) + + val e3 = new ExpressionEvaluator("if(a > 10, 100, if(a > 5, 50, 10))") + e3.setValue("a", 3) + assert(e3.evalInt() == 10) + } + + "evaluate if() with complex condition" in { + val e = new ExpressionEvaluator("if(a > 5 && b < 10, 1, 0)") + e.setValue("a", 7) + e.setValue("b", 8) + assert(e.evalInt() == 1) + + val e2 = new ExpressionEvaluator("if(a > 5 && b < 10, 1, 0)") + e2.setValue("a", 3) + e2.setValue("b", 8) + assert(e2.evalInt() == 0) + } + + "throw error for if() without exactly 3 arguments" in { + val ex = intercept[ExprSyntaxError] { + val e = new ExpressionEvaluator("if(a > 2, 5)") + e.setValue("a", 3) + e.evalInt() + } + assert(ex.getMessage.contains("exactly 3 arguments")) + } + + // Combined tests + "combine in() and if() functions" in { + val e = new ExpressionEvaluator("if(in(status, 'A', 'B'), 100, 0)") + e.setStringValue("status", "A") + assert(e.evalInt() == 100) + + val e2 = new ExpressionEvaluator("if(in(status, 'A', 'B'), 100, 0)") + e2.setStringValue("status", "C") + assert(e2.evalInt() == 0) + } } } diff --git a/spark-cobol/src/test/scala/za/co/absa/cobrix/spark/cobol/source/integration/Test42RedefineRulesSpec.scala b/spark-cobol/src/test/scala/za/co/absa/cobrix/spark/cobol/source/integration/Test42RedefineRulesSpec.scala index 02a347e71..4bbce8c73 100644 --- a/spark-cobol/src/test/scala/za/co/absa/cobrix/spark/cobol/source/integration/Test42RedefineRulesSpec.scala +++ b/spark-cobol/src/test/scala/za/co/absa/cobrix/spark/cobol/source/integration/Test42RedefineRulesSpec.scala @@ -132,5 +132,55 @@ class Test42RedefineRulesSpec extends AnyWordSpec with SparkTestBase with Binary } } } + + "extract data according to the rules with allowed null values" when { + val data = Array( + 0xF1, 0xF1, 0xF1, + 0xF2, 0xF2, 0xF2, + 0x00, 0xF3, 0xF3 + ).map(_.toByte) + + withTempBinFile("redefine_rules3", ".dat", data) { tmpFileName => + val df = spark + .read + .format("cobol") + .option("copybook_contents", copybook) + .option("record_format", "F") + .option("redefine-rule:1", "G1 => in(ID, 1, null)") + .option("redefine-rule:2", "G2 => ID = 2 || ID = null") + .option("redefine-rule:3", "G3 => ID = 3 || ID = null") + .option("pedantic", "true") + .load(tmpFileName) + + val actualData = SparkUtils.prettyJSON(df.toJSON.collect().mkString("[", ",", "]")) + + "data should match" in { + val expectedData = + """[ { + | "ID" : 1, + | "G1" : { + | "F1" : 11 + | } + |}, { + | "ID" : 2, + | "G2" : { + | "F2" : "22" + | } + |}, { + | "G1" : { + | "F1" : 33 + | }, + | "G2" : { + | "F2" : "33" + | }, + | "G3" : { + | "F3" : 3 + | } + |} ]""".stripMargin + + compareTextVertical(actualData, expectedData) + } + } + } } } From 4d7a01fc0c058abf9e05d897230471554bcabd77 Mon Sep 17 00:00:00 2001 From: Ruslan Iushchenko Date: Wed, 27 May 2026 08:52:38 +0200 Subject: [PATCH 5/7] #850 Add support for true/false boolean literals in expression evaluator. --- .../cobol/parser/expression/lexer/Lexer.scala | 9 +- .../cobol/parser/expression/lexer/Token.scala | 10 ++ .../expression/parser/ExpressionBuilder.scala | 2 + .../parser/ExpressionBuilderImpl.scala | 10 ++ .../parser/ExtractVariablesBuilder.scala | 4 + .../parser/expression/parser/Parser.scala | 14 +++ .../expression/ExpressionEvaluatorSuite.scala | 94 +++++++++++++++++++ 7 files changed, 139 insertions(+), 4 deletions(-) diff --git a/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/lexer/Lexer.scala b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/lexer/Lexer.scala index 56b475d97..7fc5224ec 100644 --- a/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/lexer/Lexer.scala +++ b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/lexer/Lexer.scala @@ -152,10 +152,11 @@ class Lexer(expression: String) { pos2 += 1 } val name = expression.substring(pos, pos2) - val token = if (name.toLowerCase == "null") { - NULL_LITERAL(pos) - } else { - NAME(pos, name) + val token = name.toLowerCase match { + case "null" => NULL_LITERAL(pos) + case "true" => TRUE_LITERAL(pos) + case "false" => FALSE_LITERAL(pos) + case _ => NAME(pos, name) } tokens += token pos = pos2 diff --git a/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/lexer/Token.scala b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/lexer/Token.scala index da62eaed1..478cbd80b 100644 --- a/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/lexer/Token.scala +++ b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/lexer/Token.scala @@ -109,4 +109,14 @@ object Token { { override def toString: String = "null" } + + case class TRUE_LITERAL(pos: Int) extends Token + { + override def toString: String = "true" + } + + case class FALSE_LITERAL(pos: Int) extends Token + { + override def toString: String = "false" + } } diff --git a/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/parser/ExpressionBuilder.scala b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/parser/ExpressionBuilder.scala index 1a56d80c2..290bd22c1 100644 --- a/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/parser/ExpressionBuilder.scala +++ b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/parser/ExpressionBuilder.scala @@ -39,4 +39,6 @@ trait ExpressionBuilder { def addNumLiteral(num: Int, pos: Int): Unit def addStringLiteral(s: String, pos: Int): Unit def addNullLiteral(pos: Int): Unit + def addTrueLiteral(pos: Int): Unit + def addFalseLiteral(pos: Int): Unit } diff --git a/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/parser/ExpressionBuilderImpl.scala b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/parser/ExpressionBuilderImpl.scala index cd7016e5c..fa13bc8ab 100644 --- a/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/parser/ExpressionBuilderImpl.scala +++ b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/parser/ExpressionBuilderImpl.scala @@ -203,6 +203,16 @@ class ExpressionBuilderImpl(vars: Map[String, Int], stringVars: Map[String, Stri valueTypes += NullType } + override def addTrueLiteral(pos: Int): Unit = { + valuesBool += true + valueTypes += BoolType + } + + override def addFalseLiteral(pos: Int): Unit = { + valuesBool += false + valueTypes += BoolType + } + def getIntResult: Int = { while (ops.nonEmpty) { eval() diff --git a/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/parser/ExtractVariablesBuilder.scala b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/parser/ExtractVariablesBuilder.scala index d8be51766..c7ff16937 100644 --- a/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/parser/ExtractVariablesBuilder.scala +++ b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/parser/ExtractVariablesBuilder.scala @@ -65,6 +65,10 @@ class ExtractVariablesBuilder(expr: String) extends ExpressionBuilder { override def addNullLiteral(pos: Int): Unit = { } + override def addTrueLiteral(pos: Int): Unit = { } + + override def addFalseLiteral(pos: Int): Unit = { } + def getResult: Seq[String] = { variables.distinct.sorted.toSeq } diff --git a/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/parser/Parser.scala b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/parser/Parser.scala index 1678e804f..888bacbee 100644 --- a/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/parser/Parser.scala +++ b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/parser/Parser.scala @@ -71,6 +71,12 @@ object Parser { case NULL_LITERAL(pos) => builder.addNullLiteral(pos) state = STATE1 + case TRUE_LITERAL(pos) => + builder.addTrueLiteral(pos) + state = STATE1 + case FALSE_LITERAL(pos) => + builder.addFalseLiteral(pos) + state = STATE1 case _ => new ExprSyntaxError(s"Unexpected '$token' at pos ${token.pos}") } } else if (state == STATE1) { @@ -177,6 +183,14 @@ object Parser { builder.addOperationNot(pos) builder.addNullLiteral(pos) state = STATE1 + case TRUE_LITERAL(pos) => + builder.addOperationNot(pos) + builder.addTrueLiteral(pos) + state = STATE1 + case FALSE_LITERAL(pos) => + builder.addOperationNot(pos) + builder.addFalseLiteral(pos) + state = STATE1 case _ => throw new ExprSyntaxError(s"Unexpected '$token' at pos ${token.pos}") } } diff --git a/cobol-parser/src/test/scala/za/co/absa/cobrix/cobol/expression/ExpressionEvaluatorSuite.scala b/cobol-parser/src/test/scala/za/co/absa/cobrix/cobol/expression/ExpressionEvaluatorSuite.scala index cd7434b06..6645aa559 100644 --- a/cobol-parser/src/test/scala/za/co/absa/cobrix/cobol/expression/ExpressionEvaluatorSuite.scala +++ b/cobol-parser/src/test/scala/za/co/absa/cobrix/cobol/expression/ExpressionEvaluatorSuite.scala @@ -481,5 +481,99 @@ class ExpressionEvaluatorSuite extends AnyWordSpec { e2.setStringValue("status", "C") assert(e2.evalInt() == 0) } + + // Boolean literal tests + "evaluate true and false literals" in { + assert(new ExpressionEvaluator("true").evalBool()) + assert(!new ExpressionEvaluator("false").evalBool()) + assert(new ExpressionEvaluator("TRUE").evalBool()) + assert(!new ExpressionEvaluator("FALSE").evalBool()) + } + + "use boolean literals in boolean operations" in { + assert(new ExpressionEvaluator("true && true").evalBool()) + assert(!new ExpressionEvaluator("true && false").evalBool()) + assert(new ExpressionEvaluator("true || false").evalBool()) + assert(!new ExpressionEvaluator("false || false").evalBool()) + } + + "use NOT with boolean literals" in { + assert(!new ExpressionEvaluator("!true").evalBool()) + assert(new ExpressionEvaluator("!false").evalBool()) + assert(new ExpressionEvaluator("!!true").evalBool()) + } + + "use boolean literals in if() function" in { + val e = new ExpressionEvaluator("if(a > 10, true, false)") + e.setValue("a", 15) + assert(e.evalBool()) + + val e2 = new ExpressionEvaluator("if(a > 10, true, false)") + e2.setValue("a", 5) + assert(!e2.evalBool()) + } + + "use boolean literals with mixed expressions in if()" in { + val e = new ExpressionEvaluator("if(a > 10, true, b < 100)") + e.setValue("a", 5) + e.setValue("b", 50) + assert(e.evalBool()) + + val e2 = new ExpressionEvaluator("if(a > 10, true, b < 100)") + e2.setValue("a", 5) + e2.setValue("b", 200) + assert(!e2.evalBool()) + + val e3 = new ExpressionEvaluator("if(a > 10, true, b < 100)") + e3.setValue("a", 15) + e3.setValue("b", 200) + assert(e3.evalBool()) + } + + "combine boolean literals with comparisons" in { + val e = new ExpressionEvaluator("(a = 5 && true) || false") + e.setValue("a", 5) + assert(e.evalBool()) + + val e2 = new ExpressionEvaluator("(a = 5 && false) || false") + e2.setValue("a", 5) + assert(!e2.evalBool()) + } + + "return boolean literal from if() based on condition" in { + val e = new ExpressionEvaluator("if(status = 'OK', true, false)") + e.setStringValue("status", "OK") + assert(e.evalBool()) + + val e2 = new ExpressionEvaluator("if(status = 'OK', true, false)") + e2.setStringValue("status", "ERROR") + assert(!e2.evalBool()) + } + + "use boolean literals in nested if() functions" in { + val e = new ExpressionEvaluator("if(a > 10, true, if(a > 5, false, true))") + e.setValue("a", 15) + assert(e.evalBool()) + + val e2 = new ExpressionEvaluator("if(a > 10, true, if(a > 5, false, true))") + e2.setValue("a", 7) + assert(!e2.evalBool()) + + val e3 = new ExpressionEvaluator("if(a > 10, true, if(a > 5, false, true))") + e3.setValue("a", 3) + assert(e3.evalBool()) + } + + "evaluate expressions with boolean logical operations" in { + val e = new ExpressionEvaluator("(a > 5) && (b < 10)") + e.setValue("a", 10) + e.setValue("b", 5) + assert(e.evalBool()) // both true, so true = true + + val e2 = new ExpressionEvaluator("(a > 5) || (b < 10)") + e2.setValue("a", 10) + e2.setValue("b", 15) + assert(e2.evalBool()) // true || false => true + } } } From 41b959038460dd8b8807e747208b3cffccc791f7 Mon Sep 17 00:00:00 2001 From: Ruslan Iushchenko Date: Thu, 28 May 2026 09:00:10 +0200 Subject: [PATCH 6/7] #850 Add schema validation for rule expressions and document redefine rules feature. --- README.md | 73 ++++++ .../asttransform/RuleExpressionSetter.scala | 52 +++- .../RuleExpressionParsingException.scala | 30 +++ .../RuleExpressionSetterSuite.scala | 234 ++++++++++++++++++ .../integration/Test42RedefineRulesSpec.scala | 57 +++++ 5 files changed, 443 insertions(+), 3 deletions(-) create mode 100644 cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/exceptions/RuleExpressionParsingException.scala create mode 100644 cobol-parser/src/test/scala/za/co/absa/cobrix/cobol/parser/asttransform/RuleExpressionSetterSuite.scala diff --git a/README.md b/README.md index 0f7f519ce..49c15b0ac 100644 --- a/README.md +++ b/README.md @@ -1011,6 +1011,79 @@ df.show(10) In the above example invalid fields became `null` and the parsing is done faster because Cobrix does not need to process every redefine for each record. +## Automatic filtering of arbitrary redefines + +Arbitrary redefines can be resolved using rule expressions. This doesn't have to be segment redefines, just any redefines. +For example, for a copybook that looks like this: +```cobol + 01 COMPANY-DETAILS. +************** RECORD-TYPE CAN BE 'C' for company, and 'P' or 'E' for person. + 05 RECORD-TYPE PIC X(1). + 05 COMPANY-ID PIC X(10). + 05 COMPANY. + 10 COMPANY-NAME PIC X(15). + 10 ADDRESS PIC X(25). + 05 PERSON REDEFINES COMPANY. + 10 FIRST-NAME PIC X(30). + 10 LAST-NAME PIC X(30). +``` + +The syntax is as follows: + +``` + .option("redefine_rule:1", "COMPANY => RECORD_TYPE = 'C'") + .option("redefine_rule:2", "PERSON => in(RECORD_TYPE, 'P', 'E')") +``` + +For the above example the load options will lok like this (last 2 options): +```scala +val df = spark + .read + .format("cobol") + .option("copybook_contents", copybook) + .option("record_format", "V") + .option("redefine_rule:1", "COMPANY => RECORD_TYPE = 'C'") + .option("redefine_rule:2", "PERSON => in(RECORD_TYPE, 'P', 'E')") + .load("examples/multisegment_data/COMP.DETAILS.SEP30.DATA.dat") +``` + +The filtered data will look like this: +``` +df.show(10) ++-----------+----------+--------------------+--------------------+ +|RECORD_TYPE|COMPANY_ID| COMPANY| PERSON| ++-----------+----------+--------------------+--------------------+ +| C|9377942526|[Joan Q & Z,10 Sa...| | +| P|9377942526| | [John, Smith]| +| C|3483483977|[Robotrd Inc.,2 P...| | +| E|3483483977| | [Jane, Wanson]| +| E|3483483977| | [Alex,Johnson]| ++-----------+----------+--------------------+--------------------+ +``` + +#### Notes +- Variable names in rule expressions are case-sensitive. +- Variable names are required to be used after column sanitization (e.g. replacement of special characters with underscores), + otherwise expression `F-A = 1` is ambiguous since it is not clear if `F-A` is a variable name or an expression of subtraction. + In this case the variable name should be `F_A` and the expression should be `F_A = 1`. +- You can only reference variables that go _before_ the redefine field. This is because record decoding is forward only. +- Use only field names themselves, not full paths, e.g. `COMPANY` instead of `RECODD.DETAILS.COMPANY` . +- Only integral numeric literals are supported. Decimals are not supported. +- The expression should return a boolean. For example: + - `RECORD_TYPE = 'C'` is valid since it returns true for company records and false for person records. + - `in(RECORD_TYPE, 'P', 'E')` is valid since it returns true for person records and false for company records. + - `COMPANY_ID > 1000` is valid since it returns true for records with company id greater than 1000 and false otherwise. + +### Expressions supported +- Comparison operators: `=`, `!=`, `>`, `<`, `>=`, `<=`. +- Boolean logic: `&&` (and), `||` (or), `!` (not). +- Integral literals: `123`, `0`, `-456` are valid, but `123.45` or `-123.45` are not valid. +- String literals: `'abc'`, `'abc'`, `'123'` are valid. Always use single quote character. +- Boolean literals: `true`, `false`. +- Null literal: `null`. For example: `RECORD_TYPE = 'C' || RECORD_TYPE = null` is valid. +- Functions: + - `in()` (one of from the list), for example: `in(RECORD_TYPE, 'P', 'E')` is valid since it returns true for person records and false for company records. + - `if()` (conditional function with 3 arguments), for example: `if(RECORD_TYPE = 'C', true, false)` is valid. ## Group Filler dropping diff --git a/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/asttransform/RuleExpressionSetter.scala b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/asttransform/RuleExpressionSetter.scala index a477b337b..8c74f98db 100644 --- a/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/asttransform/RuleExpressionSetter.scala +++ b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/asttransform/RuleExpressionSetter.scala @@ -18,12 +18,11 @@ package za.co.absa.cobrix.cobol.parser.asttransform import org.slf4j.LoggerFactory import za.co.absa.cobrix.cobol.parser.CopybookParser.CopybookAST -import za.co.absa.cobrix.cobol.parser.ast.datatype.{Decimal, Integral} import za.co.absa.cobrix.cobol.parser.ast.{Group, Primitive, Statement} +import za.co.absa.cobrix.cobol.parser.exceptions.RuleExpressionParsingException import za.co.absa.cobrix.cobol.parser.expression.ExpressionEvaluator import scala.collection.mutable -import scala.collection.mutable.ListBuffer class RuleExpressionSetter( redefineRuleExpressions: Map[String, ExpressionEvaluator] @@ -37,7 +36,31 @@ class RuleExpressionSetter( * @return The same AST with binary properties set for every field */ final override def transform(ast: CopybookAST): CopybookAST = { - val ruleDrivenRedefines = redefineRuleExpressions.keys.toSet + // Collect all field names from the schema + val allFieldNames = collectAllFieldNames(ast) + + // Validate that all target fields exist in the schema + val invalidTargets = redefineRuleExpressions.keys.toSet.diff(allFieldNames).toSeq.sorted + if (invalidTargets.nonEmpty) { + throw new RuleExpressionParsingException( + None, + msg = s"Target field(s) not found in schema: ${invalidTargets.mkString(", ")}" + ) + } + + // Validate that all variables used in rule expressions exist in the schema + val allVariables = redefineRuleExpressions.values.flatMap(_.getVariables).toSet + val invalidVariables = allVariables.diff(allFieldNames).toSeq.sorted + if (invalidVariables.nonEmpty) { + val field = redefineRuleExpressions.find { case (_, expr) => + expr.getVariables.exists(invalidVariables.contains) + }.map(_._1) + throw new RuleExpressionParsingException( + fieldOpt = field, + msg = s"Rule expression variable(s) not found in schema: ${invalidVariables.mkString(", ")}" + ) + } + val dependeeFields = redefineRuleExpressions.values.flatMap { expr => expr.getVariables }.toSet @@ -76,6 +99,29 @@ class RuleExpressionSetter( markRuleFields(ast) } + + /** + * Collects all field names from the AST recursively + * + * @param group The AST group to traverse + * @return Set of all field names in the schema + */ + private def collectAllFieldNames(group: CopybookAST): Set[String] = { + val names = mutable.Set[String]() + + def collectNames(statement: Statement): Unit = { + statement match { + case grp: Group => + names += grp.name + grp.children.foreach(collectNames) + case primitive: Primitive => + names += primitive.name + } + } + + group.children.foreach(collectNames) + names.toSet + } } object RuleExpressionSetter { diff --git a/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/exceptions/RuleExpressionParsingException.scala b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/exceptions/RuleExpressionParsingException.scala new file mode 100644 index 000000000..bf5f21361 --- /dev/null +++ b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/exceptions/RuleExpressionParsingException.scala @@ -0,0 +1,30 @@ +/* + * Copyright 2018 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.cobrix.cobol.parser.exceptions + +class RuleExpressionParsingException(val fieldOpt: Option[String], val msg: String) + extends Exception(RuleExpressionParsingException.constructErrorMessage(fieldOpt, msg)) { +} + +object RuleExpressionParsingException { + private def constructErrorMessage(fieldOpt: Option[String], msg: String): String = { + fieldOpt match { + case Some(field) => s"Error in the rule expression for '$field': $msg" + case None => s"Error in the rule expression definition(s): $msg" + } + } +} diff --git a/cobol-parser/src/test/scala/za/co/absa/cobrix/cobol/parser/asttransform/RuleExpressionSetterSuite.scala b/cobol-parser/src/test/scala/za/co/absa/cobrix/cobol/parser/asttransform/RuleExpressionSetterSuite.scala new file mode 100644 index 000000000..6b4d77d9f --- /dev/null +++ b/cobol-parser/src/test/scala/za/co/absa/cobrix/cobol/parser/asttransform/RuleExpressionSetterSuite.scala @@ -0,0 +1,234 @@ +/* + * Copyright 2018 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.cobrix.cobol.parser.asttransform + +import org.scalatest.wordspec.AnyWordSpec +import za.co.absa.cobrix.cobol.parser.CopybookParser +import za.co.absa.cobrix.cobol.parser.exceptions.{RuleExpressionParsingException, SyntaxErrorException} +import za.co.absa.cobrix.cobol.parser.expression.ExpressionEvaluator + +class RuleExpressionSetterSuite extends AnyWordSpec { + "RuleExpressionSetter" should { + "accept valid rule expressions with existing fields" in { + val copybook = + """ 01 RECORD. + | 05 FIELD-A PIC 9(5). + | 05 FIELD-B PIC 9(5). + | 05 FIELD-C PIC X(10). + |""".stripMargin + + val schema = CopybookParser.parseTree(copybook) + + val rules = Map( + "FIELD_C" -> new ExpressionEvaluator("FIELD_A > 100") + ) + + val setter = RuleExpressionSetter(rules) + // Should not throw exception + val result = setter.transform(schema.ast) + assert(result != null) + } + + "accept rule expressions with multiple variables" in { + val copybook = + """ 01 RECORD. + | 05 STATUS PIC X(2). + | 05 AMOUNT PIC 9(5). + | 05 PRIORITY PIC 9(1). + | 05 RESULT PIC X(10). + |""".stripMargin + + val schema = CopybookParser.parseTree(copybook) + + val rules = Map( + "RESULT" -> new ExpressionEvaluator("STATUS = 'OK' && AMOUNT > 100") + ) + + val setter = RuleExpressionSetter(rules) + // Should not throw exception + val result = setter.transform(schema.ast) + assert(result != null) + } + + "throw exception when target field does not exist" in { + val copybook = + """ 01 RECORD. + | 05 FIELD-A PIC 9(5). + | 05 FIELD-B PIC 9(5). + |""".stripMargin + + val schema = CopybookParser.parseTree(copybook) + + val rules = Map( + "FIELD-C" -> new ExpressionEvaluator("FIELD-A > 100") + ) + + val setter = RuleExpressionSetter(rules) + val ex = intercept[RuleExpressionParsingException] { + setter.transform(schema.ast) + } + assert(ex.getMessage.contains("Target field(s) not found in schema")) + assert(ex.getMessage.contains("FIELD-C")) + } + + "throw exception when variable in expression does not exist" in { + val copybook = + """ 01 RECORD. + | 05 FIELD-A PIC 9(5). + | 05 FIELD-B PIC 9(5). + |""".stripMargin + + val schema = CopybookParser.parseTree(copybook) + + val rules = Map( + "FIELD_B" -> new ExpressionEvaluator("FIELD_X > 100") + ) + + val setter = RuleExpressionSetter(rules) + val ex = intercept[RuleExpressionParsingException] { + setter.transform(schema.ast) + } + assert(ex.getMessage.contains("variable(s) not found in schema")) + assert(ex.getMessage.contains("FIELD_X")) + } + + "throw exception for multiple invalid variables" in { + val copybook = + """ 01 RECORD. + | 05 FIELD-A PIC 9(5). + | 05 FIELD-B PIC 9(5). + |""".stripMargin + + val schema = CopybookParser.parseTree(copybook) + + val rules = Map( + "FIELD_B" -> new ExpressionEvaluator("FIELD_X > 100 && FIELD_Y < 50") + ) + + val setter = RuleExpressionSetter(rules) + val ex = intercept[RuleExpressionParsingException] { + setter.transform(schema.ast) + } + assert(ex.getMessage.contains("variable(s) not found in schema")) + assert(ex.getMessage.contains("FIELD_X")) + assert(ex.getMessage.contains("FIELD_Y")) + } + + "throw exception for multiple invalid target fields" in { + val copybook = + """ 01 RECORD. + | 05 FIELD-A PIC 9(5). + | 05 FIELD-B PIC 9(5). + |""".stripMargin + + val schema = CopybookParser.parseTree(copybook) + + val rules = Map( + "FIELD-X" -> new ExpressionEvaluator("FIELD-A > 100"), + "FIELD-Y" -> new ExpressionEvaluator("FIELD-B < 50") + ) + + val setter = RuleExpressionSetter(rules) + val ex = intercept[RuleExpressionParsingException] { + setter.transform(schema.ast) + } + assert(ex.getMessage.contains("Target field(s) not found in schema")) + // Should contain both invalid targets + assert(ex.getMessage.contains("FIELD-X")) + assert(ex.getMessage.contains("FIELD-Y")) + } + + "accept rules for nested fields" in { + val copybook = + """ 01 RECORD. + | 05 HEADER. + | 10 STATUS PIC X(2). + | 10 CODE PIC 9(3). + | 05 BODY. + | 10 AMOUNT PIC 9(5). + | 10 RESULT PIC X(10). + |""".stripMargin + + val schema = CopybookParser.parseTree(copybook) + + val rules = Map( + "RESULT" -> new ExpressionEvaluator("STATUS = 'OK' && AMOUNT > 100") + ) + + val setter = RuleExpressionSetter(rules) + // Should not throw exception + val result = setter.transform(schema.ast) + assert(result != null) + } + + "accept rules with in() and if() functions" in { + val copybook = + """ 01 RECORD. + | 05 STATUS PIC X(2). + | 05 PRIORITY PIC 9(1). + | 05 RESULT PIC X(10). + |""".stripMargin + + val schema = CopybookParser.parseTree(copybook) + + val rules = Map( + "RESULT" -> new ExpressionEvaluator("if(in(STATUS, 'OK', 'AC'), PRIORITY > 5, false)") + ) + + val setter = RuleExpressionSetter(rules) + // Should not throw exception + val result = setter.transform(schema.ast) + assert(result != null) + } + + "accept empty rule map" in { + val copybook = + """ 01 RECORD. + | 05 FIELD-A PIC 9(5). + | 05 FIELD-B PIC 9(5). + |""".stripMargin + + val schema = CopybookParser.parseTree(copybook) + + val rules = Map.empty[String, ExpressionEvaluator] + + val setter = RuleExpressionSetter(rules) + // Should not throw exception + val result = setter.transform(schema.ast) + assert(result != null) + } + + "accept rules with literals that don't reference variables" in { + val copybook = + """ 01 RECORD. + | 05 FIELD-A PIC 9(5). + | 05 FIELD-B PIC X(10). + |""".stripMargin + + val schema = CopybookParser.parseTree(copybook) + + val rules = Map( + "FIELD_B" -> new ExpressionEvaluator("100 > 50") + ) + + val setter = RuleExpressionSetter(rules) + // Should not throw exception (no variables, just literals) + val result = setter.transform(schema.ast) + assert(result != null) + } + } +} diff --git a/spark-cobol/src/test/scala/za/co/absa/cobrix/spark/cobol/source/integration/Test42RedefineRulesSpec.scala b/spark-cobol/src/test/scala/za/co/absa/cobrix/spark/cobol/source/integration/Test42RedefineRulesSpec.scala index 4bbce8c73..b192415f2 100644 --- a/spark-cobol/src/test/scala/za/co/absa/cobrix/spark/cobol/source/integration/Test42RedefineRulesSpec.scala +++ b/spark-cobol/src/test/scala/za/co/absa/cobrix/spark/cobol/source/integration/Test42RedefineRulesSpec.scala @@ -182,5 +182,62 @@ class Test42RedefineRulesSpec extends AnyWordSpec with SparkTestBase with Binary } } } + "extract data according to the rules with string values" when { + val copybook = + """ 01 R. + 03 ID PIC X(1). + 03 G1. + 04 F1 PIC S9(2). + 03 G2 REDEFINES G1. + 04 F2 PIC X(2). + 03 G3 REDEFINES G1. + 04 F3 PIC 9(1). + """ + + val data = Array( + 0xC1, 0xF1, 0xF1, + 0xC2, 0xF2, 0xF2, + 0x00, 0xF3, 0xF3 + ).map(_.toByte) + + withTempBinFile("redefine_rules3", ".dat", data) { tmpFileName => + val df = spark + .read + .format("cobol") + .option("copybook_contents", copybook) + .option("record_format", "F") + .option("redefine-rule:1", "G1 => in(ID, 'A', null)") + .option("redefine-rule:2", "G2 => ID='B' && true") + .option("redefine-rule:3", "G3 => ID = 'C' || ID = null") + .option("pedantic", "true") + .load(tmpFileName) + + val actualData = SparkUtils.prettyJSON(df.toJSON.collect().mkString("[", ",", "]")) + + "data should match" in { + val expectedData = + """[ { + | "ID" : "A", + | "G1" : { + | "F1" : 11 + | } + |}, { + | "ID" : "B", + | "G2" : { + | "F2" : "22" + | } + |}, { + | "G1" : { + | "F1" : 33 + | }, + | "G3" : { + | "F3" : 3 + | } + |} ]""".stripMargin + + compareTextVertical(actualData, expectedData) + } + } + } } } From 05ad32339aed84edae0825b50cf12d0b6f00ddb8 Mon Sep 17 00:00:00 2001 From: Ruslan Iushchenko Date: Thu, 28 May 2026 11:04:35 +0200 Subject: [PATCH 7/7] #850 Fix expression evaluator bugs, add validation for redefine rules, and support decimal values in rule expressions (Thanks @coderabbitai). --- README.md | 8 +-- .../expression/ExpressionEvaluator.scala | 20 ++++-- .../parser/expression/parser/Parser.scala | 8 +-- .../extractors/record/RecordExtractors.scala | 49 +++++++++----- .../parameters/CobolParametersParser.scala | 20 ++++-- .../cobol/source/ParametersParsingSpec.scala | 40 +++++++++++- .../integration/Test42RedefineRulesSpec.scala | 65 +++++++++++++++++++ 7 files changed, 177 insertions(+), 33 deletions(-) diff --git a/README.md b/README.md index 49c15b0ac..d79450f37 100644 --- a/README.md +++ b/README.md @@ -974,7 +974,7 @@ segment id. This way Cobrix will parse only relevant segment redefined fields an .option("redefine-segment-id-map:1", "REDEFINED_FIELD2 => SegmentId10,SegmentId11,...") ``` -For the above example the load options will lok like this (last 2 options): +For the above example the load options will look like this (last 2 options): ```scala val df = spark .read @@ -1030,12 +1030,12 @@ For example, for a copybook that looks like this: The syntax is as follows: -``` +```scala .option("redefine_rule:1", "COMPANY => RECORD_TYPE = 'C'") .option("redefine_rule:2", "PERSON => in(RECORD_TYPE, 'P', 'E')") ``` -For the above example the load options will lok like this (last 2 options): +For the above example the load options will look like this (last 2 options): ```scala val df = spark .read @@ -1067,7 +1067,7 @@ df.show(10) otherwise expression `F-A = 1` is ambiguous since it is not clear if `F-A` is a variable name or an expression of subtraction. In this case the variable name should be `F_A` and the expression should be `F_A = 1`. - You can only reference variables that go _before_ the redefine field. This is because record decoding is forward only. -- Use only field names themselves, not full paths, e.g. `COMPANY` instead of `RECODD.DETAILS.COMPANY` . +- Use only field names themselves, not full paths, e.g. `COMPANY` instead of `RECORD.DETAILS.COMPANY` . - Only integral numeric literals are supported. Decimals are not supported. - The expression should return a boolean. For example: - `RECORD_TYPE = 'C'` is valid since it returns true for company records and false for person records. diff --git a/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/ExpressionEvaluator.scala b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/ExpressionEvaluator.scala index 4621b6bb1..4f8a07fa6 100644 --- a/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/ExpressionEvaluator.scala +++ b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/ExpressionEvaluator.scala @@ -17,7 +17,7 @@ package za.co.absa.cobrix.cobol.parser.expression import za.co.absa.cobrix.cobol.parser.expression.lexer.Lexer -import za.co.absa.cobrix.cobol.parser.expression.parser.{ExtractVariablesBuilder, ExpressionBuilderImpl, Parser} +import za.co.absa.cobrix.cobol.parser.expression.parser.{ExpressionBuilderImpl, ExtractVariablesBuilder, Parser} import scala.collection.mutable @@ -82,20 +82,32 @@ class ExpressionEvaluator(val expr: String) extends Serializable { val exprBuilder = new ExpressionBuilderImpl(vars.toMap, stringVars.toMap, nullVars.toSet, expr) Parser.parse(tokens, exprBuilder) - exprBuilder.getIntResult + val i = exprBuilder.getIntResult + clearValues() + i } def evalBool(): Boolean = { val exprBuilder = new ExpressionBuilderImpl(vars.toMap, stringVars.toMap, nullVars.toSet, expr) Parser.parse(tokens, exprBuilder) - exprBuilder.getBoolResult + val b = exprBuilder.getBoolResult + clearValues() + b } def evalString(): String = { val exprBuilder = new ExpressionBuilderImpl(vars.toMap, stringVars.toMap, nullVars.toSet, expr) Parser.parse(tokens, exprBuilder) - exprBuilder.getStringResult + val s = exprBuilder.getStringResult + clearValues() + s + } + + private def clearValues(): Unit = { + vars.clear() + stringVars.clear() + nullVars.clear() } } diff --git a/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/parser/Parser.scala b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/parser/Parser.scala index 888bacbee..d3c9fb798 100644 --- a/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/parser/Parser.scala +++ b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/expression/parser/Parser.scala @@ -77,7 +77,7 @@ object Parser { case FALSE_LITERAL(pos) => builder.addFalseLiteral(pos) state = STATE1 - case _ => new ExprSyntaxError(s"Unexpected '$token' at pos ${token.pos}") + case _ => throw new ExprSyntaxError(s"Unexpected '$token' at pos ${token.pos}") } } else if (state == STATE1) { token match { @@ -90,7 +90,7 @@ object Parser { state = STATE0 case CLOSE_PARAN(pos) => if (paranPos.isEmpty) { - throw new ExprSyntaxError(s"Unmatched ')' at pos $pos") + throw throw new ExprSyntaxError(s"Unmatched ')' at pos $pos") } paranPos.remove(paranPos.size - 1) builder.closeParen(pos) @@ -135,7 +135,7 @@ object Parser { builder.addFunction(s, pos) case NUM_LITERAL(pos, s) => builder.addNumLiteral(s.toInt, pos) - case _ => new ExprSyntaxError(s"Unexpected '$token' at pos ${token.pos}") + case _ => throw new ExprSyntaxError(s"Unexpected '$token' at pos ${token.pos}") } } else if (state == MINUS_NUM) { token match { @@ -151,7 +151,7 @@ object Parser { case NUM_LITERAL(pos, s) => builder.addNumLiteral(-s.toInt, pos) state = STATE1 - case _ => new ExprSyntaxError(s"Unexpected '$token' at pos ${token.pos}") + case _ => throw new ExprSyntaxError(s"Unexpected '$token' at pos ${token.pos}") } } else if (state == NOT_OP) { token match { diff --git a/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/reader/extractors/record/RecordExtractors.scala b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/reader/extractors/record/RecordExtractors.scala index de18cb653..1257c7def 100644 --- a/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/reader/extractors/record/RecordExtractors.scala +++ b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/reader/extractors/record/RecordExtractors.scala @@ -21,6 +21,7 @@ import za.co.absa.cobrix.cobol.parser.ast.datatype.{AlphaNumeric, COMP4} import za.co.absa.cobrix.cobol.parser.ast.{Group, Primitive, Statement} import za.co.absa.cobrix.cobol.parser.common.Constants import za.co.absa.cobrix.cobol.parser.encoding.RAW +import za.co.absa.cobrix.cobol.parser.expression.ExpressionEvaluator import za.co.absa.cobrix.cobol.parser.policies.VariableSizeOccursPolicy import za.co.absa.cobrix.cobol.reader.policies.SchemaRetentionPolicy import za.co.absa.cobrix.cobol.reader.policies.SchemaRetentionPolicy.SchemaRetentionPolicy @@ -29,6 +30,7 @@ import za.co.absa.cobrix.cobol.utils.StringUtils import scala.collection.mutable import scala.collection.mutable.{ArrayBuffer, ListBuffer} import scala.reflect.ClassTag +import scala.util.control.NonFatal object RecordExtractors { @@ -599,24 +601,41 @@ object RecordExtractors { def canExtract(field: Statement, variables: mutable.Map[String, Any]): Boolean = { field.ruleExpression match { case Some(expr) => - variables.foreach { - case (k, v) => - if (v == null) - expr.setNullValue(k) - else { - v match { - case s: String => expr.setStringValue(k, s) - case _ => expr.setValue(k, v.toString.toInt) - } - } - } - if (expr.evalBool()) { - true - } else { - false + expr.getVariables.foreach { k => + variables.get(k) match { + case Some(v: Any) => safeSetValue(expr, k, v) + case None | Some(null) => expr.setNullValue(k) + } } + expr.evalBool() case None => true } } + + private def safeSetValue(expr: ExpressionEvaluator, varName: String, value: Any): Unit = { + try { + if (value == null) { + expr.setNullValue(varName) + } else { + value match { + case s: String => expr.setStringValue(varName, s) + case i: Int => expr.setValue(varName, i) + case i: java.lang.Integer => expr.setValue(varName, i) + case l: Long => expr.setValue(varName, l.toInt) + case l: java.lang.Long => expr.setValue(varName, l.toInt) + case d: java.math.BigDecimal => expr.setValue(varName, d.intValue()) + case d: scala.math.BigDecimal => expr.setValue(varName, d.intValue) + case f: Float => expr.setValue(varName, f.toInt) + case f: Double => expr.setValue(varName, f.toInt) + case _ => expr.setValue(varName, value.toString.toInt) + } + } + } catch { + case NonFatal(_) => + // No logging here because this runs inside RDD and on the critical path. Replacing + // unparsable values with null is a standard Spark behavior in this case. + expr.setNullValue(varName) + } + } } diff --git a/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/reader/parameters/CobolParametersParser.scala b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/reader/parameters/CobolParametersParser.scala index 74a9bb1b0..1b0fbced2 100644 --- a/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/reader/parameters/CobolParametersParser.scala +++ b/cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/reader/parameters/CobolParametersParser.scala @@ -782,7 +782,8 @@ object CobolParametersParser extends Logging { */ @throws(classOf[IllegalArgumentException]) def getRedefineRuleExpressionMapping(params: Parameters): Map[String, String] = { - params.getMap.flatMap { + val redefineRules = new mutable.HashMap[String, String] + params.getMap.foreach { case (k, v) => val keyNoCase = k.toLowerCase if (keyNoCase.startsWith(PARAM_REDEFINE_RULE_PREFIX) || @@ -794,11 +795,22 @@ object CobolParametersParser extends Logging { } val redefine = splitVal(0).trim val rule = splitVal(1).trim - Option((CopybookParser.transformIdentifier(redefine), rule)) - } else { - None + if (redefine.isEmpty || rule.isEmpty) { + throw new IllegalArgumentException( + s"Illegal argument for the '$PARAM_REDEFINE_RULE_PREFIX' option: '$v'. " + + s"Both redefine field and rule expression must be non-empty." + ) + } + val key = CopybookParser.transformIdentifier(redefine) + if (redefineRules.contains(key)) { + throw new IllegalArgumentException( + s"Duplicate redefine rule for field '$key' is not allowed." + ) + } + redefineRules.put(key, rule) } } + redefineRules.toMap } /** diff --git a/spark-cobol/src/test/scala/za/co/absa/cobrix/spark/cobol/source/ParametersParsingSpec.scala b/spark-cobol/src/test/scala/za/co/absa/cobrix/spark/cobol/source/ParametersParsingSpec.scala index 0a9dc06cc..7a41154af 100644 --- a/spark-cobol/src/test/scala/za/co/absa/cobrix/spark/cobol/source/ParametersParsingSpec.scala +++ b/spark-cobol/src/test/scala/za/co/absa/cobrix/spark/cobol/source/ParametersParsingSpec.scala @@ -32,7 +32,7 @@ class ParametersParsingSpec extends AnyFunSuite { assert(segmentIdMapping("C") == "COMPANY") assert(segmentIdMapping("D") == "COMPANY") assert(segmentIdMapping("P") == "CONTACT") - assert(segmentIdMapping.get("Q").isEmpty) + assert(!segmentIdMapping.contains("Q")) } test("Test redefine rule expression mapping") { @@ -46,6 +46,42 @@ class ParametersParsingSpec extends AnyFunSuite { assert(ruleExpressions("CONTACT") == "RECORD_TYPE = 2") } + test("Test redefine rule expression must not duplicate") { + val config = HashMap[String,String] ( + "redefine-rule:1" -> "COMPANY => RECORD_TYPE = 1", + "redefine_rule:2" -> "COMPANY => RECORD_TYPE = 2") + + val ex = intercept[IllegalArgumentException] { + CobolParametersParser.getRedefineRuleExpressionMapping(new Parameters(config)) + } + + assert(ex.getMessage.contains("Duplicate redefine rule for field 'COMPANY' is not allowed.")) + } + + test("Test redefine rule target fields must not be empty") { + val config = HashMap[String,String] ( + "redefine-rule:1" -> " => RECORD_TYPE = 1", + "redefine_rule:2" -> "CONTACT => RECORD_TYPE = 2") + + val ex = intercept[IllegalArgumentException] { + CobolParametersParser.getRedefineRuleExpressionMapping(new Parameters(config)) + } + + assert(ex.getMessage.contains("Illegal argument for the 'redefine_rule' option: ' => RECORD_TYPE = 1'")) + } + + test("Test redefine rule expressions must not me empty") { + val config = HashMap[String,String] ( + "redefine-rule:1" -> "COMPANY => RECORD_TYPE = 1", + "redefine_rule:2" -> "CONTACT => ") + + val ex = intercept[IllegalArgumentException] { + CobolParametersParser.getRedefineRuleExpressionMapping(new Parameters(config)) + } + + assert(ex.getMessage.contains("Illegal argument for the 'redefine_rule' option: 'CONTACT => '")) + } + test("Test field - parent field mapping") { val config = HashMap[String,String] ("is_record_sequence"-> "true", "segment-children:1" -> "COMPANY => DEPT,CUSTOMER", @@ -60,7 +96,7 @@ class ParametersParsingSpec extends AnyFunSuite { assert(fieldParents("OFFICE") == "DEPT") assert(fieldParents("CONTACT") == "CUSTOMER") assert(fieldParents("CONTRACT") == "CUSTOMER") - assert(fieldParents.get("COMPANY").isEmpty) + assert(!fieldParents.contains("COMPANY")) } test("Test field - parent field mapping (split)") { diff --git a/spark-cobol/src/test/scala/za/co/absa/cobrix/spark/cobol/source/integration/Test42RedefineRulesSpec.scala b/spark-cobol/src/test/scala/za/co/absa/cobrix/spark/cobol/source/integration/Test42RedefineRulesSpec.scala index b192415f2..6a6ace0c6 100644 --- a/spark-cobol/src/test/scala/za/co/absa/cobrix/spark/cobol/source/integration/Test42RedefineRulesSpec.scala +++ b/spark-cobol/src/test/scala/za/co/absa/cobrix/spark/cobol/source/integration/Test42RedefineRulesSpec.scala @@ -182,6 +182,7 @@ class Test42RedefineRulesSpec extends AnyWordSpec with SparkTestBase with Binary } } } + "extract data according to the rules with string values" when { val copybook = """ 01 R. @@ -239,5 +240,69 @@ class Test42RedefineRulesSpec extends AnyWordSpec with SparkTestBase with Binary } } } + + "extract data according to the rules with decimal values" when { + val copybook = + """ 01 R. + 03 ID PIC 9V9. + 03 G1. + 04 F1 PIC S9(2). + 03 G2 REDEFINES G1. + 04 F2 PIC X(2). + 03 G3 REDEFINES G1. + 04 F3 PIC 9(1). + """ + + val data = Array( + 0xF1, 0xF1, 0xF1, 0xF1, + 0xF2, 0xF2, 0xF2, 0xF2, + 0xF3, 0xF0, 0xF3, 0xF3, + 0x00, 0x00, 0xF4, 0xF4 + ).map(_.toByte) + + withTempBinFile("redefine_rules3", ".dat", data) { tmpFileName => + val df = spark + .read + .format("cobol") + .option("copybook_contents", copybook) + .option("record_format", "F") + .option("redefine-rule:1", "G1 => in(ID, 1, null)") + .option("redefine-rule:2", "G2 => ID=2 && true") + .option("redefine-rule:3", "G3 => ID = 3 || ID = null") + .option("pedantic", "true") + .load(tmpFileName) + + val actualData = SparkUtils.prettyJSON(df.toJSON.collect().mkString("[", ",", "]")) + + "data should match" in { + val expectedData = + """[ { + | "ID" : 1.1, + | "G1" : { + | "F1" : 11 + | } + |}, { + | "ID" : 2.2, + | "G2" : { + | "F2" : "22" + | } + |}, { + | "ID" : 3.0, + | "G3" : { + | "F3" : 3 + | } + |}, { + | "G1" : { + | "F1" : 44 + | }, + | "G3" : { + | "F3" : 4 + | } + |} ]""".stripMargin + + compareTextVertical(actualData, expectedData) + } + } + } } }