507 lines
14 KiB
Go
507 lines
14 KiB
Go
package pgdb
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"errors"
|
|
"fmt"
|
|
"strings"
|
|
|
|
dberrors "git-molva.ru/Molva/molva-backend/services/api_gateway/internal/database/errors"
|
|
dbtypes "git-molva.ru/Molva/molva-backend/services/api_gateway/internal/database/types"
|
|
"github.com/Masterminds/squirrel"
|
|
"github.com/google/uuid"
|
|
"github.com/lib/pq"
|
|
)
|
|
|
|
func CompanyModerationStatusIdToString(status int32) dbtypes.CompanyModerationStatus {
|
|
switch status {
|
|
case 0:
|
|
return dbtypes.CompanyModerationStatusPending
|
|
case 1:
|
|
return dbtypes.CompanyModerationStatusApproved
|
|
case 2:
|
|
return dbtypes.CompanyModerationStatusRejected
|
|
case 3:
|
|
return dbtypes.CompanyModerationStatusNew
|
|
default:
|
|
return dbtypes.CompanyModerationStatusNew
|
|
}
|
|
}
|
|
|
|
func CompanyModerationStatusStringToId(status dbtypes.CompanyModerationStatus) int32 {
|
|
switch status {
|
|
case dbtypes.CompanyModerationStatusPending:
|
|
return 0
|
|
case dbtypes.CompanyModerationStatusApproved:
|
|
return 1
|
|
case dbtypes.CompanyModerationStatusRejected:
|
|
return 2
|
|
case dbtypes.CompanyModerationStatusNew:
|
|
return 3
|
|
default:
|
|
return 3
|
|
}
|
|
}
|
|
|
|
//nolint:gocognit // TODO: refactor
|
|
func (c *client) GetCompanyList(
|
|
ctx context.Context,
|
|
request *dbtypes.CompanyListGetRequest,
|
|
) (*dbtypes.CompanyListGetResponse, error) {
|
|
if request == nil {
|
|
return nil, fmt.Errorf("%w: request is nil", dberrors.ErrBadRequest)
|
|
}
|
|
|
|
var (
|
|
psql = squirrel.StatementBuilder.PlaceholderFormat(squirrel.Dollar)
|
|
|
|
companiesTable = fmt.Sprintf("%s.%s", c.config.Schema, CompaniesTableName)
|
|
)
|
|
|
|
getComList := psql.Select(
|
|
"id", "uid", "name", "legal_person", "description", "website",
|
|
"physical_address", "legal_address", "inn", "is_active", // TODO: add KPP when DB supports it
|
|
"has_moderation_ticket", "staff", "metadata", "additional_fields_tmpl",
|
|
).
|
|
From(companiesTable).
|
|
Where(squirrel.Eq{"uid": request.Id})
|
|
|
|
query, args, err := getComList.ToSql()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("%w: error building get distributor company list query: %v", dberrors.ErrInternal, err)
|
|
}
|
|
|
|
rows, err := c.db.QueryContext(ctx, query, args...)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("%w: error executing get distributor company list query: %v", dberrors.ErrInternal, err)
|
|
}
|
|
|
|
defer rows.Close()
|
|
|
|
var res dbtypes.CompanyListGetResponse
|
|
|
|
for rows.Next() {
|
|
var (
|
|
name, legalPerson, description, website, physicalAddress sql.NullString
|
|
legalAddress, inn, metadata, additionalFieldsTmpl sql.NullString
|
|
staff pq.StringArray
|
|
company dbtypes.Company
|
|
)
|
|
|
|
if err := rows.Scan(
|
|
&company.Id, &company.OwnerId, &name, &legalPerson, &description, &website,
|
|
&physicalAddress, &legalAddress, &inn, &company.IsActive,
|
|
&company.HasModerationTicket, &staff, &metadata, &additionalFieldsTmpl,
|
|
); err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return nil, dberrors.ErrNotFound
|
|
}
|
|
|
|
return nil, fmt.Errorf("%w: error scanning row: %v", dberrors.ErrInternal, err)
|
|
}
|
|
|
|
if name.Valid {
|
|
company.Name = &name.String
|
|
}
|
|
|
|
if legalPerson.Valid {
|
|
company.LegalPerson = &legalPerson.String
|
|
}
|
|
|
|
if description.Valid {
|
|
company.Description = &description.String
|
|
}
|
|
|
|
if website.Valid {
|
|
company.Website = &website.String
|
|
}
|
|
|
|
if physicalAddress.Valid {
|
|
company.PhysicalAddress = &physicalAddress.String
|
|
}
|
|
|
|
if legalAddress.Valid {
|
|
company.LegalAddress = &legalAddress.String
|
|
}
|
|
|
|
if inn.Valid {
|
|
company.Inn = &inn.String
|
|
}
|
|
|
|
company.Staff = staff
|
|
|
|
if metadata.Valid {
|
|
company.Metadata = &metadata.String
|
|
}
|
|
|
|
if additionalFieldsTmpl.Valid {
|
|
company.ExtraFieldsTemplate = &additionalFieldsTmpl.String
|
|
}
|
|
|
|
res.Companies = append(res.Companies, company)
|
|
}
|
|
|
|
return &res, nil
|
|
}
|
|
|
|
func (c *client) GetCompanyById(
|
|
ctx context.Context,
|
|
request *dbtypes.CompanyByIdGetRequest,
|
|
) (*dbtypes.CompanyByIdGetResponse, error) {
|
|
if request == nil {
|
|
return nil, fmt.Errorf("%w: request is nil", dberrors.ErrBadRequest)
|
|
}
|
|
|
|
var (
|
|
psql = squirrel.StatementBuilder.PlaceholderFormat(squirrel.Dollar)
|
|
|
|
companiesTable = fmt.Sprintf("%s.%s", c.config.Schema, CompaniesTableName)
|
|
)
|
|
|
|
getComList := psql.Select(
|
|
"id", "uid", "name", "legal_person", "description", "website",
|
|
"physical_address", "legal_address", "inn", "is_active", // TODO: add KPP when DB supports it
|
|
"has_moderation_ticket", "staff", "metadata", "additional_fields_tmpl",
|
|
).
|
|
From(companiesTable).
|
|
Where(squirrel.And{
|
|
squirrel.Eq{"id": request.CompanyId},
|
|
squirrel.Eq{"uid": request.Id},
|
|
})
|
|
|
|
query, args, err := getComList.ToSql()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("%w: error building get distributor company list query: %v", dberrors.ErrInternal, err)
|
|
}
|
|
|
|
row := c.db.QueryRowContext(ctx, query, args...)
|
|
|
|
var (
|
|
name, legalPerson, description, website, physicalAddress sql.NullString
|
|
legalAddress, inn, metadata, additionalFieldsTmpl sql.NullString
|
|
staff pq.StringArray
|
|
company dbtypes.Company
|
|
)
|
|
|
|
if err := row.Scan(
|
|
&company.Id, &company.OwnerId, &name, &legalPerson, &description, &website,
|
|
&physicalAddress, &legalAddress, &inn, &company.IsActive,
|
|
&company.HasModerationTicket, &staff, &metadata, &additionalFieldsTmpl,
|
|
); err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return nil, dberrors.ErrNotFound
|
|
}
|
|
|
|
return nil, fmt.Errorf("%w: error scanning row: %v", dberrors.ErrInternal, err)
|
|
}
|
|
|
|
if name.Valid {
|
|
company.Name = &name.String
|
|
}
|
|
|
|
if legalPerson.Valid {
|
|
company.LegalPerson = &legalPerson.String
|
|
}
|
|
|
|
if description.Valid {
|
|
company.Description = &description.String
|
|
}
|
|
|
|
if website.Valid {
|
|
company.Website = &website.String
|
|
}
|
|
|
|
if physicalAddress.Valid {
|
|
company.PhysicalAddress = &physicalAddress.String
|
|
}
|
|
|
|
if legalAddress.Valid {
|
|
company.LegalAddress = &legalAddress.String
|
|
}
|
|
|
|
if inn.Valid {
|
|
company.Inn = &inn.String
|
|
}
|
|
|
|
company.Staff = staff
|
|
|
|
if metadata.Valid {
|
|
company.Metadata = &metadata.String
|
|
}
|
|
|
|
if additionalFieldsTmpl.Valid {
|
|
company.ExtraFieldsTemplate = &additionalFieldsTmpl.String
|
|
}
|
|
|
|
return &dbtypes.CompanyByIdGetResponse{Company: company}, nil
|
|
}
|
|
|
|
func (c *client) CreateCompany(
|
|
ctx context.Context,
|
|
request *dbtypes.CompanyCreateRequest,
|
|
) (*dbtypes.CompanyCreateResponse, error) {
|
|
if request == nil {
|
|
return nil, fmt.Errorf("%w: request is nil", dberrors.ErrBadRequest)
|
|
}
|
|
|
|
tx, err := c.db.BeginTx(ctx, nil)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("%w: error starting transaction: %v", dberrors.ErrInternal, err)
|
|
}
|
|
|
|
defer func() { _ = tx.Rollback() }()
|
|
|
|
res, err := c.createCompany(ctx, tx, request)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error creating company: %w", err)
|
|
}
|
|
|
|
if err := c.createCompanyValidationTicket(ctx, tx, res.Id, request); err != nil {
|
|
return nil, fmt.Errorf("error creating company validation ticket: %w", err)
|
|
}
|
|
|
|
if err := tx.Commit(); err != nil {
|
|
return nil, fmt.Errorf("error committing transaction: %w", err)
|
|
}
|
|
|
|
return res, nil
|
|
}
|
|
|
|
func (c *client) createCompany(
|
|
ctx context.Context,
|
|
driver Driver,
|
|
request *dbtypes.CompanyCreateRequest,
|
|
) (*dbtypes.CompanyCreateResponse, error) {
|
|
var (
|
|
psql = squirrel.StatementBuilder.PlaceholderFormat(squirrel.Dollar)
|
|
|
|
companyTable = fmt.Sprintf("%s.%s", c.config.Schema, CompaniesTableName)
|
|
)
|
|
|
|
// TODO: use normal uuid after DB reengineering
|
|
comId := fmt.Sprintf("%sCOM", strings.ReplaceAll(uuid.NewString(), "-", ""))
|
|
|
|
createCompany := psql.Insert(companyTable).
|
|
Columns(
|
|
"id", "uid", "is_active", "has_moderation_ticket", "metadata", "additional_fields_tmpl",
|
|
).
|
|
Values(
|
|
comId, request.OwnerId, false, true, request.Metadata, request.ExtraFieldsTemplate,
|
|
)
|
|
|
|
query, args, err := createCompany.ToSql()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("%w: error building create company query: %v", dberrors.ErrInternal, err)
|
|
}
|
|
|
|
res, err := driver.ExecContext(ctx, query, args...)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("%w: error executing create company query: %v", dberrors.ErrInternal, err)
|
|
}
|
|
|
|
rowsAffected, err := res.RowsAffected()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("%w: error getting rows affected for create company query: %v", dberrors.ErrInternal, err)
|
|
}
|
|
|
|
if rowsAffected == 0 {
|
|
return nil, dberrors.ErrInternal
|
|
}
|
|
|
|
return &dbtypes.CompanyCreateResponse{Id: comId}, nil
|
|
}
|
|
|
|
func (c *client) createCompanyValidationTicket(
|
|
ctx context.Context,
|
|
driver Driver,
|
|
companyId string,
|
|
request *dbtypes.CompanyCreateRequest,
|
|
) error {
|
|
var (
|
|
psql = squirrel.StatementBuilder.PlaceholderFormat(squirrel.Dollar)
|
|
|
|
comValTable = fmt.Sprintf("%s.%s", c.config.Schema, CompanyValidationTicketsTableName)
|
|
)
|
|
|
|
var (
|
|
ticketId = fmt.Sprintf("%sTCK", strings.ReplaceAll(uuid.NewString(), "-", ""))
|
|
)
|
|
|
|
createCompany := psql.Insert(comValTable).
|
|
Columns(
|
|
"id", "company_id", "name", "legal_person", "description", "website",
|
|
"physical_address", "legal_address", "inn", // TODO: add KPP when DB supports it
|
|
"staff", "status",
|
|
).
|
|
Values(
|
|
ticketId, companyId, request.Name, request.LegalPerson, request.Description, request.Website,
|
|
request.PhysicalAddress, request.LegalAddress, request.Inn,
|
|
request.Staff, CompanyModerationStatusStringToId(dbtypes.CompanyModerationStatusPending), // TODO: switch to status "NEW"
|
|
)
|
|
|
|
query, args, err := createCompany.ToSql()
|
|
if err != nil {
|
|
return fmt.Errorf("%w: error building create company moderation ticket query: %v", dberrors.ErrInternal, err)
|
|
}
|
|
|
|
res, err := driver.ExecContext(ctx, query, args...)
|
|
if err != nil {
|
|
return fmt.Errorf("%w: error executing create company moderation ticket query: %v", dberrors.ErrInternal, err)
|
|
}
|
|
|
|
rowsAffected, err := res.RowsAffected()
|
|
if err != nil {
|
|
return fmt.Errorf("%w: error getting rows affected for create company moderation ticket query: %v", dberrors.ErrInternal, err)
|
|
}
|
|
|
|
if rowsAffected == 0 {
|
|
return dberrors.ErrInternal
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (c *client) UpdateCompany(
|
|
ctx context.Context,
|
|
request *dbtypes.CompanyUpdateRequest,
|
|
) (*dbtypes.CompanyUpdateResponse, error) {
|
|
if request == nil {
|
|
return nil, fmt.Errorf("%w: request is nil", dberrors.ErrBadRequest)
|
|
}
|
|
|
|
tx, err := c.db.BeginTx(ctx, nil)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("%w: error starting transaction: %v", dberrors.ErrInternal, err)
|
|
}
|
|
|
|
defer func() { _ = tx.Rollback() }()
|
|
|
|
res, err := c.updateCompany(ctx, tx, request)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error updating company: %w", err)
|
|
}
|
|
|
|
if err := c.updateCompanyValidationTicket(ctx, tx, request.Id, request); err != nil {
|
|
return nil, fmt.Errorf("error updating company validation ticket: %w", err)
|
|
}
|
|
|
|
if err := tx.Commit(); err != nil {
|
|
return nil, fmt.Errorf("error committing transaction: %w", err)
|
|
}
|
|
|
|
return res, nil
|
|
}
|
|
|
|
func (c *client) updateCompany(
|
|
ctx context.Context,
|
|
driver Driver,
|
|
request *dbtypes.CompanyUpdateRequest,
|
|
) (*dbtypes.CompanyUpdateResponse, error) {
|
|
var (
|
|
psql = squirrel.StatementBuilder.PlaceholderFormat(squirrel.Dollar)
|
|
|
|
companyTable = fmt.Sprintf("%s.%s", c.config.Schema, CompaniesTableName)
|
|
)
|
|
|
|
updateCompany := psql.Update(companyTable).
|
|
SetMap(map[string]any{
|
|
"is_active": false,
|
|
"has_moderation_ticket": true,
|
|
"metadata": request.Metadata,
|
|
"additional_fields_tmpl": request.ExtraFields,
|
|
}).
|
|
Where(squirrel.Eq{"id": request.Id})
|
|
|
|
query, args, err := updateCompany.ToSql()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("%w: error building update company query: %v", dberrors.ErrInternal, err)
|
|
}
|
|
|
|
res, err := driver.ExecContext(ctx, query, args...)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("%w: error executing update company query: %v", dberrors.ErrInternal, err)
|
|
}
|
|
|
|
rowsAffected, err := res.RowsAffected()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("%w: error getting rows affected for update company query: %v", dberrors.ErrInternal, err)
|
|
}
|
|
|
|
if rowsAffected == 0 {
|
|
return nil, dberrors.ErrInternal
|
|
}
|
|
|
|
return &dbtypes.CompanyUpdateResponse{}, nil
|
|
}
|
|
|
|
// NOTE: do we believe that every company has a moderation ticket?
|
|
func (c *client) updateCompanyValidationTicket(
|
|
ctx context.Context,
|
|
driver Driver,
|
|
companyId string,
|
|
request *dbtypes.CompanyUpdateRequest,
|
|
) error {
|
|
var (
|
|
psql = squirrel.StatementBuilder.PlaceholderFormat(squirrel.Dollar)
|
|
|
|
comValTable = fmt.Sprintf("%s.%s", c.config.Schema, CompanyValidationTicketsTableName)
|
|
)
|
|
|
|
updateCompany := psql.Update(comValTable).
|
|
Where(squirrel.Eq{"company_id": companyId})
|
|
|
|
if request.Name != nil {
|
|
updateCompany = updateCompany.Set("name", *request.Name)
|
|
}
|
|
|
|
if request.LegalPerson != nil {
|
|
updateCompany = updateCompany.Set("legal_person", *request.LegalPerson)
|
|
}
|
|
|
|
if request.Description != nil {
|
|
updateCompany = updateCompany.Set("description", *request.Description)
|
|
}
|
|
|
|
if request.Website != nil {
|
|
updateCompany = updateCompany.Set("website", *request.Website)
|
|
}
|
|
|
|
if request.PhysicalAddress != nil {
|
|
updateCompany = updateCompany.Set("physical_address", *request.PhysicalAddress)
|
|
}
|
|
|
|
if request.LegalAddress != nil {
|
|
updateCompany = updateCompany.Set("legal_address", *request.LegalAddress)
|
|
}
|
|
|
|
if request.Inn != nil {
|
|
updateCompany = updateCompany.Set("inn", *request.Inn)
|
|
}
|
|
|
|
if len(request.Staff) > 0 {
|
|
updateCompany = updateCompany.Set("staff", request.Staff)
|
|
}
|
|
|
|
query, args, err := updateCompany.ToSql()
|
|
if err != nil {
|
|
return fmt.Errorf("%w: error building update company moderation ticket query: %v", dberrors.ErrInternal, err)
|
|
}
|
|
|
|
res, err := driver.ExecContext(ctx, query, args...)
|
|
if err != nil {
|
|
return fmt.Errorf("%w: error executing update company moderation ticket query: %v", dberrors.ErrInternal, err)
|
|
}
|
|
|
|
rowsAffected, err := res.RowsAffected()
|
|
if err != nil {
|
|
return fmt.Errorf("%w: error getting rows affected for update company moderation ticket query: %v", dberrors.ErrInternal, err)
|
|
}
|
|
|
|
if rowsAffected == 0 {
|
|
return dberrors.ErrInternal
|
|
}
|
|
|
|
return nil
|
|
}
|