اذهب إلى المحتوى

السؤال

نشر

أنا احاول ان احصل على اوزان الطبقات في keras، قمت بكتابة الكود

import tensorflow as tf
import tensorflow.contrib.keras.api.keras.backend as K
from tensorflow.contrib.keras.api.keras.layers import Dense

tf.reset_default_graph()
init = tf.global_variables_initializer()
sess =  tf.Session()
K.set_session(sess) 

input_x = tf.placeholder(tf.float32, [None, 10], name='input_x')    
dense1 = Dense(10, activation='relu')(input_x)

sess.run(init)

dense1.get_weights()

لمحاولة عمل هذا لكن تظهر لي تلك المشكلة:

AttributeError: 'Tensor' object has no attribute 'weights'

ما الحل لتلك المشكلة؟

Recommended Posts

  • 0
نشر

لو قمت بكتابة

dense1 = Dense(10, activation='relu')(input_x)

فان dense1 ليست طبقة وانما هي الخرج، اما اذا اردت تعريف طبقة معينة تكتبها هكذا:

Dense(10, activation='relu')

لذا هذا ما يبدو انك تحاول تحقيقه:

dense1 = Dense(10, activation='relu')
y = dense1(input_x)

ويكون الكود الصحيح كالتالي:

import tensorflow as tf
from tensorflow.contrib.keras import layers

input_x = tf.placeholder(tf.float32, [None, 10], name='input_x')    
dense1 = layers.Dense(10, activation='relu')
y = dense1(input_x)

weights = dense1.get_weights()

 

  • 0
نشر (معدل)

اولا, أنت تستعمل حاليا نسخة قديمة من tensorflow, الافضل ان تقوم بالانتقال الى النسخة 2.

ثانيا, انت تستعمل ال functional api, و dense1 تعتبر مخرجات في هذه الحالة و ليست طبقة, لكي تعمل معك الدالة get_weights() يجب ان يكون الاستدعاء على طبقة او Model , لان ال Model ايضا يتعبر طبقة.

هذا مثال يوضح طريقة استعمال الدالة ()get_weights مع Model كامل للحصول على الاوزان.

 

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

model = keras.Sequential(
    [
        layers.Dense(2, activation="relu", name="layer1"),
        layers.Dense(3, activation="relu", name="layer2"),
        layers.Dense(4, name="layer3"),
    ]
)

x = tf.ones((3, 3))
y = model(x)

model.get_weights()

النتيجة


# [array([[ 0.49743366, -0.5760138 ],
#         [-0.01432669,  0.6724075 ],
#         [ 0.747548  ,  0.8514656 ]], dtype=float32),
#  array([0., 0.], dtype=float32),
#  array([[-0.20251024, -0.29482132,  1.074827  ],
#         [-0.5137537 ,  0.1377641 ,  0.44807172]], dtype=float32),
#  array([0., 0., 0.], dtype=float32),
#  array([[-0.00770271, -0.14924443,  0.38956785,  0.85535455],
#         [ 0.13032079,  0.05204147,  0.14340723,  0.81273234],
#         [-0.24441314,  0.49561667, -0.10915673,  0.12434208]],
#        dtype=float32),
#  array([0., 0., 0., 0.], dtype=float32)]

 

تم التعديل في بواسطة Walid K

انضم إلى النقاش

يمكنك أن تنشر الآن وتسجل لاحقًا. إذا كان لديك حساب، فسجل الدخول الآن لتنشر باسم حسابك.

زائر
أجب على هذا السؤال...

×   لقد أضفت محتوى بخط أو تنسيق مختلف.   Restore formatting

  Only 75 emoji are allowed.

×   Your link has been automatically embedded.   Display as a link instead

×   جرى استعادة المحتوى السابق..   امسح المحرر

×   You cannot paste images directly. Upload or insert images from URL.

  • إعلانات

  • تابعنا على



×
×
  • أضف...