Understanding and Creating Embeddings with a Simple Neural Network

Jul 14, 2024

Understanding and Creating Embeddings with a Simple Neural Network

Introduction to Embeddings

  • Embeddings in AI: A concept in modern AI that represents data in multi-dimensional space using vectors
  • Utility: Embeddings help to represent similar concepts close to each other and distinct concepts far apart in vector space
  • Objective: Create embeddings from scratch using a simple neural network without complex structures like Transformers

Embeddings Basics

  • Definition: An embedding is a vector, a bunch of floating-point numbers that represent coordinates in multi-dimensional space
  • Analogy: Similar to latitude and longitude but in a multi-dimensional space
  • Visualization: Tools like Cohere's playground can be used to generate and visualize embeddings

Creating Embeddings: A Simple Neural Network Approach

  • Example with Words: Words like bananas, apples, and rice (food-related) are close in embedding space, similar to camera and tripod (photography-related)
  • Dimensionality Reduction: Embeddings give a reduced representation of data while preserving meaningful relationships

Importance of Embeddings

  • Recommendation Systems: Provide context to models for better answer generation
  • Similarity Search: Helps in finding related documents or data points by calculating distances between embeddings

Key Concept: Siamese Network

  • Explanation: A neural network with two inputs that share the same weights (weights are shared for representing purposes)
  • Training Siamese Network: Train the network to minimize the distance between embeddings of similar data and maximize for dissimilar data

Practical Implementation

  • Dataset Chosen: MNIST dataset with handwritten digits (28x28 pixels)
  • Normalization: Pixel values are normalized to a range between 0 and 1
  • Pair Generation: Generate pairs of images (positive pairs: same digit, negative pairs: different digits)
  • Architecture: Simple Fully Connected Network
  • Embedding Size: Set to 128 for simplicity

Code Walkthrough

Pre-processing and Pair Generation

  • Data Reshaping: Convert images to 784-pixel arrays
  • Normalization: [pixel_value] / 255
  • Pair Creation: Function to create positive and negative pairs along with their labels

Building the Siamese Network

  • Input Layers: Two inputs of size 784
  • Network Body: A few fully connected layers with ReLU activation ending in an embedding layer of size 128
  • Distance Calculation: A Lambda layer to compute Euclidean distance between the embeddings

Loss Function: Contrastive Loss

  • Principle: Penalize the network if similar pairs are far apart and if dissimilar pairs are close
  • Implementation: Contrastive loss as outlined in a 2005 paper (function provided in the code)

Model Compilation and Training

  • Compilation: Using Adam optimizer and binary accuracy metrics
  • Training Process: Train for a specified number of epochs (e.g., 5 epochs)
  • Results: Training and validation accuracy trends observed in the output

Evaluating the Model

  • Predict and Display: Function to predict and display pairs of images to visualize the network's performance
  • Accuracy: Approx. 97% accuracy achieved in similarity prediction

Using the Embeddings

  • New Model for Embeddings: Extract the embedding model from the Siamese network for generating embeddings
  • Embedding Distance Calculation: Compute distances between samples and check for accuracy manually

Conclusion

  • Applications: Anomaly detection, recommendation systems, embeddings in large language models
  • Future Learning: Explore more advanced architectures and embedding techniques for different applications