This commit is contained in:
Alex Shevchuk
2025-08-18 17:12:04 +03:00
commit d84487d238
157 changed files with 160686 additions and 0 deletions

View File

@@ -0,0 +1,485 @@
package pgdb
import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
dberrors "git-molva.ru/Molva/molva-backend/services/api_gateway/internal/database/errors"
dbtypes "git-molva.ru/Molva/molva-backend/services/api_gateway/internal/database/types"
"github.com/Masterminds/squirrel"
)
func getTransactionTypeByAmount(amount int64) dbtypes.TransactionType {
if amount > 0 {
return dbtypes.TransactionTypeDeposit
}
return dbtypes.TransactionTypeWithdrawal
}
// TODO: add migration to rebind statuses
func TransactionStatusIdToString(status int32) dbtypes.TransactionStatus {
switch status {
case 0:
return dbtypes.TransactionStatusPending
case 1:
return dbtypes.TransactionStatusApproved
case 2:
return dbtypes.TransactionStatusRejected
case 3:
return dbtypes.TransactionStatusNew
default:
return dbtypes.TransactionStatusPending
}
}
func TransactionStatusStringToId(status dbtypes.TransactionStatus) int32 {
switch status {
case dbtypes.TransactionStatusNew:
return 3
case dbtypes.TransactionStatusPending:
return 0
case dbtypes.TransactionStatusApproved:
return 1
case dbtypes.TransactionStatusRejected:
return 2
default:
return 3
}
}
//nolint:gocognit // not so hard
func (c *client) GetTransactionList(
ctx context.Context,
request *dbtypes.TransactionListGetRequest,
) (*dbtypes.TransactionListGetResponse, error) {
if request == nil {
return nil, nil
}
var (
psql = squirrel.StatementBuilder.PlaceholderFormat(squirrel.Dollar)
transactionsTable = fmt.Sprintf("%s.%s", c.config.Schema, TransactionsTableName)
)
getTransactions := psql.Select(
"id", "owner_id", "bank_account_id", "amount", "currency", "status", "created_at", "payload",
).From(transactionsTable).
Where(squirrel.Eq{"owner_id": request.OwnerId}).
Limit(request.PageSize).
Offset(countOffset(request.Page, request.PageSize))
getTransactions, err := c.setGetTransactionsQueryFilters(getTransactions, request.Filters)
if err != nil {
return nil, fmt.Errorf("%w: error setting get transactions query filters: %v", dberrors.ErrBadRequest, err)
}
query, args, err := getTransactions.ToSql()
if err != nil {
return nil, fmt.Errorf("%w: error building get transactions query: %v", dberrors.ErrInternal, err)
}
rows, err := c.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, fmt.Errorf("%w: error executing get transactions query: %v", dberrors.ErrInternal, err)
}
defer rows.Close()
transList := &dbtypes.TransactionListGetResponse{
Transactions: make([]dbtypes.Transaction, 0, request.PageSize),
}
for rows.Next() {
var (
transStatus int32
ownerId string
payload, bankAccountId sql.NullString
transaction dbtypes.Transaction
)
if err := rows.Scan(
&transaction.Id, &ownerId, &bankAccountId, &transaction.Amount, &transaction.Currency,
&transStatus, &transaction.CreatedAt, &payload,
); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, dberrors.ErrNotFound
}
return nil, fmt.Errorf("%w: error scanning row: %v", dberrors.ErrInternal, err)
}
if bankAccountId.Valid {
bankAccountInfo, err := c.getBankAccountInfoById(ctx, c.db, bankAccountId.String)
if err != nil {
return nil, fmt.Errorf("%w: error getting bank account info: %v", dberrors.ErrInternal, err)
}
transaction.BankAccountInfo = bankAccountInfo
ownerInfo, err := c.getOwnerInfoById(ctx, c.db, ownerId, bankAccountInfo.OwnerType)
if err != nil {
return nil, fmt.Errorf("%w: error getting owner info: %v", dberrors.ErrInternal, err)
}
transaction.OwnerInfo = ownerInfo
}
transaction.Type = getTransactionTypeByAmount(transaction.Amount)
transaction.Status = TransactionStatusIdToString(transStatus)
if payload.Valid {
var payloadData dbtypes.TransactionPayload
if err := json.Unmarshal([]byte(payload.String), &payloadData); err != nil {
return nil, fmt.Errorf("%w: error unmarshaling transaction payload: %v", dberrors.ErrInternal, err)
}
transaction.Payload = &payloadData
}
transList.Transactions = append(transList.Transactions, transaction)
}
return transList, nil
}
func (c *client) setGetTransactionsQueryFilters(
query squirrel.SelectBuilder,
filters *dbtypes.TransactionListFilters,
) (squirrel.SelectBuilder, error) {
if filters == nil {
return query, nil
}
if filters.Type != nil {
switch *filters.Type {
case dbtypes.TransactionTypeDeposit:
query = query.Where(squirrel.Gt{
"amount": 0,
})
case dbtypes.TransactionTypeWithdrawal:
query = query.Where(squirrel.Lt{
"amount": 0,
})
default:
return query, fmt.Errorf("%w: invalid transaction type: %v", dberrors.ErrBadRequest, *filters.Type)
}
}
if filters.Status != nil {
query = query.Where(squirrel.Eq{"status": TransactionStatusStringToId(*filters.Status)})
}
if filters.BankAccountId != nil {
query = query.Where(squirrel.Eq{"bank_account_id": *filters.BankAccountId})
}
return query, nil
}
func (c *client) getOwnerInfoById(
ctx context.Context,
driver Driver,
ownerId string,
ownerType string,
) (*dbtypes.TransactionOwnerInfo, error) {
var (
psql = squirrel.StatementBuilder.PlaceholderFormat(squirrel.Dollar)
clientTableName = fmt.Sprintf("%s.%s", c.config.Schema, UsersTableName)
companyTableName = fmt.Sprintf("%s.%s", c.config.Schema, CompaniesTableName)
)
var getOwnerInfoById squirrel.SelectBuilder
// TODO: reingeneer the DB
switch ownerType {
case "agent":
getOwnerInfoById = psql.Select(
"uid",
"name",
).
From(clientTableName).
Where(squirrel.Eq{"uid": ownerId})
case "company":
getOwnerInfoById = psql.Select(
"id",
"name",
).
From(companyTableName).
Where(squirrel.Eq{"id": ownerId})
default:
return nil, fmt.Errorf("%w: invalid owner type", dberrors.ErrBadRequest)
}
query, args, err := getOwnerInfoById.ToSql()
if err != nil {
return nil, fmt.Errorf("%w: error building get owner info by id query: %v", dberrors.ErrInternal, err)
}
row := driver.QueryRowContext(ctx, query, args...)
var ownerInfo dbtypes.TransactionOwnerInfo
if err := row.Scan(
&ownerInfo.Id,
&ownerInfo.Name,
); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, dberrors.ErrNotFound
}
return nil, fmt.Errorf("%w: error scanning row for get owner info by id query: %v", dberrors.ErrInternal, err)
}
return &ownerInfo, nil
}
func (c *client) getBankAccountInfoById(
ctx context.Context,
driver Driver,
bankAccountId string,
) (*dbtypes.BankAccountInfo, error) {
var (
psql = squirrel.StatementBuilder.PlaceholderFormat(squirrel.Dollar)
bankAccountsTableName = fmt.Sprintf("%s.%s", c.config.Schema, BankAccountsTableName)
)
getBankAccountInfoById := psql.Select(
"id",
"account_number",
"bank_name",
"bik",
"correspondent_account",
"owner_type",
).
From(bankAccountsTableName).
Where(squirrel.Eq{"id": bankAccountId})
query, args, err := getBankAccountInfoById.ToSql()
if err != nil {
return nil, fmt.Errorf("%w: error building get bank account info by id query: %v", dberrors.ErrInternal, err)
}
row := driver.QueryRowContext(ctx, query, args...)
var accountInfo dbtypes.BankAccountInfo
if err := row.Scan(
&accountInfo.Id,
&accountInfo.AccountNumber,
&accountInfo.BankName,
&accountInfo.Bik,
&accountInfo.CorrespondentAccount,
&accountInfo.OwnerType,
); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, dberrors.ErrNotFound
}
return nil, fmt.Errorf("%w: error scanning row for get bank account info by id query: %v", dberrors.ErrInternal, err)
}
return &accountInfo, nil
}
func (c *client) CreateTransaction(
ctx context.Context,
request *dbtypes.TransactionCreateRequest,
) (*dbtypes.TransactionCreateResponse, error) {
if request == nil {
return nil, fmt.Errorf("%w: request is nil", dberrors.ErrBadRequest)
}
tx, err := c.db.BeginTx(ctx, nil)
if err != nil {
return nil, fmt.Errorf("%w: error starting transaction: %v", dberrors.ErrInternal, err)
}
defer func() { _ = tx.Rollback() }()
result, err := c.createTransactionWithDriver(ctx, tx, request)
if err != nil {
return nil, fmt.Errorf("error creating transaction: %w", err)
}
if err := tx.Commit(); err != nil {
return nil, fmt.Errorf("%w: error committing transaction: %w", dberrors.ErrInternal, err)
}
return result, nil
}
func (c *client) createTransactionWithDriver(
ctx context.Context,
driver Driver,
request *dbtypes.TransactionCreateRequest,
) (*dbtypes.TransactionCreateResponse, error) {
if _, err := c.getRawBalanceForUpdate(ctx, driver, request.OwnerId); err != nil {
return nil, fmt.Errorf("error getting raw balance for update: %w", err)
}
result, err := c.createTransaction(ctx, driver, request)
if err != nil {
if errors.Is(err, dberrors.ErrConflict) {
return result, nil
}
return nil, fmt.Errorf("error creating transaction: %w", err)
}
if err := c.updateBalance(ctx, driver, request.Amount, request.OwnerId); err != nil {
return nil, fmt.Errorf("error updating balance: %w", err)
}
return result, nil
}
func (c *client) getRawBalanceForUpdate(
ctx context.Context,
driver Driver,
ownerId string,
) (int64, error) {
var (
psql = squirrel.StatementBuilder.PlaceholderFormat(squirrel.Dollar)
balancesTable = fmt.Sprintf("%s.%s", c.config.Schema, BalancesTableName)
)
getBalance := psql.Select(
"raw_balance",
).From(balancesTable).
Where(squirrel.Eq{"owner_id": ownerId}).
Suffix("FOR UPDATE")
query, args, err := getBalance.ToSql()
if err != nil {
return 0, fmt.Errorf("%w: error building 'get balance' query: %v", dberrors.ErrInternal, err)
}
row := driver.QueryRowContext(ctx, query, args...)
var rawBalance int64
if err := row.Scan(&rawBalance); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return 0, dberrors.ErrNotFound
}
return 0, fmt.Errorf("%w: error scanning row: %v", dberrors.ErrInternal, err)
}
return rawBalance, nil
}
func (c *client) updateBalance(
ctx context.Context,
driver Driver,
amountDelta int64,
ownerId string,
) error {
var (
psql = squirrel.StatementBuilder.PlaceholderFormat(squirrel.Dollar)
balancesTable = fmt.Sprintf("%s.%s", c.config.Schema, BalancesTableName)
)
updateBalance := psql.Update(balancesTable).
SetMap(map[string]any{
"raw_balance": squirrel.Expr("raw_balance + ?", amountDelta),
"updated_at": squirrel.Expr("NOW()"),
}).
Where(squirrel.Eq{"owner_id": ownerId})
query, args, err := updateBalance.ToSql()
if err != nil {
return fmt.Errorf("%w: error building 'update balance' query: %v", dberrors.ErrInternal, err)
}
res, err := driver.ExecContext(ctx, query, args...)
if err != nil {
return fmt.Errorf("%w: error executing 'update balance' query: %v", dberrors.ErrInternal, err)
}
rowsAffected, err := res.RowsAffected()
if err != nil {
return fmt.Errorf("%w: error getting rows affected for 'update balance' query: %v", dberrors.ErrInternal, err)
}
if rowsAffected == 0 {
return dberrors.ErrInternal
}
return nil
}
func (c *client) createTransaction(
ctx context.Context,
driver Driver,
request *dbtypes.TransactionCreateRequest,
) (*dbtypes.TransactionCreateResponse, error) {
var (
psql = squirrel.StatementBuilder.PlaceholderFormat(squirrel.Dollar)
transactionsTable = fmt.Sprintf("%s.%s", c.config.Schema, TransactionsTableName)
)
var payload []byte
if request.Payload == nil {
payload = []byte("{}")
} else {
payloadBytes, err := json.Marshal(request.Payload)
if err != nil {
return nil, fmt.Errorf("%w: error marshaling transaction payload: %v", dberrors.ErrInternal, err)
}
payload = payloadBytes
}
createTransaction := psql.Insert(transactionsTable).
Columns(
"id", "owner_id", "bank_account_id", "amount", "currency", "status", "created_at", "payload",
).
Values(
request.RequestId, request.OwnerId, request.BankAccountId, request.Amount, request.Currency,
dbtypes.TransactionStatusNew, squirrel.Expr("CURRENT_TIMESTAMP"), payload,
).
Suffix("ON CONFLICT (id, owner_id) DO NOTHING")
query, args, err := createTransaction.ToSql()
if err != nil {
return nil, fmt.Errorf("%w: error building 'create transaction' query: %v", dberrors.ErrInternal, err)
}
res, err := driver.ExecContext(ctx, query, args...)
if err != nil {
return nil, fmt.Errorf("%w: error executing 'create transaction' query: %v", dberrors.ErrInternal, err)
}
rowsAffected, err := res.RowsAffected()
if err != nil {
return nil, fmt.Errorf("%w: error getting rows affected for 'create transaction' query: %v", dberrors.ErrInternal, err)
}
if rowsAffected == 0 {
return &dbtypes.TransactionCreateResponse{
Id: request.RequestId,
}, dberrors.ErrConflict
}
return &dbtypes.TransactionCreateResponse{
Id: request.RequestId,
}, nil
}