পুরো ডাটাসেটে যথাযথ ক্লাস (কোনটা কোন ডিজিট) ম্যাপ করা ৭০০০০ ইমেজ আছে যার মধ্যে ৫৫০০০ হচ্ছে ট্রেনিং ইমেজ, ১০০০০ হচ্ছে টেস্ট ইমেজ এবং ৫০০০ হচ্ছে ভ্যালিডেশন ইমেজ। অর্থাৎ পুরো ডাটাসেটটি ৩টি সাবসেটে বিভক্ত। কিছু ডাটা ট্রেনিং এর জন্য, কিছু ডাটা ভ্যালিডেশনের জন্য, আর কিছু ডাটা হচ্ছে ফাইনাল মডেলকে টেস্ট করার জন্য। এই সাবসেট গুলো মিউচুয়ালি এক্সকুসিভ অর্থাৎ একটি সেটের ডাটা আরেকটি সেটের মধ্যে নাই। অর্থাৎ কমন কোন এলিমেন্ট এই ৩টি সেটের মধ্যে নাই। পরীক্ষা করে দেখতে পারি নিচের কোড ওয়ালা সেলটি এক্সিকিউট করে,
# Cell 3
print("Size of:")
print("- Training-set:\t\t{}".format(len(data.train.labels)))
print("- Test-set:\t\t{}".format(len(data.test.labels)))
print("- Validation-set:\t{}".format(len(data.validation.labels)))
আউটপুট,
Size of:
- Training-set: 55000
- Test-set: 10000
- Validation-set: 5000
এই টিউটোরিয়ালে আমরা ভ্যালিডেশন সেটের ব্যবহার করবো না। যা হোক, Cell 2 এর কোডের read_data_sets মেথডের দ্বিতীয় প্যারামিটার নিয়ে একটু কথা বলি. one_hot=True পাঠিয়ে আমরা বলছি যে এই ডাটাসেট এর লেবেল (ফটোর সাপেক্ষে সঠিক উত্তর/ডিজিট) গুলোকে আমরা এই ফরম্যাটে চাই। এই ফরম্যাট ডেসিম্যাল ডিজিটের বাইনারি রিপ্রেজেন্টেশনের মতই কিন্তু একটু অন্যভাবে রিপ্রেজেন্ট করে। মাত্র একটি বিট কে হাই বা 1 করে সেই ডিজিটের অবস্থান প্রকাশ করা হয়। নিচের উদাহরণ দেখলেই ব্যাপারটি সহজেই বোঝা যাবে। যেমন 0 এবং 5 এর বাইনারি রিপ্রেজেন্টেশন হয় নিচের মত,
আর One-Hot Vector প্রেজেন্টেশন হয় নিচের মত,
অর্থাৎ ডিজিটটি যদি 5 হয় তাহলে ৫টি বিট ওয়ালা একটি ভেক্টরের ৫নাম্বার বিটটি হাই অর্থাৎ 1 সেট করে দেয়া হয়। তো, আমাদের আলোচনায় ডাউনলোড করা হাতের লেখার ফটো গুলোর লেবেল গুলো আসছে এই ফরম্যাটে। আমরা ডাটাসেট থেকে প্রথম ৫টি ফটোর লেবেল গুলোর One-Hot Vector রিপ্রেজেন্টেশন দেখতে পারি নিচের মত করে,
# Cell 4
data.test.labels[0:5, :]
আউটপুট আসবে, নিচের মত,
array([[ 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
[ 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
[ 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
[ 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[ 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.]])
তাহলে আমরা দেখে দেখেই বলে দিতে পারি প্রথম ৫টি ডিজিটের লেবেল বা নাম কি। প্রথমটার ৭নাম্বার বিটটি হাই, তাই এটি 7. দ্বিতীয়টির ২ নাম্বার বিট হাই, অর্থাৎ এটি 2 লেখা একটি ফটোর লেবেল/নাম।
আমরা চাইলে একই কাজটা কোড লিখেও করতে পারি। যেমন, নিচের লাইন খেয়াল করুন,
# Cell 5
data.test.cls = np.array([label.argmax() for label in data.test.labels])
এখানে লুপ চালিয়ে প্রত্যেকটি লেবেল ভেক্টরকে নিয়ে তার উপর argmax() মেথডটি অ্যাপ্লাই করা হয়েছে। এই মেথডের কাজ হচ্ছে একটি ভেক্টরের মধ্যে যে বিটটি হাই থাকবে তার ইনডেক্স রিটার্ন করবে। হয়ে গেলো? আমরা লেবেল গুলোর One-Hot Vector টাইপের রিপ্রেজেন্টেসন থেকে খুব সহজেই সঠিক ডিজিট নাম্বারটা পেতে পারি। এই পুরো কনভার্সনটা একটা numpy array তে কনভার্ট করে স্টোর করা হচ্ছে।
এখন যদি আমরা data.test.cls ভ্যারিয়েবলের প্রথম ৫টি এলিমেন্ট দেখি তাহলে নিচের মত আউটপুট পাবো,
# Cell 6
data.test.cls[0:5]
array([7, 2, 1, 0, 4])
এতক্ষণে One-Hot Vector প্রেজেন্টেশন এবং argmax মেথডের কাজ বোঝা গেছে নিশ্চয়ই?