package pgdb import ( "context" "database/sql" "errors" "fmt" "strings" dberrors "git-molva.ru/Molva/molva-backend/services/api_gateway/internal/database/errors" dbtypes "git-molva.ru/Molva/molva-backend/services/api_gateway/internal/database/types" "github.com/Masterminds/squirrel" "github.com/google/uuid" "github.com/lib/pq" ) func CompanyModerationStatusIdToString(status int32) dbtypes.CompanyModerationStatus { switch status { case 0: return dbtypes.CompanyModerationStatusPending case 1: return dbtypes.CompanyModerationStatusApproved case 2: return dbtypes.CompanyModerationStatusRejected case 3: return dbtypes.CompanyModerationStatusNew default: return dbtypes.CompanyModerationStatusNew } } func CompanyModerationStatusStringToId(status dbtypes.CompanyModerationStatus) int32 { switch status { case dbtypes.CompanyModerationStatusPending: return 0 case dbtypes.CompanyModerationStatusApproved: return 1 case dbtypes.CompanyModerationStatusRejected: return 2 case dbtypes.CompanyModerationStatusNew: return 3 default: return 3 } } //nolint:gocognit // TODO: refactor func (c *client) GetCompanyList( ctx context.Context, request *dbtypes.CompanyListGetRequest, ) (*dbtypes.CompanyListGetResponse, error) { if request == nil { return nil, fmt.Errorf("%w: request is nil", dberrors.ErrBadRequest) } var ( psql = squirrel.StatementBuilder.PlaceholderFormat(squirrel.Dollar) companiesTable = fmt.Sprintf("%s.%s", c.config.Schema, CompaniesTableName) ) getComList := psql.Select( "id", "uid", "name", "legal_person", "description", "website", "physical_address", "legal_address", "inn", "is_active", // TODO: add KPP when DB supports it "has_moderation_ticket", "staff", "metadata", "additional_fields_tmpl", ). From(companiesTable). Where(squirrel.Eq{"uid": request.Id}) query, args, err := getComList.ToSql() if err != nil { return nil, fmt.Errorf("%w: error building get distributor company list query: %v", dberrors.ErrInternal, err) } rows, err := c.db.QueryContext(ctx, query, args...) if err != nil { return nil, fmt.Errorf("%w: error executing get distributor company list query: %v", dberrors.ErrInternal, err) } defer rows.Close() var res dbtypes.CompanyListGetResponse for rows.Next() { var ( name, legalPerson, description, website, physicalAddress sql.NullString legalAddress, inn, metadata, additionalFieldsTmpl sql.NullString staff pq.StringArray company dbtypes.Company ) if err := rows.Scan( &company.Id, &company.OwnerId, &name, &legalPerson, &description, &website, &physicalAddress, &legalAddress, &inn, &company.IsActive, &company.HasModerationTicket, &staff, &metadata, &additionalFieldsTmpl, ); err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, dberrors.ErrNotFound } return nil, fmt.Errorf("%w: error scanning row: %v", dberrors.ErrInternal, err) } if name.Valid { company.Name = &name.String } if legalPerson.Valid { company.LegalPerson = &legalPerson.String } if description.Valid { company.Description = &description.String } if website.Valid { company.Website = &website.String } if physicalAddress.Valid { company.PhysicalAddress = &physicalAddress.String } if legalAddress.Valid { company.LegalAddress = &legalAddress.String } if inn.Valid { company.Inn = &inn.String } company.Staff = staff if metadata.Valid { company.Metadata = &metadata.String } if additionalFieldsTmpl.Valid { company.ExtraFieldsTemplate = &additionalFieldsTmpl.String } res.Companies = append(res.Companies, company) } return &res, nil } func (c *client) GetCompanyById( ctx context.Context, request *dbtypes.CompanyByIdGetRequest, ) (*dbtypes.CompanyByIdGetResponse, error) { if request == nil { return nil, fmt.Errorf("%w: request is nil", dberrors.ErrBadRequest) } var ( psql = squirrel.StatementBuilder.PlaceholderFormat(squirrel.Dollar) companiesTable = fmt.Sprintf("%s.%s", c.config.Schema, CompaniesTableName) ) getComList := psql.Select( "id", "uid", "name", "legal_person", "description", "website", "physical_address", "legal_address", "inn", "is_active", // TODO: add KPP when DB supports it "has_moderation_ticket", "staff", "metadata", "additional_fields_tmpl", ). From(companiesTable). Where(squirrel.And{ squirrel.Eq{"id": request.CompanyId}, squirrel.Eq{"uid": request.Id}, }) query, args, err := getComList.ToSql() if err != nil { return nil, fmt.Errorf("%w: error building get distributor company list query: %v", dberrors.ErrInternal, err) } row := c.db.QueryRowContext(ctx, query, args...) var ( name, legalPerson, description, website, physicalAddress sql.NullString legalAddress, inn, metadata, additionalFieldsTmpl sql.NullString staff pq.StringArray company dbtypes.Company ) if err := row.Scan( &company.Id, &company.OwnerId, &name, &legalPerson, &description, &website, &physicalAddress, &legalAddress, &inn, &company.IsActive, &company.HasModerationTicket, &staff, &metadata, &additionalFieldsTmpl, ); err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, dberrors.ErrNotFound } return nil, fmt.Errorf("%w: error scanning row: %v", dberrors.ErrInternal, err) } if name.Valid { company.Name = &name.String } if legalPerson.Valid { company.LegalPerson = &legalPerson.String } if description.Valid { company.Description = &description.String } if website.Valid { company.Website = &website.String } if physicalAddress.Valid { company.PhysicalAddress = &physicalAddress.String } if legalAddress.Valid { company.LegalAddress = &legalAddress.String } if inn.Valid { company.Inn = &inn.String } company.Staff = staff if metadata.Valid { company.Metadata = &metadata.String } if additionalFieldsTmpl.Valid { company.ExtraFieldsTemplate = &additionalFieldsTmpl.String } return &dbtypes.CompanyByIdGetResponse{Company: company}, nil } func (c *client) CreateCompany( ctx context.Context, request *dbtypes.CompanyCreateRequest, ) (*dbtypes.CompanyCreateResponse, error) { if request == nil { return nil, fmt.Errorf("%w: request is nil", dberrors.ErrBadRequest) } tx, err := c.db.BeginTx(ctx, nil) if err != nil { return nil, fmt.Errorf("%w: error starting transaction: %v", dberrors.ErrInternal, err) } defer func() { _ = tx.Rollback() }() res, err := c.createCompany(ctx, tx, request) if err != nil { return nil, fmt.Errorf("error creating company: %w", err) } if err := c.createCompanyValidationTicket(ctx, tx, res.Id, request); err != nil { return nil, fmt.Errorf("error creating company validation ticket: %w", err) } if err := tx.Commit(); err != nil { return nil, fmt.Errorf("error committing transaction: %w", err) } return res, nil } func (c *client) createCompany( ctx context.Context, driver Driver, request *dbtypes.CompanyCreateRequest, ) (*dbtypes.CompanyCreateResponse, error) { var ( psql = squirrel.StatementBuilder.PlaceholderFormat(squirrel.Dollar) companyTable = fmt.Sprintf("%s.%s", c.config.Schema, CompaniesTableName) ) // TODO: use normal uuid after DB reengineering comId := fmt.Sprintf("%sCOM", strings.ReplaceAll(uuid.NewString(), "-", "")) createCompany := psql.Insert(companyTable). Columns( "id", "uid", "is_active", "has_moderation_ticket", "metadata", "additional_fields_tmpl", ). Values( comId, request.OwnerId, false, true, request.Metadata, request.ExtraFieldsTemplate, ) query, args, err := createCompany.ToSql() if err != nil { return nil, fmt.Errorf("%w: error building create company query: %v", dberrors.ErrInternal, err) } res, err := driver.ExecContext(ctx, query, args...) if err != nil { return nil, fmt.Errorf("%w: error executing create company query: %v", dberrors.ErrInternal, err) } rowsAffected, err := res.RowsAffected() if err != nil { return nil, fmt.Errorf("%w: error getting rows affected for create company query: %v", dberrors.ErrInternal, err) } if rowsAffected == 0 { return nil, dberrors.ErrInternal } return &dbtypes.CompanyCreateResponse{Id: comId}, nil } func (c *client) createCompanyValidationTicket( ctx context.Context, driver Driver, companyId string, request *dbtypes.CompanyCreateRequest, ) error { var ( psql = squirrel.StatementBuilder.PlaceholderFormat(squirrel.Dollar) comValTable = fmt.Sprintf("%s.%s", c.config.Schema, CompanyValidationTicketsTableName) ) var ( ticketId = fmt.Sprintf("%sTCK", strings.ReplaceAll(uuid.NewString(), "-", "")) ) createCompany := psql.Insert(comValTable). Columns( "id", "company_id", "name", "legal_person", "description", "website", "physical_address", "legal_address", "inn", // TODO: add KPP when DB supports it "staff", "status", ). Values( ticketId, companyId, request.Name, request.LegalPerson, request.Description, request.Website, request.PhysicalAddress, request.LegalAddress, request.Inn, request.Staff, CompanyModerationStatusStringToId(dbtypes.CompanyModerationStatusPending), // TODO: switch to status "NEW" ) query, args, err := createCompany.ToSql() if err != nil { return fmt.Errorf("%w: error building create company moderation ticket query: %v", dberrors.ErrInternal, err) } res, err := driver.ExecContext(ctx, query, args...) if err != nil { return fmt.Errorf("%w: error executing create company moderation ticket query: %v", dberrors.ErrInternal, err) } rowsAffected, err := res.RowsAffected() if err != nil { return fmt.Errorf("%w: error getting rows affected for create company moderation ticket query: %v", dberrors.ErrInternal, err) } if rowsAffected == 0 { return dberrors.ErrInternal } return nil } func (c *client) UpdateCompany( ctx context.Context, request *dbtypes.CompanyUpdateRequest, ) (*dbtypes.CompanyUpdateResponse, error) { if request == nil { return nil, fmt.Errorf("%w: request is nil", dberrors.ErrBadRequest) } tx, err := c.db.BeginTx(ctx, nil) if err != nil { return nil, fmt.Errorf("%w: error starting transaction: %v", dberrors.ErrInternal, err) } defer func() { _ = tx.Rollback() }() res, err := c.updateCompany(ctx, tx, request) if err != nil { return nil, fmt.Errorf("error updating company: %w", err) } if err := c.updateCompanyValidationTicket(ctx, tx, request.Id, request); err != nil { return nil, fmt.Errorf("error updating company validation ticket: %w", err) } if err := tx.Commit(); err != nil { return nil, fmt.Errorf("error committing transaction: %w", err) } return res, nil } func (c *client) updateCompany( ctx context.Context, driver Driver, request *dbtypes.CompanyUpdateRequest, ) (*dbtypes.CompanyUpdateResponse, error) { var ( psql = squirrel.StatementBuilder.PlaceholderFormat(squirrel.Dollar) companyTable = fmt.Sprintf("%s.%s", c.config.Schema, CompaniesTableName) ) updateCompany := psql.Update(companyTable). SetMap(map[string]any{ "is_active": false, "has_moderation_ticket": true, "metadata": request.Metadata, "additional_fields_tmpl": request.ExtraFields, }). Where(squirrel.Eq{"id": request.Id}) query, args, err := updateCompany.ToSql() if err != nil { return nil, fmt.Errorf("%w: error building update company query: %v", dberrors.ErrInternal, err) } res, err := driver.ExecContext(ctx, query, args...) if err != nil { return nil, fmt.Errorf("%w: error executing update company query: %v", dberrors.ErrInternal, err) } rowsAffected, err := res.RowsAffected() if err != nil { return nil, fmt.Errorf("%w: error getting rows affected for update company query: %v", dberrors.ErrInternal, err) } if rowsAffected == 0 { return nil, dberrors.ErrInternal } return &dbtypes.CompanyUpdateResponse{}, nil } // NOTE: do we believe that every company has a moderation ticket? func (c *client) updateCompanyValidationTicket( ctx context.Context, driver Driver, companyId string, request *dbtypes.CompanyUpdateRequest, ) error { var ( psql = squirrel.StatementBuilder.PlaceholderFormat(squirrel.Dollar) comValTable = fmt.Sprintf("%s.%s", c.config.Schema, CompanyValidationTicketsTableName) ) updateCompany := psql.Update(comValTable). Where(squirrel.Eq{"company_id": companyId}) if request.Name != nil { updateCompany = updateCompany.Set("name", *request.Name) } if request.LegalPerson != nil { updateCompany = updateCompany.Set("legal_person", *request.LegalPerson) } if request.Description != nil { updateCompany = updateCompany.Set("description", *request.Description) } if request.Website != nil { updateCompany = updateCompany.Set("website", *request.Website) } if request.PhysicalAddress != nil { updateCompany = updateCompany.Set("physical_address", *request.PhysicalAddress) } if request.LegalAddress != nil { updateCompany = updateCompany.Set("legal_address", *request.LegalAddress) } if request.Inn != nil { updateCompany = updateCompany.Set("inn", *request.Inn) } if len(request.Staff) > 0 { updateCompany = updateCompany.Set("staff", request.Staff) } query, args, err := updateCompany.ToSql() if err != nil { return fmt.Errorf("%w: error building update company moderation ticket query: %v", dberrors.ErrInternal, err) } res, err := driver.ExecContext(ctx, query, args...) if err != nil { return fmt.Errorf("%w: error executing update company moderation ticket query: %v", dberrors.ErrInternal, err) } rowsAffected, err := res.RowsAffected() if err != nil { return fmt.Errorf("%w: error getting rows affected for update company moderation ticket query: %v", dberrors.ErrInternal, err) } if rowsAffected == 0 { return dberrors.ErrInternal } return nil }