Cached PredictionΒΆ

After fitting the model, call BaseModel.predict() to infer on test data.

model = Classifier()
model.fit(train_data, train_labels)
model.predict(test_data)

To prevent recreating the tensorflow graph with each call to BaseModel.predict(), use the model.cached_predict() context manager.

model = Classifier()
model.fit(train_data, train_labels)
with model.cached_predict():
    model.predict(test_data) # triggers prediction graph construction
    model.predict(test_data) # graph is already cached, so subsequence calls are faster