diff --git a/lambench/metrics/post_process.py b/lambench/metrics/post_process.py index 7a8e769..d2df082 100644 --- a/lambench/metrics/post_process.py +++ b/lambench/metrics/post_process.py @@ -16,6 +16,7 @@ exp_average, aggregated_nve_md_results, aggregated_inference_efficiency_results, + aggregated_diatomics_results, get_leaderboard_models, ) @@ -173,6 +174,10 @@ def process_applicability_task_for_one_model(model: BaseLargeAtomModel): applicability_results[record.task_name] = ( aggregated_inference_efficiency_results(record.metrics) ) + elif record.task_name == "homonuclear_diatomics": + applicability_results[record.task_name] = aggregated_diatomics_results( + record.metrics + ) return applicability_results diff --git a/lambench/metrics/utils.py b/lambench/metrics/utils.py index baaa803..64f4a79 100644 --- a/lambench/metrics/utils.py +++ b/lambench/metrics/utils.py @@ -151,6 +151,61 @@ def aggregated_inference_efficiency_results( } +def aggregated_diatomics_results(results: dict[str, dict]) -> dict[str, float]: + """ + Aggregate per-molecule diatomics results into summary applicability metrics. + + Leaderboard metric: + combined_roughness: avg_roughness × (1 + avg(min_pos_err / r_range)) + Multiplicative penalty couples smoothness with position accuracy. + min_pos_err / r_range ∈ [0, 1], no free parameters. + + Stored diagnostic metrics (not used for ranking): + avg_roughness: arithmetic mean of per-molecule RMSE(d²residual/dr²) (eV/Ų). + avg_min_position_error: arithmetic mean of per-molecule |r_model_min - r_dft_min| (Å). + Molecules without exactly one minimum contribute r_range as penalty. + avg_rmse: arithmetic mean of per-molecule energy RMSE (eV). + """ + roughness_values = [] + normalized_pos_err_values = [] + position_error_values = [] + rmse_values = [] + + for mol_results in results.values(): + if mol_results is None: + continue + if mol_results.get("roughness") is not None: + roughness_values.append(mol_results["roughness"]) + r_range = mol_results.get("r_range") + min_pos_err = mol_results.get("min_position_error") + if min_pos_err is not None and r_range is not None and r_range > 0: + normalized_pos_err_values.append(min_pos_err / r_range) + position_error_values.append(min_pos_err) + if mol_results.get("rmse") is not None: + rmse_values.append(mol_results["rmse"]) + + if not roughness_values: + return { + "combined_roughness": None, + "avg_roughness": None, + "avg_min_position_error": None, + "avg_rmse": None, + } + + avg_roughness = float(np.mean(roughness_values)) + avg_norm_pos_err = ( + float(np.mean(normalized_pos_err_values)) if normalized_pos_err_values else 0.0 + ) + return { + "combined_roughness": float(avg_roughness * (1 + avg_norm_pos_err)), + "avg_roughness": avg_roughness, + "avg_min_position_error": float(np.mean(position_error_values)) + if position_error_values + else None, + "avg_rmse": float(np.mean(rmse_values)) if rmse_values else None, + } + + #################################### # Visualization utility functions # #################################### diff --git a/lambench/metrics/vishelper/metrics_calculations.py b/lambench/metrics/vishelper/metrics_calculations.py index ddea233..69df200 100644 --- a/lambench/metrics/vishelper/metrics_calculations.py +++ b/lambench/metrics/vishelper/metrics_calculations.py @@ -197,6 +197,21 @@ def _calculate_instability_error(self, cell: dict, lambda_0: float = 5e-4) -> fl else: return np.clip(np.log10(slope / lambda_0), a_min=0, a_max=None) + def calculate_diatomics_roughness_results(self) -> dict[str, float]: + """ + Returns per-model leaderboard scores for the homonuclear diatomics task. + + Score = combined_roughness = avg_roughness × (1 + avg(min_pos_err/r_range)), lower is better. + Diagnostic metrics (avg_roughness, avg_min_position_error, avg_rmse) stored in DB but not ranked. + Models with missing results are excluded. + """ + raw = self.fetcher.fetch_diatomics_results() + return { + model: metrics["combined_roughness"] + for model, metrics in raw.items() + if metrics is not None and metrics.get("combined_roughness") is not None + } + def calculate_efficiency_results(self) -> dict[str, float]: efficiency_results = self.fetcher.fetch_inference_efficiency_results() # filter out models with missing efficiency results @@ -223,6 +238,7 @@ def summarize_final_rankings(self): ) stability_results = self.calculate_stability_results() efficiency_results = self.calculate_efficiency_results() + roughness_results = self.calculate_diatomics_roughness_results() if not generalizability_ood or not generalizability_downstream: logging.warning( "Missing data for generalizability metrics (ood or downstream)" @@ -255,6 +271,9 @@ def summarize_final_rankings(self): "Applicability-Efficiency ↑": [ efficiency_results[model] for model in shared_models ], + "Applicability-Roughness ↓": [ + roughness_results.get(model) for model in shared_models + ], } # Create DataFrame with models as index @@ -273,8 +292,9 @@ def summarize_final_rankings(self): "Generalizability-PC Error ↓", "Applicability-Instability ↓", "Applicability-Efficiency ↑", + "Applicability-Roughness ↓", ], - ascending=[True, True, True, False], + ascending=[True, True, True, False, True], ) print( "Final Rankings:\n", diff --git a/lambench/metrics/vishelper/results_fetcher.py b/lambench/metrics/vishelper/results_fetcher.py index f03e27c..e6c3131 100644 --- a/lambench/metrics/vishelper/results_fetcher.py +++ b/lambench/metrics/vishelper/results_fetcher.py @@ -12,6 +12,7 @@ get_domain_to_direct_task_mapping, get_leaderboard_models, aggregated_inference_efficiency_results, + aggregated_diatomics_results, ) from lambench.models.basemodel import BaseLargeAtomModel import pandas as pd @@ -127,6 +128,24 @@ def fetch_inference_efficiency_results(self) -> dict[str, dict[str, float]]: ) return results + def fetch_diatomics_results(self) -> dict[str, dict]: + """Returns aggregated diatomics roughness results for all leaderboard models.""" + results = {} + for model in self.leaderboard_models: + task_results = CalculatorRecord.query( + model_name=model.model_name, task_name="homonuclear_diatomics" + ) + if len(task_results) != 1: + logging.warning( + f"Expected one record for {model.model_name} and homonuclear_diatomics, " + f"but got {len(task_results)}" + ) + continue + results[model.model_metadata.pretty_name] = aggregated_diatomics_results( + task_results[0].metrics + ) + return results + def fetch_downstream_results(self) -> pd.DataFrame: """Returns downstream task results as a DataFrame with models as rows and task metrics as columns.""" diff --git a/lambench/models/ase_models.py b/lambench/models/ase_models.py index e1f20b4..367fe16 100644 --- a/lambench/models/ase_models.py +++ b/lambench/models/ase_models.py @@ -332,6 +332,10 @@ def evaluate( from lambench.tasks.calculator.surface.surface import run_inference assert task.test_data is not None + return {"metrics": run_inference(self, task.test_data)} + elif task.task_name == "homonuclear_diatomics": + from lambench.tasks.calculator.diatomics.diatomics import run_inference + return {"metrics": run_inference(self, task.test_data)} else: raise NotImplementedError(f"Task {task.task_name} is not implemented.") diff --git a/lambench/tasks/calculator/calculator_tasks.yml b/lambench/tasks/calculator/calculator_tasks.yml index 24bcdb5..0c1aecd 100644 --- a/lambench/tasks/calculator/calculator_tasks.yml +++ b/lambench/tasks/calculator/calculator_tasks.yml @@ -51,3 +51,6 @@ interface: surface: test_data: /bohr/lambench-surfaces-43ll/v1/surface calculator_params: null +homonuclear_diatomics: + test_data: null + calculator_params: null diff --git a/lambench/tasks/calculator/diatomics/diatomics.json b/lambench/tasks/calculator/diatomics/diatomics.json new file mode 100644 index 0000000..4e1ccaa --- /dev/null +++ b/lambench/tasks/calculator/diatomics/diatomics.json @@ -0,0 +1,1698 @@ +[ + { + "name": "FF", + "method": "PBE", + "R": [ + 0.57, + 0.77821052, + 0.98642106, + 1.19463158, + 1.4028421, + 1.61105264, + 1.81926316, + 2.02747368, + 2.23568422, + 2.44389474, + 2.65210526, + 2.86031578, + 3.06852632, + 3.27673684, + 3.48494736, + 3.6931579, + 3.90136842, + 4.10957894, + 4.31778948, + 4.526 + ], + "E": [ + 154.88514885, + 36.96380847, + 6.05981993, + -2.26124562, + -3.77595519, + -3.35903872, + -2.49608131, + -1.67387083, + -2.0229941, + -1.4131129, + -1.56698144, + -1.42110229, + -1.31664677, + -1.52899743, + -1.5303025, + -1.53619596, + -1.53555029, + -1.53648081, + -1.53125219, + -1.53143593 + ], + "F": [ + 1058.39660978, + 263.42078266, + 73.58367893, + 17.18716842, + 0.76238048, + -3.69242918, + -4.24336118, + -3.35823412, + -1.35965055, + 0.54258923, + -0.82109989, + -0.58326945, + -0.39742174, + 0.00684241, + -0.05691121, + 0.00167606, + 0.0164027, + -0.00693765, + -0.00301217, + 0.0 + ], + "S^2": [ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.000016, + 0.777924, + 0.708964, + 0.7569, + 0.748225, + 0.741321, + 0.712336, + 0.712336, + 0.712336, + 0.712336, + 0.712336, + 0.712336, + 0.712336 + ] + }, + { + "name": "AlAl", + "method": "PBE", + "R": [ + 1.21, + 1.41589286, + 1.62178572, + 1.82767858, + 2.03357142, + 2.23946428, + 2.44535714, + 2.65125, + 2.85714286, + 3.06303572, + 3.26892858, + 3.47482142, + 3.68071428, + 3.88660714, + 4.0925, + 4.29839286, + 4.50428572, + 4.71017858, + 4.91607142, + 5.12196428, + 5.32785714, + 5.53375, + 5.73964286, + 5.94553572, + 6.15142858, + 6.35732142, + 6.56321428, + 6.76910714, + 6.975 + ], + "E": [ + 21.69553068, + 11.14410695, + 4.46990509, + 0.69626246, + -1.21676381, + -2.04208305, + -2.27886427, + -2.22546192, + -2.23126744, + -2.13493129, + -1.99294153, + -1.83531198, + -1.67788789, + -1.5284843, + -1.39083808, + -1.26651214, + -1.1571301, + -1.05868295, + -0.97419251, + -0.90180437, + -0.83982298, + -0.78814601, + -0.74514801, + -0.71051327, + -0.68320322, + -0.66276761, + -0.64851475, + -0.64011496, + -0.57647315 + ], + "F": [ + 60.57266875, + 41.354763, + 24.44625509, + 13.06806215, + 6.14128837, + 2.26162033, + 0.25579529, + 0.41602807, + -0.27806619, + -0.61013121, + -0.74631893, + -0.7757976, + -0.7518143, + -0.70137733, + -0.64001526, + -0.57416177, + -0.48706032, + -0.44642018, + -0.38713862, + -0.33408605, + -0.27793413, + -0.22865668, + -0.1892725, + -0.15183375, + -0.11960044, + -0.08665681, + -0.0572195, + -0.02949642, + 0.0 + ], + "S^2": [ + 0.106276, + 0.098596, + 0.058564, + 0.038809, + 0.0289, + 0.023409, + 0.020449, + 0.025281, + 0.022801, + 0.020736, + 0.019321, + 0.018225, + 0.017424, + 0.0169, + 0.016641, + 0.016384, + 0.016384, + 0.016129, + 0.016384, + 0.016384, + 0.016641, + 0.0169, + 0.016641, + 0.0169, + 0.0169, + 0.017161, + 0.017161, + 0.017161, + 0.017161 + ] + }, + { + "name": "SiSi", + "method": "PBE", + "R": [ + 1.11, + 1.31282142, + 1.51564286, + 1.71846428, + 1.92128572, + 2.12410714, + 2.32692858, + 2.52975, + 2.73257142, + 2.93539286, + 3.13821428, + 3.34103572, + 3.54385714, + 3.74667858, + 3.9495, + 4.15232142, + 4.35514286, + 4.55796428, + 4.76078572, + 4.96360714, + 5.16642858, + 5.36925, + 5.57207142, + 5.77489286, + 5.97771428, + 6.18053572, + 6.38335714, + 6.58617858, + 6.789 + ], + "E": [ + 33.12026864, + 13.25774131, + 2.00110273, + -2.49417775, + -4.44198323, + -5.1071608, + -5.24703455, + -4.98742036, + -4.54345777, + -4.04252783, + -3.54839397, + -3.0996546, + -2.68388013, + -2.32916398, + -2.38126229, + -2.24245054, + -2.13768097, + -2.03626856, + -1.9573986, + -1.88489854, + -1.84197426, + -1.80103498, + -1.78620606, + -1.74383145, + -1.68289036, + -1.7109778, + -1.70150633, + -1.69604422, + -0.87555372 + ], + "F": [ + 124.45839432, + 68.53195105, + 33.04559539, + 13.37175758, + 6.21281285, + 0.81988881, + -0.47728931, + -1.88020767, + -2.39683204, + -2.48400236, + -2.35926277, + -2.20970491, + -1.87359833, + -1.61397952, + -0.73841286, + -0.62991633, + -0.52491607, + -0.43387721, + -0.35474778, + -0.31693349, + -0.22822001, + -0.1881757, + -0.13374293, + -0.11164342, + -0.04003388, + -0.0597873, + -0.03960887, + -0.02345417, + 0.0 + ], + "S^2": [ + 0.0, + 0.0, + 0.0, + 0.0, + 0.130321, + 0.114921, + 0.069696, + 0.067081, + 0.066049, + 0.067081, + 0.068644, + 0.071289, + 0.072361, + 0.074529, + 0.310249, + 0.311364, + 0.314721, + 0.3136, + 0.315844, + 0.315844, + 0.318096, + 0.319225, + 0.319225, + 0.321489, + 0.326041, + 0.321489, + 0.322624, + 0.322624, + 0.341056 + ] + }, + { + "name": "SS", + "method": "PBE", + "R": [ + 1.05, + 1.250375, + 1.45075, + 1.651125, + 1.8515, + 2.051875, + 2.25225, + 2.452625, + 2.653, + 2.853375, + 3.05375, + 3.254125, + 3.4545, + 3.654875, + 3.85525, + 4.055625, + 4.256, + 4.456375, + 4.65675, + 4.857125, + 5.0575, + 5.257875, + 5.45825, + 5.658625, + 5.859 + ], + "E": [ + 56.6632406, + 17.59053523, + -0.48155823, + -5.78096745, + -7.12850504, + -6.90163556, + -6.10408707, + -5.1706445, + -4.26951883, + -3.48454744, + -2.82856388, + -2.61429184, + -2.45644616, + -2.31281962, + -2.30469358, + -2.21163025, + -2.13131514, + -2.07984364, + -2.03543091, + -2.00166027, + -2.14408425, + -2.14443081, + -2.14453532, + -2.14455228, + -1.89473979 + ], + "F": [ + 317.15397377, + 120.14746006, + 43.86271234, + 13.66400583, + 1.60573497, + -3.13060913, + -4.51443905, + -4.61397552, + -4.20957069, + -3.59285054, + -2.95131997, + -0.82030173, + -0.74309327, + -0.58769699, + -0.52709542, + -0.41845518, + 0.03331956, + -0.25494417, + -0.18020652, + -0.14812717, + 0.00442569, + 0.00474803, + 0.00328453, + 0.002413, + 0.0 + ], + "S^2": [ + 0.133225, + 0.221857, + 0.252004, + 0.300304, + 0.334084, + 0.351649, + 0.358801, + 0.358801, + 0.356409, + 0.352836, + 0.346921, + 1.340964, + 1.334025, + 1.279161, + 1.311025, + 1.301881, + 1.238769, + 1.288225, + 1.281424, + 1.2769, + 1.245456, + 1.245456, + 1.245456, + 1.245456, + 1.243225 + ] + }, + { + "name": "LiLi", + "method": "PBE", + "R": [ + 1.28, + 1.48353846, + 1.68707692, + 1.89061538, + 2.09415384, + 2.2976923, + 2.50123076, + 2.70476924, + 2.9083077, + 3.11184616, + 3.31538462, + 3.51892308, + 3.72246154, + 3.926, + 4.12953846, + 4.33307692, + 4.53661538, + 4.74015384, + 4.9436923, + 5.14723076, + 5.35076924, + 5.5543077, + 5.75784616, + 5.96138462, + 6.16492308, + 6.36846154, + 6.572 + ], + "E": [ + 2.57567241, + 1.24409369, + 0.21632582, + -0.51468971, + -0.98958252, + -1.26786043, + -1.40617495, + -1.44929497, + -1.42991395, + -1.37109178, + -1.28881753, + -1.19406826, + -1.09431241, + -0.99457511, + -0.89818493, + -0.80729571, + -0.72325252, + -0.64684731, + -0.57850035, + -0.51839306, + -0.46655827, + -0.42294349, + -0.38745548, + -0.35998967, + -0.3404493, + -0.32875772, + -0.32485202 + ], + "F": [ + 7.27327891, + 5.80165488, + 4.30052111, + 2.91692256, + 1.79992067, + 0.98140216, + 0.41437564, + 0.03606571, + -0.20753374, + -0.35715219, + -0.44197593, + -0.4825392, + -0.49314499, + -0.4837446, + -0.4612383, + -0.430389, + -0.39446085, + -0.35568997, + -0.31551901, + -0.2748906, + -0.23433785, + -0.19417691, + -0.15451309, + -0.11535542, + -0.0766322, + -0.03822277, + 0.0 + ], + "S^2": [ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0 + ] + }, + { + "name": "BB", + "method": "PBE", + "R": [ + 0.84, + 1.04324, + 1.24648, + 1.44972, + 1.65296, + 1.8562, + 2.05944, + 2.26268, + 2.46592, + 2.66916, + 2.8724, + 3.07564, + 3.27888, + 3.48212, + 3.68536, + 3.8886, + 4.09184, + 4.29508, + 4.49832, + 4.70156, + 4.9048, + 5.10804, + 5.31128, + 5.51452, + 5.71776, + 5.921 + ], + "E": [ + 17.92459699, + 5.10213103, + -1.46814943, + -3.82963048, + -4.20597957, + -3.81948386, + -3.2179762, + -2.91230378, + -2.52699665, + -2.1845205, + -1.89245253, + -1.64778795, + -1.44434982, + -1.27587501, + -1.136815, + -1.02245563, + -0.92891465, + -0.85280637, + -0.79117735, + -0.74184674, + -0.7032786, + -0.88208624, + -0.88206431, + -0.88198997, + -0.88194666, + -0.88193624 + ], + "F": [ + 93.92111239, + 37.32727365, + 20.01361834, + 5.30975879, + -0.67289735, + -2.71375276, + -3.02695686, + -1.96187803, + -1.80382689, + -1.5589735, + -1.31070063, + -1.0963932, + -0.91531766, + -0.75700482, + -0.62128469, + -0.50944978, + -0.41521949, + -0.33543178, + -0.2671784, + -0.20897224, + -0.16575556, + 0.0043833, + 0.00015651, + 0.00015074, + 0.00026771, + 0.0 + ], + "S^2": [ + 0.0, + 0.0, + 0.292681, + 0.219024, + 0.184041, + 0.166464, + 0.158404, + 0.001369, + 0.000784, + 0.000361, + 0.000121, + 0.000016, + 0.000001, + 0.000025, + 0.000064, + 0.0001, + 0.000121, + 0.000121, + 0.000121, + 0.000121, + 0.000121, + 0.173889, + 0.173889, + 0.173889, + 0.173889, + 0.173889 + ] + }, + { + "name": "HH", + "method": "PBE", + "R": [ + 0.31, + 0.51058824, + 0.71117648, + 0.9117647, + 1.11235294, + 1.31294118, + 1.51352942, + 1.71411764, + 1.91470588, + 2.11529412, + 2.31588236, + 2.51647058, + 2.71705882, + 2.91764706, + 3.1182353, + 3.31882352, + 3.51941176, + 3.72 + ], + "E": [ + 6.16625385, + -4.94345379, + -6.74744559, + -6.44406252, + -5.60031954, + -4.66963245, + -3.77130272, + -3.08785779, + -2.57081413, + -2.26444356, + -2.35873978, + -2.28288566, + -2.1319407, + -2.25239887, + -2.20410061, + -2.24005623, + -2.2376133, + -2.22853887 + ], + "F": [ + 114.89461228, + 20.95182509, + 1.50958795, + -3.4657748, + -4.62739978, + -4.53944251, + -4.05228775, + -2.70741756, + -1.12798246, + -0.34896878, + -0.39594383, + -0.14748482, + 0.22604304, + -0.0576391, + -0.03009607, + -0.01355068, + -0.01382041, + 0.01622277 + ], + "S^2": [ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.000001, + 0.043681, + 0.1132525, + 0.1365305, + 0.131769, + 0.139876, + 0.1425125, + 0.142129, + 0.1444, + 0.142884, + 0.143641, + 0.1440205 + ] + }, + { + "name": "CC", + "method": "PBE", + "R": [ + 0.76, + 0.96552174, + 1.17104348, + 1.37656522, + 1.58208696, + 1.7876087, + 1.99313044, + 2.19865218, + 2.40417392, + 2.60969566, + 2.8152174, + 3.02073914, + 3.22626086, + 3.4317826, + 3.63730434, + 3.84282608, + 4.04834782, + 4.25386956, + 4.4593913, + 4.66491304, + 4.87043478, + 5.07595652, + 5.28147826, + 5.487 + ], + "E": [ + 18.39398234, + -3.70916039, + -8.68022979, + -9.32511298, + -8.156171, + -7.21772019, + -6.06685, + -5.04620074, + -4.18746572, + -3.85125001, + -3.55457882, + -3.34001635, + -3.1528414, + -3.00701982, + -2.89442774, + -2.80780126, + -2.76357043, + -2.76487168, + -2.76574912, + -2.76638309, + -2.76683959, + -2.76715143, + -2.76733522, + -2.76737918 + ], + "F": [ + 186.81226636, + 50.62478918, + 8.27496101, + -3.03514161, + -7.19269976, + -5.75305105, + -5.34518048, + -4.58538267, + -3.79960972, + -1.60255122, + -1.26671198, + -1.02168691, + -0.80995213, + -0.62902694, + -0.48236673, + -0.35239001, + -0.00571169, + 0.00599138, + 0.00439536, + 0.00328237, + 0.00240676, + 0.00165546, + 0.00095136, + 0.0 + ], + "S^2": [ + 0.429025, + 0.391876, + 0.0, + 0.002916, + 0.002704, + 0.2116, + 0.210681, + 0.214369, + 0.219961, + 0.944784, + 0.942841, + 0.942841, + 0.946729, + 0.950625, + 0.954529, + 0.956484, + 0.978121, + 0.978121, + 0.978121, + 0.978121, + 0.978121, + 0.978121, + 0.978121, + 0.978121 + ] + }, + { + "name": "NN", + "method": "PBE", + "R": [ + 0.71, + 0.91163636, + 1.11327272, + 1.3149091, + 1.51654546, + 1.71818182, + 1.91981818, + 2.12145454, + 2.3230909, + 2.52472728, + 2.72636364, + 2.928, + 3.12963636, + 3.33127272, + 3.5329091, + 3.73454546, + 3.93618182, + 4.13781818, + 4.33945454, + 4.5410909, + 4.74272728, + 4.94436364, + 5.146 + ], + "E": [ + 18.96597146, + -11.59519341, + -16.67070934, + -14.79548657, + -11.77597015, + -9.4021712, + -8.01090622, + -7.19560518, + -6.75533958, + -6.52741805, + -6.40681504, + -6.3415502, + -6.30497872, + -6.28388808, + -6.27132438, + -6.26394025, + -6.25890354, + -6.25600206, + -6.25400689, + -6.25284465, + -6.25203368, + -6.25154061, + -6.25128731 + ], + "F": [ + 279.97158378, + 64.2096515, + -0.37033147, + -14.3313419, + -14.63134059, + -8.81855796, + -5.3786076, + -2.98461774, + -1.54219118, + -0.82495166, + -0.43246396, + -0.23541783, + -0.12631534, + -0.09157462, + -0.03640523, + -0.02644672, + -0.00821568, + 0.0001443, + -0.006985, + -0.0119869, + -0.00233793, + 0.0053014, + -0.00138884 + ], + "S^2": [ + 0.0, + 0.0, + 0.0, + 0.001, + 0.378, + 1.054, + 1.374, + 1.618, + 1.768, + 1.841, + 1.883, + 1.903, + 1.913, + 1.919, + 1.922, + 1.922, + 1.924, + 1.925, + 1.925, + 1.925, + 1.926, + 1.926, + 1.926 + ] + }, + { + "name": "OO", + "method": "PBE", + "R": [ + 0.66, + 0.87, + 1.08, + 1.29, + 1.5, + 1.71, + 1.92, + 2.13, + 2.34, + 2.55, + 2.76, + 2.97, + 3.18, + 3.39, + 3.6, + 3.81, + 4.02, + 4.23, + 4.44, + 4.65 + ], + "E": [ + 56.92344016, + 18.14099414, + -7.50638681, + -8.6733768, + -7.3668323, + -5.6850076, + -4.75241627, + -4.26377908, + -4.03256358, + -3.92315738, + -3.87129054, + -3.84450264, + -3.83045256, + -3.82286857, + -3.8186312, + -3.814846, + -3.81495192, + -3.81282392, + -3.81248171, + -3.81668996 + ], + "F": [ + 467.55998715, + 107.21200237, + 19.8404717, + -3.0884703, + -7.82392488, + -7.40729454, + -3.22952379, + -1.57393269, + -0.73835829, + -0.35547724, + -0.16316444, + -0.09427935, + -0.03634869, + -0.02531456, + -0.01400939, + -0.00676212, + -0.00141637, + -0.00858585, + 0.00012734, + 0.01162883 + ], + "S^2": [ + 0.011, + 0.003, + 0.0, + 0.0, + 0.0, + 0.318, + 1.14, + 1.36, + 1.458, + 1.502, + 1.522, + 1.531, + 1.536, + 1.538, + 1.539, + 1.54, + 1.54, + 1.54, + 1.54, + 1.542 + ] + }, + { + "name": "BeBe", + "method": "PBE", + "R": [ + 0.96, + 1.16712, + 1.37424, + 1.58136, + 1.78848, + 1.9956, + 2.20272, + 2.40984, + 2.61696, + 2.82408, + 3.0312, + 3.23832, + 3.44544, + 3.65256, + 3.85968, + 4.0668, + 4.27392, + 4.48104, + 4.68816, + 4.89528, + 5.1024, + 5.30952, + 5.51664, + 5.72376, + 5.93088, + 6.138 + ], + "E": [ + 14.70966247, + 7.7203951, + 4.01751326, + 1.90205899, + 0.58532173, + -0.14064997, + -0.42898422, + -0.5009807, + -0.47459607, + -0.41112364, + -0.34043148, + -0.27646552, + -0.22396243, + -0.18341418, + -0.1532106, + -0.13134141, + -0.11381847, + -0.10487865, + -0.09719365, + -0.09180388, + -0.08797139, + -0.08521545, + -0.08314996, + -0.0816177, + -0.08044389, + -0.07946674 + ], + "F": [ + 43.30316358, + 24.93765328, + 12.82608435, + 7.7416394, + 4.4564196, + 2.16076775, + 0.73927492, + 0.04438674, + -0.25193742, + -0.34055459, + -0.33102582, + -0.28184451, + -0.22375087, + -0.16883101, + -0.12357826, + -0.08836932, + -0.06204519, + -0.04357134, + -0.03038775, + -0.02130521, + -0.01519562, + -0.01102413, + -0.00799724, + -0.00593137, + -0.00470464, + -0.00387667 + ], + "S^2": [ + 0.0075, + 0.0, + 0.0, + 0.004, + 0.0035, + 0.009, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0 + ] + }, + { + "name": "NaNa", + "method": "PBE", + "R": [ + 1.66, + 1.863, + 2.066, + 2.269, + 2.472, + 2.675, + 2.878, + 3.081, + 3.284, + 3.487, + 3.69, + 3.893, + 4.096, + 4.299, + 4.502, + 4.705, + 4.908, + 5.111, + 5.314, + 5.517, + 5.72, + 5.923, + 6.126, + 6.329, + 6.532, + 6.735, + 6.938, + 7.141, + 7.344, + 7.547, + 7.75 + ], + "E": [ + 0.85746712, + 0.25630231, + -0.25912653, + -0.65257794, + -0.92396474, + -1.09062756, + -1.17494915, + -1.19872526, + -1.17963041, + -1.13190933, + -1.06677752, + -0.98672863, + -0.90696988, + -0.82344115, + -0.78009135, + -0.71689258, + -0.66244914, + -0.61657388, + -0.57886378, + -0.54824949, + -0.52452931, + -0.50432721, + -0.49102115, + -0.48159264, + -0.47410483, + -0.46891745, + -0.4651316, + -0.46249019, + -0.45791341, + -0.45944426, + -0.45850815 + ], + "F": [ + 3.04952992, + 2.80597705, + 2.25438401, + 1.63198516, + 1.06266211, + 0.60021239, + 0.25095605, + -0.00209925, + -0.17446922, + -0.28696104, + -0.35391261, + -0.3856012, + -0.39582727, + -0.39370212, + -0.33055081, + -0.28823158, + -0.24463737, + -0.20340733, + -0.164492, + -0.13764615, + -0.10500679, + -0.08227191, + -0.0578671, + -0.04225327, + -0.02941287, + -0.01878402, + -0.01161329, + -0.00875706, + -0.00432825, + -0.00638915, + -0.00338371 + ], + "S^2": [ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.01, + 0.014, + 0.016, + 0.057, + 0.069, + 0.079, + 0.087, + 0.094, + 0.099, + 0.103, + 0.106, + 0.108, + 0.11, + 0.111, + 0.112, + 0.112, + 0.112, + 0.113, + 0.113, + 0.113 + ] + }, + { + "name": "AsAs", + "method": "PBE", + "R": [ + 1.19, + 1.39165218, + 1.59330434, + 1.79495652, + 1.9966087, + 2.19826086, + 2.39991304, + 2.60156522, + 2.8032174, + 3.00486956, + 3.20652174, + 3.40817392, + 3.60982608, + 3.81147826, + 4.01313044, + 4.2147826, + 4.41643478, + 4.61808696, + 4.81973914, + 5.0213913, + 5.22304348, + 5.42469566, + 5.62634782, + 5.828 + ], + "E": [ + 46.90130549, + 15.27079259, + 0.50425907, + -5.44350514, + -7.39262292, + -7.53727007, + -6.90760872, + -6.08840425, + -5.39953158, + -4.84332073, + -4.40800292, + -4.08327657, + -3.85371493, + -3.69782731, + -3.59449622, + -3.52708512, + -3.48325587, + -3.45489444, + -3.43655039, + -3.42467479, + -3.41696805, + -3.41169726, + -3.40846369, + -3.40669411 + ], + "F": [ + 211.79539103, + 107.70688613, + 45.82332632, + 17.017641, + 4.00045641, + -1.75461446, + -4.02922825, + -3.76126935, + -3.07956524, + -2.44689545, + -1.86444679, + -1.34402966, + -0.92160205, + -0.61006728, + -0.39469772, + -0.26509602, + -0.1708267, + -0.10898799, + -0.0697068, + -0.04437456, + -0.02644437, + -0.02413272, + -0.0147068, + -0.00846194 + ], + "S^2": [ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.101, + 0.544, + 0.754, + 0.884, + 0.988, + 1.074, + 1.137, + 1.18, + 1.207, + 1.225, + 1.236, + 1.243, + 1.247, + 1.25, + 1.252, + 1.252, + 1.253, + 1.254 + ] + }, + { + "name": "PP", + "method": "PBE", + "R": [ + 1.07, + 1.27083334, + 1.47166666, + 1.6725, + 1.87333334, + 2.07416666, + 2.275, + 2.47583334, + 2.67666666, + 2.8775, + 3.07833334, + 3.27916666, + 3.48, + 3.68083334, + 3.88166666, + 4.0825, + 4.28333334, + 4.48416666, + 4.685, + 4.88583334, + 5.08666666, + 5.2875, + 5.48833334, + 5.68916666, + 5.89 + ], + "E": [ + 45.77685669, + 11.8285476, + -2.53221431, + -7.75867909, + -9.04116862, + -8.65052068, + -7.62989401, + -6.60695669, + -5.80059777, + -5.17610919, + -4.70799017, + -4.37835492, + -4.15974948, + -4.0192357, + -3.93007014, + -3.8737127, + -3.83806226, + -3.81542695, + -3.80099527, + -3.79170717, + -3.78570181, + -3.78178332, + -3.77923188, + -3.7775969, + -3.77652785 + ], + "F": [ + 236.22493536, + 109.58056307, + 42.53599618, + 13.56016691, + 1.06192921, + -4.0452132, + -5.49770567, + -4.51230119, + -3.53193444, + -2.70588937, + -1.98952767, + -1.32745987, + -0.86758504, + -0.55867817, + -0.35600692, + -0.22737413, + -0.14623663, + -0.09527736, + -0.04018311, + -0.02066007, + -0.00920534, + 0.00666714, + 0.01311918, + 0.01594326, + -0.00486346 + ], + "S^2": [ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.002, + 0.22, + 0.585, + 0.758, + 0.876, + 0.978, + 1.063, + 1.121, + 1.157, + 1.178, + 1.192, + 1.199, + 1.204, + 1.207, + 1.209, + 1.21, + 1.211, + 1.211, + 1.211, + 1.21 + ] + }, + { + "name": "MgMg", + "method": "PBE", + "R": [ + 1.41, + 1.61551612, + 1.82103226, + 2.02654838, + 2.23206452, + 2.43758064, + 2.64309678, + 2.8486129, + 3.05412904, + 3.25964516, + 3.4651613, + 3.67067742, + 3.87619354, + 4.08170968, + 4.2872258, + 4.49274194, + 4.69825806, + 4.9037742, + 5.10929032, + 5.31480646, + 5.52032258, + 5.7258387, + 5.93135484, + 6.13687096, + 6.3423871, + 6.54790322, + 6.75341936, + 6.95893548, + 7.16445162, + 7.36996774, + 7.57548388, + 7.781 + ], + "E": [ + 11.6829813, + 7.96232039, + 5.06644745, + 2.98551402, + 1.64462471, + 0.82447508, + 0.34479291, + 0.07841884, + -0.0588093, + -0.12019509, + -0.13817893, + -0.1340168, + -0.11895453, + -0.1001041, + -0.08129658, + -0.06452872, + -0.04983171, + -0.03876635, + -0.02977393, + -0.02289903, + -0.01771109, + -0.01380399, + -0.01086075, + -0.00846704, + -0.00680896, + -0.00565185, + -0.00460832, + -0.00367814, + -0.00313205, + -0.00266563, + -0.00226028, + -0.00196101 + ], + "F": [ + 20.41658485, + 15.88300562, + 12.236662, + 8.15138628, + 5.09918246, + 3.047378, + 1.74122638, + 0.92869958, + 0.44907609, + 0.17271731, + 0.02125798, + -0.05478929, + -0.08664891, + -0.09362164, + -0.08767509, + -0.07586521, + -0.06248754, + -0.04950229, + -0.03832268, + -0.0291802, + -0.02206609, + -0.0166681, + -0.01267574, + -0.00962994, + -0.00765317, + -0.0058975, + -0.00477701, + -0.00389654, + -0.00320456, + -0.002442, + -0.00197497, + -0.00134118 + ], + "S^2": [ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0 + ] + } +] diff --git a/lambench/tasks/calculator/diatomics/diatomics.py b/lambench/tasks/calculator/diatomics/diatomics.py new file mode 100644 index 0000000..efe5f8f --- /dev/null +++ b/lambench/tasks/calculator/diatomics/diatomics.py @@ -0,0 +1,190 @@ +""" +The reference data (diatomics.json) is derived from the MLIP Arena project: + + MLIP Arena — Benchmark machine learning interatomic potential at scale + Yuan Chiang, Lawrence Berkeley National Laboratory + https://github.com/atomind-ai/mlip-arena + +Licensed under the Apache License, Version 2.0 (the "License"); you may not +use this file except in compliance with the License. You may obtain a copy of +the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software distributed +under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +CONDITIONS OF ANY KIND, either express or implied. See the License for the +specific language governing permissions and limitations under the License. + +---- + +Homonuclear diatomics dissociation curve roughness task (Applicability). + +This task evaluates whether a model produces physically smooth and topologically +correct potential energy surfaces along simple bond-stretching coordinates. + +Leaderboard metric (Applicability-Roughness ↓): + avg_roughness: geometric mean of per-molecule RMSE( d²(E_model-E_DFT)/dr² ) + in eV/Ų. Penalises high-frequency oscillations introduced by + the model on top of the DFT reference curve. + +Stored diagnostic metrics (not scored, available for analysis): + avg_min_position_error: mean absolute deviation of the predicted equilibrium + bond length from the DFT reference (Å). If the model + fails to produce exactly one minimum, the molecule's + scan range is used as a data-driven penalty. + avg_rmse: mean energy RMSE over molecules (eV). + +Reference data: lambench/tasks/calculator/diatomics/diatomics.json + List of dicts with keys: + name – molecule label, e.g. "HH", "NN", "AlAl" + method – DFT functional ("PBE") + R – bond lengths (Å), equally spaced + E – DFT energies (eV) + F – DFT forces along bond axis (eV/Å) + S^2 – ⟨S²⟩ spin-contamination values +""" + +from __future__ import annotations + +import json +import logging +from pathlib import Path + +import numpy as np +from ase import Atoms +from scipy.signal import find_peaks + +from lambench.models.ase_models import ASEModel + +_LABEL_FILE = Path(__file__).parent / "diatomics.json" +_MIN_PROMINENCE = 0.01 # eV — genuine well depth threshold for find_peaks + + +def _element_from_name(mol_name: str) -> str: + """Extract element symbol: 'AlAl' → 'Al', 'HH' → 'H'.""" + return mol_name[: len(mol_name) // 2] + + +def _minima_positions(energies: np.ndarray, bond_lengths: np.ndarray) -> list[float]: + """ + Return bond lengths (Å) at genuine local minima (prominence ≥ 10 meV gate). + Works on the valid (non-NaN) subset and maps indices back to bond_lengths. + """ + valid_mask = ~np.isnan(energies) + if valid_mask.sum() < 3: + return [] + idx_valid, _ = find_peaks(-energies[valid_mask], prominence=_MIN_PROMINENCE) + return bond_lengths[valid_mask][idx_valid].tolist() + + +def _min_position_error( + pos_dft: list[float], pos_model: list[float], r_range: float +) -> float: + """ + Position error of the single equilibrium minimum (Å). + + All reference molecules have exactly one DFT minimum, so pos_dft always + contains one element. If the model also finds exactly one minimum, the + error is the absolute position difference. If the model finds zero or + more than one minimum, the scan range is returned as a data-driven penalty + (no free parameter). + """ + if len(pos_model) != 1: + return r_range + return abs(pos_model[0] - pos_dft[0]) + + +def _compute_roughness(residuals: np.ndarray, dr: float) -> float | None: + """ + RMSE of the normalised second-order finite differences of energy residuals. + + δ²_i = (Δres_{i+1} - Δres_i) / Δr² ≈ d²(E_model - E_DFT)/dr² + + Δr² normalisation makes the metric comparable across molecules with + different grid spacings. Returns None when fewer than 3 valid points remain. + """ + delta2 = np.diff(residuals, n=2) + valid = delta2[~np.isnan(delta2)] + if len(valid) == 0: + return None + return float(np.sqrt(np.mean((valid / dr**2) ** 2))) + + +def _predict_energies( + model: ASEModel, element: str, bond_lengths: np.ndarray +) -> np.ndarray: + """Evaluate model energy for a homonuclear dimer at each bond length (eV).""" + calc = model.calc + cell = 30.0 + energies = [] + for r in bond_lengths: + atoms = Atoms( + symbols=[element, element], + positions=[[0.0, 0.0, 0.0], [r, 0.0, 0.0]], + cell=[cell, cell, cell], + pbc=True, + ) + atoms.calc = calc + try: + e = atoms.get_potential_energy() + if not np.isfinite(e): + raise ValueError("non-finite energy") + except Exception as exc: + logging.warning(f"{element}2 @ r={r:.3f} Å failed: {exc}") + e = np.nan + energies.append(e) + return np.array(energies) + + +def run_inference(model: ASEModel, test_data: Path | None = None) -> dict[str, dict]: + """ + Evaluate model PES roughness on homonuclear diatomic dissociation curves. + + Args: + model: loaded ASEModel + test_data: directory containing diatomics.json, or None to use the + bundled reference file next to this module. + + Returns: + Per-molecule dict, e.g.:: + + { + "HH": {"roughness": 0.012, "min_position_error": 0.02, + "n_minima_model": 1, "rmse": 0.05}, + "NN": {...}, + ... + } + """ + label_path = _LABEL_FILE if test_data is None else test_data / "diatomics.json" + + with open(label_path) as fh: + reference_data: list[dict] = json.load(fh) + + results: dict[str, dict] = {} + + for entry in reference_data: + mol_name: str = entry["name"] + bond_lengths = np.array(entry["R"]) + dft_energies = np.array(entry["E"]) + + element = _element_from_name(mol_name) + dr = float(np.mean(np.diff(bond_lengths))) + + r_range = float(bond_lengths[-1] - bond_lengths[0]) + pos_dft = _minima_positions(dft_energies, bond_lengths) + model_energies = _predict_energies(model, element, bond_lengths) + residuals = model_energies - dft_energies + pos_model = _minima_positions(model_energies, bond_lengths) + + mol_result = { + "roughness": _compute_roughness(residuals, dr), + "min_position_error": _min_position_error(pos_dft, pos_model, r_range), + "n_minima_model": len(pos_model), + "rmse": float(np.sqrt(np.nanmean(residuals**2))), + "r_range": r_range, + } + results[mol_name] = mol_result + logging.info(f"{mol_name}: {mol_result}") + + return results diff --git a/pyproject.toml b/pyproject.toml index 7833081..4d026cc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,4 +59,4 @@ visualization = "lambench.metrics.visualization:main" include = ["lambench*"] [tool.setuptools.package-data] -"*" = ["*.yml"] +"*" = ["*.yml", "*.json"] diff --git a/tests/metrics/test_metrics_calculator.py b/tests/metrics/test_metrics_calculator.py index 41138e9..1288301 100644 --- a/tests/metrics/test_metrics_calculator.py +++ b/tests/metrics/test_metrics_calculator.py @@ -90,12 +90,28 @@ def test_summarize_final_rankings(metrics_calculator): metrics_calculator.calculate_stability_results = MagicMock( return_value={"model1": 0.2, "model2": 0.5} ) + metrics_calculator.calculate_diatomics_roughness_results = MagicMock( + return_value={"model1": 0.05, "model2": 0.03} + ) _, result = metrics_calculator.summarize_final_rankings() assert result is not None assert result.iloc[0]["Model"] == "model2" assert result.iloc[1]["Model"] == "model1" +def test_calculate_diatomics_roughness_results(metrics_calculator, mock_raw_results): + mock_raw_results.fetch_diatomics_results.return_value = { + "model1": {"combined_roughness": 0.05, "avg_roughness": 0.04}, + "model2": {"combined_roughness": 0.03, "avg_roughness": 0.025}, + "model3": None, + "model4": {"combined_roughness": None, "avg_roughness": 0.06}, + } + result = metrics_calculator.calculate_diatomics_roughness_results() + assert set(result.keys()) == {"model1", "model2"} + np.testing.assert_almost_equal(result["model1"], 0.05) + np.testing.assert_almost_equal(result["model2"], 0.03) + + def test_calculate_generalizability_downstream_score( metrics_calculator, mock_raw_results, diff --git a/tests/metrics/test_post_process.py b/tests/metrics/test_post_process.py index deb0c01..5f7bf72 100644 --- a/tests/metrics/test_post_process.py +++ b/tests/metrics/test_post_process.py @@ -34,7 +34,8 @@ def test_process_results_for_one_model( # Find differences between the calculator tasks and results calculator_task_differences = ( - CALCULATOR_TASKS.keys() - {"inference_efficiency", "nve_md"} + CALCULATOR_TASKS.keys() + - {"inference_efficiency", "nve_md", "homonuclear_diatomics"} ).symmetric_difference(result["generalizability_domain_specific_results"].keys()) assert not calculator_task_differences, ( f"Mismatch in calculator tasks: {calculator_task_differences}" diff --git a/tests/tasks/calculator/test_diatomics.py b/tests/tasks/calculator/test_diatomics.py new file mode 100644 index 0000000..d7e2b67 --- /dev/null +++ b/tests/tasks/calculator/test_diatomics.py @@ -0,0 +1,154 @@ +import numpy as np +import pytest + +from lambench.tasks.calculator.diatomics.diatomics import ( + _compute_roughness, + _min_position_error, + _minima_positions, +) +from lambench.metrics.utils import aggregated_diatomics_results + + +# --------------------------------------------------------------------------- +# _compute_roughness +# --------------------------------------------------------------------------- + + +def test_compute_roughness_flat_residuals(): + """Flat residuals → zero roughness.""" + residuals = np.zeros(10) + assert _compute_roughness(residuals, dr=0.1) == pytest.approx(0.0) + + +def test_compute_roughness_oscillating(): + """Alternating residuals produce non-zero roughness.""" + residuals = np.array([0.0, 1.0, 0.0, 1.0, 0.0, 1.0], dtype=float) + roughness = _compute_roughness(residuals, dr=0.2) + assert roughness is not None + assert roughness > 0 + + +def test_compute_roughness_dr_scaling(): + """Roughness scales as 1/dr²: halving dr quadruples roughness.""" + residuals = np.array([0.0, 0.1, 0.0, 0.1, 0.0], dtype=float) + r1 = _compute_roughness(residuals, dr=0.2) + r2 = _compute_roughness(residuals, dr=0.1) + assert r2 == pytest.approx(4 * r1, rel=1e-6) + + +def test_compute_roughness_too_few_valid_points(): + """Fewer than 3 valid (non-NaN) points → None.""" + residuals = np.array([np.nan, 0.5, np.nan]) + assert _compute_roughness(residuals, dr=0.1) is None + + +def test_compute_roughness_all_nan(): + assert _compute_roughness(np.array([np.nan, np.nan, np.nan]), dr=0.1) is None + + +# --------------------------------------------------------------------------- +# _min_position_error +# --------------------------------------------------------------------------- + + +def test_min_position_error_exact_match(): + """Model finds exactly one minimum at the DFT position → zero error.""" + assert _min_position_error([1.0], [1.0], r_range=2.0) == pytest.approx(0.0) + + +def test_min_position_error_offset(): + """Model minimum is 0.15 Å away from DFT minimum.""" + assert _min_position_error([1.0], [1.15], r_range=2.0) == pytest.approx(0.15) + + +def test_min_position_error_no_minima(): + """Model finds zero minima → r_range penalty.""" + assert _min_position_error([1.0], [], r_range=2.0) == pytest.approx(2.0) + + +def test_min_position_error_two_minima(): + """Model finds two minima → r_range penalty.""" + assert _min_position_error([1.0], [0.9, 1.4], r_range=2.0) == pytest.approx(2.0) + + +# --------------------------------------------------------------------------- +# _minima_positions +# --------------------------------------------------------------------------- + + +def test_minima_positions_single_well(): + r = np.linspace(0.5, 5.0, 50) + e = (r - 1.5) ** 2 # parabola with minimum at 1.5 Å + pos = _minima_positions(e, r) + assert len(pos) == 1 + assert pos[0] == pytest.approx(1.5, abs=0.2) + + +def test_minima_positions_too_few_points(): + r = np.array([1.0, 1.5]) + e = np.array([0.5, 0.0]) + assert _minima_positions(e, r) == [] + + +def test_minima_positions_all_nan(): + r = np.array([1.0, 1.5, 2.0]) + e = np.array([np.nan, np.nan, np.nan]) + assert _minima_positions(e, r) == [] + + +# --------------------------------------------------------------------------- +# aggregated_diatomics_results +# --------------------------------------------------------------------------- + + +def _mol(roughness, min_pos_err, r_range, rmse=0.05): + return { + "roughness": roughness, + "min_position_error": min_pos_err, + "r_range": r_range, + "rmse": rmse, + } + + +def test_aggregated_combined_roughness_formula(): + """combined = avg_roughness × (1 + avg(min_pos_err / r_range)).""" + results = { + "HH": _mol(roughness=0.01, min_pos_err=0.10, r_range=2.0), + "NN": _mol(roughness=0.02, min_pos_err=0.40, r_range=2.0), + } + agg = aggregated_diatomics_results(results) + avg_r = (0.01 + 0.02) / 2 + avg_norm = (0.10 / 2.0 + 0.40 / 2.0) / 2 + assert agg["combined_roughness"] == pytest.approx(avg_r * (1 + avg_norm)) + assert agg["avg_roughness"] == pytest.approx(avg_r) + + +def test_aggregated_perfect_position_no_penalty(): + """pos_err = 0 → combined_roughness equals avg_roughness.""" + results = {"HH": _mol(roughness=0.01, min_pos_err=0.0, r_range=2.0)} + agg = aggregated_diatomics_results(results) + assert agg["combined_roughness"] == pytest.approx(agg["avg_roughness"]) + + +def test_aggregated_worst_position_doubles_roughness(): + """pos_err = r_range → combined_roughness = 2 × avg_roughness.""" + results = {"HH": _mol(roughness=0.01, min_pos_err=2.0, r_range=2.0)} + agg = aggregated_diatomics_results(results) + assert agg["combined_roughness"] == pytest.approx(2 * agg["avg_roughness"]) + + +def test_aggregated_empty_results(): + """No valid molecules → all None.""" + agg = aggregated_diatomics_results({}) + assert agg["combined_roughness"] is None + assert agg["avg_roughness"] is None + assert agg["avg_min_position_error"] is None + assert agg["avg_rmse"] is None + + +def test_aggregated_none_molecule_skipped(): + """None entries are silently skipped.""" + results = {"HH": None, "NN": _mol(roughness=0.02, min_pos_err=0.20, r_range=2.0)} + agg = aggregated_diatomics_results(results) + assert agg["combined_roughness"] is not None + assert agg["avg_roughness"] == pytest.approx(0.02)