CSP-202309-3-梯度求解
- 作为一个算法小白,本人第一次接触大模拟的题,本题的算法参考自:【CSP】202309-3 梯度求解
解题思路
1.输入处理
-
getchar();
:从标准输入读取一个字符。这里它的作用可能是用来“吃掉”(消耗)前一个输入后留下的换行符。确保getline
能正确读取到下一行文本。 -
getline(cin, op);
:从标准输入流cin
中读取一行文本,直到遇到换行符(用户按下回车键),然后将读取到的文本(不包括换行符)存储到之前声明的字符串变量中。
2.逐个处理表达式中的元素
(1)变量 x i x_i xi的表示
struct elem {int index; // 变量索引,即x的下标long long value; // 变量值long long derivative; // 对应变量的导数值
};
- 注意,这里因为最后求的是导函数的值,而无需记录导数的形式。例如,对于 f ( x ) = x 2 , x = 1 f(x)=x^2,x=1 f(x)=x2,x=1,其导函数 f ′ ( x ) = 2 x f'(x)=2x f′(x)=2x,这里我们直接记录 f ′ ( 1 ) = 2 f'(1)=2 f′(1)=2,即
derivative=2
。
(2)将数字字符串转换为长整型
long long str2ll(string a) {int sign = 1; // 判断是正数还是负数long long ans = 0;if (a[0] == '-')sign = -1;for (int i = 0; i < a.length(); i++) {if (a[i] != '-')ans = 10 * ans + (a[i] - '0');}return ans * sign;
}
(3)使用 stringstream
逐个处理 op
字符串中的元素
-
std::stringstream
是一个流类,可以像输入/输出流一样操作字符串。允许把字符串分割成多个部分,根据空白字符(如空格、制表符等)来拆分原始字符串。 -
循环的作用是逐个读取
op
字符串中的每个以空格分隔的子字符串,并在每次迭代中处理这些子字符串。这种处理方式对于解析和执行基于逆波兰表示法(RPN)的算术表达式非常有效,因为它允许程序按照操作的顺序(从左到右)逐步计算表达式的结果。
stringstream ss(op);
string s;
while (ss >> s) {}
(4)逆波兰式的处理逻辑
- 题目所给的字符串只涉及:
x,x的索引,运算符+-*,常数
。 - 整理来看可以分为两类:变量+运算符,这也就明确了处理的逻辑。
-
变量 x i x_i xi,存入
elem
中。if (s[0] == 'x') {elem a;a.index = str2ll(s.substr(1, s.length() - 1)); // 得到变量下标a.derivative = xIndex == a.index ? 1 : 0; // 该变量是否要被求偏导(导数是 1,否则为 0)a.value = value[a.index]; // 变量在给定的值数组中的值st.push(a); // 将包含变量信息的结构体 a 压入栈中,以便后续计算表达式的值和导数(和数字运算一样) }
-
运算符,由于求导运算本质上还是算数运算,并且是给定了变量值,本质上还是我们之前遇到过的算数运算的规则:遇到运算符,移出栈顶的两个操作数,进行对应的运算,
res
用于保存运算结果。elem op2 = st.top(); st.pop(); elem op1 = st.top(); st.pop(); elem res;
-
由于求的是导数,这里的
+-*
不再是普通意义上的+-*
,要符合导数运算的规则。switch (s[0]) { case '+': {res.value = ((op1.value + op2.value) % MOD + MOD) % MOD;res.derivative = ((op1.derivative + op2.derivative) % MOD + MOD) % MOD;break; } case '-': {res.value = ((op1.value - op2.value) % MOD + MOD) % MOD;res.derivative = ((op1.derivative - op2.derivative) % MOD + MOD) % MOD;break; } case '*': {res.value = ((op1.value * op2.value) % MOD + MOD) % MOD;res.derivative = ((op1.derivative * op2.value + op1.value * op2.derivative) % MOD + MOD) % MOD; } } st.push(res);
-
常数,类似于非变量的 x i x_i xi。
else {elem a;a.value = str2ll(s);a.derivative = 0;st.push(a); }
-
- 最终,栈顶的结果即为运算结果。
3.完善代码
#include <iostream>
#include <vector>
#include <stack>
#include <sstream>using namespace std;// 定义一个结构体 elem,用于表示表达式中的元素
struct elem {int index; // 变量索引long long value; // 变量值long long derivative; // 对应变量的导数
};const long long MOD = 1000000007; // 模数// 将字符串转换为长整型
long long str2ll(string a) {int sign = 1; // 判断是正数还是负数long long ans = 0;if (a[0] == '-')sign = -1;for (int i = 0; i < a.length(); i++) {if (a[i] != '-')ans = 10 * ans + (a[i] - '0');}return ans * sign;
}int main() {int n, m;cin >> n >> m; // 输入变量个数和表达式数量string op;getchar();getline(cin, op); // 获取表达式字符串vector<elem> expr;stack<elem> st;for (int i = 0; i < m; i++) {int xIndex;vector<long long> value(n + 1);cin >> xIndex; // 输入变量xi,其余均视为常量for (int j = 1; j <= n; j++)cin >> value[j]; // 输入每个变量的值stringstream ss(op);string s;while (ss >> s) {// 判断是否是变量if (s[0] == 'x') {elem a;a.index = str2ll(s.substr(1, s.length() - 1)); // 得到变量的索引a.derivative = xIndex == a.index ? 1 : 0; // 变量对目标变量的导数是 1,否则为 0a.value = value[a.index]; // 变量在给定的值数组中的值st.push(a); // 将包含变量信息的结构体 a 压入栈中,以便后续计算表达式的值和导数}// 检查当前读取到的字符串 s 是否只有一个字符且为加号、减号或乘号。如果是这三个运算符之一,就执行相应的运算逻辑else if (s.length() == 1 && (s[0] == '+' || s[0] == '-' || s[0] == '*')) {// 处理运算符的逻辑elem op2 = st.top();st.pop();elem op1 = st.top();st.pop();elem res;switch (s[0]) {case '+': {res.value = ((op1.value + op2.value) % MOD + MOD) % MOD;res.derivative = ((op1.derivative + op2.derivative) % MOD + MOD) % MOD;break;}case '-': {res.value = ((op1.value - op2.value) % MOD + MOD) % MOD;res.derivative = ((op1.derivative - op2.derivative) % MOD + MOD) % MOD;break;}case '*': {res.value = ((op1.value * op2.value) % MOD + MOD) % MOD;res.derivative = ((op1.derivative * op2.value + op1.value * op2.derivative) % MOD + MOD) % MOD;}}st.push(res);}else {elem a;a.value = str2ll(s);a.derivative = 0;st.push(a);}}long long ans = st.top().derivative;cout << ((ans % MOD) + MOD) % MOD << endl; // 输出结果取模}return 0;
}