Make meta-learners scikit-learn compliant via BaseEstimator#912
Make meta-learners scikit-learn compliant via BaseEstimator#912aman-coder03 wants to merge 8 commits into
Conversation
|
@aman-coder03. Can you rebase your changes against the latest master and resolve conflicts? Thanks. |
jeongyoonlee
left a comment
There was a problem hiding this comment.
Thanks for the contribution, @aman-coder03. The refactor itself is the right approach and is clean in base.py, slearner.py, rlearner (base), and drlearner.py. But the master merge (6408621) re-introduced the old #910 template machinery on top of the new code in tlearner and xlearner, and the cleanup commits missed it — tests pass only because the leftover old block runs last and yields the same numbers. A few things before we can merge.
Blocking
- tlearner botched merge —
__init__still keeps the olddeepcopy/self.model_c/self.model_tandself._model_c_template/_model_t_template(tlearner.py:62-76), and still defines__repr__/_unfitted_clone/import copy(:1,82,87-93) that this PR claims to remove; infit(),self.models_tis assigned twice (:146then:154) andself.model_cis fit twice (:150and:160), the leftover block overwriting the new one. Keep either the new block or the old, not both. - xlearner botched merge — same pattern:
__init__retainsself.model_mu_c/self._model_mu_c_template(xlearner.py:62-69) andfit()fitsself.model_mu_ctwice (:155and:159), recomputingmodels_mu_c/var_credundantly. Same fix. - XGBRRegressor breaks under the new clone path — its
__init__uses*args, **kwargs(rlearner.py:498-499) and stores a transformedeffect_learner_objective(:506), so now thatBaseLearneris aBaseEstimatorandfit_bootstrap_ensemblecallsclone(self),get_params()raisesRuntimeErrorandXGBRRegressor().fit(..., store_bootstraps=True)/predict(return_ci=True)will throw (the existing XGBRRegressor test doesn't hitreturn_ci, so CI stays green). Give it an explicit parameter signature or overrideget_params/ exclude it from the estimator contract.
Non-blocking
base.pystrips the full docstrings offget_feature_importances,get_shap_values,plot_importance,plot_shap_values,plot_shap_dependence, and_set_propensity_models— unrelated to sklearn compliance and a Sphinx-autodoc regression (and the PR checklist has "documentation" unchecked); please restore them.slearner.py:115has a duplicated comment line, one copy ending in a stray\(merge cruft).tlearner.pypredict docstring (~221-230) has duplicated/mangledReturns:blocks from the merge.fit()no longer returnsself— fine if deliberate, but note strict sklearncheck_estimatorwon't fully pass andlearner.fit(...).predict()chaining won't work; confirm that's intended given the "sklearn compliant" framing.
|
Reviewed the latest revision. The Blocking
Non-blocking
One doc follow-up: the |
|
@aman-coder03, let's wrap up this PR first before updating #901, as merging this will dissolve some of #901's blockers, and this is also smaller - less work to rebase against. |
|
A few items left on Blocking
Non-blocking (fine to fold into the same round)
|
Proposed changes
This PR makes all meta-learners (
BaseSLearner,BaseTLearner,BaseXLearner,BaseRLearner,BaseDRLearner) proper scikit-learn estimators by inheritingBaseEstimatorProblem
every
__init__was transforming its constructor arguments (e.g. deepcopyinglearnerintoself.model_c/self.model_t) instead of storing them verbatim, which....get_params/set_paramsclone()fall back todeepcopyof already-fitted models in the bootstrap path_unfitted_clone/_model_*_templateworkarounds introduced in Fix #904: Prevent deepcopy of fitted templates in bootstrap and correct predict validation ordering #910Solution
follows the standard scikit-learn convention...
__init__stores all arguments verbatim — no logic, nodeepcopyfit()get_params/set_paramswork for free viaBaseEstimatorclone(learner)works correctly withoutsafe=False_unfitted_clone/_model_*_templatemachinery removedclone(self)directlyPipeline/GridSearchCVcompatible out of the boxcloses #911
follow-up to #904 / #910
Types of changes
What types of changes does your code introduce to CausalML?
Put an
xin the boxes that applyChecklist
Put an
xin the boxes that apply. You can also fill these out after creating the PR. If you're unsure about any of them, don't hesitate to ask. We're here to help! This is simply a reminder of what we are going to look for before merging your code.Further comments
If this is a relatively large or complex change, kick off the discussion by explaining why you chose the solution you did and what alternatives you considered, etc. This PR template is adopted from appium.