Skip to content
impl_postgres.go 2.08 KiB
Newer Older
Lo^2's avatar
Lo^2 committed
package db

import (
	"code.electrolab.fr/it/vote.electrolab.fr/service"
	"context"
	"database/sql"
	"fmt"
	"github.com/lib/pq"
)

type DBPostgres struct {
	connString string
	conn       *pq.Connector
	db         *sql.DB
}

func NewDBPostgres(connString string) *DBPostgres {
	return &DBPostgres{
		connString: connString,
	}
}

func (db *DBPostgres) Name() string {
	return fmt.Sprintf("%T", db)[1:]
}

func (db *DBPostgres) Start() error {
	var err error

	db.conn, err = pq.NewConnector(db.connString)
	if err != nil {
		return fmt.Errorf("DBPostgres config: %s", err.Error())
	}
	db.db = sql.OpenDB(db.conn)

	if err = db.db.Ping(); err != nil {
		return fmt.Errorf("DBPostgres connecting: %s", err.Error())
	}
	return nil
}

func (db *DBPostgres) Stop() error {
	if err := db.db.Close(); err != nil {
		return fmt.Errorf("DBPostgres closing: %s", err.Error())
	}
	return nil
}

func (db *DBPostgres) Dependencies() []service.Service {
	return nil
}

func (db *DBPostgres) ExecContext(ctx context.Context, query string, args ...interface{}) (Result, error) {
	return db.db.ExecContext(ctx, query, args...)
}
func (db *DBPostgres) QueryContext(ctx context.Context, query string, args ...interface{}) (Rows, error) {
	return db.db.QueryContext(ctx, query, args...)
}
func (db *DBPostgres) QueryRowContext(ctx context.Context, query string, args ...interface{}) Row {
	return db.db.QueryRowContext(ctx, query, args...)
}

func (db *DBPostgres) Begin() (Tx, error) {
	tx, err := db.db.Begin()
	return sqlTx{tx}, err
}

type sqlTx struct {
	tx *sql.Tx
}

func (tx sqlTx) Commit() error {
	return tx.tx.Commit()
}

func (tx sqlTx) Rollback() error {
	return tx.tx.Rollback()
}

func (tx sqlTx) ExecContext(ctx context.Context, query string, args ...interface{}) (Result, error) {
	return tx.tx.ExecContext(ctx, query, args...)
}
func (tx sqlTx) QueryContext(ctx context.Context, query string, args ...interface{}) (Rows, error) {
	return tx.tx.QueryContext(ctx, query, args...)
}
func (tx sqlTx) QueryRowContext(ctx context.Context, query string, args ...interface{}) Row {
	return tx.tx.QueryRowContext(ctx, query, args...)
}