اذهب إلى المحتوى
  • 0

ظهور الخطأ التالي: ValueError: n_splits=3 cannot be greater than the number of members in each class. أثناء محاولة تقسيم البيانات باستخدام StratifiedKFold

Meezo ML

السؤال

أحاول تقسيم البيانات باستخدام StratifiedKFold لكن يظهر لي الخطأ التالي :

import numpy as np
from sklearn.model_selection import StratifiedKFold
X = np.array([[1,4],[2,1],[3,4],[7,8],[2,8]])
y = np.array([2,1,3,4,4])
skf = StratifiedKFold(n_splits=3)
print(skf.get_n_splits(X, y))

for train_index, test_index in skf.split(X, y):
    print("TRAIN:"+str(train_index)+'\n'+"TEST:"+str(test_index),end='\n\n')
    X_train, X_test = X[train_index], X[test_index]
    y_train, y_test = y[train_index], y[test_index]
    print('X_train:\n '+str(X_train),end='\n\n')
    print('X_test:\n '+str(X_test),end='\n\n')
    print('y_train:\n '+str(y_train),end='\n\n')
    print('y_test:\n' +str(y_test),end='\n\n')
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-56-6c55afa3238f> in <module>
      6 print(skf.get_n_splits(X, y))
      7 
----> 8 for train_index, test_index in skf.split(X, y):
      9     # للتقسيمة index  عرض ال
     10     print("TRAIN:"+str(train_index)+'\n'+"TEST:"+str(test_index),end='\n\n')

~\anaconda3\lib\site-packages\sklearn\model_selection\_split.py in split(self, X, y, groups)
    333                 .format(self.n_splits, n_samples))
    334 
--> 335         for train, test in super().split(X, y, groups):
    336             yield train, test
    337 

~\anaconda3\lib\site-packages\sklearn\model_selection\_split.py in split(self, X, y, groups)
     78         X, y, groups = indexable(X, y, groups)
     79         indices = np.arange(_num_samples(X))
---> 80         for test_index in self._iter_test_masks(X, y, groups):
     81             train_index = indices[np.logical_not(test_index)]
     82             test_index = indices[test_index]

~\anaconda3\lib\site-packages\sklearn\model_selection\_split.py in _iter_test_masks(self, X, y, groups)
    690 
    691     def _iter_test_masks(self, X, y=None, groups=None):
--> 692         test_folds = self._make_test_folds(X, y)
    693         for i in range(self.n_splits):
    694             yield test_folds == i

~\anaconda3\lib\site-packages\sklearn\model_selection\_split.py in _make_test_folds(self, X, y)
    661             raise ValueError("n_splits=%d cannot be greater than the"
    662                              " number of members in each class."
--> 663                              % (self.n_splits))
    664         if self.n_splits > min_groups:
    665             warnings.warn(("The least populated class in y has only %d"

ValueError: n_splits=3 cannot be greater than the number of members in each class.

ماهو عدد الأعضاء؟ وكيف نحل المشكلة؟

رابط هذا التعليق
شارك على الشبكات الإجتماعية

Recommended Posts

  • 2

أنت لديك 5 عينات، موزعة كالتالي:
عينة للأصناف 2و1و3
وعينتين للصنف 4
ال "number of members in each class" يقصد بها عدد الأعضاء في كل صف أي عدد العينات من أجل كل صف.
يحاول StratifiedKFold  الحفاظ على نسبة معينة في كل fold من هؤلاء الأعضاء.
أنت حددت 3 تقسيمات "Fold" وبالتالي في كل تقسيمة من  أجل الكلاس 2 مثلاً، يجب على الأقل الحفاظ على نسبة0.33 =1/3 عضو. ومن الكلاس 1 و 3 أيضاً يجب أن تتحقق نفس النسبة (0.33 عضو).
أما من أجل الكلاس 4 فيجب أن تتحقق النسبة 2/3=0.67 عضو.
لكن 0.33 تعني جزء من العينة وهذا غير ممكن! يجب على الأقل أن يكون عدد الأعضاء في الكلاسات من 1 إلى 3 يساوي 3 ومن الكلاس 4 أيضاً 3 لكي يصبح 1 عضو في كل تقسيمة على الأقل.
ولهذا السبب ظهر الخطأ.
 إذاً يجب أن يكون لديك على الأقل 3 عينات من أجل كل كلاس.
لاحظ أيضاً أنك إذا وضعت 2 (أقصد تقسيمتين) سوف ينجح الأمر لكنه سيعطيك التحذير التالي:

UserWarning: The least populated class in y has only 1 members, which is less than n_splits=2.
  % (min_groups, self.n_splits)), UserWarning)

حيث أنه يعطيك خطأ إذا لم يجد أي كلاس يحقق الشرط (لأنه لن يكون لفكرة الخوارزمية معنى بعد ذلك)، بينما يعطيك تحذير إذا كان هناك كلاس واحد  على الأقل يحققه.
 

رابط هذا التعليق
شارك على الشبكات الإجتماعية

  • 1

في الخطأ المذكور أعلاه لديك عدد 4 أصناف بحساب عينة واحدة لكل صنف بإستثناء الصنف 4 الذي لديه عينتين، فبالتالي لا يمكن التقسيم لثلاثة أجزاء للقيام بعملية التحقق، على الأقل يجب أن يساوي أحد الأصناف القيمة الموضوعة ل n_splits. لاحظ أنه يمكن لبرنامج العمل في حال كان عدد الأقسام 2 لانها تحقق الشرط (الصنف 4 يحتوي على عينتين).

import numpy as np
from sklearn.model_selection import StratifiedKFold
X = np.array([[1,4],[2,1],[3,4],[7,8],[2,8]])
y = np.array([2,2,1,3,4])
skf = StratifiedKFold(n_splits=2)
print(skf.get_n_splits(X, y))

for train_index, test_index in skf.split(X, y):
    print("TRAIN:"+str(train_index)+'\n'+"TEST:"+str(test_index),end='\n\n')
    X_train, X_test = X[train_index], X[test_index]
    y_train, y_test = y[train_index], y[test_index]
    print('X_train:\n '+str(X_train),end='\n\n')
    print('X_test:\n '+str(X_test),end='\n\n')
    print('y_train:\n '+str(y_train),end='\n\n')
    print('y_test:\n' +str(y_test),end='\n\n')

لكن بعد تنفيذ البرنامج يظهر إشعار التنبيه التالي:

/usr/local/lib/python3.7/dist-packages/sklearn/model_selection/_split.py:667: UserWarning: The least populated class in y has only 1 members, which is less than n_splits=2.
  % (min_groups, self.n_splits)), UserWarning)

والذي بلفت إنتباهك للأصناف التي تحتوي فقط على عنصر واحد مثل 1و2و3.

و الان لتفادي ذلك الخطأ يمكنك زيادة عدد العينات في البرنامج بمعدل 3 أو أكثر لكل صنف على حدة، راجع المثال أدناه:

import numpy as np
from sklearn.model_selection import StratifiedKFold
X = np.array([[1,4],[2,1],[3,4],[7,8],[2,8],[1,4],[2,1],[3,4],[7,8],[2,8],[1,4],[2,1],[3,4],[7,8],[2,8]])
y = np.array([2,2,1,3,4,2,2,1,3,4,2,2,1,3,4])
skf = StratifiedKFold(n_splits=3)
print(skf.get_n_splits(X, y))

for train_index, test_index in skf.split(X, y):
    print("TRAIN:"+str(train_index)+'\n'+"TEST:"+str(test_index),end='\n\n')
    X_train, X_test = X[train_index], X[test_index]
    y_train, y_test = y[train_index], y[test_index]
    print('X_train:\n '+str(X_train),end='\n\n')
    print('X_test:\n '+str(X_test),end='\n\n')
    print('y_train:\n '+str(y_train),end='\n\n')
    print('y_test:\n' +str(y_test),end='\n\n')

البرنامج يحتوي على نفس البيانات الموجودة في المثال الذي أنتج الخطأ و لكن بزيادة عدد العينات بمعدل 3 عينات أقلاها لكل صنف.

رابط هذا التعليق
شارك على الشبكات الإجتماعية

انضم إلى النقاش

يمكنك أن تنشر الآن وتسجل لاحقًا. إذا كان لديك حساب، فسجل الدخول الآن لتنشر باسم حسابك.

زائر
أجب على هذا السؤال...

×   لقد أضفت محتوى بخط أو تنسيق مختلف.   Restore formatting

  Only 75 emoji are allowed.

×   Your link has been automatically embedded.   Display as a link instead

×   جرى استعادة المحتوى السابق..   امسح المحرر

×   You cannot paste images directly. Upload or insert images from URL.

  • إعلانات

  • تابعنا على



×
×
  • أضف...