Compare commits

...

5 Commits

Author SHA1 Message Date
Owen Smallwood
201d743eee adds tests 2025-12-12 14:21:01 -06:00
Owen Smallwood
76e7d56cf4 have to manually close connection on commit/rollback 2025-12-12 12:01:48 -06:00
Owen Smallwood
d387521230 simplify error checking 2025-12-12 11:08:28 -06:00
Owen Smallwood
3834dc989f remove comment 2025-12-11 12:27:15 -06:00
Owen Smallwood
96f586fa45 retry when we fail to get a db connection due to a possible transient error 2025-12-11 12:25:20 -06:00
2 changed files with 237 additions and 4 deletions

View File

@@ -3,15 +3,25 @@ package dbimpl
import (
"context"
"database/sql"
"errors"
"time"
"github.com/grafana/grafana/pkg/infra/log"
"github.com/grafana/grafana/pkg/storage/unified/sql/db"
)
const (
defaultConnAttemptTimeout = 3 * time.Second
defaultConnMaxRetries = 3
defaultConnRetryBackoff = 1 * time.Second
)
// NewDB converts a *sql.DB to a db.DB.
func NewDB(d *sql.DB, driverName string) db.DB {
ret := sqlDB{
DB: d,
driverName: driverName,
log: log.New("resource-db"),
}
ret.WithTxFunc = db.NewWithTxFunc(ret.BeginTx)
@@ -22,6 +32,7 @@ type sqlDB struct {
*sql.DB
db.WithTxFunc
driverName string
log log.Logger
}
func (d sqlDB) DriverName() string {
@@ -37,20 +48,77 @@ func (d sqlDB) QueryRowContext(ctx context.Context, query string, args ...any) d
}
func (d sqlDB) BeginTx(ctx context.Context, opts *sql.TxOptions) (db.Tx, error) {
tx, err := d.DB.BeginTx(ctx, opts)
if err != nil {
return nil, err
}
return sqlTx{tx}, err
var tx *sql.Tx
var err error
var conn *sql.Conn
// try to acquire a connection with retries on transient errors
for attempt := 1; attempt <= defaultConnMaxRetries; attempt++ {
connCtx, cancel := context.WithTimeout(ctx, defaultConnAttemptTimeout)
conn, err = d.DB.Conn(connCtx)
cancel()
if err == nil {
break
}
// retry if connection deadline exceeded and attempts remain
if errors.Is(err, context.DeadlineExceeded) && attempt < defaultConnMaxRetries {
d.log.Warn("Timeout when acquiring database connection, retrying", "attempt", attempt, "max_retries", defaultConnMaxRetries)
time.Sleep(defaultConnRetryBackoff)
continue
}
if errors.Is(err, context.DeadlineExceeded) {
d.log.Error("Timeout exceeded while trying to acquire database connection", "attempt", attempt, "max_retries", defaultConnMaxRetries)
return nil, err
}
return nil, err
}
// once we have a connection, begin the transaction
tx, err = conn.BeginTx(ctx, opts)
if err != nil {
if closeErr := conn.Close(); closeErr != nil {
d.log.Error("Failed to close connection after BeginTx error", "error", closeErr)
}
return nil, err
}
return sqlTx{Tx: tx, conn: conn}, nil
}
// sqlTx wraps sql.Tx to add connection management
// since we are manually acquiring the connection in BeginTx() we need to also manually close it after Commit/Rollback
type sqlTx struct {
*sql.Tx
conn *sql.Conn
}
// NewTx wraps an existing *sql.Tx with sqlTx
func NewTx(tx *sql.Tx) db.Tx {
return sqlTx{tx}
return sqlTx{Tx: tx}
}
func (tx sqlTx) Commit() error {
err := tx.Tx.Commit()
if tx.conn != nil {
if err = tx.conn.Close(); err != nil {
return err
}
}
return err
}
func (tx sqlTx) Rollback() error {
err := tx.Tx.Rollback()
if tx.conn != nil {
if err = tx.conn.Close(); err != nil {
return err
}
}
return err
}
func (tx sqlTx) QueryContext(ctx context.Context, query string, args ...any) (db.Rows, error) {

View File

@@ -1,8 +1,13 @@
package dbimpl
import (
"context"
"database/sql"
"database/sql/driver"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/require"
@@ -25,3 +30,163 @@ func TestDB_BeginTx(t *testing.T) {
require.NoError(t, err)
require.NotNil(t, tx)
}
func TestDB_BeginTx_RetriesOnTimeout(t *testing.T) {
t.Parallel()
// Register a driver that times out on the first two connection attempts
driverName := "test-timeout-driver-" + t.Name()
timeoutDriver := &timeoutTestDriver{
timeoutsRemaining: 2,
}
sql.Register(driverName, timeoutDriver)
sqlDB, err := sql.Open(driverName, "")
require.NoError(t, err)
sqlDB.SetMaxOpenConns(1)
d := NewDB(sqlDB, driverName)
// Use a longer timeout context since we expect retries with backoff
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
tx, err := d.BeginTx(ctx, nil)
require.NoError(t, err)
require.NotNil(t, tx)
// Should have attempted 3 connections (2 timeouts + 1 success)
require.Equal(t, int32(3), timeoutDriver.connAttempts.Load())
}
func TestDB_BeginTx_ConnectionReleasedAfterCommitAndRollback(t *testing.T) {
t.Parallel()
// Register a driver that tracks open connections
driverName := "test-conn-tracking-driver-" + t.Name()
trackingDriver := &connTrackingDriver{}
sql.Register(driverName, trackingDriver)
sqlDB, err := sql.Open(driverName, "")
require.NoError(t, err)
// Setting MaxIdleConns to 0 ensures connections are closed when returned to pool
sqlDB.SetMaxIdleConns(0)
sqlDB.SetMaxOpenConns(1)
d := NewDB(sqlDB, driverName)
ctx := testutil.NewDefaultTestContext(t)
// Test Commit releases connection
tx1, err := d.BeginTx(ctx, nil)
require.NoError(t, err)
require.Equal(t, int32(1), trackingDriver.openConns.Load(), "should have 1 open connection after BeginTx")
err = tx1.Commit()
require.NoError(t, err)
require.Equal(t, int32(0), trackingDriver.openConns.Load(), "should have 0 open connections after Commit")
// Test Rollback releases connection
tx2, err := d.BeginTx(ctx, nil)
require.NoError(t, err)
require.Equal(t, int32(1), trackingDriver.openConns.Load(), "should have 1 open connection after BeginTx")
err = tx2.Rollback()
require.NoError(t, err)
require.Equal(t, int32(0), trackingDriver.openConns.Load(), "should have 0 open connections after Rollback")
}
// timeoutTestDriver simulates connection timeouts for testing retry logic
type timeoutTestDriver struct {
mu sync.Mutex
timeoutsRemaining int
connAttempts atomic.Int32
}
func (d *timeoutTestDriver) Open(name string) (driver.Conn, error) {
return d.connect(context.Background())
}
func (d *timeoutTestDriver) OpenConnector(name string) (driver.Connector, error) {
return &timeoutTestConnector{driver: d}, nil
}
func (d *timeoutTestDriver) connect(ctx context.Context) (driver.Conn, error) {
d.connAttempts.Add(1)
d.mu.Lock()
shouldTimeout := d.timeoutsRemaining > 0
if shouldTimeout {
d.timeoutsRemaining--
}
d.mu.Unlock()
if shouldTimeout {
// Block until context is cancelled to simulate a slow connection
<-ctx.Done()
return nil, ctx.Err()
}
return &timeoutTestConn{}, nil
}
type timeoutTestConnector struct {
driver *timeoutTestDriver
}
func (c *timeoutTestConnector) Connect(ctx context.Context) (driver.Conn, error) {
return c.driver.connect(ctx)
}
func (c *timeoutTestConnector) Driver() driver.Driver {
return c.driver
}
type timeoutTestConn struct{}
func (c *timeoutTestConn) Prepare(query string) (driver.Stmt, error) { return testStmt{}, nil }
func (c *timeoutTestConn) Close() error { return nil }
func (c *timeoutTestConn) Begin() (driver.Tx, error) { return testTx{}, nil }
func (c *timeoutTestConn) BeginTx(context.Context, driver.TxOptions) (driver.Tx, error) {
return testTx{}, nil
}
// connTrackingDriver tracks open connections for testing connection lifecycle
type connTrackingDriver struct {
openConns atomic.Int32
}
func (d *connTrackingDriver) Open(name string) (driver.Conn, error) {
d.openConns.Add(1)
return &connTrackingConn{driver: d}, nil
}
func (d *connTrackingDriver) OpenConnector(name string) (driver.Connector, error) {
return &connTrackingConnector{driver: d}, nil
}
type connTrackingConnector struct {
driver *connTrackingDriver
}
func (c *connTrackingConnector) Connect(ctx context.Context) (driver.Conn, error) {
c.driver.openConns.Add(1)
return &connTrackingConn{driver: c.driver}, nil
}
func (c *connTrackingConnector) Driver() driver.Driver {
return c.driver
}
type connTrackingConn struct {
driver *connTrackingDriver
}
func (c *connTrackingConn) Prepare(query string) (driver.Stmt, error) { return testStmt{}, nil }
func (c *connTrackingConn) Close() error {
c.driver.openConns.Add(-1)
return nil
}
func (c *connTrackingConn) Begin() (driver.Tx, error) { return testTx{}, nil }
func (c *connTrackingConn) BeginTx(context.Context, driver.TxOptions) (driver.Tx, error) {
return testTx{}, nil
}