ظهور الخطأ التالي: ValueError: n_splits=3 cannot be greater than the number of members in each class. أثناء محاولة تقسيم البيانات باستخدام StratifiedKFold
أحاول تقسيم البيانات باستخدام StratifiedKFold لكن يظهر لي الخطأ التالي :
import numpy as np
from sklearn.model_selection importStratifiedKFold
X = np.array([[1,4],[2,1],[3,4],[7,8],[2,8]])
y = np.array([2,1,3,4,4])
skf =StratifiedKFold(n_splits=3)print(skf.get_n_splits(X, y))for train_index, test_index in skf.split(X, y):print("TRAIN:"+str(train_index)+'\n'+"TEST:"+str(test_index),end='\n\n')
X_train, X_test = X[train_index], X[test_index]
y_train, y_test = y[train_index], y[test_index]print('X_train:\n '+str(X_train),end='\n\n')print('X_test:\n '+str(X_test),end='\n\n')print('y_train:\n '+str(y_train),end='\n\n')print('y_test:\n'+str(y_test),end='\n\n')---------------------------------------------------------------------------ValueErrorTraceback(most recent call last)<ipython-input-56-6c55afa3238f>in<module>6print(skf.get_n_splits(X, y))7---->8for train_index, test_index in skf.split(X, y):9# للتقسيمة index عرض ال10print("TRAIN:"+str(train_index)+'\n'+"TEST:"+str(test_index),end='\n\n')~\anaconda3\lib\site-packages\sklearn\model_selection\_split.py in split(self, X, y, groups)333.format(self.n_splits, n_samples))334-->335for train, test in super().split(X, y, groups):336yield train, test
337~\anaconda3\lib\site-packages\sklearn\model_selection\_split.py in split(self, X, y, groups)78 X, y, groups = indexable(X, y, groups)79 indices = np.arange(_num_samples(X))--->80for test_index in self._iter_test_masks(X, y, groups):81 train_index = indices[np.logical_not(test_index)]82 test_index = indices[test_index]~\anaconda3\lib\site-packages\sklearn\model_selection\_split.py in _iter_test_masks(self, X, y, groups)690691def _iter_test_masks(self, X, y=None, groups=None):-->692 test_folds = self._make_test_folds(X, y)693for i in range(self.n_splits):694yield test_folds == i
~\anaconda3\lib\site-packages\sklearn\model_selection\_split.py in _make_test_folds(self, X, y)661raiseValueError("n_splits=%d cannot be greater than the"662" number of members in each class."-->663%(self.n_splits))664if self.n_splits > min_groups:665 warnings.warn(("The least populated class in y has only %d"ValueError: n_splits=3 cannot be greater than the number of members in each class.
السؤال
Meezo ML
أحاول تقسيم البيانات باستخدام StratifiedKFold لكن يظهر لي الخطأ التالي :
ماهو عدد الأعضاء؟ وكيف نحل المشكلة؟
2 أجوبة على هذا السؤال
Recommended Posts
انضم إلى النقاش
يمكنك أن تنشر الآن وتسجل لاحقًا. إذا كان لديك حساب، فسجل الدخول الآن لتنشر باسم حسابك.