用 Python 拟合拉格朗日多项式
在遇到很无聊的傻福小学数学找规律题目时,我们可以用一个很通用的方法来解决这种唐的没边的问题:拉格朗日差值公式。
公式形式如下: $$ L(x) := \sum_{j=0}^k y_j \left( \prod_{i=0,i\neq j}^{k} \frac{x-x_i}{x_j-x_i} \right) $$
这个公式的原理很简单,右侧括号内的乘积用于保证在变量 $x$ 取到 $x_j$ 时值为 $1$ ,而如果是剩下的 $x_i$ 则值为 $0$ ,这时候乘上一个 $y$ 就能得到点 $(x_j,y_j)$ 了。
显然我们不喜欢手算这玩意,于是我们寻求计算机的帮助。我们希望写一个程序,我们输入数组 x[k] 与 y[k] ,程序帮我们计算并输出对应的公式(LaTeX代码)。于是初步判断我们会用到以下库:
import numpy as np
import matplotlib.pyplot as plt
from sympy import symbols, simplify, latex, lambdify, Rational, expand
处理输入部分:我准备把所有的 x 和所有的 y 都一股脑输进去,而不是输入一对一对的 (x,y) ,会很慢。这里我的代码没有要求先输入数组长度,而是判断了一下输入的两个数组长度是否一致。如果不一致就报 ERROR 。绘图用浮点,公式用分数,讲究!。
arr_x = input("请输入 x 数组(用空格分隔,例如:1 2.5 3.75):\n")
arr_y = input("请输入 y 数组(用空格分隔,例如:2 3.1 4.8):\n")
x_strs = arr_x.split()
y_strs = arr_y.split()
if len(x_strs) == 0 or len(y_strs) == 0:
raise ValueError("未输入任何 x 或 y 数据。")
if len(x_strs) != len(y_strs):
raise ValueError("x 和 y 的数量必须一致!")
# 转为 float 用于绘图
x_values_float = [float(s) for s in x_strs]
y_values_float = [float(s) for s in y_strs]
if len(set(x_values_float)) != len(x_values_float):
raise ValueError("x 数组中不能有重复值!")
# 转为 SymPy Rational(精确,分数形式)
# 使用 Rational(s) 可以把十进制小数字符串精确转成分数
x_values_sym = [Rational(s) for s in x_strs]
y_values_sym = [Rational(s) for s in y_strs]
这一段是拉格朗日插值公式的本体,其中 li 就是刚刚提到的那个多项式乘积。 sympy 会自动把含有刚刚定义过的 x 的式子用公式的形式保存起来(而不是像 python 本体一样直接进行计算)。使用 simplify() 来化简刚刚整出来的多项式。
def lagrange_interpolation(x_vals_sym, y_vals_sym):
lagrange_poly = 0
n = len(x_vals_sym)
for i in range(n):
li = 1
for j in range(n):
if i != j:
li *= (x - x_vals_sym[j]) / (x_vals_sym[i] - x_vals_sym[j])
lagrange_poly += li * y_vals_sym[i]
lagrange_poly = simplify(lagrange_poly)
return lagrange_poly
计算......
lagrange_poly = lagrange_interpolation(x_values_sym, y_values_sym)
lagrange_poly_expanded = expand(lagrange_poly)
输出 LaTeX :
expanded_latex = latex(lagrange_poly_expanded)
print("展开后的多项式:")
print(expanded_latex)
因为使用的是 JupyterLab ,尝试启用 LaTeX 渲染。
use_tex = True
try:
plt.rcParams['text.usetex'] = True
plt.rcParams['text.latex.preamble'] = r'\usepackage{amsmath}'
# 小绘制检测(若系统没有 TeX 会在这里抛错)
fig_test = plt.figure()
fig_test.text(0.5, 0.5, r"$\mathrm{test}$")
plt.close(fig_test)
except Exception:
use_tex = False
plt.rcParams['text.usetex'] = False
if 'text.latex.preamble' in plt.rcParams:
del plt.rcParams['text.latex.preamble']
print("提示:系统未检测到完整 LaTeX(或渲染失败),将使用 matplotlib mathtext。")
# 将精确的符号多项式数值化(lambdify 会把 Rational 转成可计算的浮点函数)
f = lambdify(x, lagrange_poly, 'numpy')
接下来就是保存与绘图了,正常使用 matplotlib 即可。
# 范围
x_plot = np.linspace(min(x_values_float) - 0.5, max(x_values_float) + 0.5, 600)
y_plot = f(x_plot) # 由精确多项式数值化得到的曲线(浮点)
# 点和曲线
plt.figure(figsize=(10, 6), dpi=150)
plt.plot(x_values_float, y_values_float, 'o', label='Data points', markersize=6)
plt.plot(x_plot, y_plot, '-', label='Lagrange polynomial')
plt.legend()
plt.grid(True)
plt.xlabel('x')
plt.ylabel('y')
plt.title('Lagrange Interpolation')
# plt.savefig("lagrange_polynomial_graph.png", dpi=300)
# print("函数图像已保存为 lagrange_polynomial_graph.png")
plt.show()
然后来绘制公式,其实也不难,按照如下方法即可。在这里给公式添加了 f(x) = 让他更加优雅。
formula_with_y = r"f(x) = %s" % expanded_latex
plt.text(
0.5, 0.5,
r"$%s$" % formula_with_y,
fontsize=20,
ha='center',
va='center',
transform=plt.gca().transAxes
)
plt.savefig("lagrange_formula.png", dpi=300, bbox_inches='tight')
print("公式图片已保存为 lagrange_formula.png")
plt.show()
举个例子,输入:
x = [1, 2, 3, 4, 5]
y = [1.1, 2.2, 3.3, 4.4, 114.514]
发现程序计算输出了:
\frac{18169 x^{4}}{4000} - \frac{18169 x^{3}}{400} + \frac{127183 x^{2}}{800} - \frac{18081 x}{80} + \frac{54507}{500}
也就是
$$ f(x) = \frac{18169 x^{4}}{4000} - \frac{18169 x^{3}}{400} + \frac{127183 x^{2}}{800} - \frac{18081 x}{80} + \frac{54507}{500} $$