diff --git a/sjsonnet/src-jvm-native/sjsonnet/SjsonnetMainBase.scala b/sjsonnet/src-jvm-native/sjsonnet/SjsonnetMainBase.scala index 07155f8d..3fb7ab2b 100644 --- a/sjsonnet/src-jvm-native/sjsonnet/SjsonnetMainBase.scala +++ b/sjsonnet/src-jvm-native/sjsonnet/SjsonnetMainBase.scala @@ -1,5 +1,7 @@ package sjsonnet +import upickle.core.SimpleVisitor + import java.io.{ BufferedOutputStream, InputStream, @@ -14,53 +16,53 @@ import scala.annotation.unused import scala.util.Try object SjsonnetMainBase { - def resolveImport( + class SimpleImporter( searchRoots0: Seq[Path], // Evaluated in order, first occurrence wins allowedInputs: Option[Set[os.Path]] = None, - debugImporter: Boolean = false): Importer = - new Importer { - def resolve(docBase: Path, importName: String): Option[Path] = - (docBase +: searchRoots0) - .flatMap(base => - os.FilePath(importName) match { - case r: os.SubPath => Some(base.asInstanceOf[OsPath].p / r) - case r: os.RelPath => - if (r.ups > base.segmentCount()) None - else Some(base.asInstanceOf[OsPath].p / r) - case a: os.Path => Some(a) - } - ) - .filter(p => { - val allowed = allowedInputs.fold(true)(_(p)) - if (debugImporter) { - if (allowed) System.err.println(s"[import $importName] candidate $p") - else - System.err.println( - s"[import $importName] excluded $p because it's not in $allowedInputs" - ) - } - allowed - }) - .find(f => os.exists(f) && !os.isDir(f)) - .orElse({ - if (debugImporter) { - System.err.println(s"[import $importName] none of the candidates exist") - } - None - }) - .flatMap(p => { - if (debugImporter) { + debugImporter: Boolean = false) + extends Importer { + def resolve(docBase: Path, importName: String): Option[Path] = + (docBase +: searchRoots0) + .flatMap(base => + os.FilePath(importName) match { + case r: os.SubPath => Some(base.asInstanceOf[OsPath].p / r) + case r: os.RelPath => + if (r.ups > base.segmentCount()) None + else Some(base.asInstanceOf[OsPath].p / r) + case a: os.Path => Some(a) + } + ) + .filter(p => { + val allowed = allowedInputs.fold(true)(_(p)) + if (debugImporter) { + if (allowed) System.err.println(s"[import $importName] candidate $p") + else System.err.println( - s"[import $importName] $p is selected as it exists and is not a directory" + s"[import $importName] excluded $p because it's not in $allowedInputs" ) - } - Some(OsPath(p)) - }) + } + allowed + }) + .find(f => os.exists(f) && !os.isDir(f)) + .orElse({ + if (debugImporter) { + System.err.println(s"[import $importName] none of the candidates exist") + } + None + }) + .flatMap(p => { + if (debugImporter) { + System.err.println( + s"[import $importName] $p is selected as it exists and is not a directory" + ) + } + Some(OsPath(p)) + }) - def read(path: Path, binaryData: Boolean): Option[ResolvedFile] = { - readPath(path, binaryData, debugImporter) - } + def read(path: Path, binaryData: Boolean): Option[ResolvedFile] = { + readPath(path, binaryData, debugImporter) } + } def main0( args: Array[String], @@ -70,7 +72,7 @@ object SjsonnetMainBase { stderr: PrintStream, wd: os.Path, allowedInputs: Option[Set[os.Path]] = None, - importer: Option[(Path, String) => Option[os.Path]] = None, + importer: Option[Importer] = None, std: Val.Obj = sjsonnet.stdlib.StdLibModule.Default.module): Int = { var hasWarnings = false @@ -94,7 +96,28 @@ object SjsonnetMainBase { autoPrintHelpAndExit = None ) file <- Right(config.file) - outputStr <- mainConfigured(file, config, parseCache, wd, allowedInputs, importer, warn, std) + outputStr <- mainConfigured( + file, + config, + new Settings( + preserveOrder = config.preserveOrder.value, + strict = config.strict.value, + throwErrorForInvalidSets = config.throwErrorForInvalidSets.value, + maxParserRecursionDepth = config.maxParserRecursionDepth, + brokenAssertionLogic = config.brokenAssertionLogic.value + ), + parseCache, + wd, + importer.getOrElse { + new SimpleImporter( + config.getOrderedJpaths.map(p => OsPath(os.Path(p, wd))), + allowedInputs, + debugImporter = config.debugImporter.value + ) + }, + warn, + std + ) res <- { if (hasWarnings && config.fatalWarnings.value) Left("") else Right(outputStr) @@ -112,7 +135,6 @@ object SjsonnetMainBase { case Some(f) => os.write.over(os.Path(f, wd), str) } } - 0 } } @@ -124,6 +146,14 @@ object SjsonnetMainBase { indent = config.indent, getCurrentPosition = getCurrentPosition ) + else if (config.expectString.value) + new SimpleVisitor[Writer, Writer] { + val expectedMsg = "expected string result" + override def visitString(s: CharSequence, index: Int): Writer = { + wr.write(s.toString) + wr + } + } else new Renderer(wr, indent = config.indent) private def handleWriteFile[T](f: => T): Either[String, T] = @@ -141,7 +171,6 @@ object SjsonnetMainBase { case None => val sw = new StringWriter materialize(sw).map(_ => sw.toString) - case Some(f) => handleWriteFile( os.write.over.outputStream(os.Path(f, wd), createFolders = config.createDirs.value) @@ -157,11 +186,6 @@ object SjsonnetMainBase { } } - private def expectString(v: ujson.Value) = v match { - case ujson.Str(s) => Right(s) - case _ => Left("expected string result, got: " + v.getClass) - } - private def renderNormal( config: Config, interp: Interpreter, @@ -169,36 +193,27 @@ object SjsonnetMainBase { path: os.Path, wd: os.Path, getCurrentPosition: () => Position) = { - writeToFile(config, wd)(writer => - if (config.expectString.value) { - val res = interp.interpret(jsonnetCode, OsPath(path)).flatMap(expectString) - res match { - case Right(s) => writer.write(s) - case _ => - } - res - } else { - val renderer = rendererForConfig(writer, config, getCurrentPosition) - val res = interp.interpret0(jsonnetCode, OsPath(path), renderer) - if (config.yamlOut.value) writer.write('\n') - res - } - ) + writeToFile(config, wd) { writer => + val renderer = rendererForConfig(writer, config, getCurrentPosition) + val res = interp.interpret0(jsonnetCode, OsPath(path), renderer) + if (config.yamlOut.value) writer.write('\n') + res + } } private def isScalar(v: ujson.Value) = !v.isInstanceOf[ujson.Arr] && !v.isInstanceOf[ujson.Obj] - private def parseBindings( + def parseBindings( strs: Seq[String], strFiles: Seq[String], codes: Seq[String], codeFiles: Seq[String], - wd: os.Path) = { + wd: os.Path): Map[String, String] = { def split(s: String) = s.split("=", 2) match { case Array(x) => (x, System.getenv(x)) case Array(x, v) => (x, v) - case _ => ??? + case _ => throw new IllegalArgumentException("invalid binding: " + s) } def splitMap(s: Seq[String], f: String => String) = @@ -223,15 +238,16 @@ object SjsonnetMainBase { * Right(str) if there's some string that needs to be printed to stdout or --output-file, * Left(err) if there is an error to be reported */ - private def mainConfigured( + def mainConfigured( file: String, config: Config, + settings: Settings, parseCache: ParseCache, wd: os.Path, - allowedInputs: Option[Set[os.Path]], - importer: Option[(Path, String) => Option[os.Path]], - warnLogger: (Boolean, String) => Unit, - std: Val.Obj): Either[String, String] = { + importer: Importer, + warnLogger: Evaluator.Logger, + std: Val.Obj, + evaluatorOverride: Option[Evaluator] = None): Either[String, String] = { val (jsonnetCode, path) = if (config.exec.value) (file, wd / Util.wrapInLessThanGreaterThan("exec")) @@ -265,35 +281,23 @@ object SjsonnetMainBase { queryExtVar = (key: String) => extBinding.get(key).map(ExternalVariable.code), queryTlaVar = (key: String) => tlaBinding.get(key).map(ExternalVariable.code), OsPath(wd), - importer = importer match { - case Some(i) => - new Importer { - def resolve(docBase: Path, importName: String): Option[Path] = - i(docBase, importName).map(OsPath.apply) - def read(path: Path, binaryData: Boolean): Option[ResolvedFile] = { - readPath(path, binaryData) - } - } - case None => - resolveImport( - config.getOrderedJpaths.map(os.Path(_, wd)).map(OsPath.apply), - allowedInputs, - config.debugImporter.value - ) - }, + importer = importer, parseCache, - settings = new Settings( - preserveOrder = config.preserveOrder.value, - strict = config.strict.value, - throwErrorForInvalidSets = config.throwErrorForInvalidSets.value, - maxParserRecursionDepth = config.maxParserRecursionDepth, - brokenAssertionLogic = config.brokenAssertionLogic.value - ), + settings = settings, storePos = (position: Position) => if (config.yamlDebug.value) currentPos = position else (), logger = warnLogger, std = std, variableResolver = _ => None - ) + ) { + override def createEvaluator( + resolver: CachedResolver, + extVars: String => Option[Expr], + wd: Path, + settings: Settings): Evaluator = + evaluatorOverride.getOrElse( + super.createEvaluator(resolver, extVars, wd, settings) + ) + } (config.multi, config.yamlStream.value) match { case (Some(multiPath), _) => @@ -303,14 +307,10 @@ object SjsonnetMainBase { obj.value.toSeq.map { case (f, v) => for { rendered <- { - if (config.expectString.value) { - expectString(v) - } else { - val writer = new StringWriter() - val renderer = rendererForConfig(writer, config, () => currentPos) - ujson.transform(v, renderer) - Right(writer.toString) - } + val writer = new StringWriter() + val renderer = rendererForConfig(writer, config, () => currentPos) + ujson.transform(v, renderer) + Right(writer.toString) } relPath = (os.FilePath(multiPath) / os.RelPath(f)).asInstanceOf[os.FilePath] _ <- writeFile(config, relPath.resolveFrom(wd), rendered) @@ -376,7 +376,7 @@ object SjsonnetMainBase { private def readPath( path: Path, binaryData: Boolean, - debugImporter: Boolean = false): Option[ResolvedFile] = { + debugImporter: Boolean): Option[ResolvedFile] = { val osPath = path.asInstanceOf[OsPath].p if (os.exists(osPath) && !os.isDir(osPath)) { Some( diff --git a/sjsonnet/test/src-jvm-native/sjsonnet/BaseFileTests.scala b/sjsonnet/test/src-jvm-native/sjsonnet/BaseFileTests.scala index 300f8c5f..d759df8d 100644 --- a/sjsonnet/test/src-jvm-native/sjsonnet/BaseFileTests.scala +++ b/sjsonnet/test/src-jvm-native/sjsonnet/BaseFileTests.scala @@ -51,7 +51,7 @@ abstract class BaseFileTests extends TestSuite { ), Map("var1" -> "\"test\"", "var2" -> """{"x": 1, "y": 2}"""), OsPath(testSuiteRoot / testSuite), - importer = sjsonnet.SjsonnetMainBase.resolveImport(Array.empty[Path].toIndexedSeq), + importer = new sjsonnet.SjsonnetMainBase.SimpleImporter(Array.empty[Path].toIndexedSeq), parseCache = new DefaultParseCache, logger = (isTrace: Boolean, msg: String) => { if (isTrace) { diff --git a/sjsonnet/test/src-jvm-native/sjsonnet/PrettyYamlRendererTests.scala b/sjsonnet/test/src-jvm-native/sjsonnet/PrettyYamlRendererTests.scala index 03ef2878..def630eb 100644 --- a/sjsonnet/test/src-jvm-native/sjsonnet/PrettyYamlRendererTests.scala +++ b/sjsonnet/test/src-jvm-native/sjsonnet/PrettyYamlRendererTests.scala @@ -11,7 +11,7 @@ object PrettyYamlRendererTests extends TestSuite { Map(), Map(), OsPath(testSuiteRoot), - importer = sjsonnet.SjsonnetMainBase.resolveImport(Array(OsPath(testSuiteRoot)).toIndexedSeq), + importer = new SjsonnetMainBase.SimpleImporter(Array(OsPath(testSuiteRoot)).toIndexedSeq), parseCache = new DefaultParseCache, storePos = if (comments) currentPos = _ else null )