1
This commit is contained in:
506
internal/database/postgres/company.go
Normal file
506
internal/database/postgres/company.go
Normal file
@@ -0,0 +1,506 @@
|
||||
package pgdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
dberrors "git-molva.ru/Molva/molva-backend/services/api_gateway/internal/database/errors"
|
||||
dbtypes "git-molva.ru/Molva/molva-backend/services/api_gateway/internal/database/types"
|
||||
"github.com/Masterminds/squirrel"
|
||||
"github.com/google/uuid"
|
||||
"github.com/lib/pq"
|
||||
)
|
||||
|
||||
func CompanyModerationStatusIdToString(status int32) dbtypes.CompanyModerationStatus {
|
||||
switch status {
|
||||
case 0:
|
||||
return dbtypes.CompanyModerationStatusPending
|
||||
case 1:
|
||||
return dbtypes.CompanyModerationStatusApproved
|
||||
case 2:
|
||||
return dbtypes.CompanyModerationStatusRejected
|
||||
case 3:
|
||||
return dbtypes.CompanyModerationStatusNew
|
||||
default:
|
||||
return dbtypes.CompanyModerationStatusNew
|
||||
}
|
||||
}
|
||||
|
||||
func CompanyModerationStatusStringToId(status dbtypes.CompanyModerationStatus) int32 {
|
||||
switch status {
|
||||
case dbtypes.CompanyModerationStatusPending:
|
||||
return 0
|
||||
case dbtypes.CompanyModerationStatusApproved:
|
||||
return 1
|
||||
case dbtypes.CompanyModerationStatusRejected:
|
||||
return 2
|
||||
case dbtypes.CompanyModerationStatusNew:
|
||||
return 3
|
||||
default:
|
||||
return 3
|
||||
}
|
||||
}
|
||||
|
||||
//nolint:gocognit // TODO: refactor
|
||||
func (c *client) GetCompanyList(
|
||||
ctx context.Context,
|
||||
request *dbtypes.CompanyListGetRequest,
|
||||
) (*dbtypes.CompanyListGetResponse, error) {
|
||||
if request == nil {
|
||||
return nil, fmt.Errorf("%w: request is nil", dberrors.ErrBadRequest)
|
||||
}
|
||||
|
||||
var (
|
||||
psql = squirrel.StatementBuilder.PlaceholderFormat(squirrel.Dollar)
|
||||
|
||||
companiesTable = fmt.Sprintf("%s.%s", c.config.Schema, CompaniesTableName)
|
||||
)
|
||||
|
||||
getComList := psql.Select(
|
||||
"id", "uid", "name", "legal_person", "description", "website",
|
||||
"physical_address", "legal_address", "inn", "is_active", // TODO: add KPP when DB supports it
|
||||
"has_moderation_ticket", "staff", "metadata", "additional_fields_tmpl",
|
||||
).
|
||||
From(companiesTable).
|
||||
Where(squirrel.Eq{"uid": request.Id})
|
||||
|
||||
query, args, err := getComList.ToSql()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: error building get distributor company list query: %v", dberrors.ErrInternal, err)
|
||||
}
|
||||
|
||||
rows, err := c.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: error executing get distributor company list query: %v", dberrors.ErrInternal, err)
|
||||
}
|
||||
|
||||
defer rows.Close()
|
||||
|
||||
var res dbtypes.CompanyListGetResponse
|
||||
|
||||
for rows.Next() {
|
||||
var (
|
||||
name, legalPerson, description, website, physicalAddress sql.NullString
|
||||
legalAddress, inn, metadata, additionalFieldsTmpl sql.NullString
|
||||
staff pq.StringArray
|
||||
company dbtypes.Company
|
||||
)
|
||||
|
||||
if err := rows.Scan(
|
||||
&company.Id, &company.OwnerId, &name, &legalPerson, &description, &website,
|
||||
&physicalAddress, &legalAddress, &inn, &company.IsActive,
|
||||
&company.HasModerationTicket, &staff, &metadata, &additionalFieldsTmpl,
|
||||
); err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, dberrors.ErrNotFound
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("%w: error scanning row: %v", dberrors.ErrInternal, err)
|
||||
}
|
||||
|
||||
if name.Valid {
|
||||
company.Name = &name.String
|
||||
}
|
||||
|
||||
if legalPerson.Valid {
|
||||
company.LegalPerson = &legalPerson.String
|
||||
}
|
||||
|
||||
if description.Valid {
|
||||
company.Description = &description.String
|
||||
}
|
||||
|
||||
if website.Valid {
|
||||
company.Website = &website.String
|
||||
}
|
||||
|
||||
if physicalAddress.Valid {
|
||||
company.PhysicalAddress = &physicalAddress.String
|
||||
}
|
||||
|
||||
if legalAddress.Valid {
|
||||
company.LegalAddress = &legalAddress.String
|
||||
}
|
||||
|
||||
if inn.Valid {
|
||||
company.Inn = &inn.String
|
||||
}
|
||||
|
||||
company.Staff = staff
|
||||
|
||||
if metadata.Valid {
|
||||
company.Metadata = &metadata.String
|
||||
}
|
||||
|
||||
if additionalFieldsTmpl.Valid {
|
||||
company.ExtraFieldsTemplate = &additionalFieldsTmpl.String
|
||||
}
|
||||
|
||||
res.Companies = append(res.Companies, company)
|
||||
}
|
||||
|
||||
return &res, nil
|
||||
}
|
||||
|
||||
func (c *client) GetCompanyById(
|
||||
ctx context.Context,
|
||||
request *dbtypes.CompanyByIdGetRequest,
|
||||
) (*dbtypes.CompanyByIdGetResponse, error) {
|
||||
if request == nil {
|
||||
return nil, fmt.Errorf("%w: request is nil", dberrors.ErrBadRequest)
|
||||
}
|
||||
|
||||
var (
|
||||
psql = squirrel.StatementBuilder.PlaceholderFormat(squirrel.Dollar)
|
||||
|
||||
companiesTable = fmt.Sprintf("%s.%s", c.config.Schema, CompaniesTableName)
|
||||
)
|
||||
|
||||
getComList := psql.Select(
|
||||
"id", "uid", "name", "legal_person", "description", "website",
|
||||
"physical_address", "legal_address", "inn", "is_active", // TODO: add KPP when DB supports it
|
||||
"has_moderation_ticket", "staff", "metadata", "additional_fields_tmpl",
|
||||
).
|
||||
From(companiesTable).
|
||||
Where(squirrel.And{
|
||||
squirrel.Eq{"id": request.CompanyId},
|
||||
squirrel.Eq{"uid": request.Id},
|
||||
})
|
||||
|
||||
query, args, err := getComList.ToSql()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: error building get distributor company list query: %v", dberrors.ErrInternal, err)
|
||||
}
|
||||
|
||||
row := c.db.QueryRowContext(ctx, query, args...)
|
||||
|
||||
var (
|
||||
name, legalPerson, description, website, physicalAddress sql.NullString
|
||||
legalAddress, inn, metadata, additionalFieldsTmpl sql.NullString
|
||||
staff pq.StringArray
|
||||
company dbtypes.Company
|
||||
)
|
||||
|
||||
if err := row.Scan(
|
||||
&company.Id, &company.OwnerId, &name, &legalPerson, &description, &website,
|
||||
&physicalAddress, &legalAddress, &inn, &company.IsActive,
|
||||
&company.HasModerationTicket, &staff, &metadata, &additionalFieldsTmpl,
|
||||
); err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, dberrors.ErrNotFound
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("%w: error scanning row: %v", dberrors.ErrInternal, err)
|
||||
}
|
||||
|
||||
if name.Valid {
|
||||
company.Name = &name.String
|
||||
}
|
||||
|
||||
if legalPerson.Valid {
|
||||
company.LegalPerson = &legalPerson.String
|
||||
}
|
||||
|
||||
if description.Valid {
|
||||
company.Description = &description.String
|
||||
}
|
||||
|
||||
if website.Valid {
|
||||
company.Website = &website.String
|
||||
}
|
||||
|
||||
if physicalAddress.Valid {
|
||||
company.PhysicalAddress = &physicalAddress.String
|
||||
}
|
||||
|
||||
if legalAddress.Valid {
|
||||
company.LegalAddress = &legalAddress.String
|
||||
}
|
||||
|
||||
if inn.Valid {
|
||||
company.Inn = &inn.String
|
||||
}
|
||||
|
||||
company.Staff = staff
|
||||
|
||||
if metadata.Valid {
|
||||
company.Metadata = &metadata.String
|
||||
}
|
||||
|
||||
if additionalFieldsTmpl.Valid {
|
||||
company.ExtraFieldsTemplate = &additionalFieldsTmpl.String
|
||||
}
|
||||
|
||||
return &dbtypes.CompanyByIdGetResponse{Company: company}, nil
|
||||
}
|
||||
|
||||
func (c *client) CreateCompany(
|
||||
ctx context.Context,
|
||||
request *dbtypes.CompanyCreateRequest,
|
||||
) (*dbtypes.CompanyCreateResponse, error) {
|
||||
if request == nil {
|
||||
return nil, fmt.Errorf("%w: request is nil", dberrors.ErrBadRequest)
|
||||
}
|
||||
|
||||
tx, err := c.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: error starting transaction: %v", dberrors.ErrInternal, err)
|
||||
}
|
||||
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
|
||||
res, err := c.createCompany(ctx, tx, request)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating company: %w", err)
|
||||
}
|
||||
|
||||
if err := c.createCompanyValidationTicket(ctx, tx, res.Id, request); err != nil {
|
||||
return nil, fmt.Errorf("error creating company validation ticket: %w", err)
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return nil, fmt.Errorf("error committing transaction: %w", err)
|
||||
}
|
||||
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func (c *client) createCompany(
|
||||
ctx context.Context,
|
||||
driver Driver,
|
||||
request *dbtypes.CompanyCreateRequest,
|
||||
) (*dbtypes.CompanyCreateResponse, error) {
|
||||
var (
|
||||
psql = squirrel.StatementBuilder.PlaceholderFormat(squirrel.Dollar)
|
||||
|
||||
companyTable = fmt.Sprintf("%s.%s", c.config.Schema, CompaniesTableName)
|
||||
)
|
||||
|
||||
// TODO: use normal uuid after DB reengineering
|
||||
comId := fmt.Sprintf("%sCOM", strings.ReplaceAll(uuid.NewString(), "-", ""))
|
||||
|
||||
createCompany := psql.Insert(companyTable).
|
||||
Columns(
|
||||
"id", "uid", "is_active", "has_moderation_ticket", "metadata", "additional_fields_tmpl",
|
||||
).
|
||||
Values(
|
||||
comId, request.OwnerId, false, true, request.Metadata, request.ExtraFieldsTemplate,
|
||||
)
|
||||
|
||||
query, args, err := createCompany.ToSql()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: error building create company query: %v", dberrors.ErrInternal, err)
|
||||
}
|
||||
|
||||
res, err := driver.ExecContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: error executing create company query: %v", dberrors.ErrInternal, err)
|
||||
}
|
||||
|
||||
rowsAffected, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: error getting rows affected for create company query: %v", dberrors.ErrInternal, err)
|
||||
}
|
||||
|
||||
if rowsAffected == 0 {
|
||||
return nil, dberrors.ErrInternal
|
||||
}
|
||||
|
||||
return &dbtypes.CompanyCreateResponse{Id: comId}, nil
|
||||
}
|
||||
|
||||
func (c *client) createCompanyValidationTicket(
|
||||
ctx context.Context,
|
||||
driver Driver,
|
||||
companyId string,
|
||||
request *dbtypes.CompanyCreateRequest,
|
||||
) error {
|
||||
var (
|
||||
psql = squirrel.StatementBuilder.PlaceholderFormat(squirrel.Dollar)
|
||||
|
||||
comValTable = fmt.Sprintf("%s.%s", c.config.Schema, CompanyValidationTicketsTableName)
|
||||
)
|
||||
|
||||
var (
|
||||
ticketId = fmt.Sprintf("%sTCK", strings.ReplaceAll(uuid.NewString(), "-", ""))
|
||||
)
|
||||
|
||||
createCompany := psql.Insert(comValTable).
|
||||
Columns(
|
||||
"id", "company_id", "name", "legal_person", "description", "website",
|
||||
"physical_address", "legal_address", "inn", // TODO: add KPP when DB supports it
|
||||
"staff", "status",
|
||||
).
|
||||
Values(
|
||||
ticketId, companyId, request.Name, request.LegalPerson, request.Description, request.Website,
|
||||
request.PhysicalAddress, request.LegalAddress, request.Inn,
|
||||
request.Staff, CompanyModerationStatusStringToId(dbtypes.CompanyModerationStatusPending), // TODO: switch to status "NEW"
|
||||
)
|
||||
|
||||
query, args, err := createCompany.ToSql()
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: error building create company moderation ticket query: %v", dberrors.ErrInternal, err)
|
||||
}
|
||||
|
||||
res, err := driver.ExecContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: error executing create company moderation ticket query: %v", dberrors.ErrInternal, err)
|
||||
}
|
||||
|
||||
rowsAffected, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: error getting rows affected for create company moderation ticket query: %v", dberrors.ErrInternal, err)
|
||||
}
|
||||
|
||||
if rowsAffected == 0 {
|
||||
return dberrors.ErrInternal
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *client) UpdateCompany(
|
||||
ctx context.Context,
|
||||
request *dbtypes.CompanyUpdateRequest,
|
||||
) (*dbtypes.CompanyUpdateResponse, error) {
|
||||
if request == nil {
|
||||
return nil, fmt.Errorf("%w: request is nil", dberrors.ErrBadRequest)
|
||||
}
|
||||
|
||||
tx, err := c.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: error starting transaction: %v", dberrors.ErrInternal, err)
|
||||
}
|
||||
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
|
||||
res, err := c.updateCompany(ctx, tx, request)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error updating company: %w", err)
|
||||
}
|
||||
|
||||
if err := c.updateCompanyValidationTicket(ctx, tx, request.Id, request); err != nil {
|
||||
return nil, fmt.Errorf("error updating company validation ticket: %w", err)
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return nil, fmt.Errorf("error committing transaction: %w", err)
|
||||
}
|
||||
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func (c *client) updateCompany(
|
||||
ctx context.Context,
|
||||
driver Driver,
|
||||
request *dbtypes.CompanyUpdateRequest,
|
||||
) (*dbtypes.CompanyUpdateResponse, error) {
|
||||
var (
|
||||
psql = squirrel.StatementBuilder.PlaceholderFormat(squirrel.Dollar)
|
||||
|
||||
companyTable = fmt.Sprintf("%s.%s", c.config.Schema, CompaniesTableName)
|
||||
)
|
||||
|
||||
updateCompany := psql.Update(companyTable).
|
||||
SetMap(map[string]any{
|
||||
"is_active": false,
|
||||
"has_moderation_ticket": true,
|
||||
"metadata": request.Metadata,
|
||||
"additional_fields_tmpl": request.ExtraFields,
|
||||
}).
|
||||
Where(squirrel.Eq{"id": request.Id})
|
||||
|
||||
query, args, err := updateCompany.ToSql()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: error building update company query: %v", dberrors.ErrInternal, err)
|
||||
}
|
||||
|
||||
res, err := driver.ExecContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: error executing update company query: %v", dberrors.ErrInternal, err)
|
||||
}
|
||||
|
||||
rowsAffected, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: error getting rows affected for update company query: %v", dberrors.ErrInternal, err)
|
||||
}
|
||||
|
||||
if rowsAffected == 0 {
|
||||
return nil, dberrors.ErrInternal
|
||||
}
|
||||
|
||||
return &dbtypes.CompanyUpdateResponse{}, nil
|
||||
}
|
||||
|
||||
// NOTE: do we believe that every company has a moderation ticket?
|
||||
func (c *client) updateCompanyValidationTicket(
|
||||
ctx context.Context,
|
||||
driver Driver,
|
||||
companyId string,
|
||||
request *dbtypes.CompanyUpdateRequest,
|
||||
) error {
|
||||
var (
|
||||
psql = squirrel.StatementBuilder.PlaceholderFormat(squirrel.Dollar)
|
||||
|
||||
comValTable = fmt.Sprintf("%s.%s", c.config.Schema, CompanyValidationTicketsTableName)
|
||||
)
|
||||
|
||||
updateCompany := psql.Update(comValTable).
|
||||
Where(squirrel.Eq{"company_id": companyId})
|
||||
|
||||
if request.Name != nil {
|
||||
updateCompany = updateCompany.Set("name", *request.Name)
|
||||
}
|
||||
|
||||
if request.LegalPerson != nil {
|
||||
updateCompany = updateCompany.Set("legal_person", *request.LegalPerson)
|
||||
}
|
||||
|
||||
if request.Description != nil {
|
||||
updateCompany = updateCompany.Set("description", *request.Description)
|
||||
}
|
||||
|
||||
if request.Website != nil {
|
||||
updateCompany = updateCompany.Set("website", *request.Website)
|
||||
}
|
||||
|
||||
if request.PhysicalAddress != nil {
|
||||
updateCompany = updateCompany.Set("physical_address", *request.PhysicalAddress)
|
||||
}
|
||||
|
||||
if request.LegalAddress != nil {
|
||||
updateCompany = updateCompany.Set("legal_address", *request.LegalAddress)
|
||||
}
|
||||
|
||||
if request.Inn != nil {
|
||||
updateCompany = updateCompany.Set("inn", *request.Inn)
|
||||
}
|
||||
|
||||
if len(request.Staff) > 0 {
|
||||
updateCompany = updateCompany.Set("staff", request.Staff)
|
||||
}
|
||||
|
||||
query, args, err := updateCompany.ToSql()
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: error building update company moderation ticket query: %v", dberrors.ErrInternal, err)
|
||||
}
|
||||
|
||||
res, err := driver.ExecContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: error executing update company moderation ticket query: %v", dberrors.ErrInternal, err)
|
||||
}
|
||||
|
||||
rowsAffected, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: error getting rows affected for update company moderation ticket query: %v", dberrors.ErrInternal, err)
|
||||
}
|
||||
|
||||
if rowsAffected == 0 {
|
||||
return dberrors.ErrInternal
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
Reference in New Issue
Block a user