Cached PredictionΒΆ
After fitting the model, call BaseModel.predict()
to infer on test data.
model = Classifier()
model.fit(train_data, train_labels)
model.predict(test_data)
To prevent recreating the tensorflow graph with each call to BaseModel.predict()
,
use the model.cached_predict()
context manager.
model = Classifier()
model.fit(train_data, train_labels)
with model.cached_predict():
model.predict(test_data) # triggers prediction graph construction
model.predict(test_data) # graph is already cached, so subsequence calls are faster