Files
test_deploy/internal/database/postgres/company.go
Alex Shevchuk d84487d238 1
2025-08-18 17:12:04 +03:00

507 lines
14 KiB
Go

package pgdb
import (
"context"
"database/sql"
"errors"
"fmt"
"strings"
dberrors "git-molva.ru/Molva/molva-backend/services/api_gateway/internal/database/errors"
dbtypes "git-molva.ru/Molva/molva-backend/services/api_gateway/internal/database/types"
"github.com/Masterminds/squirrel"
"github.com/google/uuid"
"github.com/lib/pq"
)
func CompanyModerationStatusIdToString(status int32) dbtypes.CompanyModerationStatus {
switch status {
case 0:
return dbtypes.CompanyModerationStatusPending
case 1:
return dbtypes.CompanyModerationStatusApproved
case 2:
return dbtypes.CompanyModerationStatusRejected
case 3:
return dbtypes.CompanyModerationStatusNew
default:
return dbtypes.CompanyModerationStatusNew
}
}
func CompanyModerationStatusStringToId(status dbtypes.CompanyModerationStatus) int32 {
switch status {
case dbtypes.CompanyModerationStatusPending:
return 0
case dbtypes.CompanyModerationStatusApproved:
return 1
case dbtypes.CompanyModerationStatusRejected:
return 2
case dbtypes.CompanyModerationStatusNew:
return 3
default:
return 3
}
}
//nolint:gocognit // TODO: refactor
func (c *client) GetCompanyList(
ctx context.Context,
request *dbtypes.CompanyListGetRequest,
) (*dbtypes.CompanyListGetResponse, error) {
if request == nil {
return nil, fmt.Errorf("%w: request is nil", dberrors.ErrBadRequest)
}
var (
psql = squirrel.StatementBuilder.PlaceholderFormat(squirrel.Dollar)
companiesTable = fmt.Sprintf("%s.%s", c.config.Schema, CompaniesTableName)
)
getComList := psql.Select(
"id", "uid", "name", "legal_person", "description", "website",
"physical_address", "legal_address", "inn", "is_active", // TODO: add KPP when DB supports it
"has_moderation_ticket", "staff", "metadata", "additional_fields_tmpl",
).
From(companiesTable).
Where(squirrel.Eq{"uid": request.Id})
query, args, err := getComList.ToSql()
if err != nil {
return nil, fmt.Errorf("%w: error building get distributor company list query: %v", dberrors.ErrInternal, err)
}
rows, err := c.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, fmt.Errorf("%w: error executing get distributor company list query: %v", dberrors.ErrInternal, err)
}
defer rows.Close()
var res dbtypes.CompanyListGetResponse
for rows.Next() {
var (
name, legalPerson, description, website, physicalAddress sql.NullString
legalAddress, inn, metadata, additionalFieldsTmpl sql.NullString
staff pq.StringArray
company dbtypes.Company
)
if err := rows.Scan(
&company.Id, &company.OwnerId, &name, &legalPerson, &description, &website,
&physicalAddress, &legalAddress, &inn, &company.IsActive,
&company.HasModerationTicket, &staff, &metadata, &additionalFieldsTmpl,
); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, dberrors.ErrNotFound
}
return nil, fmt.Errorf("%w: error scanning row: %v", dberrors.ErrInternal, err)
}
if name.Valid {
company.Name = &name.String
}
if legalPerson.Valid {
company.LegalPerson = &legalPerson.String
}
if description.Valid {
company.Description = &description.String
}
if website.Valid {
company.Website = &website.String
}
if physicalAddress.Valid {
company.PhysicalAddress = &physicalAddress.String
}
if legalAddress.Valid {
company.LegalAddress = &legalAddress.String
}
if inn.Valid {
company.Inn = &inn.String
}
company.Staff = staff
if metadata.Valid {
company.Metadata = &metadata.String
}
if additionalFieldsTmpl.Valid {
company.ExtraFieldsTemplate = &additionalFieldsTmpl.String
}
res.Companies = append(res.Companies, company)
}
return &res, nil
}
func (c *client) GetCompanyById(
ctx context.Context,
request *dbtypes.CompanyByIdGetRequest,
) (*dbtypes.CompanyByIdGetResponse, error) {
if request == nil {
return nil, fmt.Errorf("%w: request is nil", dberrors.ErrBadRequest)
}
var (
psql = squirrel.StatementBuilder.PlaceholderFormat(squirrel.Dollar)
companiesTable = fmt.Sprintf("%s.%s", c.config.Schema, CompaniesTableName)
)
getComList := psql.Select(
"id", "uid", "name", "legal_person", "description", "website",
"physical_address", "legal_address", "inn", "is_active", // TODO: add KPP when DB supports it
"has_moderation_ticket", "staff", "metadata", "additional_fields_tmpl",
).
From(companiesTable).
Where(squirrel.And{
squirrel.Eq{"id": request.CompanyId},
squirrel.Eq{"uid": request.Id},
})
query, args, err := getComList.ToSql()
if err != nil {
return nil, fmt.Errorf("%w: error building get distributor company list query: %v", dberrors.ErrInternal, err)
}
row := c.db.QueryRowContext(ctx, query, args...)
var (
name, legalPerson, description, website, physicalAddress sql.NullString
legalAddress, inn, metadata, additionalFieldsTmpl sql.NullString
staff pq.StringArray
company dbtypes.Company
)
if err := row.Scan(
&company.Id, &company.OwnerId, &name, &legalPerson, &description, &website,
&physicalAddress, &legalAddress, &inn, &company.IsActive,
&company.HasModerationTicket, &staff, &metadata, &additionalFieldsTmpl,
); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, dberrors.ErrNotFound
}
return nil, fmt.Errorf("%w: error scanning row: %v", dberrors.ErrInternal, err)
}
if name.Valid {
company.Name = &name.String
}
if legalPerson.Valid {
company.LegalPerson = &legalPerson.String
}
if description.Valid {
company.Description = &description.String
}
if website.Valid {
company.Website = &website.String
}
if physicalAddress.Valid {
company.PhysicalAddress = &physicalAddress.String
}
if legalAddress.Valid {
company.LegalAddress = &legalAddress.String
}
if inn.Valid {
company.Inn = &inn.String
}
company.Staff = staff
if metadata.Valid {
company.Metadata = &metadata.String
}
if additionalFieldsTmpl.Valid {
company.ExtraFieldsTemplate = &additionalFieldsTmpl.String
}
return &dbtypes.CompanyByIdGetResponse{Company: company}, nil
}
func (c *client) CreateCompany(
ctx context.Context,
request *dbtypes.CompanyCreateRequest,
) (*dbtypes.CompanyCreateResponse, error) {
if request == nil {
return nil, fmt.Errorf("%w: request is nil", dberrors.ErrBadRequest)
}
tx, err := c.db.BeginTx(ctx, nil)
if err != nil {
return nil, fmt.Errorf("%w: error starting transaction: %v", dberrors.ErrInternal, err)
}
defer func() { _ = tx.Rollback() }()
res, err := c.createCompany(ctx, tx, request)
if err != nil {
return nil, fmt.Errorf("error creating company: %w", err)
}
if err := c.createCompanyValidationTicket(ctx, tx, res.Id, request); err != nil {
return nil, fmt.Errorf("error creating company validation ticket: %w", err)
}
if err := tx.Commit(); err != nil {
return nil, fmt.Errorf("error committing transaction: %w", err)
}
return res, nil
}
func (c *client) createCompany(
ctx context.Context,
driver Driver,
request *dbtypes.CompanyCreateRequest,
) (*dbtypes.CompanyCreateResponse, error) {
var (
psql = squirrel.StatementBuilder.PlaceholderFormat(squirrel.Dollar)
companyTable = fmt.Sprintf("%s.%s", c.config.Schema, CompaniesTableName)
)
// TODO: use normal uuid after DB reengineering
comId := fmt.Sprintf("%sCOM", strings.ReplaceAll(uuid.NewString(), "-", ""))
createCompany := psql.Insert(companyTable).
Columns(
"id", "uid", "is_active", "has_moderation_ticket", "metadata", "additional_fields_tmpl",
).
Values(
comId, request.OwnerId, false, true, request.Metadata, request.ExtraFieldsTemplate,
)
query, args, err := createCompany.ToSql()
if err != nil {
return nil, fmt.Errorf("%w: error building create company query: %v", dberrors.ErrInternal, err)
}
res, err := driver.ExecContext(ctx, query, args...)
if err != nil {
return nil, fmt.Errorf("%w: error executing create company query: %v", dberrors.ErrInternal, err)
}
rowsAffected, err := res.RowsAffected()
if err != nil {
return nil, fmt.Errorf("%w: error getting rows affected for create company query: %v", dberrors.ErrInternal, err)
}
if rowsAffected == 0 {
return nil, dberrors.ErrInternal
}
return &dbtypes.CompanyCreateResponse{Id: comId}, nil
}
func (c *client) createCompanyValidationTicket(
ctx context.Context,
driver Driver,
companyId string,
request *dbtypes.CompanyCreateRequest,
) error {
var (
psql = squirrel.StatementBuilder.PlaceholderFormat(squirrel.Dollar)
comValTable = fmt.Sprintf("%s.%s", c.config.Schema, CompanyValidationTicketsTableName)
)
var (
ticketId = fmt.Sprintf("%sTCK", strings.ReplaceAll(uuid.NewString(), "-", ""))
)
createCompany := psql.Insert(comValTable).
Columns(
"id", "company_id", "name", "legal_person", "description", "website",
"physical_address", "legal_address", "inn", // TODO: add KPP when DB supports it
"staff", "status",
).
Values(
ticketId, companyId, request.Name, request.LegalPerson, request.Description, request.Website,
request.PhysicalAddress, request.LegalAddress, request.Inn,
request.Staff, CompanyModerationStatusStringToId(dbtypes.CompanyModerationStatusPending), // TODO: switch to status "NEW"
)
query, args, err := createCompany.ToSql()
if err != nil {
return fmt.Errorf("%w: error building create company moderation ticket query: %v", dberrors.ErrInternal, err)
}
res, err := driver.ExecContext(ctx, query, args...)
if err != nil {
return fmt.Errorf("%w: error executing create company moderation ticket query: %v", dberrors.ErrInternal, err)
}
rowsAffected, err := res.RowsAffected()
if err != nil {
return fmt.Errorf("%w: error getting rows affected for create company moderation ticket query: %v", dberrors.ErrInternal, err)
}
if rowsAffected == 0 {
return dberrors.ErrInternal
}
return nil
}
func (c *client) UpdateCompany(
ctx context.Context,
request *dbtypes.CompanyUpdateRequest,
) (*dbtypes.CompanyUpdateResponse, error) {
if request == nil {
return nil, fmt.Errorf("%w: request is nil", dberrors.ErrBadRequest)
}
tx, err := c.db.BeginTx(ctx, nil)
if err != nil {
return nil, fmt.Errorf("%w: error starting transaction: %v", dberrors.ErrInternal, err)
}
defer func() { _ = tx.Rollback() }()
res, err := c.updateCompany(ctx, tx, request)
if err != nil {
return nil, fmt.Errorf("error updating company: %w", err)
}
if err := c.updateCompanyValidationTicket(ctx, tx, request.Id, request); err != nil {
return nil, fmt.Errorf("error updating company validation ticket: %w", err)
}
if err := tx.Commit(); err != nil {
return nil, fmt.Errorf("error committing transaction: %w", err)
}
return res, nil
}
func (c *client) updateCompany(
ctx context.Context,
driver Driver,
request *dbtypes.CompanyUpdateRequest,
) (*dbtypes.CompanyUpdateResponse, error) {
var (
psql = squirrel.StatementBuilder.PlaceholderFormat(squirrel.Dollar)
companyTable = fmt.Sprintf("%s.%s", c.config.Schema, CompaniesTableName)
)
updateCompany := psql.Update(companyTable).
SetMap(map[string]any{
"is_active": false,
"has_moderation_ticket": true,
"metadata": request.Metadata,
"additional_fields_tmpl": request.ExtraFields,
}).
Where(squirrel.Eq{"id": request.Id})
query, args, err := updateCompany.ToSql()
if err != nil {
return nil, fmt.Errorf("%w: error building update company query: %v", dberrors.ErrInternal, err)
}
res, err := driver.ExecContext(ctx, query, args...)
if err != nil {
return nil, fmt.Errorf("%w: error executing update company query: %v", dberrors.ErrInternal, err)
}
rowsAffected, err := res.RowsAffected()
if err != nil {
return nil, fmt.Errorf("%w: error getting rows affected for update company query: %v", dberrors.ErrInternal, err)
}
if rowsAffected == 0 {
return nil, dberrors.ErrInternal
}
return &dbtypes.CompanyUpdateResponse{}, nil
}
// NOTE: do we believe that every company has a moderation ticket?
func (c *client) updateCompanyValidationTicket(
ctx context.Context,
driver Driver,
companyId string,
request *dbtypes.CompanyUpdateRequest,
) error {
var (
psql = squirrel.StatementBuilder.PlaceholderFormat(squirrel.Dollar)
comValTable = fmt.Sprintf("%s.%s", c.config.Schema, CompanyValidationTicketsTableName)
)
updateCompany := psql.Update(comValTable).
Where(squirrel.Eq{"company_id": companyId})
if request.Name != nil {
updateCompany = updateCompany.Set("name", *request.Name)
}
if request.LegalPerson != nil {
updateCompany = updateCompany.Set("legal_person", *request.LegalPerson)
}
if request.Description != nil {
updateCompany = updateCompany.Set("description", *request.Description)
}
if request.Website != nil {
updateCompany = updateCompany.Set("website", *request.Website)
}
if request.PhysicalAddress != nil {
updateCompany = updateCompany.Set("physical_address", *request.PhysicalAddress)
}
if request.LegalAddress != nil {
updateCompany = updateCompany.Set("legal_address", *request.LegalAddress)
}
if request.Inn != nil {
updateCompany = updateCompany.Set("inn", *request.Inn)
}
if len(request.Staff) > 0 {
updateCompany = updateCompany.Set("staff", request.Staff)
}
query, args, err := updateCompany.ToSql()
if err != nil {
return fmt.Errorf("%w: error building update company moderation ticket query: %v", dberrors.ErrInternal, err)
}
res, err := driver.ExecContext(ctx, query, args...)
if err != nil {
return fmt.Errorf("%w: error executing update company moderation ticket query: %v", dberrors.ErrInternal, err)
}
rowsAffected, err := res.RowsAffected()
if err != nil {
return fmt.Errorf("%w: error getting rows affected for update company moderation ticket query: %v", dberrors.ErrInternal, err)
}
if rowsAffected == 0 {
return dberrors.ErrInternal
}
return nil
}