Cached PredictionΒΆ
After fitting the model, call BaseModel.predict() to infer on test data.
model = Classifier()
model.fit(train_data, train_labels)
model.predict(test_data)
To prevent recreating the tensorflow graph with each call to BaseModel.predict(),
use the model.cached_predict() context manager.
model = Classifier()
model.fit(train_data, train_labels)
with model.cached_predict():
model.predict(test_data) # triggers prediction graph construction
model.predict(test_data) # graph is already cached, so subsequence calls are faster