From ff76316df5a2358ae8aa73eb8e1ba7bd807e340f Mon Sep 17 00:00:00 2001 From: Nikhil Dev Goyal Date: Thu, 23 Apr 2026 09:03:44 -0700 Subject: [PATCH] Replace Tanh with FastTanh in flash_attention PiperOrigin-RevId: 904484786 --- gemma/flash_attention.cc | 20 ++++++++++---------- gemma/flash_attention_test.cc | 2 +- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/gemma/flash_attention.cc b/gemma/flash_attention.cc index 9fb6e4dd..016f8a08 100644 --- a/gemma/flash_attention.cc +++ b/gemma/flash_attention.cc @@ -591,7 +591,7 @@ static HWY_INLINE void FlashAttentionTileStepAndApplySoftCap4( if (att_cap > 0.0f) { VF4 cap = hn::Set(df4, att_cap); VF4 one_over_cap = hn::Set(df4, one_over_att_cap); - new_max = hn::Mul(cap, hn::Tanh(df4, hn::Mul(new_max, one_over_cap))); + new_max = hn::Mul(cap, hn::FastTanh(df4, hn::Mul(new_max, one_over_cap))); } VF4 local_max = new_max; VF4 old_max_vf = hn::Set(df4, kNegInf); @@ -733,7 +733,7 @@ static HWY_INLINE void FlashAttentionTileStepAndApplySoftCap8( if (att_cap > 0.0f) { VF8 cap = hn::Set(df8, att_cap); VF8 one_over_cap = hn::Set(df8, one_over_att_cap); - new_max = hn::Mul(cap, hn::Tanh(df8, hn::Mul(new_max, one_over_cap))); + new_max = hn::Mul(cap, hn::FastTanh(df8, hn::Mul(new_max, one_over_cap))); } VF8 local_max = new_max; VF8 old_max_vf = hn::Set(df8, kNegInf); @@ -1309,27 +1309,27 @@ static HWY_INLINE void ApplySoftCap(DF df, float att_cap, float one_over_cap, if (att_cap > 0.0f) { VF cap = hn::Set(df, att_cap); VF one_over_cap_vec = hn::Set(df, one_over_cap); - x0 = hn::Mul(cap, hn::CallTanh(df, hn::Mul(x0, one_over_cap_vec))); + x0 = hn::Mul(cap, hn::CallFastTanh(df, hn::Mul(x0, one_over_cap_vec))); if constexpr (kVTileSize >= 2) { - x1 = hn::Mul(cap, hn::CallTanh(df, hn::Mul(x1, one_over_cap_vec))); + x1 = hn::Mul(cap, hn::CallFastTanh(df, hn::Mul(x1, one_over_cap_vec))); } if constexpr (kVTileSize >= 3) { - x2 = hn::Mul(cap, hn::CallTanh(df, hn::Mul(x2, one_over_cap_vec))); + x2 = hn::Mul(cap, hn::CallFastTanh(df, hn::Mul(x2, one_over_cap_vec))); } if constexpr (kVTileSize >= 4) { - x3 = hn::Mul(cap, hn::CallTanh(df, hn::Mul(x3, one_over_cap_vec))); + x3 = hn::Mul(cap, hn::CallFastTanh(df, hn::Mul(x3, one_over_cap_vec))); } if constexpr (kVTileSize >= 5) { - x4 = hn::Mul(cap, hn::CallTanh(df, hn::Mul(x4, one_over_cap_vec))); + x4 = hn::Mul(cap, hn::CallFastTanh(df, hn::Mul(x4, one_over_cap_vec))); } if constexpr (kVTileSize >= 6) { - x5 = hn::Mul(cap, hn::CallTanh(df, hn::Mul(x5, one_over_cap_vec))); + x5 = hn::Mul(cap, hn::CallFastTanh(df, hn::Mul(x5, one_over_cap_vec))); } if constexpr (kVTileSize >= 7) { - x6 = hn::Mul(cap, hn::CallTanh(df, hn::Mul(x6, one_over_cap_vec))); + x6 = hn::Mul(cap, hn::CallFastTanh(df, hn::Mul(x6, one_over_cap_vec))); } if constexpr (kVTileSize >= 8) { - x7 = hn::Mul(cap, hn::CallTanh(df, hn::Mul(x7, one_over_cap_vec))); + x7 = hn::Mul(cap, hn::CallFastTanh(df, hn::Mul(x7, one_over_cap_vec))); } } } diff --git a/gemma/flash_attention_test.cc b/gemma/flash_attention_test.cc index 04c00051..06a89b2e 100644 --- a/gemma/flash_attention_test.cc +++ b/gemma/flash_attention_test.cc @@ -257,7 +257,7 @@ void AssertClose(const MatPtrT& a, const MatPtrT& b) { if (rel_abs_delta > 0.0f) { rel_abs_delta /= std::max(std::abs(a_row[c]), std::abs(b_row[c])); } - EXPECT_LT(rel_abs_delta, 1e-3) + EXPECT_LT(rel_abs_delta, 3e-3) << "a[" << r << "," << c << "]=" << a_row[c] << ", b[" << r << "," << c << "]=" << b_row[c]; }