feat: add Airwallex payments and multi-currency support
This commit is contained in:
@@ -3,6 +3,7 @@ package service
|
||||
import (
|
||||
"math"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/payment"
|
||||
"github.com/shopspring/decimal"
|
||||
)
|
||||
|
||||
@@ -22,16 +23,17 @@ func calculateCreditedBalance(paymentAmount, multiplier float64) float64 {
|
||||
InexactFloat64()
|
||||
}
|
||||
|
||||
func calculateGatewayRefundAmount(orderAmount, payAmount, refundAmount float64) float64 {
|
||||
func calculateGatewayRefundAmount(orderAmount, payAmount, refundAmount float64, currency string) float64 {
|
||||
if orderAmount <= 0 || payAmount <= 0 || refundAmount <= 0 {
|
||||
return 0
|
||||
}
|
||||
if math.Abs(refundAmount-orderAmount) <= amountToleranceCNY {
|
||||
return decimal.NewFromFloat(payAmount).Round(2).InexactFloat64()
|
||||
fractionDigits := int32(payment.CurrencyMaxFractionDigits(currency))
|
||||
if math.Abs(refundAmount-orderAmount) <= paymentAmountToleranceForCurrency(currency) {
|
||||
return decimal.NewFromFloat(payAmount).Round(fractionDigits).InexactFloat64()
|
||||
}
|
||||
return decimal.NewFromFloat(payAmount).
|
||||
Mul(decimal.NewFromFloat(refundAmount)).
|
||||
Div(decimal.NewFromFloat(orderAmount)).
|
||||
Round(2).
|
||||
Round(fractionDigits).
|
||||
InexactFloat64()
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance"
|
||||
"github.com/Wei-Shaw/sub2api/internal/payment"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
)
|
||||
|
||||
// GetAvailableMethodLimits collects all payment types from enabled provider
|
||||
@@ -25,7 +26,12 @@ func (s *PaymentConfigService) GetAvailableMethodLimits(ctx context.Context) (*M
|
||||
Methods: make(map[string]MethodLimits, len(typeInstances)),
|
||||
}
|
||||
for pt, insts := range typeInstances {
|
||||
currency, ok := s.pcAggregateMethodCurrency(insts)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
ml := pcAggregateMethodLimits(pt, insts)
|
||||
ml.Currency = currency
|
||||
resp.Methods[ml.PaymentType] = ml
|
||||
}
|
||||
resp.GlobalMin, resp.GlobalMax = pcComputeGlobalRange(resp.Methods)
|
||||
@@ -82,11 +88,81 @@ func (s *PaymentConfigService) GetMethodLimits(ctx context.Context, types []stri
|
||||
matching = append(matching, inst)
|
||||
}
|
||||
}
|
||||
result = append(result, pcAggregateMethodLimits(pt, matching))
|
||||
currency, ok := s.pcAggregateMethodCurrency(matching)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
ml := pcAggregateMethodLimits(pt, matching)
|
||||
ml.Currency = currency
|
||||
result = append(result, ml)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *PaymentConfigService) ValidateMethodCurrencyConsistency(ctx context.Context, paymentType string) (string, error) {
|
||||
method := NormalizeVisibleMethod(paymentType)
|
||||
if method == "" || s == nil || s.entClient == nil {
|
||||
return payment.DefaultPaymentCurrency, nil
|
||||
}
|
||||
|
||||
instances, err := s.entClient.PaymentProviderInstance.Query().
|
||||
Where(paymentproviderinstance.EnabledEQ(true)).All(ctx)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("query provider instances: %w", err)
|
||||
}
|
||||
|
||||
typeInstances := pcGroupByPaymentType(instances)
|
||||
typeInstances = s.pcApplyEnabledVisibleMethodInstances(ctx, typeInstances, instances)
|
||||
matching := typeInstances[method]
|
||||
if len(matching) == 0 {
|
||||
return payment.DefaultPaymentCurrency, nil
|
||||
}
|
||||
|
||||
currency, ok := s.pcAggregateMethodCurrency(matching)
|
||||
if !ok {
|
||||
return "", infraerrors.ServiceUnavailable(
|
||||
"PAYMENT_METHOD_CURRENCY_CONFLICT",
|
||||
"payment method has enabled provider instances with mixed currencies",
|
||||
).WithMetadata(map[string]string{"payment_type": method})
|
||||
}
|
||||
return currency, nil
|
||||
}
|
||||
|
||||
func (s *PaymentConfigService) pcAggregateMethodCurrency(instances []*dbent.PaymentProviderInstance) (string, bool) {
|
||||
currency := ""
|
||||
for _, inst := range instances {
|
||||
next := s.pcInstancePaymentCurrency(inst)
|
||||
if next == "" {
|
||||
continue
|
||||
}
|
||||
if currency == "" {
|
||||
currency = next
|
||||
continue
|
||||
}
|
||||
if currency != next {
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
if currency == "" {
|
||||
return payment.DefaultPaymentCurrency, true
|
||||
}
|
||||
return currency, true
|
||||
}
|
||||
|
||||
func (s *PaymentConfigService) pcInstancePaymentCurrency(inst *dbent.PaymentProviderInstance) string {
|
||||
if inst == nil {
|
||||
return payment.DefaultPaymentCurrency
|
||||
}
|
||||
cfg := map[string]string{}
|
||||
if s != nil {
|
||||
decrypted, err := s.decryptConfig(inst.Config)
|
||||
if err == nil && decrypted != nil {
|
||||
cfg = decrypted
|
||||
}
|
||||
}
|
||||
return paymentProviderConfigCurrency(inst.ProviderKey, cfg)
|
||||
}
|
||||
|
||||
// pcGroupByPaymentType groups instances by user-facing payment type.
|
||||
// For Stripe providers, ALL sub-types (card, link, alipay, wxpay) map to "stripe"
|
||||
// because the user sees a single "Stripe" button, not individual sub-methods.
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/internal/payment"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
@@ -199,6 +200,61 @@ func TestPcGroupByPaymentType(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestPcAggregateMethodCurrency(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
svc := &PaymentConfigService{}
|
||||
stripe := makeInstance(1, payment.TypeStripe, payment.TypeStripe, "")
|
||||
stripe.Config = `{"currency":"hkd"}`
|
||||
currency, ok := svc.pcAggregateMethodCurrency([]*dbent.PaymentProviderInstance{stripe})
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "HKD", currency)
|
||||
|
||||
airwallex := makeInstance(2, payment.TypeAirwallex, payment.TypeAirwallex, "")
|
||||
airwallex.Config = `{"currency":"usd"}`
|
||||
currency, ok = svc.pcAggregateMethodCurrency([]*dbent.PaymentProviderInstance{stripe, airwallex})
|
||||
require.False(t, ok)
|
||||
require.Empty(t, currency)
|
||||
|
||||
easypay := makeInstance(3, payment.TypeEasyPay, payment.TypeAlipay, "")
|
||||
currency, ok = svc.pcAggregateMethodCurrency([]*dbent.PaymentProviderInstance{easypay})
|
||||
require.True(t, ok)
|
||||
require.Equal(t, payment.DefaultPaymentCurrency, currency)
|
||||
}
|
||||
|
||||
func TestGetAvailableMethodLimitsOmitsMixedCurrencyMethod(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := newPaymentConfigServiceTestClient(t)
|
||||
|
||||
_, err := client.PaymentProviderInstance.Create().
|
||||
SetProviderKey(payment.TypeStripe).
|
||||
SetName("Stripe HKD").
|
||||
SetConfig(`{"currency":"HKD"}`).
|
||||
SetSupportedTypes("card,link").
|
||||
SetEnabled(true).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = client.PaymentProviderInstance.Create().
|
||||
SetProviderKey(payment.TypeStripe).
|
||||
SetName("Stripe USD").
|
||||
SetConfig(`{"currency":"USD"}`).
|
||||
SetSupportedTypes("card,link").
|
||||
SetEnabled(true).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
svc := &PaymentConfigService{entClient: client}
|
||||
resp, err := svc.GetAvailableMethodLimits(ctx)
|
||||
require.NoError(t, err)
|
||||
require.NotContains(t, resp.Methods, payment.TypeStripe)
|
||||
|
||||
_, err = svc.ValidateMethodCurrencyConsistency(ctx, payment.TypeStripe)
|
||||
require.Error(t, err)
|
||||
appErr := infraerrors.FromError(err)
|
||||
require.Equal(t, "PAYMENT_METHOD_CURRENCY_CONFLICT", appErr.Reason)
|
||||
}
|
||||
|
||||
func TestPcComputeGlobalRange(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@@ -110,10 +110,11 @@ var pendingOrderStatuses = []string{
|
||||
// Key matching is case-insensitive. Non-listed keys (e.g. appId, notifyUrl,
|
||||
// stripe publishableKey) are returned in plaintext by the admin GET API.
|
||||
var providerSensitiveConfigFields = map[string]map[string]struct{}{
|
||||
payment.TypeEasyPay: {"pkey": {}},
|
||||
payment.TypeAlipay: {"privatekey": {}, "publickey": {}, "alipaypublickey": {}},
|
||||
payment.TypeWxpay: {"privatekey": {}, "apiv3key": {}, "publickey": {}},
|
||||
payment.TypeStripe: {"secretkey": {}, "webhooksecret": {}},
|
||||
payment.TypeEasyPay: {"pkey": {}},
|
||||
payment.TypeAlipay: {"privatekey": {}, "publickey": {}, "alipaypublickey": {}},
|
||||
payment.TypeWxpay: {"privatekey": {}, "apiv3key": {}, "publickey": {}},
|
||||
payment.TypeStripe: {"secretkey": {}, "webhooksecret": {}},
|
||||
payment.TypeAirwallex: {"apikey": {}, "webhooksecret": {}},
|
||||
}
|
||||
|
||||
// providerPendingOrderProtectedConfigFields lists config keys that cannot be
|
||||
@@ -121,10 +122,11 @@ var providerSensitiveConfigFields = map[string]map[string]struct{}{
|
||||
// all provider identity fields that are snapshotted into orders or used by
|
||||
// webhook/refund verification.
|
||||
var providerPendingOrderProtectedConfigFields = map[string]map[string]struct{}{
|
||||
payment.TypeEasyPay: {"pkey": {}, "pid": {}},
|
||||
payment.TypeAlipay: {"privatekey": {}, "publickey": {}, "alipaypublickey": {}, "appid": {}},
|
||||
payment.TypeWxpay: {"privatekey": {}, "apiv3key": {}, "publickey": {}, "appid": {}, "mpappid": {}, "mchid": {}, "publickeyid": {}, "certserial": {}},
|
||||
payment.TypeStripe: {"secretkey": {}, "webhooksecret": {}},
|
||||
payment.TypeEasyPay: {"pkey": {}, "pid": {}},
|
||||
payment.TypeAlipay: {"privatekey": {}, "publickey": {}, "alipaypublickey": {}, "appid": {}},
|
||||
payment.TypeWxpay: {"privatekey": {}, "apiv3key": {}, "publickey": {}, "appid": {}, "mpappid": {}, "mchid": {}, "publickeyid": {}, "certserial": {}},
|
||||
payment.TypeStripe: {"secretkey": {}, "webhooksecret": {}, "currency": {}},
|
||||
payment.TypeAirwallex: {"clientid": {}, "apikey": {}, "webhooksecret": {}, "apibase": {}, "accountid": {}, "currency": {}},
|
||||
}
|
||||
|
||||
func isSensitiveProviderConfigField(providerKey, fieldName string) bool {
|
||||
@@ -175,7 +177,7 @@ func (s *PaymentConfigService) countPendingOrdersByPlan(ctx context.Context, pla
|
||||
}
|
||||
|
||||
var validProviderKeys = map[string]bool{
|
||||
payment.TypeEasyPay: true, payment.TypeAlipay: true, payment.TypeWxpay: true, payment.TypeStripe: true,
|
||||
payment.TypeEasyPay: true, payment.TypeAlipay: true, payment.TypeWxpay: true, payment.TypeStripe: true, payment.TypeAirwallex: true,
|
||||
}
|
||||
|
||||
func (s *PaymentConfigService) CreateProviderInstance(ctx context.Context, req CreateProviderInstanceRequest) (*dbent.PaymentProviderInstance, error) {
|
||||
|
||||
@@ -44,6 +44,13 @@ func TestValidateProviderRequest(t *testing.T) {
|
||||
supportedTypes: "",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid airwallex provider",
|
||||
providerKey: payment.TypeAirwallex,
|
||||
providerName: "Airwallex Provider",
|
||||
supportedTypes: payment.TypeAirwallex,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid alipay provider",
|
||||
providerKey: "alipay",
|
||||
@@ -120,6 +127,7 @@ func TestIsSensitiveProviderConfigField(t *testing.T) {
|
||||
{"stripe", "webhookSecret", true},
|
||||
{"stripe", "SecretKey", true}, // case-insensitive
|
||||
{"stripe", "publishableKey", false},
|
||||
{"stripe", "currency", false},
|
||||
{"stripe", "appId", false},
|
||||
|
||||
// Alipay
|
||||
@@ -142,6 +150,14 @@ func TestIsSensitiveProviderConfigField(t *testing.T) {
|
||||
{"easypay", "pid", false},
|
||||
{"easypay", "apiBase", false},
|
||||
|
||||
// Airwallex
|
||||
{payment.TypeAirwallex, "apiKey", true},
|
||||
{payment.TypeAirwallex, "webhookSecret", true},
|
||||
{payment.TypeAirwallex, "clientId", false},
|
||||
{payment.TypeAirwallex, "apiBase", false},
|
||||
{payment.TypeAirwallex, "accountId", false},
|
||||
{payment.TypeAirwallex, "currency", false},
|
||||
|
||||
// Unknown provider: never sensitive
|
||||
{"unknown", "secretKey", false},
|
||||
}
|
||||
@@ -395,6 +411,42 @@ func TestUpdateProviderInstanceRejectsProtectedConfigChangesWhilePendingOrders(t
|
||||
fieldName: "pid",
|
||||
wantValue: "pid-test",
|
||||
},
|
||||
{
|
||||
name: "stripe currency",
|
||||
providerKey: payment.TypeStripe,
|
||||
createConfig: validStripeProviderConfig,
|
||||
supportedType: []string{payment.TypeStripe},
|
||||
updateConfig: map[string]string{"currency": "HKD"},
|
||||
fieldName: "currency",
|
||||
wantValue: "CNY",
|
||||
},
|
||||
{
|
||||
name: "airwallex accountId",
|
||||
providerKey: payment.TypeAirwallex,
|
||||
createConfig: validAirwallexProviderConfig,
|
||||
supportedType: []string{payment.TypeAirwallex},
|
||||
updateConfig: map[string]string{"accountId": "acct-updated"},
|
||||
fieldName: "accountId",
|
||||
wantValue: "acct-test",
|
||||
},
|
||||
{
|
||||
name: "airwallex currency",
|
||||
providerKey: payment.TypeAirwallex,
|
||||
createConfig: validAirwallexProviderConfig,
|
||||
supportedType: []string{payment.TypeAirwallex},
|
||||
updateConfig: map[string]string{"currency": "HKD"},
|
||||
fieldName: "currency",
|
||||
wantValue: "CNY",
|
||||
},
|
||||
{
|
||||
name: "airwallex webhookSecret",
|
||||
providerKey: payment.TypeAirwallex,
|
||||
createConfig: validAirwallexProviderConfig,
|
||||
supportedType: []string{payment.TypeAirwallex},
|
||||
updateConfig: map[string]string{"webhookSecret": "whsec-updated"},
|
||||
fieldName: "webhookSecret",
|
||||
wantValue: "whsec-test",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
@@ -506,6 +558,39 @@ func TestUpdateProviderInstanceAllowsSafeConfigChangesWhilePendingOrders(t *test
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateProviderInstanceClearsAirwallexAccountID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
client := newPaymentConfigServiceTestClient(t)
|
||||
svc := &PaymentConfigService{
|
||||
entClient: client,
|
||||
encryptionKey: []byte("0123456789abcdef0123456789abcdef"),
|
||||
}
|
||||
|
||||
instance, err := svc.CreateProviderInstance(ctx, CreateProviderInstanceRequest{
|
||||
ProviderKey: payment.TypeAirwallex,
|
||||
Name: "airwallex-clear-account",
|
||||
Config: validAirwallexProviderConfig(t),
|
||||
SupportedTypes: []string{payment.TypeAirwallex},
|
||||
Enabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
updated, err := svc.UpdateProviderInstance(ctx, instance.ID, UpdateProviderInstanceRequest{
|
||||
Config: map[string]string{"accountId": ""},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, updated)
|
||||
|
||||
saved, err := client.PaymentProviderInstance.Get(ctx, instance.ID)
|
||||
require.NoError(t, err)
|
||||
cfg, err := svc.decryptConfig(saved.Config)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, cfg["accountId"])
|
||||
require.Equal(t, "client-id-test", cfg["clientId"])
|
||||
}
|
||||
|
||||
func createPendingProviderConfigOrder(t *testing.T, ctx context.Context, client *dbent.Client, instance *dbent.PaymentProviderInstance) {
|
||||
t.Helper()
|
||||
|
||||
@@ -545,11 +630,26 @@ func providerPendingOrderPaymentType(providerKey string) string {
|
||||
return payment.TypeWxpay
|
||||
case payment.TypeAlipay:
|
||||
return payment.TypeAlipay
|
||||
case payment.TypeAirwallex:
|
||||
return payment.TypeAirwallex
|
||||
case payment.TypeStripe:
|
||||
return payment.TypeStripe
|
||||
default:
|
||||
return payment.TypeAlipay
|
||||
}
|
||||
}
|
||||
|
||||
func validStripeProviderConfig(t *testing.T) map[string]string {
|
||||
t.Helper()
|
||||
|
||||
return map[string]string{
|
||||
"secretKey": "sk_test_123",
|
||||
"publishableKey": "pk_test_123",
|
||||
"webhookSecret": "whsec-test",
|
||||
"currency": "CNY",
|
||||
}
|
||||
}
|
||||
|
||||
func boolPtrValue(v bool) *bool {
|
||||
return &v
|
||||
}
|
||||
@@ -577,6 +677,19 @@ func validEasyPayProviderConfig(t *testing.T) map[string]string {
|
||||
}
|
||||
}
|
||||
|
||||
func validAirwallexProviderConfig(t *testing.T) map[string]string {
|
||||
t.Helper()
|
||||
|
||||
return map[string]string{
|
||||
"clientId": "client-id-test",
|
||||
"apiKey": "api-key-test",
|
||||
"webhookSecret": "whsec-test",
|
||||
"apiBase": "https://api-demo.airwallex.com/api/v1",
|
||||
"accountId": "acct-test",
|
||||
"currency": "CNY",
|
||||
}
|
||||
}
|
||||
|
||||
func validWxpayProviderConfig(t *testing.T) map[string]string {
|
||||
t.Helper()
|
||||
|
||||
|
||||
@@ -103,6 +103,7 @@ type UpdatePaymentConfigRequest struct {
|
||||
// MethodLimits holds per-payment-type limits.
|
||||
type MethodLimits struct {
|
||||
PaymentType string `json:"payment_type"`
|
||||
Currency string `json:"currency"`
|
||||
FeeRate float64 `json:"fee_rate"`
|
||||
DailyLimit float64 `json:"daily_limit"`
|
||||
SingleMin float64 `json:"single_min"`
|
||||
|
||||
@@ -0,0 +1,28 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/internal/payment"
|
||||
)
|
||||
|
||||
func paymentProviderConfigCurrency(providerKey string, cfg map[string]string) string {
|
||||
switch strings.TrimSpace(providerKey) {
|
||||
case payment.TypeStripe, payment.TypeAirwallex:
|
||||
currency, err := payment.NormalizePaymentCurrency(cfg["currency"])
|
||||
if err == nil {
|
||||
return currency
|
||||
}
|
||||
}
|
||||
return payment.DefaultPaymentCurrency
|
||||
}
|
||||
|
||||
func PaymentOrderCurrency(order *dbent.PaymentOrder) string {
|
||||
if snapshot := psOrderProviderSnapshot(order); snapshot != nil {
|
||||
if currency, err := payment.NormalizePaymentCurrency(snapshot.Currency); err == nil {
|
||||
return currency
|
||||
}
|
||||
}
|
||||
return payment.DefaultPaymentCurrency
|
||||
}
|
||||
@@ -101,13 +101,21 @@ func (s *PaymentService) confirmPayment(ctx context.Context, oid int64, tradeNo
|
||||
})
|
||||
return fmt.Errorf("invalid paid amount from provider: %v", paid)
|
||||
}
|
||||
if math.Abs(paid-o.PayAmount) > amountToleranceCNY {
|
||||
if math.Abs(paid-o.PayAmount) > paymentAmountToleranceForCurrency(PaymentOrderCurrency(o)) {
|
||||
s.writeAuditLog(ctx, o.ID, "PAYMENT_AMOUNT_MISMATCH", pk, map[string]any{"expected": o.PayAmount, "paid": paid, "tradeNo": tradeNo})
|
||||
return fmt.Errorf("amount mismatch: expected %.2f, got %.2f", o.PayAmount, paid)
|
||||
return fmt.Errorf("amount mismatch: expected %s, got %s", strconv.FormatFloat(o.PayAmount, 'f', -1, 64), strconv.FormatFloat(paid, 'f', -1, 64))
|
||||
}
|
||||
return s.toPaid(ctx, o, tradeNo, paid, pk)
|
||||
}
|
||||
|
||||
func paymentAmountToleranceForCurrency(currency string) float64 {
|
||||
minorUnit := payment.CurrencyMinorUnit(currency)
|
||||
if minorUnit <= 2 {
|
||||
return amountToleranceCNY
|
||||
}
|
||||
return math.Pow10(-minorUnit) / 2
|
||||
}
|
||||
|
||||
func isValidProviderAmount(amount float64) bool {
|
||||
return amount > 0 && !math.IsNaN(amount) && !math.IsInf(amount, 0)
|
||||
}
|
||||
|
||||
@@ -366,3 +366,55 @@ func TestValidateProviderNotificationMetadataRejectsEasyPaySnapshotMismatch(t *t
|
||||
})
|
||||
assert.ErrorContains(t, err, "easypay pid mismatch")
|
||||
}
|
||||
|
||||
func TestValidateProviderNotificationMetadataRejectsAirwallexSnapshotMismatch(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
order := &dbent.PaymentOrder{
|
||||
PaymentType: payment.TypeAirwallex,
|
||||
ProviderSnapshot: map[string]any{
|
||||
"schema_version": 2,
|
||||
"merchant_id": "acct_expected",
|
||||
"currency": "CNY",
|
||||
},
|
||||
}
|
||||
|
||||
err := validateProviderNotificationMetadata(order, payment.TypeAirwallex, map[string]string{
|
||||
"account_id": "acct_other",
|
||||
"currency": "CNY",
|
||||
"status": "SUCCEEDED",
|
||||
})
|
||||
assert.ErrorContains(t, err, "airwallex account_id mismatch")
|
||||
|
||||
err = validateProviderNotificationMetadata(order, payment.TypeAirwallex, map[string]string{
|
||||
"account_id": "acct_expected",
|
||||
"currency": "USD",
|
||||
"status": "SUCCEEDED",
|
||||
})
|
||||
assert.ErrorContains(t, err, "airwallex currency mismatch")
|
||||
}
|
||||
|
||||
func TestValidateProviderNotificationMetadataRejectsStripeCurrencyMismatch(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
order := &dbent.PaymentOrder{
|
||||
PaymentType: payment.TypeStripe,
|
||||
ProviderSnapshot: map[string]any{
|
||||
"schema_version": 2,
|
||||
"currency": "HKD",
|
||||
},
|
||||
}
|
||||
|
||||
err := validateProviderNotificationMetadata(order, payment.TypeStripe, map[string]string{
|
||||
"currency": "USD",
|
||||
})
|
||||
assert.ErrorContains(t, err, "stripe currency mismatch")
|
||||
}
|
||||
|
||||
func TestPaymentAmountToleranceForThreeDecimalCurrency(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert.Equal(t, amountToleranceCNY, paymentAmountToleranceForCurrency("CNY"))
|
||||
assert.Equal(t, amountToleranceCNY, paymentAmountToleranceForCurrency("JPY"))
|
||||
assert.InDelta(t, 0.0005, paymentAmountToleranceForCurrency("KWD"), 1e-12)
|
||||
}
|
||||
|
||||
@@ -57,8 +57,17 @@ func (s *PaymentService) CreateOrder(ctx context.Context, req CreateOrderRequest
|
||||
orderAmount = calculateCreditedBalance(req.Amount, cfg.BalanceRechargeMultiplier)
|
||||
}
|
||||
feeRate := cfg.RechargeFeeRate
|
||||
payAmountStr := payment.CalculatePayAmount(limitAmount, feeRate)
|
||||
payAmount, _ := strconv.ParseFloat(payAmountStr, 64)
|
||||
methodCurrency := payment.DefaultPaymentCurrency
|
||||
if s.configService != nil {
|
||||
methodCurrency, err = s.configService.ValidateMethodCurrencyConsistency(ctx, req.PaymentType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
payAmountStr, payAmount, err := calculateCreateOrderPayAmount(limitAmount, feeRate, methodCurrency)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sel, err := s.selectCreateOrderInstance(ctx, req, cfg, payAmount)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -66,6 +75,19 @@ func (s *PaymentService) CreateOrder(ctx context.Context, req CreateOrderRequest
|
||||
if err := s.validateSelectedCreateOrderInstance(ctx, req, sel); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
selectedCurrency := payment.DefaultPaymentCurrency
|
||||
if sel != nil {
|
||||
selectedCurrency = paymentProviderConfigCurrency(sel.ProviderKey, sel.Config)
|
||||
}
|
||||
if selectedCurrency != methodCurrency {
|
||||
payAmountStr, payAmount, err = calculateCreateOrderPayAmount(limitAmount, feeRate, selectedCurrency)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if err := validateSelectedCreateOrderAmountCurrency(payAmountStr, sel); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
oauthResp, err := s.maybeBuildWeChatOAuthRequiredResponseForSelection(ctx, req, limitAmount, payAmount, feeRate, sel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -257,7 +279,7 @@ func buildPaymentOrderProviderSnapshot(sel *payment.InstanceSelection, req Creat
|
||||
if merchantID := strings.TrimSpace(sel.Config["mchId"]); merchantID != "" {
|
||||
snapshot["merchant_id"] = merchantID
|
||||
}
|
||||
snapshot["currency"] = "CNY"
|
||||
snapshot["currency"] = payment.DefaultPaymentCurrency
|
||||
}
|
||||
if providerKey == payment.TypeAlipay {
|
||||
if merchantAppID := strings.TrimSpace(sel.Config["appId"]); merchantAppID != "" {
|
||||
@@ -269,6 +291,15 @@ func buildPaymentOrderProviderSnapshot(sel *payment.InstanceSelection, req Creat
|
||||
snapshot["merchant_id"] = merchantID
|
||||
}
|
||||
}
|
||||
if providerKey == payment.TypeStripe {
|
||||
snapshot["currency"] = paymentProviderConfigCurrency(providerKey, sel.Config)
|
||||
}
|
||||
if providerKey == payment.TypeAirwallex {
|
||||
if accountID := strings.TrimSpace(sel.Config["accountId"]); accountID != "" {
|
||||
snapshot["merchant_id"] = accountID
|
||||
}
|
||||
snapshot["currency"] = paymentProviderConfigCurrency(providerKey, sel.Config)
|
||||
}
|
||||
|
||||
if len(snapshot) == 1 {
|
||||
return nil
|
||||
@@ -377,7 +408,7 @@ func (s *PaymentService) invokeProvider(ctx context.Context, order *dbent.Paymen
|
||||
return nil, infraerrors.ServiceUnavailable("PAYMENT_PROVIDER_MISCONFIGURED", "provider_misconfigured").
|
||||
WithMetadata(map[string]string{"provider": sel.ProviderKey, "instance_id": sel.InstanceID})
|
||||
}
|
||||
subject := s.buildPaymentSubject(plan, limitAmount, cfg)
|
||||
subject := s.buildPaymentSubject(plan, limitAmount, cfg, sel)
|
||||
outTradeNo := order.OutTradeNo
|
||||
canonicalReturnURL, err := CanonicalizeReturnURL(req.ReturnURL, req.SrcHost, req.SrcURL)
|
||||
if err != nil {
|
||||
@@ -466,20 +497,24 @@ func selectedInstanceSupportedTypes(sel *payment.InstanceSelection) string {
|
||||
return sel.SupportedTypes
|
||||
}
|
||||
|
||||
func (s *PaymentService) buildPaymentSubject(plan *dbent.SubscriptionPlan, limitAmount float64, cfg *PaymentConfig) string {
|
||||
func (s *PaymentService) buildPaymentSubject(plan *dbent.SubscriptionPlan, limitAmount float64, cfg *PaymentConfig, sel *payment.InstanceSelection) string {
|
||||
if plan != nil {
|
||||
if plan.ProductName != "" {
|
||||
return plan.ProductName
|
||||
}
|
||||
return "Sub2API Subscription " + plan.Name
|
||||
}
|
||||
amountStr := strconv.FormatFloat(limitAmount, 'f', 2, 64)
|
||||
currency := payment.DefaultPaymentCurrency
|
||||
if sel != nil {
|
||||
currency = paymentProviderConfigCurrency(sel.ProviderKey, sel.Config)
|
||||
}
|
||||
amountStr := payment.FormatAmountForCurrency(limitAmount, currency)
|
||||
pf := strings.TrimSpace(cfg.ProductNamePrefix)
|
||||
sf := strings.TrimSpace(cfg.ProductNameSuffix)
|
||||
if pf != "" || sf != "" {
|
||||
return strings.TrimSpace(pf + " " + amountStr + " " + sf)
|
||||
}
|
||||
return "Sub2API " + amountStr + " CNY"
|
||||
return "Sub2API " + amountStr + " " + currency
|
||||
}
|
||||
|
||||
func (s *PaymentService) maybeBuildWeChatOAuthRequiredResponse(ctx context.Context, req CreateOrderRequest, amount, payAmount, feeRate float64) (*CreateOrderResponse, error) {
|
||||
@@ -540,6 +575,44 @@ func (s *PaymentService) validateSelectedCreateOrderInstance(ctx context.Context
|
||||
return nil
|
||||
}
|
||||
|
||||
func calculateCreateOrderPayAmount(limitAmount, feeRate float64, currency string) (string, float64, error) {
|
||||
if err := validateCreateOrderAmountCurrency(limitAmount, currency); err != nil {
|
||||
return "", 0, err
|
||||
}
|
||||
payAmountStr := payment.CalculatePayAmountForCurrency(limitAmount, feeRate, currency)
|
||||
if _, err := payment.AmountToMinorUnit(payAmountStr, currency); err != nil {
|
||||
return "", 0, infraerrors.BadRequest("INVALID_AMOUNT", err.Error()).
|
||||
WithMetadata(map[string]string{"currency": currency})
|
||||
}
|
||||
payAmount, err := strconv.ParseFloat(payAmountStr, 64)
|
||||
if err != nil {
|
||||
return "", 0, infraerrors.BadRequest("INVALID_AMOUNT", "invalid payment amount").
|
||||
WithMetadata(map[string]string{"currency": currency})
|
||||
}
|
||||
return payAmountStr, payAmount, nil
|
||||
}
|
||||
|
||||
func validateCreateOrderAmountCurrency(amount float64, currency string) error {
|
||||
amountStr := strconv.FormatFloat(amount, 'f', -1, 64)
|
||||
if _, err := payment.AmountToMinorUnit(amountStr, currency); err != nil {
|
||||
return infraerrors.BadRequest("INVALID_AMOUNT", err.Error()).
|
||||
WithMetadata(map[string]string{"currency": currency})
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateSelectedCreateOrderAmountCurrency(payAmount string, sel *payment.InstanceSelection) error {
|
||||
if sel == nil {
|
||||
return nil
|
||||
}
|
||||
currency := paymentProviderConfigCurrency(sel.ProviderKey, sel.Config)
|
||||
if _, err := payment.AmountToMinorUnit(payAmount, currency); err != nil {
|
||||
return infraerrors.BadRequest("INVALID_AMOUNT", err.Error()).
|
||||
WithMetadata(map[string]string{"currency": currency})
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func requiresWeChatJSAPICompatibleSelection(req CreateOrderRequest, sel *payment.InstanceSelection) bool {
|
||||
if sel == nil || sel.ProviderKey != payment.TypeWxpay || payment.GetBasePaymentType(req.PaymentType) != payment.TypeWxpay {
|
||||
return false
|
||||
@@ -596,6 +669,10 @@ func buildCreateOrderResponse(order *dbent.PaymentOrder, req CreateOrderRequest,
|
||||
PayURL: pr.PayURL,
|
||||
QRCode: pr.QRCode,
|
||||
ClientSecret: pr.ClientSecret,
|
||||
IntentID: pr.IntentID,
|
||||
Currency: pr.Currency,
|
||||
CountryCode: pr.CountryCode,
|
||||
PaymentEnv: pr.PaymentEnv,
|
||||
OAuth: pr.OAuth,
|
||||
JSAPI: pr.JSAPI,
|
||||
JSAPIPayload: pr.JSAPI,
|
||||
|
||||
@@ -188,6 +188,38 @@ func validateProviderSnapshotMetadata(order *dbent.PaymentOrder, providerKey str
|
||||
return fmt.Errorf("easypay pid mismatch: expected %s, got %s", expected, actual)
|
||||
}
|
||||
}
|
||||
case payment.TypeStripe:
|
||||
if expected := strings.TrimSpace(snapshot.Currency); expected != "" {
|
||||
actual := strings.ToUpper(strings.TrimSpace(metadata["currency"]))
|
||||
if actual == "" {
|
||||
return fmt.Errorf("stripe notification missing currency")
|
||||
}
|
||||
if !strings.EqualFold(expected, actual) {
|
||||
return fmt.Errorf("stripe currency mismatch: expected %s, got %s", expected, actual)
|
||||
}
|
||||
}
|
||||
case payment.TypeAirwallex:
|
||||
if expected := strings.TrimSpace(snapshot.MerchantID); expected != "" {
|
||||
actual := strings.TrimSpace(metadata["account_id"])
|
||||
if actual == "" {
|
||||
return fmt.Errorf("airwallex account_id missing")
|
||||
}
|
||||
if !strings.EqualFold(expected, actual) {
|
||||
return fmt.Errorf("airwallex account_id mismatch: expected %s, got %s", expected, actual)
|
||||
}
|
||||
}
|
||||
if expected := strings.TrimSpace(snapshot.Currency); expected != "" {
|
||||
actual := strings.ToUpper(strings.TrimSpace(metadata["currency"]))
|
||||
if actual == "" {
|
||||
return fmt.Errorf("airwallex notification missing currency")
|
||||
}
|
||||
if !strings.EqualFold(expected, actual) {
|
||||
return fmt.Errorf("airwallex currency mismatch: expected %s, got %s", expected, actual)
|
||||
}
|
||||
}
|
||||
if actual := strings.TrimSpace(metadata["status"]); actual != "" && !strings.EqualFold(actual, "SUCCEEDED") {
|
||||
return fmt.Errorf("airwallex status mismatch: expected SUCCEEDED, got %s", actual)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -164,6 +164,30 @@ func TestBuildPaymentOrderProviderSnapshot_IncludesEasyPayMerchantIdentity(t *te
|
||||
require.NotContains(t, snapshot, "pkey")
|
||||
}
|
||||
|
||||
func TestBuildPaymentOrderProviderSnapshot_IncludesProviderCurrency(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
stripeSnapshot := buildPaymentOrderProviderSnapshot(&payment.InstanceSelection{
|
||||
InstanceID: "77",
|
||||
ProviderKey: payment.TypeStripe,
|
||||
Config: map[string]string{
|
||||
"currency": "hkd",
|
||||
},
|
||||
}, CreateOrderRequest{})
|
||||
require.Equal(t, "HKD", stripeSnapshot["currency"])
|
||||
|
||||
airwallexSnapshot := buildPaymentOrderProviderSnapshot(&payment.InstanceSelection{
|
||||
InstanceID: "78",
|
||||
ProviderKey: payment.TypeAirwallex,
|
||||
Config: map[string]string{
|
||||
"currency": "usd",
|
||||
"accountId": "acct-78",
|
||||
},
|
||||
}, CreateOrderRequest{})
|
||||
require.Equal(t, "USD", airwallexSnapshot["currency"])
|
||||
require.Equal(t, "acct-78", airwallexSnapshot["merchant_id"])
|
||||
}
|
||||
|
||||
func valueOrEmpty(v *string) string {
|
||||
if v == nil {
|
||||
return ""
|
||||
|
||||
@@ -91,6 +91,53 @@ func TestBuildCreateOrderResponseCopiesJSAPIPayload(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateSelectedCreateOrderAmountCurrencyRejectsFractionalZeroDecimal(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
err := validateSelectedCreateOrderAmountCurrency("100.50", &payment.InstanceSelection{
|
||||
ProviderKey: payment.TypeStripe,
|
||||
Config: map[string]string{"currency": "JPY"},
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("expected fractional JPY amount to fail")
|
||||
}
|
||||
if appErr := infraerrors.FromError(err); appErr.Reason != "INVALID_AMOUNT" {
|
||||
t.Fatalf("reason = %q, want INVALID_AMOUNT", appErr.Reason)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculateCreateOrderPayAmountUsesCurrencyPrecision(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
amountStr, amount, err := calculateCreateOrderPayAmount(100, 2.5, "JPY")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if amountStr != "103" || amount != 103 {
|
||||
t.Fatalf("JPY pay amount = (%q, %v), want (103, 103)", amountStr, amount)
|
||||
}
|
||||
|
||||
amountStr, amount, err = calculateCreateOrderPayAmount(12.345, 1, "KWD")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if amountStr != "12.469" || amount != 12.469 {
|
||||
t.Fatalf("KWD pay amount = (%q, %v), want (12.469, 12.469)", amountStr, amount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculateCreateOrderPayAmountRejectsFractionalZeroDecimal(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, _, err := calculateCreateOrderPayAmount(100.5, 0, "JPY")
|
||||
if err == nil {
|
||||
t.Fatal("expected fractional JPY amount to fail")
|
||||
}
|
||||
if appErr := infraerrors.FromError(err); appErr.Reason != "INVALID_AMOUNT" {
|
||||
t.Fatalf("reason = %q, want INVALID_AMOUNT", appErr.Reason)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMaybeBuildWeChatOAuthRequiredResponse(t *testing.T) {
|
||||
t.Setenv("PAYMENT_RESUME_SIGNING_KEY", "0123456789abcdef0123456789abcdef")
|
||||
|
||||
|
||||
@@ -226,10 +226,11 @@ func (s *PaymentService) PrepareRefund(ctx context.Context, oid int64, amt float
|
||||
if amt <= 0 {
|
||||
amt = o.Amount
|
||||
}
|
||||
if amt-o.Amount > amountToleranceCNY {
|
||||
orderCurrency := PaymentOrderCurrency(o)
|
||||
if amt-o.Amount > paymentAmountToleranceForCurrency(orderCurrency) {
|
||||
return nil, nil, infraerrors.BadRequest("REFUND_AMOUNT_EXCEEDED", "refund amount exceeds recharge")
|
||||
}
|
||||
ga := calculateGatewayRefundAmount(o.Amount, o.PayAmount, amt)
|
||||
ga := calculateGatewayRefundAmount(o.Amount, o.PayAmount, amt, orderCurrency)
|
||||
rr := strings.TrimSpace(reason)
|
||||
if rr == "" && o.RefundRequestReason != nil {
|
||||
rr = *o.RefundRequestReason
|
||||
@@ -339,13 +340,35 @@ func (s *PaymentService) gwRefund(ctx context.Context, p *RefundPlan) error {
|
||||
})
|
||||
return err
|
||||
}
|
||||
_, err = prov.Refund(ctx, payment.RefundRequest{
|
||||
resp, err := prov.Refund(ctx, payment.RefundRequest{
|
||||
TradeNo: p.Order.PaymentTradeNo,
|
||||
OrderID: p.Order.OutTradeNo,
|
||||
Amount: strconv.FormatFloat(p.GatewayAmount, 'f', 2, 64),
|
||||
Amount: formatGatewayRefundAmount(p.GatewayAmount, p.Order),
|
||||
Reason: p.Reason,
|
||||
})
|
||||
return err
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return validateRefundProviderResponse(resp)
|
||||
}
|
||||
|
||||
func formatGatewayRefundAmount(amount float64, order *dbent.PaymentOrder) string {
|
||||
return payment.FormatAmountForCurrency(amount, PaymentOrderCurrency(order))
|
||||
}
|
||||
|
||||
func validateRefundProviderResponse(resp *payment.RefundResponse) error {
|
||||
if resp == nil {
|
||||
return fmt.Errorf("payment refund response missing")
|
||||
}
|
||||
status := strings.TrimSpace(resp.Status)
|
||||
switch status {
|
||||
case payment.ProviderStatusSuccess, payment.ProviderStatusRefunded, payment.ProviderStatusPending:
|
||||
return nil
|
||||
case payment.ProviderStatusFailed:
|
||||
return fmt.Errorf("payment refund failed: status %s", status)
|
||||
default:
|
||||
return fmt.Errorf("payment refund returned unknown status: %s", status)
|
||||
}
|
||||
}
|
||||
|
||||
// getRefundProvider creates a provider using the order's original instance config.
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/internal/payment"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -184,3 +185,26 @@ func TestGwRefundRejectsAlipayMerchantIdentitySnapshotMismatch(t *testing.T) {
|
||||
})
|
||||
require.ErrorContains(t, err, "alipay app_id mismatch")
|
||||
}
|
||||
|
||||
func TestCalculateGatewayRefundAmountUsesCurrencyPrecision(t *testing.T) {
|
||||
require.InDelta(t, 6.173, calculateGatewayRefundAmount(100, 12.345, 50, "KWD"), 1e-12)
|
||||
require.InDelta(t, 12.345, calculateGatewayRefundAmount(100, 12.345, 100, "KWD"), 1e-12)
|
||||
require.InDelta(t, 52, calculateGatewayRefundAmount(100, 103, 50, "JPY"), 1e-12)
|
||||
}
|
||||
|
||||
func TestFormatGatewayRefundAmountUsesOrderCurrency(t *testing.T) {
|
||||
order := &dbent.PaymentOrder{
|
||||
ProviderSnapshot: map[string]any{
|
||||
"currency": "KWD",
|
||||
},
|
||||
}
|
||||
|
||||
require.Equal(t, "12.345", formatGatewayRefundAmount(12.345, order))
|
||||
}
|
||||
|
||||
func TestValidateRefundProviderResponseAcceptsPending(t *testing.T) {
|
||||
require.NoError(t, validateRefundProviderResponse(&payment.RefundResponse{Status: payment.ProviderStatusPending}))
|
||||
require.NoError(t, validateRefundProviderResponse(&payment.RefundResponse{Status: payment.ProviderStatusSuccess}))
|
||||
require.Error(t, validateRefundProviderResponse(&payment.RefundResponse{Status: payment.ProviderStatusFailed}))
|
||||
require.Error(t, validateRefundProviderResponse(nil))
|
||||
}
|
||||
|
||||
@@ -97,6 +97,10 @@ type CreateOrderResponse struct {
|
||||
PayURL string `json:"pay_url,omitempty"`
|
||||
QRCode string `json:"qr_code,omitempty"`
|
||||
ClientSecret string `json:"client_secret,omitempty"`
|
||||
IntentID string `json:"intent_id,omitempty"`
|
||||
Currency string `json:"currency,omitempty"`
|
||||
CountryCode string `json:"country_code,omitempty"`
|
||||
PaymentEnv string `json:"payment_env,omitempty"`
|
||||
OAuth *payment.WechatOAuthInfo `json:"oauth,omitempty"`
|
||||
JSAPI *payment.WechatJSAPIPayload `json:"jsapi,omitempty"`
|
||||
JSAPIPayload *payment.WechatJSAPIPayload `json:"jsapi_payload,omitempty"`
|
||||
|
||||
Reference in New Issue
Block a user