Нейросеть для таблиц

Как TabM помогает в бизнесе, медицине и науке

Лаборатория исследований искусственного интеллекта Yandex Research представила новую нейросетевую архитектуру для работы с табличными данными — TabM. Она позволяет быстро обрабатывать большие массивы данных и строить высокоточные прогнозы, что востребовано в бизнесе, исследованиях и медицине. Модели для работы с табличными данными помогают оптимизировать поставки, прогнозировать энергопотребление, классифицировать пациентов по риску заболеваний и решать многие другие задачи.

Фото: Александр Миридонов, Коммерсантъ

Фото: Александр Миридонов, Коммерсантъ

Разработку использовали на Kaggle — платформе международных соревнований по анализу данных и машинному обучению от Google. В частности, новую архитектуру применяли для предсказания выживаемости пациентов после трансплантации костного мозга. За решение этой и других задач с помощью TabM призеры и победители Kaggle получили в совокупности $60 тыс.

TabM (от англ. Tabular DL model that makes Multiple predictions) — это эффективная реализация так называемого ансамбля моделей, когда каждая модель проводит свой анализ, после чего прогноз усредняется. Архитектура TabM позволяет добиться оптимального соотношения точности прогноза и необходимых вычислительных мощностей.

По результатам тестирования на 46 наборах данных TabM превзошла другие решения не только по занимаемому в среднем месту (1,7 у TabM против 2,9 у ближайшего конкурента), но и по стабильности работы, что важно для практического применения. Благодаря способности объединять усилия нескольких подмоделей и эффективному использованию вычислительных ресурсов, TabM успешно конкурирует с классическими моделями градиентного бустинга — CatBoost, XGBoost, LightGBM,— которые долгое время считались лучшим решением для табличных данных.

Архитектура уже доступна разработчикам и исследователям на GitHub, а научная статья — на arXiv.

С 2019 года исследователи Yandex Research опубликовали восемь научных статей по глубокому обучению моделей для работы с табличными данными. В общей сложности статьи получили более 1,9 тыс. цитирований. В частности, статью о TabM цитировали Университет Мангейма (Германия), Национальный университет Сингапура, Корейский университет, Иллинойсский университет в Урбане-Шампейне. В разные годы статьи были приняты на самые влиятельные конференции по ИИ, в том числе NeurIPS, ICLR и ICML.

Артем Бабенко, руководитель лаборатории Yandex Research, рассказал «Ъ-Науке» о том, как новая нейросеть анализирует табличные данные, как она применяется в медицине, финансах и логистике и почему важна стабильность ее работы.

— Что такое TabM и в чем ее основное преимущество перед другими моделями для табличных данных?

— TabM — это нейросетевая архитектура для работы с табличными данными, построенная на основе многослойного персептрона (MLP), в которой реализован подход эффективного ансамблирования. Главная фишка TabM — она позволяет получать сразу несколько предсказаний с помощью одной модели и благодаря этому достигает такой же или даже лучшей точности, как у традиционных ансамблей, но при этом потребляет сильно меньше вычислительных ресурсов. На данный момент TabM — простой и мощный инструмент для задач машинного обучения с табличными данными, который по соотношению «качество—стоимость использования» обходит большинство существующих альтернатив, как нейросетевых, так и моделей на решающих деревьях.

— Как работает ансамблевый подход в TabM и почему он эффективен?

— Ансамблирование в TabM реализуется через так называемое неявную параметризацию: внутри одной нейросети фактически существует несколько «виртуальных» моделей, каждая из которых делает свое предсказание. Итоговое предсказание формируется усреднением предсказаний индивидуальных моделей. Большинство весов у моделей общие, и лишь небольшие адаптеры (специальные множители) уникальны для каждой из них, что практически не увеличивает размер модели. Если обучать много отдельных нейросетей, то каждая из них по отдельности может ошибаться и даже «запоминать» шум в данных — то есть переобучаться. Но когда мы объединяем их предсказания, случайные ошибки отдельных моделей друг друга компенсируют и итоговый ответ становится более точным.

Недавно мы выпустили Python-пакет, упрощающий использование TabM. Он включает реализацию модели на PyTorch, готовые слои и функции для создания собственных ансамблей, а также примеры на Jupyter и Google Colab.

— В каких сферах может применяться TabM?

— TabM — универсальный инструмент в области глубокого обучения на табличных данных. Это могут быть самые разные задачи: кредитный скоринг в банках, прогнозирование спроса и цен в ритейле, анализ медицинских данных, логистика, промышленный интернет вещей, онлайн-реклама и многое другое. В тех сферах, где сейчас стандартом де-факто считаются решения на основе классических моделей градиентного бустинга (CatBoost, XGBoost или LightGBM), также может эффективно применяться и TabM. Это особенно актуально, если одновременно важны точность, стабильность и высокая вычислительная эффективность на больших или сложных датасетах.

— Какие результаты показала TabM в соревнованиях Kaggle и какие задачи с ее помощью решали?

— Наша архитектура стала востребованным инструментом в самых популярных в комьюнити международных соревнованиях по машинному обучению на платформе Kaggle. В престижном медицинском конкурсе от CIBMTR (Center for International Blood and Marrow Transplant Research), где требовалось предсказать выживаемость пациентов после трансплантации костного мозга, TabM вошла в состав решений, занявших сразу четыре призовых места. В соревновании Маастрихтского университета по анализу MCTS-алгоритмов для игр TabM также применялась в решении команды-победителя. А в соревновании Jane Street Capital по прогнозированию рыночных данных в реальном времени архитектура появлялась в открытых примерах решений (публичных ноутбуках), которые часто служат отправной точкой для построения успешных моделей. В сумме команды, применявшие TabM, выиграли $60 тыс. призовых.

— Насколько TabM превосходит классические алгоритмы (CatBoost, XGBoost, LightGBM)?

— По результатам тестирования на 46 наборах данных TabM превзошла другие решения не только по среднему месту в общем рейтинге (1,7 у TabM против 2,9 у ближайшего конкурента), но и по стабильности работы. Именно стабильность крайне важна для практического применения: TabM демонстрирует высокие результаты на разных наборах данных и сценариях. CatBoost, XGBoost и LightGBM, которые традиционно считались «золотым стандартом» для работы с табличными данными, идут вровень с TabM, иногда даже уступая нашей архитектуре. TabM можно рассматривать как современную и удобную альтернативу моделям градиентного бустинга: при очень близком или равном качестве она часто оказывается быстрее и удобнее при работе с большими датасетами.

— Как TabM сочетает высокую точность и вычислительную эффективность?

— Обычно для ансамбля нужно обучать целое семейство отдельных моделей, что увеличивает время работы и расходует много вычислительных ресурсов. В TabM мы добиваемся эффекта ансамбля гораздо проще: все строится внутри одной нейросети, где большая часть параметров общая для всех «виртуальных» моделей, а индивидуальные настройки задаются только специальными малыми адаптерами. Это позволяет быстро обучать и применять модель даже на очень больших датасетах — например, если задача требует быстро обрабатывать миллионы данных в реальном времени. В результате TabM отлично подходит для промышленного применения, где важна и точность, и скорость.

— Где можно найти реализацию TabM и научную статью о ней?

— Научная работа была представлена на ICLR — одной из крупнейших в мире конференций по искусственному интеллекту — и опубликована в архиве научных статей препринтов Корнеллского университета. Сама архитектура выложена в открытом доступе на GitHub.

— Почему стабильность работы TabM важна для ее практического применения?

— В реальных задачах одна из ключевых проблем — это когда модель ведет себя непредсказуемо на новых данных или нестандартных выборках. Для бизнеса особенно важно, чтобы такая ситуация не возникала. TabM как раз отличается высокой стабильностью: на разных выборках и типах задач она демонстрирует предсказуемое качество — без резких спадов, которые могут возникать у других современных нейросетевых моделей.

Татьяна Репина