Skip to content

Make meta-learners scikit-learn compliant via BaseEstimator#912

Open
aman-coder03 wants to merge 8 commits into
uber:masterfrom
aman-coder03:feature/sklearn-compliant-meta-learners
Open

Make meta-learners scikit-learn compliant via BaseEstimator#912
aman-coder03 wants to merge 8 commits into
uber:masterfrom
aman-coder03:feature/sklearn-compliant-meta-learners

Conversation

@aman-coder03

Copy link
Copy Markdown
Contributor

Proposed changes

This PR makes all meta-learners (BaseSLearner, BaseTLearner, BaseXLearner, BaseRLearner, BaseDRLearner) proper scikit-learn estimators by inheriting BaseEstimator

Problem

every __init__ was transforming its constructor arguments (e.g. deepcopying learner into self.model_c / self.model_t) instead of storing them verbatim, which....

Solution

follows the standard scikit-learn convention...

  • __init__ stores all arguments verbatim — no logic, no deepcopy
  • all model construction moves to fit()
  • get_params / set_params work for free via BaseEstimator
  • clone(learner) works correctly without safe=False
  • _unfitted_clone / _model_*_template machinery removed
  • bootstrap uses clone(self) directly
  • Pipeline / GridSearchCV compatible out of the box

closes #911
follow-up to #904 / #910

Types of changes

What types of changes does your code introduce to CausalML?
Put an x in the boxes that apply

  • Bugfix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Documentation Update (if none of the other choices apply)

Checklist

Put an x in the boxes that apply. You can also fill these out after creating the PR. If you're unsure about any of them, don't hesitate to ask. We're here to help! This is simply a reminder of what we are going to look for before merging your code.

  • I have read the CONTRIBUTING doc
  • I have signed the CLA
  • Lint and unit tests pass locally with my changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have added necessary documentation (if appropriate)
  • Any dependent changes have been merged and published in downstream modules

Further comments

If this is a relatively large or complex change, kick off the discussion by explaining why you chose the solution you did and what alternatives you considered, etc. This PR template is adopted from appium.

@aman-coder03 aman-coder03 changed the title make meta-learners sklearn-compliant via BaseEstimator Make meta-learners scikit-learn compliant via BaseEstimator Jun 13, 2026
@jeongyoonlee

Copy link
Copy Markdown
Collaborator

@aman-coder03. Can you rebase your changes against the latest master and resolve conflicts? Thanks.

@jeongyoonlee jeongyoonlee added the enhancement New feature or request label Jun 15, 2026

@jeongyoonlee jeongyoonlee left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the contribution, @aman-coder03. The refactor itself is the right approach and is clean in base.py, slearner.py, rlearner (base), and drlearner.py. But the master merge (6408621) re-introduced the old #910 template machinery on top of the new code in tlearner and xlearner, and the cleanup commits missed it — tests pass only because the leftover old block runs last and yields the same numbers. A few things before we can merge.

Blocking

  • tlearner botched merge__init__ still keeps the old deepcopy/self.model_c/self.model_t and self._model_c_template/_model_t_template (tlearner.py:62-76), and still defines __repr__/_unfitted_clone/import copy (:1,82,87-93) that this PR claims to remove; in fit(), self.models_t is assigned twice (:146 then :154) and self.model_c is fit twice (:150 and :160), the leftover block overwriting the new one. Keep either the new block or the old, not both.
  • xlearner botched merge — same pattern: __init__ retains self.model_mu_c/self._model_mu_c_template (xlearner.py:62-69) and fit() fits self.model_mu_c twice (:155 and :159), recomputing models_mu_c/var_c redundantly. Same fix.
  • XGBRRegressor breaks under the new clone path — its __init__ uses *args, **kwargs (rlearner.py:498-499) and stores a transformed effect_learner_objective (:506), so now that BaseLearner is a BaseEstimator and fit_bootstrap_ensemble calls clone(self), get_params() raises RuntimeError and XGBRRegressor().fit(..., store_bootstraps=True) / predict(return_ci=True) will throw (the existing XGBRRegressor test doesn't hit return_ci, so CI stays green). Give it an explicit parameter signature or override get_params / exclude it from the estimator contract.

Non-blocking

  • base.py strips the full docstrings off get_feature_importances, get_shap_values, plot_importance, plot_shap_values, plot_shap_dependence, and _set_propensity_models — unrelated to sklearn compliance and a Sphinx-autodoc regression (and the PR checklist has "documentation" unchecked); please restore them.
  • slearner.py:115 has a duplicated comment line, one copy ending in a stray \ (merge cruft).
  • tlearner.py predict docstring (~221-230) has duplicated/mangled Returns: blocks from the merge.
  • fit() no longer returns self — fine if deliberate, but note strict sklearn check_estimator won't fully pass and learner.fit(...).predict() chaining won't work; confirm that's intended given the "sklearn compliant" framing.

@jeongyoonlee

Copy link
Copy Markdown
Collaborator

Reviewed the latest revision. The BaseEstimator refactor is sound — I verified the #904/#910 warm-start invariant still holds (every fit() deepcopies the verbatim-stored arg before fitting; self.learner stays unfitted across refits), all 31 test_meta_learners.py pass, and clone()/get_params() work across all learners. A few issues remain, two of them merge artifacts from the unresolved conflicts.

Blocking

  • BaseTClassifier.predict (tlearner.py:487): the return_ci and return_components check was moved back to the bottom of the method, reverting Fix #904: Prevent deepcopy of fitted templates in bootstrap and correct predict validation ordering #910's deliberate fail-fast ordering — base BaseTLearner.predict still checks at the top, so the two are now inconsistent (the classifier runs all prediction work before raising). Move it back to the top.
  • XGBRRegressor (rlearner.py:512): dropping *args/**kwargs is an undeclared breaking change — XGBRRegressor(max_depth=4, ...) now raises TypeError with no replacement passthrough; either restore a verbatim-stored **kwargs path or declare/document the break (the "Breaking change" box is unchecked).

Non-blocking

  • XGBRRegressor.predict (rlearner.py:686) is byte-identical to the inherited BaseRLearner.predict (rlearner.py:170) — remove the duplicate.
  • Removing self.propensity = None / self.propensity_model = None from __init__ makes estimate_ate(pretrain=True) before fit() raise AttributeError instead of the intended ValueError("no propensity score, please call fit() first") (xlearner.py:320) — keep the initializers or guard with hasattr.
  • __repr__ removed from all learners falls back to BaseEstimator's default repr, which changes repr output in notebooks/doctests.
  • BaseXClassifier.__init__ (xlearner.py:449) passes resolved values (control_outcome_learner or outcome_learner) to super() rather than storing verbatim; clone() still round-trips, but it diverges from the store-verbatim convention this PR establishes.

One doc follow-up: the _model_*_template mechanism is gone, so any internal notes referencing it are now stale — the invariant holds via the verbatim-store + deepcopy-in-fit() path instead.

@jeongyoonlee

Copy link
Copy Markdown
Collaborator

@aman-coder03, let's wrap up this PR first before updating #901, as merging this will dissolve some of #901's blockers, and this is also smaller - less work to rebase against.

@jeongyoonlee

Copy link
Copy Markdown
Collaborator

A few items left on 5f999bc:

Blocking

  • Tests. The PR still adds none, and the XGBRRegressor fix is exactly the path CI can't see — the existing XGBRRegressor() test never exercises return_ci, so green CI does not prove the clone path works (as noted last round). Please add:
    • clone(XGBRRegressor()) + get_params() round-trip (incl. xgb_kwargs), and XGBRRegressor().fit(..., store_bootstraps=True)predict(return_ci=True);
    • a clone()/get_params contract test for each of S/T/X/R/DR — regressor and classifier (the classifier overrides are where the botched merge hid last time);
    • one fixed-seed bit-identical check (predict/estimate_ate/return_ci) as an equivalence guard.

Non-blocking (fine to fold into the same round)

  • XGBRRegressor get_params/set_params overrides are now redundantxgb_kwargs is an explicit param, so BaseEstimator already surfaces and routes it. Drop both, and store self.xgb_kwargs = xgb_kwargs verbatim (coalesce with or {} in fit()) instead of ... if ... else {} in __init__.
  • xgb_kwargs is still a breaking call-signature change vs the old **kwargs (XGBRRegressor(max_depth=4) no longer works). Either tick "Breaking change" + add a one-line migration note, or accept **kwargs and fold into xgb_kwargs.
  • fit() still doesn't return self (raised last round). Given the "Pipeline / GridSearchCV out of the box" goal, return self from every fit() override (leave fit_predict returning predictions) — or confirm the deviation is intended.
  • DR vs X consistency: X added a self.propensity = {} sentinel so estimate_ate(pretrain=True) before fit() fails cleanly; DR dropped its self.propensity = None with no equivalent (→ AttributeError). Make them consistent.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Make meta-learners scikit-learn-compliant estimators (inherit BaseEstimator)

2 participants