package pgdb import ( "context" "database/sql" "encoding/json" "errors" "fmt" dberrors "git-molva.ru/Molva/molva-backend/services/api_gateway/internal/database/errors" dbtypes "git-molva.ru/Molva/molva-backend/services/api_gateway/internal/database/types" "github.com/Masterminds/squirrel" ) func getTransactionTypeByAmount(amount int64) dbtypes.TransactionType { if amount > 0 { return dbtypes.TransactionTypeDeposit } return dbtypes.TransactionTypeWithdrawal } // TODO: add migration to rebind statuses func TransactionStatusIdToString(status int32) dbtypes.TransactionStatus { switch status { case 0: return dbtypes.TransactionStatusPending case 1: return dbtypes.TransactionStatusApproved case 2: return dbtypes.TransactionStatusRejected case 3: return dbtypes.TransactionStatusNew default: return dbtypes.TransactionStatusPending } } func TransactionStatusStringToId(status dbtypes.TransactionStatus) int32 { switch status { case dbtypes.TransactionStatusNew: return 3 case dbtypes.TransactionStatusPending: return 0 case dbtypes.TransactionStatusApproved: return 1 case dbtypes.TransactionStatusRejected: return 2 default: return 3 } } //nolint:gocognit // not so hard func (c *client) GetTransactionList( ctx context.Context, request *dbtypes.TransactionListGetRequest, ) (*dbtypes.TransactionListGetResponse, error) { if request == nil { return nil, nil } var ( psql = squirrel.StatementBuilder.PlaceholderFormat(squirrel.Dollar) transactionsTable = fmt.Sprintf("%s.%s", c.config.Schema, TransactionsTableName) ) getTransactions := psql.Select( "id", "owner_id", "bank_account_id", "amount", "currency", "status", "created_at", "payload", ).From(transactionsTable). Where(squirrel.Eq{"owner_id": request.OwnerId}). Limit(request.PageSize). Offset(countOffset(request.Page, request.PageSize)) getTransactions, err := c.setGetTransactionsQueryFilters(getTransactions, request.Filters) if err != nil { return nil, fmt.Errorf("%w: error setting get transactions query filters: %v", dberrors.ErrBadRequest, err) } query, args, err := getTransactions.ToSql() if err != nil { return nil, fmt.Errorf("%w: error building get transactions query: %v", dberrors.ErrInternal, err) } rows, err := c.db.QueryContext(ctx, query, args...) if err != nil { return nil, fmt.Errorf("%w: error executing get transactions query: %v", dberrors.ErrInternal, err) } defer rows.Close() transList := &dbtypes.TransactionListGetResponse{ Transactions: make([]dbtypes.Transaction, 0, request.PageSize), } for rows.Next() { var ( transStatus int32 ownerId string payload, bankAccountId sql.NullString transaction dbtypes.Transaction ) if err := rows.Scan( &transaction.Id, &ownerId, &bankAccountId, &transaction.Amount, &transaction.Currency, &transStatus, &transaction.CreatedAt, &payload, ); err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, dberrors.ErrNotFound } return nil, fmt.Errorf("%w: error scanning row: %v", dberrors.ErrInternal, err) } if bankAccountId.Valid { bankAccountInfo, err := c.getBankAccountInfoById(ctx, c.db, bankAccountId.String) if err != nil { return nil, fmt.Errorf("%w: error getting bank account info: %v", dberrors.ErrInternal, err) } transaction.BankAccountInfo = bankAccountInfo ownerInfo, err := c.getOwnerInfoById(ctx, c.db, ownerId, bankAccountInfo.OwnerType) if err != nil { return nil, fmt.Errorf("%w: error getting owner info: %v", dberrors.ErrInternal, err) } transaction.OwnerInfo = ownerInfo } transaction.Type = getTransactionTypeByAmount(transaction.Amount) transaction.Status = TransactionStatusIdToString(transStatus) if payload.Valid { var payloadData dbtypes.TransactionPayload if err := json.Unmarshal([]byte(payload.String), &payloadData); err != nil { return nil, fmt.Errorf("%w: error unmarshaling transaction payload: %v", dberrors.ErrInternal, err) } transaction.Payload = &payloadData } transList.Transactions = append(transList.Transactions, transaction) } return transList, nil } func (c *client) setGetTransactionsQueryFilters( query squirrel.SelectBuilder, filters *dbtypes.TransactionListFilters, ) (squirrel.SelectBuilder, error) { if filters == nil { return query, nil } if filters.Type != nil { switch *filters.Type { case dbtypes.TransactionTypeDeposit: query = query.Where(squirrel.Gt{ "amount": 0, }) case dbtypes.TransactionTypeWithdrawal: query = query.Where(squirrel.Lt{ "amount": 0, }) default: return query, fmt.Errorf("%w: invalid transaction type: %v", dberrors.ErrBadRequest, *filters.Type) } } if filters.Status != nil { query = query.Where(squirrel.Eq{"status": TransactionStatusStringToId(*filters.Status)}) } if filters.BankAccountId != nil { query = query.Where(squirrel.Eq{"bank_account_id": *filters.BankAccountId}) } return query, nil } func (c *client) getOwnerInfoById( ctx context.Context, driver Driver, ownerId string, ownerType string, ) (*dbtypes.TransactionOwnerInfo, error) { var ( psql = squirrel.StatementBuilder.PlaceholderFormat(squirrel.Dollar) clientTableName = fmt.Sprintf("%s.%s", c.config.Schema, UsersTableName) companyTableName = fmt.Sprintf("%s.%s", c.config.Schema, CompaniesTableName) ) var getOwnerInfoById squirrel.SelectBuilder // TODO: reingeneer the DB switch ownerType { case "agent": getOwnerInfoById = psql.Select( "uid", "name", ). From(clientTableName). Where(squirrel.Eq{"uid": ownerId}) case "company": getOwnerInfoById = psql.Select( "id", "name", ). From(companyTableName). Where(squirrel.Eq{"id": ownerId}) default: return nil, fmt.Errorf("%w: invalid owner type", dberrors.ErrBadRequest) } query, args, err := getOwnerInfoById.ToSql() if err != nil { return nil, fmt.Errorf("%w: error building get owner info by id query: %v", dberrors.ErrInternal, err) } row := driver.QueryRowContext(ctx, query, args...) var ownerInfo dbtypes.TransactionOwnerInfo if err := row.Scan( &ownerInfo.Id, &ownerInfo.Name, ); err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, dberrors.ErrNotFound } return nil, fmt.Errorf("%w: error scanning row for get owner info by id query: %v", dberrors.ErrInternal, err) } return &ownerInfo, nil } func (c *client) getBankAccountInfoById( ctx context.Context, driver Driver, bankAccountId string, ) (*dbtypes.BankAccountInfo, error) { var ( psql = squirrel.StatementBuilder.PlaceholderFormat(squirrel.Dollar) bankAccountsTableName = fmt.Sprintf("%s.%s", c.config.Schema, BankAccountsTableName) ) getBankAccountInfoById := psql.Select( "id", "account_number", "bank_name", "bik", "correspondent_account", "owner_type", ). From(bankAccountsTableName). Where(squirrel.Eq{"id": bankAccountId}) query, args, err := getBankAccountInfoById.ToSql() if err != nil { return nil, fmt.Errorf("%w: error building get bank account info by id query: %v", dberrors.ErrInternal, err) } row := driver.QueryRowContext(ctx, query, args...) var accountInfo dbtypes.BankAccountInfo if err := row.Scan( &accountInfo.Id, &accountInfo.AccountNumber, &accountInfo.BankName, &accountInfo.Bik, &accountInfo.CorrespondentAccount, &accountInfo.OwnerType, ); err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, dberrors.ErrNotFound } return nil, fmt.Errorf("%w: error scanning row for get bank account info by id query: %v", dberrors.ErrInternal, err) } return &accountInfo, nil } func (c *client) CreateTransaction( ctx context.Context, request *dbtypes.TransactionCreateRequest, ) (*dbtypes.TransactionCreateResponse, error) { if request == nil { return nil, fmt.Errorf("%w: request is nil", dberrors.ErrBadRequest) } tx, err := c.db.BeginTx(ctx, nil) if err != nil { return nil, fmt.Errorf("%w: error starting transaction: %v", dberrors.ErrInternal, err) } defer func() { _ = tx.Rollback() }() result, err := c.createTransactionWithDriver(ctx, tx, request) if err != nil { return nil, fmt.Errorf("error creating transaction: %w", err) } if err := tx.Commit(); err != nil { return nil, fmt.Errorf("%w: error committing transaction: %w", dberrors.ErrInternal, err) } return result, nil } func (c *client) createTransactionWithDriver( ctx context.Context, driver Driver, request *dbtypes.TransactionCreateRequest, ) (*dbtypes.TransactionCreateResponse, error) { if _, err := c.getRawBalanceForUpdate(ctx, driver, request.OwnerId); err != nil { return nil, fmt.Errorf("error getting raw balance for update: %w", err) } result, err := c.createTransaction(ctx, driver, request) if err != nil { if errors.Is(err, dberrors.ErrConflict) { return result, nil } return nil, fmt.Errorf("error creating transaction: %w", err) } if err := c.updateBalance(ctx, driver, request.Amount, request.OwnerId); err != nil { return nil, fmt.Errorf("error updating balance: %w", err) } return result, nil } func (c *client) getRawBalanceForUpdate( ctx context.Context, driver Driver, ownerId string, ) (int64, error) { var ( psql = squirrel.StatementBuilder.PlaceholderFormat(squirrel.Dollar) balancesTable = fmt.Sprintf("%s.%s", c.config.Schema, BalancesTableName) ) getBalance := psql.Select( "raw_balance", ).From(balancesTable). Where(squirrel.Eq{"owner_id": ownerId}). Suffix("FOR UPDATE") query, args, err := getBalance.ToSql() if err != nil { return 0, fmt.Errorf("%w: error building 'get balance' query: %v", dberrors.ErrInternal, err) } row := driver.QueryRowContext(ctx, query, args...) var rawBalance int64 if err := row.Scan(&rawBalance); err != nil { if errors.Is(err, sql.ErrNoRows) { return 0, dberrors.ErrNotFound } return 0, fmt.Errorf("%w: error scanning row: %v", dberrors.ErrInternal, err) } return rawBalance, nil } func (c *client) updateBalance( ctx context.Context, driver Driver, amountDelta int64, ownerId string, ) error { var ( psql = squirrel.StatementBuilder.PlaceholderFormat(squirrel.Dollar) balancesTable = fmt.Sprintf("%s.%s", c.config.Schema, BalancesTableName) ) updateBalance := psql.Update(balancesTable). SetMap(map[string]any{ "raw_balance": squirrel.Expr("raw_balance + ?", amountDelta), "updated_at": squirrel.Expr("NOW()"), }). Where(squirrel.Eq{"owner_id": ownerId}) query, args, err := updateBalance.ToSql() if err != nil { return fmt.Errorf("%w: error building 'update balance' query: %v", dberrors.ErrInternal, err) } res, err := driver.ExecContext(ctx, query, args...) if err != nil { return fmt.Errorf("%w: error executing 'update balance' query: %v", dberrors.ErrInternal, err) } rowsAffected, err := res.RowsAffected() if err != nil { return fmt.Errorf("%w: error getting rows affected for 'update balance' query: %v", dberrors.ErrInternal, err) } if rowsAffected == 0 { return dberrors.ErrInternal } return nil } func (c *client) createTransaction( ctx context.Context, driver Driver, request *dbtypes.TransactionCreateRequest, ) (*dbtypes.TransactionCreateResponse, error) { var ( psql = squirrel.StatementBuilder.PlaceholderFormat(squirrel.Dollar) transactionsTable = fmt.Sprintf("%s.%s", c.config.Schema, TransactionsTableName) ) var payload []byte if request.Payload == nil { payload = []byte("{}") } else { payloadBytes, err := json.Marshal(request.Payload) if err != nil { return nil, fmt.Errorf("%w: error marshaling transaction payload: %v", dberrors.ErrInternal, err) } payload = payloadBytes } createTransaction := psql.Insert(transactionsTable). Columns( "id", "owner_id", "bank_account_id", "amount", "currency", "status", "created_at", "payload", ). Values( request.RequestId, request.OwnerId, request.BankAccountId, request.Amount, request.Currency, dbtypes.TransactionStatusNew, squirrel.Expr("CURRENT_TIMESTAMP"), payload, ). Suffix("ON CONFLICT (id, owner_id) DO NOTHING") query, args, err := createTransaction.ToSql() if err != nil { return nil, fmt.Errorf("%w: error building 'create transaction' query: %v", dberrors.ErrInternal, err) } res, err := driver.ExecContext(ctx, query, args...) if err != nil { return nil, fmt.Errorf("%w: error executing 'create transaction' query: %v", dberrors.ErrInternal, err) } rowsAffected, err := res.RowsAffected() if err != nil { return nil, fmt.Errorf("%w: error getting rows affected for 'create transaction' query: %v", dberrors.ErrInternal, err) } if rowsAffected == 0 { return &dbtypes.TransactionCreateResponse{ Id: request.RequestId, }, dberrors.ErrConflict } return &dbtypes.TransactionCreateResponse{ Id: request.RequestId, }, nil }