Подбирали количество рекурсий, обнаружили оптимальные значения для HRM T = 3, n = 3 (эквивалентно 48 рекурсиям) и для TRM T = 3, n = 6 (42 рекурсии), это на Sudoku-Extreme. TRM требует бэкпропа через всю глубину рекурсии (правда T на это не влияет, там T-1 шагов делаются без градиента), так что увеличение начинает приводить к Out of Memory. 🧪 Эксперименты Тестирование такое же, как и в статье для HRM: ARC-AGI-1 и -2, Sudoku-Extreme, Maze-Hard. В Sudoku-Extreme использовались 1K примеров для обучения и проверка на 423K примеров. Maze-Hard по 1000 примеров в обучении и тесте. То есть вроде как в HRM, может с поправкой на random seed и конкретные выборки тысячи примеров. Для ARC-AGI использовался также датасет ConceptARC для аугментации (это вроде не как в HRM, но похоже на то, что делала команда ARC-AGI в своей проверке). Аугментации тоже не уверен, что целиком повторяли таковые из статьи про HRM, надо копать глубже. Цифры для HRM в точности такие же как в оригинальной статье, так что видимо брали из самой статьи, но с другой стороны код для HRM в репе TRM тоже лежит. Общий результат, TRM достигает ещё более высоких цифр, чем HRM: * 74.7%/87.4% (версия с attention/версия с MLP) против 55% для Sudoku * 85.3% (версия с attention, версия с MLP даёт 0) против 74.5% для Maze * 44.6%/29.6% (attn/MLP) против 40.3% для ARC-AGI-1 * 7.8%/2.4% (attn/MLP) против 5.0% для ARC-AGI-2 Интересно, что для судоку лучше работает версия с MLP, для остальных, требующих большего контекста, лучше версия с вниманием. Версия TRM с вниманием содержала 7M параметров, версия с MLP -- 5M для Sudoku и 19M для остальных задач. HRM всегда была 27M. В приложении есть небольшая секция про идеи, которые не сработали. Среди таковых: * Замена SwiGLU MLP на SwiGLU MoE -- генерализация сильно просела, но возможно на большем количестве данных было бы по-другому. * пробовали проводить градиенты меньше, чем через всю рекурсию -- например, только через последние 4 шага -- никак не помогло, только всё усложнило. * убирание ACT всё ухудшило * общие веса для эмбеддингов входа и выхода всё ухудшили * замена рекурсии на fixed-point iteration из TocrhDEQ замедлило и ухудшило. Возможно, это лишнее подтверждение, что сходимость к неподвижной точке не важна. ARC-AGI проверили результаты для TRM (https://x.com/arcprize/status/1978872651180577060) - ARC-AGI-1: 40%, $1.76/task - ARC-AGI-2: 6.2%, $2.10/task Здесь разброс между статьёй и измерениями самих ARC меньше, чем был у HRM. TRM меньше, но рантайм жрёт больше (неудивительно при наличии рекурсии). Возможно, более хорошие результаты не от того, что модель умнее, а от того, что училась дольше? Не понял, насколько модели одинаковы по части затраченных FLOPS, было бы интересно посмотреть. --- Короче, работа прикольная, эмпирический результат интересный. Нет чувства, что глубоко понятна теоретическая часть, почему именно эти рекурсии работают так хорошо. Также эта работа -- прикольный пример какой-то архитектурной изобретательности в противовес вечному скейлингу моделей (хотя отскейлить эту конкретную тоже интересно, как и распространить её на другие классы задач). Думаю, будут развития. Эксперименты не выглядят сильно дорогими, рантайм от <24 часов до примерно трёх дней максимум на 4*H100, если верить данным в репе. Всем хороших рекурсий!
Подбирали количество рекурсий, обнаружили оптимальные значения для HRM T = 3, n…
Из этого канала
- #4143Агенты для исследования массово выходят в опенсорс. Сразу две работы за…
Агенты для исследования массово выходят в опенсорс. Сразу две работы за последнее время: Barbarians at the Gate: How AI is Upending Systems Research…
- #4144Если не видели, тут очередной курс по трансформерам выкладывают. CME 295 -…
Если не видели, тут очередной курс по трансформерам выкладывают. CME 295 - Transformers & Large Language Models This course explores the world of Transformers…
- #4145Дистилляцию в BitNet (тернарные веса и 1.58-битные модели) завезли!…
Дистилляцию в BitNet (тернарные веса и 1.58-битные модели) завезли! https://t.me/gonzoMLpodcasts/990
- #4129Интересно, что это отличается от латентного ризонинга в стиле Coconut…
Интересно, что это отличается от латентного ризонинга в стиле Coconut (https://t.me/gonzoML/3567), там он был на уровне токенов при авторегрессионной…
- #4128На входе у неё три элемента: input (x), latent (z) и prediction (y), они все…
На входе у неё три элемента: input (x), latent (z) и prediction (y), они все суммируются в одно значение.