اذهب إلى المحتوى
  • 0

كيف أستطيع ان احصل على اوزان الطبقات في keras.

Moatasm Elshahry

السؤال

أنا احاول ان احصل على اوزان الطبقات في keras، قمت بكتابة الكود

import tensorflow as tf
import tensorflow.contrib.keras.api.keras.backend as K
from tensorflow.contrib.keras.api.keras.layers import Dense

tf.reset_default_graph()
init = tf.global_variables_initializer()
sess =  tf.Session()
K.set_session(sess) 

input_x = tf.placeholder(tf.float32, [None, 10], name='input_x')    
dense1 = Dense(10, activation='relu')(input_x)

sess.run(init)

dense1.get_weights()

لمحاولة عمل هذا لكن تظهر لي تلك المشكلة:

AttributeError: 'Tensor' object has no attribute 'weights'

ما الحل لتلك المشكلة؟

رابط هذا التعليق
شارك على الشبكات الإجتماعية

Recommended Posts

  • 0

لو قمت بكتابة

dense1 = Dense(10, activation='relu')(input_x)

فان dense1 ليست طبقة وانما هي الخرج، اما اذا اردت تعريف طبقة معينة تكتبها هكذا:

Dense(10, activation='relu')

لذا هذا ما يبدو انك تحاول تحقيقه:

dense1 = Dense(10, activation='relu')
y = dense1(input_x)

ويكون الكود الصحيح كالتالي:

import tensorflow as tf
from tensorflow.contrib.keras import layers

input_x = tf.placeholder(tf.float32, [None, 10], name='input_x')    
dense1 = layers.Dense(10, activation='relu')
y = dense1(input_x)

weights = dense1.get_weights()

 

رابط هذا التعليق
شارك على الشبكات الإجتماعية

  • 0

اولا, أنت تستعمل حاليا نسخة قديمة من tensorflow, الافضل ان تقوم بالانتقال الى النسخة 2.

ثانيا, انت تستعمل ال functional api, و dense1 تعتبر مخرجات في هذه الحالة و ليست طبقة, لكي تعمل معك الدالة get_weights() يجب ان يكون الاستدعاء على طبقة او Model , لان ال Model ايضا يتعبر طبقة.

هذا مثال يوضح طريقة استعمال الدالة ()get_weights مع Model كامل للحصول على الاوزان.

 

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

model = keras.Sequential(
    [
        layers.Dense(2, activation="relu", name="layer1"),
        layers.Dense(3, activation="relu", name="layer2"),
        layers.Dense(4, name="layer3"),
    ]
)

x = tf.ones((3, 3))
y = model(x)

model.get_weights()

النتيجة


# [array([[ 0.49743366, -0.5760138 ],
#         [-0.01432669,  0.6724075 ],
#         [ 0.747548  ,  0.8514656 ]], dtype=float32),
#  array([0., 0.], dtype=float32),
#  array([[-0.20251024, -0.29482132,  1.074827  ],
#         [-0.5137537 ,  0.1377641 ,  0.44807172]], dtype=float32),
#  array([0., 0., 0.], dtype=float32),
#  array([[-0.00770271, -0.14924443,  0.38956785,  0.85535455],
#         [ 0.13032079,  0.05204147,  0.14340723,  0.81273234],
#         [-0.24441314,  0.49561667, -0.10915673,  0.12434208]],
#        dtype=float32),
#  array([0., 0., 0., 0.], dtype=float32)]

 

تم التعديل في بواسطة Walid K
رابط هذا التعليق
شارك على الشبكات الإجتماعية

انضم إلى النقاش

يمكنك أن تنشر الآن وتسجل لاحقًا. إذا كان لديك حساب، فسجل الدخول الآن لتنشر باسم حسابك.

زائر
أجب على هذا السؤال...

×   لقد أضفت محتوى بخط أو تنسيق مختلف.   Restore formatting

  Only 75 emoji are allowed.

×   Your link has been automatically embedded.   Display as a link instead

×   جرى استعادة المحتوى السابق..   امسح المحرر

×   You cannot paste images directly. Upload or insert images from URL.

  • إعلانات

  • تابعنا على



×
×
  • أضف...