文章目录
- 05 Ceres
- 5.0 仿函数
- 5.1 Ceres 简介
- 5.2 代码示例
05 Ceres
5.0 仿函数
简单来说,仿函数就是重载了 () 操作符的类,可以实现类似函数调用的过程,所以叫做仿函数。
struct MyPlus
{int operator()(const int &a , const int &b) const{return a + b;}
};int main()
{MyPlus a;cout << MyPlus()(1,2) << endl; //1、通过产生临时对象调用重载运算符cout << a.operator()(1,2) << endl; //2、通过对象显示调用重载运算符cout << a(1,2) << endl; //3、通过对象类似函数调用 隐示地调用重载运算符return 0;
}
其次,关于构造函数和仿函数的调用:
- 构造函数无返回值,而
operator()
是可以有返回值的; - 构造函数是声明对象,而仿函数则需要声明好的对象进行调用。
- 构造函数和仿函数在形式上是可以相同的,但用法和目的不同。
例如
#include <iostream>// 定义一个带有构造函数的仿函数类
class MyFunctor {
public:// 构造函数,接受一个整数参数MyFunctor(int config) : config_(config) { }// 仿函数的函数调用运算符,接受两个整数并返回它们的和,加上构造函数参数int operator()(int a, int b) {return a + b + config_;}private:int config_;
};int main() {// 创建一个带有构造函数参数的仿函数对象MyFunctor addWithConfig(10);// 使用仿函数对象调用它int result = addWithConfig(3, 4);std::cout << "Result: " << result << std::endl;return 0;
}
构造函数和仿函数可以是相同的(接收同样的参数)
#include <iostream>// 定义一个类,同时拥有构造函数和仿函数
class MyFunction {
public:// 构造函数,用于对象的初始化MyFunction(int initialValue) : value(initialValue) {std::cout << "Constructor called. Value: " << value << std::endl;}// 仿函数的函数调用运算符,用于执行特定的操作int operator()(int x) {std::cout << "Function called with argument: " << x << std::endl;return value + x;}private:int value;
};int main() {// 创建一个对象并调用构造函数MyFunction obj1(10);// 使用仿函数的方式调用对象int result = obj1(5);std::cout << "Result: " << result << std::endl;return 0;
}
5.1 Ceres 简介
Ceres 可以解决带有约束条件的非线性最小二乘问题,数学表达如下:
min x 1 2 ∑ i ρ i ( ∥ f i ( x i 1 , … , x i k ) ∥ 2 ) s . t . l j ⩽ x j ⩽ u j . \min_{x}\quad\frac12\sum_{\mathrm{i}}\rho_{\mathrm{i}}\left(\|f_{\mathrm{i}}\left(x_{\mathrm{i}1},\ldots,x_{\mathrm{i}\mathrm{k}}\right.)\|^2\right) \\ \mathrm{s.t.~}l_j\leqslant x_j\leqslant u_j. xmin21i∑ρi(∥fi(xi1,…,xik)∥2)s.t. lj⩽xj⩽uj.
其中
-
ρ i ( ∥ f i ( x i 1 , … , x i k ) ∥ 2 ) \rho_{\mathrm{i}}\left(\|f_{\mathrm{i}}\left(x_{\mathrm{i}1},\ldots,x_{\mathrm{i}\mathrm{k}}\right.)\|^2\right) ρi(∥fi(xi1,…,xik)∥2) 为残差块;
-
f ( ⋅ ) f(\cdot) f(⋅) 为代价函数,也即误差项;
-
ρ ( ⋅ ) \rho(\cdot) ρ(⋅) 是核函数(也称损失函数),它属于标量函数,为了减小异常值对非线性优化的影响,一般就取恒等函数(如 “ ρ ( x ) = x \rho(x)=x ρ(x)=x”);
-
l j , u j l_j, u_j lj,uj 为 x j x_j xj 的上下界。
特殊情况:当损失函数 ρ i = x \rho_i=x ρi=x,$l_j=- \infty , u_j=\infty $,那么得到了一个常见的非线性优化函数:
min x 1 2 ∑ i ( ∥ f i ( x i 1 , … , x i k ) ∥ 2 ) \min_{x}\quad\frac12\sum_{\mathrm{i}}\left(\|f_{\mathrm{i}}\left(x_{\mathrm{i}1},\ldots,x_{\mathrm{i}\mathrm{k}}\right.)\|^2\right) xmin21i∑(∥fi(xi1,…,xik)∥2)
Ceres 求解的一般步骤
-
定义Cost Function模型,即代价函数 f ( ⋅ ) f(\cdot) f(⋅)。也就是我们要寻找的最优目标,这里我们用到了仿函数或称为拟函数(functor)。做法是写一个类,然后在仿函数中重载()运算符。
-
使用定义的代价函数构建待求解的优化问题。即调用AddResidualBlock将误差项,添加到目标函数中。由于优化需要梯度,我们有几种选择: ① 使用ceres自动求导(Auto Diff)② 使用数值求导(Numeric Diff)3)自行推导解析形式,提供给ceres。
-
配置求解器参数并求解问题。配置项options比较丰富,可以查看options的定义。
5.2 代码示例
拟合函数 y = exp ( a x 2 + b x + c ) y=\exp(ax^2+bx+c) y=exp(ax2+bx+c),自动求导。
/*********************************************************** *
* Time: 2023/8/29
* Author: xiaocong
* Function: Ceres
***********************************************************/#include <iostream>
#include <ceres/ceres.h>const int N = 100; // 数据点个数using namespace std;// 定义代价函数即 f()
// 仿函数
class CurveFittingCost
{
public:// 构造函数CurveFittingCost(double x, double y) : _x(x), _y(y) {}// 计算残差template <typename T>bool operator()(const T* const abc, // 待优化变量,三维T* residual) const // 残差{residual[0] = T(_y) - ceres::exp(abc[0] * T(_x) * T(_x) + abc[1] * T(_x) + abc[2]); // y-exp(ax^2+bx+c)return true;}private:const double _x, _y;
};int main()
{double ar = 1.0, br = 2.0, cr = 1.0; // 真实参数值double abc[3] = { 0,0,0 }; // 初始估计值// 生成数据vector<double> x_data, y_data;for (int i = 0; i < N; i++){double xi = i / 100.0; // [0~1]double sigma = 0.02 * (rand() % 1000) / 1000.0 - 0.01; // 随机噪声,[-0.01, 0.01]double yi = exp(ar * xi * xi + br * xi + cr) + sigma;x_data.push_back(xi);y_data.push_back(yi);}// 构建最小二乘问题ceres::Problem problem;for (int i = 0;i < N;i++) // 添加残差块{// 使用自动求导,模板参数:误差类型,输出维度1,输入维度3,维数要与前面 class 中一致problem.AddResidualBlock(new ceres::AutoDiffCostFunction<CurveFittingCost, 1, 3>(new CurveFittingCost(x_data[i], y_data[i])), nullptr, abc);}// 配置求解器ceres::Solver::Options options;options.linear_solver_type = ceres::DENSE_QR; // 求解增量方程,QR 分解options.minimizer_progress_to_stdout = true; // 输出到coutceres::Solver::Summary summary; // 优化信息ceres::Solve(options, &problem, &summary); // 开始优化// 输出结果cout << summary.BriefReport() << endl;cout << "estimated a,b,c = ";for (auto a : abc)cout << a << " ";cout << endl;return 0;
}
结果
estimated a,b,c = 0.999313 2.00098 0.999658