1
This commit is contained in:
485
internal/database/postgres/transaction.go
Normal file
485
internal/database/postgres/transaction.go
Normal file
@@ -0,0 +1,485 @@
|
||||
package pgdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
dberrors "git-molva.ru/Molva/molva-backend/services/api_gateway/internal/database/errors"
|
||||
dbtypes "git-molva.ru/Molva/molva-backend/services/api_gateway/internal/database/types"
|
||||
"github.com/Masterminds/squirrel"
|
||||
)
|
||||
|
||||
func getTransactionTypeByAmount(amount int64) dbtypes.TransactionType {
|
||||
if amount > 0 {
|
||||
return dbtypes.TransactionTypeDeposit
|
||||
}
|
||||
|
||||
return dbtypes.TransactionTypeWithdrawal
|
||||
}
|
||||
|
||||
// TODO: add migration to rebind statuses
|
||||
|
||||
func TransactionStatusIdToString(status int32) dbtypes.TransactionStatus {
|
||||
switch status {
|
||||
case 0:
|
||||
return dbtypes.TransactionStatusPending
|
||||
case 1:
|
||||
return dbtypes.TransactionStatusApproved
|
||||
case 2:
|
||||
return dbtypes.TransactionStatusRejected
|
||||
case 3:
|
||||
return dbtypes.TransactionStatusNew
|
||||
default:
|
||||
return dbtypes.TransactionStatusPending
|
||||
}
|
||||
}
|
||||
|
||||
func TransactionStatusStringToId(status dbtypes.TransactionStatus) int32 {
|
||||
switch status {
|
||||
case dbtypes.TransactionStatusNew:
|
||||
return 3
|
||||
case dbtypes.TransactionStatusPending:
|
||||
return 0
|
||||
case dbtypes.TransactionStatusApproved:
|
||||
return 1
|
||||
case dbtypes.TransactionStatusRejected:
|
||||
return 2
|
||||
default:
|
||||
return 3
|
||||
}
|
||||
}
|
||||
|
||||
//nolint:gocognit // not so hard
|
||||
func (c *client) GetTransactionList(
|
||||
ctx context.Context,
|
||||
request *dbtypes.TransactionListGetRequest,
|
||||
) (*dbtypes.TransactionListGetResponse, error) {
|
||||
if request == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var (
|
||||
psql = squirrel.StatementBuilder.PlaceholderFormat(squirrel.Dollar)
|
||||
|
||||
transactionsTable = fmt.Sprintf("%s.%s", c.config.Schema, TransactionsTableName)
|
||||
)
|
||||
|
||||
getTransactions := psql.Select(
|
||||
"id", "owner_id", "bank_account_id", "amount", "currency", "status", "created_at", "payload",
|
||||
).From(transactionsTable).
|
||||
Where(squirrel.Eq{"owner_id": request.OwnerId}).
|
||||
Limit(request.PageSize).
|
||||
Offset(countOffset(request.Page, request.PageSize))
|
||||
|
||||
getTransactions, err := c.setGetTransactionsQueryFilters(getTransactions, request.Filters)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: error setting get transactions query filters: %v", dberrors.ErrBadRequest, err)
|
||||
}
|
||||
|
||||
query, args, err := getTransactions.ToSql()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: error building get transactions query: %v", dberrors.ErrInternal, err)
|
||||
}
|
||||
|
||||
rows, err := c.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: error executing get transactions query: %v", dberrors.ErrInternal, err)
|
||||
}
|
||||
|
||||
defer rows.Close()
|
||||
|
||||
transList := &dbtypes.TransactionListGetResponse{
|
||||
Transactions: make([]dbtypes.Transaction, 0, request.PageSize),
|
||||
}
|
||||
|
||||
for rows.Next() {
|
||||
var (
|
||||
transStatus int32
|
||||
ownerId string
|
||||
payload, bankAccountId sql.NullString
|
||||
transaction dbtypes.Transaction
|
||||
)
|
||||
|
||||
if err := rows.Scan(
|
||||
&transaction.Id, &ownerId, &bankAccountId, &transaction.Amount, &transaction.Currency,
|
||||
&transStatus, &transaction.CreatedAt, &payload,
|
||||
); err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, dberrors.ErrNotFound
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("%w: error scanning row: %v", dberrors.ErrInternal, err)
|
||||
}
|
||||
|
||||
if bankAccountId.Valid {
|
||||
bankAccountInfo, err := c.getBankAccountInfoById(ctx, c.db, bankAccountId.String)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: error getting bank account info: %v", dberrors.ErrInternal, err)
|
||||
}
|
||||
|
||||
transaction.BankAccountInfo = bankAccountInfo
|
||||
|
||||
ownerInfo, err := c.getOwnerInfoById(ctx, c.db, ownerId, bankAccountInfo.OwnerType)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: error getting owner info: %v", dberrors.ErrInternal, err)
|
||||
}
|
||||
|
||||
transaction.OwnerInfo = ownerInfo
|
||||
}
|
||||
|
||||
transaction.Type = getTransactionTypeByAmount(transaction.Amount)
|
||||
transaction.Status = TransactionStatusIdToString(transStatus)
|
||||
|
||||
if payload.Valid {
|
||||
var payloadData dbtypes.TransactionPayload
|
||||
|
||||
if err := json.Unmarshal([]byte(payload.String), &payloadData); err != nil {
|
||||
return nil, fmt.Errorf("%w: error unmarshaling transaction payload: %v", dberrors.ErrInternal, err)
|
||||
}
|
||||
|
||||
transaction.Payload = &payloadData
|
||||
}
|
||||
|
||||
transList.Transactions = append(transList.Transactions, transaction)
|
||||
}
|
||||
|
||||
return transList, nil
|
||||
}
|
||||
|
||||
func (c *client) setGetTransactionsQueryFilters(
|
||||
query squirrel.SelectBuilder,
|
||||
filters *dbtypes.TransactionListFilters,
|
||||
) (squirrel.SelectBuilder, error) {
|
||||
if filters == nil {
|
||||
return query, nil
|
||||
}
|
||||
|
||||
if filters.Type != nil {
|
||||
switch *filters.Type {
|
||||
case dbtypes.TransactionTypeDeposit:
|
||||
query = query.Where(squirrel.Gt{
|
||||
"amount": 0,
|
||||
})
|
||||
|
||||
case dbtypes.TransactionTypeWithdrawal:
|
||||
query = query.Where(squirrel.Lt{
|
||||
"amount": 0,
|
||||
})
|
||||
|
||||
default:
|
||||
return query, fmt.Errorf("%w: invalid transaction type: %v", dberrors.ErrBadRequest, *filters.Type)
|
||||
}
|
||||
}
|
||||
|
||||
if filters.Status != nil {
|
||||
query = query.Where(squirrel.Eq{"status": TransactionStatusStringToId(*filters.Status)})
|
||||
}
|
||||
|
||||
if filters.BankAccountId != nil {
|
||||
query = query.Where(squirrel.Eq{"bank_account_id": *filters.BankAccountId})
|
||||
}
|
||||
|
||||
return query, nil
|
||||
}
|
||||
|
||||
func (c *client) getOwnerInfoById(
|
||||
ctx context.Context,
|
||||
driver Driver,
|
||||
ownerId string,
|
||||
ownerType string,
|
||||
) (*dbtypes.TransactionOwnerInfo, error) {
|
||||
var (
|
||||
psql = squirrel.StatementBuilder.PlaceholderFormat(squirrel.Dollar)
|
||||
|
||||
clientTableName = fmt.Sprintf("%s.%s", c.config.Schema, UsersTableName)
|
||||
companyTableName = fmt.Sprintf("%s.%s", c.config.Schema, CompaniesTableName)
|
||||
)
|
||||
|
||||
var getOwnerInfoById squirrel.SelectBuilder
|
||||
|
||||
// TODO: reingeneer the DB
|
||||
switch ownerType {
|
||||
case "agent":
|
||||
getOwnerInfoById = psql.Select(
|
||||
"uid",
|
||||
"name",
|
||||
).
|
||||
From(clientTableName).
|
||||
Where(squirrel.Eq{"uid": ownerId})
|
||||
|
||||
case "company":
|
||||
getOwnerInfoById = psql.Select(
|
||||
"id",
|
||||
"name",
|
||||
).
|
||||
From(companyTableName).
|
||||
Where(squirrel.Eq{"id": ownerId})
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("%w: invalid owner type", dberrors.ErrBadRequest)
|
||||
}
|
||||
|
||||
query, args, err := getOwnerInfoById.ToSql()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: error building get owner info by id query: %v", dberrors.ErrInternal, err)
|
||||
}
|
||||
|
||||
row := driver.QueryRowContext(ctx, query, args...)
|
||||
|
||||
var ownerInfo dbtypes.TransactionOwnerInfo
|
||||
|
||||
if err := row.Scan(
|
||||
&ownerInfo.Id,
|
||||
&ownerInfo.Name,
|
||||
); err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, dberrors.ErrNotFound
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("%w: error scanning row for get owner info by id query: %v", dberrors.ErrInternal, err)
|
||||
}
|
||||
|
||||
return &ownerInfo, nil
|
||||
}
|
||||
|
||||
func (c *client) getBankAccountInfoById(
|
||||
ctx context.Context,
|
||||
driver Driver,
|
||||
bankAccountId string,
|
||||
) (*dbtypes.BankAccountInfo, error) {
|
||||
var (
|
||||
psql = squirrel.StatementBuilder.PlaceholderFormat(squirrel.Dollar)
|
||||
|
||||
bankAccountsTableName = fmt.Sprintf("%s.%s", c.config.Schema, BankAccountsTableName)
|
||||
)
|
||||
|
||||
getBankAccountInfoById := psql.Select(
|
||||
"id",
|
||||
"account_number",
|
||||
"bank_name",
|
||||
"bik",
|
||||
"correspondent_account",
|
||||
"owner_type",
|
||||
).
|
||||
From(bankAccountsTableName).
|
||||
Where(squirrel.Eq{"id": bankAccountId})
|
||||
|
||||
query, args, err := getBankAccountInfoById.ToSql()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: error building get bank account info by id query: %v", dberrors.ErrInternal, err)
|
||||
}
|
||||
|
||||
row := driver.QueryRowContext(ctx, query, args...)
|
||||
|
||||
var accountInfo dbtypes.BankAccountInfo
|
||||
|
||||
if err := row.Scan(
|
||||
&accountInfo.Id,
|
||||
&accountInfo.AccountNumber,
|
||||
&accountInfo.BankName,
|
||||
&accountInfo.Bik,
|
||||
&accountInfo.CorrespondentAccount,
|
||||
&accountInfo.OwnerType,
|
||||
); err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, dberrors.ErrNotFound
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("%w: error scanning row for get bank account info by id query: %v", dberrors.ErrInternal, err)
|
||||
}
|
||||
|
||||
return &accountInfo, nil
|
||||
}
|
||||
|
||||
func (c *client) CreateTransaction(
|
||||
ctx context.Context,
|
||||
request *dbtypes.TransactionCreateRequest,
|
||||
) (*dbtypes.TransactionCreateResponse, error) {
|
||||
if request == nil {
|
||||
return nil, fmt.Errorf("%w: request is nil", dberrors.ErrBadRequest)
|
||||
}
|
||||
|
||||
tx, err := c.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: error starting transaction: %v", dberrors.ErrInternal, err)
|
||||
}
|
||||
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
|
||||
result, err := c.createTransactionWithDriver(ctx, tx, request)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating transaction: %w", err)
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return nil, fmt.Errorf("%w: error committing transaction: %w", dberrors.ErrInternal, err)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (c *client) createTransactionWithDriver(
|
||||
ctx context.Context,
|
||||
driver Driver,
|
||||
request *dbtypes.TransactionCreateRequest,
|
||||
) (*dbtypes.TransactionCreateResponse, error) {
|
||||
if _, err := c.getRawBalanceForUpdate(ctx, driver, request.OwnerId); err != nil {
|
||||
return nil, fmt.Errorf("error getting raw balance for update: %w", err)
|
||||
}
|
||||
|
||||
result, err := c.createTransaction(ctx, driver, request)
|
||||
if err != nil {
|
||||
if errors.Is(err, dberrors.ErrConflict) {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("error creating transaction: %w", err)
|
||||
}
|
||||
|
||||
if err := c.updateBalance(ctx, driver, request.Amount, request.OwnerId); err != nil {
|
||||
return nil, fmt.Errorf("error updating balance: %w", err)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (c *client) getRawBalanceForUpdate(
|
||||
ctx context.Context,
|
||||
driver Driver,
|
||||
ownerId string,
|
||||
) (int64, error) {
|
||||
var (
|
||||
psql = squirrel.StatementBuilder.PlaceholderFormat(squirrel.Dollar)
|
||||
|
||||
balancesTable = fmt.Sprintf("%s.%s", c.config.Schema, BalancesTableName)
|
||||
)
|
||||
|
||||
getBalance := psql.Select(
|
||||
"raw_balance",
|
||||
).From(balancesTable).
|
||||
Where(squirrel.Eq{"owner_id": ownerId}).
|
||||
Suffix("FOR UPDATE")
|
||||
|
||||
query, args, err := getBalance.ToSql()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("%w: error building 'get balance' query: %v", dberrors.ErrInternal, err)
|
||||
}
|
||||
|
||||
row := driver.QueryRowContext(ctx, query, args...)
|
||||
|
||||
var rawBalance int64
|
||||
|
||||
if err := row.Scan(&rawBalance); err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return 0, dberrors.ErrNotFound
|
||||
}
|
||||
|
||||
return 0, fmt.Errorf("%w: error scanning row: %v", dberrors.ErrInternal, err)
|
||||
}
|
||||
|
||||
return rawBalance, nil
|
||||
}
|
||||
|
||||
func (c *client) updateBalance(
|
||||
ctx context.Context,
|
||||
driver Driver,
|
||||
amountDelta int64,
|
||||
ownerId string,
|
||||
) error {
|
||||
var (
|
||||
psql = squirrel.StatementBuilder.PlaceholderFormat(squirrel.Dollar)
|
||||
|
||||
balancesTable = fmt.Sprintf("%s.%s", c.config.Schema, BalancesTableName)
|
||||
)
|
||||
|
||||
updateBalance := psql.Update(balancesTable).
|
||||
SetMap(map[string]any{
|
||||
"raw_balance": squirrel.Expr("raw_balance + ?", amountDelta),
|
||||
"updated_at": squirrel.Expr("NOW()"),
|
||||
}).
|
||||
Where(squirrel.Eq{"owner_id": ownerId})
|
||||
|
||||
query, args, err := updateBalance.ToSql()
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: error building 'update balance' query: %v", dberrors.ErrInternal, err)
|
||||
}
|
||||
|
||||
res, err := driver.ExecContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: error executing 'update balance' query: %v", dberrors.ErrInternal, err)
|
||||
}
|
||||
|
||||
rowsAffected, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: error getting rows affected for 'update balance' query: %v", dberrors.ErrInternal, err)
|
||||
}
|
||||
|
||||
if rowsAffected == 0 {
|
||||
return dberrors.ErrInternal
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *client) createTransaction(
|
||||
ctx context.Context,
|
||||
driver Driver,
|
||||
request *dbtypes.TransactionCreateRequest,
|
||||
) (*dbtypes.TransactionCreateResponse, error) {
|
||||
var (
|
||||
psql = squirrel.StatementBuilder.PlaceholderFormat(squirrel.Dollar)
|
||||
|
||||
transactionsTable = fmt.Sprintf("%s.%s", c.config.Schema, TransactionsTableName)
|
||||
)
|
||||
|
||||
var payload []byte
|
||||
|
||||
if request.Payload == nil {
|
||||
payload = []byte("{}")
|
||||
} else {
|
||||
payloadBytes, err := json.Marshal(request.Payload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: error marshaling transaction payload: %v", dberrors.ErrInternal, err)
|
||||
}
|
||||
|
||||
payload = payloadBytes
|
||||
}
|
||||
|
||||
createTransaction := psql.Insert(transactionsTable).
|
||||
Columns(
|
||||
"id", "owner_id", "bank_account_id", "amount", "currency", "status", "created_at", "payload",
|
||||
).
|
||||
Values(
|
||||
request.RequestId, request.OwnerId, request.BankAccountId, request.Amount, request.Currency,
|
||||
dbtypes.TransactionStatusNew, squirrel.Expr("CURRENT_TIMESTAMP"), payload,
|
||||
).
|
||||
Suffix("ON CONFLICT (id, owner_id) DO NOTHING")
|
||||
|
||||
query, args, err := createTransaction.ToSql()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: error building 'create transaction' query: %v", dberrors.ErrInternal, err)
|
||||
}
|
||||
|
||||
res, err := driver.ExecContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: error executing 'create transaction' query: %v", dberrors.ErrInternal, err)
|
||||
}
|
||||
|
||||
rowsAffected, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: error getting rows affected for 'create transaction' query: %v", dberrors.ErrInternal, err)
|
||||
}
|
||||
|
||||
if rowsAffected == 0 {
|
||||
return &dbtypes.TransactionCreateResponse{
|
||||
Id: request.RequestId,
|
||||
}, dberrors.ErrConflict
|
||||
}
|
||||
|
||||
return &dbtypes.TransactionCreateResponse{
|
||||
Id: request.RequestId,
|
||||
}, nil
|
||||
}
|
Reference in New Issue
Block a user