Попаднах на точно същия проблем и човече беше заешка дупка. Исках да публикувам решението си тук, тъй като може да спести на някого ден работа:
Структури от данни, специфични за нишките на TensorFlow
В TensorFlow има две ключови структури от данни, които работят зад кулисите, когато извикате model.predict
(или keras.models.load_model
или keras.backend.clear_session
, или почти всяка друга функция, взаимодействаща с бекенда на TensorFlow):
- Графика TensorFlow, която представя структурата на вашия модел Keras
- Сесия TensorFlow, която е връзката между текущата ви графика и времето за изпълнение на TensorFlow
Нещо, което не е изрично ясно в документите без известно копаене, е, че и сесията, и графиката са свойства на текущата нишка . Вижте документи за API тук и тук.
Използване на модели TensorFlow в различни нишки
Естествено е да искате да заредите модела си веднъж и след това да извикате .predict()
върху него няколко пъти по-късно:
from keras.models import load_model
MY_MODEL = load_model('path/to/model/file')
def some_worker_function(inputs):
return MY_MODEL.predict(inputs)
В контекста на уеб сървър или работнически пул като Celery, това означава, че ще заредите модела, когато импортирате модула, съдържащ load_model
ред, тогава друга нишка ще изпълни some_worker_function
, изпълняващ прогнозиране върху глобалната променлива, съдържаща модела на Keras. Въпреки това, опитът за изпълнение на прогнозиране върху модел, зареден в друга нишка, води до грешки "тензорът не е елемент от тази графика". Благодарение на няколкото SO публикации, които засягат тази тема, като ValueError:Tensor Tensor(...) не е елемент от тази графика. При използване на глобална променлива keras модел. За да накарате това да работи, трябва да се придържате към графиката TensorFlow, която е била използвана - както видяхме по-рано, графиката е свойство на текущата нишка. Актуализираният код изглежда така:
from keras.models import load_model
import tensorflow as tf
MY_MODEL = load_model('path/to/model/file')
MY_GRAPH = tf.get_default_graph()
def some_worker_function(inputs):
with MY_GRAPH.as_default():
return MY_MODEL.predict(inputs)
Донякъде изненадващият обрат тук е:горният код е достатъчен, ако използвате Thread
s, но виси за неопределено време, ако използвате Process
ес. И по подразбиране Celery използва процеси за управление на всички свои работни групи. Така че в този момент нещата са все още не работи на Celery.
Защо това работи само в Thread
s?
В Python, Thread
s споделят същия глобален контекст на изпълнение като родителския процес. От документите на Python _thread:
Този модул предоставя примитиви от ниско ниво за работа с множество нишки (наричани още леки процеси или задачи) — множество нишки на контрол, споделящи своето глобално пространство за данни.
Тъй като нишките не са действителни отделни процеси, те използват един и същ интерпретатор на python и по този начин са обект на прословутото Global Interpeter Lock (GIL). Може би по-важното за това разследване е, че тесподелят глобално пространство за данни с родителя.
За разлика от това, Process
es са действителни нови процеси, породени от програмата. Това означава:
- Нов екземпляр на интерпретатора на Python (и без GIL)
- Глобалното адресно пространство е дублирано
Обърнете внимание на разликата тук. Докато Thread
имат достъп до споделена единична глобална променлива на сесията (съхранена вътрешно в tensorflow_backend
модул на Keras), Process
es имат дубликати на променливата Session.
Най-доброто ми разбиране за този проблем е, че променливата Session трябва да представлява уникална връзка между клиент (процес) и времето за изпълнение на TensorFlow, но като се дублира в процеса на разклоняване, тази информация за връзката не се коригира правилно. Това кара TensorFlow да виси, когато се опитва да използва сесия, създадена в различен процес. Ако някой има повече представа за това как работи това под капака в TensorFlow, ще се радвам да го чуя!
Решението / Заобиколно решение
Настроих Celery така, че да използва Thread
s вместо Process
es за обединяване. Има някои недостатъци на този подход (вижте коментара на GIL по-горе), но това ни позволява да заредим модела само веднъж. Така или иначе всъщност не сме обвързани с процесора, тъй като времето за изпълнение на TensorFlow максимизира всички процесорни ядра (това може да заобиколи GIL, тъй като не е написан на Python). Трябва да предоставите на Celery отделна библиотека, за да правите обединяване на базата на нишки; документите предлагат две опции:gevent
или eventlet
. След това предавате избраната от вас библиотека в работния файл чрез --pool
аргумент на командния ред.
Като алтернатива изглежда (както вече разбрахте @pX0r), че други бекендове на Keras като Theano нямат този проблем. Това има смисъл, тъй като тези проблеми са тясно свързани с подробностите за внедряването на TensorFlow. Аз лично все още не съм пробвал Theano, така че пробегът ви може да варира.
Знам, че този въпрос беше публикуван преди време, но проблемът все още е там, така че се надявам това да помогне на някой!