اذهب إلى المحتوى

السؤال

نشر

قمت ببناء شبكة عصبية تلاففية لتصنيف الصور باستخدام تنسرفلو، والآن انتهيت من تدريبها لكنني أريد الآن حفظ جميع المتغيرات والغراف للشبكة في ملف لاستخدامه في المستقبل، كيف يمكنني القيام بذلك من خلال مكتبة تنسرفلو؟
 

Recommended Posts

  • 1
نشر

حسناً سأبدأ بتوضيح فكرة بسيطة (يمكنك تجاوزها إذا لم تكن مهتماً بالتفاصيل). ملفات تنسرفلو تتكون من ملفين رئيسيين هما "Meta graph" و "Checkpoint file" بحيث أن الأول هو عبارة عن بروتوكول ال "buffer" الذي يقوم بحفظ كامل الغراف Graph (أي جميع المتغيرات والعمليات ..إلخ) وهذا الملف يكون بامتداد meta. أما الثاني فهو ملف ثنائي يحتوي على جميع قيم الأوزان وال bias والتدرجات "gradients" وجميع المتغيرات الأخرى المحفوظة. ويكون ملف واحد بامتداد ckpt في الأصدارات التي تسبق الإصدار 0.11 أما في الإصدارات الأحدث فيتم تخزين هذه المعلومات في ملفين:

mymodel.data-00000-of-00001 # هذا هو الملف الذي يحوي متغيرات التدريب
mymodel.index

وإلى جانب هذه الملفات يحتفظ تنسرفلو بملف آخر هو checkpoint يحتفظ بآخر نقاط  ال checkpoints التي تم حفظها. أي في حالة النسخ الأقدم يتم تخزين النموذج كالتالي:

inception_v1.meta
inception_v1.ckpt
checkpoint

أما الأحدث:

mymodel.data-00000-of-00001
mymodel.index
inception_v1.meta
checkpoint

الآن سنبدأ (يمكنك تخطي القسم السابق). أول خطوة ستقوم بها هي أخذ كائن من الكلاس Saver، ويجب أن تقوم بإنشائه داخل الجلسة التي قمت بتعريف نموذجك ومتغيراتك بها لأن المتغيرات والعمليات ووو تكون نشطة (أو موجودة أو حية .. أيَاً يكن التعبير) فقط ضمن الجسلة الخاصة بها:

saver = tf.train.Saver() 
saver.save(sess, 'mymodel') # الوسيط الأول هو اسم الجلسة والثاني اسم الملف الذي نريد حفظه فيه

مثال:

import tensorflow as tf
# على سبيل المثال 
w1 = tf.Variable(tf.random_normal(shape=[2]), name='w1')
w2 = tf.Variable(tf.random_normal(shape=[5]), name='w2')
save_ = tf.train.Saver()
mysess = tf.Session()
sess.run(tf.global_variables_initializer())
save_.save(mysess, 'mymodel')

إذا أردنا حفظ النموذج بعد 1000 تكرار، فسنقوم باستدعاء Save وتمرير عدد الخطوات أيضاً :

save_.save(mysess, 'mymodel',global_step=1000)

إذا أردت أن يتم حفظ النموذج أيضاً بعد كل 3 ساعات مثلاً، ولمرتين فقط يمكنك القيام بالتالي:

save_ = tf.train.Saver(max_to_keep=2, keep_checkpoint_every_n_hours=3)
#max_to_keep نحدد العدد الأعظمي لعمليات الحفظ
# keep_checkpoint_every_n_hours الساعات

أيضاً يجب أن تعلم أنه عندما لانقوم بتحديد أي وسيط للكلاس Saver  فهذا يعني أننا نريد حفظ كل المتغيرات، وبالتالي إذا أردنا تحديد مايتم حفظه يجب أن نمرره لباني الصف كقائمة أو قاموس:

import tensorflow as tf
w1 = tf.Variable(tf.random_normal(shape=[2]), name='w1')
w2 = tf.Variable(tf.random_normal(shape=[5]), name='w2')
# نمرر له مانريد حفظه فقط
save_ = tf.train.Saver([w1,w2])
mysess = tf.Session()
mysess.run(tf.global_variables_initializer())
save_.save(mysess, 'mymodel')

يمكنك إنشاء الشبكة عن طريق كتابة كود بيثون لإنشاء كل طبقة يدوياً كنموذج أصلي. ومع ذلك، إذا فكرت في الأمر، فقد حفظنا الشبكة في ملف meta. والذي يمكننا استخدامه لإعادة إنشاء الشبكة باستخدام:

save_ = tf.train.import_meta_graph('mymodel.meta')

تذكر أن import_meta_graph تلحق الشبكة المحددة في ملف meta بالغراف الحالي. لذلك، سيؤدي هذا إلى إنشاء الغراف / الشبكة لك ولكننا ما زلنا بحاجة إلى تحميل قيم الأوزان التي دربناها على هذا الغراف. ويمكننا استعادتها بالشكل التالي:

with tf.Session() as mysess:
  new_save_ = tf.train.import_meta_graph(.....meta')
  new_save_.restore(sess, tf.train.latest_checkpoint('./'))

تذكر أيضاً أنها لا يتم حفظ ال placeholders.

 

 

انضم إلى النقاش

يمكنك أن تنشر الآن وتسجل لاحقًا. إذا كان لديك حساب، فسجل الدخول الآن لتنشر باسم حسابك.

زائر
أجب على هذا السؤال...

×   لقد أضفت محتوى بخط أو تنسيق مختلف.   Restore formatting

  Only 75 emoji are allowed.

×   Your link has been automatically embedded.   Display as a link instead

×   جرى استعادة المحتوى السابق..   امسح المحرر

×   You cannot paste images directly. Upload or insert images from URL.

  • إعلانات

  • تابعنا على



×
×
  • أضف...