🚀 AI 一键生成 joinquant 策略代码
立即体验
👉 AI 生成代码 👈

在JoinQuant中使用Scipy进行三次样条拟合及提取多项式表达式指南

joinquant | 发布时间: 2026/6/2
以下内容由 EasyQuant 生成。

问题描述

关于曲线拟合

各位大佬,请问有人了解scipy库吗?scipy.interpolate.splrep(x1,y1,k=3)三次样条曲线拟合后我该怎么获取不同区间段多项式的表达式呢?有现成的函数吗?我百度了半天没搞懂,都是教怎么拟合曲线的,没说怎么求多项式表达式.....有大神知道怎么求吗?

解决方案

在JoinQuant(聚宽)的研究环境中,我们经常需要使用 scipy 库进行数据平滑和曲线拟合(例如拟合收益率曲线、波动率曲面等)。

关于你提到的 scipy.interpolate.splrep,它返回的是一个 B样条 (B-Spline) 的表示形式(包含节点、系数和阶数的元组 tck),而不是直接的分段多项式系数。要获取不同区间段的多项式表达式,我们需要将这个 B样条 转换为 分段多项式 (Piecewise Polynomial, PPoly) 形式。

以下是具体的实现步骤和现成的函数调用方法:

核心解决思路

  1. 使用 splrep 得到 tck
  2. 使用 scipy.interpolate.BSplinetck 包装成 B样条对象。
  3. 使用 scipy.interpolate.PPoly.from_spline 将 B样条对象转换为分段多项式对象。
  4. PPoly 对象中提取断点(区间)和多项式系数。

完整代码示例

你可以在 JoinQuant 的研究环境(Jupyter Notebook)中直接运行以下代码:

import numpy as np
from scipy.interpolate import splrep, BSpline, PPoly

# 1. 准备示例数据
x1 = np.array([0, 1, 2, 3, 4, 5])
y1 = np.sin(x1)

# 2. 进行三次样条拟合 (k=3)
tck = splrep(x1, y1, k=3)

# 3. 将 tck 转换为 BSpline 对象
spline_obj = BSpline(tck[0], tck[1], tck[2])

# 4. 将 BSpline 转换为分段多项式 (PPoly) 对象
ppoly_obj = PPoly.from_spline(spline_obj)

# 5. 提取区间和系数
# ppoly_obj.x 包含了所有的断点 (区间边界)
# ppoly_obj.c 包含了多项式的系数,形状为 (k+1, 数量)
breakpoints = ppoly_obj.x
coefficients = ppoly_obj.c

print("断点 (区间边界):", breakpoints)
print("\n系数矩阵形状:", coefficients.shape)

# 6. 打印每个区间段的多项式表达式
print("\n各区间段的多项式表达式:")
# 注意:PPoly 的区间数量比断点数量少 1
num_intervals = len(breakpoints) - 1 

for i in range(num_intervals):
    # 过滤掉长度为0的无效区间(splrep会在两端生成重复节点)
    if breakpoints[i] == breakpoints[i+1]:
        continue
        
    print(f"\n区间 [{breakpoints[i]:.2f}, {breakpoints[i+1]:.2f}]:")
    
    # 提取第 i 个区间的系数 (从高次到低次)
    # 对于三次样条,系数依次为 a, b, c, d
    coefs = coefficients[:, i]
    
    # 构造表达式字符串
    # PPoly 的公式形式为: sum( c[m, i] * (x - breakpoints[i])**(k-m) )
    expr = f"y = "
    expr += f"({coefs[0]:.4f}) * (x - {breakpoints[i]:.2f})^3 + "
    expr += f"({coefs[1]:.4f}) * (x - {breakpoints[i]:.2f})^2 + "
    expr += f"({coefs[2]:.4f}) * (x - {breakpoints[i]:.2f}) + "
    expr += f"({coefs[3]:.4f})"
    
    print(expr)

原理解析

  • ppoly_obj.x:这是一个一维数组,表示分段多项式的断点(即每个区间的起始和结束位置)。注意,splrep 为了满足边界条件,会在两端生成重复的节点,我们在遍历时需要跳过长度为0的区间(breakpoints[i] == breakpoints[i+1])。
  • ppoly_obj.c:这是一个二维数组,形状为 (k+1, n_intervals)。对于三次样条(k=3),它有 4 行。每一列对应一个区间。
  • 数学表达式:在第 i 个区间 $[x_i, x_{i+1}]$ 内,多项式的形式是基于局部坐标 $(x - x_i)$ 的展开:
    $$ y(x) = c_{0,i}(x - x_i)^3 + c_{1,i}(x - x_i)^2 + c_{2,i}(x - x_i) + c_{3,i} $$
    其中 $c_{m,i}$ 就是 coefficients[m, i]

在量化交易中的应用提示

在 JoinQuant 编写策略时,如果你需要对离散的因子值(如不同期限的远期升贴水率)进行平滑,获取多项式表达式后,你可以非常方便地计算曲线的导数(例如求斜率来判断期限结构的陡峭程度)或者积分PPoly 对象本身也提供了 .derivative().antiderivative() 方法,比手动解析表达式更加高效。