diff --git a/client.go b/client.go index 1fb7260b..30a8e65b 100644 --- a/client.go +++ b/client.go @@ -992,6 +992,13 @@ func NewClient[TTx any](driver riverdriver.Driver[TTx], config *Config) (*Client client.testSignals.queueCleaner = &queueCleaner.TestSignals } + if driver.DatabaseName() == riverdriver.DatabaseNameSQLite { + sqliteNotificationCleaner := maintenance.NewSQLiteNotificationCleaner(archetype, &maintenance.SQLiteNotificationCleanerConfig{ + Schema: config.Schema, + }, driver.GetExecutor()) + maintenanceServices = append(maintenanceServices, sqliteNotificationCleaner) + } + { var scheduleFunc func(time.Time) time.Time if config.ReindexerSchedule != nil { @@ -2378,8 +2385,6 @@ type JobListResult struct { LastCursor *JobListCursor } -const databaseNameSQLite = "sqlite" - var errJobListParamsMetadataNotSupportedSQLite = errors.New("JobListParams.Metadata is not supported on SQLite") // JobList returns a paginated list of jobs matching the provided filters. The @@ -2401,7 +2406,7 @@ func (c *Client[TTx]) JobList(ctx context.Context, params *JobListParams) (*JobL } params.schema = c.config.Schema - if c.driver.DatabaseName() == databaseNameSQLite && params.metadataCalled { + if c.driver.DatabaseName() == riverdriver.DatabaseNameSQLite && params.metadataCalled { return nil, errJobListParamsMetadataNotSupportedSQLite } @@ -2442,7 +2447,7 @@ func (c *Client[TTx]) JobListTx(ctx context.Context, tx TTx, params *JobListPara } params.schema = c.config.Schema - if c.driver.DatabaseName() == databaseNameSQLite && params.metadataCalled { + if c.driver.DatabaseName() == riverdriver.DatabaseNameSQLite && params.metadataCalled { return nil, errJobListParamsMetadataNotSupportedSQLite } diff --git a/internal/maintenance/sqlite_notification_cleaner.go b/internal/maintenance/sqlite_notification_cleaner.go new file mode 100644 index 00000000..2fd07f7c --- /dev/null +++ b/internal/maintenance/sqlite_notification_cleaner.go @@ -0,0 +1,152 @@ +package maintenance + +import ( + "cmp" + "context" + "errors" + "log/slog" + "time" + + "github.com/riverqueue/river/riverdriver" + "github.com/riverqueue/river/rivershared/baseservice" + "github.com/riverqueue/river/rivershared/riversharedmaintenance" + "github.com/riverqueue/river/rivershared/startstop" + "github.com/riverqueue/river/rivershared/testsignal" + "github.com/riverqueue/river/rivershared/util/testutil" + "github.com/riverqueue/river/rivershared/util/timeutil" +) + +const ( + SQLiteNotificationCleanerIntervalDefault = time.Minute + SQLiteNotificationCleanerRetentionPeriodDefault = 5 * time.Minute +) + +// SQLiteNotificationCleanerTestSignals are internal signals used exclusively in tests. +type SQLiteNotificationCleanerTestSignals struct { + DeletedBatch testsignal.TestSignal[struct{}] // notifies when runOnce finishes a pass +} + +func (ts *SQLiteNotificationCleanerTestSignals) Init(tb testutil.TestingTB) { + ts.DeletedBatch.Init(tb) +} + +type SQLiteNotificationCleanerConfig struct { + // Interval is the amount of time to wait between cleaner runs. + Interval time.Duration + + // RetentionPeriod is the amount of time to keep notification rows around + // before they're removed. + RetentionPeriod time.Duration + + // Schema where River tables are located. Empty string omits schema. + Schema string + + // Timeout is the timeout for each delete query. + Timeout time.Duration +} + +func (c *SQLiteNotificationCleanerConfig) mustValidate() *SQLiteNotificationCleanerConfig { + if c.Interval <= 0 { + panic("SQLiteNotificationCleanerConfig.Interval must be above zero") + } + if c.RetentionPeriod <= 0 { + panic("SQLiteNotificationCleanerConfig.RetentionPeriod must be above zero") + } + if c.Timeout <= 0 { + panic("SQLiteNotificationCleanerConfig.Timeout must be above zero") + } + + return c +} + +// SQLiteNotificationCleaner periodically removes old rows from SQLite's +// notification outbox. It is only needed for the SQLite driver's emulated +// listen/notify support. +type SQLiteNotificationCleaner struct { + riversharedmaintenance.QueueMaintainerServiceBase + startstop.BaseStartStop + + // exported for test purposes + Config *SQLiteNotificationCleanerConfig + TestSignals SQLiteNotificationCleanerTestSignals + + exec riverdriver.Executor +} + +// NewSQLiteNotificationCleaner returns a SQLite notification cleaner. +func NewSQLiteNotificationCleaner(archetype *baseservice.Archetype, config *SQLiteNotificationCleanerConfig, exec riverdriver.Executor) *SQLiteNotificationCleaner { + return baseservice.Init(archetype, &SQLiteNotificationCleaner{ + Config: (&SQLiteNotificationCleanerConfig{ + Interval: cmp.Or(config.Interval, SQLiteNotificationCleanerIntervalDefault), + RetentionPeriod: cmp.Or(config.RetentionPeriod, SQLiteNotificationCleanerRetentionPeriodDefault), + Schema: config.Schema, + Timeout: cmp.Or(config.Timeout, riversharedmaintenance.TimeoutDefault), + }).mustValidate(), + exec: exec, + }) +} + +func (s *SQLiteNotificationCleaner) Start(ctx context.Context) error { //nolint:dupl + ctx, shouldStart, started, stopped := s.StartInit(ctx) + if !shouldStart { + return nil + } + + s.StaggerStart(ctx) + + go func() { + started() + defer stopped() // this defer should come first so it's last out + + s.Logger.DebugContext(ctx, s.Name+riversharedmaintenance.LogPrefixRunLoopStarted) + defer s.Logger.DebugContext(ctx, s.Name+riversharedmaintenance.LogPrefixRunLoopStopped) + + ticker := timeutil.NewTickerWithInitialTick(ctx, s.Config.Interval) + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + } + + res, err := s.runOnce(ctx) + if err != nil { + if !errors.Is(err, context.Canceled) { + s.Logger.ErrorContext(ctx, s.Name+": Error cleaning SQLite notifications", slog.String("error", err.Error())) + } + continue + } + + if res.NumNotificationsDeleted > 0 { + s.Logger.InfoContext(ctx, s.Name+riversharedmaintenance.LogPrefixRanSuccessfully, + slog.Int("num_notifications_deleted", res.NumNotificationsDeleted), + ) + } + } + }() + + return nil +} + +type sqliteNotificationCleanerRunOnceResult struct { + NumNotificationsDeleted int +} + +func (s *SQLiteNotificationCleaner) runOnce(ctx context.Context) (*sqliteNotificationCleanerRunOnceResult, error) { + ctx, cancelFunc := context.WithTimeout(ctx, s.Config.Timeout) + defer cancelFunc() + + numDeleted, err := s.exec.NotificationDeleteBefore(ctx, &riverdriver.NotificationDeleteBeforeParams{ + CreatedAtHorizon: time.Now().Add(-s.Config.RetentionPeriod), + Schema: s.Config.Schema, + }) + if err != nil { + return nil, err + } + + s.TestSignals.DeletedBatch.Signal(struct{}{}) + + return &sqliteNotificationCleanerRunOnceResult{ + NumNotificationsDeleted: numDeleted, + }, nil +} diff --git a/internal/maintenance/sqlite_notification_cleaner_test.go b/internal/maintenance/sqlite_notification_cleaner_test.go new file mode 100644 index 00000000..348dacce --- /dev/null +++ b/internal/maintenance/sqlite_notification_cleaner_test.go @@ -0,0 +1,105 @@ +package maintenance + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/riverqueue/river/riverdbtest" + "github.com/riverqueue/river/riverdriver" + "github.com/riverqueue/river/riverdriver/riverpgxv5" + "github.com/riverqueue/river/rivershared/riversharedtest" + "github.com/riverqueue/river/rivershared/startstoptest" +) + +func TestSQLiteNotificationCleaner(t *testing.T) { + t.Parallel() + + ctx := context.Background() + + type testBundle struct { + exec riverdriver.Executor + schema string + } + + setup := func(t *testing.T) (*SQLiteNotificationCleaner, *testBundle) { + t.Helper() + + driver := riverpgxv5.New(riversharedtest.DBPool(ctx, t)) + tx, schema := riverdbtest.TestTxPgxDriver(ctx, t, driver, nil) + + bundle := &testBundle{ + exec: driver.UnwrapExecutor(tx), + schema: schema, + } + + cleaner := NewSQLiteNotificationCleaner( + riversharedtest.BaseServiceArchetype(t), + &SQLiteNotificationCleanerConfig{ + Interval: time.Hour, + RetentionPeriod: time.Hour, + Schema: bundle.schema, + Timeout: time.Second, + }, + bundle.exec, + ) + cleaner.StaggerStartupDisable(true) + t.Cleanup(cleaner.Stop) + + return cleaner, bundle + } + + notificationCount := func(t *testing.T, exec riverdriver.Executor) int { + t.Helper() + + var count int + require.NoError(t, exec.QueryRow(ctx, "SELECT count(*) FROM river_notification").Scan(&count)) + return count + } + + t.Run("Defaults", func(t *testing.T) { + t.Parallel() + + cleaner := NewSQLiteNotificationCleaner( + riversharedtest.BaseServiceArchetype(t), + &SQLiteNotificationCleanerConfig{}, + nil, + ) + + require.Equal(t, SQLiteNotificationCleanerIntervalDefault, cleaner.Config.Interval) + require.Equal(t, SQLiteNotificationCleanerRetentionPeriodDefault, cleaner.Config.RetentionPeriod) + }) + + t.Run("DeletesExpiredNotifications", func(t *testing.T) { + t.Parallel() + + cleaner, bundle := setup(t) + cleaner.TestSignals.Init(t) + + now := time.Now() + require.NoError(t, bundle.exec.Exec(ctx, ` + INSERT INTO river_notification (created_at, payload, topic) + VALUES + ($1, 'old_payload_1', 'topic'), + ($2, 'old_payload_2', 'topic'), + ($3, 'new_payload', 'topic') + `, now.Add(-2*time.Hour), now.Add(-61*time.Minute), now.Add(-30*time.Minute))) + + res, err := cleaner.runOnce(ctx) + require.NoError(t, err) + require.Equal(t, 2, res.NumNotificationsDeleted) + cleaner.TestSignals.DeletedBatch.WaitOrTimeout() + require.Equal(t, 1, notificationCount(t, bundle.exec)) + }) + + t.Run("StartStopStress", func(t *testing.T) { + t.Parallel() + + cleaner, _ := setup(t) + cleaner.Logger = riversharedtest.LoggerWarn(t) // loop started/stop log is very noisy; suppress + + startstoptest.Stress(ctx, t, cleaner) + }) +} diff --git a/metadata_test.go b/metadata_test.go index 2736b0c8..94298915 100644 --- a/metadata_test.go +++ b/metadata_test.go @@ -4,8 +4,9 @@ import ( "context" "testing" - "github.com/riverqueue/river/internal/jobexecutor" "github.com/stretchr/testify/require" + + "github.com/riverqueue/river/internal/jobexecutor" ) func TestMetadataSet(t *testing.T) { diff --git a/riverdriver/river_driver_interface.go b/riverdriver/river_driver_interface.go index 8727a6e3..4a650d2b 100644 --- a/riverdriver/river_driver_interface.go +++ b/riverdriver/river_driver_interface.go @@ -24,6 +24,11 @@ import ( const AllQueuesString = "*" +const ( + DatabaseNamePostgres = "postgres" + DatabaseNameSQLite = "sqlite" +) + const MigrationLineMain = "main" var ( @@ -138,12 +143,14 @@ type Driver[TTx any] interface { // API is not stable. DO NOT USE. SupportsListener() bool - // SupportsListenNotify indicates whether the underlying database supports - // listen/notify. This differs from SupportsListener in that even if a - // driver doesn't a support a listener but the database supports the - // underlying listen/notify mechanism, it will still broadcast in case there - // are other clients/drivers on the database that do support a listener. If - // listen/notify can't be supported at all, no broadcast attempt is made. + // SupportsListenNotify indicates whether the driver can broadcast + // notifications that a listener can receive, either through a native + // database mechanism like Postgres LISTEN/NOTIFY or a driver-specific + // emulation. This differs from SupportsListener in that even if a driver + // doesn't support a listener but the database supports the underlying + // notification mechanism, it will still broadcast in case there are other + // clients/drivers on the database that do support a listener. If + // notifications can't be supported at all, no broadcast attempt is made. // // API is not stable. DO NOT USE. SupportsListenNotify() bool @@ -256,6 +263,14 @@ type Executor interface { // the `line` column was added to the migrations table. MigrationInsertManyAssumingMain(ctx context.Context, params *MigrationInsertManyAssumingMainParams) ([]*Migration, error) + // NotificationDeleteBefore deletes notifications before a certain time + // horizon. + // + // A "notification" in this context refers to a row in `river_notification` + // which is a special table implemented in some databases (e.g. SQLite) that + // simulates Postgres' listen/notify when not available. + NotificationDeleteBefore(ctx context.Context, params *NotificationDeleteBeforeParams) (int, error) + NotifyMany(ctx context.Context, params *NotifyManyParams) error PGAdvisoryXactLock(ctx context.Context, key int64) (*struct{}, error) @@ -775,6 +790,11 @@ type NotifyManyParams struct { Schema string } +type NotificationDeleteBeforeParams struct { + CreatedAtHorizon time.Time + Schema string +} + type ProducerKeepAliveParams struct { ID int64 QueueName string @@ -883,8 +903,10 @@ func MigrationLineMainTruncateTables(version int) []string { return []string{"river_job", "river_leader"} case 4: return []string{"river_job", "river_leader", "river_queue"} - case 0, 5, 6: + case 5, 6: return []string{"river_job", "river_leader", "river_queue", "river_client", "river_client_queue"} + case 0, 7: + return []string{"river_job", "river_leader", "river_queue", "river_client", "river_client_queue", "river_notification"} } panic(fmt.Sprintf("unrecognized migration version: %d", version)) diff --git a/riverdriver/riverdatabasesql/internal/dbsqlc/models.go b/riverdriver/riverdatabasesql/internal/dbsqlc/models.go index c8188fc7..07dc3581 100644 --- a/riverdriver/riverdatabasesql/internal/dbsqlc/models.go +++ b/riverdriver/riverdatabasesql/internal/dbsqlc/models.go @@ -111,6 +111,13 @@ type RiverMigration struct { CreatedAt time.Time } +type RiverNotification struct { + ID int64 + CreatedAt time.Time + Payload string + Topic string +} + type RiverQueue struct { Name string CreatedAt time.Time diff --git a/riverdriver/riverdatabasesql/internal/dbsqlc/river_notification.sql.go b/riverdriver/riverdatabasesql/internal/dbsqlc/river_notification.sql.go new file mode 100644 index 00000000..33577934 --- /dev/null +++ b/riverdriver/riverdatabasesql/internal/dbsqlc/river_notification.sql.go @@ -0,0 +1,24 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.31.0 +// source: river_notification.sql + +package dbsqlc + +import ( + "context" + "time" +) + +const notificationDeleteBefore = `-- name: NotificationDeleteBefore :execrows +DELETE FROM /* TEMPLATE: schema */river_notification +WHERE created_at < $1::timestamptz +` + +func (q *Queries) NotificationDeleteBefore(ctx context.Context, db DBTX, createdAtHorizon time.Time) (int64, error) { + result, err := db.ExecContext(ctx, notificationDeleteBefore, createdAtHorizon) + if err != nil { + return 0, err + } + return result.RowsAffected() +} diff --git a/riverdriver/riverdatabasesql/internal/dbsqlc/sqlc.yaml b/riverdriver/riverdatabasesql/internal/dbsqlc/sqlc.yaml index 77805117..93a5e9bc 100644 --- a/riverdriver/riverdatabasesql/internal/dbsqlc/sqlc.yaml +++ b/riverdriver/riverdatabasesql/internal/dbsqlc/sqlc.yaml @@ -8,6 +8,7 @@ sql: - ../../../riverpgxv5/internal/dbsqlc/river_job.sql - ../../../riverpgxv5/internal/dbsqlc/river_leader.sql - ../../../riverpgxv5/internal/dbsqlc/river_migration.sql + - ../../../riverpgxv5/internal/dbsqlc/river_notification.sql - ../../../riverpgxv5/internal/dbsqlc/river_queue.sql - ../../../riverpgxv5/internal/dbsqlc/schema.sql schema: @@ -17,6 +18,7 @@ sql: - ../../../riverpgxv5/internal/dbsqlc/river_job.sql - ../../../riverpgxv5/internal/dbsqlc/river_leader.sql - ../../../riverpgxv5/internal/dbsqlc/river_migration.sql + - ../../../riverpgxv5/internal/dbsqlc/river_notification.sql - ../../../riverpgxv5/internal/dbsqlc/river_queue.sql - ../../../riverpgxv5/internal/dbsqlc/schema.sql gen: diff --git a/riverdriver/riverdatabasesql/migration/main/007_notification_outbox.down.sql b/riverdriver/riverdatabasesql/migration/main/007_notification_outbox.down.sql new file mode 100644 index 00000000..bd5d4a89 --- /dev/null +++ b/riverdriver/riverdatabasesql/migration/main/007_notification_outbox.down.sql @@ -0,0 +1 @@ +DROP TABLE /* TEMPLATE: schema */river_notification; diff --git a/riverdriver/riverdatabasesql/migration/main/007_notification_outbox.up.sql b/riverdriver/riverdatabasesql/migration/main/007_notification_outbox.up.sql new file mode 100644 index 00000000..8f91f880 --- /dev/null +++ b/riverdriver/riverdatabasesql/migration/main/007_notification_outbox.up.sql @@ -0,0 +1,10 @@ +CREATE TABLE /* TEMPLATE: schema */river_notification ( + id bigserial PRIMARY KEY, + created_at timestamptz NOT NULL DEFAULT now(), + payload text NOT NULL, + topic text NOT NULL, + CONSTRAINT topic_length CHECK (length(topic) > 0 AND length(topic) < 128) +); + +CREATE INDEX river_notification_created_at_idx ON /* TEMPLATE: schema */river_notification (created_at); +CREATE INDEX river_notification_topic_id_idx ON /* TEMPLATE: schema */river_notification (topic, id); diff --git a/riverdriver/riverdatabasesql/river_database_sql_driver.go b/riverdriver/riverdatabasesql/river_database_sql_driver.go index 16a3ae35..32733e89 100644 --- a/riverdriver/riverdatabasesql/river_database_sql_driver.go +++ b/riverdriver/riverdatabasesql/river_database_sql_driver.go @@ -54,7 +54,7 @@ func New(dbPool *sql.DB) *Driver { const argPlaceholder = "$" func (d *Driver) ArgPlaceholder() string { return argPlaceholder } -func (d *Driver) DatabaseName() string { return "postgres" } +func (d *Driver) DatabaseName() string { return riverdriver.DatabaseNamePostgres } func (d *Driver) GetExecutor() riverdriver.Executor { return &Executor{d.dbPool, templateReplaceWrapper{d.dbPool, &d.replacer}, d} @@ -859,6 +859,11 @@ func (e *Executor) MigrationInsertManyAssumingMain(ctx context.Context, params * }), nil } +func (e *Executor) NotificationDeleteBefore(ctx context.Context, params *riverdriver.NotificationDeleteBeforeParams) (int, error) { + numDeleted, err := dbsqlc.New().NotificationDeleteBefore(schemaTemplateParam(ctx, params.Schema), e.dbtx, params.CreatedAtHorizon) + return int(numDeleted), interpretError(err) +} + func (e *Executor) NotifyMany(ctx context.Context, params *riverdriver.NotifyManyParams) error { return dbsqlc.New().PGNotifyMany(ctx, e.dbtx, &dbsqlc.PGNotifyManyParams{ Payload: params.Payload, diff --git a/riverdriver/riverdrivertest/driver_client_test.go b/riverdriver/riverdrivertest/driver_client_test.go index 7a597e65..8f99cec4 100644 --- a/riverdriver/riverdrivertest/driver_client_test.go +++ b/riverdriver/riverdrivertest/driver_client_test.go @@ -319,7 +319,7 @@ func ExerciseClient[TTx any](ctx context.Context, t *testing.T, require.NoError(t, err) // SQLite can't support multiple concurrent transactions, so skip this extra check there. - if bundle.driver.DatabaseName() != databaseNameSQLite { + if bundle.driver.DatabaseName() != riverdriver.DatabaseNameSQLite { _, otherExecTx := beginTx(ctx, t, bundle) // Both jobs present because other transaction doesn't see the deletion. @@ -404,7 +404,7 @@ func ExerciseClient[TTx any](ctx context.Context, t *testing.T, require.ErrorIs(t, err, rivertype.ErrNotFound) // SQLite can't support multiple concurrent transactions, so skip this extra check there. - if bundle.driver.DatabaseName() != databaseNameSQLite { + if bundle.driver.DatabaseName() != riverdriver.DatabaseNameSQLite { _, otherExecTx := beginTx(ctx, t, bundle) // Jobs present because other transaction doesn't see the deletions. @@ -519,7 +519,7 @@ func ExerciseClient[TTx any](ctx context.Context, t *testing.T, }) listRes, err := client.JobList(ctx, river.NewJobListParams().Metadata(`{"foo":"bar"}`)) - if bundle.driver.DatabaseName() == databaseNameSQLite { + if bundle.driver.DatabaseName() == riverdriver.DatabaseNameSQLite { t.Logf("Ignoring unsupported JobListResult.Metadata on SQLite") require.EqualError(t, err, "JobListParams.Metadata is not supported on SQLite") return @@ -583,7 +583,7 @@ func ExerciseClient[TTx any](ctx context.Context, t *testing.T, }) listRes, err := client.JobListTx(ctx, tx, river.NewJobListParams().Metadata(`{"foo":"bar"}`)) - if bundle.driver.DatabaseName() == databaseNameSQLite { + if bundle.driver.DatabaseName() == riverdriver.DatabaseNameSQLite { t.Logf("Ignoring unsupported JobListTxResult.Metadata on SQLite") require.EqualError(t, err, "JobListParams.Metadata is not supported on SQLite") return @@ -607,7 +607,7 @@ func ExerciseClient[TTx any](ctx context.Context, t *testing.T, listParams := river.NewJobListParams() - if bundle.driver.DatabaseName() == databaseNameSQLite { + if bundle.driver.DatabaseName() == riverdriver.DatabaseNameSQLite { listParams = listParams.Where("metadata ->> @json_path = @json_val", river.NamedArgs{"json_path": "$.foo", "json_val": "bar"}) } else { // "bar" is quoted in this branch because `jsonb_path_query_first` needs to be compared to a JSON value diff --git a/riverdriver/riverdrivertest/executor_tx.go b/riverdriver/riverdrivertest/executor_tx.go index 451c65c2..e4313932 100644 --- a/riverdriver/riverdrivertest/executor_tx.go +++ b/riverdriver/riverdrivertest/executor_tx.go @@ -166,7 +166,7 @@ func exerciseExecutorTx[TTx any](ctx context.Context, t *testing.T, { driver, _ := driverWithSchema(ctx, t, nil) - if driver.DatabaseName() == databaseNameSQLite { + if driver.DatabaseName() == riverdriver.DatabaseNameSQLite { t.Logf("Skipping PGAdvisoryXactLock test for SQLite") return } diff --git a/riverdriver/riverdrivertest/job_delete.go b/riverdriver/riverdrivertest/job_delete.go index 787a5d09..2b1919dc 100644 --- a/riverdriver/riverdrivertest/job_delete.go +++ b/riverdriver/riverdrivertest/job_delete.go @@ -261,7 +261,7 @@ func exerciseJobDelete[TTx any](ctx context.Context, t *testing.T, executorWithT // since we only expect to need `queues_excluded` on SQLite (and not // `queues_included` for the foreseeable future), I've just set // SQLite to not support `queues_included` for the time being. - if bundle.driver.DatabaseName() == databaseNameSQLite { + if bundle.driver.DatabaseName() == riverdriver.DatabaseNameSQLite { t.Logf("Skipping JobDeleteBefore with QueuesIncluded test for SQLite") return } diff --git a/riverdriver/riverdrivertest/listener.go b/riverdriver/riverdrivertest/listener.go index 79aa7211..7a096075 100644 --- a/riverdriver/riverdrivertest/listener.go +++ b/riverdriver/riverdrivertest/listener.go @@ -128,6 +128,10 @@ func exerciseListener[TTx any](ctx context.Context, t *testing.T, driverWithPool listener = driver.GetListener(&riverdriver.GetListenenerParams{Schema: ""}) ) + if driver.DatabaseName() == riverdriver.DatabaseNameSQLite { + t.Skip("SQLite has no search_path") + } + listener.SetAfterConnectExec("SET search_path TO 'public'") connectListener(ctx, t, listener) diff --git a/riverdriver/riverdrivertest/migration.go b/riverdriver/riverdrivertest/migration.go index ba87a2bf..856040ad 100644 --- a/riverdriver/riverdrivertest/migration.go +++ b/riverdriver/riverdrivertest/migration.go @@ -58,6 +58,7 @@ func exerciseMigration[TTx any](ctx context.Context, t *testing.T, t.Parallel() driver, _ := driverWithSchema(ctx, t, nil) + expectedLatestTables := []string{"river_job", "river_leader", "river_queue", "river_client", "river_client_queue", "river_notification"} require.Empty(t, driver.GetMigrationTruncateTables(riverdriver.MigrationLineMain, 1)) require.Equal(t, []string{"river_job", "river_leader"}, @@ -70,7 +71,9 @@ func exerciseMigration[TTx any](ctx context.Context, t *testing.T, driver.GetMigrationTruncateTables(riverdriver.MigrationLineMain, 5)) require.Equal(t, []string{"river_job", "river_leader", "river_queue", "river_client", "river_client_queue"}, driver.GetMigrationTruncateTables(riverdriver.MigrationLineMain, 6)) - require.Equal(t, []string{"river_job", "river_leader", "river_queue", "river_client", "river_client_queue"}, + require.Equal(t, expectedLatestTables, + driver.GetMigrationTruncateTables(riverdriver.MigrationLineMain, 7)) + require.Equal(t, expectedLatestTables, driver.GetMigrationTruncateTables(riverdriver.MigrationLineMain, 0)) }) }) diff --git a/riverdriver/riverdrivertest/notification.go b/riverdriver/riverdrivertest/notification.go new file mode 100644 index 00000000..311f77dd --- /dev/null +++ b/riverdriver/riverdrivertest/notification.go @@ -0,0 +1,64 @@ +package riverdrivertest + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/riverqueue/river/riverdriver" +) + +func exerciseNotification[TTx any](ctx context.Context, t *testing.T, executorWithTx func(ctx context.Context, t *testing.T) (riverdriver.Executor, riverdriver.Driver[TTx])) { + t.Helper() + + t.Run("NotificationDeleteBefore", func(t *testing.T) { + t.Parallel() + + exec, driver := executorWithTx(ctx, t) + + insertQuery := ` + INSERT INTO river_notification (created_at, payload, topic) + VALUES + ($1, $2, $3), + ($4, $5, $6), + ($7, $8, $9) + ` + if driver.DatabaseName() == riverdriver.DatabaseNameSQLite { + insertQuery = ` + INSERT INTO river_notification (created_at, payload, topic) + VALUES + (?, ?, ?), + (?, ?, ?), + (?, ?, ?) + ` + } + createdAt := func(t time.Time) any { return t } + if driver.DatabaseName() == riverdriver.DatabaseNameSQLite { + // Keep this in the same format that the SQLite driver uses for + // CreatedAtHorizon so SQLite's text comparison stays chronological. + createdAt = func(t time.Time) any { + const sqliteFormat = "2006-01-02 15:04:05.999" + return t.UTC().Round(time.Millisecond).Format(sqliteFormat) + } + } + + now := time.Now().UTC() + require.NoError(t, exec.Exec(ctx, insertQuery, + createdAt(now.Add(-2*time.Hour)), "old_payload_1", "topic", + createdAt(now.Add(-61*time.Minute)), "old_payload_2", "topic", + createdAt(now.Add(-30*time.Minute)), "new_payload", "topic", + )) + + numDeleted, err := exec.NotificationDeleteBefore(ctx, &riverdriver.NotificationDeleteBeforeParams{ + CreatedAtHorizon: now.Add(-time.Hour), + }) + require.NoError(t, err) + require.Equal(t, 2, numDeleted) + + var count int + require.NoError(t, exec.QueryRow(ctx, "SELECT count(*) FROM river_notification").Scan(&count)) + require.Equal(t, 1, count) + }) +} diff --git a/riverdriver/riverdrivertest/riverdrivertest.go b/riverdriver/riverdrivertest/riverdrivertest.go index c51093e9..ad4eb976 100644 --- a/riverdriver/riverdrivertest/riverdrivertest.go +++ b/riverdriver/riverdrivertest/riverdrivertest.go @@ -33,6 +33,7 @@ func Exercise[TTx any](ctx context.Context, t *testing.T, exerciseDriverPool(ctx, t, driverWithSchema, executorWithTx) exerciseMigration(ctx, t, driverWithSchema, executorWithTx) + exerciseNotification(ctx, t, executorWithTx) exerciseSQLFragments(ctx, t, executorWithTx) exerciseExecutorTx(ctx, t, driverWithSchema, executorWithTx) exerciseSchemaIntrospection(ctx, t, driverWithSchema, executorWithTx) @@ -45,11 +46,7 @@ func Exercise[TTx any](ctx context.Context, t *testing.T, exerciseQueue(ctx, t, executorWithTx) } -const ( - databaseNamePostgres = "postgres" - databaseNameSQLite = "sqlite" - testClientID = "test-client-id" -) +const testClientID = "test-client-id" func exerciseDriverPool[TTx any](ctx context.Context, t *testing.T, driverWithSchema func(ctx context.Context, t *testing.T, opts *riverdbtest.TestSchemaOpts) (riverdriver.Driver[TTx], string), @@ -89,10 +86,10 @@ func exerciseDriverPool[TTx any](ctx context.Context, t *testing.T, _, driver := executorWithTx(ctx, t) switch driver.DatabaseName() { - case databaseNamePostgres: + case riverdriver.DatabaseNamePostgres: + require.True(t, driver.SupportsListenNotify()) + case riverdriver.DatabaseNameSQLite: require.True(t, driver.SupportsListenNotify()) - case databaseNameSQLite: - require.False(t, driver.SupportsListenNotify()) default: require.FailNow(t, "Don't know how to check SupportsListenNotify for: "+driver.DatabaseName()) } diff --git a/riverdriver/riverdrivertest/schema_introspection.go b/riverdriver/riverdrivertest/schema_introspection.go index c52cbb25..14eca2a7 100644 --- a/riverdriver/riverdrivertest/schema_introspection.go +++ b/riverdriver/riverdrivertest/schema_introspection.go @@ -90,7 +90,7 @@ func exerciseSchemaIntrospection[TTx any](ctx context.Context, t *testing.T, // the index name, but on Postgres it should go before the table. // The schema is empty for SQLite anyway since we're operating in // isolation in a particular database file. - if driver.DatabaseName() == databaseNameSQLite { + if driver.DatabaseName() == riverdriver.DatabaseNameSQLite { require.NoError(t, driver.GetExecutor().Exec(ctx, "CREATE INDEX river_job_index_drop_if_exists ON river_job (id)")) } else { require.NoError(t, driver.GetExecutor().Exec(ctx, fmt.Sprintf("CREATE INDEX river_job_index_drop_if_exists ON %s.river_job (id)", schema))) @@ -142,7 +142,7 @@ func exerciseSchemaIntrospection[TTx any](ctx context.Context, t *testing.T, Index: "river_job_prioritized_fetching_index", Schema: "custom_schema", }) - if bundle.driver.DatabaseName() == databaseNameSQLite { + if bundle.driver.DatabaseName() == riverdriver.DatabaseNameSQLite { requireMissingRelation(t, err, "custom_schema", "sqlite_master") } else { require.NoError(t, err) @@ -205,7 +205,7 @@ func exerciseSchemaIntrospection[TTx any](ctx context.Context, t *testing.T, IndexNames: []string{"river_job_kind", "river_job_prioritized_fetching_index"}, Schema: "custom_schema_that_does_not_exist", }) - if bundle.driver.DatabaseName() == databaseNameSQLite { + if bundle.driver.DatabaseName() == riverdriver.DatabaseNameSQLite { requireMissingRelation(t, err, "custom_schema_that_does_not_exist", "sqlite_master") } else { require.NoError(t, err) @@ -245,7 +245,7 @@ func exerciseSchemaIntrospection[TTx any](ctx context.Context, t *testing.T, // empty because they're actually separate databases and can't be // referenced with their fully qualified name. So instead, extract // the name of the current database via pragma and use it as schema. - if driver1.DatabaseName() == databaseNameSQLite { + if driver1.DatabaseName() == riverdriver.DatabaseNameSQLite { getCurrentSchema := func(exec riverdriver.Executor) string { var databaseFile string require.NoError(t, exec.QueryRow(ctx, "SELECT file FROM pragma_database_list WHERE name = ?1", "main").Scan(&databaseFile)) diff --git a/riverdriver/riverdrivertest/schema_name.go b/riverdriver/riverdrivertest/schema_name.go index 14f249db..9388ec7d 100644 --- a/riverdriver/riverdrivertest/schema_name.go +++ b/riverdriver/riverdrivertest/schema_name.go @@ -26,7 +26,7 @@ func exerciseSchemaName[TTx any](ctx context.Context, t *testing.T, // In SQLite schemas are files assigned to particular names, so this // check isn't relevant in the same way. - if driver.DatabaseName() != databaseNamePostgres { + if driver.DatabaseName() != riverdriver.DatabaseNamePostgres { t.Skip("Skipping; schema names with spaces only relevant for Postgres") } diff --git a/riverdriver/riverpgxv5/internal/dbsqlc/models.go b/riverdriver/riverpgxv5/internal/dbsqlc/models.go index 4c8ba801..104ed01a 100644 --- a/riverdriver/riverpgxv5/internal/dbsqlc/models.go +++ b/riverdriver/riverpgxv5/internal/dbsqlc/models.go @@ -113,6 +113,13 @@ type RiverMigration struct { CreatedAt time.Time } +type RiverNotification struct { + ID int64 + CreatedAt time.Time + Payload string + Topic string +} + type RiverQueue struct { Name string CreatedAt time.Time diff --git a/riverdriver/riverpgxv5/internal/dbsqlc/river_notification.sql b/riverdriver/riverpgxv5/internal/dbsqlc/river_notification.sql new file mode 100644 index 00000000..576d2444 --- /dev/null +++ b/riverdriver/riverpgxv5/internal/dbsqlc/river_notification.sql @@ -0,0 +1,16 @@ +-- This table isn't used under Postgres currently, but we have it in place +-- because its useful for simulating under Postgres as if we were running +-- SQLite, and it may be useful as a good listen/notify alternative for Postgres +-- down the line instead of poll-only mode in cases like where a bouncer makes +-- listen/notify difficult to use. +CREATE TABLE river_notification ( + id bigserial PRIMARY KEY, + created_at timestamptz NOT NULL DEFAULT now(), + payload text NOT NULL, + topic text NOT NULL, + CONSTRAINT topic_length CHECK (length(topic) > 0 AND length(topic) < 128) +); + +-- name: NotificationDeleteBefore :execrows +DELETE FROM /* TEMPLATE: schema */river_notification +WHERE created_at < @created_at_horizon::timestamptz; diff --git a/riverdriver/riverpgxv5/internal/dbsqlc/river_notification.sql.go b/riverdriver/riverpgxv5/internal/dbsqlc/river_notification.sql.go new file mode 100644 index 00000000..cb460451 --- /dev/null +++ b/riverdriver/riverpgxv5/internal/dbsqlc/river_notification.sql.go @@ -0,0 +1,24 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.31.0 +// source: river_notification.sql + +package dbsqlc + +import ( + "context" + "time" +) + +const notificationDeleteBefore = `-- name: NotificationDeleteBefore :execrows +DELETE FROM /* TEMPLATE: schema */river_notification +WHERE created_at < $1::timestamptz +` + +func (q *Queries) NotificationDeleteBefore(ctx context.Context, db DBTX, createdAtHorizon time.Time) (int64, error) { + result, err := db.Exec(ctx, notificationDeleteBefore, createdAtHorizon) + if err != nil { + return 0, err + } + return result.RowsAffected(), nil +} diff --git a/riverdriver/riverpgxv5/internal/dbsqlc/sqlc.yaml b/riverdriver/riverpgxv5/internal/dbsqlc/sqlc.yaml index a818dcfa..89fe855c 100644 --- a/riverdriver/riverpgxv5/internal/dbsqlc/sqlc.yaml +++ b/riverdriver/riverpgxv5/internal/dbsqlc/sqlc.yaml @@ -9,6 +9,7 @@ sql: - river_job_copyfrom.sql - river_leader.sql - river_migration.sql + - river_notification.sql - river_queue.sql - schema.sql schema: @@ -18,6 +19,7 @@ sql: - river_job.sql - river_leader.sql - river_migration.sql + - river_notification.sql - river_queue.sql - schema.sql gen: diff --git a/riverdriver/riverpgxv5/migration/main/007_notification_outbox.down.sql b/riverdriver/riverpgxv5/migration/main/007_notification_outbox.down.sql new file mode 100644 index 00000000..bd5d4a89 --- /dev/null +++ b/riverdriver/riverpgxv5/migration/main/007_notification_outbox.down.sql @@ -0,0 +1 @@ +DROP TABLE /* TEMPLATE: schema */river_notification; diff --git a/riverdriver/riverpgxv5/migration/main/007_notification_outbox.up.sql b/riverdriver/riverpgxv5/migration/main/007_notification_outbox.up.sql new file mode 100644 index 00000000..8f91f880 --- /dev/null +++ b/riverdriver/riverpgxv5/migration/main/007_notification_outbox.up.sql @@ -0,0 +1,10 @@ +CREATE TABLE /* TEMPLATE: schema */river_notification ( + id bigserial PRIMARY KEY, + created_at timestamptz NOT NULL DEFAULT now(), + payload text NOT NULL, + topic text NOT NULL, + CONSTRAINT topic_length CHECK (length(topic) > 0 AND length(topic) < 128) +); + +CREATE INDEX river_notification_created_at_idx ON /* TEMPLATE: schema */river_notification (created_at); +CREATE INDEX river_notification_topic_id_idx ON /* TEMPLATE: schema */river_notification (topic, id); diff --git a/riverdriver/riverpgxv5/river_pgx_v5_driver.go b/riverdriver/riverpgxv5/river_pgx_v5_driver.go index efbddf1f..54c3f567 100644 --- a/riverdriver/riverpgxv5/river_pgx_v5_driver.go +++ b/riverdriver/riverpgxv5/river_pgx_v5_driver.go @@ -64,7 +64,7 @@ func New(dbPool *pgxpool.Pool) *Driver { const argPlaceholder = "$" func (d *Driver) ArgPlaceholder() string { return argPlaceholder } -func (d *Driver) DatabaseName() string { return "postgres" } +func (d *Driver) DatabaseName() string { return riverdriver.DatabaseNamePostgres } func (d *Driver) GetExecutor() riverdriver.Executor { return &Executor{templateReplaceWrapper{d.dbPool, &d.replacer}, d} @@ -844,6 +844,11 @@ func (e *Executor) MigrationInsertManyAssumingMain(ctx context.Context, params * }), nil } +func (e *Executor) NotificationDeleteBefore(ctx context.Context, params *riverdriver.NotificationDeleteBeforeParams) (int, error) { + numDeleted, err := dbsqlc.New().NotificationDeleteBefore(schemaTemplateParam(ctx, params.Schema), e.dbtx, params.CreatedAtHorizon) + return int(numDeleted), interpretError(err) +} + func (e *Executor) NotifyMany(ctx context.Context, params *riverdriver.NotifyManyParams) error { return dbsqlc.New().PGNotifyMany(ctx, e.dbtx, &dbsqlc.PGNotifyManyParams{ Payload: params.Payload, diff --git a/riverdriver/riversqlite/internal/dbsqlc/models.go b/riverdriver/riversqlite/internal/dbsqlc/models.go index 1a9ce5c5..7c00c356 100644 --- a/riverdriver/riversqlite/internal/dbsqlc/models.go +++ b/riverdriver/riversqlite/internal/dbsqlc/models.go @@ -61,6 +61,13 @@ type RiverMigration struct { CreatedAt time.Time } +type RiverNotification struct { + ID int64 + CreatedAt time.Time + Payload string + Topic string +} + type RiverQueue struct { Name string CreatedAt time.Time diff --git a/riverdriver/riversqlite/internal/dbsqlc/river_notification.sql b/riverdriver/riversqlite/internal/dbsqlc/river_notification.sql new file mode 100644 index 00000000..ddaf3d70 --- /dev/null +++ b/riverdriver/riversqlite/internal/dbsqlc/river_notification.sql @@ -0,0 +1,32 @@ +CREATE TABLE river_notification ( + id integer PRIMARY KEY AUTOINCREMENT, + created_at timestamp NOT NULL DEFAULT (datetime('now', 'subsec')), + payload text NOT NULL, + topic text NOT NULL, + CONSTRAINT topic_length CHECK (length(topic) > 0 AND length(topic) < 128) +); + +-- name: NotificationDeleteBefore :execrows +DELETE FROM /* TEMPLATE: schema */river_notification +WHERE created_at < cast(@created_at_horizon AS text); + +-- name: NotificationGetAfter :one +SELECT * +FROM /* TEMPLATE: schema */river_notification +WHERE id > @after +ORDER BY id ASC +LIMIT 1; + +-- name: NotificationGetLastID :one +SELECT cast(coalesce(max(id), 0) AS integer) +FROM /* TEMPLATE: schema */river_notification; + +-- name: NotificationInsertMany :exec +INSERT INTO /* TEMPLATE: schema */river_notification ( + payload, + topic +) +SELECT + json_extract(value, '$.payload'), + json_extract(value, '$.topic') +FROM json_each(cast(@notifications AS blob)); diff --git a/riverdriver/riversqlite/internal/dbsqlc/river_notification.sql.go b/riverdriver/riversqlite/internal/dbsqlc/river_notification.sql.go new file mode 100644 index 00000000..13eb7deb --- /dev/null +++ b/riverdriver/riversqlite/internal/dbsqlc/river_notification.sql.go @@ -0,0 +1,71 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.31.0 +// source: river_notification.sql + +package dbsqlc + +import ( + "context" +) + +const notificationDeleteBefore = `-- name: NotificationDeleteBefore :execrows +DELETE FROM /* TEMPLATE: schema */river_notification +WHERE created_at < cast(?1 AS text) +` + +func (q *Queries) NotificationDeleteBefore(ctx context.Context, db DBTX, createdAtHorizon string) (int64, error) { + result, err := db.ExecContext(ctx, notificationDeleteBefore, createdAtHorizon) + if err != nil { + return 0, err + } + return result.RowsAffected() +} + +const notificationGetAfter = `-- name: NotificationGetAfter :one +SELECT id, created_at, payload, topic +FROM /* TEMPLATE: schema */river_notification +WHERE id > ?1 +ORDER BY id ASC +LIMIT 1 +` + +func (q *Queries) NotificationGetAfter(ctx context.Context, db DBTX, after int64) (*RiverNotification, error) { + row := db.QueryRowContext(ctx, notificationGetAfter, after) + var i RiverNotification + err := row.Scan( + &i.ID, + &i.CreatedAt, + &i.Payload, + &i.Topic, + ) + return &i, err +} + +const notificationGetLastID = `-- name: NotificationGetLastID :one +SELECT cast(coalesce(max(id), 0) AS integer) +FROM /* TEMPLATE: schema */river_notification +` + +func (q *Queries) NotificationGetLastID(ctx context.Context, db DBTX) (int64, error) { + row := db.QueryRowContext(ctx, notificationGetLastID) + var column_1 int64 + err := row.Scan(&column_1) + return column_1, err +} + +const notificationInsertMany = `-- name: NotificationInsertMany :exec +INSERT INTO /* TEMPLATE: schema */river_notification ( + payload, + topic +) +SELECT + json_extract(value, '$.payload'), + json_extract(value, '$.topic') +FROM json_each(cast(?1 AS blob)) +` + +func (q *Queries) NotificationInsertMany(ctx context.Context, db DBTX, notifications []byte) error { + _, err := db.ExecContext(ctx, notificationInsertMany, notifications) + return err +} diff --git a/riverdriver/riversqlite/internal/dbsqlc/sqlc.yaml b/riverdriver/riversqlite/internal/dbsqlc/sqlc.yaml index dd86a267..e9886e38 100644 --- a/riverdriver/riversqlite/internal/dbsqlc/sqlc.yaml +++ b/riverdriver/riversqlite/internal/dbsqlc/sqlc.yaml @@ -7,6 +7,7 @@ sql: - river_job.sql - river_leader.sql - river_migration.sql + - river_notification.sql - river_queue.sql - schema.sql schema: @@ -15,6 +16,7 @@ sql: - river_job.sql - river_leader.sql - river_migration.sql + - river_notification.sql - river_queue.sql - schema.sql gen: diff --git a/riverdriver/riversqlite/migration/main/007_notification_outbox.down.sql b/riverdriver/riversqlite/migration/main/007_notification_outbox.down.sql new file mode 100644 index 00000000..bd5d4a89 --- /dev/null +++ b/riverdriver/riversqlite/migration/main/007_notification_outbox.down.sql @@ -0,0 +1 @@ +DROP TABLE /* TEMPLATE: schema */river_notification; diff --git a/riverdriver/riversqlite/migration/main/007_notification_outbox.up.sql b/riverdriver/riversqlite/migration/main/007_notification_outbox.up.sql new file mode 100644 index 00000000..c33ad2f9 --- /dev/null +++ b/riverdriver/riversqlite/migration/main/007_notification_outbox.up.sql @@ -0,0 +1,10 @@ +CREATE TABLE /* TEMPLATE: schema */river_notification ( + id integer PRIMARY KEY AUTOINCREMENT, + created_at timestamp NOT NULL DEFAULT (datetime('now', 'subsec')), + payload text NOT NULL, + topic text NOT NULL, + CONSTRAINT topic_length CHECK (length(topic) > 0 AND length(topic) < 128) +); + +CREATE INDEX /* TEMPLATE: schema */river_notification_created_at_idx ON river_notification (created_at); +CREATE INDEX /* TEMPLATE: schema */river_notification_topic_id_idx ON river_notification (topic, id); diff --git a/riverdriver/riversqlite/river_sqlite_driver.go b/riverdriver/riversqlite/river_sqlite_driver.go index b7d729ae..3aac9093 100644 --- a/riverdriver/riversqlite/river_sqlite_driver.go +++ b/riverdriver/riversqlite/river_sqlite_driver.go @@ -78,14 +78,20 @@ func New(dbPool *sql.DB) *Driver { const argPlaceholder = "?" func (d *Driver) ArgPlaceholder() string { return argPlaceholder } -func (d *Driver) DatabaseName() string { return "sqlite" } +func (d *Driver) DatabaseName() string { return riverdriver.DatabaseNameSQLite } func (d *Driver) GetExecutor() riverdriver.Executor { return &Executor{d.dbPool, templateReplaceWrapper{d.dbPool, &d.replacer}, d, nil} } func (d *Driver) GetListener(params *riverdriver.GetListenenerParams) riverdriver.Listener { - panic(riverdriver.ErrNotImplemented) + return &Listener{ + dbPool: d.dbPool, + pollInterval: notificationPollIntervalDefault, + replacer: &d.replacer, + schema: params.Schema, + topics: make(map[string]struct{}), + } } func (d *Driver) GetMigrationDefaultLines() []string { return []string{riverdriver.MigrationLineMain} } @@ -121,8 +127,8 @@ func (d *Driver) SQLFragmentColumnIn(column string, values any) (string, any, er return fmt.Sprintf("%s IN (SELECT value FROM json_each(cast(@%s AS blob)))", column, column), arg, nil } -func (d *Driver) SupportsListener() bool { return false } -func (d *Driver) SupportsListenNotify() bool { return false } +func (d *Driver) SupportsListener() bool { return true } +func (d *Driver) SupportsListenNotify() bool { return true } func (d *Driver) TimePrecision() time.Duration { return time.Millisecond } func (d *Driver) UnwrapExecutor(tx *sql.Tx) riverdriver.ExecutorTx { @@ -1269,8 +1275,31 @@ func (e *Executor) MigrationInsertManyAssumingMain(ctx context.Context, params * return migrations, nil } +func (e *Executor) NotificationDeleteBefore(ctx context.Context, params *riverdriver.NotificationDeleteBeforeParams) (int, error) { + numDeleted, err := dbsqlc.New().NotificationDeleteBefore( + schemaTemplateParam(ctx, params.Schema), + e.dbtx, + timeString(params.CreatedAtHorizon), + ) + return int(numDeleted), interpretError(err) +} + func (e *Executor) NotifyMany(ctx context.Context, params *riverdriver.NotifyManyParams) error { - return riverdriver.ErrNotImplemented + if len(params.Payload) < 1 { + return nil + } + + notifications, err := json.Marshal(sliceutil.Map(params.Payload, func(payload string) notificationPayload { + return notificationPayload{ + Payload: payload, + Topic: params.Topic, + } + })) + if err != nil { + return err + } + + return dbsqlc.New().NotificationInsertMany(schemaTemplateParam(ctx, params.Schema), e.dbtx, notifications) } func (e *Executor) PGAdvisoryXactLock(ctx context.Context, key int64) (*struct{}, error) { diff --git a/riverdriver/riversqlite/river_sqlite_listener.go b/riverdriver/riversqlite/river_sqlite_listener.go new file mode 100644 index 00000000..444c21f9 --- /dev/null +++ b/riverdriver/riversqlite/river_sqlite_listener.go @@ -0,0 +1,271 @@ +package riversqlite + +import ( + "context" + "database/sql" + "errors" + "sync" + "time" + + "github.com/riverqueue/river/riverdriver" + "github.com/riverqueue/river/riverdriver/riversqlite/internal/dbsqlc" + "github.com/riverqueue/river/rivershared/sqlctemplate" +) + +const ( + notificationPollIntervalDefault = 50 * time.Millisecond +) + +// Listener receives SQLite notifications from the river_notification outbox +// table. SQLite doesn't have a native LISTEN/NOTIFY equivalent, so NotifyMany +// appends rows to river_notification and this listener polls for rows with IDs +// greater than its remembered lastID. The lastID marker is initialized to the +// current max ID on connect so historical rows aren't replayed, and advances +// past every observed row so unlistened topics don't get delivered later if +// they're re-listened. +type Listener struct { + afterConnectExec string // should only ever be used in testing + dbPool *sql.DB + isConnected bool + + // lastID is safe to use as a visibility cursor because SQLite serializes + // writers. A transaction that has inserted a lower notification ID holds + // the write lock until commit/rollback, so another transaction can't insert + // and commit a higher notification ID first. This would not be true in a + // multi-writer database where sequence IDs may be allocated before commit. + lastID int64 + + mu sync.Mutex + pollInterval time.Duration + replacer *sqlctemplate.Replacer + schema string + topics map[string]struct{} +} + +type notificationPayload struct { + Payload string `json:"payload"` + Topic string `json:"topic"` +} + +func (l *Listener) Close(context.Context) error { + l.mu.Lock() + defer l.mu.Unlock() + + l.isConnected = false + return nil +} + +func (l *Listener) Connect(ctx context.Context) error { + var ( + afterConnectExec string + dbPool *sql.DB + replacer *sqlctemplate.Replacer + schema string + ) + + l.mu.Lock() + if l.isConnected { + l.mu.Unlock() + return errors.New("connection already established") + } + afterConnectExec = l.afterConnectExec + dbPool = l.dbPool + replacer = l.replacer + schema = l.schema + l.mu.Unlock() + + if dbPool == nil { + return errors.New("database pool is nil") + } + if replacer == nil { + replacer = &sqlctemplate.Replacer{} + } + + if afterConnectExec != "" { + if _, err := dbPool.ExecContext(ctx, afterConnectExec); err != nil { + return err + } + } + + lastID, err := dbsqlc.New().NotificationGetLastID(schemaTemplateParam(ctx, schema), templateReplaceWrapper{dbPool, replacer}) + if err != nil { + return err + } + + l.mu.Lock() + defer l.mu.Unlock() + + if l.isConnected { + return errors.New("connection already established") + } + + l.isConnected = true + l.lastID = lastID + + return nil +} + +func (l *Listener) Listen(_ context.Context, topic string) error { + l.mu.Lock() + defer l.mu.Unlock() + + if !l.isConnected { + return errors.New("listener is not connected") + } + + if l.topics == nil { + l.topics = make(map[string]struct{}) + } + + l.topics[topic] = struct{}{} + return nil +} + +func (l *Listener) Ping(ctx context.Context) error { + dbPool, err := l.stateDBPool() + if err != nil { + return err + } + return dbPool.PingContext(ctx) +} + +func (l *Listener) Schema() string { + l.mu.Lock() + defer l.mu.Unlock() + + return l.schema +} + +func (l *Listener) SetAfterConnectExec(sql string) { + l.mu.Lock() + defer l.mu.Unlock() + + l.afterConnectExec = sql +} + +func (l *Listener) Unlisten(_ context.Context, topic string) error { + l.mu.Lock() + defer l.mu.Unlock() + + if !l.isConnected { + return errors.New("listener is not connected") + } + + delete(l.topics, topic) + return nil +} + +func (l *Listener) WaitForNotification(ctx context.Context) (*riverdriver.Notification, error) { + for { + if err := ctx.Err(); err != nil { + return nil, err + } + + notification, found, err := l.waitForNotificationOnce(ctx) + if errors.Is(err, sql.ErrNoRows) { + if err := l.waitForNextPoll(ctx); err != nil { + return nil, err + } + continue + } + if err != nil { + if ctx.Err() != nil { + return nil, ctx.Err() + } + return nil, err + } + if found { + return notification, nil + } + } +} + +func (l *Listener) stateDBPool() (*sql.DB, error) { + l.mu.Lock() + defer l.mu.Unlock() + + if !l.isConnected { + return nil, errors.New("listener is not connected") + } + if l.dbPool == nil { + return nil, errors.New("database pool is nil") + } + + return l.dbPool, nil +} + +func (l *Listener) waitForNextPoll(ctx context.Context) error { + l.mu.Lock() + pollInterval := l.pollInterval + l.mu.Unlock() + + if pollInterval <= 0 { + pollInterval = notificationPollIntervalDefault + } + + timer := time.NewTimer(pollInterval) + defer timer.Stop() + + select { + case <-ctx.Done(): + return ctx.Err() + case <-timer.C: + return nil + } +} + +func (l *Listener) waitForNotificationOnce(ctx context.Context) (*riverdriver.Notification, bool, error) { + var ( + after int64 + dbPool *sql.DB + replacer *sqlctemplate.Replacer + schema string + ) + + l.mu.Lock() + if !l.isConnected { + l.mu.Unlock() + return nil, false, errors.New("listener is not connected") + } + after = l.lastID + dbPool = l.dbPool + replacer = l.replacer + schema = l.schema + l.mu.Unlock() + + if dbPool == nil { + return nil, false, errors.New("database pool is nil") + } + + notification, err := dbsqlc.New().NotificationGetAfter( + schemaTemplateParam(ctx, schema), + notificationDBTX(dbPool, replacer), + after, + ) + if err != nil { + return nil, false, err + } + + l.mu.Lock() + defer l.mu.Unlock() + + if notification.ID > l.lastID { + l.lastID = notification.ID + } + + if _, ok := l.topics[notification.Topic]; !ok { + return nil, false, nil + } + + return &riverdriver.Notification{ + Payload: notification.Payload, + Topic: notification.Topic, + }, true, nil +} + +func notificationDBTX(dbPool *sql.DB, replacer *sqlctemplate.Replacer) templateReplaceWrapper { + if replacer == nil { + replacer = &sqlctemplate.Replacer{} + } + return templateReplaceWrapper{dbPool, replacer} +}