50 std::vector< std::vector<double> > m_upper;
51 std::vector< std::vector<double> > m_lower;
54 band_matrix(
int dim,
int n_u,
int n_l);
56 void resize(
int dim,
int n_u,
int n_l);
60 return m_upper.size()-1;
64 return m_lower.size()-1;
67 double & operator () (
int i,
int j);
68 double operator () (
int i,
int j)
const;
70 double& saved_diag(
int i);
71 double saved_diag(
int i)
const;
73 std::vector<double> r_solve(
const std::vector<double>& b)
const;
74 std::vector<double> l_solve(
const std::vector<double>& b)
const;
75 std::vector<double> lu_solve(
const std::vector<double>& b,
76 bool is_lu_decomposed=
false);
91 std::vector<double> m_x,m_y;
94 std::vector<double> m_a,m_b,m_c;
96 bd_type m_left, m_right;
97 double m_left_value, m_right_value;
98 bool m_force_linear_extrapolation;
102 spline(): m_left(second_deriv), m_right(second_deriv),
103 m_left_value(0.0), m_right_value(0.0),
104 m_force_linear_extrapolation(false)
110 void set_boundary(bd_type left,
double left_value,
111 bd_type right,
double right_value,
112 bool force_linear_extrapolation=
false);
113 void set_points(
const std::vector<double>& x,
114 const std::vector<double>& y,
bool cubic_spline=
true);
115 double operator() (
double x)
const;
128 band_matrix::band_matrix(
int dim,
int n_u,
int n_l)
130 resize(dim, n_u, n_l);
132 void band_matrix::resize(
int dim,
int n_u,
int n_l)
137 m_upper.resize(n_u+1);
138 m_lower.resize(n_l+1);
139 for(
size_t i=0; i<m_upper.size(); i++) {
140 m_upper[i].resize(dim);
142 for(
size_t i=0; i<m_lower.size(); i++) {
143 m_lower[i].resize(dim);
146 int band_matrix::dim()
const 148 if(m_upper.size()>0) {
149 return m_upper[0].size();
158 double & band_matrix::operator () (
int i,
int j)
161 assert( (i>=0) && (i<dim()) && (j>=0) && (j<dim()) );
162 assert( (-num_lower()<=k) && (k<=num_upper()) );
164 if(k>=0)
return m_upper[k][i];
165 else return m_lower[-k][i];
167 double band_matrix::operator () (
int i,
int j)
const 170 assert( (i>=0) && (i<dim()) && (j>=0) && (j<dim()) );
171 assert( (-num_lower()<=k) && (k<=num_upper()) );
173 if(k>=0)
return m_upper[k][i];
174 else return m_lower[-k][i];
177 double band_matrix::saved_diag(
int i)
const 179 assert( (i>=0) && (i<dim()) );
180 return m_lower[0][i];
182 double & band_matrix::saved_diag(
int i)
184 assert( (i>=0) && (i<dim()) );
185 return m_lower[0][i];
189 void band_matrix::lu_decompose()
197 for(
int i=0; i<this->dim(); i++) {
198 assert(this->
operator()(i,i)!=0.0);
199 this->saved_diag(i)=1.0/this->operator()(i,i);
200 j_min=std::max(0,i-this->num_lower());
201 j_max=std::min(this->dim()-1,i+this->num_upper());
202 for(
int j=j_min; j<=j_max; j++) {
203 this->operator()(i,j) *= this->saved_diag(i);
205 this->operator()(i,i)=1.0;
209 for(
int k=0; k<this->dim(); k++) {
210 i_max=std::min(this->dim()-1,k+this->num_lower());
211 for(
int i=k+1; i<=i_max; i++) {
212 assert(this->
operator()(k,k)!=0.0);
213 x=-this->operator()(i,k)/this->operator()(k,k);
214 this->operator()(i,k)=-x;
215 j_max=std::min(this->dim()-1,k+this->num_upper());
216 for(
int j=k+1; j<=j_max; j++) {
218 this->operator()(i,j)=this->operator()(i,j)+x*this->operator()(k,j);
224 std::vector<double> band_matrix::l_solve(
const std::vector<double>& b)
const 226 assert( this->dim()==(
int)b.size() );
227 std::vector<double> x(this->dim());
230 for(
int i=0; i<this->dim(); i++) {
232 j_start=std::max(0,i-this->num_lower());
233 for(
int j=j_start; j<i; j++) sum += this->
operator()(i,j)*x[j];
234 x[i]=(b[i]*this->saved_diag(i)) - sum;
239 std::vector<double> band_matrix::r_solve(
const std::vector<double>& b)
const 241 assert( this->dim()==(
int)b.size() );
242 std::vector<double> x(this->dim());
245 for(
int i=this->dim()-1; i>=0; i--) {
247 j_stop=std::min(this->dim()-1,i+this->num_upper());
248 for(
int j=i+1; j<=j_stop; j++) sum += this->
operator()(i,j)*x[j];
249 x[i]=( b[i] - sum ) / this->
operator()(i,i);
254 std::vector<double> band_matrix::lu_solve(
const std::vector<double>& b,
255 bool is_lu_decomposed)
257 assert( this->dim()==(
int)b.size() );
258 std::vector<double> x,y;
259 if(is_lu_decomposed==
false) {
260 this->lu_decompose();
273 void spline::set_boundary(spline::bd_type left,
double left_value,
274 spline::bd_type right,
double right_value,
275 bool force_linear_extrapolation)
277 assert(m_x.size()==0);
280 m_left_value=left_value;
281 m_right_value=right_value;
282 m_force_linear_extrapolation=force_linear_extrapolation;
286 void spline::set_points(
const std::vector<double>& x,
287 const std::vector<double>& y,
bool cubic_spline)
289 assert(x.size()==y.size());
295 for(
int i=0; i<n-1; i++) {
296 assert(m_x[i]<m_x[i+1]);
299 if(cubic_spline==
true) {
302 band_matrix A(n,1,1);
303 std::vector<double> rhs(n);
304 for(
int i=1; i<n-1; i++) {
305 A(i,i-1)=1.0/3.0*(x[i]-x[i-1]);
306 A(i,i)=2.0/3.0*(x[i+1]-x[i-1]);
307 A(i,i+1)=1.0/3.0*(x[i+1]-x[i]);
308 rhs[i]=(y[i+1]-y[i])/(x[i+1]-x[i]) - (y[i]-y[i-1])/(x[i]-x[i-1]);
311 if(m_left == spline::second_deriv) {
316 }
else if(m_left == spline::first_deriv) {
319 A(0,0)=2.0*(x[1]-x[0]);
320 A(0,1)=1.0*(x[1]-x[0]);
321 rhs[0]=3.0*((y[1]-y[0])/(x[1]-x[0])-m_left_value);
325 if(m_right == spline::second_deriv) {
329 rhs[n-1]=m_right_value;
330 }
else if(m_right == spline::first_deriv) {
334 A(n-1,n-1)=2.0*(x[n-1]-x[n-2]);
335 A(n-1,n-2)=1.0*(x[n-1]-x[n-2]);
336 rhs[n-1]=3.0*(m_right_value-(y[n-1]-y[n-2])/(x[n-1]-x[n-2]));
347 for(
int i=0; i<n-1; i++) {
348 m_a[i]=1.0/3.0*(m_b[i+1]-m_b[i])/(x[i+1]-x[i]);
349 m_c[i]=(y[i+1]-y[i])/(x[i+1]-x[i])
350 - 1.0/3.0*(2.0*m_b[i]+m_b[i+1])*(x[i+1]-x[i]);
356 for(
int i=0; i<n-1; i++) {
359 m_c[i]=(m_y[i+1]-m_y[i])/(m_x[i+1]-m_x[i]);
364 m_b0 = (m_force_linear_extrapolation==
false) ? m_b[0] : 0.0;
369 double h=x[n-1]-x[n-2];
372 m_c[n-1]=3.0*m_a[n-2]*h*h+2.0*m_b[n-2]*h+m_c[n-2];
373 if(m_force_linear_extrapolation==
true)
377 double spline::operator() (
double x)
const 381 std::vector<double>::const_iterator it;
382 it=std::lower_bound(m_x.begin(),m_x.end(),x);
383 int idx=std::max(
int(it-m_x.begin())-1, 0);
389 interpol=(m_b0*h + m_c0)*h + m_y[0];
390 }
else if(x>m_x[n-1]) {
392 interpol=(m_b[n-1]*h + m_c[n-1])*h + m_y[n-1];
395 interpol=((m_a[idx]*h + m_b[idx])*h + m_c[idx])*h + m_y[idx];