HomeArticle

Can AI Really Learn Mental Arithmetic? Implicit Chain of Thought Theoretically Proven for the First Time, With Stuart Russell's Participation

机器之心2026-06-08 09:45
UC Berkeley has mathematically proven the feasibility of implicit Chain-of-Thought.

In the past year, the cost of using AI inference models has given many developers a hard time.

The "slow thinking" models truly shine when dealing with math, code, and logic problems. However, the cost is that each call generates hundreds or even thousands of "thinking tokens". These tokens, which come before the answer, are like the scratch paper for the model's step - by - step calculations. These tokens are visible but expensive. For a complex math problem, just the "thinking process" may consume more than ten times the computing resources of an ordinary conversation.

In the thinking mode, even simple communication consumes a lot of tokens.

Recently, some new technologies have indeed shown the possibility of reducing inference costs. However, no matter how the architecture is optimized, as long as the intermediate steps of the Chain - of - Thought (CoT) are still generated one by one in the form of tokens, there is a fundamental lower limit to the inference delay. Each step must start after the previous one is completed. The longer the inference chain, the longer the waiting time.

This is a structural problem, not an engineering problem.

So, is it possible to let the model "hide the scratch paper in its brain" and still retain the reasoning ability brought by the explicit Chain - of - Thought without outputting any intermediate steps?

This is exactly what the "Implicit Chain - of - Thought (ICoT)" aims to solve. Just a few days ago, a research team from UC Berkeley and Princeton University took a crucial step in this issue. They not only presented a solution but also strictly proved its effectiveness mathematically.

Paper title: Transformers Provably Learn to Internalize Chain - of - Thought

Paper address: https://arxiv.org/abs/2605.28600v1

The main authors of this research are from UC Berkeley and Princeton University. The first author is Yixiao Huang, a doctoral student at Berkeley. The guiding professors include Jiantao Jiao, Stuart Russell, Somayeh Sojoudi, and Song Mei.

This team has published a series of works on analyzing the training mechanism of Transformers using mathematical methods in recent years, covering topics from the formation of attention patterns to the optimization dynamics of multi - step reasoning. This research on ICoT is their attempt to extend the theoretical tool system to the new field of "implicit reasoning".

The Cost of Chain - of - Thought

To understand the significance of this research, we need to first figure out where the high cost of the Chain - of - Thought lies.

Let's make an analogy. Suppose you are tutoring a student in multi - digit multiplication. One way is to let him write down each step of the operation on paper and calculate line by line: first calculate the units digit, then the tens digit, and finally add them up. This is the explicit Chain - of - Thought - each intermediate result is visible and can therefore be checked and corrected. Another way is to let him "calculate in his head" and directly report the final answer.

There is an essential difference in information processing between these two methods. The former is serial: each step depends on the result of the previous step and cannot be parallelized. The latter is different - if the brain can handle all intermediate calculations at once, the answer can be obtained almost simultaneously.

For large language models (LLMs), this difference is directly reflected in inference delay and token consumption. The explicit Chain - of - Thought requires the model to generate each intermediate token one by one. If the inference chain has k steps, at least k additional tokens need to be output, and these tokens must be generated strictly serially. For the current state - of - the - art inference models, this number is often in the hundreds to thousands.

The idea of ICoT is: Can we train the model to "internalize" the intermediate steps into the hidden state and only output the answer during the final inference, with the intermediate steps completely invisible?

This idea itself is not new. Yuntian Deng et al. proposed a training method in their 2024 paper "From Explicit CoT to Implicit CoT: Learning to Internalize CoT Step by Step": First, let the model learn to answer with a complete Chain - of - Thought, then gradually "hide" the intermediate tokens one by one, with one less each time, so that the model gradually gets used to completing reasoning with fewer visible clues. This method is effective in experiments, but it has an obvious flaw: if the Chain - of - Thought has k steps, k - 1 training stages are required, and the training cost increases linearly with the length of the inference chain.

A more fundamental problem is that no one knows why this method is effective. Can we theoretically guarantee that what ICoT learns is equivalent to the explicit CoT? Under what conditions can we guarantee it? These questions remain unresolved.

Core Innovation: Redesigning the Training Curriculum with a Tree Structure

The core contributions of this paper are at two levels: a new training method and the first strict mathematical proof for this method.

The experimental platform for the research is the "k - parity" problem, which is a classic test - bed in theoretical computer science.

Given n bits, select k of them and determine whether their product is +1 or -1. The characteristic of this problem is that there are no intermediate steps, and no gradient - descent algorithm with finite precision can solve it with non - trivial precision using a polynomial number of samples. However, once a complete Chain - of - Thought is provided as assistance, even a single - layer Transformer can learn it efficiently. This contrast makes it an ideal sandbox for studying the mechanism of CoT.

Key insight: The structure of the Chain - of - Thought is actually a tree.

The parity check of k bits can be decomposed into a binary tree with a depth of log₂k. The leaf nodes are the original input bits, and each internal node calculates the product of its two child nodes, recursively reaching the root node to obtain the final answer. The structure of this tree determines the hierarchical relationship of the intermediate steps: the first layer calculates pairwise products, the second layer calculates the products of the results of the first layer, and so on.

The standard ICoT method hides only one token at a time and does not utilize the structure of this tree at all. The "Log - ICoT" proposed in this paper hides an entire layer of the tree at once. This means that the original k - 1 training stages are now reduced to log₂k. For k = 16, this means reducing from 15 stages to 4.

This is not just an improvement in engineering efficiency. More importantly, it aligns the training process with the hierarchical structure inside the model - each Transformer layer is exactly responsible for absorbing one level of the Chain - of - Thought tree.

Schematic diagram of the comparison of three training paradigms: Explicit CoT, Standard ICoT, Log - ICoT

Theoretical Proof: Writing "Internalization" as a Theorem for the First Time

The most milestone - worthy part of this research is providing the first strict convergence guarantee for ICoT.

Core content of the theorem (Theorem 1): An L - layer Transformer, when trained under the Log - ICoT curriculum, only needs a polynomial number (on the order of n^(2 + ε)) of samples and log₂k gradient steps to predict the correct k - parity result directly from the pure input bits during testing with a probability close to 1 - the error is exponentially small.

This matches the sample complexity of the explicit CoT, but no intermediate tokens need to be output during inference.

The proof process faces two main technical challenges, which the team overcomes with two design methods respectively:

The first challenge is "representation collapse". In a multi - layer Transformer, as the number of layers increases, the vector representations at each position tend to be uniform, losing their distinguishability, and the gradient signal also disappears. The team introduced "gated connections": each layer only "opens the door" and activates at the positions corresponding to the tree level, and the remaining positions remain closed. This allows the gradient signal of each layer to be precisely concentrated on the part of the task it should handle, avoiding the representation being averaged out.

The second challenge is "error propagation". In multi - stage training, small approximation errors in the early stages will be amplified layer by layer in the subsequent stages, ultimately drowning out the effective signal. The solution is to perform integer quantization (round to the nearest integer) on the attention weights after each gradient update. This seemingly rough operation has an accurate "locking" effect - for the already trained layers, the subsequent gradient update amount is extremely small, and quantization will directly round it back to the original value, keeping the early training results unchanged.

Layer - by - layer attention heat map after the training of a 4 - layer Transformer is completed. It can be seen that each layer precisely focuses on the corresponding level nodes of the tree.

Experiment: Achieving 100% Accuracy in 4 Stages

The theoretical proof needs experimental verification. The team conducted a complete experiment with n = 30 input bits and k = 16 (i.e., a 4 - layer Transformer and 4 training stages).

The training dynamics are highly consistent with the theoretical predictions. In the first stage, the complete Chain - of - Thought is visible, and the loss quickly drops to near zero. In each subsequent stage, half of the remaining Chain - of - Thought positions are replaced with all - zero padding, and the loss shows a brief spike - this corresponds to the moment when the model starts to "digest" a new layer of the Chain - of - Thought. The spike then quickly drops, and the model adapts to the new constraints.

At the end of the fourth stage, all Chain - of - Thought positions are filled with zeros, and the model only sees the original input bits, but the accuracy on the validation set reaches 100%.

The visualization of the attention weights further confirms the theoretical analysis: the attention of the first layer focuses on the node pairs of the first layer of the tree (pairwise input bits), the second layer focuses on the node pairs of the second layer, and so on. The model has indeed learned to "engrave" each layer of the Chain - of - Thought into the corresponding Transformer layer, rather than representing all information chaotically in one layer.

Conclusion

The contribution of this paper is, first of all, to fill a theoretical gap.

ICoT, as a practice, has been verified to be effective in practical tasks (such as arithmetic and reasoning problems) by several papers. However, there is a huge gap between "being effective" and "why it is effective" and "under what conditions it is guaranteed to be effective". This paper builds this bridge for the first time - using strict mathematical language to show that the implicit Chain - of - Thought is not just a coincidentally effective technique, but a provable training method under clear conditions.

This means that the "silent thinking" of inference models has obtained mathematical legitimacy for the first time.

From a longer - term perspective, this work points to an unachieved but clearly - directed goal: to systematically "compress" the long Chain - of - Thought of large - scale inference models into the hidden layers of the model through a structured curriculum training. By then, the model will still have complete reasoning ability, but what users perceive will only be the direct answer, without long waiting times and expensive thinking token bills.

Of course, there is still a long way to go from the current theoretical conclusion to engineering implementation. The paper itself also clearly points out that the current proof relies on several simplified assumptions: a fixed value matrix, preset gated weights, and a synthetic task structure represented by parity check. The challenge of applying Log - ICoT to real LLMs lies in how to design a reasonable "stage division" method without a clear hierarchical structure.

This article is from the WeChat official account "MachineHeart" (ID: almosthuman2014). The author is MachineHeart, which studies reasoning. It is published by 36Kr with authorization.