Skip to content

Commit 12449ca

Browse files
committed
fix: higher order safe operators
1 parent 1c64682 commit 12449ca

File tree

3 files changed

+41
-2
lines changed

3 files changed

+41
-2
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "SymbolicRegression"
22
uuid = "8254be44-1295-4e6a-a16d-46603ac705cb"
33
authors = ["MilesCranmer <[email protected]>"]
4-
version = "1.5.0"
4+
version = "1.5.1"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

src/Operators.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@ const Dual = ForwardDiff.Dual
2626
#binary: mod
2727
#unary: exp, abs, log1p, sin, cos, tan, sinh, cosh, tanh, asin, acos, atan, asinh, acosh, atanh, erf, erfc, gamma, relu, round, floor, ceil, round, sign.
2828

29-
const FloatOrDual = Union{AbstractFloat,Dual{<:Any,<:AbstractFloat}}
29+
const FloatOrDual = Union{AbstractFloat,Dual}
30+
# Note that a complex dual is Complex{<:Dual}, so we are safe to use this signature.
3031

3132
# Use some fast operators from https://github.com/JuliaLang/julia/blob/81597635c4ad1e8c2e1c5753fda4ec0e7397543f/base/fastmath.jl
3233
# Define allowed operators. Any julia operator can also be used.

test/test_composable_expression.jl

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,3 +340,41 @@ end
340340
X = stack(([1.0, 2.0], [3.0, 4.0], [5.0, 6.0]); dims=1)
341341
@test expr(X) [1.0, 2.0] .- sin.([3.0, 4.0] .- [5.0, 6.0]) .+ 2.5
342342
end
343+
344+
@testitem "Test higher-order derivatives of safe_log with DynamicDiff" tags = [:part3] begin
345+
using SymbolicRegression
346+
using SymbolicRegression: D, safe_log, ValidVector
347+
using DynamicExpressions: OperatorEnum
348+
using ForwardDiff: DimensionMismatch
349+
350+
operators = OperatorEnum(; binary_operators=(+, -, *, /), unary_operators=(safe_log,))
351+
variable_names = ["x"]
352+
x = ComposableExpression(Node{Float64}(; feature=1); operators, variable_names)
353+
354+
# Test first and second derivatives of log(x)
355+
structure = TemplateStructure{(:f,)}(
356+
((; f), (x,)) ->
357+
ValidVector([(f(x).x[1], D(f, 1)(x).x[1], D(D(f, 1), 1)(x).x[1])], true),
358+
)
359+
expr = TemplateExpression((; f=log(x)); structure, operators, variable_names)
360+
361+
# Test at x = 2.0 where log(x) is well-defined
362+
X = [2.0]'
363+
result = only(expr(X))
364+
@test result !== nothing
365+
@test result[1] == log(2.0) # function value
366+
@test result[2] == 1 / 2.0 # first derivative
367+
@test result[3] == -1 / 4.0 # second derivative
368+
369+
# We handle invalid ranges gracefully:
370+
X_invalid = [-1.0]'
371+
result = only(expr(X_invalid))
372+
@test result !== nothing
373+
@test isnan(result[1])
374+
@test result[2] == 0.0
375+
@test result[3] == 0.0
376+
377+
# Eventually we want to support complex numbers:
378+
X_complex = [-1.0 - 1.0im]'
379+
@test_throws DimensionMismatch expr(X_complex)
380+
end

0 commit comments

Comments
 (0)