用 Python 拟合拉格朗日多项式

在遇到很无聊的傻福小学数学找规律题目时,我们可以用一个很通用的方法来解决这种唐的没边的问题:拉格朗日差值公式。

公式形式如下: $$ L(x) := \sum_{j=0}^k y_j \left( \prod_{i=0,i\neq j}^{k} \frac{x-x_i}{x_j-x_i} \right) $$

这个公式的原理很简单,右侧括号内的乘积用于保证在变量 $x$ 取到 $x_j$ 时值为 $1$ ,而如果是剩下的 $x_i$ 则值为 $0$ ,这时候乘上一个 $y$ 就能得到点 $(x_j,y_j)$ 了。

显然我们不喜欢手算这玩意,于是我们寻求计算机的帮助。我们希望写一个程序,我们输入数组 x[k]y[k] ,程序帮我们计算并输出对应的公式(LaTeX代码)。于是初步判断我们会用到以下库:

import numpy as np
import matplotlib.pyplot as plt
from sympy import symbols, simplify, latex, lambdify, Rational, expand

处理输入部分:我准备把所有的 x 和所有的 y 都一股脑输进去,而不是输入一对一对的 (x,y) ,会很慢。这里我的代码没有要求先输入数组长度,而是判断了一下输入的两个数组长度是否一致。如果不一致就报 ERROR 。绘图用浮点,公式用分数,讲究!。

arr_x = input("请输入 x 数组(用空格分隔,例如:1 2.5 3.75):\n")
arr_y = input("请输入 y 数组(用空格分隔,例如:2 3.1 4.8):\n")

x_strs = arr_x.split()
y_strs = arr_y.split()

if len(x_strs) == 0 or len(y_strs) == 0:
    raise ValueError("未输入任何 x 或 y 数据。")
if len(x_strs) != len(y_strs):
    raise ValueError("x 和 y 的数量必须一致!")

# 转为 float 用于绘图
x_values_float = [float(s) for s in x_strs]
y_values_float = [float(s) for s in y_strs]

if len(set(x_values_float)) != len(x_values_float):
    raise ValueError("x 数组中不能有重复值!")

# 转为 SymPy Rational(精确,分数形式)
# 使用 Rational(s) 可以把十进制小数字符串精确转成分数
x_values_sym = [Rational(s) for s in x_strs]
y_values_sym = [Rational(s) for s in y_strs]

这一段是拉格朗日插值公式的本体,其中 li 就是刚刚提到的那个多项式乘积。 sympy 会自动把含有刚刚定义过的 x 的式子用公式的形式保存起来(而不是像 python 本体一样直接进行计算)。使用 simplify() 来化简刚刚整出来的多项式。

def lagrange_interpolation(x_vals_sym, y_vals_sym):
    lagrange_poly = 0
    n = len(x_vals_sym)
    for i in range(n):
        li = 1
        for j in range(n):
            if i != j:
                li *= (x - x_vals_sym[j]) / (x_vals_sym[i] - x_vals_sym[j])
        lagrange_poly += li * y_vals_sym[i]
    lagrange_poly = simplify(lagrange_poly)
    return lagrange_poly

计算......

lagrange_poly = lagrange_interpolation(x_values_sym, y_values_sym)
lagrange_poly_expanded = expand(lagrange_poly)

输出 LaTeX

expanded_latex = latex(lagrange_poly_expanded)
print("展开后的多项式:")
print(expanded_latex)

因为使用的是 JupyterLab ,尝试启用 LaTeX 渲染。

use_tex = True
try:
    plt.rcParams['text.usetex'] = True
    plt.rcParams['text.latex.preamble'] = r'\usepackage{amsmath}'
    # 小绘制检测(若系统没有 TeX 会在这里抛错)
    fig_test = plt.figure()
    fig_test.text(0.5, 0.5, r"$\mathrm{test}$")
    plt.close(fig_test)
except Exception:
    use_tex = False
    plt.rcParams['text.usetex'] = False
    if 'text.latex.preamble' in plt.rcParams:
        del plt.rcParams['text.latex.preamble']
    print("提示:系统未检测到完整 LaTeX(或渲染失败),将使用 matplotlib mathtext。")
# 将精确的符号多项式数值化(lambdify 会把 Rational 转成可计算的浮点函数)
f = lambdify(x, lagrange_poly, 'numpy')

接下来就是保存与绘图了,正常使用 matplotlib 即可。

# 范围
x_plot = np.linspace(min(x_values_float) - 0.5, max(x_values_float) + 0.5, 600)
y_plot = f(x_plot)  # 由精确多项式数值化得到的曲线(浮点)

# 点和曲线
plt.figure(figsize=(10, 6), dpi=150)
plt.plot(x_values_float, y_values_float, 'o', label='Data points', markersize=6)
plt.plot(x_plot, y_plot, '-', label='Lagrange polynomial')
plt.legend()
plt.grid(True)
plt.xlabel('x')
plt.ylabel('y')
plt.title('Lagrange Interpolation')

# plt.savefig("lagrange_polynomial_graph.png", dpi=300)
# print("函数图像已保存为 lagrange_polynomial_graph.png")
plt.show()

然后来绘制公式,其实也不难,按照如下方法即可。在这里给公式添加了 f(x) = 让他更加优雅。

formula_with_y = r"f(x) = %s" % expanded_latex

plt.text(
    0.5, 0.5,
    r"$%s$" % formula_with_y,
    fontsize=20,
    ha='center',
    va='center',
    transform=plt.gca().transAxes
)

plt.savefig("lagrange_formula.png", dpi=300, bbox_inches='tight')
print("公式图片已保存为 lagrange_formula.png")
plt.show()

举个例子,输入:

x = [1, 2, 3, 4, 5]
y = [1.1, 2.2, 3.3, 4.4, 114.514]

发现程序计算输出了:

\frac{18169 x^{4}}{4000} - \frac{18169 x^{3}}{400} + \frac{127183 x^{2}}{800} - \frac{18081 x}{80} + \frac{54507}{500}

也就是

$$ f(x) = \frac{18169 x^{4}}{4000} - \frac{18169 x^{3}}{400} + \frac{127183 x^{2}}{800} - \frac{18081 x}{80} + \frac{54507}{500} $$