Hide table of contents

Quick Intro: My name is Strad and I am a new grad working in tech wanting to learn and write more about AI safety and how tech will effect our future. I'm trying to challenge myself to write a short article a day to get back into writing. Would love any feedback on the article and any advice on writing in this field!  

A popular question on the interpretability of AI models is the degree to which they just memorize their training data vs learn actual concepts to help produce accurate outputs. With recent advances in the field, it is clear that models do tend to learn more basic concepts that can help them generalize, especially as their size increases.

One phenomenon found in models that helps emphasize this point is “Grokking.” This is the tendency for AI models failing to achieve a high validation accuracy to suddenly improve its accuracy after a large amount of training iterations.

Concept art of a model remembering individual data points vs creating a representation of them based on a general concept learned

To better understand this phenomenon, I provide an overview of the foundational paper, “Grokking: Generalization Beyond Overfitting on Small Algorithmic Datasets” which first demonstrated and coined the term, below.

The Setup

In order to demonstrate grokking in practice, the researchers behind this paper utilized a synthetic dataset consisting of sudoku-like tables which represented a given operation done between two symbols. An example of a small version of one of these tables can be seen below:

Here, the star represents a binary operation which is any rule where two inputs go in and one output comes out such as addition, multiplication, division or combinations of them. Each cell represents the result of the binary operation with that cell’s row and column as input. The goal of the AI models in this paper was to determine the values of the missing cells.

A key aspect of this dataset is the fact that the actual symbols used in the binary operation tables don’t represent any mathematical concept. The symbols are essentially random to the model. This means that the model has no extra information from the symbols themselves and has to learn any binary operations connecting symbols in a given table from scratch.

Key Findings from the Paper

The Grokking Phenomenon

The dataset of binary operation tables was used to train a model. As expected, the accuracy of the model on the training data was very poor at the beginning and then shot up and converged to 100% after a sufficient amount of optimization steps. During this time the accuracy of the model on the validation set remained close to 0.

What was interesting is that after continuing to run the model for significantly more optimization steps, well above the usual amount used in practice, it was found that this validation accuracy shot up to a near 100%. The increase was very sudden and fast rather than a gradual process. It seemed as if the model suddenly shifted its internals in a way that allowed it to understand what was necessary to accurately solve the binary operation tables. The researchers coined the term “Grokking” to represent this phenomenon.

Even more interesting was what occurred with the model slightly before this high accuracy was achieved. The error of the model before achieving high accuracy actually started to go up and then suddenly dropped once it hit improved performance.

This lends to the idea that after enough optimization steps, the model started to explore new strategies for solving the tables rather than just memorizing the training data. While this exploration resulted in increased errors, it also allowed the model to discover other strategies that let it quickly improve at the task.

A Tradeoff Between Data and Time

Another interesting result found from the experiments done in the paper was the fact that the fraction of data used for training had a inverse relationship with the amount of time it took for the model to experience grokking and converge to a high accuracy for its validation set.

What was even more interesting was just how sensitive the amount of time was to the amount of training data. The paper stated that for some of the experiments, reducing the amount of training data by just 1% was enough to increase the amount of time to reach grokking by 40–50%. It was also shown that for sufficiently low levels of training data, the model never experienced grokking meaning there wasn't enough data to learn the necessary concepts for high performance.

Weight Decay Helps Speed Up the Time it Takes for Grokking to Occur

Many different optimization techniques where tested in order to see which ones decreased the time it took for Grokking to occur in a model. What the researchers found was that weight decay seemed to have the most significant impact on decreasing this time.

Weight decay is an optimization technique that penalizes models for large weights causing them to trend towards smaller weights. This makes the model biased towards simpler solutions which the researchers hypothesize pushes the model towards flatter regions in the loss landscape which is where models that grok tend to lie.

In other words, simpler solutions are more likely to generalize since they are less likely to be overfitted to the training data, and weight decay makes the models more likely to find one of these solutions earlier.

Visualization After Grokking Shows Recognizable Mathematical Structures

Another thing the researchers did was take the embeddings of these models created for the abstract symbols after grokking and, using a dimension-reducing technique called t-SNE, mapped them onto a 2D coordinate system.

The point of this was to get an idea of how the models structured the relationships between these symbols based on the data in the binary operation tables. What the researchers found was that some of the structures that the model created accurately matched up with the known mathematical structures associated with the given binary operation.

This further lends to the idea that the model is actually learning the binary operation itself and applying it to the data rather than just remembering what values were in previous operation tables with similar patterns.

The researchers also suggest that this mapping of embeddings to real mathematical structures might mean that novel mathematical structures could be discovered in the future using the internal representations of mathematical concepts learnt by an AI model.

Takeaway

The phenomenon of grokking has been of interest to interpretability research as a whole since it emphasizes the fact that model outputs alone are not a reliable way of determining how a model is actually reasoning. A model could have a high accuracy due to its memory of the training data or due to a deeper concept it learned that allowed it to generalize.

Furthermore, grokking can be used to help in interpretability efforts by using snapshots of a model before and after grokking to determine how its internals changed to produce a more general representation of concepts.

While this paper was limited in that it only tested data of a very specific format, the phenomenon of grokking became well popularized in the interpretability space afterwards and much work related to it has been done since which I hope to look into and review in future articles!

4

0
0

Reactions

0
0

More posts like this

Comments1
Sorted by Click to highlight new comments since:
Curated and popular this week
Relevant opportunities