Ceres的自动求导
算法原理介绍
1. 函数
在Ceres中自动求导函数在jet.h中,内部定义了许多常规的函数,关于cos,sin,log等,全部使用了模板类的方式进行封装,因为涉及到多元函数的复合求导,未知数数量未知
// An highlighted block
// atan2(b + db, a + da) ~= atan2(b, a) + (- b da + a db) / (a^2 + b^2)
//
// In words: the rate of change of theta is 1/r times the rate of
// change of (x, y) in the positive angular direction.
template <typename T, int N> inline
Jet<T, N> atan2(const Jet<T, N>& g, const Jet<T, N>& f) {
// Note order of arguments:
//
// f = a + da
// g = b + db
T const tmp = T(1.0) / (f.a * f.a + g.a * g.a);
return Jet<T, N>(atan2(g.a, f.a), tmp * (- g.a * f.v + f.a * g.v));
}
// pow -- base is a differentiable function, exponent is a constant.
// (a+da)^p ~= a^p + p*a^(p-1) da
template <typename T, int N> inline
Jet<T, N> pow(const Jet<T, N>& f, double g) {
T const tmp = g * pow(f.a, g - T(1.0));
return Jet<T, N>(pow(f.a, g), tmp * f.v);
}
// pow -- base is a constant, exponent is a differentiable function.
// We have various special cases, see the comment for pow(Jet, Jet) for
// analysis:
//
// 1. For f > 0 we have: (f)^(g + dg) ~= f^g + f^g log(f) dg
//
// 2. For f == 0 and g > 0 we have: (f)^(g + dg) ~= f^g
//
// 3. For f < 0 and integer g we have: (f)^(g + dg) ~= f^g but if dg
// != 0, the derivatives are not defined and we return NaN.
template <typename T, int N> inline
Jet<T, N> pow(double f, const Jet<T, N>& g) {
if (f == 0 && g.a > 0) {
// Handle case 2.
return Jet<T, N>(T(0.0));
}
if (f < 0 && g.a == floor(g.a)) {
// Handle case 3.
Jet<T, N> ret(pow(f, g.a));
for (int i = 0; i < N; i++) {
if (g.v[i] != T(0.0)) {
// Return a NaN when g.v != 0.
ret.v[i] = std::numeric_limits<T>::quiet_NaN();
}
}
return ret;
}
// Handle case 1.
T const tmp = pow(f, g.a);
return Jet<T, N>(tmp, log(f) * tmp * g.v);
}
}
2. 复合函数求导
其核心算法使用的是复合函数的链式求导法则,
3. 前向求导
ceres中就是采用了前向求导,具体的实现如下:
4. 例子
对以下公式求导
x∗x+y∗y
\sqrt{x*x+y*y}
x∗x+y∗y
// An highlighted block
template <typename T, int N>
inline Jet<T, N> hypot(const Jet<T, N>& x, const Jet<T, N>& y) {
// d/da sqrt(a) = 0.5 / sqrt(a)
// d/dx x^2 + y^2 = 2x
// So by the chain rule:
// d/dx sqrt(x^2 + y^2) = 0.5 / sqrt(x^2 + y^2) * 2x = x / sqrt(x^2 + y^2)
// d/dy sqrt(x^2 + y^2) = y / sqrt(x^2 + y^2)
const T tmp = hypot(x.a, y.a);
return Jet<T, N>(tmp, x.a / tmp * x.v + y.a / tmp * y.v);
}
假设w=x∗x+y∗y
假设w=x*x+y*y
假设w=x∗x+y∗y
对x求导则变成
12∗w2x=xw
\frac{1}{ 2*\sqrt{w} }2x=\frac{x}{ \sqrt{w} }
2∗w12x=wx
对y求导则变成
12∗w2y=yw
\frac{1}{ 2*\sqrt{w} }2y=\frac{y}{ \sqrt{w} }
2∗w12y=wy
其中w为上一步已知的数值
参考:
[1]: https://blog.csdn.net/zhangjunhit/article/details/89173488
[2]: https://blog.csdn.net/electech6/article/details/52790562