mirror of
https://github.com/grafana/grafana.git
synced 2025-12-21 12:04:45 +08:00
Compare commits
5 Commits
docs/add-t
...
unified-st
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
201d743eee | ||
|
|
76e7d56cf4 | ||
|
|
d387521230 | ||
|
|
3834dc989f | ||
|
|
96f586fa45 |
@@ -3,15 +3,25 @@ package dbimpl
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/grafana/grafana/pkg/infra/log"
|
||||
"github.com/grafana/grafana/pkg/storage/unified/sql/db"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultConnAttemptTimeout = 3 * time.Second
|
||||
defaultConnMaxRetries = 3
|
||||
defaultConnRetryBackoff = 1 * time.Second
|
||||
)
|
||||
|
||||
// NewDB converts a *sql.DB to a db.DB.
|
||||
func NewDB(d *sql.DB, driverName string) db.DB {
|
||||
ret := sqlDB{
|
||||
DB: d,
|
||||
driverName: driverName,
|
||||
log: log.New("resource-db"),
|
||||
}
|
||||
ret.WithTxFunc = db.NewWithTxFunc(ret.BeginTx)
|
||||
|
||||
@@ -22,6 +32,7 @@ type sqlDB struct {
|
||||
*sql.DB
|
||||
db.WithTxFunc
|
||||
driverName string
|
||||
log log.Logger
|
||||
}
|
||||
|
||||
func (d sqlDB) DriverName() string {
|
||||
@@ -37,20 +48,77 @@ func (d sqlDB) QueryRowContext(ctx context.Context, query string, args ...any) d
|
||||
}
|
||||
|
||||
func (d sqlDB) BeginTx(ctx context.Context, opts *sql.TxOptions) (db.Tx, error) {
|
||||
tx, err := d.DB.BeginTx(ctx, opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return sqlTx{tx}, err
|
||||
var tx *sql.Tx
|
||||
var err error
|
||||
var conn *sql.Conn
|
||||
|
||||
// try to acquire a connection with retries on transient errors
|
||||
for attempt := 1; attempt <= defaultConnMaxRetries; attempt++ {
|
||||
connCtx, cancel := context.WithTimeout(ctx, defaultConnAttemptTimeout)
|
||||
conn, err = d.DB.Conn(connCtx)
|
||||
cancel()
|
||||
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
|
||||
// retry if connection deadline exceeded and attempts remain
|
||||
if errors.Is(err, context.DeadlineExceeded) && attempt < defaultConnMaxRetries {
|
||||
d.log.Warn("Timeout when acquiring database connection, retrying", "attempt", attempt, "max_retries", defaultConnMaxRetries)
|
||||
time.Sleep(defaultConnRetryBackoff)
|
||||
continue
|
||||
}
|
||||
|
||||
if errors.Is(err, context.DeadlineExceeded) {
|
||||
d.log.Error("Timeout exceeded while trying to acquire database connection", "attempt", attempt, "max_retries", defaultConnMaxRetries)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// once we have a connection, begin the transaction
|
||||
tx, err = conn.BeginTx(ctx, opts)
|
||||
if err != nil {
|
||||
if closeErr := conn.Close(); closeErr != nil {
|
||||
d.log.Error("Failed to close connection after BeginTx error", "error", closeErr)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return sqlTx{Tx: tx, conn: conn}, nil
|
||||
}
|
||||
|
||||
// sqlTx wraps sql.Tx to add connection management
|
||||
// since we are manually acquiring the connection in BeginTx() we need to also manually close it after Commit/Rollback
|
||||
type sqlTx struct {
|
||||
*sql.Tx
|
||||
conn *sql.Conn
|
||||
}
|
||||
|
||||
// NewTx wraps an existing *sql.Tx with sqlTx
|
||||
func NewTx(tx *sql.Tx) db.Tx {
|
||||
return sqlTx{tx}
|
||||
return sqlTx{Tx: tx}
|
||||
}
|
||||
|
||||
func (tx sqlTx) Commit() error {
|
||||
err := tx.Tx.Commit()
|
||||
if tx.conn != nil {
|
||||
if err = tx.conn.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (tx sqlTx) Rollback() error {
|
||||
err := tx.Tx.Rollback()
|
||||
if tx.conn != nil {
|
||||
if err = tx.conn.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (tx sqlTx) QueryContext(ctx context.Context, query string, args ...any) (db.Rows, error) {
|
||||
|
||||
@@ -1,8 +1,13 @@
|
||||
package dbimpl
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
@@ -25,3 +30,163 @@ func TestDB_BeginTx(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, tx)
|
||||
}
|
||||
|
||||
func TestDB_BeginTx_RetriesOnTimeout(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Register a driver that times out on the first two connection attempts
|
||||
driverName := "test-timeout-driver-" + t.Name()
|
||||
timeoutDriver := &timeoutTestDriver{
|
||||
timeoutsRemaining: 2,
|
||||
}
|
||||
sql.Register(driverName, timeoutDriver)
|
||||
|
||||
sqlDB, err := sql.Open(driverName, "")
|
||||
require.NoError(t, err)
|
||||
sqlDB.SetMaxOpenConns(1)
|
||||
|
||||
d := NewDB(sqlDB, driverName)
|
||||
|
||||
// Use a longer timeout context since we expect retries with backoff
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
tx, err := d.BeginTx(ctx, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, tx)
|
||||
|
||||
// Should have attempted 3 connections (2 timeouts + 1 success)
|
||||
require.Equal(t, int32(3), timeoutDriver.connAttempts.Load())
|
||||
}
|
||||
|
||||
func TestDB_BeginTx_ConnectionReleasedAfterCommitAndRollback(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Register a driver that tracks open connections
|
||||
driverName := "test-conn-tracking-driver-" + t.Name()
|
||||
trackingDriver := &connTrackingDriver{}
|
||||
sql.Register(driverName, trackingDriver)
|
||||
|
||||
sqlDB, err := sql.Open(driverName, "")
|
||||
require.NoError(t, err)
|
||||
// Setting MaxIdleConns to 0 ensures connections are closed when returned to pool
|
||||
sqlDB.SetMaxIdleConns(0)
|
||||
sqlDB.SetMaxOpenConns(1)
|
||||
|
||||
d := NewDB(sqlDB, driverName)
|
||||
ctx := testutil.NewDefaultTestContext(t)
|
||||
|
||||
// Test Commit releases connection
|
||||
tx1, err := d.BeginTx(ctx, nil)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int32(1), trackingDriver.openConns.Load(), "should have 1 open connection after BeginTx")
|
||||
|
||||
err = tx1.Commit()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int32(0), trackingDriver.openConns.Load(), "should have 0 open connections after Commit")
|
||||
|
||||
// Test Rollback releases connection
|
||||
tx2, err := d.BeginTx(ctx, nil)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int32(1), trackingDriver.openConns.Load(), "should have 1 open connection after BeginTx")
|
||||
|
||||
err = tx2.Rollback()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int32(0), trackingDriver.openConns.Load(), "should have 0 open connections after Rollback")
|
||||
}
|
||||
|
||||
// timeoutTestDriver simulates connection timeouts for testing retry logic
|
||||
type timeoutTestDriver struct {
|
||||
mu sync.Mutex
|
||||
timeoutsRemaining int
|
||||
connAttempts atomic.Int32
|
||||
}
|
||||
|
||||
func (d *timeoutTestDriver) Open(name string) (driver.Conn, error) {
|
||||
return d.connect(context.Background())
|
||||
}
|
||||
|
||||
func (d *timeoutTestDriver) OpenConnector(name string) (driver.Connector, error) {
|
||||
return &timeoutTestConnector{driver: d}, nil
|
||||
}
|
||||
|
||||
func (d *timeoutTestDriver) connect(ctx context.Context) (driver.Conn, error) {
|
||||
d.connAttempts.Add(1)
|
||||
|
||||
d.mu.Lock()
|
||||
shouldTimeout := d.timeoutsRemaining > 0
|
||||
if shouldTimeout {
|
||||
d.timeoutsRemaining--
|
||||
}
|
||||
d.mu.Unlock()
|
||||
|
||||
if shouldTimeout {
|
||||
// Block until context is cancelled to simulate a slow connection
|
||||
<-ctx.Done()
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
|
||||
return &timeoutTestConn{}, nil
|
||||
}
|
||||
|
||||
type timeoutTestConnector struct {
|
||||
driver *timeoutTestDriver
|
||||
}
|
||||
|
||||
func (c *timeoutTestConnector) Connect(ctx context.Context) (driver.Conn, error) {
|
||||
return c.driver.connect(ctx)
|
||||
}
|
||||
|
||||
func (c *timeoutTestConnector) Driver() driver.Driver {
|
||||
return c.driver
|
||||
}
|
||||
|
||||
type timeoutTestConn struct{}
|
||||
|
||||
func (c *timeoutTestConn) Prepare(query string) (driver.Stmt, error) { return testStmt{}, nil }
|
||||
func (c *timeoutTestConn) Close() error { return nil }
|
||||
func (c *timeoutTestConn) Begin() (driver.Tx, error) { return testTx{}, nil }
|
||||
func (c *timeoutTestConn) BeginTx(context.Context, driver.TxOptions) (driver.Tx, error) {
|
||||
return testTx{}, nil
|
||||
}
|
||||
|
||||
// connTrackingDriver tracks open connections for testing connection lifecycle
|
||||
type connTrackingDriver struct {
|
||||
openConns atomic.Int32
|
||||
}
|
||||
|
||||
func (d *connTrackingDriver) Open(name string) (driver.Conn, error) {
|
||||
d.openConns.Add(1)
|
||||
return &connTrackingConn{driver: d}, nil
|
||||
}
|
||||
|
||||
func (d *connTrackingDriver) OpenConnector(name string) (driver.Connector, error) {
|
||||
return &connTrackingConnector{driver: d}, nil
|
||||
}
|
||||
|
||||
type connTrackingConnector struct {
|
||||
driver *connTrackingDriver
|
||||
}
|
||||
|
||||
func (c *connTrackingConnector) Connect(ctx context.Context) (driver.Conn, error) {
|
||||
c.driver.openConns.Add(1)
|
||||
return &connTrackingConn{driver: c.driver}, nil
|
||||
}
|
||||
|
||||
func (c *connTrackingConnector) Driver() driver.Driver {
|
||||
return c.driver
|
||||
}
|
||||
|
||||
type connTrackingConn struct {
|
||||
driver *connTrackingDriver
|
||||
}
|
||||
|
||||
func (c *connTrackingConn) Prepare(query string) (driver.Stmt, error) { return testStmt{}, nil }
|
||||
func (c *connTrackingConn) Close() error {
|
||||
c.driver.openConns.Add(-1)
|
||||
return nil
|
||||
}
|
||||
func (c *connTrackingConn) Begin() (driver.Tx, error) { return testTx{}, nil }
|
||||
func (c *connTrackingConn) BeginTx(context.Context, driver.TxOptions) (driver.Tx, error) {
|
||||
return testTx{}, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user