Merge branch 'main' into dev

This commit is contained in:
ZeroDeng
2026-05-11 11:48:55 +08:00
committed by GitHub
77 changed files with 3204 additions and 208 deletions
+1 -1
View File
@@ -28,7 +28,7 @@ const (
// DefaultCSPPolicy is the default Content-Security-Policy with nonce support
// __CSP_NONCE__ will be replaced with actual nonce at request time by the SecurityHeaders middleware
const DefaultCSPPolicy = "default-src 'self'; script-src 'self' __CSP_NONCE__ https://challenges.cloudflare.com https://static.cloudflareinsights.com https://*.stripe.com; style-src 'self' 'unsafe-inline' https://fonts.googleapis.com; img-src 'self' data: https:; font-src 'self' data: https://fonts.gstatic.com; connect-src 'self' https:; frame-src https://challenges.cloudflare.com https://*.stripe.com; frame-ancestors 'none'; base-uri 'self'; form-action 'self'"
const DefaultCSPPolicy = "default-src 'self'; script-src 'self' __CSP_NONCE__ https://challenges.cloudflare.com https://static.cloudflareinsights.com https://*.stripe.com https://static.airwallex.com https://checkout.airwallex.com https://static-demo.airwallex.com https://checkout-demo.airwallex.com; style-src 'self' 'unsafe-inline' https://fonts.googleapis.com https://static.airwallex.com https://checkout.airwallex.com https://static-demo.airwallex.com https://checkout-demo.airwallex.com; img-src 'self' data: https:; font-src 'self' data: https://fonts.gstatic.com; connect-src 'self' https:; frame-src https://challenges.cloudflare.com https://*.stripe.com https://checkout.airwallex.com https://checkout-demo.airwallex.com; frame-ancestors 'none'; base-uri 'self'; form-action 'self'"
// UMQ(用户消息队列)模式常量
const (
+55 -10
View File
@@ -459,6 +459,7 @@ type PublicOrderResult struct {
Amount float64 `json:"amount"`
PayAmount float64 `json:"pay_amount"`
FeeRate float64 `json:"fee_rate"`
Currency string `json:"currency"`
PaymentType string `json:"payment_type"`
OrderType string `json:"order_type"`
Status string `json:"status"`
@@ -481,6 +482,7 @@ func buildPublicOrderResult(order *dbent.PaymentOrder) PublicOrderResult {
Amount: order.Amount,
PayAmount: order.PayAmount,
FeeRate: order.FeeRate,
Currency: service.PaymentOrderCurrency(order),
PaymentType: order.PaymentType,
OrderType: order.OrderType,
Status: order.Status,
@@ -554,24 +556,67 @@ func isMobile(c *gin.Context) bool {
return false
}
func sanitizePaymentOrdersForResponse(orders []*dbent.PaymentOrder) []*dbent.PaymentOrder {
if len(orders) == 0 {
return orders
}
out := make([]*dbent.PaymentOrder, 0, len(orders))
type PaymentOrderResult struct {
ID int64 `json:"id"`
UserID int64 `json:"user_id"`
Amount float64 `json:"amount"`
PayAmount float64 `json:"pay_amount"`
FeeRate float64 `json:"fee_rate"`
Currency string `json:"currency"`
PaymentType string `json:"payment_type"`
OutTradeNo string `json:"out_trade_no"`
Status string `json:"status"`
OrderType string `json:"order_type"`
CreatedAt time.Time `json:"created_at"`
ExpiresAt time.Time `json:"expires_at"`
PaidAt *time.Time `json:"paid_at,omitempty"`
CompletedAt *time.Time `json:"completed_at,omitempty"`
RefundAmount float64 `json:"refund_amount"`
RefundReason *string `json:"refund_reason,omitempty"`
RefundRequestedAt *time.Time `json:"refund_requested_at,omitempty"`
RefundRequestedBy *string `json:"refund_requested_by,omitempty"`
RefundRequestReason *string `json:"refund_request_reason,omitempty"`
PlanID *int64 `json:"plan_id,omitempty"`
ProviderInstanceID *string `json:"provider_instance_id,omitempty"`
}
func sanitizePaymentOrdersForResponse(orders []*dbent.PaymentOrder) []PaymentOrderResult {
out := make([]PaymentOrderResult, 0, len(orders))
for _, order := range orders {
out = append(out, sanitizePaymentOrderForResponse(order))
if item := sanitizePaymentOrderForResponse(order); item != nil {
out = append(out, *item)
}
}
return out
}
func sanitizePaymentOrderForResponse(order *dbent.PaymentOrder) *dbent.PaymentOrder {
func sanitizePaymentOrderForResponse(order *dbent.PaymentOrder) *PaymentOrderResult {
if order == nil {
return nil
}
cloned := *order
cloned.ProviderSnapshot = nil
return &cloned
return &PaymentOrderResult{
ID: order.ID,
UserID: order.UserID,
Amount: order.Amount,
PayAmount: order.PayAmount,
FeeRate: order.FeeRate,
Currency: service.PaymentOrderCurrency(order),
PaymentType: order.PaymentType,
OutTradeNo: order.OutTradeNo,
Status: order.Status,
OrderType: order.OrderType,
CreatedAt: order.CreatedAt,
ExpiresAt: order.ExpiresAt,
PaidAt: order.PaidAt,
CompletedAt: order.CompletedAt,
RefundAmount: order.RefundAmount,
RefundReason: order.RefundReason,
RefundRequestedAt: order.RefundRequestedAt,
RefundRequestedBy: order.RefundRequestedBy,
RefundRequestReason: order.RefundRequestReason,
PlanID: order.PlanID,
ProviderInstanceID: order.ProviderInstanceID,
}
}
func isWeChatBrowser(c *gin.Context) bool {
@@ -114,6 +114,7 @@ func TestVerifyOrderPublicReturnsLegacyOrderState(t *testing.T) {
SetExpiresAt(time.Now().Add(time.Hour)).
SetClientIP("127.0.0.1").
SetSrcHost("api.example.com").
SetProviderSnapshot(map[string]any{"currency": "HKD"}).
Save(context.Background())
require.NoError(t, err)
@@ -141,6 +142,7 @@ func TestVerifyOrderPublicReturnsLegacyOrderState(t *testing.T) {
Amount float64 `json:"amount"`
PayAmount float64 `json:"pay_amount"`
FeeRate float64 `json:"fee_rate"`
Currency string `json:"currency"`
PaymentType string `json:"payment_type"`
OrderType string `json:"order_type"`
Status string `json:"status"`
@@ -155,6 +157,7 @@ func TestVerifyOrderPublicReturnsLegacyOrderState(t *testing.T) {
require.Equal(t, "legacy-order-no", resp.Data.OutTradeNo)
require.Equal(t, 90.64, resp.Data.PayAmount)
require.Equal(t, 0.03, resp.Data.FeeRate)
require.Equal(t, "HKD", resp.Data.Currency)
require.Equal(t, payment.TypeAlipay, resp.Data.PaymentType)
require.Equal(t, payment.OrderTypeBalance, resp.Data.OrderType)
require.Equal(t, service.OrderStatusPending, resp.Data.Status)
@@ -202,6 +205,7 @@ func TestResolveOrderPublicByResumeTokenReturnsFrontendContractFields(t *testing
SetPaidAt(time.Now()).
SetClientIP("127.0.0.1").
SetSrcHost("api.example.com").
SetProviderSnapshot(map[string]any{"currency": "USD"}).
Save(context.Background())
require.NoError(t, err)
@@ -242,6 +246,7 @@ func TestResolveOrderPublicByResumeTokenReturnsFrontendContractFields(t *testing
require.Equal(t, 100.0, resp.Data["amount"])
require.Equal(t, 103.0, resp.Data["pay_amount"])
require.Equal(t, 0.03, resp.Data["fee_rate"])
require.Equal(t, "USD", resp.Data["currency"])
require.Equal(t, payment.TypeAlipay, resp.Data["payment_type"])
require.Equal(t, payment.OrderTypeBalance, resp.Data["order_type"])
require.Equal(t, service.OrderStatusPaid, resp.Data["status"])
@@ -2,6 +2,7 @@ package handler
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
@@ -60,6 +61,12 @@ func (h *PaymentWebhookHandler) StripeWebhook(c *gin.Context) {
h.handleNotify(c, payment.TypeStripe)
}
// AirwallexWebhook 处理空中云汇 Webhook 事件。
// POST /api/v1/payment/webhook/airwallex
func (h *PaymentWebhookHandler) AirwallexWebhook(c *gin.Context) {
h.handleNotify(c, payment.TypeAirwallex)
}
// handleNotify is the shared logic for all provider webhook handlers.
func (h *PaymentWebhookHandler) handleNotify(c *gin.Context, providerKey string) {
var rawBody string
@@ -146,6 +153,17 @@ func extractOutTradeNo(rawBody, providerKey string) string {
if err == nil {
return values.Get("out_trade_no")
}
case payment.TypeAirwallex:
var payload struct {
Data struct {
Object struct {
MerchantOrderID string `json:"merchant_order_id"`
} `json:"object"`
} `json:"data"`
}
if err := json.Unmarshal([]byte(rawBody), &payload); err == nil {
return strings.TrimSpace(payload.Data.Object.MerchantOrderID)
}
}
// For other providers (Stripe, Alipay direct, WxPay direct), the registry
// typically has only one instance, so no instance lookup is needed.
@@ -183,14 +201,14 @@ const (
wxpaySuccessMessage = "成功"
)
// writeSuccessResponse sends the provider-specific success response.
// WeChat Pay requires JSON {"code":"SUCCESS","message":"成功"};
// Stripe expects an empty 200; others accept plain text "success".
// writeSuccessResponse 返回各支付服务商要求的成功响应。
// 微信支付需要 JSON {"code":"SUCCESS","message":"成功"}
// Stripe 和空中云汇接受空 200,其它服务商接受纯文本 "success"
func writeSuccessResponse(c *gin.Context, providerKey string) {
switch providerKey {
case payment.TypeWxpay:
c.JSON(http.StatusOK, wxpaySuccessResponse{Code: wxpaySuccessCode, Message: wxpaySuccessMessage})
case payment.TypeStripe:
case payment.TypeStripe, payment.TypeAirwallex:
c.String(http.StatusOK, "")
default:
c.String(http.StatusOK, "success")
@@ -47,6 +47,13 @@ func TestWriteSuccessResponse(t *testing.T) {
wantContentType: "text/plain",
wantBody: "",
},
{
name: "airwallex returns empty 200",
providerKey: payment.TypeAirwallex,
wantCode: http.StatusOK,
wantContentType: "text/plain",
wantBody: "",
},
{
name: "easypay returns plain text success",
providerKey: "easypay",
@@ -165,6 +172,12 @@ func TestExtractOutTradeNo(t *testing.T) {
rawBody: "{}",
want: "",
},
{
name: "airwallex payment intent payload",
providerKey: payment.TypeAirwallex,
rawBody: `{"name":"payment_intent.succeeded","data":{"object":{"merchant_order_id":"sub2_awx_123"}}}`,
want: "sub2_awx_123",
},
}
for _, tt := range tests {
@@ -220,7 +233,7 @@ type webhookHandlerProviderStub struct {
verifyErr error
}
func (p webhookHandlerProviderStub) Name() string { return p.key }
func (p webhookHandlerProviderStub) Name() string { return p.key }
func (p webhookHandlerProviderStub) ProviderKey() string { return p.key }
func (p webhookHandlerProviderStub) SupportedTypes() []payment.PaymentType {
return []payment.PaymentType{payment.PaymentType(p.key)}
+2 -17
View File
@@ -1,24 +1,9 @@
package payment
import (
"fmt"
"github.com/shopspring/decimal"
)
const centsPerYuan = 100
// YuanToFen converts a CNY yuan string (e.g. "10.50") to fen (int64).
// Uses shopspring/decimal for precision.
func YuanToFen(yuanStr string) (int64, error) {
d, err := decimal.NewFromString(yuanStr)
if err != nil {
return 0, fmt.Errorf("invalid amount: %s", yuanStr)
}
return d.Mul(decimal.NewFromInt(centsPerYuan)).IntPart(), nil
return AmountToMinorUnit(yuanStr, DefaultPaymentCurrency)
}
// FenToYuan converts fen (int64) to yuan as a float64 for interface compatibility.
func FenToYuan(fen int64) float64 {
return decimal.NewFromInt(fen).Div(decimal.NewFromInt(centsPerYuan)).InexactFloat64()
return MinorUnitToAmount(fen, DefaultPaymentCurrency)
}
+101
View File
@@ -126,3 +126,104 @@ func TestYuanToFenRoundTrip(t *testing.T) {
}
}
}
func TestPaymentCurrencyHelpers(t *testing.T) {
tests := []struct {
name string
currency string
amount string
wantMinor int64
wantBack float64
}{
{name: "hkd uses cents", currency: "hkd", amount: "12.34", wantMinor: 1234, wantBack: 12.34},
{name: "jpy has no minor unit", currency: "JPY", amount: "12", wantMinor: 12, wantBack: 12},
{name: "kwd uses three decimal minor units", currency: "KWD", amount: "12.345", wantMinor: 12345, wantBack: 12.345},
{name: "isk uses Stripe legacy two-decimal API amount", currency: "ISK", amount: "12", wantMinor: 1200, wantBack: 12},
{name: "ugx uses Stripe legacy two-decimal API amount", currency: "UGX", amount: "12.00", wantMinor: 1200, wantBack: 12},
{name: "empty currency defaults to cny", currency: "", amount: "1.23", wantMinor: 123, wantBack: 1.23},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := AmountToMinorUnit(tt.amount, tt.currency)
if err != nil {
t.Fatalf("AmountToMinorUnit(%q, %q) unexpected error: %v", tt.amount, tt.currency, err)
}
if got != tt.wantMinor {
t.Fatalf("AmountToMinorUnit(%q, %q) = %d, want %d", tt.amount, tt.currency, got, tt.wantMinor)
}
back := MinorUnitToAmount(got, tt.currency)
if math.Abs(back-tt.wantBack) > 1e-9 {
t.Fatalf("MinorUnitToAmount(%d, %q) = %f, want %f", got, tt.currency, back, tt.wantBack)
}
})
}
}
func TestFormatAmountForCurrency(t *testing.T) {
tests := []struct {
currency string
amount float64
want string
}{
{currency: "CNY", amount: 12.3, want: "12.30"},
{currency: "JPY", amount: 12, want: "12"},
{currency: "KWD", amount: 12.345, want: "12.345"},
{currency: "ISK", amount: 12, want: "12"},
}
for _, tt := range tests {
t.Run(tt.currency, func(t *testing.T) {
if got := FormatAmountForCurrency(tt.amount, tt.currency); got != tt.want {
t.Fatalf("FormatAmountForCurrency(%v, %q) = %q, want %q", tt.amount, tt.currency, got, tt.want)
}
})
}
}
func TestAmountToMinorUnitRejectsUnsupportedPrecision(t *testing.T) {
if _, err := AmountToMinorUnit("100.50", "JPY"); err == nil {
t.Fatal("expected fractional JPY amount to fail")
}
if _, err := AmountToMinorUnit("100.50", "ISK"); err == nil {
t.Fatal("expected fractional ISK amount to fail")
}
if _, err := AmountToMinorUnit("100.50", "UGX"); err == nil {
t.Fatal("expected fractional UGX amount to fail")
}
if _, err := AmountToMinorUnit("12.345", "HKD"); err == nil {
t.Fatal("expected amount with more than two decimal places to fail")
}
if _, err := AmountToMinorUnit("12.3456", "KWD"); err == nil {
t.Fatal("expected amount with more than three decimal places to fail")
}
if got, err := AmountToMinorUnit("100.00", "JPY"); err != nil || got != 100 {
t.Fatalf("AmountToMinorUnit integer-form JPY = (%d, %v), want (100, nil)", got, err)
}
}
func TestThreeDecimalPaymentCurrencies(t *testing.T) {
for _, currency := range []string{"BHD", "IQD", "JOD", "KWD", "LYD", "OMR", "TND"} {
t.Run(currency, func(t *testing.T) {
got, err := AmountToMinorUnit("12.345", currency)
if err != nil {
t.Fatalf("AmountToMinorUnit(%q, %q) unexpected error: %v", "12.345", currency, err)
}
if got != 12345 {
t.Fatalf("AmountToMinorUnit(%q, %q) = %d, want 12345", "12.345", currency, got)
}
if back := MinorUnitToAmount(got, currency); math.Abs(back-12.345) > 1e-9 {
t.Fatalf("MinorUnitToAmount(%d, %q) = %f, want 12.345", got, currency, back)
}
})
}
}
func TestNormalizePaymentCurrencyRejectsInvalidCodes(t *testing.T) {
if _, err := NormalizePaymentCurrency("HK"); err == nil {
t.Fatal("expected invalid two-letter currency to fail")
}
if _, err := NormalizePaymentCurrency("US1"); err == nil {
t.Fatal("expected non-letter currency to fail")
}
}
+118
View File
@@ -0,0 +1,118 @@
package payment
import (
"fmt"
"strings"
"github.com/shopspring/decimal"
)
const DefaultPaymentCurrency = "CNY"
type paymentCurrencyAmountUnit struct {
apiMinorUnit int
maxFractionDigits int
}
var (
zeroDecimalAmountUnit = paymentCurrencyAmountUnit{apiMinorUnit: 0, maxFractionDigits: 0}
twoDecimalAmountUnit = paymentCurrencyAmountUnit{apiMinorUnit: 2, maxFractionDigits: 2}
threeDecimalAmountUnit = paymentCurrencyAmountUnit{apiMinorUnit: 3, maxFractionDigits: 3}
stripeLegacyZeroAmount = paymentCurrencyAmountUnit{apiMinorUnit: 2, maxFractionDigits: 0}
)
var paymentCurrencyAmountUnits = map[string]paymentCurrencyAmountUnit{
"BIF": zeroDecimalAmountUnit,
"CLP": zeroDecimalAmountUnit,
"DJF": zeroDecimalAmountUnit,
"GNF": zeroDecimalAmountUnit,
"JPY": zeroDecimalAmountUnit,
"KMF": zeroDecimalAmountUnit,
"KRW": zeroDecimalAmountUnit,
"MGA": zeroDecimalAmountUnit,
"PYG": zeroDecimalAmountUnit,
"RWF": zeroDecimalAmountUnit,
"VND": zeroDecimalAmountUnit,
"VUV": zeroDecimalAmountUnit,
"XAF": zeroDecimalAmountUnit,
"XOF": zeroDecimalAmountUnit,
"XPF": zeroDecimalAmountUnit,
"ISK": stripeLegacyZeroAmount,
"UGX": stripeLegacyZeroAmount,
"BHD": threeDecimalAmountUnit,
"IQD": threeDecimalAmountUnit,
"JOD": threeDecimalAmountUnit,
"KWD": threeDecimalAmountUnit,
"LYD": threeDecimalAmountUnit,
"OMR": threeDecimalAmountUnit,
"TND": threeDecimalAmountUnit,
}
func NormalizePaymentCurrency(raw string) (string, error) {
currency := strings.ToUpper(strings.TrimSpace(raw))
if currency == "" {
return DefaultPaymentCurrency, nil
}
if len(currency) != 3 {
return "", fmt.Errorf("payment currency must be a 3-letter ISO currency code")
}
for _, ch := range currency {
if ch < 'A' || ch > 'Z' {
return "", fmt.Errorf("payment currency must be a 3-letter ISO currency code")
}
}
return currency, nil
}
func CurrencyMinorUnit(currency string) int {
return paymentCurrencyAmountUnitFor(currency).apiMinorUnit
}
// CurrencyMaxFractionDigits 返回支付金额允许展示和输入的小数位数。
func CurrencyMaxFractionDigits(currency string) int {
return paymentCurrencyAmountUnitFor(currency).maxFractionDigits
}
// FormatAmountForCurrency 按币种允许的小数位格式化支付金额。
func FormatAmountForCurrency(amount float64, currency string) string {
return decimal.NewFromFloat(amount).StringFixed(int32(CurrencyMaxFractionDigits(currency)))
}
func paymentCurrencyAmountUnitFor(currency string) paymentCurrencyAmountUnit {
normalized, err := NormalizePaymentCurrency(currency)
if err != nil {
return twoDecimalAmountUnit
}
if amountUnit, ok := paymentCurrencyAmountUnits[normalized]; ok {
return amountUnit
}
return twoDecimalAmountUnit
}
func AmountToMinorUnit(amountStr, currency string) (int64, error) {
d, err := decimal.NewFromString(strings.TrimSpace(amountStr))
if err != nil {
return 0, fmt.Errorf("invalid amount: %s", amountStr)
}
normalizedCurrency, err := NormalizePaymentCurrency(currency)
if err != nil {
return 0, err
}
amountUnit := paymentCurrencyAmountUnitFor(normalizedCurrency)
precisionFactor := decimal.New(1, int32(amountUnit.maxFractionDigits))
scaledForPrecision := d.Mul(precisionFactor)
if !scaledForPrecision.Equal(scaledForPrecision.Truncate(0)) {
if amountUnit.maxFractionDigits == 0 {
return 0, fmt.Errorf("payment amount for %s must be a whole number", normalizedCurrency)
}
return 0, fmt.Errorf("payment amount for %s must not have more than %d decimal places", normalizedCurrency, amountUnit.maxFractionDigits)
}
factor := decimal.New(1, int32(amountUnit.apiMinorUnit))
minorAmount := d.Mul(factor)
return minorAmount.IntPart(), nil
}
func MinorUnitToAmount(value int64, currency string) float64 {
factor := decimal.New(1, int32(CurrencyMinorUnit(currency)))
return decimal.NewFromInt(value).Div(factor).InexactFloat64()
}
+9 -7
View File
@@ -4,16 +4,18 @@ import (
"github.com/shopspring/decimal"
)
// CalculatePayAmount computes the total pay amount given a recharge amount and
// fee rate (percentage). Fee = amount * feeRate / 100, rounded UP (away from zero)
// to 2 decimal places. The returned string is formatted to exactly 2 decimal places.
// If feeRate <= 0, the amount is returned as-is (formatted to 2 decimal places).
func CalculatePayAmount(rechargeAmount float64, feeRate float64) string {
return CalculatePayAmountForCurrency(rechargeAmount, feeRate, DefaultPaymentCurrency)
}
// CalculatePayAmountForCurrency 按币种精度计算应付金额,手续费向上取整到该币种最小支付单位。
func CalculatePayAmountForCurrency(rechargeAmount float64, feeRate float64, currency string) string {
fractionDigits := int32(CurrencyMaxFractionDigits(currency))
amount := decimal.NewFromFloat(rechargeAmount)
if feeRate <= 0 {
return amount.StringFixed(2)
return amount.StringFixed(fractionDigits)
}
rate := decimal.NewFromFloat(feeRate)
fee := amount.Mul(rate).Div(decimal.NewFromInt(100)).RoundUp(2)
return amount.Add(fee).StringFixed(2)
fee := amount.Mul(rate).Div(decimal.NewFromInt(100)).RoundUp(fractionDigits)
return amount.Add(fee).StringFixed(fractionDigits)
}
+52
View File
@@ -109,3 +109,55 @@ func TestCalculatePayAmount(t *testing.T) {
})
}
}
func TestCalculatePayAmountForCurrency(t *testing.T) {
t.Parallel()
tests := []struct {
name string
amount float64
feeRate float64
currency string
expected string
}{
{
name: "zero decimal currency rounds fee up to whole unit",
amount: 100,
feeRate: 2.5,
currency: "JPY",
expected: "103",
},
{
name: "three decimal currency keeps three decimal places",
amount: 12.345,
feeRate: 1,
currency: "KWD",
expected: "12.469",
},
{
name: "stripe legacy zero decimal currency displays whole unit",
amount: 100,
feeRate: 2.5,
currency: "ISK",
expected: "103",
},
{
name: "default currency keeps existing two decimal behavior",
amount: 10,
feeRate: 3.33,
currency: "CNY",
expected: "10.34",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := CalculatePayAmountForCurrency(tt.amount, tt.feeRate, tt.currency)
if got != tt.expected {
t.Fatalf("CalculatePayAmountForCurrency(%v, %v, %q) = %q, want %q", tt.amount, tt.feeRate, tt.currency, got, tt.expected)
}
})
}
}
@@ -0,0 +1,639 @@
package provider
import (
"bytes"
"context"
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/payment"
"github.com/google/uuid"
"github.com/shopspring/decimal"
)
const (
airwallexDemoAPIBase = "https://api-demo.airwallex.com/api/v1"
airwallexProdAPIBase = "https://api.airwallex.com/api/v1"
airwallexDefaultCountry = "CN"
airwallexHTTPTimeout = 15 * time.Second
airwallexMaxResponseSize = 1 << 20
airwallexMaxErrorSummary = 512
airwallexTokenSkew = 2 * time.Minute
airwallexWebhookTolerance = 5 * time.Minute
airwallexEventPaymentSucceeded = "payment_intent.succeeded"
airwallexEventPaymentCancelled = "payment_intent.cancelled"
airwallexPaymentStatusSucceeded = "SUCCEEDED"
airwallexPaymentStatusCancelled = "CANCELLED"
airwallexRefundStatusReceived = "RECEIVED"
airwallexRefundStatusAccepted = "ACCEPTED"
airwallexRefundStatusSettled = "SETTLED"
airwallexRefundStatusFailed = "FAILED"
)
type Airwallex struct {
instanceID string
config map[string]string
httpClient *http.Client
}
type airwallexTokenState struct {
mu sync.Mutex
token string
expiresAt time.Time
}
var airwallexAccessTokens sync.Map
func NewAirwallex(instanceID string, config map[string]string) (*Airwallex, error) {
for _, k := range []string{"clientId", "apiKey", "webhookSecret", "apiBase"} {
if strings.TrimSpace(config[k]) == "" {
return nil, fmt.Errorf("airwallex config missing required key: %s", k)
}
}
cfg := cloneStringMap(config)
apiBase, err := normalizeAirwallexAPIBase(cfg["apiBase"])
if err != nil {
return nil, err
}
cfg["apiBase"] = apiBase
currency, err := payment.NormalizePaymentCurrency(cfg["currency"])
if err != nil {
return nil, fmt.Errorf("airwallex config currency: %w", err)
}
cfg["currency"] = currency
countryCode, err := normalizeAirwallexCountryCode(cfg["countryCode"])
if err != nil {
return nil, err
}
cfg["countryCode"] = countryCode
return &Airwallex{
instanceID: instanceID,
config: cfg,
httpClient: &http.Client{Timeout: airwallexHTTPTimeout},
}, nil
}
func normalizeAirwallexCountryCode(raw string) (string, error) {
countryCode := strings.ToUpper(strings.TrimSpace(raw))
if countryCode == "" {
return airwallexDefaultCountry, nil
}
if len(countryCode) != 2 {
return "", fmt.Errorf("airwallex config countryCode must be a two-letter ISO country code")
}
for _, ch := range countryCode {
if ch < 'A' || ch > 'Z' {
return "", fmt.Errorf("airwallex config countryCode must be a two-letter ISO country code")
}
}
return countryCode, nil
}
func normalizeAirwallexAPIBase(raw string) (string, error) {
base := strings.TrimSpace(raw)
if base == "" {
return "", fmt.Errorf("airwallex apiBase is required")
}
parsed, err := url.Parse(base)
if err != nil || parsed.Scheme != "https" || parsed.Host == "" {
return "", fmt.Errorf("airwallex apiBase must be an HTTPS URL")
}
host := strings.ToLower(parsed.Host)
if host != "api-demo.airwallex.com" && host != "api.airwallex.com" {
return "", fmt.Errorf("airwallex apiBase host must be api-demo.airwallex.com or api.airwallex.com")
}
parsed.RawQuery = ""
parsed.Fragment = ""
parsed.RawPath = ""
parsed.Path = strings.TrimRight(parsed.Path, "/")
if parsed.Path == "" {
parsed.Path = "/api/v1"
}
if parsed.Path != "/api/v1" {
return "", fmt.Errorf("airwallex apiBase path must be /api/v1")
}
return parsed.String(), nil
}
func (a *Airwallex) Name() string { return "空中云汇" }
func (a *Airwallex) ProviderKey() string { return payment.TypeAirwallex }
func (a *Airwallex) SupportedTypes() []payment.PaymentType {
return []payment.PaymentType{payment.TypeAirwallex}
}
func (a *Airwallex) MerchantIdentityMetadata() map[string]string {
if a == nil {
return nil
}
metadata := map[string]string{"currency": a.currency()}
if accountID := strings.TrimSpace(a.config["accountId"]); accountID != "" {
metadata["account_id"] = accountID
}
return metadata
}
func (a *Airwallex) currency() string {
if a == nil {
return payment.DefaultPaymentCurrency
}
currency, err := payment.NormalizePaymentCurrency(a.config["currency"])
if err != nil {
return payment.DefaultPaymentCurrency
}
return currency
}
func (a *Airwallex) CreatePayment(ctx context.Context, req payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) {
amount, err := decimal.NewFromString(req.Amount)
if err != nil || amount.LessThanOrEqual(decimal.Zero) {
return nil, fmt.Errorf("airwallex create payment: invalid amount %s", req.Amount)
}
token, err := a.accessToken(ctx)
if err != nil {
return nil, fmt.Errorf("airwallex auth: %w", err)
}
currency := a.currency()
requestID := airwallexDeterministicRequestID("payment-intent", req.OrderID, req.Amount, currency)
payload := airwallexCreatePaymentIntentRequest{
RequestID: requestID,
Amount: newAirwallexRequestAmount(amount),
Currency: currency,
MerchantOrderID: req.OrderID,
ReturnURL: req.ReturnURL,
Metadata: map[string]string{
"order_id": req.OrderID,
},
}
if descriptor := strings.TrimSpace(a.config["descriptor"]); descriptor != "" {
payload.Descriptor = descriptor
}
var intent airwallexPaymentIntent
if err := a.doJSON(ctx, http.MethodPost, "/pa/payment_intents/create", token, payload, &intent); err != nil {
return nil, fmt.Errorf("airwallex create payment: %w", err)
}
if strings.TrimSpace(intent.ID) == "" || strings.TrimSpace(intent.ClientSecret) == "" {
return nil, fmt.Errorf("airwallex create payment: missing payment intent id or client secret")
}
return &payment.CreatePaymentResponse{
TradeNo: intent.ID,
ClientSecret: intent.ClientSecret,
IntentID: intent.ID,
Currency: currency,
CountryCode: a.config["countryCode"],
PaymentEnv: a.checkoutEnv(),
}, nil
}
func (a *Airwallex) QueryOrder(ctx context.Context, tradeNo string) (*payment.QueryOrderResponse, error) {
intentID := strings.TrimSpace(tradeNo)
if intentID == "" {
return nil, fmt.Errorf("airwallex query order: missing payment intent id")
}
token, err := a.accessToken(ctx)
if err != nil {
return nil, fmt.Errorf("airwallex auth: %w", err)
}
var intent airwallexPaymentIntent
if err := a.doJSON(ctx, http.MethodGet, "/pa/payment_intents/"+url.PathEscape(intentID), token, nil, &intent); err != nil {
return nil, fmt.Errorf("airwallex query order: %w", err)
}
return &payment.QueryOrderResponse{
TradeNo: intent.ID,
Status: airwallexProviderStatus(intent.Status),
Amount: intent.Amount.InexactFloat64(),
Metadata: a.intentMetadata(intent, ""),
}, nil
}
func (a *Airwallex) VerifyNotification(_ context.Context, rawBody string, headers map[string]string) (*payment.PaymentNotification, error) {
if err := verifyAirwallexWebhookSignature(rawBody, headers, a.config["webhookSecret"], time.Now()); err != nil {
return nil, err
}
var event airwallexWebhookEvent
if err := json.Unmarshal([]byte(rawBody), &event); err != nil {
return nil, fmt.Errorf("airwallex parse webhook: %w", err)
}
switch event.Name {
case airwallexEventPaymentSucceeded, airwallexEventPaymentCancelled:
default:
return nil, nil
}
var intent airwallexPaymentIntent
if err := json.Unmarshal(event.Data.Object, &intent); err != nil {
return nil, fmt.Errorf("airwallex parse payment intent: %w", err)
}
if strings.TrimSpace(intent.ID) == "" || strings.TrimSpace(intent.MerchantOrderID) == "" {
return nil, fmt.Errorf("airwallex webhook missing payment intent id or merchant_order_id")
}
status := payment.ProviderStatusFailed
if event.Name == airwallexEventPaymentSucceeded {
if strings.ToUpper(strings.TrimSpace(intent.Status)) != airwallexPaymentStatusSucceeded {
return nil, fmt.Errorf("airwallex succeeded webhook has non-succeeded status: %s", intent.Status)
}
status = payment.NotificationStatusSuccess
}
return &payment.PaymentNotification{
TradeNo: intent.ID,
OrderID: intent.MerchantOrderID,
Amount: intent.Amount.InexactFloat64(),
Status: status,
RawData: rawBody,
Metadata: a.intentMetadata(intent, event.accountID()),
}, nil
}
func (a *Airwallex) Refund(ctx context.Context, req payment.RefundRequest) (*payment.RefundResponse, error) {
intentID := strings.TrimSpace(req.TradeNo)
if intentID == "" {
return nil, fmt.Errorf("airwallex refund missing payment intent id")
}
amount, err := decimal.NewFromString(req.Amount)
if err != nil || amount.LessThanOrEqual(decimal.Zero) {
return nil, fmt.Errorf("airwallex refund: invalid amount %s", req.Amount)
}
token, err := a.accessToken(ctx)
if err != nil {
return nil, fmt.Errorf("airwallex auth: %w", err)
}
payload := airwallexCreateRefundRequest{
RequestID: airwallexDeterministicRequestID("refund", intentID, req.Amount),
PaymentIntentID: intentID,
Amount: newAirwallexRequestAmount(amount),
Reason: strings.TrimSpace(req.Reason),
}
if payload.Reason == "" {
payload.Reason = "refund"
}
var resp airwallexRefund
if err := a.doJSON(ctx, http.MethodPost, "/pa/refunds/create", token, payload, &resp); err != nil {
return nil, fmt.Errorf("airwallex refund: %w", err)
}
if strings.TrimSpace(resp.ID) == "" {
return nil, fmt.Errorf("airwallex refund: missing refund id")
}
refundResp := &payment.RefundResponse{
RefundID: resp.ID,
Status: airwallexRefundProviderStatus(resp.Status),
}
if refundResp.Status != payment.ProviderStatusSuccess {
return refundResp, fmt.Errorf("airwallex refund not settled: status %s", strings.ToUpper(strings.TrimSpace(resp.Status)))
}
return refundResp, nil
}
func (a *Airwallex) CancelPayment(ctx context.Context, tradeNo string) error {
intentID := strings.TrimSpace(tradeNo)
if intentID == "" {
return nil
}
token, err := a.accessToken(ctx)
if err != nil {
return fmt.Errorf("airwallex auth: %w", err)
}
var intent airwallexPaymentIntent
if err := a.doJSON(ctx, http.MethodPost, "/pa/payment_intents/"+url.PathEscape(intentID)+"/cancel", token, nil, &intent); err != nil {
return fmt.Errorf("airwallex cancel payment: %w", err)
}
return nil
}
func (a *Airwallex) intentMetadata(intent airwallexPaymentIntent, accountID string) map[string]string {
metadata := map[string]string{
"currency": strings.ToUpper(strings.TrimSpace(intent.Currency)),
"status": strings.ToUpper(strings.TrimSpace(intent.Status)),
}
if accountID = strings.TrimSpace(accountID); accountID != "" {
metadata["account_id"] = accountID
} else if configured := strings.TrimSpace(a.config["accountId"]); configured != "" {
metadata["account_id"] = configured
}
return metadata
}
func (a *Airwallex) checkoutEnv() string {
if strings.EqualFold(a.config["apiBase"], airwallexProdAPIBase) {
return "prod"
}
return "demo"
}
func (a *Airwallex) accessToken(ctx context.Context) (string, error) {
cacheKey := a.tokenCacheKey()
rawState, _ := airwallexAccessTokens.LoadOrStore(cacheKey, &airwallexTokenState{})
state, ok := rawState.(*airwallexTokenState)
if !ok {
return "", fmt.Errorf("airwallex auth token cache state type mismatch")
}
state.mu.Lock()
defer state.mu.Unlock()
if state.token != "" && time.Now().Add(airwallexTokenSkew).Before(state.expiresAt) {
return state.token, nil
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, a.config["apiBase"]+"/authentication/login", nil)
if err != nil {
return "", err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("x-client-id", a.config["clientId"])
req.Header.Set("x-api-key", a.config["apiKey"])
if accountID := strings.TrimSpace(a.config["accountId"]); accountID != "" {
req.Header.Set("x-login-as", accountID)
}
body, status, err := a.do(req)
if err != nil {
return "", err
}
if status < http.StatusOK || status >= http.StatusMultipleChoices {
return "", formatAirwallexAuthHTTPError(status, body)
}
var resp airwallexAuthResponse
if err := json.Unmarshal(body, &resp); err != nil {
return "", fmt.Errorf("parse authentication response: %w", err)
}
if strings.TrimSpace(resp.Token) == "" {
return "", fmt.Errorf("authentication response missing token")
}
expiresAt, err := parseAirwallexTime(resp.ExpiresAt)
if err != nil {
expiresAt = time.Now().Add(25 * time.Minute)
}
state.token = resp.Token
state.expiresAt = expiresAt
return state.token, nil
}
func formatAirwallexAuthHTTPError(status int, body []byte) error {
summary := summarizeAirwallexResponse(body)
if status == http.StatusUnauthorized || status == http.StatusForbidden {
return fmt.Errorf("authentication HTTP %d: %s; Airwallex credentials were rejected, check Client ID/API Key, API Base environment (sandbox: https://api-demo.airwallex.com/api/v1, production: https://api.airwallex.com/api/v1), and Account ID (leave it empty for single-account scoped keys)", status, summary)
}
return fmt.Errorf("authentication HTTP %d: %s", status, summary)
}
func (a *Airwallex) tokenCacheKey() string {
sum := sha256.Sum256([]byte(a.config["apiKey"]))
return a.config["apiBase"] + "|" + a.config["clientId"] + "|" + strings.TrimSpace(a.config["accountId"]) + "|" + hex.EncodeToString(sum[:8])
}
func (a *Airwallex) doJSON(ctx context.Context, method, path, token string, payload any, out any) error {
var bodyReader io.Reader
if payload != nil {
body, err := json.Marshal(payload)
if err != nil {
return err
}
bodyReader = bytes.NewReader(body)
}
req, err := http.NewRequestWithContext(ctx, method, a.config["apiBase"]+path, bodyReader)
if err != nil {
return err
}
req.Header.Set("Content-Type", "application/json")
if token != "" {
req.Header.Set("Authorization", "Bearer "+token)
}
if accountID := strings.TrimSpace(a.config["accountId"]); accountID != "" {
req.Header.Set("x-on-behalf-of", accountID)
}
body, status, err := a.do(req)
if err != nil {
return err
}
if status < http.StatusOK || status >= http.StatusMultipleChoices {
return fmt.Errorf("HTTP %d: %s", status, summarizeAirwallexResponse(body))
}
if out == nil || len(bytes.TrimSpace(body)) == 0 {
return nil
}
if err := json.Unmarshal(body, out); err != nil {
return fmt.Errorf("parse response: %w", err)
}
return nil
}
func (a *Airwallex) do(req *http.Request) ([]byte, int, error) {
client := a.httpClient
if client == nil {
client = &http.Client{Timeout: airwallexHTTPTimeout}
}
resp, err := client.Do(req)
if err != nil {
return nil, 0, err
}
defer func() { _ = resp.Body.Close() }()
body, err := io.ReadAll(io.LimitReader(resp.Body, airwallexMaxResponseSize))
if err != nil {
return nil, resp.StatusCode, err
}
return body, resp.StatusCode, nil
}
func airwallexProviderStatus(status string) string {
switch strings.ToUpper(strings.TrimSpace(status)) {
case airwallexPaymentStatusSucceeded:
return payment.ProviderStatusPaid
case airwallexPaymentStatusCancelled:
return payment.ProviderStatusFailed
default:
return payment.ProviderStatusPending
}
}
func airwallexRefundProviderStatus(status string) string {
switch strings.ToUpper(strings.TrimSpace(status)) {
case airwallexRefundStatusSettled:
return payment.ProviderStatusSuccess
case airwallexRefundStatusFailed:
return payment.ProviderStatusFailed
case airwallexRefundStatusReceived, airwallexRefundStatusAccepted:
return payment.ProviderStatusPending
default:
return payment.ProviderStatusPending
}
}
func airwallexDeterministicRequestID(parts ...string) string {
hash := sha256.Sum256([]byte(strings.Join(parts, "\x00")))
var id uuid.UUID
copy(id[:], hash[:16])
id[6] = (id[6] & 0x0f) | 0x40
id[8] = (id[8] & 0x3f) | 0x80
return id.String()
}
func verifyAirwallexWebhookSignature(rawBody string, headers map[string]string, secret string, now time.Time) error {
secret = strings.TrimSpace(secret)
if secret == "" {
return fmt.Errorf("airwallex webhookSecret not configured")
}
timestamp := strings.TrimSpace(headers["x-timestamp"])
signature := strings.ToLower(strings.TrimSpace(headers["x-signature"]))
if timestamp == "" || signature == "" {
return fmt.Errorf("airwallex notification missing x-timestamp or x-signature header")
}
mac := hmac.New(sha256.New, []byte(secret))
_, _ = mac.Write([]byte(timestamp))
_, _ = mac.Write([]byte(rawBody))
expected := hex.EncodeToString(mac.Sum(nil))
if !hmac.Equal([]byte(expected), []byte(signature)) {
return fmt.Errorf("airwallex invalid signature")
}
ts, err := parseAirwallexWebhookTimestamp(timestamp)
if err != nil {
return err
}
if now.IsZero() {
now = time.Now()
}
if diff := now.Sub(ts).Abs(); diff > airwallexWebhookTolerance {
return fmt.Errorf("airwallex webhook timestamp outside tolerance")
}
return nil
}
func parseAirwallexWebhookTimestamp(raw string) (time.Time, error) {
ts, err := decimal.NewFromString(strings.TrimSpace(raw))
if err != nil {
return time.Time{}, fmt.Errorf("airwallex invalid webhook timestamp")
}
millis := ts.IntPart()
if millis <= 0 {
return time.Time{}, fmt.Errorf("airwallex invalid webhook timestamp")
}
return time.UnixMilli(millis), nil
}
func parseAirwallexTime(raw string) (time.Time, error) {
raw = strings.TrimSpace(raw)
if raw == "" {
return time.Time{}, fmt.Errorf("empty time")
}
for _, layout := range []string{time.RFC3339, "2006-01-02T15:04:05-0700", "2006-01-02T15:04:05.000-0700"} {
if t, err := time.Parse(layout, raw); err == nil {
return t, nil
}
}
return time.Time{}, fmt.Errorf("invalid time: %s", raw)
}
func summarizeAirwallexResponse(body []byte) string {
summary := strings.Join(strings.Fields(string(body)), " ")
if summary == "" {
return "<empty>"
}
if len(summary) > airwallexMaxErrorSummary {
return summary[:airwallexMaxErrorSummary] + "..."
}
return summary
}
type airwallexAuthResponse struct {
Token string `json:"token"`
ExpiresAt string `json:"expires_at"`
}
type airwallexCreatePaymentIntentRequest struct {
RequestID string `json:"request_id"`
Amount airwallexRequestAmount `json:"amount"`
Currency string `json:"currency"`
MerchantOrderID string `json:"merchant_order_id"`
ReturnURL string `json:"return_url,omitempty"`
Descriptor string `json:"descriptor,omitempty"`
Metadata map[string]string `json:"metadata,omitempty"`
}
type airwallexCreateRefundRequest struct {
RequestID string `json:"request_id"`
PaymentIntentID string `json:"payment_intent_id"`
Amount airwallexRequestAmount `json:"amount,omitempty"`
Reason string `json:"reason,omitempty"`
}
type airwallexRequestAmount struct {
decimal.Decimal
}
func newAirwallexRequestAmount(amount decimal.Decimal) airwallexRequestAmount {
return airwallexRequestAmount{Decimal: amount}
}
func (a airwallexRequestAmount) MarshalJSON() ([]byte, error) {
return []byte(a.String()), nil
}
func (a *airwallexRequestAmount) UnmarshalJSON(data []byte) error {
amount, err := decimal.NewFromString(strings.Trim(string(data), `"`))
if err != nil {
return err
}
a.Decimal = amount
return nil
}
type airwallexPaymentIntent struct {
ID string `json:"id"`
RequestID string `json:"request_id"`
ClientSecret string `json:"client_secret"`
MerchantOrderID string `json:"merchant_order_id"`
Amount decimal.Decimal `json:"amount"`
Currency string `json:"currency"`
Status string `json:"status"`
Metadata map[string]string `json:"metadata"`
}
type airwallexRefund struct {
ID string `json:"id"`
RequestID string `json:"request_id"`
PaymentIntentID string `json:"payment_intent_id"`
Amount decimal.Decimal `json:"amount"`
Currency string `json:"currency"`
Status string `json:"status"`
}
type airwallexWebhookEvent struct {
ID string `json:"id"`
Name string `json:"name"`
AccountID string `json:"accountId"`
AccountIDSnake string `json:"account_id"`
Data struct {
Object json.RawMessage `json:"object"`
} `json:"data"`
}
func (e airwallexWebhookEvent) accountID() string {
if accountID := strings.TrimSpace(e.AccountID); accountID != "" {
return accountID
}
return strings.TrimSpace(e.AccountIDSnake)
}
var (
_ payment.Provider = (*Airwallex)(nil)
_ payment.CancelableProvider = (*Airwallex)(nil)
_ payment.MerchantIdentityProvider = (*Airwallex)(nil)
)
@@ -0,0 +1,352 @@
//go:build unit
package provider
import (
"context"
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"strconv"
"strings"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/payment"
"github.com/stretchr/testify/require"
)
func TestNewAirwallexValidatesConfig(t *testing.T) {
t.Parallel()
_, err := NewAirwallex("1", map[string]string{
"clientId": "cid",
"apiKey": "key",
"webhookSecret": "secret",
"apiBase": "https://evil.example.com/api/v1",
})
require.ErrorContains(t, err, "apiBase host")
_, err = NewAirwallex("1", map[string]string{
"clientId": "cid",
"apiKey": "key",
"webhookSecret": "secret",
"apiBase": airwallexDemoAPIBase,
"countryCode": "C1",
})
require.ErrorContains(t, err, "countryCode")
prov, err := NewAirwallex("1", map[string]string{
"clientId": "cid",
"apiKey": "key",
"webhookSecret": "secret",
"apiBase": airwallexDemoAPIBase,
})
require.NoError(t, err)
require.Equal(t, payment.TypeAirwallex, prov.ProviderKey())
require.Equal(t, []payment.PaymentType{payment.TypeAirwallex}, prov.SupportedTypes())
require.Equal(t, payment.DefaultPaymentCurrency, prov.config["currency"])
require.Equal(t, airwallexDefaultCountry, prov.config["countryCode"])
}
func TestAirwallexCreatePaymentUsesServerAmountAndStableRequestID(t *testing.T) {
t.Parallel()
var createRequests []airwallexCreatePaymentIntentRequest
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/v1/authentication/login":
require.Equal(t, "cid", r.Header.Get("x-client-id"))
require.Equal(t, "key", r.Header.Get("x-api-key"))
_, _ = w.Write([]byte(`{"token":"token-1","expires_at":"2099-01-01T00:00:00Z"}`))
case "/api/v1/pa/payment_intents/create":
require.Equal(t, "Bearer token-1", r.Header.Get("Authorization"))
body, err := io.ReadAll(r.Body)
require.NoError(t, err)
require.Contains(t, string(body), `"amount":12.34`)
var payload airwallexCreatePaymentIntentRequest
require.NoError(t, json.Unmarshal(body, &payload))
createRequests = append(createRequests, payload)
_, _ = w.Write([]byte(`{"id":"int_123","client_secret":"secret_123","amount":12.34,"currency":"CNY","merchant_order_id":"sub2_order","status":"REQUIRES_PAYMENT_METHOD"}`))
default:
http.NotFound(w, r)
}
}))
defer server.Close()
prov := mustTestAirwallexProvider(t, server)
resp, err := prov.CreatePayment(context.Background(), payment.CreatePaymentRequest{
OrderID: "sub2_order",
Amount: "12.34",
ReturnURL: "https://merchant.example.com/payment/result",
})
require.NoError(t, err)
require.Equal(t, "int_123", resp.TradeNo)
require.Equal(t, "secret_123", resp.ClientSecret)
require.Equal(t, "int_123", resp.IntentID)
require.Equal(t, "CNY", resp.Currency)
require.Equal(t, "CN", resp.CountryCode)
require.Equal(t, "demo", resp.PaymentEnv)
require.Len(t, createRequests, 1)
require.Equal(t, "12.34", createRequests[0].Amount.StringFixed(2))
require.Equal(t, "CNY", createRequests[0].Currency)
require.Equal(t, "sub2_order", createRequests[0].MerchantOrderID)
require.Equal(t, airwallexDeterministicRequestID("payment-intent", "sub2_order", "12.34", "CNY"), createRequests[0].RequestID)
}
func TestAirwallexCreatePaymentUsesConfiguredCurrency(t *testing.T) {
t.Parallel()
var createRequest airwallexCreatePaymentIntentRequest
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/v1/authentication/login":
_, _ = w.Write([]byte(`{"token":"token-1","expires_at":"2099-01-01T00:00:00Z"}`))
case "/api/v1/pa/payment_intents/create":
body, err := io.ReadAll(r.Body)
require.NoError(t, err)
require.NoError(t, json.Unmarshal(body, &createRequest))
_, _ = w.Write([]byte(`{"id":"int_123","client_secret":"secret_123","amount":12.34,"currency":"HKD","merchant_order_id":"sub2_order","status":"REQUIRES_PAYMENT_METHOD"}`))
default:
http.NotFound(w, r)
}
}))
defer server.Close()
prov, err := NewAirwallex("1", map[string]string{
"clientId": "cid",
"apiKey": "key",
"webhookSecret": "whsec",
"apiBase": airwallexDemoAPIBase,
"currency": "hkd",
"countryCode": "HK",
})
require.NoError(t, err)
prov.config["apiBase"] = server.URL + "/api/v1"
prov.httpClient = server.Client()
resp, err := prov.CreatePayment(context.Background(), payment.CreatePaymentRequest{
OrderID: "sub2_order",
Amount: "12.34",
ReturnURL: "https://merchant.example.com/payment/result",
})
require.NoError(t, err)
require.Equal(t, "HKD", createRequest.Currency)
require.Equal(t, "HKD", resp.Currency)
require.Equal(t, "HK", resp.CountryCode)
require.Equal(t, "HKD", prov.MerchantIdentityMetadata()["currency"])
}
func TestAirwallexRequestsUseConfiguredAccountID(t *testing.T) {
t.Parallel()
paRequestCount := 0
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/v1/authentication/login":
require.Equal(t, "acct_123", r.Header.Get("x-login-as"))
_, _ = w.Write([]byte(`{"token":"token-1","expires_at":"2099-01-01T00:00:00Z"}`))
case "/api/v1/pa/payment_intents/create":
paRequestCount++
require.Equal(t, "acct_123", r.Header.Get("x-on-behalf-of"))
_, _ = w.Write([]byte(`{"id":"int_123","client_secret":"secret_123","amount":12.34,"currency":"CNY","merchant_order_id":"sub2_order","status":"REQUIRES_PAYMENT_METHOD"}`))
case "/api/v1/pa/payment_intents/int_123":
paRequestCount++
require.Equal(t, "acct_123", r.Header.Get("x-on-behalf-of"))
_, _ = w.Write([]byte(`{"id":"int_123","amount":12.34,"currency":"CNY","merchant_order_id":"sub2_order","status":"SUCCEEDED"}`))
case "/api/v1/pa/refunds/create":
paRequestCount++
require.Equal(t, "acct_123", r.Header.Get("x-on-behalf-of"))
_, _ = w.Write([]byte(`{"id":"ref_123","payment_intent_id":"int_123","amount":12.34,"currency":"CNY","status":"SETTLED"}`))
case "/api/v1/pa/payment_intents/int_123/cancel":
paRequestCount++
require.Equal(t, "acct_123", r.Header.Get("x-on-behalf-of"))
_, _ = w.Write([]byte(`{"id":"int_123","amount":12.34,"currency":"CNY","merchant_order_id":"sub2_order","status":"CANCELLED"}`))
default:
http.NotFound(w, r)
}
}))
defer server.Close()
prov, err := NewAirwallex("1", map[string]string{
"clientId": "cid",
"apiKey": "key",
"webhookSecret": "whsec",
"apiBase": airwallexDemoAPIBase,
"accountId": "acct_123",
})
require.NoError(t, err)
prov.config["apiBase"] = server.URL + "/api/v1"
prov.httpClient = server.Client()
_, err = prov.CreatePayment(context.Background(), payment.CreatePaymentRequest{
OrderID: "sub2_order",
Amount: "12.34",
})
require.NoError(t, err)
_, err = prov.QueryOrder(context.Background(), "int_123")
require.NoError(t, err)
_, err = prov.Refund(context.Background(), payment.RefundRequest{
TradeNo: "int_123",
Amount: "12.34",
Reason: "test refund",
})
require.NoError(t, err)
require.NoError(t, prov.CancelPayment(context.Background(), "int_123"))
require.Contains(t, prov.tokenCacheKey(), "acct_123")
require.Equal(t, 4, paRequestCount)
}
func TestAirwallexRefundRejectsUnsettledStatus(t *testing.T) {
t.Parallel()
for _, status := range []string{"RECEIVED", "ACCEPTED", "FAILED"} {
t.Run(status, func(t *testing.T) {
t.Parallel()
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/v1/authentication/login":
_, _ = w.Write([]byte(`{"token":"token-1","expires_at":"2099-01-01T00:00:00Z"}`))
case "/api/v1/pa/refunds/create":
_, _ = w.Write([]byte(`{"id":"ref_123","payment_intent_id":"int_123","amount":12.34,"currency":"CNY","status":"` + status + `"}`))
default:
http.NotFound(w, r)
}
}))
defer server.Close()
prov := mustTestAirwallexProvider(t, server)
resp, err := prov.Refund(context.Background(), payment.RefundRequest{
TradeNo: "int_123",
Amount: "12.34",
Reason: "test refund",
})
require.ErrorContains(t, err, "airwallex refund not settled")
require.NotNil(t, resp)
require.Equal(t, "ref_123", resp.RefundID)
if status == airwallexRefundStatusFailed {
require.Equal(t, payment.ProviderStatusFailed, resp.Status)
} else {
require.Equal(t, payment.ProviderStatusPending, resp.Status)
}
})
}
}
func TestAirwallexAuthErrorIncludesCredentialGuidance(t *testing.T) {
t.Parallel()
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "/api/v1/authentication/login", r.URL.Path)
w.WriteHeader(http.StatusUnauthorized)
_, _ = w.Write([]byte(`{"code":"credentials_invalid","details":["Access Denied"],"message":"UNAUTHORIZED","source":""}`))
}))
defer server.Close()
prov := mustTestAirwallexProvider(t, server)
_, err := prov.CreatePayment(context.Background(), payment.CreatePaymentRequest{
OrderID: "sub2_order",
Amount: "12.34",
})
require.ErrorContains(t, err, "credentials_invalid")
require.ErrorContains(t, err, "API Base environment")
require.ErrorContains(t, err, "https://api-demo.airwallex.com/api/v1")
require.ErrorContains(t, err, "https://api.airwallex.com/api/v1")
require.ErrorContains(t, err, "Account ID")
}
func TestAirwallexVerifyNotificationRequiresValidSignatureAndCurrency(t *testing.T) {
t.Parallel()
prov, err := NewAirwallex("1", map[string]string{
"clientId": "cid",
"apiKey": "key",
"webhookSecret": "whsec",
"apiBase": airwallexDemoAPIBase,
"accountId": "acct_123",
})
require.NoError(t, err)
raw := `{"id":"evt_1","name":"payment_intent.succeeded","accountId":"acct_123","data":{"object":{"id":"int_123","merchant_order_id":"sub2_abc","amount":88.66,"currency":"CNY","status":"SUCCEEDED"}}}`
timestamp := strconv.FormatInt(time.Now().UnixMilli(), 10)
headers := signedAirwallexHeaders(raw, timestamp, "whsec")
n, err := prov.VerifyNotification(context.Background(), raw, headers)
require.NoError(t, err)
require.NotNil(t, n)
require.Equal(t, "int_123", n.TradeNo)
require.Equal(t, "sub2_abc", n.OrderID)
require.Equal(t, payment.NotificationStatusSuccess, n.Status)
require.InDelta(t, 88.66, n.Amount, 0.0001)
require.Equal(t, "CNY", n.Metadata["currency"])
require.Equal(t, "acct_123", n.Metadata["account_id"])
headers["x-signature"] = strings.Repeat("0", 64)
_, err = prov.VerifyNotification(context.Background(), raw, headers)
require.ErrorContains(t, err, "invalid signature")
}
func TestVerifyAirwallexWebhookSignatureRejectsReplay(t *testing.T) {
t.Parallel()
raw := `{"id":"evt_1"}`
timestamp := "1778241600000"
headers := signedAirwallexHeaders(raw, timestamp, "whsec")
err := verifyAirwallexWebhookSignature(raw, headers, "whsec", time.UnixMilli(1778241600000).Add(airwallexWebhookTolerance+time.Millisecond))
require.ErrorContains(t, err, "outside tolerance")
}
func TestAirwallexQueryOrderMapsSucceeded(t *testing.T) {
t.Parallel()
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/v1/authentication/login":
_, _ = w.Write([]byte(`{"token":"token-1","expires_at":"2099-01-01T00:00:00Z"}`))
case "/api/v1/pa/payment_intents/int_123":
_, _ = w.Write([]byte(`{"id":"int_123","amount":99.01,"currency":"CNY","merchant_order_id":"sub2_order","status":"SUCCEEDED"}`))
default:
http.NotFound(w, r)
}
}))
defer server.Close()
prov := mustTestAirwallexProvider(t, server)
resp, err := prov.QueryOrder(context.Background(), "int_123")
require.NoError(t, err)
require.Equal(t, payment.ProviderStatusPaid, resp.Status)
require.InDelta(t, 99.01, resp.Amount, 0.0001)
require.Equal(t, "CNY", resp.Metadata["currency"])
require.Equal(t, "SUCCEEDED", resp.Metadata["status"])
}
func mustTestAirwallexProvider(t *testing.T, server *httptest.Server) *Airwallex {
t.Helper()
prov, err := NewAirwallex("1", map[string]string{
"clientId": "cid",
"apiKey": "key",
"webhookSecret": "whsec",
"apiBase": airwallexDemoAPIBase,
})
require.NoError(t, err)
prov.config["apiBase"] = server.URL + "/api/v1"
prov.httpClient = server.Client()
return prov
}
func signedAirwallexHeaders(rawBody, timestamp, secret string) map[string]string {
mac := hmac.New(sha256.New, []byte(secret))
_, _ = mac.Write([]byte(timestamp))
_, _ = mac.Write([]byte(rawBody))
return map[string]string{
"x-timestamp": timestamp,
"x-signature": hex.EncodeToString(mac.Sum(nil)),
}
}
@@ -17,6 +17,8 @@ func CreateProvider(providerKey string, instanceID string, config map[string]str
return NewWxpay(instanceID, config)
case payment.TypeStripe:
return NewStripe(instanceID, config)
case payment.TypeAirwallex:
return NewAirwallex(instanceID, config)
default:
return nil, fmt.Errorf("unknown provider key: %s", providerKey)
}
+57 -11
View File
@@ -14,7 +14,6 @@ import (
// Stripe constants.
const (
stripeCurrency = "cny"
stripeEventPaymentSuccess = "payment_intent.succeeded"
stripeEventPaymentFailed = "payment_intent.payment_failed"
)
@@ -34,9 +33,15 @@ func NewStripe(instanceID string, config map[string]string) (*Stripe, error) {
if config["secretKey"] == "" {
return nil, fmt.Errorf("stripe config missing required key: secretKey")
}
cfg := cloneStringMap(config)
currency, err := payment.NormalizePaymentCurrency(cfg["currency"])
if err != nil {
return nil, fmt.Errorf("stripe config currency: %w", err)
}
cfg["currency"] = currency
return &Stripe{
instanceID: instanceID,
config: config,
config: cfg,
}, nil
}
@@ -60,6 +65,24 @@ func (s *Stripe) SupportedTypes() []payment.PaymentType {
return []payment.PaymentType{payment.TypeStripe}
}
func (s *Stripe) MerchantIdentityMetadata() map[string]string {
if s == nil {
return nil
}
return map[string]string{"currency": s.currency()}
}
func (s *Stripe) currency() string {
if s == nil {
return payment.DefaultPaymentCurrency
}
currency, err := payment.NormalizePaymentCurrency(s.config["currency"])
if err != nil {
return payment.DefaultPaymentCurrency
}
return currency
}
// stripePaymentMethodTypes maps our PaymentType to Stripe payment_method_types.
var stripePaymentMethodTypes = map[string][]string{
payment.TypeCard: {"card"},
@@ -72,7 +95,8 @@ var stripePaymentMethodTypes = map[string][]string{
func (s *Stripe) CreatePayment(ctx context.Context, req payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) {
s.ensureInit()
amountInCents, err := payment.YuanToFen(req.Amount)
currency := s.currency()
amountInMinorUnit, err := payment.AmountToMinorUnit(req.Amount, currency)
if err != nil {
return nil, fmt.Errorf("stripe create payment: %w", err)
}
@@ -86,8 +110,8 @@ func (s *Stripe) CreatePayment(ctx context.Context, req payment.CreatePaymentReq
}
params := &stripe.PaymentIntentCreateParams{
Amount: stripe.Int64(amountInCents),
Currency: stripe.String(stripeCurrency),
Amount: stripe.Int64(amountInMinorUnit),
Currency: stripe.String(strings.ToLower(currency)),
PaymentMethodTypes: pmTypes,
Description: stripe.String(req.Subject),
Metadata: map[string]string{"orderId": req.OrderID},
@@ -113,6 +137,7 @@ func (s *Stripe) CreatePayment(ctx context.Context, req payment.CreatePaymentReq
return &payment.CreatePaymentResponse{
TradeNo: pi.ID,
ClientSecret: pi.ClientSecret,
Currency: currency,
}, nil
}
@@ -133,10 +158,14 @@ func (s *Stripe) QueryOrder(ctx context.Context, tradeNo string) (*payment.Query
status = payment.ProviderStatusFailed
}
currency := stripeIntentCurrency(pi.Currency, s.currency())
return &payment.QueryOrderResponse{
TradeNo: pi.ID,
Status: status,
Amount: payment.FenToYuan(pi.Amount),
Amount: payment.MinorUnitToAmount(pi.Amount, currency),
Metadata: map[string]string{
"currency": currency,
},
}, nil
}
@@ -174,12 +203,16 @@ func parseStripePaymentIntent(event *stripe.Event, status string, rawBody string
if err := json.Unmarshal(event.Data.Raw, &pi); err != nil {
return nil, fmt.Errorf("stripe parse payment_intent: %w", err)
}
currency := stripeIntentCurrency(pi.Currency, payment.DefaultPaymentCurrency)
return &payment.PaymentNotification{
TradeNo: pi.ID,
OrderID: pi.Metadata["orderId"],
Amount: payment.FenToYuan(pi.Amount),
Amount: payment.MinorUnitToAmount(pi.Amount, currency),
Status: status,
RawData: rawBody,
Metadata: map[string]string{
"currency": currency,
},
}, nil
}
@@ -187,14 +220,14 @@ func parseStripePaymentIntent(event *stripe.Event, status string, rawBody string
func (s *Stripe) Refund(ctx context.Context, req payment.RefundRequest) (*payment.RefundResponse, error) {
s.ensureInit()
amountInCents, err := payment.YuanToFen(req.Amount)
amountInMinorUnit, err := payment.AmountToMinorUnit(req.Amount, s.currency())
if err != nil {
return nil, fmt.Errorf("stripe refund: %w", err)
}
params := &stripe.RefundCreateParams{
PaymentIntent: stripe.String(req.TradeNo),
Amount: stripe.Int64(amountInCents),
Amount: stripe.Int64(amountInMinorUnit),
Reason: stripe.String(string(stripe.RefundReasonRequestedByCustomer)),
}
params.Context = ctx
@@ -215,6 +248,18 @@ func (s *Stripe) Refund(ctx context.Context, req payment.RefundRequest) (*paymen
}, nil
}
func stripeIntentCurrency(raw stripe.Currency, fallback string) string {
currency, err := payment.NormalizePaymentCurrency(string(raw))
if err != nil || currency == payment.DefaultPaymentCurrency && strings.TrimSpace(string(raw)) == "" {
normalizedFallback, fallbackErr := payment.NormalizePaymentCurrency(fallback)
if fallbackErr == nil {
return normalizedFallback
}
return payment.DefaultPaymentCurrency
}
return currency
}
// resolveStripeMethodTypes converts instance supported_types (comma-separated)
// into Stripe API payment_method_types. Falls back to ["card"] if empty.
func resolveStripeMethodTypes(instanceSubMethods string) []string {
@@ -257,6 +302,7 @@ func (s *Stripe) CancelPayment(ctx context.Context, tradeNo string) error {
// Ensure interface compliance.
var (
_ payment.Provider = (*Stripe)(nil)
_ payment.CancelableProvider = (*Stripe)(nil)
_ payment.Provider = (*Stripe)(nil)
_ payment.CancelableProvider = (*Stripe)(nil)
_ payment.MerchantIdentityProvider = (*Stripe)(nil)
)
+10 -3
View File
@@ -17,6 +17,7 @@ const (
TypeCard PaymentType = "card"
TypeLink PaymentType = "link"
TypeEasyPay PaymentType = "easypay"
TypeAirwallex PaymentType = "airwallex"
)
// Order status constants shared across payment and service layers.
@@ -82,6 +83,8 @@ func GetBasePaymentType(t string) string {
switch {
case t == TypeEasyPay:
return TypeEasyPay
case t == TypeAirwallex:
return TypeAirwallex
case t == TypeStripe || t == TypeCard || t == TypeLink:
return TypeStripe
case len(t) >= len(TypeAlipay) && t[:len(TypeAlipay)] == TypeAlipay:
@@ -96,7 +99,7 @@ func GetBasePaymentType(t string) string {
// CreatePaymentRequest holds the parameters for creating a new payment.
type CreatePaymentRequest struct {
OrderID string // Internal order ID
Amount string // Pay amount in CNY (formatted to 2 decimal places)
Amount string // 支付金额,按服务商实例配置的币种解释
PaymentType string // e.g. "alipay", "wxpay", "stripe"
Subject string // Product description
NotifyURL string // Webhook callback URL
@@ -141,7 +144,11 @@ type CreatePaymentResponse struct {
TradeNo string // Third-party transaction ID
PayURL string // H5 payment URL (alipay/wxpay)
QRCode string // QR code content for scanning
ClientSecret string // Stripe PaymentIntent client secret
ClientSecret string // Stripe PaymentIntent 客户端密钥
IntentID string // 前端 SDK 需要的服务商支付意图 ID
Currency string // 服务商支付币种
CountryCode string // 服务商收银台国家/地区代码
PaymentEnv string // 服务商前端环境标识
ResultType CreatePaymentResultType // Typed result contract for frontend flows
OAuth *WechatOAuthInfo // WeChat OAuth bootstrap payload when required
JSAPI *WechatJSAPIPayload // WeChat JSAPI invocation payload when ready
@@ -151,7 +158,7 @@ type CreatePaymentResponse struct {
type QueryOrderResponse struct {
TradeNo string
Status string // "pending", "paid", "failed", "refunded"
Amount float64 // Amount in CNY
Amount float64 // 按服务商返回币种解释的金额
PaidAt string // RFC3339 timestamp or empty
Metadata map[string]string
}
@@ -20,8 +20,35 @@ const (
CloudflareInsightsDomain = "https://static.cloudflareinsights.com"
// StripeDomain is the domain for Stripe.js SDK
StripeDomain = "https://*.stripe.com"
// AirwallexStaticDomain 是 Airwallex 生产环境 SDK 脚本域名。
AirwallexStaticDomain = "https://static.airwallex.com"
// AirwallexCheckoutDomain 是 Airwallex 生产环境收银台元素和 iframe 域名。
AirwallexCheckoutDomain = "https://checkout.airwallex.com"
// AirwallexDemoStaticDomain 是 Airwallex 沙箱环境 SDK 脚本域名。
AirwallexDemoStaticDomain = "https://static-demo.airwallex.com"
// AirwallexDemoCheckoutDomain 是 Airwallex 沙箱环境收银台元素和 iframe 域名。
AirwallexDemoCheckoutDomain = "https://checkout-demo.airwallex.com"
)
var requiredCSPDirectiveValues = []struct {
directive string
value string
}{
{"script-src", CloudflareInsightsDomain},
{"script-src", StripeDomain},
{"frame-src", StripeDomain},
{"script-src", AirwallexStaticDomain},
{"script-src", AirwallexCheckoutDomain},
{"style-src", AirwallexStaticDomain},
{"style-src", AirwallexCheckoutDomain},
{"frame-src", AirwallexCheckoutDomain},
{"script-src", AirwallexDemoStaticDomain},
{"script-src", AirwallexDemoCheckoutDomain},
{"style-src", AirwallexDemoStaticDomain},
{"style-src", AirwallexDemoCheckoutDomain},
{"frame-src", AirwallexDemoCheckoutDomain},
}
// GenerateNonce generates a cryptographically secure random nonce.
// 返回 error 以确保调用方在 crypto/rand 失败时能正确降级。
func GenerateNonce() (string, error) {
@@ -100,29 +127,39 @@ func isAPIRoutePath(c *gin.Context) bool {
strings.HasPrefix(path, "/images")
}
// enhanceCSPPolicy ensures the CSP policy includes nonce support, Cloudflare Insights,
// and Stripe.js domains. This allows the application to work correctly even if the
// config file has an older CSP policy.
// enhanceCSPPolicy 确保 CSP 策略包含 nonce 支持和支付 SDK 必需域名。
// 这样旧配置文件没有及时补域名时,前端支付组件仍能正常加载。
func enhanceCSPPolicy(policy string) string {
// Add nonce placeholder to script-src if not present
if !strings.Contains(policy, NonceTemplate) && !strings.Contains(policy, "'nonce-") {
policy = addToDirective(policy, "script-src", NonceTemplate)
}
// Add Cloudflare Insights domain to script-src if not present
if !strings.Contains(policy, CloudflareInsightsDomain) {
policy = addToDirective(policy, "script-src", CloudflareInsightsDomain)
}
// Add Stripe.js domain to script-src and frame-src if not present
if !strings.Contains(policy, "stripe.com") {
policy = addToDirective(policy, "script-src", StripeDomain)
policy = addToDirective(policy, "frame-src", StripeDomain)
for _, required := range requiredCSPDirectiveValues {
if !directiveHasValue(policy, required.directive, required.value) {
policy = addToDirective(policy, required.directive, required.value)
}
}
return policy
}
func directiveHasValue(policy, directive, value string) bool {
for _, rawDirective := range strings.Split(policy, ";") {
fields := strings.Fields(strings.TrimSpace(rawDirective))
if len(fields) == 0 || fields[0] != directive {
continue
}
for _, field := range fields[1:] {
if field == value {
return true
}
}
return false
}
return false
}
// addToDirective adds a value to a specific CSP directive.
// If the directive doesn't exist, it will be added after default-src.
func addToDirective(policy, directive, value string) string {
@@ -330,6 +330,52 @@ func TestEnhanceCSPPolicy(t *testing.T) {
assert.NotContains(t, enhanced, NonceTemplate)
assert.Contains(t, enhanced, "'nonce-existing'")
})
t.Run("adds_airwallex_domains_for_payment_sdk", func(t *testing.T) {
policy := "default-src 'self'; script-src 'self' __CSP_NONCE__; style-src 'self'; frame-src 'self'"
enhanced := enhanceCSPPolicy(policy)
assert.Contains(t, enhanced, "script-src 'self' __CSP_NONCE__")
assert.Contains(t, enhanced, AirwallexStaticDomain)
assert.Contains(t, enhanced, AirwallexCheckoutDomain)
assert.Contains(t, enhanced, AirwallexDemoStaticDomain)
assert.Contains(t, enhanced, AirwallexDemoCheckoutDomain)
assert.Contains(t, enhanced, "style-src 'self'")
assert.Contains(t, enhanced, "frame-src 'self'")
})
t.Run("does_not_duplicate_airwallex_domains", func(t *testing.T) {
policy := "default-src 'self'; script-src 'self' https://static.airwallex.com https://static-demo.airwallex.com; frame-src https://checkout.airwallex.com https://checkout-demo.airwallex.com"
enhanced := enhanceCSPPolicy(policy)
assert.Equal(t, 1, countDirectiveValue(enhanced, "script-src", AirwallexStaticDomain))
assert.Equal(t, 1, countDirectiveValue(enhanced, "script-src", AirwallexCheckoutDomain))
assert.Equal(t, 1, countDirectiveValue(enhanced, "style-src", AirwallexStaticDomain))
assert.Equal(t, 1, countDirectiveValue(enhanced, "style-src", AirwallexCheckoutDomain))
assert.Equal(t, 1, countDirectiveValue(enhanced, "frame-src", AirwallexCheckoutDomain))
assert.Equal(t, 1, countDirectiveValue(enhanced, "script-src", AirwallexDemoStaticDomain))
assert.Equal(t, 1, countDirectiveValue(enhanced, "script-src", AirwallexDemoCheckoutDomain))
assert.Equal(t, 1, countDirectiveValue(enhanced, "style-src", AirwallexDemoStaticDomain))
assert.Equal(t, 1, countDirectiveValue(enhanced, "style-src", AirwallexDemoCheckoutDomain))
assert.Equal(t, 1, countDirectiveValue(enhanced, "frame-src", AirwallexDemoCheckoutDomain))
})
}
func countDirectiveValue(policy, directive, value string) int {
for _, rawDirective := range strings.Split(policy, ";") {
fields := strings.Fields(strings.TrimSpace(rawDirective))
if len(fields) == 0 || fields[0] != directive {
continue
}
count := 0
for _, field := range fields[1:] {
if field == value {
count++
}
}
return count
}
return 0
}
func TestAddToDirective(t *testing.T) {
@@ -62,6 +62,7 @@ func RegisterPaymentRoutes(
webhook.POST("/alipay", webhookHandler.AlipayNotify)
webhook.POST("/wxpay", webhookHandler.WxpayNotify)
webhook.POST("/stripe", webhookHandler.StripeWebhook)
webhook.POST("/airwallex", webhookHandler.AirwallexWebhook)
}
// --- Admin payment endpoints (admin auth) ---
+6 -4
View File
@@ -3,6 +3,7 @@ package service
import (
"math"
"github.com/Wei-Shaw/sub2api/internal/payment"
"github.com/shopspring/decimal"
)
@@ -22,16 +23,17 @@ func calculateCreditedBalance(paymentAmount, multiplier float64) float64 {
InexactFloat64()
}
func calculateGatewayRefundAmount(orderAmount, payAmount, refundAmount float64) float64 {
func calculateGatewayRefundAmount(orderAmount, payAmount, refundAmount float64, currency string) float64 {
if orderAmount <= 0 || payAmount <= 0 || refundAmount <= 0 {
return 0
}
if math.Abs(refundAmount-orderAmount) <= amountToleranceCNY {
return decimal.NewFromFloat(payAmount).Round(2).InexactFloat64()
fractionDigits := int32(payment.CurrencyMaxFractionDigits(currency))
if math.Abs(refundAmount-orderAmount) <= paymentAmountToleranceForCurrency(currency) {
return decimal.NewFromFloat(payAmount).Round(fractionDigits).InexactFloat64()
}
return decimal.NewFromFloat(payAmount).
Mul(decimal.NewFromFloat(refundAmount)).
Div(decimal.NewFromFloat(orderAmount)).
Round(2).
Round(fractionDigits).
InexactFloat64()
}
@@ -8,6 +8,7 @@ import (
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance"
"github.com/Wei-Shaw/sub2api/internal/payment"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
// GetAvailableMethodLimits collects all payment types from enabled provider
@@ -25,7 +26,12 @@ func (s *PaymentConfigService) GetAvailableMethodLimits(ctx context.Context) (*M
Methods: make(map[string]MethodLimits, len(typeInstances)),
}
for pt, insts := range typeInstances {
currency, ok := s.pcAggregateMethodCurrency(insts)
if !ok {
continue
}
ml := pcAggregateMethodLimits(pt, insts)
ml.Currency = currency
resp.Methods[ml.PaymentType] = ml
}
resp.GlobalMin, resp.GlobalMax = pcComputeGlobalRange(resp.Methods)
@@ -82,11 +88,81 @@ func (s *PaymentConfigService) GetMethodLimits(ctx context.Context, types []stri
matching = append(matching, inst)
}
}
result = append(result, pcAggregateMethodLimits(pt, matching))
currency, ok := s.pcAggregateMethodCurrency(matching)
if !ok {
continue
}
ml := pcAggregateMethodLimits(pt, matching)
ml.Currency = currency
result = append(result, ml)
}
return result, nil
}
func (s *PaymentConfigService) ValidateMethodCurrencyConsistency(ctx context.Context, paymentType string) (string, error) {
method := NormalizeVisibleMethod(paymentType)
if method == "" || s == nil || s.entClient == nil {
return payment.DefaultPaymentCurrency, nil
}
instances, err := s.entClient.PaymentProviderInstance.Query().
Where(paymentproviderinstance.EnabledEQ(true)).All(ctx)
if err != nil {
return "", fmt.Errorf("query provider instances: %w", err)
}
typeInstances := pcGroupByPaymentType(instances)
typeInstances = s.pcApplyEnabledVisibleMethodInstances(ctx, typeInstances, instances)
matching := typeInstances[method]
if len(matching) == 0 {
return payment.DefaultPaymentCurrency, nil
}
currency, ok := s.pcAggregateMethodCurrency(matching)
if !ok {
return "", infraerrors.ServiceUnavailable(
"PAYMENT_METHOD_CURRENCY_CONFLICT",
"payment method has enabled provider instances with mixed currencies",
).WithMetadata(map[string]string{"payment_type": method})
}
return currency, nil
}
func (s *PaymentConfigService) pcAggregateMethodCurrency(instances []*dbent.PaymentProviderInstance) (string, bool) {
currency := ""
for _, inst := range instances {
next := s.pcInstancePaymentCurrency(inst)
if next == "" {
continue
}
if currency == "" {
currency = next
continue
}
if currency != next {
return "", false
}
}
if currency == "" {
return payment.DefaultPaymentCurrency, true
}
return currency, true
}
func (s *PaymentConfigService) pcInstancePaymentCurrency(inst *dbent.PaymentProviderInstance) string {
if inst == nil {
return payment.DefaultPaymentCurrency
}
cfg := map[string]string{}
if s != nil {
decrypted, err := s.decryptConfig(inst.Config)
if err == nil && decrypted != nil {
cfg = decrypted
}
}
return paymentProviderConfigCurrency(inst.ProviderKey, cfg)
}
// pcGroupByPaymentType groups instances by user-facing payment type.
// For Stripe providers, ALL sub-types (card, link, alipay, wxpay) map to "stripe"
// because the user sees a single "Stripe" button, not individual sub-methods.
@@ -6,6 +6,7 @@ import (
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/payment"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/stretchr/testify/require"
)
@@ -199,6 +200,61 @@ func TestPcGroupByPaymentType(t *testing.T) {
})
}
func TestPcAggregateMethodCurrency(t *testing.T) {
t.Parallel()
svc := &PaymentConfigService{}
stripe := makeInstance(1, payment.TypeStripe, payment.TypeStripe, "")
stripe.Config = `{"currency":"hkd"}`
currency, ok := svc.pcAggregateMethodCurrency([]*dbent.PaymentProviderInstance{stripe})
require.True(t, ok)
require.Equal(t, "HKD", currency)
airwallex := makeInstance(2, payment.TypeAirwallex, payment.TypeAirwallex, "")
airwallex.Config = `{"currency":"usd"}`
currency, ok = svc.pcAggregateMethodCurrency([]*dbent.PaymentProviderInstance{stripe, airwallex})
require.False(t, ok)
require.Empty(t, currency)
easypay := makeInstance(3, payment.TypeEasyPay, payment.TypeAlipay, "")
currency, ok = svc.pcAggregateMethodCurrency([]*dbent.PaymentProviderInstance{easypay})
require.True(t, ok)
require.Equal(t, payment.DefaultPaymentCurrency, currency)
}
func TestGetAvailableMethodLimitsOmitsMixedCurrencyMethod(t *testing.T) {
ctx := context.Background()
client := newPaymentConfigServiceTestClient(t)
_, err := client.PaymentProviderInstance.Create().
SetProviderKey(payment.TypeStripe).
SetName("Stripe HKD").
SetConfig(`{"currency":"HKD"}`).
SetSupportedTypes("card,link").
SetEnabled(true).
Save(ctx)
require.NoError(t, err)
_, err = client.PaymentProviderInstance.Create().
SetProviderKey(payment.TypeStripe).
SetName("Stripe USD").
SetConfig(`{"currency":"USD"}`).
SetSupportedTypes("card,link").
SetEnabled(true).
Save(ctx)
require.NoError(t, err)
svc := &PaymentConfigService{entClient: client}
resp, err := svc.GetAvailableMethodLimits(ctx)
require.NoError(t, err)
require.NotContains(t, resp.Methods, payment.TypeStripe)
_, err = svc.ValidateMethodCurrencyConsistency(ctx, payment.TypeStripe)
require.Error(t, err)
appErr := infraerrors.FromError(err)
require.Equal(t, "PAYMENT_METHOD_CURRENCY_CONFLICT", appErr.Reason)
}
func TestPcComputeGlobalRange(t *testing.T) {
t.Parallel()
@@ -110,10 +110,11 @@ var pendingOrderStatuses = []string{
// Key matching is case-insensitive. Non-listed keys (e.g. appId, notifyUrl,
// stripe publishableKey) are returned in plaintext by the admin GET API.
var providerSensitiveConfigFields = map[string]map[string]struct{}{
payment.TypeEasyPay: {"pkey": {}},
payment.TypeAlipay: {"privatekey": {}, "publickey": {}, "alipaypublickey": {}},
payment.TypeWxpay: {"privatekey": {}, "apiv3key": {}, "publickey": {}},
payment.TypeStripe: {"secretkey": {}, "webhooksecret": {}},
payment.TypeEasyPay: {"pkey": {}},
payment.TypeAlipay: {"privatekey": {}, "publickey": {}, "alipaypublickey": {}},
payment.TypeWxpay: {"privatekey": {}, "apiv3key": {}, "publickey": {}},
payment.TypeStripe: {"secretkey": {}, "webhooksecret": {}},
payment.TypeAirwallex: {"apikey": {}, "webhooksecret": {}},
}
// providerPendingOrderProtectedConfigFields lists config keys that cannot be
@@ -121,10 +122,11 @@ var providerSensitiveConfigFields = map[string]map[string]struct{}{
// all provider identity fields that are snapshotted into orders or used by
// webhook/refund verification.
var providerPendingOrderProtectedConfigFields = map[string]map[string]struct{}{
payment.TypeEasyPay: {"pkey": {}, "pid": {}},
payment.TypeAlipay: {"privatekey": {}, "publickey": {}, "alipaypublickey": {}, "appid": {}},
payment.TypeWxpay: {"privatekey": {}, "apiv3key": {}, "publickey": {}, "appid": {}, "mpappid": {}, "mchid": {}, "publickeyid": {}, "certserial": {}},
payment.TypeStripe: {"secretkey": {}, "webhooksecret": {}},
payment.TypeEasyPay: {"pkey": {}, "pid": {}},
payment.TypeAlipay: {"privatekey": {}, "publickey": {}, "alipaypublickey": {}, "appid": {}},
payment.TypeWxpay: {"privatekey": {}, "apiv3key": {}, "publickey": {}, "appid": {}, "mpappid": {}, "mchid": {}, "publickeyid": {}, "certserial": {}},
payment.TypeStripe: {"secretkey": {}, "webhooksecret": {}, "currency": {}},
payment.TypeAirwallex: {"clientid": {}, "apikey": {}, "webhooksecret": {}, "apibase": {}, "accountid": {}, "currency": {}},
}
func isSensitiveProviderConfigField(providerKey, fieldName string) bool {
@@ -175,7 +177,7 @@ func (s *PaymentConfigService) countPendingOrdersByPlan(ctx context.Context, pla
}
var validProviderKeys = map[string]bool{
payment.TypeEasyPay: true, payment.TypeAlipay: true, payment.TypeWxpay: true, payment.TypeStripe: true,
payment.TypeEasyPay: true, payment.TypeAlipay: true, payment.TypeWxpay: true, payment.TypeStripe: true, payment.TypeAirwallex: true,
}
func (s *PaymentConfigService) CreateProviderInstance(ctx context.Context, req CreateProviderInstanceRequest) (*dbent.PaymentProviderInstance, error) {
@@ -44,6 +44,13 @@ func TestValidateProviderRequest(t *testing.T) {
supportedTypes: "",
wantErr: false,
},
{
name: "valid airwallex provider",
providerKey: payment.TypeAirwallex,
providerName: "Airwallex Provider",
supportedTypes: payment.TypeAirwallex,
wantErr: false,
},
{
name: "valid alipay provider",
providerKey: "alipay",
@@ -120,6 +127,7 @@ func TestIsSensitiveProviderConfigField(t *testing.T) {
{"stripe", "webhookSecret", true},
{"stripe", "SecretKey", true}, // case-insensitive
{"stripe", "publishableKey", false},
{"stripe", "currency", false},
{"stripe", "appId", false},
// Alipay
@@ -142,6 +150,14 @@ func TestIsSensitiveProviderConfigField(t *testing.T) {
{"easypay", "pid", false},
{"easypay", "apiBase", false},
// Airwallex
{payment.TypeAirwallex, "apiKey", true},
{payment.TypeAirwallex, "webhookSecret", true},
{payment.TypeAirwallex, "clientId", false},
{payment.TypeAirwallex, "apiBase", false},
{payment.TypeAirwallex, "accountId", false},
{payment.TypeAirwallex, "currency", false},
// Unknown provider: never sensitive
{"unknown", "secretKey", false},
}
@@ -395,6 +411,42 @@ func TestUpdateProviderInstanceRejectsProtectedConfigChangesWhilePendingOrders(t
fieldName: "pid",
wantValue: "pid-test",
},
{
name: "stripe currency",
providerKey: payment.TypeStripe,
createConfig: validStripeProviderConfig,
supportedType: []string{payment.TypeStripe},
updateConfig: map[string]string{"currency": "HKD"},
fieldName: "currency",
wantValue: "CNY",
},
{
name: "airwallex accountId",
providerKey: payment.TypeAirwallex,
createConfig: validAirwallexProviderConfig,
supportedType: []string{payment.TypeAirwallex},
updateConfig: map[string]string{"accountId": "acct-updated"},
fieldName: "accountId",
wantValue: "acct-test",
},
{
name: "airwallex currency",
providerKey: payment.TypeAirwallex,
createConfig: validAirwallexProviderConfig,
supportedType: []string{payment.TypeAirwallex},
updateConfig: map[string]string{"currency": "HKD"},
fieldName: "currency",
wantValue: "CNY",
},
{
name: "airwallex webhookSecret",
providerKey: payment.TypeAirwallex,
createConfig: validAirwallexProviderConfig,
supportedType: []string{payment.TypeAirwallex},
updateConfig: map[string]string{"webhookSecret": "whsec-updated"},
fieldName: "webhookSecret",
wantValue: "whsec-test",
},
}
for _, tc := range tests {
@@ -506,6 +558,39 @@ func TestUpdateProviderInstanceAllowsSafeConfigChangesWhilePendingOrders(t *test
}
}
func TestUpdateProviderInstanceClearsAirwallexAccountID(t *testing.T) {
t.Parallel()
ctx := context.Background()
client := newPaymentConfigServiceTestClient(t)
svc := &PaymentConfigService{
entClient: client,
encryptionKey: []byte("0123456789abcdef0123456789abcdef"),
}
instance, err := svc.CreateProviderInstance(ctx, CreateProviderInstanceRequest{
ProviderKey: payment.TypeAirwallex,
Name: "airwallex-clear-account",
Config: validAirwallexProviderConfig(t),
SupportedTypes: []string{payment.TypeAirwallex},
Enabled: true,
})
require.NoError(t, err)
updated, err := svc.UpdateProviderInstance(ctx, instance.ID, UpdateProviderInstanceRequest{
Config: map[string]string{"accountId": ""},
})
require.NoError(t, err)
require.NotNil(t, updated)
saved, err := client.PaymentProviderInstance.Get(ctx, instance.ID)
require.NoError(t, err)
cfg, err := svc.decryptConfig(saved.Config)
require.NoError(t, err)
require.Empty(t, cfg["accountId"])
require.Equal(t, "client-id-test", cfg["clientId"])
}
func createPendingProviderConfigOrder(t *testing.T, ctx context.Context, client *dbent.Client, instance *dbent.PaymentProviderInstance) {
t.Helper()
@@ -545,11 +630,26 @@ func providerPendingOrderPaymentType(providerKey string) string {
return payment.TypeWxpay
case payment.TypeAlipay:
return payment.TypeAlipay
case payment.TypeAirwallex:
return payment.TypeAirwallex
case payment.TypeStripe:
return payment.TypeStripe
default:
return payment.TypeAlipay
}
}
func validStripeProviderConfig(t *testing.T) map[string]string {
t.Helper()
return map[string]string{
"secretKey": "sk_test_123",
"publishableKey": "pk_test_123",
"webhookSecret": "whsec-test",
"currency": "CNY",
}
}
func boolPtrValue(v bool) *bool {
return &v
}
@@ -577,6 +677,19 @@ func validEasyPayProviderConfig(t *testing.T) map[string]string {
}
}
func validAirwallexProviderConfig(t *testing.T) map[string]string {
t.Helper()
return map[string]string{
"clientId": "client-id-test",
"apiKey": "api-key-test",
"webhookSecret": "whsec-test",
"apiBase": "https://api-demo.airwallex.com/api/v1",
"accountId": "acct-test",
"currency": "CNY",
}
}
func validWxpayProviderConfig(t *testing.T) map[string]string {
t.Helper()
@@ -103,6 +103,7 @@ type UpdatePaymentConfigRequest struct {
// MethodLimits holds per-payment-type limits.
type MethodLimits struct {
PaymentType string `json:"payment_type"`
Currency string `json:"currency"`
FeeRate float64 `json:"fee_rate"`
DailyLimit float64 `json:"daily_limit"`
SingleMin float64 `json:"single_min"`
@@ -0,0 +1,28 @@
package service
import (
"strings"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/payment"
)
func paymentProviderConfigCurrency(providerKey string, cfg map[string]string) string {
switch strings.TrimSpace(providerKey) {
case payment.TypeStripe, payment.TypeAirwallex:
currency, err := payment.NormalizePaymentCurrency(cfg["currency"])
if err == nil {
return currency
}
}
return payment.DefaultPaymentCurrency
}
func PaymentOrderCurrency(order *dbent.PaymentOrder) string {
if snapshot := psOrderProviderSnapshot(order); snapshot != nil {
if currency, err := payment.NormalizePaymentCurrency(snapshot.Currency); err == nil {
return currency
}
}
return payment.DefaultPaymentCurrency
}
@@ -101,13 +101,21 @@ func (s *PaymentService) confirmPayment(ctx context.Context, oid int64, tradeNo
})
return fmt.Errorf("invalid paid amount from provider: %v", paid)
}
if math.Abs(paid-o.PayAmount) > amountToleranceCNY {
if math.Abs(paid-o.PayAmount) > paymentAmountToleranceForCurrency(PaymentOrderCurrency(o)) {
s.writeAuditLog(ctx, o.ID, "PAYMENT_AMOUNT_MISMATCH", pk, map[string]any{"expected": o.PayAmount, "paid": paid, "tradeNo": tradeNo})
return fmt.Errorf("amount mismatch: expected %.2f, got %.2f", o.PayAmount, paid)
return fmt.Errorf("amount mismatch: expected %s, got %s", strconv.FormatFloat(o.PayAmount, 'f', -1, 64), strconv.FormatFloat(paid, 'f', -1, 64))
}
return s.toPaid(ctx, o, tradeNo, paid, pk)
}
func paymentAmountToleranceForCurrency(currency string) float64 {
minorUnit := payment.CurrencyMinorUnit(currency)
if minorUnit <= 2 {
return amountToleranceCNY
}
return math.Pow10(-minorUnit) / 2
}
func isValidProviderAmount(amount float64) bool {
return amount > 0 && !math.IsNaN(amount) && !math.IsInf(amount, 0)
}
@@ -366,3 +366,55 @@ func TestValidateProviderNotificationMetadataRejectsEasyPaySnapshotMismatch(t *t
})
assert.ErrorContains(t, err, "easypay pid mismatch")
}
func TestValidateProviderNotificationMetadataRejectsAirwallexSnapshotMismatch(t *testing.T) {
t.Parallel()
order := &dbent.PaymentOrder{
PaymentType: payment.TypeAirwallex,
ProviderSnapshot: map[string]any{
"schema_version": 2,
"merchant_id": "acct_expected",
"currency": "CNY",
},
}
err := validateProviderNotificationMetadata(order, payment.TypeAirwallex, map[string]string{
"account_id": "acct_other",
"currency": "CNY",
"status": "SUCCEEDED",
})
assert.ErrorContains(t, err, "airwallex account_id mismatch")
err = validateProviderNotificationMetadata(order, payment.TypeAirwallex, map[string]string{
"account_id": "acct_expected",
"currency": "USD",
"status": "SUCCEEDED",
})
assert.ErrorContains(t, err, "airwallex currency mismatch")
}
func TestValidateProviderNotificationMetadataRejectsStripeCurrencyMismatch(t *testing.T) {
t.Parallel()
order := &dbent.PaymentOrder{
PaymentType: payment.TypeStripe,
ProviderSnapshot: map[string]any{
"schema_version": 2,
"currency": "HKD",
},
}
err := validateProviderNotificationMetadata(order, payment.TypeStripe, map[string]string{
"currency": "USD",
})
assert.ErrorContains(t, err, "stripe currency mismatch")
}
func TestPaymentAmountToleranceForThreeDecimalCurrency(t *testing.T) {
t.Parallel()
assert.Equal(t, amountToleranceCNY, paymentAmountToleranceForCurrency("CNY"))
assert.Equal(t, amountToleranceCNY, paymentAmountToleranceForCurrency("JPY"))
assert.InDelta(t, 0.0005, paymentAmountToleranceForCurrency("KWD"), 1e-12)
}
+84 -7
View File
@@ -57,8 +57,17 @@ func (s *PaymentService) CreateOrder(ctx context.Context, req CreateOrderRequest
orderAmount = calculateCreditedBalance(req.Amount, cfg.BalanceRechargeMultiplier)
}
feeRate := cfg.RechargeFeeRate
payAmountStr := payment.CalculatePayAmount(limitAmount, feeRate)
payAmount, _ := strconv.ParseFloat(payAmountStr, 64)
methodCurrency := payment.DefaultPaymentCurrency
if s.configService != nil {
methodCurrency, err = s.configService.ValidateMethodCurrencyConsistency(ctx, req.PaymentType)
if err != nil {
return nil, err
}
}
payAmountStr, payAmount, err := calculateCreateOrderPayAmount(limitAmount, feeRate, methodCurrency)
if err != nil {
return nil, err
}
sel, err := s.selectCreateOrderInstance(ctx, req, cfg, payAmount)
if err != nil {
return nil, err
@@ -66,6 +75,19 @@ func (s *PaymentService) CreateOrder(ctx context.Context, req CreateOrderRequest
if err := s.validateSelectedCreateOrderInstance(ctx, req, sel); err != nil {
return nil, err
}
selectedCurrency := payment.DefaultPaymentCurrency
if sel != nil {
selectedCurrency = paymentProviderConfigCurrency(sel.ProviderKey, sel.Config)
}
if selectedCurrency != methodCurrency {
payAmountStr, payAmount, err = calculateCreateOrderPayAmount(limitAmount, feeRate, selectedCurrency)
if err != nil {
return nil, err
}
}
if err := validateSelectedCreateOrderAmountCurrency(payAmountStr, sel); err != nil {
return nil, err
}
oauthResp, err := s.maybeBuildWeChatOAuthRequiredResponseForSelection(ctx, req, limitAmount, payAmount, feeRate, sel)
if err != nil {
return nil, err
@@ -257,7 +279,7 @@ func buildPaymentOrderProviderSnapshot(sel *payment.InstanceSelection, req Creat
if merchantID := strings.TrimSpace(sel.Config["mchId"]); merchantID != "" {
snapshot["merchant_id"] = merchantID
}
snapshot["currency"] = "CNY"
snapshot["currency"] = payment.DefaultPaymentCurrency
}
if providerKey == payment.TypeAlipay {
if merchantAppID := strings.TrimSpace(sel.Config["appId"]); merchantAppID != "" {
@@ -269,6 +291,15 @@ func buildPaymentOrderProviderSnapshot(sel *payment.InstanceSelection, req Creat
snapshot["merchant_id"] = merchantID
}
}
if providerKey == payment.TypeStripe {
snapshot["currency"] = paymentProviderConfigCurrency(providerKey, sel.Config)
}
if providerKey == payment.TypeAirwallex {
if accountID := strings.TrimSpace(sel.Config["accountId"]); accountID != "" {
snapshot["merchant_id"] = accountID
}
snapshot["currency"] = paymentProviderConfigCurrency(providerKey, sel.Config)
}
if len(snapshot) == 1 {
return nil
@@ -377,7 +408,7 @@ func (s *PaymentService) invokeProvider(ctx context.Context, order *dbent.Paymen
return nil, infraerrors.ServiceUnavailable("PAYMENT_PROVIDER_MISCONFIGURED", "provider_misconfigured").
WithMetadata(map[string]string{"provider": sel.ProviderKey, "instance_id": sel.InstanceID})
}
subject := s.buildPaymentSubject(plan, limitAmount, cfg)
subject := s.buildPaymentSubject(plan, limitAmount, cfg, sel)
outTradeNo := order.OutTradeNo
canonicalReturnURL, err := CanonicalizeReturnURL(req.ReturnURL, req.SrcHost, req.SrcURL)
if err != nil {
@@ -466,20 +497,24 @@ func selectedInstanceSupportedTypes(sel *payment.InstanceSelection) string {
return sel.SupportedTypes
}
func (s *PaymentService) buildPaymentSubject(plan *dbent.SubscriptionPlan, limitAmount float64, cfg *PaymentConfig) string {
func (s *PaymentService) buildPaymentSubject(plan *dbent.SubscriptionPlan, limitAmount float64, cfg *PaymentConfig, sel *payment.InstanceSelection) string {
if plan != nil {
if plan.ProductName != "" {
return plan.ProductName
}
return "Sub2API Subscription " + plan.Name
}
amountStr := strconv.FormatFloat(limitAmount, 'f', 2, 64)
currency := payment.DefaultPaymentCurrency
if sel != nil {
currency = paymentProviderConfigCurrency(sel.ProviderKey, sel.Config)
}
amountStr := payment.FormatAmountForCurrency(limitAmount, currency)
pf := strings.TrimSpace(cfg.ProductNamePrefix)
sf := strings.TrimSpace(cfg.ProductNameSuffix)
if pf != "" || sf != "" {
return strings.TrimSpace(pf + " " + amountStr + " " + sf)
}
return "Sub2API " + amountStr + " CNY"
return "Sub2API " + amountStr + " " + currency
}
func (s *PaymentService) maybeBuildWeChatOAuthRequiredResponse(ctx context.Context, req CreateOrderRequest, amount, payAmount, feeRate float64) (*CreateOrderResponse, error) {
@@ -540,6 +575,44 @@ func (s *PaymentService) validateSelectedCreateOrderInstance(ctx context.Context
return nil
}
func calculateCreateOrderPayAmount(limitAmount, feeRate float64, currency string) (string, float64, error) {
if err := validateCreateOrderAmountCurrency(limitAmount, currency); err != nil {
return "", 0, err
}
payAmountStr := payment.CalculatePayAmountForCurrency(limitAmount, feeRate, currency)
if _, err := payment.AmountToMinorUnit(payAmountStr, currency); err != nil {
return "", 0, infraerrors.BadRequest("INVALID_AMOUNT", err.Error()).
WithMetadata(map[string]string{"currency": currency})
}
payAmount, err := strconv.ParseFloat(payAmountStr, 64)
if err != nil {
return "", 0, infraerrors.BadRequest("INVALID_AMOUNT", "invalid payment amount").
WithMetadata(map[string]string{"currency": currency})
}
return payAmountStr, payAmount, nil
}
func validateCreateOrderAmountCurrency(amount float64, currency string) error {
amountStr := strconv.FormatFloat(amount, 'f', -1, 64)
if _, err := payment.AmountToMinorUnit(amountStr, currency); err != nil {
return infraerrors.BadRequest("INVALID_AMOUNT", err.Error()).
WithMetadata(map[string]string{"currency": currency})
}
return nil
}
func validateSelectedCreateOrderAmountCurrency(payAmount string, sel *payment.InstanceSelection) error {
if sel == nil {
return nil
}
currency := paymentProviderConfigCurrency(sel.ProviderKey, sel.Config)
if _, err := payment.AmountToMinorUnit(payAmount, currency); err != nil {
return infraerrors.BadRequest("INVALID_AMOUNT", err.Error()).
WithMetadata(map[string]string{"currency": currency})
}
return nil
}
func requiresWeChatJSAPICompatibleSelection(req CreateOrderRequest, sel *payment.InstanceSelection) bool {
if sel == nil || sel.ProviderKey != payment.TypeWxpay || payment.GetBasePaymentType(req.PaymentType) != payment.TypeWxpay {
return false
@@ -596,6 +669,10 @@ func buildCreateOrderResponse(order *dbent.PaymentOrder, req CreateOrderRequest,
PayURL: pr.PayURL,
QRCode: pr.QRCode,
ClientSecret: pr.ClientSecret,
IntentID: pr.IntentID,
Currency: pr.Currency,
CountryCode: pr.CountryCode,
PaymentEnv: pr.PaymentEnv,
OAuth: pr.OAuth,
JSAPI: pr.JSAPI,
JSAPIPayload: pr.JSAPI,
@@ -188,6 +188,38 @@ func validateProviderSnapshotMetadata(order *dbent.PaymentOrder, providerKey str
return fmt.Errorf("easypay pid mismatch: expected %s, got %s", expected, actual)
}
}
case payment.TypeStripe:
if expected := strings.TrimSpace(snapshot.Currency); expected != "" {
actual := strings.ToUpper(strings.TrimSpace(metadata["currency"]))
if actual == "" {
return fmt.Errorf("stripe notification missing currency")
}
if !strings.EqualFold(expected, actual) {
return fmt.Errorf("stripe currency mismatch: expected %s, got %s", expected, actual)
}
}
case payment.TypeAirwallex:
if expected := strings.TrimSpace(snapshot.MerchantID); expected != "" {
actual := strings.TrimSpace(metadata["account_id"])
if actual == "" {
return fmt.Errorf("airwallex account_id missing")
}
if !strings.EqualFold(expected, actual) {
return fmt.Errorf("airwallex account_id mismatch: expected %s, got %s", expected, actual)
}
}
if expected := strings.TrimSpace(snapshot.Currency); expected != "" {
actual := strings.ToUpper(strings.TrimSpace(metadata["currency"]))
if actual == "" {
return fmt.Errorf("airwallex notification missing currency")
}
if !strings.EqualFold(expected, actual) {
return fmt.Errorf("airwallex currency mismatch: expected %s, got %s", expected, actual)
}
}
if actual := strings.TrimSpace(metadata["status"]); actual != "" && !strings.EqualFold(actual, "SUCCEEDED") {
return fmt.Errorf("airwallex status mismatch: expected SUCCEEDED, got %s", actual)
}
}
return nil
@@ -164,6 +164,30 @@ func TestBuildPaymentOrderProviderSnapshot_IncludesEasyPayMerchantIdentity(t *te
require.NotContains(t, snapshot, "pkey")
}
func TestBuildPaymentOrderProviderSnapshot_IncludesProviderCurrency(t *testing.T) {
t.Parallel()
stripeSnapshot := buildPaymentOrderProviderSnapshot(&payment.InstanceSelection{
InstanceID: "77",
ProviderKey: payment.TypeStripe,
Config: map[string]string{
"currency": "hkd",
},
}, CreateOrderRequest{})
require.Equal(t, "HKD", stripeSnapshot["currency"])
airwallexSnapshot := buildPaymentOrderProviderSnapshot(&payment.InstanceSelection{
InstanceID: "78",
ProviderKey: payment.TypeAirwallex,
Config: map[string]string{
"currency": "usd",
"accountId": "acct-78",
},
}, CreateOrderRequest{})
require.Equal(t, "USD", airwallexSnapshot["currency"])
require.Equal(t, "acct-78", airwallexSnapshot["merchant_id"])
}
func valueOrEmpty(v *string) string {
if v == nil {
return ""
@@ -91,6 +91,53 @@ func TestBuildCreateOrderResponseCopiesJSAPIPayload(t *testing.T) {
}
}
func TestValidateSelectedCreateOrderAmountCurrencyRejectsFractionalZeroDecimal(t *testing.T) {
t.Parallel()
err := validateSelectedCreateOrderAmountCurrency("100.50", &payment.InstanceSelection{
ProviderKey: payment.TypeStripe,
Config: map[string]string{"currency": "JPY"},
})
if err == nil {
t.Fatal("expected fractional JPY amount to fail")
}
if appErr := infraerrors.FromError(err); appErr.Reason != "INVALID_AMOUNT" {
t.Fatalf("reason = %q, want INVALID_AMOUNT", appErr.Reason)
}
}
func TestCalculateCreateOrderPayAmountUsesCurrencyPrecision(t *testing.T) {
t.Parallel()
amountStr, amount, err := calculateCreateOrderPayAmount(100, 2.5, "JPY")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if amountStr != "103" || amount != 103 {
t.Fatalf("JPY pay amount = (%q, %v), want (103, 103)", amountStr, amount)
}
amountStr, amount, err = calculateCreateOrderPayAmount(12.345, 1, "KWD")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if amountStr != "12.469" || amount != 12.469 {
t.Fatalf("KWD pay amount = (%q, %v), want (12.469, 12.469)", amountStr, amount)
}
}
func TestCalculateCreateOrderPayAmountRejectsFractionalZeroDecimal(t *testing.T) {
t.Parallel()
_, _, err := calculateCreateOrderPayAmount(100.5, 0, "JPY")
if err == nil {
t.Fatal("expected fractional JPY amount to fail")
}
if appErr := infraerrors.FromError(err); appErr.Reason != "INVALID_AMOUNT" {
t.Fatalf("reason = %q, want INVALID_AMOUNT", appErr.Reason)
}
}
func TestMaybeBuildWeChatOAuthRequiredResponse(t *testing.T) {
t.Setenv("PAYMENT_RESUME_SIGNING_KEY", "0123456789abcdef0123456789abcdef")
+28 -5
View File
@@ -226,10 +226,11 @@ func (s *PaymentService) PrepareRefund(ctx context.Context, oid int64, amt float
if amt <= 0 {
amt = o.Amount
}
if amt-o.Amount > amountToleranceCNY {
orderCurrency := PaymentOrderCurrency(o)
if amt-o.Amount > paymentAmountToleranceForCurrency(orderCurrency) {
return nil, nil, infraerrors.BadRequest("REFUND_AMOUNT_EXCEEDED", "refund amount exceeds recharge")
}
ga := calculateGatewayRefundAmount(o.Amount, o.PayAmount, amt)
ga := calculateGatewayRefundAmount(o.Amount, o.PayAmount, amt, orderCurrency)
rr := strings.TrimSpace(reason)
if rr == "" && o.RefundRequestReason != nil {
rr = *o.RefundRequestReason
@@ -339,13 +340,35 @@ func (s *PaymentService) gwRefund(ctx context.Context, p *RefundPlan) error {
})
return err
}
_, err = prov.Refund(ctx, payment.RefundRequest{
resp, err := prov.Refund(ctx, payment.RefundRequest{
TradeNo: p.Order.PaymentTradeNo,
OrderID: p.Order.OutTradeNo,
Amount: strconv.FormatFloat(p.GatewayAmount, 'f', 2, 64),
Amount: formatGatewayRefundAmount(p.GatewayAmount, p.Order),
Reason: p.Reason,
})
return err
if err != nil {
return err
}
return validateRefundProviderResponse(resp)
}
func formatGatewayRefundAmount(amount float64, order *dbent.PaymentOrder) string {
return payment.FormatAmountForCurrency(amount, PaymentOrderCurrency(order))
}
func validateRefundProviderResponse(resp *payment.RefundResponse) error {
if resp == nil {
return fmt.Errorf("payment refund response missing")
}
status := strings.TrimSpace(resp.Status)
switch status {
case payment.ProviderStatusSuccess, payment.ProviderStatusRefunded, payment.ProviderStatusPending:
return nil
case payment.ProviderStatusFailed:
return fmt.Errorf("payment refund failed: status %s", status)
default:
return fmt.Errorf("payment refund returned unknown status: %s", status)
}
}
// getRefundProvider creates a provider using the order's original instance config.
@@ -8,6 +8,7 @@ import (
"testing"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/payment"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/stretchr/testify/require"
@@ -184,3 +185,26 @@ func TestGwRefundRejectsAlipayMerchantIdentitySnapshotMismatch(t *testing.T) {
})
require.ErrorContains(t, err, "alipay app_id mismatch")
}
func TestCalculateGatewayRefundAmountUsesCurrencyPrecision(t *testing.T) {
require.InDelta(t, 6.173, calculateGatewayRefundAmount(100, 12.345, 50, "KWD"), 1e-12)
require.InDelta(t, 12.345, calculateGatewayRefundAmount(100, 12.345, 100, "KWD"), 1e-12)
require.InDelta(t, 52, calculateGatewayRefundAmount(100, 103, 50, "JPY"), 1e-12)
}
func TestFormatGatewayRefundAmountUsesOrderCurrency(t *testing.T) {
order := &dbent.PaymentOrder{
ProviderSnapshot: map[string]any{
"currency": "KWD",
},
}
require.Equal(t, "12.345", formatGatewayRefundAmount(12.345, order))
}
func TestValidateRefundProviderResponseAcceptsPending(t *testing.T) {
require.NoError(t, validateRefundProviderResponse(&payment.RefundResponse{Status: payment.ProviderStatusPending}))
require.NoError(t, validateRefundProviderResponse(&payment.RefundResponse{Status: payment.ProviderStatusSuccess}))
require.Error(t, validateRefundProviderResponse(&payment.RefundResponse{Status: payment.ProviderStatusFailed}))
require.Error(t, validateRefundProviderResponse(nil))
}
@@ -97,6 +97,10 @@ type CreateOrderResponse struct {
PayURL string `json:"pay_url,omitempty"`
QRCode string `json:"qr_code,omitempty"`
ClientSecret string `json:"client_secret,omitempty"`
IntentID string `json:"intent_id,omitempty"`
Currency string `json:"currency,omitempty"`
CountryCode string `json:"country_code,omitempty"`
PaymentEnv string `json:"payment_env,omitempty"`
OAuth *payment.WechatOAuthInfo `json:"oauth,omitempty"`
JSAPI *payment.WechatJSAPIPayload `json:"jsapi,omitempty"`
JSAPIPayload *payment.WechatJSAPIPayload `json:"jsapi_payload,omitempty"`