问题描述
各位大佬,请问有人了解scipy库吗?scipy.interpolate.splrep(x1,y1,k=3)三次样条曲线拟合后我该怎么获取不同区间段多项式的表达式呢?有现成的函数吗?我百度了半天没搞懂,都是教怎么拟合曲线的,没说怎么求多项式表达式.....有大神知道怎么求吗?
解决方案
在JoinQuant(聚宽)的研究环境中,我们经常需要使用 scipy 库进行数据平滑和曲线拟合(例如拟合收益率曲线、波动率曲面等)。
关于你提到的 scipy.interpolate.splrep,它返回的是一个 B样条 (B-Spline) 的表示形式(包含节点、系数和阶数的元组 tck),而不是直接的分段多项式系数。要获取不同区间段的多项式表达式,我们需要将这个 B样条 转换为 分段多项式 (Piecewise Polynomial, PPoly) 形式。
以下是具体的实现步骤和现成的函数调用方法:
核心解决思路
- 使用
splrep得到tck。 - 使用
scipy.interpolate.BSpline将tck包装成 B样条对象。 - 使用
scipy.interpolate.PPoly.from_spline将 B样条对象转换为分段多项式对象。 - 从
PPoly对象中提取断点(区间)和多项式系数。
完整代码示例
你可以在 JoinQuant 的研究环境(Jupyter Notebook)中直接运行以下代码:
import numpy as np
from scipy.interpolate import splrep, BSpline, PPoly
# 1. 准备示例数据
x1 = np.array([0, 1, 2, 3, 4, 5])
y1 = np.sin(x1)
# 2. 进行三次样条拟合 (k=3)
tck = splrep(x1, y1, k=3)
# 3. 将 tck 转换为 BSpline 对象
spline_obj = BSpline(tck[0], tck[1], tck[2])
# 4. 将 BSpline 转换为分段多项式 (PPoly) 对象
ppoly_obj = PPoly.from_spline(spline_obj)
# 5. 提取区间和系数
# ppoly_obj.x 包含了所有的断点 (区间边界)
# ppoly_obj.c 包含了多项式的系数,形状为 (k+1, 数量)
breakpoints = ppoly_obj.x
coefficients = ppoly_obj.c
print("断点 (区间边界):", breakpoints)
print("\n系数矩阵形状:", coefficients.shape)
# 6. 打印每个区间段的多项式表达式
print("\n各区间段的多项式表达式:")
# 注意:PPoly 的区间数量比断点数量少 1
num_intervals = len(breakpoints) - 1
for i in range(num_intervals):
# 过滤掉长度为0的无效区间(splrep会在两端生成重复节点)
if breakpoints[i] == breakpoints[i+1]:
continue
print(f"\n区间 [{breakpoints[i]:.2f}, {breakpoints[i+1]:.2f}]:")
# 提取第 i 个区间的系数 (从高次到低次)
# 对于三次样条,系数依次为 a, b, c, d
coefs = coefficients[:, i]
# 构造表达式字符串
# PPoly 的公式形式为: sum( c[m, i] * (x - breakpoints[i])**(k-m) )
expr = f"y = "
expr += f"({coefs[0]:.4f}) * (x - {breakpoints[i]:.2f})^3 + "
expr += f"({coefs[1]:.4f}) * (x - {breakpoints[i]:.2f})^2 + "
expr += f"({coefs[2]:.4f}) * (x - {breakpoints[i]:.2f}) + "
expr += f"({coefs[3]:.4f})"
print(expr)
原理解析
ppoly_obj.x:这是一个一维数组,表示分段多项式的断点(即每个区间的起始和结束位置)。注意,splrep为了满足边界条件,会在两端生成重复的节点,我们在遍历时需要跳过长度为0的区间(breakpoints[i] == breakpoints[i+1])。ppoly_obj.c:这是一个二维数组,形状为(k+1, n_intervals)。对于三次样条(k=3),它有 4 行。每一列对应一个区间。- 数学表达式:在第
i个区间 $[x_i, x_{i+1}]$ 内,多项式的形式是基于局部坐标 $(x - x_i)$ 的展开:
$$ y(x) = c_{0,i}(x - x_i)^3 + c_{1,i}(x - x_i)^2 + c_{2,i}(x - x_i) + c_{3,i} $$
其中 $c_{m,i}$ 就是coefficients[m, i]。
在量化交易中的应用提示
在 JoinQuant 编写策略时,如果你需要对离散的因子值(如不同期限的远期升贴水率)进行平滑,获取多项式表达式后,你可以非常方便地计算曲线的导数(例如求斜率来判断期限结构的陡峭程度)或者积分。PPoly 对象本身也提供了 .derivative() 和 .antiderivative() 方法,比手动解析表达式更加高效。