486 lines
13 KiB
Go
486 lines
13 KiB
Go
package pgdb
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
|
|
dberrors "git-molva.ru/Molva/molva-backend/services/api_gateway/internal/database/errors"
|
|
dbtypes "git-molva.ru/Molva/molva-backend/services/api_gateway/internal/database/types"
|
|
"github.com/Masterminds/squirrel"
|
|
)
|
|
|
|
func getTransactionTypeByAmount(amount int64) dbtypes.TransactionType {
|
|
if amount > 0 {
|
|
return dbtypes.TransactionTypeDeposit
|
|
}
|
|
|
|
return dbtypes.TransactionTypeWithdrawal
|
|
}
|
|
|
|
// TODO: add migration to rebind statuses
|
|
|
|
func TransactionStatusIdToString(status int32) dbtypes.TransactionStatus {
|
|
switch status {
|
|
case 0:
|
|
return dbtypes.TransactionStatusPending
|
|
case 1:
|
|
return dbtypes.TransactionStatusApproved
|
|
case 2:
|
|
return dbtypes.TransactionStatusRejected
|
|
case 3:
|
|
return dbtypes.TransactionStatusNew
|
|
default:
|
|
return dbtypes.TransactionStatusPending
|
|
}
|
|
}
|
|
|
|
func TransactionStatusStringToId(status dbtypes.TransactionStatus) int32 {
|
|
switch status {
|
|
case dbtypes.TransactionStatusNew:
|
|
return 3
|
|
case dbtypes.TransactionStatusPending:
|
|
return 0
|
|
case dbtypes.TransactionStatusApproved:
|
|
return 1
|
|
case dbtypes.TransactionStatusRejected:
|
|
return 2
|
|
default:
|
|
return 3
|
|
}
|
|
}
|
|
|
|
//nolint:gocognit // not so hard
|
|
func (c *client) GetTransactionList(
|
|
ctx context.Context,
|
|
request *dbtypes.TransactionListGetRequest,
|
|
) (*dbtypes.TransactionListGetResponse, error) {
|
|
if request == nil {
|
|
return nil, nil
|
|
}
|
|
|
|
var (
|
|
psql = squirrel.StatementBuilder.PlaceholderFormat(squirrel.Dollar)
|
|
|
|
transactionsTable = fmt.Sprintf("%s.%s", c.config.Schema, TransactionsTableName)
|
|
)
|
|
|
|
getTransactions := psql.Select(
|
|
"id", "owner_id", "bank_account_id", "amount", "currency", "status", "created_at", "payload",
|
|
).From(transactionsTable).
|
|
Where(squirrel.Eq{"owner_id": request.OwnerId}).
|
|
Limit(request.PageSize).
|
|
Offset(countOffset(request.Page, request.PageSize))
|
|
|
|
getTransactions, err := c.setGetTransactionsQueryFilters(getTransactions, request.Filters)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("%w: error setting get transactions query filters: %v", dberrors.ErrBadRequest, err)
|
|
}
|
|
|
|
query, args, err := getTransactions.ToSql()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("%w: error building get transactions query: %v", dberrors.ErrInternal, err)
|
|
}
|
|
|
|
rows, err := c.db.QueryContext(ctx, query, args...)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("%w: error executing get transactions query: %v", dberrors.ErrInternal, err)
|
|
}
|
|
|
|
defer rows.Close()
|
|
|
|
transList := &dbtypes.TransactionListGetResponse{
|
|
Transactions: make([]dbtypes.Transaction, 0, request.PageSize),
|
|
}
|
|
|
|
for rows.Next() {
|
|
var (
|
|
transStatus int32
|
|
ownerId string
|
|
payload, bankAccountId sql.NullString
|
|
transaction dbtypes.Transaction
|
|
)
|
|
|
|
if err := rows.Scan(
|
|
&transaction.Id, &ownerId, &bankAccountId, &transaction.Amount, &transaction.Currency,
|
|
&transStatus, &transaction.CreatedAt, &payload,
|
|
); err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return nil, dberrors.ErrNotFound
|
|
}
|
|
|
|
return nil, fmt.Errorf("%w: error scanning row: %v", dberrors.ErrInternal, err)
|
|
}
|
|
|
|
if bankAccountId.Valid {
|
|
bankAccountInfo, err := c.getBankAccountInfoById(ctx, c.db, bankAccountId.String)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("%w: error getting bank account info: %v", dberrors.ErrInternal, err)
|
|
}
|
|
|
|
transaction.BankAccountInfo = bankAccountInfo
|
|
|
|
ownerInfo, err := c.getOwnerInfoById(ctx, c.db, ownerId, bankAccountInfo.OwnerType)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("%w: error getting owner info: %v", dberrors.ErrInternal, err)
|
|
}
|
|
|
|
transaction.OwnerInfo = ownerInfo
|
|
}
|
|
|
|
transaction.Type = getTransactionTypeByAmount(transaction.Amount)
|
|
transaction.Status = TransactionStatusIdToString(transStatus)
|
|
|
|
if payload.Valid {
|
|
var payloadData dbtypes.TransactionPayload
|
|
|
|
if err := json.Unmarshal([]byte(payload.String), &payloadData); err != nil {
|
|
return nil, fmt.Errorf("%w: error unmarshaling transaction payload: %v", dberrors.ErrInternal, err)
|
|
}
|
|
|
|
transaction.Payload = &payloadData
|
|
}
|
|
|
|
transList.Transactions = append(transList.Transactions, transaction)
|
|
}
|
|
|
|
return transList, nil
|
|
}
|
|
|
|
func (c *client) setGetTransactionsQueryFilters(
|
|
query squirrel.SelectBuilder,
|
|
filters *dbtypes.TransactionListFilters,
|
|
) (squirrel.SelectBuilder, error) {
|
|
if filters == nil {
|
|
return query, nil
|
|
}
|
|
|
|
if filters.Type != nil {
|
|
switch *filters.Type {
|
|
case dbtypes.TransactionTypeDeposit:
|
|
query = query.Where(squirrel.Gt{
|
|
"amount": 0,
|
|
})
|
|
|
|
case dbtypes.TransactionTypeWithdrawal:
|
|
query = query.Where(squirrel.Lt{
|
|
"amount": 0,
|
|
})
|
|
|
|
default:
|
|
return query, fmt.Errorf("%w: invalid transaction type: %v", dberrors.ErrBadRequest, *filters.Type)
|
|
}
|
|
}
|
|
|
|
if filters.Status != nil {
|
|
query = query.Where(squirrel.Eq{"status": TransactionStatusStringToId(*filters.Status)})
|
|
}
|
|
|
|
if filters.BankAccountId != nil {
|
|
query = query.Where(squirrel.Eq{"bank_account_id": *filters.BankAccountId})
|
|
}
|
|
|
|
return query, nil
|
|
}
|
|
|
|
func (c *client) getOwnerInfoById(
|
|
ctx context.Context,
|
|
driver Driver,
|
|
ownerId string,
|
|
ownerType string,
|
|
) (*dbtypes.TransactionOwnerInfo, error) {
|
|
var (
|
|
psql = squirrel.StatementBuilder.PlaceholderFormat(squirrel.Dollar)
|
|
|
|
clientTableName = fmt.Sprintf("%s.%s", c.config.Schema, UsersTableName)
|
|
companyTableName = fmt.Sprintf("%s.%s", c.config.Schema, CompaniesTableName)
|
|
)
|
|
|
|
var getOwnerInfoById squirrel.SelectBuilder
|
|
|
|
// TODO: reingeneer the DB
|
|
switch ownerType {
|
|
case "agent":
|
|
getOwnerInfoById = psql.Select(
|
|
"uid",
|
|
"name",
|
|
).
|
|
From(clientTableName).
|
|
Where(squirrel.Eq{"uid": ownerId})
|
|
|
|
case "company":
|
|
getOwnerInfoById = psql.Select(
|
|
"id",
|
|
"name",
|
|
).
|
|
From(companyTableName).
|
|
Where(squirrel.Eq{"id": ownerId})
|
|
|
|
default:
|
|
return nil, fmt.Errorf("%w: invalid owner type", dberrors.ErrBadRequest)
|
|
}
|
|
|
|
query, args, err := getOwnerInfoById.ToSql()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("%w: error building get owner info by id query: %v", dberrors.ErrInternal, err)
|
|
}
|
|
|
|
row := driver.QueryRowContext(ctx, query, args...)
|
|
|
|
var ownerInfo dbtypes.TransactionOwnerInfo
|
|
|
|
if err := row.Scan(
|
|
&ownerInfo.Id,
|
|
&ownerInfo.Name,
|
|
); err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return nil, dberrors.ErrNotFound
|
|
}
|
|
|
|
return nil, fmt.Errorf("%w: error scanning row for get owner info by id query: %v", dberrors.ErrInternal, err)
|
|
}
|
|
|
|
return &ownerInfo, nil
|
|
}
|
|
|
|
func (c *client) getBankAccountInfoById(
|
|
ctx context.Context,
|
|
driver Driver,
|
|
bankAccountId string,
|
|
) (*dbtypes.BankAccountInfo, error) {
|
|
var (
|
|
psql = squirrel.StatementBuilder.PlaceholderFormat(squirrel.Dollar)
|
|
|
|
bankAccountsTableName = fmt.Sprintf("%s.%s", c.config.Schema, BankAccountsTableName)
|
|
)
|
|
|
|
getBankAccountInfoById := psql.Select(
|
|
"id",
|
|
"account_number",
|
|
"bank_name",
|
|
"bik",
|
|
"correspondent_account",
|
|
"owner_type",
|
|
).
|
|
From(bankAccountsTableName).
|
|
Where(squirrel.Eq{"id": bankAccountId})
|
|
|
|
query, args, err := getBankAccountInfoById.ToSql()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("%w: error building get bank account info by id query: %v", dberrors.ErrInternal, err)
|
|
}
|
|
|
|
row := driver.QueryRowContext(ctx, query, args...)
|
|
|
|
var accountInfo dbtypes.BankAccountInfo
|
|
|
|
if err := row.Scan(
|
|
&accountInfo.Id,
|
|
&accountInfo.AccountNumber,
|
|
&accountInfo.BankName,
|
|
&accountInfo.Bik,
|
|
&accountInfo.CorrespondentAccount,
|
|
&accountInfo.OwnerType,
|
|
); err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return nil, dberrors.ErrNotFound
|
|
}
|
|
|
|
return nil, fmt.Errorf("%w: error scanning row for get bank account info by id query: %v", dberrors.ErrInternal, err)
|
|
}
|
|
|
|
return &accountInfo, nil
|
|
}
|
|
|
|
func (c *client) CreateTransaction(
|
|
ctx context.Context,
|
|
request *dbtypes.TransactionCreateRequest,
|
|
) (*dbtypes.TransactionCreateResponse, error) {
|
|
if request == nil {
|
|
return nil, fmt.Errorf("%w: request is nil", dberrors.ErrBadRequest)
|
|
}
|
|
|
|
tx, err := c.db.BeginTx(ctx, nil)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("%w: error starting transaction: %v", dberrors.ErrInternal, err)
|
|
}
|
|
|
|
defer func() { _ = tx.Rollback() }()
|
|
|
|
result, err := c.createTransactionWithDriver(ctx, tx, request)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error creating transaction: %w", err)
|
|
}
|
|
|
|
if err := tx.Commit(); err != nil {
|
|
return nil, fmt.Errorf("%w: error committing transaction: %w", dberrors.ErrInternal, err)
|
|
}
|
|
|
|
return result, nil
|
|
}
|
|
|
|
func (c *client) createTransactionWithDriver(
|
|
ctx context.Context,
|
|
driver Driver,
|
|
request *dbtypes.TransactionCreateRequest,
|
|
) (*dbtypes.TransactionCreateResponse, error) {
|
|
if _, err := c.getRawBalanceForUpdate(ctx, driver, request.OwnerId); err != nil {
|
|
return nil, fmt.Errorf("error getting raw balance for update: %w", err)
|
|
}
|
|
|
|
result, err := c.createTransaction(ctx, driver, request)
|
|
if err != nil {
|
|
if errors.Is(err, dberrors.ErrConflict) {
|
|
return result, nil
|
|
}
|
|
|
|
return nil, fmt.Errorf("error creating transaction: %w", err)
|
|
}
|
|
|
|
if err := c.updateBalance(ctx, driver, request.Amount, request.OwnerId); err != nil {
|
|
return nil, fmt.Errorf("error updating balance: %w", err)
|
|
}
|
|
|
|
return result, nil
|
|
}
|
|
|
|
func (c *client) getRawBalanceForUpdate(
|
|
ctx context.Context,
|
|
driver Driver,
|
|
ownerId string,
|
|
) (int64, error) {
|
|
var (
|
|
psql = squirrel.StatementBuilder.PlaceholderFormat(squirrel.Dollar)
|
|
|
|
balancesTable = fmt.Sprintf("%s.%s", c.config.Schema, BalancesTableName)
|
|
)
|
|
|
|
getBalance := psql.Select(
|
|
"raw_balance",
|
|
).From(balancesTable).
|
|
Where(squirrel.Eq{"owner_id": ownerId}).
|
|
Suffix("FOR UPDATE")
|
|
|
|
query, args, err := getBalance.ToSql()
|
|
if err != nil {
|
|
return 0, fmt.Errorf("%w: error building 'get balance' query: %v", dberrors.ErrInternal, err)
|
|
}
|
|
|
|
row := driver.QueryRowContext(ctx, query, args...)
|
|
|
|
var rawBalance int64
|
|
|
|
if err := row.Scan(&rawBalance); err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return 0, dberrors.ErrNotFound
|
|
}
|
|
|
|
return 0, fmt.Errorf("%w: error scanning row: %v", dberrors.ErrInternal, err)
|
|
}
|
|
|
|
return rawBalance, nil
|
|
}
|
|
|
|
func (c *client) updateBalance(
|
|
ctx context.Context,
|
|
driver Driver,
|
|
amountDelta int64,
|
|
ownerId string,
|
|
) error {
|
|
var (
|
|
psql = squirrel.StatementBuilder.PlaceholderFormat(squirrel.Dollar)
|
|
|
|
balancesTable = fmt.Sprintf("%s.%s", c.config.Schema, BalancesTableName)
|
|
)
|
|
|
|
updateBalance := psql.Update(balancesTable).
|
|
SetMap(map[string]any{
|
|
"raw_balance": squirrel.Expr("raw_balance + ?", amountDelta),
|
|
"updated_at": squirrel.Expr("NOW()"),
|
|
}).
|
|
Where(squirrel.Eq{"owner_id": ownerId})
|
|
|
|
query, args, err := updateBalance.ToSql()
|
|
if err != nil {
|
|
return fmt.Errorf("%w: error building 'update balance' query: %v", dberrors.ErrInternal, err)
|
|
}
|
|
|
|
res, err := driver.ExecContext(ctx, query, args...)
|
|
if err != nil {
|
|
return fmt.Errorf("%w: error executing 'update balance' query: %v", dberrors.ErrInternal, err)
|
|
}
|
|
|
|
rowsAffected, err := res.RowsAffected()
|
|
if err != nil {
|
|
return fmt.Errorf("%w: error getting rows affected for 'update balance' query: %v", dberrors.ErrInternal, err)
|
|
}
|
|
|
|
if rowsAffected == 0 {
|
|
return dberrors.ErrInternal
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (c *client) createTransaction(
|
|
ctx context.Context,
|
|
driver Driver,
|
|
request *dbtypes.TransactionCreateRequest,
|
|
) (*dbtypes.TransactionCreateResponse, error) {
|
|
var (
|
|
psql = squirrel.StatementBuilder.PlaceholderFormat(squirrel.Dollar)
|
|
|
|
transactionsTable = fmt.Sprintf("%s.%s", c.config.Schema, TransactionsTableName)
|
|
)
|
|
|
|
var payload []byte
|
|
|
|
if request.Payload == nil {
|
|
payload = []byte("{}")
|
|
} else {
|
|
payloadBytes, err := json.Marshal(request.Payload)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("%w: error marshaling transaction payload: %v", dberrors.ErrInternal, err)
|
|
}
|
|
|
|
payload = payloadBytes
|
|
}
|
|
|
|
createTransaction := psql.Insert(transactionsTable).
|
|
Columns(
|
|
"id", "owner_id", "bank_account_id", "amount", "currency", "status", "created_at", "payload",
|
|
).
|
|
Values(
|
|
request.RequestId, request.OwnerId, request.BankAccountId, request.Amount, request.Currency,
|
|
dbtypes.TransactionStatusNew, squirrel.Expr("CURRENT_TIMESTAMP"), payload,
|
|
).
|
|
Suffix("ON CONFLICT (id, owner_id) DO NOTHING")
|
|
|
|
query, args, err := createTransaction.ToSql()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("%w: error building 'create transaction' query: %v", dberrors.ErrInternal, err)
|
|
}
|
|
|
|
res, err := driver.ExecContext(ctx, query, args...)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("%w: error executing 'create transaction' query: %v", dberrors.ErrInternal, err)
|
|
}
|
|
|
|
rowsAffected, err := res.RowsAffected()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("%w: error getting rows affected for 'create transaction' query: %v", dberrors.ErrInternal, err)
|
|
}
|
|
|
|
if rowsAffected == 0 {
|
|
return &dbtypes.TransactionCreateResponse{
|
|
Id: request.RequestId,
|
|
}, dberrors.ErrConflict
|
|
}
|
|
|
|
return &dbtypes.TransactionCreateResponse{
|
|
Id: request.RequestId,
|
|
}, nil
|
|
}
|