From 210550007e04a096e0e600c6586da677c33d6d90 Mon Sep 17 00:00:00 2001 From: Adam Li Date: Fri, 6 Oct 2023 16:49:16 -0400 Subject: [PATCH] MAINT Replace double with float64_t inside tree submodule (#27539) Signed-off-by: Adam Li --- sklearn/tree/_criterion.pxd | 68 +++++----- sklearn/tree/_criterion.pyx | 258 ++++++++++++++++++------------------ sklearn/tree/_splitter.pxd | 30 ++--- sklearn/tree/_splitter.pyx | 82 ++++++------ sklearn/tree/_tree.pxd | 10 +- sklearn/tree/_tree.pyx | 104 ++++++++------- sklearn/tree/_utils.pxd | 6 +- sklearn/tree/_utils.pyx | 16 +-- 8 files changed, 289 insertions(+), 285 deletions(-) diff --git a/sklearn/tree/_criterion.pxd b/sklearn/tree/_criterion.pxd index d5341c4db3be1..6538b9b824a79 100644 --- a/sklearn/tree/_criterion.pxd +++ b/sklearn/tree/_criterion.pxd @@ -32,11 +32,11 @@ cdef class Criterion: cdef intp_t n_outputs # Number of outputs cdef intp_t n_samples # Number of samples cdef intp_t n_node_samples # Number of samples in the node (end-start) - cdef double weighted_n_samples # Weighted number of samples (in total) - cdef double weighted_n_node_samples # Weighted number of samples in the node - cdef double weighted_n_left # Weighted number of samples in the left node - cdef double weighted_n_right # Weighted number of samples in the right node - cdef double weighted_n_missing # Weighted number of samples that are missing + cdef float64_t weighted_n_samples # Weighted number of samples (in total) + cdef float64_t weighted_n_node_samples # Weighted number of samples in the node + cdef float64_t weighted_n_left # Weighted number of samples in the left node + cdef float64_t weighted_n_right # Weighted number of samples in the right node + cdef float64_t weighted_n_missing # Weighted number of samples that are missing # The criterion object is maintained such that left and right collected # statistics correspond to samples[start:pos] and samples[pos:end]. @@ -46,7 +46,7 @@ cdef class Criterion: self, const float64_t[:, ::1] y, const float64_t[:] sample_weight, - double weighted_n_samples, + float64_t weighted_n_samples, const intp_t[:] sample_indices, intp_t start, intp_t end @@ -56,43 +56,43 @@ cdef class Criterion: cdef int reset(self) except -1 nogil cdef int reverse_reset(self) except -1 nogil cdef int update(self, intp_t new_pos) except -1 nogil - cdef double node_impurity(self) noexcept nogil + cdef float64_t node_impurity(self) noexcept nogil cdef void children_impurity( self, - double* impurity_left, - double* impurity_right + float64_t* impurity_left, + float64_t* impurity_right ) noexcept nogil cdef void node_value( self, - double* dest + float64_t* dest ) noexcept nogil cdef void clip_node_value( self, - double* dest, - double lower_bound, - double upper_bound + float64_t* dest, + float64_t lower_bound, + float64_t upper_bound ) noexcept nogil - cdef double middle_value(self) noexcept nogil - cdef double impurity_improvement( + cdef float64_t middle_value(self) noexcept nogil + cdef float64_t impurity_improvement( self, - double impurity_parent, - double impurity_left, - double impurity_right + float64_t impurity_parent, + float64_t impurity_left, + float64_t impurity_right ) noexcept nogil - cdef double proxy_impurity_improvement(self) noexcept nogil + cdef float64_t proxy_impurity_improvement(self) noexcept nogil cdef bint check_monotonicity( self, cnp.int8_t monotonic_cst, - double lower_bound, - double upper_bound, + float64_t lower_bound, + float64_t upper_bound, ) noexcept nogil cdef inline bint _check_monotonicity( self, cnp.int8_t monotonic_cst, - double lower_bound, - double upper_bound, - double sum_left, - double sum_right, + float64_t lower_bound, + float64_t upper_bound, + float64_t sum_left, + float64_t sum_right, ) noexcept nogil cdef class ClassificationCriterion(Criterion): @@ -101,17 +101,17 @@ cdef class ClassificationCriterion(Criterion): cdef intp_t[::1] n_classes cdef intp_t max_n_classes - cdef double[:, ::1] sum_total # The sum of the weighted count of each label. - cdef double[:, ::1] sum_left # Same as above, but for the left side of the split - cdef double[:, ::1] sum_right # Same as above, but for the right side of the split - cdef double[:, ::1] sum_missing # Same as above, but for missing values in X + cdef float64_t[:, ::1] sum_total # The sum of the weighted count of each label. + cdef float64_t[:, ::1] sum_left # Same as above, but for the left side of the split + cdef float64_t[:, ::1] sum_right # Same as above, but for the right side of the split + cdef float64_t[:, ::1] sum_missing # Same as above, but for missing values in X cdef class RegressionCriterion(Criterion): """Abstract regression criterion.""" - cdef double sq_sum_total + cdef float64_t sq_sum_total - cdef double[::1] sum_total # The sum of w*y. - cdef double[::1] sum_left # Same as above, but for the left side of the split - cdef double[::1] sum_right # Same as above, but for the right side of the split - cdef double[::1] sum_missing # Same as above, but for missing values in X + cdef float64_t[::1] sum_total # The sum of w*y. + cdef float64_t[::1] sum_left # Same as above, but for the left side of the split + cdef float64_t[::1] sum_right # Same as above, but for the right side of the split + cdef float64_t[::1] sum_missing # Same as above, but for missing values in X diff --git a/sklearn/tree/_criterion.pyx b/sklearn/tree/_criterion.pyx index 13d18888db79f..89a7639f9bbcf 100644 --- a/sklearn/tree/_criterion.pyx +++ b/sklearn/tree/_criterion.pyx @@ -26,7 +26,7 @@ from ._utils cimport log from ._utils cimport WeightedMedianCalculator # EPSILON is used in the Poisson criterion -cdef double EPSILON = 10 * np.finfo('double').eps +cdef float64_t EPSILON = 10 * np.finfo('double').eps cdef class Criterion: """Interface for impurity criteria. @@ -44,7 +44,7 @@ cdef class Criterion: self, const float64_t[:, ::1] y, const float64_t[:] sample_weight, - double weighted_n_samples, + float64_t weighted_n_samples, const intp_t[:] sample_indices, intp_t start, intp_t end, @@ -61,7 +61,7 @@ cdef class Criterion: stored as a Cython memoryview. sample_weight : ndarray, dtype=float64_t The weight of each sample stored as a Cython memoryview. - weighted_n_samples : double + weighted_n_samples : float64_t The total weight of the samples being considered sample_indices : ndarray, dtype=intp_t A mask on the samples. Indices of the samples in X and y we want to use, @@ -115,7 +115,7 @@ cdef class Criterion: """ pass - cdef double node_impurity(self) noexcept nogil: + cdef float64_t node_impurity(self) noexcept nogil: """Placeholder for calculating the impurity of the node. Placeholder for a method which will evaluate the impurity of @@ -125,8 +125,8 @@ cdef class Criterion: """ pass - cdef void children_impurity(self, double* impurity_left, - double* impurity_right) noexcept nogil: + cdef void children_impurity(self, float64_t* impurity_left, + float64_t* impurity_right) noexcept nogil: """Placeholder for calculating the impurity of children. Placeholder for a method which evaluates the impurity in @@ -135,16 +135,16 @@ cdef class Criterion: Parameters ---------- - impurity_left : double pointer + impurity_left : float64_t pointer The memory address where the impurity of the left child should be stored. - impurity_right : double pointer + impurity_right : float64_t pointer The memory address where the impurity of the right child should be stored """ pass - cdef void node_value(self, double* dest) noexcept nogil: + cdef void node_value(self, float64_t* dest) noexcept nogil: """Placeholder for storing the node value. Placeholder for a method which will compute the node value @@ -152,22 +152,22 @@ cdef class Criterion: Parameters ---------- - dest : double pointer + dest : float64_t pointer The memory address where the node value should be stored. """ pass - cdef void clip_node_value(self, double* dest, double lower_bound, double upper_bound) noexcept nogil: + cdef void clip_node_value(self, float64_t* dest, float64_t lower_bound, float64_t upper_bound) noexcept nogil: pass - cdef double middle_value(self) noexcept nogil: + cdef float64_t middle_value(self) noexcept nogil: """Compute the middle value of a split for monotonicity constraints This method is implemented in ClassificationCriterion and RegressionCriterion. """ pass - cdef double proxy_impurity_improvement(self) noexcept nogil: + cdef float64_t proxy_impurity_improvement(self) noexcept nogil: """Compute a proxy of the impurity reduction. This method is used to speed up the search for the best split. @@ -178,16 +178,16 @@ cdef class Criterion: The absolute impurity improvement is only computed by the impurity_improvement method once the best split has been found. """ - cdef double impurity_left - cdef double impurity_right + cdef float64_t impurity_left + cdef float64_t impurity_right self.children_impurity(&impurity_left, &impurity_right) return (- self.weighted_n_right * impurity_right - self.weighted_n_left * impurity_left) - cdef double impurity_improvement(self, double impurity_parent, - double impurity_left, - double impurity_right) noexcept nogil: + cdef float64_t impurity_improvement(self, float64_t impurity_parent, + float64_t impurity_left, + float64_t impurity_right) noexcept nogil: """Compute the improvement in impurity. This method computes the improvement in impurity when a split occurs. @@ -202,18 +202,18 @@ cdef class Criterion: Parameters ---------- - impurity_parent : double + impurity_parent : float64_t The initial impurity of the parent node before the split - impurity_left : double + impurity_left : float64_t The impurity of the left child - impurity_right : double + impurity_right : float64_t The impurity of the right child Return ------ - double : improvement in impurity after the split occurs + float64_t : improvement in impurity after the split occurs """ return ((self.weighted_n_node_samples / self.weighted_n_samples) * (impurity_parent - (self.weighted_n_right / @@ -224,18 +224,18 @@ cdef class Criterion: cdef bint check_monotonicity( self, cnp.int8_t monotonic_cst, - double lower_bound, - double upper_bound, + float64_t lower_bound, + float64_t upper_bound, ) noexcept nogil: pass cdef inline bint _check_monotonicity( self, cnp.int8_t monotonic_cst, - double lower_bound, - double upper_bound, - double value_left, - double value_right, + float64_t lower_bound, + float64_t upper_bound, + float64_t value_left, + float64_t value_right, ) noexcept nogil: cdef: bint check_lower_bound = ( @@ -256,10 +256,10 @@ cdef class Criterion: cdef inline void _move_sums_classification( ClassificationCriterion criterion, - double[:, ::1] sum_1, - double[:, ::1] sum_2, - double* weighted_n_1, - double* weighted_n_2, + float64_t[:, ::1] sum_1, + float64_t[:, ::1] sum_2, + float64_t* weighted_n_1, + float64_t* weighted_n_2, bint put_missing_in_1, ) noexcept nogil: """Distribute sum_total and sum_missing into sum_1 and sum_2. @@ -276,7 +276,7 @@ cdef inline void _move_sums_classification( cdef intp_t k, c, n_bytes if criterion.n_missing != 0 and put_missing_in_1: for k in range(criterion.n_outputs): - n_bytes = criterion.n_classes[k] * sizeof(double) + n_bytes = criterion.n_classes[k] * sizeof(float64_t) memcpy(&sum_1[k, 0], &criterion.sum_missing[k, 0], n_bytes) for k in range(criterion.n_outputs): @@ -288,7 +288,7 @@ cdef inline void _move_sums_classification( else: # Assigning sum_2 = sum_total for all outputs. for k in range(criterion.n_outputs): - n_bytes = criterion.n_classes[k] * sizeof(double) + n_bytes = criterion.n_classes[k] * sizeof(float64_t) memset(&sum_1[k, 0], 0, n_bytes) memcpy(&sum_2[k, 0], &criterion.sum_total[k, 0], n_bytes) @@ -351,7 +351,7 @@ cdef class ClassificationCriterion(Criterion): self, const float64_t[:, ::1] y, const float64_t[:] sample_weight, - double weighted_n_samples, + float64_t weighted_n_samples, const intp_t[:] sample_indices, intp_t start, intp_t end @@ -370,7 +370,7 @@ cdef class ClassificationCriterion(Criterion): The target stored as a buffer for memory efficiency. sample_weight : ndarray, dtype=float64_t The weight of each sample stored as a Cython memoryview. - weighted_n_samples : double + weighted_n_samples : float64_t The total weight of all samples sample_indices : ndarray, dtype=intp_t A mask on the samples. Indices of the samples in X and y we want to use, @@ -396,7 +396,7 @@ cdef class ClassificationCriterion(Criterion): cdef float64_t w = 1.0 for k in range(self.n_outputs): - memset(&self.sum_total[k, 0], 0, self.n_classes[k] * sizeof(double)) + memset(&self.sum_total[k, 0], 0, self.n_classes[k] * sizeof(float64_t)) for p in range(start, end): i = sample_indices[p] @@ -434,7 +434,7 @@ cdef class ClassificationCriterion(Criterion): if n_missing == 0: return - memset(&self.sum_missing[0, 0], 0, self.max_n_classes * self.n_outputs * sizeof(double)) + memset(&self.sum_missing[0, 0], 0, self.max_n_classes * self.n_outputs * sizeof(float64_t)) self.weighted_n_missing = 0.0 @@ -553,28 +553,28 @@ cdef class ClassificationCriterion(Criterion): self.pos = new_pos return 0 - cdef double node_impurity(self) noexcept nogil: + cdef float64_t node_impurity(self) noexcept nogil: pass - cdef void children_impurity(self, double* impurity_left, - double* impurity_right) noexcept nogil: + cdef void children_impurity(self, float64_t* impurity_left, + float64_t* impurity_right) noexcept nogil: pass - cdef void node_value(self, double* dest) noexcept nogil: + cdef void node_value(self, float64_t* dest) noexcept nogil: """Compute the node value of sample_indices[start:end] and save it into dest. Parameters ---------- - dest : double pointer + dest : float64_t pointer The memory address which we will save the node value into. """ cdef intp_t k for k in range(self.n_outputs): - memcpy(dest, &self.sum_total[k, 0], self.n_classes[k] * sizeof(double)) + memcpy(dest, &self.sum_total[k, 0], self.n_classes[k] * sizeof(float64_t)) dest += self.max_n_classes - cdef void clip_node_value(self, double * dest, double lower_bound, double upper_bound) noexcept nogil: + cdef void clip_node_value(self, float64_t * dest, float64_t lower_bound, float64_t upper_bound) noexcept nogil: """Clip the value in dest between lower_bound and upper_bound for monotonic constraints. Note that monotonicity constraints are only supported for: @@ -589,7 +589,7 @@ cdef class ClassificationCriterion(Criterion): # Class proportions for binary classification must sum to 1. dest[1] = 1 - dest[0] - cdef inline double middle_value(self) noexcept nogil: + cdef inline float64_t middle_value(self) noexcept nogil: """Compute the middle value of a split for monotonicity constraints as the simple average of the left and right children values. @@ -605,13 +605,13 @@ cdef class ClassificationCriterion(Criterion): cdef inline bint check_monotonicity( self, cnp.int8_t monotonic_cst, - double lower_bound, - double upper_bound, + float64_t lower_bound, + float64_t upper_bound, ) noexcept nogil: """Check monotonicity constraint is satisfied at the current classification split""" cdef: - double value_left = self.sum_left[0][0] / self.weighted_n_left - double value_right = self.sum_right[0][0] / self.weighted_n_right + float64_t value_left = self.sum_left[0][0] / self.weighted_n_left + float64_t value_right = self.sum_right[0][0] / self.weighted_n_right return self._check_monotonicity(monotonic_cst, lower_bound, upper_bound, value_left, value_right) @@ -632,15 +632,15 @@ cdef class Entropy(ClassificationCriterion): cross-entropy = -\sum_{k=0}^{K-1} count_k log(count_k) """ - cdef double node_impurity(self) noexcept nogil: + cdef float64_t node_impurity(self) noexcept nogil: """Evaluate the impurity of the current node. Evaluate the cross-entropy criterion as impurity of the current node, i.e. the impurity of sample_indices[start:end]. The smaller the impurity the better. """ - cdef double entropy = 0.0 - cdef double count_k + cdef float64_t entropy = 0.0 + cdef float64_t count_k cdef intp_t k cdef intp_t c @@ -653,8 +653,8 @@ cdef class Entropy(ClassificationCriterion): return entropy / self.n_outputs - cdef void children_impurity(self, double* impurity_left, - double* impurity_right) noexcept nogil: + cdef void children_impurity(self, float64_t* impurity_left, + float64_t* impurity_right) noexcept nogil: """Evaluate the impurity in children nodes. i.e. the impurity of the left child (sample_indices[start:pos]) and the @@ -662,14 +662,14 @@ cdef class Entropy(ClassificationCriterion): Parameters ---------- - impurity_left : double pointer + impurity_left : float64_t pointer The memory address to save the impurity of the left node - impurity_right : double pointer + impurity_right : float64_t pointer The memory address to save the impurity of the right node """ - cdef double entropy_left = 0.0 - cdef double entropy_right = 0.0 - cdef double count_k + cdef float64_t entropy_left = 0.0 + cdef float64_t entropy_right = 0.0 + cdef float64_t count_k cdef intp_t k cdef intp_t c @@ -706,16 +706,16 @@ cdef class Gini(ClassificationCriterion): = 1 - \sum_{k=0}^{K-1} count_k ** 2 """ - cdef double node_impurity(self) noexcept nogil: + cdef float64_t node_impurity(self) noexcept nogil: """Evaluate the impurity of the current node. Evaluate the Gini criterion as impurity of the current node, i.e. the impurity of sample_indices[start:end]. The smaller the impurity the better. """ - cdef double gini = 0.0 - cdef double sq_count - cdef double count_k + cdef float64_t gini = 0.0 + cdef float64_t sq_count + cdef float64_t count_k cdef intp_t k cdef intp_t c @@ -731,8 +731,8 @@ cdef class Gini(ClassificationCriterion): return gini / self.n_outputs - cdef void children_impurity(self, double* impurity_left, - double* impurity_right) noexcept nogil: + cdef void children_impurity(self, float64_t* impurity_left, + float64_t* impurity_right) noexcept nogil: """Evaluate the impurity in children nodes. i.e. the impurity of the left child (sample_indices[start:pos]) and the @@ -740,16 +740,16 @@ cdef class Gini(ClassificationCriterion): Parameters ---------- - impurity_left : double pointer + impurity_left : float64_t pointer The memory address to save the impurity of the left node to - impurity_right : double pointer + impurity_right : float64_t pointer The memory address to save the impurity of the right node to """ - cdef double gini_left = 0.0 - cdef double gini_right = 0.0 - cdef double sq_count_left - cdef double sq_count_right - cdef double count_k + cdef float64_t gini_left = 0.0 + cdef float64_t gini_right = 0.0 + cdef float64_t sq_count_left + cdef float64_t sq_count_right + cdef float64_t count_k cdef intp_t k cdef intp_t c @@ -776,10 +776,10 @@ cdef class Gini(ClassificationCriterion): cdef inline void _move_sums_regression( RegressionCriterion criterion, - double[::1] sum_1, - double[::1] sum_2, - double* weighted_n_1, - double* weighted_n_2, + float64_t[::1] sum_1, + float64_t[::1] sum_2, + float64_t* weighted_n_1, + float64_t* weighted_n_2, bint put_missing_in_1, ) noexcept nogil: """Distribute sum_total and sum_missing into sum_1 and sum_2. @@ -795,7 +795,7 @@ cdef inline void _move_sums_regression( """ cdef: intp_t i - intp_t n_bytes = criterion.n_outputs * sizeof(double) + intp_t n_bytes = criterion.n_outputs * sizeof(float64_t) bint has_missing = criterion.n_missing != 0 if has_missing and put_missing_in_1: @@ -861,7 +861,7 @@ cdef class RegressionCriterion(Criterion): self, const float64_t[:, ::1] y, const float64_t[:] sample_weight, - double weighted_n_samples, + float64_t weighted_n_samples, const intp_t[:] sample_indices, intp_t start, intp_t end, @@ -888,7 +888,7 @@ cdef class RegressionCriterion(Criterion): cdef float64_t w_y_ik cdef float64_t w = 1.0 self.sq_sum_total = 0.0 - memset(&self.sum_total[0], 0, self.n_outputs * sizeof(double)) + memset(&self.sum_total[0], 0, self.n_outputs * sizeof(float64_t)) for p in range(start, end): i = sample_indices[p] @@ -927,7 +927,7 @@ cdef class RegressionCriterion(Criterion): if n_missing == 0: return - memset(&self.sum_missing[0], 0, self.n_outputs * sizeof(double)) + memset(&self.sum_missing[0], 0, self.n_outputs * sizeof(float64_t)) self.weighted_n_missing = 0.0 @@ -1026,28 +1026,28 @@ cdef class RegressionCriterion(Criterion): self.pos = new_pos return 0 - cdef double node_impurity(self) noexcept nogil: + cdef float64_t node_impurity(self) noexcept nogil: pass - cdef void children_impurity(self, double* impurity_left, - double* impurity_right) noexcept nogil: + cdef void children_impurity(self, float64_t* impurity_left, + float64_t* impurity_right) noexcept nogil: pass - cdef void node_value(self, double* dest) noexcept nogil: + cdef void node_value(self, float64_t* dest) noexcept nogil: """Compute the node value of sample_indices[start:end] into dest.""" cdef intp_t k for k in range(self.n_outputs): dest[k] = self.sum_total[k] / self.weighted_n_node_samples - cdef inline void clip_node_value(self, double* dest, double lower_bound, double upper_bound) noexcept nogil: + cdef inline void clip_node_value(self, float64_t* dest, float64_t lower_bound, float64_t upper_bound) noexcept nogil: """Clip the value in dest between lower_bound and upper_bound for monotonic constraints.""" if dest[0] < lower_bound: dest[0] = lower_bound elif dest[0] > upper_bound: dest[0] = upper_bound - cdef double middle_value(self) noexcept nogil: + cdef float64_t middle_value(self) noexcept nogil: """Compute the middle value of a split for monotonicity constraints as the simple average of the left and right children values. @@ -1062,13 +1062,13 @@ cdef class RegressionCriterion(Criterion): cdef bint check_monotonicity( self, cnp.int8_t monotonic_cst, - double lower_bound, - double upper_bound, + float64_t lower_bound, + float64_t upper_bound, ) noexcept nogil: """Check monotonicity constraint is satisfied at the current regression split""" cdef: - double value_left = self.sum_left[0] / self.weighted_n_left - double value_right = self.sum_right[0] / self.weighted_n_right + float64_t value_left = self.sum_left[0] / self.weighted_n_left + float64_t value_right = self.sum_right[0] / self.weighted_n_right return self._check_monotonicity(monotonic_cst, lower_bound, upper_bound, value_left, value_right) @@ -1078,14 +1078,14 @@ cdef class MSE(RegressionCriterion): MSE = var_left + var_right """ - cdef double node_impurity(self) noexcept nogil: + cdef float64_t node_impurity(self) noexcept nogil: """Evaluate the impurity of the current node. Evaluate the MSE criterion as impurity of the current node, i.e. the impurity of sample_indices[start:end]. The smaller the impurity the better. """ - cdef double impurity + cdef float64_t impurity cdef intp_t k impurity = self.sq_sum_total / self.weighted_n_node_samples @@ -1094,7 +1094,7 @@ cdef class MSE(RegressionCriterion): return impurity / self.n_outputs - cdef double proxy_impurity_improvement(self) noexcept nogil: + cdef float64_t proxy_impurity_improvement(self) noexcept nogil: """Compute a proxy of the impurity reduction. This method is used to speed up the search for the best split. @@ -1115,8 +1115,8 @@ cdef class MSE(RegressionCriterion): - 1/n_L * sum_{i left}(y_i)^2 - 1/n_R * sum_{i right}(y_i)^2 """ cdef intp_t k - cdef double proxy_impurity_left = 0.0 - cdef double proxy_impurity_right = 0.0 + cdef float64_t proxy_impurity_left = 0.0 + cdef float64_t proxy_impurity_right = 0.0 for k in range(self.n_outputs): proxy_impurity_left += self.sum_left[k] * self.sum_left[k] @@ -1125,8 +1125,8 @@ cdef class MSE(RegressionCriterion): return (proxy_impurity_left / self.weighted_n_left + proxy_impurity_right / self.weighted_n_right) - cdef void children_impurity(self, double* impurity_left, - double* impurity_right) noexcept nogil: + cdef void children_impurity(self, float64_t* impurity_left, + float64_t* impurity_right) noexcept nogil: """Evaluate the impurity in children nodes. i.e. the impurity of the left child (sample_indices[start:pos]) and the @@ -1139,8 +1139,8 @@ cdef class MSE(RegressionCriterion): cdef float64_t y_ik - cdef double sq_sum_left = 0.0 - cdef double sq_sum_right + cdef float64_t sq_sum_left = 0.0 + cdef float64_t sq_sum_right cdef intp_t i cdef intp_t p @@ -1221,7 +1221,7 @@ cdef class MAE(RegressionCriterion): self, const float64_t[:, ::1] y, const float64_t[:] sample_weight, - double weighted_n_samples, + float64_t weighted_n_samples, const intp_t[:] sample_indices, intp_t start, intp_t end, @@ -1395,13 +1395,13 @@ cdef class MAE(RegressionCriterion): self.pos = new_pos return 0 - cdef void node_value(self, double* dest) noexcept nogil: + cdef void node_value(self, float64_t* dest) noexcept nogil: """Computes the node value of sample_indices[start:end] into dest.""" cdef intp_t k for k in range(self.n_outputs): - dest[k] = self.node_medians[k] + dest[k] = self.node_medians[k] - cdef inline double middle_value(self) noexcept nogil: + cdef inline float64_t middle_value(self) noexcept nogil: """Compute the middle value of a split for monotonicity constraints as the simple average of the left and right children values. @@ -1416,17 +1416,17 @@ cdef class MAE(RegressionCriterion): cdef inline bint check_monotonicity( self, cnp.int8_t monotonic_cst, - double lower_bound, - double upper_bound, + float64_t lower_bound, + float64_t upper_bound, ) noexcept nogil: """Check monotonicity constraint is satisfied at the current regression split""" cdef: - double value_left = ( self.left_child_ptr[0]).get_median() - double value_right = ( self.right_child_ptr[0]).get_median() + float64_t value_left = ( self.left_child_ptr[0]).get_median() + float64_t value_right = ( self.right_child_ptr[0]).get_median() return self._check_monotonicity(monotonic_cst, lower_bound, upper_bound, value_left, value_right) - cdef double node_impurity(self) noexcept nogil: + cdef float64_t node_impurity(self) noexcept nogil: """Evaluate the impurity of the current node. Evaluate the MAE criterion as impurity of the current node, @@ -1450,8 +1450,8 @@ cdef class MAE(RegressionCriterion): return impurity / (self.weighted_n_node_samples * self.n_outputs) - cdef void children_impurity(self, double* p_impurity_left, - double* p_impurity_right) noexcept nogil: + cdef void children_impurity(self, float64_t* p_impurity_left, + float64_t* p_impurity_right) noexcept nogil: """Evaluate the impurity in children nodes. i.e. the impurity of the left child (sample_indices[start:pos]) and the @@ -1507,7 +1507,7 @@ cdef class FriedmanMSE(MSE): improvement = n_left * n_right * diff^2 / (n_left + n_right) """ - cdef double proxy_impurity_improvement(self) noexcept nogil: + cdef float64_t proxy_impurity_improvement(self) noexcept nogil: """Compute a proxy of the impurity reduction. This method is used to speed up the search for the best split. @@ -1518,11 +1518,11 @@ cdef class FriedmanMSE(MSE): The absolute impurity improvement is only computed by the impurity_improvement method once the best split has been found. """ - cdef double total_sum_left = 0.0 - cdef double total_sum_right = 0.0 + cdef float64_t total_sum_left = 0.0 + cdef float64_t total_sum_right = 0.0 cdef intp_t k - cdef double diff = 0.0 + cdef float64_t diff = 0.0 for k in range(self.n_outputs): total_sum_left += self.sum_left[k] @@ -1533,14 +1533,14 @@ cdef class FriedmanMSE(MSE): return diff * diff / (self.weighted_n_left * self.weighted_n_right) - cdef double impurity_improvement(self, double impurity_parent, double - impurity_left, double impurity_right) noexcept nogil: + cdef float64_t impurity_improvement(self, float64_t impurity_parent, float64_t + impurity_left, float64_t impurity_right) noexcept nogil: # Note: none of the arguments are used here - cdef double total_sum_left = 0.0 - cdef double total_sum_right = 0.0 + cdef float64_t total_sum_left = 0.0 + cdef float64_t total_sum_right = 0.0 cdef intp_t k - cdef double diff = 0.0 + cdef float64_t diff = 0.0 for k in range(self.n_outputs): total_sum_left += self.sum_left[k] @@ -1574,7 +1574,7 @@ cdef class Poisson(RegressionCriterion): # children_impurity would only need to go over left xor right split, not # both. This could be faster. - cdef double node_impurity(self) noexcept nogil: + cdef float64_t node_impurity(self) noexcept nogil: """Evaluate the impurity of the current node. Evaluate the Poisson criterion as impurity of the current node, @@ -1584,7 +1584,7 @@ cdef class Poisson(RegressionCriterion): return self.poisson_loss(self.start, self.end, self.sum_total, self.weighted_n_node_samples) - cdef double proxy_impurity_improvement(self) noexcept nogil: + cdef float64_t proxy_impurity_improvement(self) noexcept nogil: """Compute a proxy of the impurity reduction. This method is used to speed up the search for the best split. @@ -1608,10 +1608,10 @@ cdef class Poisson(RegressionCriterion): - sum{i right}(y_i) * log(mean{i right}(y_i)) """ cdef intp_t k - cdef double proxy_impurity_left = 0.0 - cdef double proxy_impurity_right = 0.0 - cdef double y_mean_left = 0. - cdef double y_mean_right = 0. + cdef float64_t proxy_impurity_left = 0.0 + cdef float64_t proxy_impurity_right = 0.0 + cdef float64_t y_mean_left = 0. + cdef float64_t y_mean_right = 0. for k in range(self.n_outputs): if (self.sum_left[k] <= EPSILON) or (self.sum_right[k] <= EPSILON): @@ -1630,8 +1630,8 @@ cdef class Poisson(RegressionCriterion): return - proxy_impurity_left - proxy_impurity_right - cdef void children_impurity(self, double* impurity_left, - double* impurity_right) noexcept nogil: + cdef void children_impurity(self, float64_t* impurity_left, + float64_t* impurity_right) noexcept nogil: """Evaluate the impurity in children nodes. i.e. the impurity of the left child (sample_indices[start:pos]) and the @@ -1651,7 +1651,7 @@ cdef class Poisson(RegressionCriterion): self, intp_t start, intp_t end, - const double[::1] y_sum, + const float64_t[::1] y_sum, float64_t weight_sum ) noexcept nogil: """Helper function to compute Poisson loss (~deviance) of a given node. diff --git a/sklearn/tree/_splitter.pxd b/sklearn/tree/_splitter.pxd index a096014804847..adc14011cb7a2 100644 --- a/sklearn/tree/_splitter.pxd +++ b/sklearn/tree/_splitter.pxd @@ -21,12 +21,12 @@ cdef struct SplitRecord: intp_t pos # Split samples array at the given position, # # i.e. count of samples below threshold for feature. # # pos is >= end if the node is a leaf. - double threshold # Threshold to split at. - double improvement # Impurity improvement given parent node. - double impurity_left # Impurity of the left split. - double impurity_right # Impurity of the right split. - double lower_bound # Lower bound on value of both children for monotonicity - double upper_bound # Upper bound on value of both children for monotonicity + float64_t threshold # Threshold to split at. + float64_t improvement # Impurity improvement given parent node. + float64_t impurity_left # Impurity of the left split. + float64_t impurity_right # Impurity of the right split. + float64_t lower_bound # Lower bound on value of both children for monotonicity + float64_t upper_bound # Upper bound on value of both children for monotonicity unsigned char missing_go_to_left # Controls if missing values go to the left node. intp_t n_missing # Number of missing values for the feature being split on @@ -40,14 +40,14 @@ cdef class Splitter: cdef public Criterion criterion # Impurity criterion cdef public intp_t max_features # Number of features to test cdef public intp_t min_samples_leaf # Min samples in a leaf - cdef public double min_weight_leaf # Minimum weight in a leaf + cdef public float64_t min_weight_leaf # Minimum weight in a leaf cdef object random_state # Random state cdef uint32_t rand_r_state # sklearn_rand_r random number state cdef intp_t[::1] samples # Sample indices in X, y cdef intp_t n_samples # X.shape[0] - cdef double weighted_n_samples # Weighted number of samples + cdef float64_t weighted_n_samples # Weighted number of samples cdef intp_t[::1] features # Feature indices in X cdef intp_t[::1] constant_features # Constant features indices cdef intp_t n_features # X.shape[1] @@ -95,20 +95,20 @@ cdef class Splitter: self, intp_t start, intp_t end, - double* weighted_n_node_samples + float64_t* weighted_n_node_samples ) except -1 nogil cdef int node_split( self, - double impurity, # Impurity of the node + float64_t impurity, # Impurity of the node SplitRecord* split, intp_t* n_constant_features, - double lower_bound, - double upper_bound, + float64_t lower_bound, + float64_t upper_bound, ) except -1 nogil - cdef void node_value(self, double* dest) noexcept nogil + cdef void node_value(self, float64_t* dest) noexcept nogil - cdef void clip_node_value(self, double* dest, double lower_bound, double upper_bound) noexcept nogil + cdef void clip_node_value(self, float64_t* dest, float64_t lower_bound, float64_t upper_bound) noexcept nogil - cdef double node_impurity(self) noexcept nogil + cdef float64_t node_impurity(self) noexcept nogil diff --git a/sklearn/tree/_splitter.pyx b/sklearn/tree/_splitter.pyx index f11c7feb541f8..a9d3a169ec84a 100644 --- a/sklearn/tree/_splitter.pyx +++ b/sklearn/tree/_splitter.pyx @@ -31,7 +31,7 @@ from ._utils cimport RAND_R_MAX cnp.import_array() -cdef double INFINITY = np.inf +cdef float64_t INFINITY = np.inf # Mitigate precision differences between 32 bit and 64 bit cdef float32_t FEATURE_THRESHOLD = 1e-7 @@ -62,7 +62,7 @@ cdef class Splitter: Criterion criterion, intp_t max_features, intp_t min_samples_leaf, - double min_weight_leaf, + float64_t min_weight_leaf, object random_state, const cnp.int8_t[:] monotonic_cst, ): @@ -81,7 +81,7 @@ cdef class Splitter: which would result in having less samples in a leaf are not considered. - min_weight_leaf : double + min_weight_leaf : float64_t The minimal weight each leaf can have, where the weight is the sum of the weights of each sample in it. @@ -161,7 +161,7 @@ cdef class Splitter: cdef intp_t[::1] samples = self.samples cdef intp_t i, j - cdef double weighted_n_samples = 0.0 + cdef float64_t weighted_n_samples = 0.0 j = 0 for i in range(n_samples): @@ -194,7 +194,7 @@ cdef class Splitter: return 0 cdef int node_reset(self, intp_t start, intp_t end, - double* weighted_n_node_samples) except -1 nogil: + float64_t* weighted_n_node_samples) except -1 nogil: """Reset splitter on node samples[start:end]. Returns -1 in case of failure to allocate memory (and raise MemoryError) @@ -206,7 +206,7 @@ cdef class Splitter: The index of the first sample to consider end : intp_t The index of the last sample to consider - weighted_n_node_samples : ndarray, dtype=double pointer + weighted_n_node_samples : ndarray, dtype=float64_t pointer The total weight of those samples """ @@ -227,11 +227,11 @@ cdef class Splitter: cdef int node_split( self, - double impurity, + float64_t impurity, SplitRecord* split, intp_t* n_constant_features, - double lower_bound, - double upper_bound, + float64_t lower_bound, + float64_t upper_bound, ) except -1 nogil: """Find the best split on node samples[start:end]. @@ -244,17 +244,17 @@ cdef class Splitter: pass - cdef void node_value(self, double* dest) noexcept nogil: + cdef void node_value(self, float64_t* dest) noexcept nogil: """Copy the value of node samples[start:end] into dest.""" self.criterion.node_value(dest) - cdef inline void clip_node_value(self, double* dest, double lower_bound, double upper_bound) noexcept nogil: + cdef inline void clip_node_value(self, float64_t* dest, float64_t lower_bound, float64_t upper_bound) noexcept nogil: """Clip the value in dest between lower_bound and upper_bound for monotonic constraints.""" self.criterion.clip_node_value(dest, lower_bound, upper_bound) - cdef double node_impurity(self) noexcept nogil: + cdef float64_t node_impurity(self) noexcept nogil: """Return the impurity of the current node.""" return self.criterion.node_impurity() @@ -290,13 +290,13 @@ cdef inline int node_split_best( Splitter splitter, Partitioner partitioner, Criterion criterion, - double impurity, + float64_t impurity, SplitRecord* split, intp_t* n_constant_features, bint with_monotonic_cst, const cnp.int8_t[:] monotonic_cst, - double lower_bound, - double upper_bound, + float64_t lower_bound, + float64_t upper_bound, ) except -1 nogil: """Find the best split on node samples[start:end] @@ -321,12 +321,12 @@ cdef inline int node_split_best( cdef float32_t[::1] feature_values = splitter.feature_values cdef intp_t max_features = splitter.max_features cdef intp_t min_samples_leaf = splitter.min_samples_leaf - cdef double min_weight_leaf = splitter.min_weight_leaf + cdef float64_t min_weight_leaf = splitter.min_weight_leaf cdef uint32_t* random_state = &splitter.rand_r_state cdef SplitRecord best_split, current_split - cdef double current_proxy_improvement = -INFINITY - cdef double best_proxy_improvement = -INFINITY + cdef float64_t current_proxy_improvement = -INFINITY + cdef float64_t best_proxy_improvement = -INFINITY cdef intp_t f_i = n_features cdef intp_t f_j @@ -671,13 +671,13 @@ cdef inline int node_split_random( Splitter splitter, Partitioner partitioner, Criterion criterion, - double impurity, + float64_t impurity, SplitRecord* split, intp_t* n_constant_features, bint with_monotonic_cst, const cnp.int8_t[:] monotonic_cst, - double lower_bound, - double upper_bound, + float64_t lower_bound, + float64_t upper_bound, ) except -1 nogil: """Find the best random split on node samples[start:end] @@ -694,12 +694,12 @@ cdef inline int node_split_random( cdef intp_t max_features = splitter.max_features cdef intp_t min_samples_leaf = splitter.min_samples_leaf - cdef double min_weight_leaf = splitter.min_weight_leaf + cdef float64_t min_weight_leaf = splitter.min_weight_leaf cdef uint32_t* random_state = &splitter.rand_r_state cdef SplitRecord best_split, current_split - cdef double current_proxy_improvement = - INFINITY - cdef double best_proxy_improvement = - INFINITY + cdef float64_t current_proxy_improvement = - INFINITY + cdef float64_t best_proxy_improvement = - INFINITY cdef intp_t f_i = n_features cdef intp_t f_j @@ -989,7 +989,7 @@ cdef class DensePartitioner: # (feature_values[p] >= end) or (feature_values[p] > feature_values[p - 1]) p[0] += 1 - cdef inline intp_t partition_samples(self, double current_threshold) noexcept nogil: + cdef inline intp_t partition_samples(self, float64_t current_threshold) noexcept nogil: """Partition samples for feature_values at the current_threshold.""" cdef: intp_t p = self.start @@ -1013,7 +1013,7 @@ cdef class DensePartitioner: cdef inline void partition_samples_final( self, intp_t best_pos, - double best_threshold, + float64_t best_threshold, intp_t best_feature, intp_t best_n_missing, ) noexcept nogil: @@ -1236,14 +1236,14 @@ cdef class SparsePartitioner: p_prev[0] = p[0] p[0] = p_next - cdef inline intp_t partition_samples(self, double current_threshold) noexcept nogil: + cdef inline intp_t partition_samples(self, float64_t current_threshold) noexcept nogil: """Partition samples for feature_values at the current_threshold.""" return self._partition(current_threshold, self.start_positive) cdef inline void partition_samples_final( self, intp_t best_pos, - double best_threshold, + float64_t best_threshold, intp_t best_feature, intp_t n_missing, ) noexcept nogil: @@ -1251,7 +1251,7 @@ cdef class SparsePartitioner: self.extract_nnz(best_feature) self._partition(best_threshold, best_pos) - cdef inline intp_t _partition(self, double threshold, intp_t zero_pos) noexcept nogil: + cdef inline intp_t _partition(self, float64_t threshold, intp_t zero_pos) noexcept nogil: """Partition samples[start:end] based on threshold.""" cdef: intp_t p, partition_end @@ -1504,11 +1504,11 @@ cdef class BestSplitter(Splitter): cdef int node_split( self, - double impurity, + float64_t impurity, SplitRecord* split, intp_t* n_constant_features, - double lower_bound, - double upper_bound + float64_t lower_bound, + float64_t upper_bound ) except -1 nogil: return node_split_best( self, @@ -1540,11 +1540,11 @@ cdef class BestSparseSplitter(Splitter): cdef int node_split( self, - double impurity, + float64_t impurity, SplitRecord* split, intp_t* n_constant_features, - double lower_bound, - double upper_bound + float64_t lower_bound, + float64_t upper_bound ) except -1 nogil: return node_split_best( self, @@ -1576,11 +1576,11 @@ cdef class RandomSplitter(Splitter): cdef int node_split( self, - double impurity, + float64_t impurity, SplitRecord* split, intp_t* n_constant_features, - double lower_bound, - double upper_bound + float64_t lower_bound, + float64_t upper_bound ) except -1 nogil: return node_split_random( self, @@ -1611,11 +1611,11 @@ cdef class RandomSparseSplitter(Splitter): ) cdef int node_split( self, - double impurity, + float64_t impurity, SplitRecord* split, intp_t* n_constant_features, - double lower_bound, - double upper_bound + float64_t lower_bound, + float64_t upper_bound ) except -1 nogil: return node_split_random( self, diff --git a/sklearn/tree/_tree.pxd b/sklearn/tree/_tree.pxd index 97634748c3f42..e4081921f40f9 100644 --- a/sklearn/tree/_tree.pxd +++ b/sklearn/tree/_tree.pxd @@ -48,14 +48,14 @@ cdef class Tree: cdef public intp_t node_count # Counter for node IDs cdef public intp_t capacity # Capacity of tree, in terms of nodes cdef Node* nodes # Array of nodes - cdef double* value # (capacity, n_outputs, max_n_classes) array of values + cdef float64_t* value # (capacity, n_outputs, max_n_classes) array of values cdef intp_t value_stride # = n_outputs * max_n_classes # Methods cdef intp_t _add_node(self, intp_t parent, bint is_left, bint is_leaf, - intp_t feature, double threshold, double impurity, + intp_t feature, float64_t threshold, float64_t impurity, intp_t n_node_samples, - double weighted_n_node_samples, + float64_t weighted_n_node_samples, unsigned char missing_go_to_left) except -1 nogil cdef int _resize(self, intp_t capacity) except -1 nogil cdef int _resize_c(self, intp_t capacity=*) except -1 nogil @@ -93,9 +93,9 @@ cdef class TreeBuilder: cdef intp_t min_samples_split # Minimum number of samples in an internal node cdef intp_t min_samples_leaf # Minimum number of samples in a leaf - cdef double min_weight_leaf # Minimum weight in a leaf + cdef float64_t min_weight_leaf # Minimum weight in a leaf cdef intp_t max_depth # Maximal tree depth - cdef double min_impurity_decrease # Impurity threshold for early stopping + cdef float64_t min_impurity_decrease # Impurity threshold for early stopping cpdef build( self, diff --git a/sklearn/tree/_tree.pyx b/sklearn/tree/_tree.pyx index eef59dfb3ec43..b4ce56a4d2a0b 100644 --- a/sklearn/tree/_tree.pyx +++ b/sklearn/tree/_tree.pyx @@ -59,8 +59,8 @@ cdef extern from "" namespace "std" nogil: from numpy import float32 as DTYPE from numpy import float64 as DOUBLE -cdef double INFINITY = np.inf -cdef double EPSILON = np.finfo('double').eps +cdef float64_t INFINITY = np.inf +cdef float64_t EPSILON = np.finfo('double').eps # Some handy constants (BestFirstTreeBuilder) cdef int IS_FIRST = 1 @@ -145,17 +145,17 @@ cdef struct StackRecord: intp_t depth intp_t parent bint is_left - double impurity + float64_t impurity intp_t n_constant_features - double lower_bound - double upper_bound + float64_t lower_bound + float64_t upper_bound cdef class DepthFirstTreeBuilder(TreeBuilder): """Build a decision tree in depth-first fashion.""" def __cinit__(self, Splitter splitter, intp_t min_samples_split, - intp_t min_samples_leaf, double min_weight_leaf, - intp_t max_depth, double min_impurity_decrease): + intp_t min_samples_leaf, float64_t min_weight_leaf, + intp_t max_depth, float64_t min_impurity_decrease): self.splitter = splitter self.min_samples_split = min_samples_split self.min_samples_leaf = min_samples_leaf @@ -190,9 +190,9 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): cdef Splitter splitter = self.splitter cdef intp_t max_depth = self.max_depth cdef intp_t min_samples_leaf = self.min_samples_leaf - cdef double min_weight_leaf = self.min_weight_leaf + cdef float64_t min_weight_leaf = self.min_weight_leaf cdef intp_t min_samples_split = self.min_samples_split - cdef double min_impurity_decrease = self.min_impurity_decrease + cdef float64_t min_impurity_decrease = self.min_impurity_decrease # Recursive partition (without actual recursion) splitter.init(X, y, sample_weight, missing_values_in_feature_mask) @@ -203,14 +203,18 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): cdef intp_t parent cdef bint is_left cdef intp_t n_node_samples = splitter.n_samples - cdef double weighted_n_node_samples + cdef float64_t weighted_n_node_samples cdef SplitRecord split cdef intp_t node_id - cdef double impurity = INFINITY - cdef double lower_bound - cdef double upper_bound - cdef double middle_value + cdef float64_t impurity = INFINITY + cdef float64_t lower_bound + cdef float64_t upper_bound + cdef float64_t middle_value + cdef float64_t left_child_min + cdef float64_t left_child_max + cdef float64_t right_child_min + cdef float64_t right_child_max cdef intp_t n_constant_features cdef bint is_leaf cdef bint first = 1 @@ -375,13 +379,13 @@ cdef struct FrontierRecord: intp_t pos intp_t depth bint is_leaf - double impurity - double impurity_left - double impurity_right - double improvement - double lower_bound - double upper_bound - double middle_value + float64_t impurity + float64_t impurity_left + float64_t impurity_right + float64_t improvement + float64_t lower_bound + float64_t upper_bound + float64_t middle_value cdef inline bool _compare_records( const FrontierRecord& left, @@ -409,7 +413,7 @@ cdef class BestFirstTreeBuilder(TreeBuilder): def __cinit__(self, Splitter splitter, intp_t min_samples_split, intp_t min_samples_leaf, min_weight_leaf, intp_t max_depth, intp_t max_leaf_nodes, - double min_impurity_decrease): + float64_t min_impurity_decrease): self.splitter = splitter self.min_samples_split = min_samples_split self.min_samples_leaf = min_samples_leaf @@ -442,10 +446,10 @@ cdef class BestFirstTreeBuilder(TreeBuilder): cdef FrontierRecord record cdef FrontierRecord split_node_left cdef FrontierRecord split_node_right - cdef double left_child_min - cdef double left_child_max - cdef double right_child_min - cdef double right_child_max + cdef float64_t left_child_min + cdef float64_t left_child_max + cdef float64_t right_child_min + cdef float64_t right_child_max cdef intp_t n_node_samples = splitter.n_samples cdef intp_t max_split_nodes = max_leaf_nodes - 1 @@ -589,13 +593,13 @@ cdef class BestFirstTreeBuilder(TreeBuilder): Tree tree, intp_t start, intp_t end, - double impurity, + float64_t impurity, bint is_first, bint is_left, Node* parent, intp_t depth, - double lower_bound, - double upper_bound, + float64_t lower_bound, + float64_t upper_bound, FrontierRecord* res ) except -1 nogil: """Adds node w/ partition ``[start, end)`` to the frontier. """ @@ -603,8 +607,8 @@ cdef class BestFirstTreeBuilder(TreeBuilder): cdef intp_t node_id cdef intp_t n_node_samples cdef intp_t n_constant_features = 0 - cdef double min_impurity_decrease = self.min_impurity_decrease - cdef double weighted_n_node_samples + cdef float64_t min_impurity_decrease = self.min_impurity_decrease + cdef float64_t weighted_n_node_samples cdef bint is_leaf splitter.node_reset(start, end, &weighted_n_node_samples) @@ -720,20 +724,20 @@ cdef class Tree: feature : array of int, shape [node_count] feature[i] holds the feature to split on, for the internal node i. - threshold : array of double, shape [node_count] + threshold : array of float64_t, shape [node_count] threshold[i] holds the threshold for the internal node i. - value : array of double, shape [node_count, n_outputs, max_n_classes] + value : array of float64_t, shape [node_count, n_outputs, max_n_classes] Contains the constant prediction value of each node. - impurity : array of double, shape [node_count] + impurity : array of float64_t, shape [node_count] impurity[i] holds the impurity (i.e., the value of the splitting criterion) at node i. n_node_samples : array of int, shape [node_count] n_node_samples[i] holds the number of training samples reaching node i. - weighted_n_node_samples : array of double, shape [node_count] + weighted_n_node_samples : array of float64_t, shape [node_count] weighted_n_node_samples[i] holds the weighted number of training samples reaching node i. @@ -872,7 +876,7 @@ cdef class Tree: memcpy(self.nodes, cnp.PyArray_DATA(node_ndarray), self.capacity * sizeof(Node)) memcpy(self.value, cnp.PyArray_DATA(value_ndarray), - self.capacity * self.value_stride * sizeof(double)) + self.capacity * self.value_stride * sizeof(float64_t)) cdef int _resize(self, intp_t capacity) except -1 nogil: """Resize all inner arrays to `capacity`, if `capacity` == -1, then @@ -908,7 +912,7 @@ cdef class Tree: if capacity > self.capacity: memset((self.value + self.capacity * self.value_stride), 0, (capacity - self.capacity) * self.value_stride * - sizeof(double)) + sizeof(float64_t)) # if capacity smaller than node_count, adjust the counter if capacity < self.node_count: @@ -918,9 +922,9 @@ cdef class Tree: return 0 cdef intp_t _add_node(self, intp_t parent, bint is_left, bint is_leaf, - intp_t feature, double threshold, double impurity, + intp_t feature, float64_t threshold, float64_t impurity, intp_t n_node_samples, - double weighted_n_node_samples, + float64_t weighted_n_node_samples, unsigned char missing_go_to_left) except -1 nogil: """Add a node to the tree. @@ -1267,7 +1271,7 @@ cdef class Tree: cdef Node* node = nodes cdef Node* end_node = node + self.node_count - cdef double normalizer = 0. + cdef float64_t normalizer = 0. cdef cnp.float64_t[:] importances = np.zeros(self.n_features) @@ -1338,7 +1342,7 @@ cdef class Tree: def compute_partial_dependence(self, float32_t[:, ::1] X, int[::1] target_features, - double[::1] out): + float64_t[::1] out): """Partial dependence of the response on the ``target_feature`` set. For each sample in ``X`` a tree traversal is performed. @@ -1367,16 +1371,16 @@ cdef class Tree: point. """ cdef: - double[::1] weight_stack = np.zeros(self.node_count, - dtype=np.float64) + float64_t[::1] weight_stack = np.zeros(self.node_count, + dtype=np.float64) intp_t[::1] node_idx_stack = np.zeros(self.node_count, dtype=np.intp) intp_t sample_idx intp_t feature_idx int stack_size - double left_sample_frac - double current_weight - double total_weight # used for sanity check only + float64_t left_sample_frac + float64_t current_weight + float64_t total_weight # used for sanity check only Node *current_node # use a pointer to avoid copying attributes intp_t current_node_idx bint is_target_feature @@ -1814,7 +1818,7 @@ def _build_pruned_tree_ccp( Location to place the pruned tree orig_tree : Tree Original tree - ccp_alpha : positive double + ccp_alpha : positive float64_t Complexity parameter. The subtree with the largest cost complexity that is smaller than ``ccp_alpha`` will be chosen. By default, no pruning is performed. @@ -1921,8 +1925,8 @@ cdef _build_pruned_tree( intp_t max_depth_seen = -1 int rc = 0 Node* node - double* orig_value_ptr - double* new_value_ptr + float64_t* orig_value_ptr + float64_t* new_value_ptr stack[BuildPrunedRecord] prune_stack BuildPrunedRecord stack_record @@ -1955,7 +1959,7 @@ cdef _build_pruned_tree( # copy value from original tree to new tree orig_value_ptr = orig_tree.value + value_stride * orig_node_id new_value_ptr = tree.value + value_stride * new_node_id - memcpy(new_value_ptr, orig_value_ptr, sizeof(double) * value_stride) + memcpy(new_value_ptr, orig_value_ptr, sizeof(float64_t) * value_stride) if not is_leaf: # Push right child on stack diff --git a/sklearn/tree/_utils.pxd b/sklearn/tree/_utils.pxd index cd7a77cc1bbc9..4167230bfbf4d 100644 --- a/sklearn/tree/_utils.pxd +++ b/sklearn/tree/_utils.pxd @@ -49,11 +49,11 @@ cdef intp_t rand_int(intp_t low, intp_t high, uint32_t* random_state) noexcept nogil -cdef double rand_uniform(double low, double high, - uint32_t* random_state) noexcept nogil +cdef float64_t rand_uniform(float64_t low, float64_t high, + uint32_t* random_state) noexcept nogil -cdef double log(double x) noexcept nogil +cdef float64_t log(float64_t x) noexcept nogil # ============================================================================= # WeightedPQueue data structure diff --git a/sklearn/tree/_utils.pyx b/sklearn/tree/_utils.pyx index 98a8249928b6f..3c0c312b25fbe 100644 --- a/sklearn/tree/_utils.pyx +++ b/sklearn/tree/_utils.pyx @@ -63,14 +63,14 @@ cdef inline intp_t rand_int(intp_t low, intp_t high, return low + our_rand_r(random_state) % (high - low) -cdef inline double rand_uniform(double low, double high, - uint32_t* random_state) noexcept nogil: - """Generate a random double in [low; high).""" - return ((high - low) * our_rand_r(random_state) / - RAND_R_MAX) + low +cdef inline float64_t rand_uniform(float64_t low, float64_t high, + uint32_t* random_state) noexcept nogil: + """Generate a random float64_t in [low; high).""" + return ((high - low) * our_rand_r(random_state) / + RAND_R_MAX) + low -cdef inline double log(double x) noexcept nogil: +cdef inline float64_t log(float64_t x) noexcept nogil: return ln(x) / ln(2.0) # ============================================================================= @@ -372,7 +372,7 @@ cdef class WeightedMedianCalculator: left and moving to the right. """ cdef int return_value - cdef double original_median = 0.0 + cdef float64_t original_median = 0.0 if self.size() != 0: original_median = self.get_median() @@ -389,7 +389,7 @@ cdef class WeightedMedianCalculator: cdef int update_median_parameters_post_remove( self, float64_t data, float64_t weight, - double original_median) noexcept nogil: + float64_t original_median) noexcept nogil: """Update the parameters used in the median calculation, namely `k` and `sum_w_0_k` after a removal""" # reset parameters because it there are no elements