diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index 0b1c250b9..dfc02f659 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -219,10 +219,12 @@ function setlogp!!(vi::AbstractVarInfo, logp::NamedTuple{names}) where {names} end function setlogp!!(vi::AbstractVarInfo, logp::Number) - return error(""" - `setlogp!!(vi::AbstractVarInfo, logp::Number)` is no longer supported. Use - `setloglikelihood!!`, `setlogjac!!`, and/or `setlogprior!!` instead. - """) + return error( + """ + `setlogp!!(vi::AbstractVarInfo, logp::Number)` is no longer supported. Use + `setloglikelihood!!`, `setlogjac!!`, and/or `setlogprior!!` instead. + """ + ) end """ diff --git a/src/debug_utils.jl b/src/debug_utils.jl index 7c0be506b..1097bac6f 100644 --- a/src/debug_utils.jl +++ b/src/debug_utils.jl @@ -14,8 +14,10 @@ export check_model, has_static_constraints An accumulator which checks calls at each tilde-statement for potential errors. -Right now this accumulator only checks for `NaN` values on the left-hand side of observe -statements, and partially `missing` values on the left-hand side of observe statements. +Right now this accumulator checks for `NaN` and `±Inf` values on the left-hand +side of observe statements, and partially `missing` values on the left-hand side +of observe statements. + Other checks in `check_model` are accomplished via different accumulators. """ @@ -72,6 +74,15 @@ _has_nans(x::NamedTuple) = any(_has_nans, x) _has_nans(x::AbstractArray) = any(_has_nans, x) _has_nans(x) = isnan(x) _has_nans(::Missing) = false +""" + _has_infs(x) + +Check if `x` is `Inf` or `-Inf`, or contains any such values. +""" +_has_infs(x::NamedTuple) = any(_has_infs, x) +_has_infs(x::AbstractArray) = any(_has_infs, x) +_has_infs(x) = isinf(x) +_has_infs(::Missing) = false function DynamicPPL.accumulate_assume!!( acc::DebugAccumulator, val, tval, logjac, vn::VarName, right::Distribution, template @@ -107,6 +118,16 @@ function DynamicPPL.accumulate_observe!!( @warn msg failed = true end + # Check for Inf values, but only warn if the logpdf at that value is -Inf + # (i.e., Inf is not in the support of the distribution) + if _has_infs(val) && isinf(logpdf(right, val)) + msg = + "Encountered an infinite value on the left-hand side of an" * + " observe statement; this may indicate that your data" * + " contain Inf or -Inf values." + @warn msg + failed = true + end return DebugAccumulator(failed) end @@ -130,6 +151,7 @@ needed. - Repeated usage of the same or overlapping VarNames - `NaN` on the left-hand side of observe statements +- `±Inf` on the left-hand side of observe statements (when the value is not in the support of the distribution) - (if `fail_if_discrete` is set) Usage of discrete distributions diff --git a/test/compiler.jl b/test/compiler.jl index 3d34594cc..18a631ba6 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -372,12 +372,14 @@ end x .~ [Normal(), Normal()] return x end - expected_error = ArgumentError(""" - As of v0.35, DynamicPPL does not allow arrays of distributions in `.~`. \ - Please use `product_distribution` instead, or write a loop if necessary. \ - See https://github.com/TuringLang/DynamicPPL.jl/releases/tag/v0.35.0 for more \ - details.\ - """) + expected_error = ArgumentError( + """ +As of v0.35, DynamicPPL does not allow arrays of distributions in `.~`. \ +Please use `product_distribution` instead, or write a loop if necessary. \ +See https://github.com/TuringLang/DynamicPPL.jl/releases/tag/v0.35.0 for more \ +details.\ +""" + ) @test_throws expected_error (vector_dot_tilde()(); true) end diff --git a/test/debug_utils.jl b/test/debug_utils.jl index 551cf0323..6e3faf402 100644 --- a/test/debug_utils.jl +++ b/test/debug_utils.jl @@ -122,6 +122,22 @@ end @test_throws ErrorException check_model(m; error_on_failure=true) end + @testset "Inf in data" begin + @model function demo_inf_in_data(x) + a ~ Normal() + for i in eachindex(x) + x[i] ~ Normal(a) + end + end + m = demo_inf_in_data([1.0, Inf]) + @test_throws ErrorException check_model(m; error_on_failure=true) + m2 = demo_inf_in_data([1.0, -Inf]) + @test_throws ErrorException check_model(m2; error_on_failure=true) + # Finite data should pass + m3 = demo_inf_in_data([1.0, 2.0]) + @test check_model(m3; error_on_failure=true) + end + @testset "incorrect use of condition" begin @testset "missing in multivariate" begin @model function demo_missing_in_multivariate(x)