From 07bcffe78cea356fdee588dbe4f69870f4733ba4 Mon Sep 17 00:00:00 2001 From: "Yoshiaki Ueda (bootjp)" Date: Fri, 12 Jun 2026 15:44:32 +0900 Subject: [PATCH 1/4] adapter: split dynamodb.go schema/keys/attrvalue/errors into cohesive files (no behavior change) Pure code movement within package adapter: schema+table management, storage key prefix/segment encoding, attribute-value compare/clone/ key-encode helpers, and DynamoDB error mapping/transient classifier move out of dynamodb.go into dedicated dynamodb_schema.go / dynamodb_keys.go / dynamodb_attrvalue.go / dynamodb_errors.go. No declarations changed (byte-identical, gofmt; imports recomputed per file via goimports). --- adapter/dynamodb_attrvalue.go | 530 +++++++++++++++++++ adapter/dynamodb_errors.go | 78 +++ adapter/dynamodb_keys.go | 191 +++++++ adapter/dynamodb_schema.go | 938 ++++++++++++++++++++++++++++++++++ 4 files changed, 1737 insertions(+) create mode 100644 adapter/dynamodb_attrvalue.go create mode 100644 adapter/dynamodb_errors.go create mode 100644 adapter/dynamodb_keys.go create mode 100644 adapter/dynamodb_schema.go diff --git a/adapter/dynamodb_attrvalue.go b/adapter/dynamodb_attrvalue.go new file mode 100644 index 00000000..97ebc105 --- /dev/null +++ b/adapter/dynamodb_attrvalue.go @@ -0,0 +1,530 @@ +package adapter + +import ( + "bytes" + "encoding/base64" + "encoding/binary" + "math/big" + "sort" + "strconv" + "strings" + + "github.com/cockroachdb/errors" +) + +var attributeValueKeyExtractors = map[attributeValueKind]func(attributeValue) string{ + attributeValueKindString: func(attr attributeValue) string { return attr.stringValue() }, + attributeValueKindNumber: func(attr attributeValue) string { return attr.numberValue() }, + attributeValueKindBinary: func(attr attributeValue) string { return string(attr.B) }, +} + +var attributeValueKeyByteExtractors = map[attributeValueKind]func(attributeValue) []byte{ + attributeValueKindString: func(attr attributeValue) []byte { + return []byte(attr.stringValue()) + }, + attributeValueKindBinary: func(attr attributeValue) []byte { + return bytes.Clone(attr.B) + }, +} + +var attributeValueScalarEqualityComparators = map[attributeValueKind]func(attributeValue, attributeValue) bool{ + attributeValueKindString: func(left attributeValue, right attributeValue) bool { return left.stringValue() == right.stringValue() }, + attributeValueKindNumber: numberAttributeValueEqual, + attributeValueKindBinary: func(left attributeValue, right attributeValue) bool { return bytes.Equal(left.B, right.B) }, + attributeValueKindBool: func(left attributeValue, right attributeValue) bool { return *left.BOOL == *right.BOOL }, + attributeValueKindNull: func(left attributeValue, right attributeValue) bool { return *left.NULL == *right.NULL }, + attributeValueKindStringSet: func(left attributeValue, right attributeValue) bool { + return unorderedStringSlicesEqual(left.SS, right.SS) + }, + attributeValueKindNumberSet: func(left attributeValue, right attributeValue) bool { + return unorderedNumberSlicesEqual(left.NS, right.NS) + }, + attributeValueKindBinarySet: func(left attributeValue, right attributeValue) bool { + return unorderedBinarySlicesEqual(left.BS, right.BS) + }, +} + +var attributeValueSortFormatters = map[attributeValueKind]func(attributeValue) string{ + attributeValueKindString: func(attr attributeValue) string { return attr.stringValue() }, + attributeValueKindNumber: func(attr attributeValue) string { return attr.numberValue() }, + attributeValueKindBinary: func(attr attributeValue) string { return base64.RawURLEncoding.EncodeToString(attr.B) }, + attributeValueKindBool: formatBoolAttributeValue, + attributeValueKindNull: func(attributeValue) string { return "" }, + attributeValueKindStringSet: func(attr attributeValue) string { return strings.Join(sortedStringSlice(attr.SS), "\x00") }, + attributeValueKindNumberSet: func(attr attributeValue) string { return strings.Join(sortedNumberStrings(attr.NS), "\x00") }, + attributeValueKindBinarySet: func(attr attributeValue) string { return strings.Join(sortedBinaryStrings(attr.BS), "\x00") }, +} + +func attributeValueAsKey(attr attributeValue) (string, error) { + kind, count := detectAttributeValueKind(attr) + if count != 1 { + return "", errors.New("unsupported key attribute type") + } + extract, ok := attributeValueKeyExtractors[kind] + if !ok { + return "", errors.New("unsupported key attribute type") + } + return extract(attr), nil +} + +func attributeValueAsKeyBytes(attr attributeValue) ([]byte, error) { + kind, count := detectAttributeValueKind(attr) + if count != 1 { + return nil, errors.New("unsupported key attribute type") + } + if kind == attributeValueKindNumber { + return encodeNumericKeyBytes(attr.numberValue()) + } + extract, ok := attributeValueKeyByteExtractors[kind] + if !ok { + return nil, errors.New("unsupported key attribute type") + } + return extract(attr), nil +} + +func attributeValueAsKeySegment(attr attributeValue) ([]byte, error) { + raw, err := attributeValueAsKeyBytes(attr) + if err != nil { + return nil, err + } + return encodeDynamoKeySegment(raw), nil +} + +type numericKeyParts struct { + negative bool + exponent int64 + digits []byte +} + +func encodeNumericKeyBytes(v string) ([]byte, error) { + parts, err := parseNumericKeyParts(v) + if err != nil { + return nil, err + } + if len(parts.digits) == 0 { + return []byte{0x01}, nil + } + body := encodeOrderedSignedInt64(parts.exponent) + body = append(body, dynamoKeyEscapeByte) + body = append(body, parts.digits...) + if !parts.negative { + return append([]byte{0x02}, body...), nil + } + return append([]byte{0x00}, invertBytes(body)...), nil +} + +func parseNumericKeyParts(v string) (numericKeyParts, error) { + trimmed, negative, exp10, err := parseNumericKeyLiteral(v) + if err != nil { + return numericKeyParts{}, err + } + digits, exponent, zero, err := normalizeNumericKeyParts(trimmed, exp10) + if err != nil { + return numericKeyParts{}, err + } + if zero { + return numericKeyParts{}, nil + } + return numericKeyParts{ + negative: negative, + exponent: exponent, + digits: digits, + }, nil +} + +func parseNumericKeyLiteral(v string) (string, bool, int64, error) { + trimmed := strings.TrimSpace(v) + if trimmed == "" { + return "", false, 0, errors.New("unsupported key attribute type") + } + + negative := false + switch trimmed[0] { + case '+': + trimmed = trimmed[1:] + case '-': + negative = true + trimmed = trimmed[1:] + } + if trimmed == "" { + return "", false, 0, errors.New("unsupported key attribute type") + } + + exp10 := int64(0) + if idx := strings.IndexAny(trimmed, "eE"); idx >= 0 { + expPart := strings.TrimSpace(trimmed[idx+1:]) + trimmed = trimmed[:idx] + parsedExp, err := parseNumericExponent(expPart) + if err != nil { + return "", false, 0, err + } + exp10 = parsedExp + } + return trimmed, negative, exp10, nil +} + +func parseNumericExponent(expPart string) (int64, error) { + if expPart == "" { + return 0, errors.New("unsupported key attribute type") + } + parsedExp, err := strconv.ParseInt(expPart, 10, 64) + if err != nil { + return 0, errors.New("unsupported key attribute type") + } + return parsedExp, nil +} + +func normalizeNumericKeyParts(trimmed string, exp10 int64) ([]byte, int64, bool, error) { + intPart, fracPart, err := splitNumericMantissa(trimmed) + if err != nil { + return nil, 0, false, err + } + combined := intPart + fracPart + leadingZeros := leadingZeroCount(combined) + if leadingZeros == len(combined) { + return nil, 0, true, nil + } + digits := []byte(strings.TrimRight(combined[leadingZeros:], "0")) + if len(digits) == 0 { + return nil, 0, true, nil + } + exponent := int64(len(intPart)) + exp10 - int64(leadingZeros) + return digits, exponent, false, nil +} + +func splitNumericMantissa(trimmed string) (string, string, error) { + if strings.Count(trimmed, ".") > 1 { + return "", "", errors.New("unsupported key attribute type") + } + intPart := trimmed + fracPart := "" + if before, after, ok := strings.Cut(trimmed, "."); ok { + intPart = before + fracPart = after + } + if intPart == "" && fracPart == "" { + return "", "", errors.New("unsupported key attribute type") + } + if !decimalDigitsOnly(intPart) || !decimalDigitsOnly(fracPart) { + return "", "", errors.New("unsupported key attribute type") + } + return intPart, fracPart, nil +} + +func leadingZeroCount(v string) int { + count := 0 + for count < len(v) && v[count] == '0' { + count++ + } + return count +} + +func decimalDigitsOnly(v string) bool { + for i := range v { + if v[i] < '0' || v[i] > '9' { + return false + } + } + return true +} + +func encodeOrderedSignedInt64(v int64) []byte { + switch { + case v < 0: + return append([]byte{0x00}, invertBytes(encodeOrderedUint64(signedMagnitude(v)))...) + case v == 0: + return []byte{0x01} + default: + return append([]byte{0x02}, encodeOrderedUint64(uint64(v))...) + } +} + +func signedMagnitude(v int64) uint64 { + if v >= 0 { + return uint64(v) + } + abs := big.NewInt(v) + abs.Abs(abs) + return abs.Uint64() +} + +var orderedUint64LengthPrefix = [...]byte{0, 1, 2, 3, 4, 5, 6, 7, 8} + +func encodeOrderedUint64(v uint64) []byte { + var buf [8]byte + binary.BigEndian.PutUint64(buf[:], v) + start := 0 + for start < len(buf)-1 && buf[start] == 0 { + start++ + } + width := len(buf) - start + out := make([]byte, 0, width+1) + out = append(out, orderedUint64LengthPrefix[width]) + out = append(out, buf[start:]...) + return out +} + +func invertBytes(in []byte) []byte { + out := make([]byte, len(in)) + for i := range in { + out[i] = ^in[i] + } + return out +} + +func encodeDynamoKeySegment(raw []byte) []byte { + out := encodeDynamoKeySegmentPrefix(raw) + out = append(out, dynamoKeyEscapeByte, dynamoKeyTerminatorByte) + return out +} + +func encodeDynamoKeySegmentPrefix(raw []byte) []byte { + return appendEscapedDynamoKeyBytes(make([]byte, 0, len(raw)+dynamoKeySegmentOverhead), raw) +} + +func appendEscapedDynamoKeyBytes(dst []byte, raw []byte) []byte { + for _, b := range raw { + if b == dynamoKeyEscapeByte { + dst = append(dst, dynamoKeyEscapeByte, dynamoKeyEscapedZeroByte) + continue + } + dst = append(dst, b) + } + return dst +} + +func attributeValueEqual(left attributeValue, right attributeValue) bool { + leftKind, leftCount := detectAttributeValueKind(left) + rightKind, rightCount := detectAttributeValueKind(right) + if leftCount == 0 && rightCount == 0 { + return true + } + if leftCount != 1 || rightCount != 1 || leftKind != rightKind { + return false + } + if leftKind == attributeValueKindMap { + return mapAttributeValueEqual(left, right) + } + if leftKind == attributeValueKindList { + return listAttributeValueEqual(left, right) + } + compare, ok := attributeValueScalarEqualityComparators[leftKind] + if !ok { + return false + } + return compare(left, right) +} + +func numberAttributeValueEqual(left attributeValue, right attributeValue) bool { + cmp, ok := compareNumericAttributeString(left.numberValue(), right.numberValue()) + if !ok { + return left.numberValue() == right.numberValue() + } + return cmp == 0 +} + +func mapAttributeValueEqual(left attributeValue, right attributeValue) bool { + if len(left.M) != len(right.M) { + return false + } + for key, leftValue := range left.M { + rightValue, ok := right.M[key] + if !ok || !attributeValueEqual(leftValue, rightValue) { + return false + } + } + return true +} + +func listAttributeValueEqual(left attributeValue, right attributeValue) bool { + if len(left.L) != len(right.L) { + return false + } + for i := range left.L { + if !attributeValueEqual(left.L[i], right.L[i]) { + return false + } + } + return true +} + +func compareAttributeValueSortKey(left attributeValue, right attributeValue) int { + if left.hasNumberType() && right.hasNumberType() { + if cmp, ok := compareNumericAttributeString(left.numberValue(), right.numberValue()); ok { + return cmp + } + } + if left.hasBinaryType() && right.hasBinaryType() { + return bytes.Compare(left.B, right.B) + } + return strings.Compare(attributeValueSortFallback(left), attributeValueSortFallback(right)) +} + +func compareNumericAttributeString(left string, right string) (int, bool) { + leftRat := &big.Rat{} + rightRat := &big.Rat{} + if _, ok := leftRat.SetString(strings.TrimSpace(left)); !ok { + return 0, false + } + if _, ok := rightRat.SetString(strings.TrimSpace(right)); !ok { + return 0, false + } + return leftRat.Cmp(rightRat), true +} + +func attributeValueSortFallback(attr attributeValue) string { + kind, count := detectAttributeValueKind(attr) + if count != 1 { + return "" + } + format, ok := attributeValueSortFormatters[kind] + if !ok { + return "" + } + return format(attr) +} + +func formatBoolAttributeValue(attr attributeValue) string { + if *attr.BOOL { + return "1" + } + return "0" +} + +func unorderedStringSlicesEqual(left []string, right []string) bool { + if len(left) != len(right) { + return false + } + lv := sortedStringSlice(left) + rv := sortedStringSlice(right) + for i := range lv { + if lv[i] != rv[i] { + return false + } + } + return true +} + +func unorderedNumberSlicesEqual(left []string, right []string) bool { + if len(left) != len(right) { + return false + } + lv := sortedNumberStrings(left) + rv := sortedNumberStrings(right) + for i := range lv { + if lv[i] != rv[i] { + return false + } + } + return true +} + +func unorderedBinarySlicesEqual(left [][]byte, right [][]byte) bool { + if len(left) != len(right) { + return false + } + lv := sortedBinaryStrings(left) + rv := sortedBinaryStrings(right) + for i := range lv { + if lv[i] != rv[i] { + return false + } + } + return true +} + +func sortedStringSlice(in []string) []string { + out := append([]string(nil), in...) + sort.Strings(out) + return out +} + +func sortedNumberStrings(in []string) []string { + out := make([]string, len(in)) + for i := range in { + out[i] = canonicalNumberString(in[i]) + } + sort.Strings(out) + return out +} + +func sortedBinaryStrings(in [][]byte) []string { + out := make([]string, len(in)) + for i := range in { + out[i] = base64.RawURLEncoding.EncodeToString(in[i]) + } + sort.Strings(out) + return out +} + +func canonicalNumberString(v string) string { + rat := &big.Rat{} + if _, ok := rat.SetString(strings.TrimSpace(v)); !ok { + return strings.TrimSpace(v) + } + return rat.RatString() +} + +func reverseItems(items []map[string]attributeValue) { + for i, j := 0, len(items)-1; i < j; i, j = i+1, j-1 { + items[i], items[j] = items[j], items[i] + } +} + +func cloneAttributeValueMap(in map[string]attributeValue) map[string]attributeValue { + if in == nil { + return nil + } + out := make(map[string]attributeValue, len(in)) + for k, v := range in { + out[k] = cloneAttributeValue(v) + } + return out +} + +func cloneAttributeValueList(in []attributeValue) []attributeValue { + if in == nil { + return nil + } + out := make([]attributeValue, 0, len(in)) + for _, value := range in { + out = append(out, cloneAttributeValue(value)) + } + return out +} + +func cloneAttributeValue(in attributeValue) attributeValue { + out := attributeValue{} + if in.S != nil { + s := *in.S + out.S = &s + } + if in.N != nil { + n := *in.N + out.N = &n + } + if in.B != nil { + out.B = bytes.Clone(in.B) + } + if in.BOOL != nil { + b := *in.BOOL + out.BOOL = &b + } + if in.NULL != nil { + n := *in.NULL + out.NULL = &n + } + out.SS = cloneStringSlice(in.SS) + out.NS = cloneStringSlice(in.NS) + out.BS = cloneBinarySet(in.BS) + if in.L != nil { + out.L = make([]attributeValue, len(in.L)) + for i := range in.L { + out.L[i] = cloneAttributeValue(in.L[i]) + } + } + if in.M != nil { + out.M = cloneAttributeValueMap(in.M) + } + return out +} diff --git a/adapter/dynamodb_errors.go b/adapter/dynamodb_errors.go new file mode 100644 index 00000000..6f98ad2e --- /dev/null +++ b/adapter/dynamodb_errors.go @@ -0,0 +1,78 @@ +package adapter + +import ( + "net/http" + + "github.com/cockroachdb/errors" + json "github.com/goccy/go-json" +) + +type dynamoAPIError struct { + status int + errorType string + message string +} + +func (e *dynamoAPIError) Error() string { + if e == nil { + return "" + } + if e.message != "" { + return e.message + } + return http.StatusText(e.status) +} + +func newDynamoAPIError(status int, errorType string, message string) error { + return &dynamoAPIError{ + status: status, + errorType: errorType, + message: message, + } +} + +func writeDynamoErrorFromErr(w http.ResponseWriter, err error) { + var apiErr *dynamoAPIError + if errors.As(err, &apiErr) { + writeDynamoError(w, apiErr.status, apiErr.errorType, apiErr.message) + return + } + writeDynamoError(w, http.StatusInternalServerError, dynamoErrInternal, err.Error()) +} + +// dynamoErrIsTransient reports whether err is a transient/internal failure +// (Pebble error, context deadline, decode failure) as opposed to a structured +// validation/malformed-input error. A *dynamoAPIError is always a deliberate +// validation result (it carries an HTTP status + error type), so it is NOT +// transient; everything else — a raw wrapped store/context error — is. The +// lease pre-pass uses this to decide whether an unresolvable item must fail +// closed (transient) or may be skipped (validation, rejected identically by +// the read path). +func dynamoErrIsTransient(err error) bool { + if err == nil { + return false + } + var apiErr *dynamoAPIError + return !errors.As(err, &apiErr) +} + +func writeDynamoError(w http.ResponseWriter, status int, errorType string, message string) { + if message == "" { + message = http.StatusText(status) + } + + resp := map[string]string{"message": message} + if errorType != "" { + resp["__type"] = errorType + w.Header().Set("x-amzn-ErrorType", errorType) + } + w.Header().Set("Content-Type", "application/x-amz-json-1.0") + w.WriteHeader(status) + _ = json.NewEncoder(w).Encode(resp) +} + +func writeDynamoJSON(w http.ResponseWriter, payload any) { + w.Header().Set("Content-Type", "application/x-amz-json-1.0") + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(payload) +} diff --git a/adapter/dynamodb_keys.go b/adapter/dynamodb_keys.go new file mode 100644 index 00000000..d832d683 --- /dev/null +++ b/adapter/dynamodb_keys.go @@ -0,0 +1,191 @@ +package adapter + +import ( + "bytes" + "context" + "encoding/base64" + "strconv" + "strings" + + "github.com/bootjp/elastickv/store" + "github.com/cockroachdb/errors" +) + +func (d *DynamoDBServer) scanAllByPrefix(ctx context.Context, prefix []byte) ([]*store.KVPair, error) { + return d.scanAllByPrefixAt(ctx, prefix, snapshotTS(d.coordinator.Clock(), d.store)) +} + +func (d *DynamoDBServer) scanAllByPrefixAt(ctx context.Context, prefix []byte, readTS uint64) ([]*store.KVPair, error) { + readPin := d.pinReadTS(readTS) + defer readPin.Release() + + end := prefixScanEnd(prefix) + start := bytes.Clone(prefix) + + out := make([]*store.KVPair, 0, dynamoScanPageLimit) + for { + kvs, err := d.store.ScanAt(ctx, start, end, dynamoScanPageLimit, readTS) + if err != nil { + return nil, errors.WithStack(err) + } + if len(kvs) == 0 { + break + } + for _, kvp := range kvs { + if !bytes.HasPrefix(kvp.Key, prefix) { + return out, nil + } + out = append(out, kvp) + } + if len(kvs) < dynamoScanPageLimit { + break + } + start = nextScanCursor(kvs[len(kvs)-1].Key) + if end != nil && bytes.Compare(start, end) > 0 { + break + } + } + return out, nil +} + +func nextScanCursor(lastKey []byte) []byte { + next := make([]byte, len(lastKey)+1) + copy(next, lastKey) + return next +} + +func minInt(a int, b int) int { + if a < b { + return a + } + return b +} + +func dynamoTableMetaKey(tableName string) []byte { + return []byte(dynamoTableMetaPrefix + encodeDynamoSegment(tableName)) +} + +func dynamoTableGenerationKey(tableName string) []byte { + return []byte(dynamoTableGenerationPrefix + encodeDynamoSegment(tableName)) +} + +func dynamoItemPrefixForTable(tableName string, generation uint64) []byte { + return []byte(dynamoItemPrefix + encodeDynamoSegment(tableName) + "|" + strconv.FormatUint(generation, 10) + "|") +} + +func dynamoItemHashPrefixForTable(tableName string, generation uint64, hashKey []byte) []byte { + base := dynamoItemPrefixForTable(tableName, generation) + return append(base, hashKey...) +} + +func legacyDynamoItemHashPrefixForTable(tableName string, generation uint64, hashKey string) []byte { + return []byte( + dynamoItemPrefix + + encodeDynamoSegment(tableName) + "|" + + strconv.FormatUint(generation, 10) + "|" + + encodeDynamoSegment(hashKey) + "|", + ) +} + +func dynamoGSIPrefixForTable(tableName string, generation uint64) []byte { + return []byte( + dynamoGSIPrefix + + encodeDynamoSegment(tableName) + "|" + + strconv.FormatUint(generation, 10) + "|", + ) +} + +func dynamoGSIIndexPrefixForTable(tableName string, generation uint64, indexName string) []byte { + return []byte( + dynamoGSIPrefix + + encodeDynamoSegment(tableName) + "|" + + strconv.FormatUint(generation, 10) + "|" + + encodeDynamoSegment(indexName) + "|", + ) +} + +func dynamoGSIHashPrefixForTable(tableName string, generation uint64, indexName string, hashKey []byte) []byte { + base := dynamoGSIIndexPrefixForTable(tableName, generation, indexName) + return append(base, hashKey...) +} + +func legacyDynamoGSIHashPrefixForTable(tableName string, generation uint64, indexName string, hashKey string) []byte { + return []byte( + dynamoGSIPrefix + + encodeDynamoSegment(tableName) + "|" + + strconv.FormatUint(generation, 10) + "|" + + encodeDynamoSegment(indexName) + "|" + + encodeDynamoSegment(hashKey) + "|", + ) +} + +func dynamoGSIKey( + tableName string, + generation uint64, + indexName string, + indexHash []byte, + indexRange []byte, + pkHash []byte, + pkRange []byte, +) []byte { + key := dynamoGSIIndexPrefixForTable(tableName, generation, indexName) + key = append(key, indexHash...) + key = append(key, indexRange...) + key = append(key, pkHash...) + key = append(key, pkRange...) + return key +} + +func legacyDynamoGSIKey(tableName string, generation uint64, indexName string, indexHash string, indexRange string, pkHash string, pkRange string) []byte { + return []byte( + dynamoGSIPrefix + + encodeDynamoSegment(tableName) + "|" + + strconv.FormatUint(generation, 10) + "|" + + encodeDynamoSegment(indexName) + "|" + + encodeDynamoSegment(indexHash) + "|" + + encodeDynamoSegment(indexRange) + "|" + + encodeDynamoSegment(pkHash) + "|" + + encodeDynamoSegment(pkRange), + ) +} + +func dynamoItemKey(tableName string, generation uint64, hashKey []byte, rangeKey []byte) []byte { + key := dynamoItemPrefixForTable(tableName, generation) + key = append(key, hashKey...) + key = append(key, rangeKey...) + return key +} + +func legacyDynamoItemKey(tableName string, generation uint64, hashKey, rangeKey string) []byte { + return []byte( + dynamoItemPrefix + + encodeDynamoSegment(tableName) + "|" + + strconv.FormatUint(generation, 10) + "|" + + encodeDynamoSegment(hashKey) + "|" + + encodeDynamoSegment(rangeKey), + ) +} + +func encodeDynamoSegment(v string) string { + return base64.RawURLEncoding.EncodeToString([]byte(v)) +} + +func decodeDynamoSegment(v string) (string, error) { + b, err := base64.RawURLEncoding.DecodeString(v) + if err != nil { + return "", errors.WithStack(err) + } + return string(b), nil +} + +func tableNameFromMetaKey(key []byte) (string, bool) { + enc, ok := strings.CutPrefix(string(key), dynamoTableMetaPrefix) + if !ok || enc == "" { + return "", false + } + name, err := decodeDynamoSegment(enc) + if err != nil { + return "", false + } + return name, true +} diff --git a/adapter/dynamodb_schema.go b/adapter/dynamodb_schema.go new file mode 100644 index 00000000..91e0bee5 --- /dev/null +++ b/adapter/dynamodb_schema.go @@ -0,0 +1,938 @@ +package adapter + +import ( + "bytes" + "context" + "io" + "log/slog" + "net/http" + "sort" + "strconv" + "strings" + "time" + + "github.com/bootjp/elastickv/kv" + "github.com/bootjp/elastickv/store" + "github.com/cockroachdb/errors" + json "github.com/goccy/go-json" +) + +type createTableAttributeDefinition struct { + AttributeName string `json:"AttributeName"` + AttributeType string `json:"AttributeType"` +} + +type createTableKeySchemaElement struct { + AttributeName string `json:"AttributeName"` + KeyType string `json:"KeyType"` +} + +type createTableGSI struct { + IndexName string `json:"IndexName"` + KeySchema []createTableKeySchemaElement `json:"KeySchema"` + Projection createTableProjection `json:"Projection"` +} + +type createTableProjection struct { + ProjectionType string `json:"ProjectionType"` + NonKeyAttributes []string `json:"NonKeyAttributes"` +} + +type createTableInput struct { + TableName string `json:"TableName"` + AttributeDefinitions []createTableAttributeDefinition `json:"AttributeDefinitions"` + KeySchema []createTableKeySchemaElement `json:"KeySchema"` + GlobalSecondaryIndexes []createTableGSI `json:"GlobalSecondaryIndexes"` +} + +type deleteTableInput struct { + TableName string `json:"TableName"` +} + +type describeTableInput struct { + TableName string `json:"TableName"` +} + +type listTablesInput struct { + ExclusiveStartTableName string `json:"ExclusiveStartTableName"` + Limit int32 `json:"Limit"` +} + +type dynamoKeySchema struct { + HashKey string `json:"hash_key"` + RangeKey string `json:"range_key,omitempty"` +} + +type dynamoGSIProjection struct { + ProjectionType string `json:"projection_type"` + NonKeyAttributes []string `json:"non_key_attributes,omitempty"` +} + +type dynamoGlobalSecondaryIndex struct { + KeySchema dynamoKeySchema `json:"key_schema"` + Projection dynamoGSIProjection `json:"projection"` +} + +func (g *dynamoGlobalSecondaryIndex) UnmarshalJSON(b []byte) error { + type rawGSI struct { + KeySchema *dynamoKeySchema `json:"key_schema"` + Projection *dynamoGSIProjection `json:"projection"` + HashKey string `json:"hash_key"` + RangeKey string `json:"range_key"` + } + + var raw rawGSI + if err := json.Unmarshal(b, &raw); err != nil { + return errors.WithStack(err) + } + + if raw.KeySchema != nil { + g.KeySchema = *raw.KeySchema + } else { + g.KeySchema = dynamoKeySchema{ + HashKey: raw.HashKey, + RangeKey: raw.RangeKey, + } + } + + if raw.Projection != nil && strings.TrimSpace(raw.Projection.ProjectionType) != "" { + g.Projection = *raw.Projection + } else { + // Older schema snapshots stored only the key schema. Those GSIs behaved + // like ALL projections, so preserve that behavior when normalizing. + g.Projection = dynamoGSIProjection{ProjectionType: "ALL"} + } + + return nil +} + +type dynamoTableSchema struct { + TableName string `json:"table_name"` + AttributeDefinitions map[string]string `json:"attribute_definitions,omitempty"` + PrimaryKey dynamoKeySchema `json:"primary_key"` + GlobalSecondaryIndexes map[string]dynamoGlobalSecondaryIndex `json:"global_secondary_indexes,omitempty"` + KeyEncodingVersion int `json:"key_encoding_version,omitempty"` + MigratingFromGeneration uint64 `json:"migrating_from_generation,omitempty"` + Generation uint64 `json:"generation"` +} + +func (d *DynamoDBServer) createTable(w http.ResponseWriter, r *http.Request) { + in, err := decodeCreateTableInput(maxDynamoBodyReader(w, r)) + if err != nil { + writeDynamoErrorFromErr(w, err) + return + } + unlock := d.lockTableOperations([]string{in.TableName}) + defer unlock() + schema, err := buildCreateTableSchema(in) + if err != nil { + writeDynamoError(w, http.StatusBadRequest, dynamoErrValidation, err.Error()) + return + } + if err := d.createTableWithRetry(r.Context(), in.TableName, schema); err != nil { + writeDynamoErrorFromErr(w, err) + return + } + d.observeTables(r.Context(), schema.TableName) + writeDynamoJSON(w, map[string]any{ + "TableDescription": map[string]any{ + "TableName": in.TableName, + "TableStatus": "ACTIVE", + }, + }) +} + +func decodeCreateTableInput(bodyReader io.Reader) (createTableInput, error) { + body, err := io.ReadAll(bodyReader) + if err != nil { + return createTableInput{}, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) + } + var in createTableInput + if err := json.Unmarshal(body, &in); err != nil { + return createTableInput{}, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) + } + if strings.TrimSpace(in.TableName) == "" { + return createTableInput{}, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, "missing table name") + } + return in, nil +} + +func buildCreateTableSchema(in createTableInput) (*dynamoTableSchema, error) { + primary, err := parseCreateTableKeySchema(in.KeySchema) + if err != nil { + return nil, err + } + attrDefs := make(map[string]string, len(in.AttributeDefinitions)) + for _, def := range in.AttributeDefinitions { + if strings.TrimSpace(def.AttributeName) == "" { + return nil, errors.New("invalid attribute definition") + } + attrDefs[def.AttributeName] = def.AttributeType + } + gsis := make(map[string]dynamoGlobalSecondaryIndex, len(in.GlobalSecondaryIndexes)) + for _, gsi := range in.GlobalSecondaryIndexes { + if strings.TrimSpace(gsi.IndexName) == "" { + return nil, errors.New("invalid global secondary index") + } + ks, err := parseCreateTableKeySchema(gsi.KeySchema) + if err != nil { + return nil, err + } + projection, err := buildCreateTableProjection(gsi.Projection) + if err != nil { + return nil, err + } + gsis[gsi.IndexName] = dynamoGlobalSecondaryIndex{ + KeySchema: ks, + Projection: projection, + } + } + return &dynamoTableSchema{ + TableName: in.TableName, + AttributeDefinitions: attrDefs, + PrimaryKey: primary, + GlobalSecondaryIndexes: gsis, + KeyEncodingVersion: dynamoOrderedKeyEncodingV2, + }, nil +} + +func buildCreateTableProjection(in createTableProjection) (dynamoGSIProjection, error) { + switch strings.TrimSpace(in.ProjectionType) { + case "", "ALL": + return dynamoGSIProjection{ProjectionType: "ALL"}, nil + case "KEYS_ONLY": + return dynamoGSIProjection{ProjectionType: "KEYS_ONLY"}, nil + case "INCLUDE": + return dynamoGSIProjection{ + ProjectionType: "INCLUDE", + NonKeyAttributes: append([]string(nil), in.NonKeyAttributes...), + }, nil + default: + return dynamoGSIProjection{}, errors.New("invalid projection") + } +} + +func (d *DynamoDBServer) createTableWithRetry(ctx context.Context, tableName string, baseSchema *dynamoTableSchema) error { + backoff := transactRetryInitialBackoff + deadline := time.Now().Add(transactRetryMaxDuration) + for range transactRetryMaxAttempts { + readTS := d.nextTxnReadTS() + exists, err := d.tableExistsAt(ctx, tableName, readTS) + if err != nil { + return err + } + if exists { + return newDynamoAPIError(http.StatusBadRequest, dynamoErrResourceInUse, "table already exists") + } + nextGeneration, err := d.nextTableGenerationAt(ctx, tableName, readTS) + if err != nil { + return err + } + req, err := makeCreateTableRequest(baseSchema, nextGeneration) + if err != nil { + return err + } + if _, err := d.coordinator.Dispatch(ctx, req); err == nil { + return nil + } + if !isRetryableTransactWriteError(err) { + return errors.WithStack(err) + } + if err := waitRetryWithDeadline(ctx, deadline, backoff); err != nil { + return errors.WithStack(err) + } + backoff = nextTransactRetryBackoff(backoff) + } + return newDynamoAPIError(http.StatusInternalServerError, dynamoErrInternal, "create table retry attempts exhausted") +} + +func (d *DynamoDBServer) tableExistsAt(ctx context.Context, tableName string, readTS uint64) (bool, error) { + _, exists, err := d.loadTableSchemaAt(ctx, tableName, readTS) + if err != nil { + return false, errors.WithStack(err) + } + return exists, nil +} + +func (d *DynamoDBServer) nextTableGenerationAt(ctx context.Context, tableName string, readTS uint64) (uint64, error) { + lastGeneration, err := d.loadTableGenerationAt(ctx, tableName, readTS) + if err != nil { + return 0, errors.WithStack(err) + } + return lastGeneration + 1, nil +} + +func makeCreateTableRequest(baseSchema *dynamoTableSchema, nextGeneration uint64) (*kv.OperationGroup[kv.OP], error) { + schema := &dynamoTableSchema{ + TableName: baseSchema.TableName, + AttributeDefinitions: baseSchema.AttributeDefinitions, + PrimaryKey: baseSchema.PrimaryKey, + GlobalSecondaryIndexes: baseSchema.GlobalSecondaryIndexes, + KeyEncodingVersion: baseSchema.KeyEncodingVersion, + MigratingFromGeneration: baseSchema.MigratingFromGeneration, + Generation: nextGeneration, + } + schemaBytes, err := encodeStoredDynamoTableSchema(schema) + if err != nil { + return nil, errors.WithStack(err) + } + return &kv.OperationGroup[kv.OP]{ + IsTxn: true, + StartTS: 0, + Elems: []*kv.Elem[kv.OP]{ + {Op: kv.Put, Key: dynamoTableMetaKey(baseSchema.TableName), Value: schemaBytes}, + {Op: kv.Put, Key: dynamoTableGenerationKey(baseSchema.TableName), Value: []byte(strconv.FormatUint(nextGeneration, 10))}, + }, + }, nil +} + +func (d *DynamoDBServer) deleteTable(w http.ResponseWriter, r *http.Request) { + in, err := decodeDeleteTableInput(maxDynamoBodyReader(w, r)) + if err != nil { + writeDynamoErrorFromErr(w, err) + return + } + unlock := d.lockTableOperations([]string{in.TableName}) + defer unlock() + if err := d.deleteTableWithRetry(r.Context(), in.TableName); err != nil { + writeDynamoErrorFromErr(w, err) + return + } + resp := map[string]any{ + "TableDescription": map[string]any{ + "TableName": in.TableName, + "TableStatus": "DELETING", + }, + } + writeDynamoJSON(w, resp) +} + +func decodeDeleteTableInput(bodyReader io.Reader) (deleteTableInput, error) { + body, err := io.ReadAll(bodyReader) + if err != nil { + return deleteTableInput{}, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) + } + var in deleteTableInput + if err := json.Unmarshal(body, &in); err != nil { + return deleteTableInput{}, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) + } + if strings.TrimSpace(in.TableName) == "" { + return deleteTableInput{}, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, "missing table name") + } + return in, nil +} + +func (d *DynamoDBServer) deleteTableWithRetry(ctx context.Context, tableName string) error { + backoff := transactRetryInitialBackoff + deadline := time.Now().Add(transactRetryMaxDuration) + for range transactRetryMaxAttempts { + readTS := d.nextTxnReadTS() + schema, exists, err := d.loadTableSchemaAt(ctx, tableName, readTS) + if err != nil { + return errors.WithStack(err) + } + if !exists { + return newDynamoAPIError(http.StatusBadRequest, dynamoErrResourceNotFound, "table not found") + } + + req := &kv.OperationGroup[kv.OP]{ + IsTxn: true, + StartTS: 0, + Elems: []*kv.Elem[kv.OP]{ + {Op: kv.Del, Key: dynamoTableMetaKey(tableName)}, + }, + } + if _, err := d.coordinator.Dispatch(ctx, req); err != nil { + if !isRetryableTransactWriteError(err) { + return errors.WithStack(err) + } + } else { + d.launchDeletedTableCleanup(tableName, schema.Generation) + return nil + } + + if err := waitRetryWithDeadline(ctx, deadline, backoff); err != nil { + return errors.WithStack(err) + } + backoff = nextTransactRetryBackoff(backoff) + } + return newDynamoAPIError(http.StatusInternalServerError, dynamoErrInternal, "delete table retry attempts exhausted") +} + +func (d *DynamoDBServer) launchDeletedTableCleanup(tableName string, generation uint64) { + go func() { + ctx, cancel := context.WithTimeout(context.Background(), tableCleanupAsyncTimeout) + defer cancel() + if err := d.cleanupDeletedTableGeneration(ctx, tableName, generation); err != nil { + slog.Error("dynamodb delete table cleanup failed", + "table", tableName, + "generation", generation, + "error", err, + ) + } + }() +} + +func (d *DynamoDBServer) cleanupDeletedTableGeneration(ctx context.Context, tableName string, generation uint64) error { + prefixes := [][]byte{ + dynamoItemPrefixForTable(tableName, generation), + dynamoGSIPrefixForTable(tableName, generation), + } + // Dispatch a single DEL_PREFIX operation per prefix. The FSM on each node + // scans and writes tombstones locally, avoiding the enumerate-then-batch- + // delete loop that previously required many Raft proposals. + for _, prefix := range prefixes { + _, err := d.coordinator.Dispatch(ctx, &kv.OperationGroup[kv.OP]{ + Elems: []*kv.Elem[kv.OP]{ + {Op: kv.DelPrefix, Key: prefix}, + }, + }) + if err != nil { + return errors.WithStack(err) + } + } + return nil +} + +func (d *DynamoDBServer) dispatchDeleteBatch(ctx context.Context, keys [][]byte) error { + elems := make([]*kv.Elem[kv.OP], 0, len(keys)) + for _, key := range keys { + elems = append(elems, &kv.Elem[kv.OP]{Op: kv.Del, Key: key}) + } + req := &kv.OperationGroup[kv.OP]{ + IsTxn: false, + Elems: elems, + } + _, err := d.coordinator.Dispatch(ctx, req) + if err != nil { + return errors.WithStack(err) + } + return nil +} + +func (d *DynamoDBServer) describeTable(w http.ResponseWriter, r *http.Request) { + body, err := io.ReadAll(maxDynamoBodyReader(w, r)) + if err != nil { + writeDynamoError(w, http.StatusBadRequest, dynamoErrValidation, err.Error()) + return + } + var in describeTableInput + if err := json.Unmarshal(body, &in); err != nil { + writeDynamoError(w, http.StatusBadRequest, dynamoErrValidation, err.Error()) + return + } + if strings.TrimSpace(in.TableName) == "" { + writeDynamoError(w, http.StatusBadRequest, dynamoErrValidation, "missing table name") + return + } + if err := d.ensureLegacyTableMigration(r.Context(), in.TableName); err != nil { + writeDynamoErrorFromErr(w, err) + return + } + schema, exists, err := d.loadTableSchema(r.Context(), in.TableName) + if err != nil { + writeDynamoError(w, http.StatusInternalServerError, dynamoErrInternal, err.Error()) + return + } + if !exists { + writeDynamoError(w, http.StatusBadRequest, dynamoErrResourceNotFound, "table not found") + return + } + writeDynamoJSON(w, map[string]any{"Table": describeTableShape(schema)}) +} + +func (d *DynamoDBServer) listTables(w http.ResponseWriter, r *http.Request) { + in, err := decodeListTablesInput(maxDynamoBodyReader(w, r)) + if err != nil { + writeDynamoErrorFromErr(w, err) + return + } + names, err := d.listTableNames(r.Context()) + if err != nil { + writeDynamoError(w, http.StatusInternalServerError, dynamoErrInternal, err.Error()) + return + } + outNames, hasNext := paginateTableNames(names, in) + + resp := map[string]any{"TableNames": outNames} + if hasNext && len(outNames) > 0 { + resp["LastEvaluatedTableName"] = outNames[len(outNames)-1] + } + writeDynamoJSON(w, resp) +} + +func decodeListTablesInput(bodyReader io.Reader) (listTablesInput, error) { + body, err := io.ReadAll(bodyReader) + if err != nil { + return listTablesInput{}, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) + } + var in listTablesInput + if len(bytes.TrimSpace(body)) == 0 { + return in, nil + } + if err := json.Unmarshal(body, &in); err != nil { + return listTablesInput{}, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) + } + return in, nil +} + +func paginateTableNames(names []string, in listTablesInput) ([]string, bool) { + start := findExclusiveStartIndex(names, in.ExclusiveStartTableName) + limit := resolveTableListLimit(in.Limit, len(names)) + end := min(start+limit, len(names)) + return names[start:end], end < len(names) +} + +func findExclusiveStartIndex(names []string, startName string) int { + if startName == "" { + return 0 + } + for i, name := range names { + if name == startName { + return i + 1 + } + } + return 0 +} + +func resolveTableListLimit(limit int32, tableCount int) int { + if limit <= 0 || int(limit) >= tableCount { + return tableCount + } + return int(limit) +} + +func (d *DynamoDBServer) listTableNames(ctx context.Context) ([]string, error) { + kvs, err := d.scanAllByPrefix(ctx, []byte(dynamoTableMetaPrefix)) + if err != nil { + return nil, err + } + names := make([]string, 0, len(kvs)) + for _, kvp := range kvs { + name, ok := tableNameFromMetaKey(kvp.Key) + if !ok { + continue + } + names = append(names, name) + } + sort.Strings(names) + return names, nil +} + +func parseCreateTableKeySchema(elems []createTableKeySchemaElement) (dynamoKeySchema, error) { + var ks dynamoKeySchema + for _, e := range elems { + switch strings.ToUpper(strings.TrimSpace(e.KeyType)) { + case "HASH": + ks.HashKey = e.AttributeName + case "RANGE": + ks.RangeKey = e.AttributeName + } + } + if strings.TrimSpace(ks.HashKey) == "" { + return dynamoKeySchema{}, errors.New("missing HASH key schema") + } + return ks, nil +} + +func (t *dynamoTableSchema) keySchemaForQuery(indexName string) (dynamoKeySchema, error) { + if strings.TrimSpace(indexName) == "" { + return t.PrimaryKey, nil + } + gsi, ok := t.GlobalSecondaryIndexes[indexName] + if !ok { + return dynamoKeySchema{}, errors.New("unknown index") + } + return gsi.KeySchema, nil +} + +func (t *dynamoTableSchema) gsiProjectedAttributeSet(indexName string) (bool, map[string]struct{}, error) { + gsi, ok := t.GlobalSecondaryIndexes[indexName] + if !ok { + return false, nil, errors.New("unknown index") + } + if strings.EqualFold(gsi.Projection.ProjectionType, "ALL") { + return true, nil, nil + } + out := map[string]struct{}{ + t.PrimaryKey.HashKey: {}, + gsi.KeySchema.HashKey: {}, + } + if t.PrimaryKey.RangeKey != "" { + out[t.PrimaryKey.RangeKey] = struct{}{} + } + if gsi.KeySchema.RangeKey != "" { + out[gsi.KeySchema.RangeKey] = struct{}{} + } + for _, attr := range gsi.Projection.NonKeyAttributes { + out[attr] = struct{}{} + } + return false, out, nil +} + +func (t *dynamoTableSchema) projectItemForIndex(indexName string, item map[string]attributeValue) (map[string]attributeValue, error) { + allProjected, projected, err := t.gsiProjectedAttributeSet(indexName) + if err != nil { + return nil, err + } + if allProjected { + return cloneAttributeValueMap(item), nil + } + out := make(map[string]attributeValue, len(projected)) + for attr := range projected { + if value, ok := item[attr]; ok { + out[attr] = cloneAttributeValue(value) + } + } + return out, nil +} + +func (t *dynamoTableSchema) usesOrderedKeyEncoding() bool { + return t != nil && t.KeyEncodingVersion >= dynamoOrderedKeyEncodingV2 +} + +func (t *dynamoTableSchema) needsLegacyKeyMigration() bool { + return t != nil && (!t.usesOrderedKeyEncoding() || t.MigratingFromGeneration != 0) +} + +func (t *dynamoTableSchema) migrationSourceSchema() *dynamoTableSchema { + if t == nil || t.MigratingFromGeneration == 0 { + return nil + } + return &dynamoTableSchema{ + TableName: t.TableName, + AttributeDefinitions: t.AttributeDefinitions, + PrimaryKey: t.PrimaryKey, + GlobalSecondaryIndexes: t.GlobalSecondaryIndexes, + KeyEncodingVersion: 0, + Generation: t.MigratingFromGeneration, + } +} + +func (t *dynamoTableSchema) itemKeyFromAttributes(attrs map[string]attributeValue) ([]byte, error) { + if !t.usesOrderedKeyEncoding() { + return t.legacyItemKeyFromAttributes(attrs) + } + primary, err := t.primaryKeyValues(attrs) + if err != nil { + return nil, err + } + return dynamoItemKey(t.TableName, t.Generation, primary.hash, primary.rangeKey), nil +} + +func (t *dynamoTableSchema) legacyItemKeyFromAttributes(attrs map[string]attributeValue) ([]byte, error) { + hashAttr, ok := attrs[t.PrimaryKey.HashKey] + if !ok { + return nil, errors.New("missing hash key attribute") + } + hashKey, err := attributeValueAsKey(hashAttr) + if err != nil { + return nil, err + } + rangeKey := "" + if t.PrimaryKey.RangeKey != "" { + rangeAttr, ok := attrs[t.PrimaryKey.RangeKey] + if !ok { + return nil, errors.New("missing range key attribute") + } + rangeKey, err = attributeValueAsKey(rangeAttr) + if err != nil { + return nil, err + } + } + return legacyDynamoItemKey(t.TableName, t.Generation, hashKey, rangeKey), nil +} + +func (t *dynamoTableSchema) gsiKeyFromAttributes(indexName string, attrs map[string]attributeValue) ([]byte, bool, error) { + if !t.usesOrderedKeyEncoding() { + return t.legacyGSIKeyFromAttributes(indexName, attrs) + } + gsi, ok := t.GlobalSecondaryIndexes[indexName] + if !ok { + return nil, false, errors.New("global secondary index not found") + } + primary, err := t.primaryKeyValues(attrs) + if err != nil { + return nil, false, err + } + index, include, err := gsiKeyValues(attrs, gsi.KeySchema) + if err != nil || !include { + return nil, include, err + } + return dynamoGSIKey(t.TableName, t.Generation, indexName, index.hash, index.rangeKey, primary.hash, primary.rangeKey), true, nil +} + +func (t *dynamoTableSchema) legacyGSIKeyFromAttributes(indexName string, attrs map[string]attributeValue) ([]byte, bool, error) { + gsi, ok := t.GlobalSecondaryIndexes[indexName] + if !ok { + return nil, false, errors.New("global secondary index not found") + } + pkHash, pkRange, err := t.legacyPrimaryKeyValues(attrs) + if err != nil { + return nil, false, err + } + indexHash, indexRange, include, err := legacyGSIKeyValues(attrs, gsi.KeySchema) + if err != nil || !include { + return nil, include, err + } + return legacyDynamoGSIKey(t.TableName, t.Generation, indexName, indexHash, indexRange, pkHash, pkRange), true, nil +} + +func (t *dynamoTableSchema) gsiEntryKeysForItem(attrs map[string]attributeValue) ([][]byte, error) { + if len(t.GlobalSecondaryIndexes) == 0 || len(attrs) == 0 { + return nil, nil + } + if !t.usesOrderedKeyEncoding() { + return t.legacyGSIEntryKeysForItem(attrs) + } + primary, err := t.primaryKeyValues(attrs) + if err != nil { + return nil, err + } + indexNames := sortedGSIIndexNames(t.GlobalSecondaryIndexes) + keys := make([][]byte, 0, len(indexNames)) + for _, indexName := range indexNames { + gsi := t.GlobalSecondaryIndexes[indexName] + index, include, err := gsiKeyValues(attrs, gsi.KeySchema) + if err != nil { + return nil, err + } + if !include { + continue + } + keys = append(keys, dynamoGSIKey(t.TableName, t.Generation, indexName, index.hash, index.rangeKey, primary.hash, primary.rangeKey)) + } + return keys, nil +} + +func (t *dynamoTableSchema) legacyGSIEntryKeysForItem(attrs map[string]attributeValue) ([][]byte, error) { + primaryHash, primaryRange, err := t.legacyPrimaryKeyValues(attrs) + if err != nil { + return nil, err + } + indexNames := sortedGSIIndexNames(t.GlobalSecondaryIndexes) + keys := make([][]byte, 0, len(indexNames)) + for _, indexName := range indexNames { + gsi := t.GlobalSecondaryIndexes[indexName] + indexHash, indexRange, include, err := legacyGSIKeyValues(attrs, gsi.KeySchema) + if err != nil { + return nil, err + } + if !include { + continue + } + keys = append(keys, legacyDynamoGSIKey(t.TableName, t.Generation, indexName, indexHash, indexRange, primaryHash, primaryRange)) + } + return keys, nil +} + +func (t *dynamoTableSchema) legacyPrimaryKeyValues(attrs map[string]attributeValue) (string, string, error) { + hashAttr, ok := attrs[t.PrimaryKey.HashKey] + if !ok { + return "", "", errors.New("missing hash key attribute") + } + hash, err := attributeValueAsKey(hashAttr) + if err != nil { + return "", "", err + } + rangeKey := "" + if t.PrimaryKey.RangeKey != "" { + rangeAttr, ok := attrs[t.PrimaryKey.RangeKey] + if !ok { + return "", "", errors.New("missing range key attribute") + } + rangeKey, err = attributeValueAsKey(rangeAttr) + if err != nil { + return "", "", err + } + } + return hash, rangeKey, nil +} + +type dynamoEncodedKeyValues struct { + hash []byte + rangeKey []byte +} + +func (t *dynamoTableSchema) primaryKeyValues(attrs map[string]attributeValue) (dynamoEncodedKeyValues, error) { + hashAttr, ok := attrs[t.PrimaryKey.HashKey] + if !ok { + return dynamoEncodedKeyValues{}, errors.New("missing hash key attribute") + } + hash, err := attributeValueAsKeySegment(hashAttr) + if err != nil { + return dynamoEncodedKeyValues{}, err + } + var rangeKey []byte + if t.PrimaryKey.RangeKey != "" { + rangeAttr, ok := attrs[t.PrimaryKey.RangeKey] + if !ok { + return dynamoEncodedKeyValues{}, errors.New("missing range key attribute") + } + rangeKey, err = attributeValueAsKeySegment(rangeAttr) + if err != nil { + return dynamoEncodedKeyValues{}, err + } + } + return dynamoEncodedKeyValues{hash: hash, rangeKey: rangeKey}, nil +} + +func gsiKeyValues(attrs map[string]attributeValue, ks dynamoKeySchema) (dynamoEncodedKeyValues, bool, error) { + hashAttr, ok := attrs[ks.HashKey] + if !ok { + return dynamoEncodedKeyValues{}, false, nil + } + hash, err := attributeValueAsKeySegment(hashAttr) + if err != nil { + return dynamoEncodedKeyValues{}, false, err + } + var rangeKey []byte + if ks.RangeKey != "" { + rangeAttr, ok := attrs[ks.RangeKey] + if !ok { + return dynamoEncodedKeyValues{}, false, nil + } + rangeKey, err = attributeValueAsKeySegment(rangeAttr) + if err != nil { + return dynamoEncodedKeyValues{}, false, err + } + } + return dynamoEncodedKeyValues{hash: hash, rangeKey: rangeKey}, true, nil +} + +func legacyGSIKeyValues(attrs map[string]attributeValue, ks dynamoKeySchema) (string, string, bool, error) { + hashAttr, ok := attrs[ks.HashKey] + if !ok { + return "", "", false, nil + } + hash, err := attributeValueAsKey(hashAttr) + if err != nil { + return "", "", false, err + } + rangeKey := "" + if ks.RangeKey != "" { + rangeAttr, ok := attrs[ks.RangeKey] + if !ok { + return "", "", false, nil + } + rangeKey, err = attributeValueAsKey(rangeAttr) + if err != nil { + return "", "", false, err + } + } + return hash, rangeKey, true, nil +} + +func sortedGSIIndexNames(indexes map[string]dynamoGlobalSecondaryIndex) []string { + names := make([]string, 0, len(indexes)) + for name := range indexes { + names = append(names, name) + } + sort.Strings(names) + return names +} + +func describeTableShape(t *dynamoTableSchema) map[string]any { + attrDefs := make([]map[string]string, 0, len(t.AttributeDefinitions)) + for name, typ := range t.AttributeDefinitions { + attrDefs = append(attrDefs, map[string]string{ + "AttributeName": name, + "AttributeType": typ, + }) + } + sort.Slice(attrDefs, func(i, j int) bool { + return attrDefs[i]["AttributeName"] < attrDefs[j]["AttributeName"] + }) + + keySchema := []map[string]string{{ + "AttributeName": t.PrimaryKey.HashKey, + "KeyType": "HASH", + }} + if t.PrimaryKey.RangeKey != "" { + keySchema = append(keySchema, map[string]string{ + "AttributeName": t.PrimaryKey.RangeKey, + "KeyType": "RANGE", + }) + } + + resp := map[string]any{ + "TableName": t.TableName, + "TableStatus": "ACTIVE", + "KeySchema": keySchema, + "AttributeDefinitions": attrDefs, + } + + if len(t.GlobalSecondaryIndexes) > 0 { + gsis := make([]map[string]any, 0, len(t.GlobalSecondaryIndexes)) + indexNames := make([]string, 0, len(t.GlobalSecondaryIndexes)) + for name := range t.GlobalSecondaryIndexes { + indexNames = append(indexNames, name) + } + sort.Strings(indexNames) + for _, name := range indexNames { + gsi := t.GlobalSecondaryIndexes[name] + ks := gsi.KeySchema + projection := map[string]any{ + "ProjectionType": gsi.Projection.ProjectionType, + } + if len(gsi.Projection.NonKeyAttributes) > 0 { + projection["NonKeyAttributes"] = append([]string(nil), gsi.Projection.NonKeyAttributes...) + } + idxKeySchema := []map[string]string{{ + "AttributeName": ks.HashKey, + "KeyType": "HASH", + }} + if ks.RangeKey != "" { + idxKeySchema = append(idxKeySchema, map[string]string{ + "AttributeName": ks.RangeKey, + "KeyType": "RANGE", + }) + } + indexDesc := map[string]any{ + "IndexName": name, + "IndexStatus": "ACTIVE", + "KeySchema": idxKeySchema, + "Projection": projection, + } + gsis = append(gsis, indexDesc) + } + resp["GlobalSecondaryIndexes"] = gsis + } + + return resp +} + +func (d *DynamoDBServer) loadTableSchema(ctx context.Context, tableName string) (*dynamoTableSchema, bool, error) { + return d.loadTableSchemaAt(ctx, tableName, snapshotTS(d.coordinator.Clock(), d.store)) +} + +func (d *DynamoDBServer) loadTableSchemaAt(ctx context.Context, tableName string, ts uint64) (*dynamoTableSchema, bool, error) { + b, err := d.store.GetAt(ctx, dynamoTableMetaKey(tableName), ts) + if err != nil { + if errors.Is(err, store.ErrKeyNotFound) { + return nil, false, nil + } + return nil, false, errors.WithStack(err) + } + schema, err := decodeStoredDynamoTableSchema(b) + if err != nil { + return nil, false, err + } + d.observeTables(ctx, schema.TableName) + return schema, true, nil +} + +func (d *DynamoDBServer) loadTableGenerationAt(ctx context.Context, tableName string, ts uint64) (uint64, error) { + b, err := d.store.GetAt(ctx, dynamoTableGenerationKey(tableName), ts) + if err != nil { + if errors.Is(err, store.ErrKeyNotFound) { + return 0, nil + } + return 0, errors.WithStack(err) + } + gen, err := strconv.ParseUint(strings.TrimSpace(string(b)), 10, 64) + if err != nil { + return 0, errors.WithStack(err) + } + return gen, nil +} From 95a29e1f42be0af17a20931e48577a236427b613 Mon Sep 17 00:00:00 2001 From: "Yoshiaki Ueda (bootjp)" Date: Fri, 12 Jun 2026 15:44:32 +0900 Subject: [PATCH 2/4] adapter: split dynamodb.go item/query/transact/lease handlers into cohesive files (no behavior change) Pure code movement within package adapter: GetItem read path, Put/Update/Delete write paths + one-phase dedup, Query/Scan + read iterators + key-condition parsing, BatchWrite/TransactWrite/TransactGet, and the lease pre-pass helpers (leaseCheckScan/leaseCheckQuery/ leaseCheckTransactGetItems, multiShardReadLeasePlan, queryLeaseKey, leaseReadGroupKeys) move into dynamodb_item_read.go / dynamodb_item_write.go / dynamodb_query_scan.go / dynamodb_transact.go / dynamodb_lease.go. No declarations changed. --- adapter/dynamodb_item_read.go | 189 +++ adapter/dynamodb_item_write.go | 921 +++++++++++++ adapter/dynamodb_lease.go | 522 ++++++++ adapter/dynamodb_query_scan.go | 2281 ++++++++++++++++++++++++++++++++ adapter/dynamodb_transact.go | 1220 +++++++++++++++++ 5 files changed, 5133 insertions(+) create mode 100644 adapter/dynamodb_item_read.go create mode 100644 adapter/dynamodb_item_write.go create mode 100644 adapter/dynamodb_lease.go create mode 100644 adapter/dynamodb_query_scan.go create mode 100644 adapter/dynamodb_transact.go diff --git a/adapter/dynamodb_item_read.go b/adapter/dynamodb_item_read.go new file mode 100644 index 00000000..26f64e01 --- /dev/null +++ b/adapter/dynamodb_item_read.go @@ -0,0 +1,189 @@ +package adapter + +import ( + "bytes" + "context" + "io" + "net/http" + "strings" + + "github.com/bootjp/elastickv/kv" + "github.com/bootjp/elastickv/store" + "github.com/cockroachdb/errors" + json "github.com/goccy/go-json" +) + +type getItemInput struct { + TableName string `json:"TableName"` + Key map[string]attributeValue `json:"Key"` + ProjectionExpression string `json:"ProjectionExpression"` + ExpressionAttributeNames map[string]string `json:"ExpressionAttributeNames"` + ConsistentRead *bool `json:"ConsistentRead"` +} + +func (d *DynamoDBServer) parseGetItemInput(w http.ResponseWriter, r *http.Request) (getItemInput, bool) { + body, err := io.ReadAll(maxDynamoBodyReader(w, r)) + if err != nil { + writeDynamoError(w, http.StatusBadRequest, dynamoErrValidation, err.Error()) + return getItemInput{}, false + } + var in getItemInput + if err := json.Unmarshal(body, &in); err != nil { + writeDynamoError(w, http.StatusBadRequest, dynamoErrValidation, err.Error()) + return getItemInput{}, false + } + if strings.TrimSpace(in.TableName) == "" { + writeDynamoError(w, http.StatusBadRequest, dynamoErrValidation, "missing table name") + return getItemInput{}, false + } + if err := d.ensureLegacyTableMigration(r.Context(), in.TableName); err != nil { + writeDynamoErrorFromErr(w, err) + return getItemInput{}, false + } + return in, true +} + +func (d *DynamoDBServer) getItem(w http.ResponseWriter, r *http.Request) { + in, ok := d.parseGetItemInput(w, r) + if !ok { + return + } + // Tentative TS for schema resolution only; schemas change rarely + // so a slight pre-lease stale is acceptable. The item read below + // is sampled AFTER the lease check. + tentativeTS := d.resolveDynamoReadTS(in.ConsistentRead) + _, itemKey, ok := d.resolveGetItemTarget(w, r, in, tentativeTS) + if !ok { + return + } + // Lease-check the shard that actually owns the ITEM key with a + // bounded timeout so a stalled Raft cannot hang this handler + // indefinitely if the client never cancels. Use defer so the + // cancel runs even if LeaseReadForKey panics or a future + // refactor inserts an early return; the cost of keeping ctx + // alive until handler exit is negligible because the next + // in-handler calls are local store reads. + leaseCtx, leaseCancel := context.WithTimeout(r.Context(), dynamoLeaseReadTimeout) + defer leaseCancel() + if _, err := kv.LeaseReadForKeyThrough(d.coordinator, leaseCtx, itemKey); err != nil { + writeDynamoError(w, http.StatusInternalServerError, dynamoErrInternal, err.Error()) + return + } + // Re-sample readTS AFTER the lease confirmation so that any write + // that completed on the same shard BEFORE the confirmation is + // visible. Sampling earlier would violate linearizability for + // ConsistentRead=false reads by returning a snapshot from before + // the most recent confirmed commit. + readTS := d.resolveDynamoReadTS(in.ConsistentRead) + // Pin readTS so concurrent MVCC GC cannot reclaim versions + // between the schema revalidation and the item read below; + // matches the pattern already used by queryItems / scanItems / + // transactGetItems. + readPin := d.pinReadTS(readTS) + defer readPin.Release() + + // Re-resolve schema + itemKey at readTS and verify that the key + // we lease-checked is STILL the key that will be read. A table + // migration that commits between the tentative schema load and + // the lease confirmation may shift the item to a different shard + // even if the request parameters are unchanged, so comparing the + // computed item keys (not just generation) catches any future + // schema change that alters item routing. + finalSchema, freshItemKey, ok := d.resolveGetItemTarget(w, r, in, readTS) + if !ok { + return + } + if !bytes.Equal(freshItemKey, itemKey) { + writeDynamoError(w, http.StatusServiceUnavailable, dynamoErrInternal, + "table routing changed during read; please retry") + return + } + + current, found, err := d.readLogicalItemAt(r.Context(), finalSchema, in.Key, readTS) + if err != nil { + writeDynamoError(w, http.StatusBadRequest, dynamoErrValidation, err.Error()) + return + } + if !found { + writeDynamoJSON(w, map[string]any{}) + return + } + d.observeReadMetrics(r.Context(), in.TableName, 1, 1) + projected, err := projectItem(current.item, in.ProjectionExpression, in.ExpressionAttributeNames) + if err != nil { + writeDynamoErrorFromErr(w, err) + return + } + writeDynamoJSON(w, map[string]any{"Item": projected}) +} + +// resolveGetItemTarget loads the schema and computes the item key whose +// shard must be lease-checked before the read. Returns false after +// writing an error response; the caller should simply return. +func (d *DynamoDBServer) resolveGetItemTarget(w http.ResponseWriter, r *http.Request, in getItemInput, readTS uint64) (*dynamoTableSchema, []byte, bool) { + schema, exists, err := d.loadTableSchemaAt(r.Context(), in.TableName, readTS) + if err != nil { + writeDynamoError(w, http.StatusInternalServerError, dynamoErrInternal, err.Error()) + return nil, nil, false + } + if !exists { + writeDynamoError(w, http.StatusBadRequest, dynamoErrResourceNotFound, "table not found") + return nil, nil, false + } + itemKey, err := schema.itemKeyFromAttributes(in.Key) + if err != nil { + writeDynamoError(w, http.StatusBadRequest, dynamoErrValidation, err.Error()) + return nil, nil, false + } + return schema, itemKey, true +} + +func (d *DynamoDBServer) readItemAtKeyAt(ctx context.Context, key []byte, ts uint64) (map[string]attributeValue, bool, error) { + b, err := d.store.GetAt(ctx, key, ts) + if err != nil { + if errors.Is(err, store.ErrKeyNotFound) { + return nil, false, nil + } + return nil, false, errors.WithStack(err) + } + item, err := decodeStoredDynamoItem(b) + if err != nil { + return nil, false, err + } + return item, true, nil +} + +func (d *DynamoDBServer) readLogicalItemAt( + ctx context.Context, + schema *dynamoTableSchema, + key map[string]attributeValue, + ts uint64, +) (*dynamoItemLocation, bool, error) { + itemKey, err := schema.itemKeyFromAttributes(key) + if err != nil { + return nil, false, err + } + item, found, err := d.readItemAtKeyAt(ctx, itemKey, ts) + if err != nil { + return nil, false, err + } + if found { + return &dynamoItemLocation{schema: schema, key: itemKey, item: item}, true, nil + } + sourceSchema := schema.migrationSourceSchema() + if sourceSchema == nil { + return nil, false, nil + } + sourceKey, err := sourceSchema.itemKeyFromAttributes(key) + if err != nil { + return nil, false, err + } + item, found, err = d.readItemAtKeyAt(ctx, sourceKey, ts) + if err != nil { + return nil, false, err + } + if !found { + return nil, false, nil + } + return &dynamoItemLocation{schema: sourceSchema, key: sourceKey, item: item}, true, nil +} diff --git a/adapter/dynamodb_item_write.go b/adapter/dynamodb_item_write.go new file mode 100644 index 00000000..a5682259 --- /dev/null +++ b/adapter/dynamodb_item_write.go @@ -0,0 +1,921 @@ +package adapter + +import ( + "bytes" + "context" + "io" + "maps" + "net/http" + "sort" + "strings" + "time" + + "github.com/bootjp/elastickv/kv" + "github.com/bootjp/elastickv/store" + "github.com/cockroachdb/errors" + json "github.com/goccy/go-json" +) + +type updateItemInput struct { + TableName string `json:"TableName"` + Key map[string]attributeValue `json:"Key"` + UpdateExpression string `json:"UpdateExpression"` + ConditionExpression string `json:"ConditionExpression"` + ExpressionAttributeNames map[string]string `json:"ExpressionAttributeNames"` + ExpressionAttributeValues map[string]attributeValue `json:"ExpressionAttributeValues"` + ReturnValues string `json:"ReturnValues"` +} + +type putItemInput struct { + TableName string `json:"TableName"` + Item map[string]attributeValue `json:"Item"` + ConditionExpression string `json:"ConditionExpression"` + ExpressionAttributeNames map[string]string `json:"ExpressionAttributeNames"` + ExpressionAttributeValues map[string]attributeValue `json:"ExpressionAttributeValues"` + ReturnValues string `json:"ReturnValues"` +} + +type deleteItemInput struct { + TableName string `json:"TableName"` + Key map[string]attributeValue `json:"Key"` + ConditionExpression string `json:"ConditionExpression"` + ExpressionAttributeNames map[string]string `json:"ExpressionAttributeNames"` + ExpressionAttributeValues map[string]attributeValue `json:"ExpressionAttributeValues"` + ReturnValues string `json:"ReturnValues"` +} + +func (d *DynamoDBServer) putItem(w http.ResponseWriter, r *http.Request) { + in, err := decodePutItemInput(maxDynamoBodyReader(w, r)) + if err != nil { + writeDynamoErrorFromErr(w, err) + return + } + if err := d.ensureLegacyTableMigration(r.Context(), in.TableName); err != nil { + writeDynamoErrorFromErr(w, err) + return + } + plan, err := d.putItemWithRetry(r.Context(), in) + if err != nil { + writeDynamoErrorFromErr(w, err) + return + } + d.observeWrittenItems(r.Context(), in.TableName, 1) + resp := map[string]any{} + if attrs := putItemReturnAttributes(in.ReturnValues, plan.current); len(attrs) > 0 { + resp["Attributes"] = attrs + } + writeDynamoJSON(w, resp) +} + +func decodePutItemInput(bodyReader io.Reader) (putItemInput, error) { + body, err := io.ReadAll(bodyReader) + if err != nil { + return putItemInput{}, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) + } + var in putItemInput + if err := json.Unmarshal(body, &in); err != nil { + return putItemInput{}, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) + } + if strings.TrimSpace(in.TableName) == "" { + return putItemInput{}, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, "missing table name") + } + if err := validatePutItemReturnValues(in.ReturnValues); err != nil { + return putItemInput{}, err + } + return in, nil +} + +func validatePutItemReturnValues(returnValues string) error { + switch strings.TrimSpace(returnValues) { + case "", dynamoReturnValueNone, dynamoReturnValueAllOld: + return nil + default: + return newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, "unsupported ReturnValues") + } +} + +func (d *DynamoDBServer) putItemWithRetry(ctx context.Context, in putItemInput) (*itemWritePlan, error) { + return d.retryItemWriteWithGeneration( + ctx, + in.TableName, + "put item retry attempts exhausted", + func(readTS uint64) (*itemWritePlan, error) { + return d.preparePutItemWrite(ctx, in, readTS) + }, + ) +} + +type itemWritePlan struct { + req *kv.OperationGroup[kv.OP] + generation uint64 + cleanup [][]byte + current map[string]attributeValue + next map[string]attributeValue +} + +func (d *DynamoDBServer) retryItemWriteWithGeneration( + ctx context.Context, + tableName string, + exhaustedMessage string, + prepare func(readTS uint64) (*itemWritePlan, error), +) (*itemWritePlan, error) { + // Option-2 one-phase dedup (gated, default off): on a retryable write error, + // reuse the failed attempt's write set under a fresh commit_ts + prev_commit_ts + // so the FSM no-ops a commit that already landed under leadership churn, + // instead of re-reading and re-appending (the :duplicate-elements anomaly). + // See docs/design/2026_06_03_partial_dynamodb_onephase_dedup.md. + // + // Leader-only (codex P1, PR #920): the dedup path allocates commit_ts from + // the LOCAL HLC and carries it as prev_commit_ts, so that timestamp MUST be + // leader-issued to stay globally unique — otherwise two frontends could mint + // the same commit_ts in one millisecond and the exact-ts probe would dedup + // against the wrong writer's version, losing an update. On the leader the + // single HLC issues monotonic unique values, and NextFenced's physical-ceiling + // fence keeps a deposed leader's window disjoint from its successor's. A + // non-leader (reachable only when no leaderMap HTTP proxy forwards follower + // ingress) falls back to the legacy path, where Coordinator.Dispatch redirects + // to the leader and the LEADER allocates commit_ts — never this follower's HLC. + if d.onePhaseTxnDedup && d.coordinator.IsLeader() { + return d.retryItemWriteWithGenerationDedup(ctx, tableName, exhaustedMessage, prepare) + } + return d.retryItemWriteWithGenerationLegacy(ctx, tableName, exhaustedMessage, prepare) +} + +// retryItemWriteWithGenerationLegacy is the pre-dedup retry loop: it recomputes +// the write set from a fresh read on every retryable error. It is the active +// path whenever the dedup gate is off or this node is not the leader, so it +// stays byte-identical to the pre-feature behavior. +func (d *DynamoDBServer) retryItemWriteWithGenerationLegacy( + ctx context.Context, + tableName string, + exhaustedMessage string, + prepare func(readTS uint64) (*itemWritePlan, error), +) (*itemWritePlan, error) { + backoff := transactRetryInitialBackoff + deadline := time.Now().Add(transactRetryMaxDuration) + for range transactRetryMaxAttempts { + readTS := d.nextTxnReadTS() + plan, err := prepare(readTS) + if err != nil { + return nil, err + } + if plan.req == nil { + return plan, nil + } + plan.req.StartTS = readTS + if err = d.commitItemWrite(ctx, plan.req); err != nil { + if !isRetryableTransactWriteError(err) { + return nil, errors.WithStack(err) + } + } else { + retry, verifyErr := d.handleGenerationFenceResult( + ctx, + d.verifyTableGeneration(ctx, tableName, plan.generation), + plan.cleanup, + ) + if verifyErr != nil { + return nil, verifyErr + } + if !retry { + return plan, nil + } + } + if err := waitRetryWithDeadline(ctx, deadline, backoff); err != nil { + return nil, errors.WithStack(err) + } + backoff = nextTransactRetryBackoff(backoff) + } + return nil, newDynamoAPIError(http.StatusInternalServerError, dynamoErrInternal, exhaustedMessage) +} + +// reusableItemWrite captures a dispatched single-item write attempt so a +// subsequent retry can REUSE its exact write set (the same Put/Del elems) under +// a fresh commit_ts and probe whether it already landed, instead of re-reading +// and recomputing the item. Recomputing is what duplicates a list_append under +// leadership churn: attempt 1 commits at C1 but returns a WriteConflict, the +// retry re-reads the now-larger list and appends again. Reuse + the FSM's +// exact-ts dedup probe close that. See option 2 in +// docs/design/2026_06_03_partial_dynamodb_onephase_dedup.md. +type reusableItemWrite struct { + // plan holds the reused OperationGroup (plan.req: Elems + fixed StartTS) and + // the captured current/next item. The client-visible result + // (updateItemReturnAttributes over current/next) is invariant across reuse + // — the write set was built once from attempt 1's read — so plan is also the + // correct value to return when the FSM dedup no-ops the apply (R1). + plan *itemWritePlan + // commitTS is the most recent dispatched commit_ts for this write set; the + // next retry passes it as PrevCommitTS so the FSM probes exactly the attempt + // that might have landed. + commitTS uint64 + // probeKey is kv.PrimaryKeyForElems(plan.req.Elems) — the same key the FSM + // uses as meta.PrimaryKey — so the adapter-side self-inflicted-conflict guard + // and the FSM dedup probe agree on the point they query (R4). + probeKey []byte +} + +// retryItemWriteWithGenerationDedup is the option-2 retry loop. The first +// attempt computes the write set from a fresh read; any retryable failure makes +// the next iteration REUSE that write set under a fresh commit_ts carrying +// prev_commit_ts, so the FSM no-ops if the prior attempt already landed. A +// genuine WriteConflict on a reuse (the self-conflict probe missed) drops the +// pending attempt and recomputes from a fresh read. +func (d *DynamoDBServer) retryItemWriteWithGenerationDedup( + ctx context.Context, + tableName string, + exhaustedMessage string, + prepare func(readTS uint64) (*itemWritePlan, error), +) (*itemWritePlan, error) { + backoff := transactRetryInitialBackoff + deadline := time.Now().Add(transactRetryMaxDuration) + var pending *reusableItemWrite + for range transactRetryMaxAttempts { + var ( + plan *itemWritePlan + err error + ) + if pending != nil { + plan, pending, err = d.itemWriteReuseAttempt(ctx, tableName, pending) + } else { + plan, pending, err = d.itemWriteFirstAttempt(ctx, tableName, prepare) + } + if err != nil { + // commitItemWrite already wraps dispatch errors; the attempt helpers + // return them raw, so return raw here too (no double WithStack). + if !isRetryableTransactWriteError(err) { + return nil, err + } + } else if plan != nil { + return plan, nil + } + if waitErr := waitRetryWithDeadline(ctx, deadline, backoff); waitErr != nil { + return nil, errors.WithStack(waitErr) + } + backoff = nextTransactRetryBackoff(backoff) + } + return nil, newDynamoAPIError(http.StatusInternalServerError, dynamoErrInternal, exhaustedMessage) +} + +// itemWriteFirstAttempt runs the recompute branch of the dedup loop: a fresh +// read snapshot, a locally-allocated commit_ts, and a dispatch. On a retryable +// write error it returns a reusableItemWrite so the next iteration reuses this +// write set. Return shapes match itemWriteReuseAttempt (see +// retryItemWriteWithGenerationDedup). +func (d *DynamoDBServer) itemWriteFirstAttempt( + ctx context.Context, + tableName string, + prepare func(readTS uint64) (*itemWritePlan, error), +) (*itemWritePlan, *reusableItemWrite, error) { + readTS := d.nextTxnReadTS() + plan, err := prepare(readTS) + if err != nil { + return nil, nil, err + } + if plan.req == nil { + return plan, nil, nil + } + // NextFenced (not Next) honors the HLC physical-ceiling fence so a + // stale-leader window cannot mint a colliding commit_ts (HLC-4); + // ErrCeilingExpired is non-retryable and surfaces to the client. + commitTS, err := d.coordinator.Clock().NextFenced() + if err != nil { + return nil, nil, errors.Wrap(err, "dynamodb item-write first attempt: allocate commitTS") + } + plan.req.StartTS = readTS + plan.req.CommitTS = commitTS + if dispErr := d.commitItemWrite(ctx, plan.req); dispErr != nil { + // dispErr is already wrapped by commitItemWrite; return it raw. + if isRetryableTransactWriteError(dispErr) { + return nil, &reusableItemWrite{ + plan: plan, + commitTS: commitTS, + probeKey: kv.PrimaryKeyForElems(plan.req.Elems), + }, dispErr + } + return nil, nil, dispErr + } + return d.finishItemWriteAttempt(ctx, tableName, plan) +} + +// itemWriteReuseAttempt runs one reuse iteration: re-dispatch the captured write +// set under a fresh commit_ts carrying pending.commitTS as PrevCommitTS, so the +// FSM probes whether the prior attempt landed. +func (d *DynamoDBServer) itemWriteReuseAttempt( + ctx context.Context, + tableName string, + pending *reusableItemWrite, +) (*itemWritePlan, *reusableItemWrite, error) { + commitTS, err := d.coordinator.Clock().NextFenced() + if err != nil { + return nil, pending, errors.Wrap(err, "dynamodb item-write reuse: allocate commitTS") + } + pending.plan.req.CommitTS = commitTS + pending.plan.req.PrevCommitTS = pending.commitTS + dispErr := d.commitItemWrite(ctx, pending.plan.req) + if dispErr == nil { + return d.finishItemWriteAttempt(ctx, tableName, pending.plan) + } + if errors.Is(dispErr, store.ErrWriteConflict) { + return d.resolveReuseWriteConflict(ctx, tableName, pending, commitTS, dispErr) + } + if isRetryableTransactWriteError(dispErr) { + // Still ambiguous (e.g. TxnLocked): this reuse may itself have landed, + // so the next retry must probe THIS commit_ts. dispErr is already + // wrapped by commitItemWrite; return it raw. + pending.commitTS = commitTS + return nil, pending, dispErr + } + return nil, nil, dispErr +} + +// resolveReuseWriteConflict handles a WriteConflict from a reuse dispatch via +// the self-inflicted-conflict guard: probe whether THIS reuse's commit_ts +// actually landed (the apply may have committed but surfaced WriteConflict under +// churn). On a hit the conflict is against our own commit — return the cached +// plan, no double-apply. On a miss the write key is genuinely held by another +// txn — drop pending so the next iteration recomputes from a fresh read. +func (d *DynamoDBServer) resolveReuseWriteConflict( + ctx context.Context, + tableName string, + pending *reusableItemWrite, + commitTS uint64, + dispErr error, +) (*itemWritePlan, *reusableItemWrite, error) { + if len(pending.probeKey) > 0 { + landed, perr := d.store.CommittedVersionAt(ctx, pending.probeKey, commitTS) + if perr != nil { + // Fail closed: a probe read error makes "did our reuse land?" + // unknowable, and a blind recompute would double-append if it HAD + // landed. Surface the probe error instead of silently recomputing, + // matching the FSM-side dedupProbeOnePhase (kv/fsm.go) which also + // propagates probe errors. The wrapped error is non-retryable, so + // the loop returns it to the client rather than re-applying. + return nil, nil, errors.Wrap(perr, "dynamodb item-write: self-conflict probe") + } + if landed { + // The reuse landed at commitTS. Run the SAME generation fence + + // cleanup the normal success path runs (finishItemWriteAttempt), so + // a table dropped/recreated under us cleans up the landed write and + // recomputes instead of returning a stale plan (coderabbit major). + return d.finishItemWriteAttempt(ctx, tableName, pending.plan) + } + } + // Probe missed (or no probe key): a genuine cross-writer conflict. dispErr + // is already wrapped by commitItemWrite; return it raw so the loop recomputes. + return nil, nil, dispErr +} + +// finishItemWriteAttempt runs the table-generation fence after a successful +// commit. Returns (plan, nil, nil) when the write is durable; (nil, nil, nil) +// when the generation changed and the caller must recompute from a fresh read; +// (nil, nil, err) on a fence error. +func (d *DynamoDBServer) finishItemWriteAttempt( + ctx context.Context, + tableName string, + plan *itemWritePlan, +) (*itemWritePlan, *reusableItemWrite, error) { + retry, verifyErr := d.handleGenerationFenceResult( + ctx, + d.verifyTableGeneration(ctx, tableName, plan.generation), + plan.cleanup, + ) + if verifyErr != nil { + return nil, nil, verifyErr + } + if retry { + return nil, nil, nil + } + return plan, nil, nil +} + +func (d *DynamoDBServer) preparePutItemWrite(ctx context.Context, in putItemInput, readTS uint64) (*itemWritePlan, error) { + schema, exists, err := d.loadTableSchemaAt(ctx, in.TableName, readTS) + if err != nil { + return nil, errors.WithStack(err) + } + if !exists { + return nil, newDynamoAPIError(http.StatusBadRequest, dynamoErrResourceNotFound, "table not found") + } + itemKey, err := schema.itemKeyFromAttributes(in.Item) + if err != nil { + return nil, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) + } + keyAttrs, err := primaryKeyAttributes(schema.PrimaryKey, in.Item) + if err != nil { + return nil, err + } + currentLocation, found, err := d.readLogicalItemAt(ctx, schema, keyAttrs, readTS) + if err != nil { + return nil, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) + } + var current map[string]attributeValue + if found { + current = currentLocation.item + } + if err := validateConditionOnItem( + in.ConditionExpression, + in.ExpressionAttributeNames, + in.ExpressionAttributeValues, + valueOrEmptyMap(current, found), + ); err != nil { + return nil, newDynamoAPIError(http.StatusBadRequest, dynamoErrConditionalFailed, err.Error()) + } + req, cleanup, err := buildItemWriteRequestWithSource(schema, itemKey, in.Item, currentLocation) + if err != nil { + return nil, err + } + return &itemWritePlan{ + req: req, + generation: schema.Generation, + cleanup: cleanup, + current: cloneAttributeValueMap(current), + next: cloneAttributeValueMap(in.Item), + }, nil +} + +func (d *DynamoDBServer) commitItemWrite(ctx context.Context, req *kv.OperationGroup[kv.OP]) error { + _, err := d.coordinator.Dispatch(ctx, req) + if err != nil { + return errors.WithStack(err) + } + return nil +} + +func (d *DynamoDBServer) deleteItem(w http.ResponseWriter, r *http.Request) { + in, shouldReturnOld, err := decodeDeleteItemInput(maxDynamoBodyReader(w, r)) + if err != nil { + writeDynamoErrorFromErr(w, err) + return + } + if err := d.ensureLegacyTableMigration(r.Context(), in.TableName); err != nil { + writeDynamoErrorFromErr(w, err) + return + } + lockKey, err := dynamoItemUpdateLockKey(in.TableName, in.Key) + if err != nil { + writeDynamoError(w, http.StatusBadRequest, dynamoErrValidation, err.Error()) + return + } + unlock := d.lockItemUpdate(lockKey) + defer unlock() + plan, err := d.deleteItemWithRetry(r.Context(), in) + if err != nil { + writeDynamoErrorFromErr(w, err) + return + } + if len(plan.current) > 0 { + d.observeWrittenItems(r.Context(), in.TableName, 1) + } + resp := map[string]any{} + if shouldReturnOld && len(plan.current) > 0 { + resp["Attributes"] = plan.current + } + writeDynamoJSON(w, resp) +} + +func decodeDeleteItemInput(bodyReader io.Reader) (deleteItemInput, bool, error) { + body, err := io.ReadAll(bodyReader) + if err != nil { + return deleteItemInput{}, false, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) + } + var in deleteItemInput + if err := json.Unmarshal(body, &in); err != nil { + return deleteItemInput{}, false, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) + } + if strings.TrimSpace(in.TableName) == "" { + return deleteItemInput{}, false, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, "missing table name") + } + shouldReturnOld, err := parseDeleteItemReturnValues(in.ReturnValues) + if err != nil { + return deleteItemInput{}, false, err + } + return in, shouldReturnOld, nil +} + +func parseDeleteItemReturnValues(returnValues string) (bool, error) { + switch strings.TrimSpace(returnValues) { + case "", dynamoReturnValueNone: + return false, nil + case dynamoReturnValueAllOld: + return true, nil + default: + return false, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, "unsupported ReturnValues") + } +} + +type deleteItemPlan struct { + req *kv.OperationGroup[kv.OP] + generation uint64 + current map[string]attributeValue +} + +func (d *DynamoDBServer) deleteItemWithRetry(ctx context.Context, in deleteItemInput) (*deleteItemPlan, error) { + var deletePlan *deleteItemPlan + _, err := d.retryItemWriteWithGeneration( + ctx, + in.TableName, + "delete retry attempts exhausted", + func(readTS uint64) (*itemWritePlan, error) { + var err error + deletePlan, err = d.prepareDeleteItemWrite(ctx, in, readTS) + if err != nil { + return nil, err + } + return &itemWritePlan{ + req: deletePlan.req, + generation: deletePlan.generation, + }, nil + }, + ) + if err != nil { + return nil, err + } + return deletePlan, nil +} + +func (d *DynamoDBServer) prepareDeleteItemWrite(ctx context.Context, in deleteItemInput, readTS uint64) (*deleteItemPlan, error) { + schema, exists, err := d.loadTableSchemaAt(ctx, in.TableName, readTS) + if err != nil { + return nil, errors.WithStack(err) + } + if !exists { + return nil, newDynamoAPIError(http.StatusBadRequest, dynamoErrResourceNotFound, "table not found") + } + currentLocation, found, err := d.readLogicalItemAt(ctx, schema, in.Key, readTS) + if err != nil { + return nil, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) + } + current := map[string]attributeValue(nil) + if found { + current = currentLocation.item + } + if err := validateConditionOnItem( + in.ConditionExpression, + in.ExpressionAttributeNames, + in.ExpressionAttributeValues, + valueOrEmptyMap(current, found), + ); err != nil { + return nil, newDynamoAPIError(http.StatusBadRequest, dynamoErrConditionalFailed, err.Error()) + } + if !found { + return &deleteItemPlan{current: nil}, nil + } + req, err := buildItemDeleteRequestWithSource(currentLocation) + if err != nil { + return nil, err + } + return &deleteItemPlan{ + req: req, + generation: schema.Generation, + current: cloneAttributeValueMap(current), + }, nil +} + +func (d *DynamoDBServer) updateItem(w http.ResponseWriter, r *http.Request) { + in, err := decodeUpdateItemInput(maxDynamoBodyReader(w, r)) + if err != nil { + writeDynamoErrorFromErr(w, err) + return + } + if err := d.ensureLegacyTableMigration(r.Context(), in.TableName); err != nil { + writeDynamoErrorFromErr(w, err) + return + } + lockKey, err := dynamoItemUpdateLockKey(in.TableName, in.Key) + if err != nil { + writeDynamoError(w, http.StatusBadRequest, dynamoErrValidation, err.Error()) + return + } + unlock := d.lockItemUpdate(lockKey) + defer unlock() + plan, err := d.updateItemWithRetry(r.Context(), in) + if err != nil { + writeDynamoErrorFromErr(w, err) + return + } + d.observeWrittenItems(r.Context(), in.TableName, 1) + resp := map[string]any{} + if attrs := updateItemReturnAttributes(in.ReturnValues, plan.current, plan.next); len(attrs) > 0 { + resp["Attributes"] = attrs + } + writeDynamoJSON(w, resp) +} + +func decodeUpdateItemInput(bodyReader io.Reader) (updateItemInput, error) { + body, err := io.ReadAll(bodyReader) + if err != nil { + return updateItemInput{}, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) + } + var in updateItemInput + if err := json.Unmarshal(body, &in); err != nil { + return updateItemInput{}, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) + } + if strings.TrimSpace(in.TableName) == "" { + return updateItemInput{}, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, "missing table name") + } + if err := validateUpdateItemReturnValues(in.ReturnValues); err != nil { + return updateItemInput{}, err + } + return in, nil +} + +func validateUpdateItemReturnValues(returnValues string) error { + switch strings.TrimSpace(returnValues) { + case "", + dynamoReturnValueNone, + dynamoReturnValueAllOld, + dynamoReturnValueUpdatedOld, + dynamoReturnValueAllNew, + dynamoReturnValueUpdatedNew: + return nil + default: + return newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, "unsupported ReturnValues") + } +} + +func (d *DynamoDBServer) updateItemWithRetry(ctx context.Context, in updateItemInput) (*itemWritePlan, error) { + return d.retryItemWriteWithGeneration( + ctx, + in.TableName, + "update retry attempts exhausted", + func(readTS uint64) (*itemWritePlan, error) { + return d.prepareUpdateItemWrite(ctx, in, readTS) + }, + ) +} + +func (d *DynamoDBServer) prepareUpdateItemWrite(ctx context.Context, in updateItemInput, readTS uint64) (*itemWritePlan, error) { + schema, exists, err := d.loadTableSchemaAt(ctx, in.TableName, readTS) + if err != nil { + return nil, errors.WithStack(err) + } + if !exists { + return nil, newDynamoAPIError(http.StatusBadRequest, dynamoErrResourceNotFound, "table not found") + } + itemKey, err := schema.itemKeyFromAttributes(in.Key) + if err != nil { + return nil, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) + } + currentLocation, found, err := d.readLogicalItemAt(ctx, schema, in.Key, readTS) + if err != nil { + return nil, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) + } + var current map[string]attributeValue + if !found { + current = map[string]attributeValue{} + } else { + current = currentLocation.item + } + nextItem, err := buildUpdatedItem(schema, in, current) + if err != nil { + return nil, err + } + req, cleanup, err := buildItemWriteRequestWithSource(schema, itemKey, nextItem, currentLocation) + if err != nil { + return nil, err + } + return &itemWritePlan{ + req: req, + generation: schema.Generation, + cleanup: cleanup, + current: cloneAttributeValueMap(current), + next: cloneAttributeValueMap(nextItem), + }, nil +} + +func buildUpdatedItem(schema *dynamoTableSchema, in updateItemInput, current map[string]attributeValue) (map[string]attributeValue, error) { + if err := validateConditionOnItem(in.ConditionExpression, in.ExpressionAttributeNames, in.ExpressionAttributeValues, current); err != nil { + return nil, newDynamoAPIError(http.StatusBadRequest, dynamoErrConditionalFailed, err.Error()) + } + nextItem := cloneAttributeValueMap(current) + maps.Copy(nextItem, in.Key) + if err := applyUpdateExpression(in.UpdateExpression, in.ExpressionAttributeNames, in.ExpressionAttributeValues, nextItem); err != nil { + return nil, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) + } + if err := ensurePrimaryKeyUnchanged(schema.PrimaryKey, in.Key, nextItem); err != nil { + return nil, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) + } + return nextItem, nil +} + +func ensurePrimaryKeyUnchanged(keySchema dynamoKeySchema, originalKey map[string]attributeValue, nextItem map[string]attributeValue) error { + if err := ensureSinglePrimaryKeyUnchanged(keySchema.HashKey, originalKey, nextItem); err != nil { + return err + } + if keySchema.RangeKey != "" { + if err := ensureSinglePrimaryKeyUnchanged(keySchema.RangeKey, originalKey, nextItem); err != nil { + return err + } + } + return nil +} + +func ensureSinglePrimaryKeyUnchanged(attrName string, originalKey map[string]attributeValue, nextItem map[string]attributeValue) error { + keyVal, ok := originalKey[attrName] + if !ok { + return errors.New("missing key attribute") + } + nextVal, ok := nextItem[attrName] + if !ok { + return errors.New("cannot remove key attribute") + } + if !attributeValueEqual(keyVal, nextVal) { + return errors.New("cannot update primary key attribute") + } + return nil +} + +type dynamoItemLocation struct { + schema *dynamoTableSchema + key []byte + item map[string]attributeValue +} + +func buildItemWriteRequestWithSource( + targetSchema *dynamoTableSchema, + targetKey []byte, + nextItem map[string]attributeValue, + current *dynamoItemLocation, +) (*kv.OperationGroup[kv.OP], [][]byte, error) { + payload, err := encodeStoredDynamoItem(nextItem) + if err != nil { + return nil, nil, errors.WithStack(err) + } + elems := []*kv.Elem[kv.OP]{{Op: kv.Put, Key: targetKey, Value: payload}} + cleanup := [][]byte{targetKey} + delKeys, putKeys, err := itemStorageDelta(targetSchema, targetKey, nextItem, current) + if err != nil { + return nil, nil, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) + } + for _, key := range delKeys { + elems = append(elems, &kv.Elem[kv.OP]{Op: kv.Del, Key: key}) + } + for _, key := range putKeys { + elems = append(elems, &kv.Elem[kv.OP]{Op: kv.Put, Key: key, Value: targetKey}) + cleanup = append(cleanup, key) + } + return &kv.OperationGroup[kv.OP]{ + IsTxn: true, + StartTS: 0, + Elems: elems, + }, cleanup, nil +} + +func itemStorageDelta( + targetSchema *dynamoTableSchema, + targetKey []byte, + nextItem map[string]attributeValue, + current *dynamoItemLocation, +) ([][]byte, [][]byte, error) { + oldKeys, err := itemStorageKeys(current) + if err != nil { + return nil, nil, err + } + newKeys, err := targetSchema.gsiEntryKeysForItem(nextItem) + if err != nil { + return nil, nil, err + } + newSet := bytesToSet(newKeys) + oldSet := bytesToSet(oldKeys) + delete(oldSet, string(targetKey)) + delKeys := make([][]byte, 0, len(oldKeys)) + for key, raw := range oldSet { + if _, ok := newSet[key]; ok { + continue + } + delKeys = append(delKeys, raw) + } + putKeys := make([][]byte, 0, len(newKeys)) + for key, raw := range newSet { + if _, ok := oldSet[key]; ok { + continue + } + putKeys = append(putKeys, raw) + } + return delKeys, putKeys, nil +} + +func itemStorageKeys(current *dynamoItemLocation) ([][]byte, error) { + if current == nil || len(current.item) == 0 { + return nil, nil + } + gsiKeys, err := current.schema.gsiEntryKeysForItem(current.item) + if err != nil { + return nil, err + } + out := make([][]byte, 0, len(gsiKeys)+1) + out = append(out, bytes.Clone(current.key)) + out = append(out, gsiKeys...) + return out, nil +} + +func bytesToSet(keys [][]byte) map[string][]byte { + out := make(map[string][]byte, len(keys)) + for _, key := range keys { + out[string(key)] = key + } + return out +} + +func putItemReturnAttributes(returnValues string, current map[string]attributeValue) map[string]attributeValue { + if !strings.EqualFold(strings.TrimSpace(returnValues), dynamoReturnValueAllOld) || len(current) == 0 { + return nil + } + return cloneAttributeValueMap(current) +} + +func updateItemReturnAttributes(returnValues string, current map[string]attributeValue, next map[string]attributeValue) map[string]attributeValue { + switch strings.TrimSpace(returnValues) { + case "", dynamoReturnValueNone: + return nil + case dynamoReturnValueAllOld: + if len(current) == 0 { + return nil + } + return cloneAttributeValueMap(current) + case dynamoReturnValueAllNew: + return cloneAttributeValueMap(next) + case dynamoReturnValueUpdatedOld: + return selectUpdatedAttributes(current, next, true) + case dynamoReturnValueUpdatedNew: + return selectUpdatedAttributes(current, next, false) + default: + return nil + } +} + +func selectUpdatedAttributes(current map[string]attributeValue, next map[string]attributeValue, oldValues bool) map[string]attributeValue { + keys := updatedAttributeNames(current, next) + if len(keys) == 0 { + return nil + } + out := make(map[string]attributeValue, len(keys)) + for _, key := range keys { + if oldValues { + if value, ok := current[key]; ok { + out[key] = value + } + continue + } + if value, ok := next[key]; ok { + out[key] = value + } + } + if len(out) == 0 { + return nil + } + return out +} + +func updatedAttributeNames(current map[string]attributeValue, next map[string]attributeValue) []string { + seen := make(map[string]struct{}, len(current)+len(next)) + for name := range current { + seen[name] = struct{}{} + } + for name := range next { + seen[name] = struct{}{} + } + names := make([]string, 0, len(seen)) + for name := range seen { + names = append(names, name) + } + sort.Strings(names) + out := make([]string, 0, len(names)) + for _, name := range names { + oldVal, oldOK := current[name] + newVal, newOK := next[name] + if !oldOK && !newOK { + continue + } + if oldOK && newOK && attributeValueEqual(oldVal, newVal) { + continue + } + out = append(out, name) + } + return out +} + +func buildItemDeleteRequestWithSource(current *dynamoItemLocation) (*kv.OperationGroup[kv.OP], error) { + if current == nil { + return &kv.OperationGroup[kv.OP]{ + IsTxn: true, + StartTS: 0, + Elems: nil, + }, nil + } + elems := []*kv.Elem[kv.OP]{{Op: kv.Del, Key: current.key}} + delKeys, err := itemStorageKeys(current) + if err != nil { + return nil, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) + } + for _, key := range delKeys { + if bytes.Equal(key, current.key) { + continue + } + elems = append(elems, &kv.Elem[kv.OP]{Op: kv.Del, Key: key}) + } + return &kv.OperationGroup[kv.OP]{ + IsTxn: true, + StartTS: 0, + Elems: elems, + }, nil +} diff --git a/adapter/dynamodb_lease.go b/adapter/dynamodb_lease.go new file mode 100644 index 00000000..fd8aed7c --- /dev/null +++ b/adapter/dynamodb_lease.go @@ -0,0 +1,522 @@ +package adapter + +import ( + "context" + "net/http" + "strings" + "sync" + + "github.com/bootjp/elastickv/kv" + "github.com/cockroachdb/errors" +) + +// leaseReadKeyless performs a keyless quorum-freshness lease check for +// multi-shard read handlers (Scan, GSI/whole-table Query fallback). These +// reads visit every shard the range intersects, so the check fences EVERY +// group the coordinator owns via LeaseReadAllGroupsThrough — a default-group- +// only lease would let a non-default group serve a stale snapshot. A +// single-group coordinator falls back to one LeaseRead, so single-group +// deployments still issue exactly one lease read. It bounds the wait with +// dynamoLeaseReadTimeout so a stalled Raft cannot hang the handler when the +// client never cancels, and writes the same InternalServerError that getItem +// produces on lease failure. Returns false after writing an error response; +// the caller should simply return. +// leaseReadKeyless fences every group via the keyless all-groups lease check. +// `leaseCtx` MUST be the SAME context the pre-pass armed (it bounds the entire +// pre-pass — schema read + the lease that lands here — by dynamoLeaseReadTimeout +// total; coderabbit Major on PR #952 round-4). Creating a fresh +// context.WithTimeout(r.Context(), dynamoLeaseReadTimeout) here would re-arm +// the 5s budget per call, so a slow schema read followed by the keyless +// fallback could consume close to 10s end-to-end. Callers that do NOT have a +// pre-pass context must pass their own bounded ctx; r.Context() with the +// handler's own timeout-on-the-roundabout is the conservative choice. +func (d *DynamoDBServer) leaseReadKeyless(w http.ResponseWriter, leaseCtx context.Context) bool { + if err := kv.LeaseReadAllGroupsThrough(d.coordinator, leaseCtx); err != nil { + writeDynamoError(w, http.StatusInternalServerError, dynamoErrInternal, err.Error()) + return false + } + return true +} + +// leaseCheckScan runs the Scan lease pre-pass. A Scan reads the whole table, so +// a VALID scan must fence EVERY group via the keyless all-groups check. But a +// scan against a missing table, an unknown index, or a GSI with +// ConsistentRead=true never touches data: the read path rejects it with a +// deterministic 4xx, so establishing freshness is unnecessary and a failed +// all-groups fence on a degraded deployment would mask that 4xx with a 500 +// (codex #952 P2-A). leaseCheckScan therefore cheaply pre-validates the request +// (schema load + the same GSI read-option checks scanItems re-runs) at a +// tentative timestamp and skips the lease on a client-side validation error, +// while still failing closed (fencing every group) on a transient schema-read +// failure. Returns false after writing an error response; the caller returns. +func (d *DynamoDBServer) leaseCheckScan(w http.ResponseWriter, r *http.Request, in scanInput) bool { + // leaseCtx bounds the pre-validation schema read AND the lease read so a + // stalled schema read cannot block the handler past dynamoLeaseReadTimeout + // before the lease phase begins. leaseReadKeyless creates its own bounded + // context for the actual lease read. + leaseCtx, leaseCancel := context.WithTimeout(r.Context(), dynamoLeaseReadTimeout) + defer leaseCancel() + schema, plan, err := d.multiShardReadLeasePlan(leaseCtx, in.TableName, in.IndexName, in.Select, in.ProjectionExpression, in.ExpressionAttributeNames, in.ConsistentRead) + if err != nil { + // Transient/internal schema-read failure: fail closed by fencing every + // group. leaseReadKeyless writes the same InternalServerError on its + // own failure. + return d.leaseReadKeyless(w, leaseCtx) + } + if plan == queryLeaseSkip { + // Client-side validation problem (table not found, unknown index, + // unsupported ConsistentRead): the read path re-runs the identical + // validation and surfaces the deterministic 4xx, so skip the lease so a + // degraded-lease failure cannot mask it with a 500 (codex #952 P2-A). + return true + } + // Malformed ExclusiveStartKey is a deterministic ValidationException the + // read path rejects in resolveTableReadBounds / resolveGSIReadBounds — + // before the iterator is constructed and before any store read. If we let + // the lease run first, a degraded shard's 500 would mask that 4xx + // (codex #952 P2 round-3). Pre-validate against the loaded schema and skip + // leasing on failure; the read path will surface the identical error. + if scanExclusiveStartKeyInvalid(schema, in) { + return true + } + // Same logic for a malformed ProjectionExpression: newReadPageState runs + // resolveProjectionAttributes before the iterator reads from the store, so a + // parse failure is a deterministic ValidationException the lease pre-pass + // must not mask (codex #952 P2 round-4 line 2346). validateGSIReadOptions + // already covers the GSI case; this catches the base-table path that the + // earlier ExclusiveStartKey check left exposed. + if projectionInvalid(in.ProjectionExpression, in.ExpressionAttributeNames) { + return true + } + // Valid whole-table read: fence every group (fail closed). + return d.leaseReadKeyless(w, leaseCtx) +} + +// multiShardReadLeasePlan cheaply classifies whether a multi-shard read (Scan or +// a GSI/whole-table Query) is a VALID data read that must fence every group +// (queryLeaseAllGroups) or a CLIENT-side validation problem the read path +// rejects identically without touching data (queryLeaseSkip). It performs the +// same table-existence and GSI read-option checks prepareReadSchema runs, at a +// tentative timestamp (schema only, no readTS sampling), so the lease pre-pass +// never masks a deterministic 4xx with a degraded-lease 500. A transient/internal +// schema-read failure is returned as an error so the caller fails closed. +// +// The loaded schema is returned (nil when the table is missing or on error) so +// callers that need a further deterministic validation (the GSI Query +// KeyConditionExpression check) can reuse it without a second schema load. +// +// Validation failures are reported via queryLeaseSkip rather than an error: the +// read path re-runs the same resolution and reports the identical validation +// error, so error mapping is unchanged. +func (d *DynamoDBServer) multiShardReadLeasePlan( + ctx context.Context, + tableName string, + indexName string, + selectValue string, + projectionExpression string, + names map[string]string, + consistentRead *bool, +) (*dynamoTableSchema, queryLeasePlan, error) { + tentativeTS := snapshotTS(d.coordinator.Clock(), d.store) + schema, exists, err := d.loadTableSchemaAt(ctx, tableName, tentativeTS) + if err != nil { + // loadTableSchemaAt maps ErrKeyNotFound to (_, false, nil); any error + // reaching here is a transient store/context/decode failure, so fail + // closed. + return nil, queryLeaseAllGroups, errors.WithStack(err) + } + if !exists { + // Table not found is a deterministic ResourceNotFoundException the read + // path produces without touching data; skip the lease. + return nil, queryLeaseSkip, nil + } + // validateGSIReadOptions runs the identical unknown-index / GSI + // ConsistentRead / projection checks prepareReadSchema performs. Any failure + // is a *dynamoAPIError (a deterministic ValidationException), so classify it + // as a skip; a transient failure is impossible here (no store access). + if err := validateGSIReadOptions(schema, indexName, selectValue, projectionExpression, names, consistentRead); err != nil { + if dynamoErrIsTransient(err) { + // Defensive: validateGSIReadOptions returns only *dynamoAPIError, so + // this is unreachable. Fail closed if a future change adds a + // transient path. + return nil, queryLeaseAllGroups, errors.WithStack(err) + } + return schema, queryLeaseSkip, nil + } + return schema, queryLeaseAllGroups, nil +} + +// leaseCheckQuery lease-checks the shard a Query reads with a bounded +// timeout, writing the same InternalServerError getItem produces on +// failure. When the request resolves to a single base-table hash-key +// prefix (the common case), the check is routed to that prefix's owning +// group via LeaseReadForKey so a multi-group deployment confirms the +// shard that actually holds the data — not the default group. GSI +// queries and any request whose prefix cannot be resolved here fall back +// to the keyless check, which establishes freshness across every shard +// the range can touch. Returns false after writing an error response; +// the caller should simply return. +func (d *DynamoDBServer) leaseCheckQuery(w http.ResponseWriter, r *http.Request, in queryInput) bool { + // leaseCtx bounds the entire pre-pass — the schema read that resolves + // the lease key and the lease read itself — so a stalled schema read + // cannot block the handler past dynamoLeaseReadTimeout before the lease + // phase begins. The keyless fallback creates its own bounded context. + leaseCtx, leaseCancel := context.WithTimeout(r.Context(), dynamoLeaseReadTimeout) + defer leaseCancel() + leaseKey, plan, err := d.queryLeaseKey(leaseCtx, in) + if err != nil { + // Transient/internal schema read failure: the routing key could + // not be resolved, so fail closed by fencing EVERY group via the + // keyless check (a strict superset of the single group this query + // would have routed to). leaseReadKeyless writes the same + // InternalServerError on its own failure. + return d.leaseReadKeyless(w, leaseCtx) + } + switch plan { + case queryLeaseSkip: + // Client-side validation problem (table not found, malformed/ + // unsupported KeyConditionExpression): the request touches no data, + // so establishing freshness is unnecessary. Skip the lease entirely + // and let queryItems re-run the identical resolution and surface the + // deterministic ResourceNotFoundException/ValidationException — a + // lease failure on the fallback must not mask that 4xx with a 500 in + // a degraded deployment (codex #952 P2). This matches getItem, which + // writes the 4xx before any lease read. + return true + case queryLeaseAllGroups: + // GSI / whole-table query: a VALID read that spans multiple shards, + // so the keyless all-groups check is the correct fence (fail closed). + return d.leaseReadKeyless(w, leaseCtx) + case queryLeaseSingleGroup: + if _, err := kv.LeaseReadForKeyThrough(d.coordinator, leaseCtx, leaseKey); err != nil { + writeDynamoError(w, http.StatusInternalServerError, dynamoErrInternal, err.Error()) + return false + } + return true + default: + // Unreachable: queryLeaseKey only returns the three plans above. Fail + // closed via the all-groups fence rather than silently proceeding. + return d.leaseReadKeyless(w, leaseCtx) + } +} + +// queryLeasePlan classifies how a Query lease pre-pass must fence the read. +type queryLeasePlan int + +const ( + // queryLeaseSingleGroup: the query routes to exactly one shard group; + // fence that group via the resolved leaseKey. + queryLeaseSingleGroup queryLeasePlan = iota + // queryLeaseAllGroups: a VALID multi-shard read (GSI query or whole-table + // prefix); fence every group via the keyless all-groups check. + queryLeaseAllGroups + // queryLeaseSkip: a CLIENT-side validation problem (table not found, + // malformed/unsupported KeyConditionExpression) that the read path rejects + // deterministically without touching data; skip the lease so the handler's + // 4xx is never masked by a transient lease failure. + queryLeaseSkip +) + +// queryLeaseKey resolves the single hash-key prefix a base-table Query reads, +// at a tentative timestamp (schema only, no readTS sampling), so the lease +// check can be routed to the owning shard group. It returns: +// - (prefix, queryLeaseSingleGroup, nil) when the query routes to exactly +// one shard group; +// - (nil, queryLeaseAllGroups, nil) for a VALID multi-shard read (GSI query +// or whole-table prefix) the caller must fence across every group; +// - (nil, queryLeaseSkip, nil) for a CLIENT-side validation problem (table +// not found, unknown index, GSI ConsistentRead, malformed/unsupported +// KeyConditionExpression) the read path rejects identically — the caller +// skips the lease so the deterministic 4xx is not masked by a transient +// lease failure (codex #952 P2). GSI queries are validated against the +// table schema here before being classified as a multi-shard read so an +// invalid index can never trigger the all-groups fence (codex #952 P2-B); +// - (nil, _, err) for a TRANSIENT/INTERNAL schema-read failure (leaseCtx +// deadline, Pebble error) so the caller fails closed. +// +// Validation failures are reported via queryLeaseSkip rather than an error: the +// read path re-runs the same resolution and reports the identical validation +// error, so error mapping is unchanged. +func (d *DynamoDBServer) queryLeaseKey(ctx context.Context, in queryInput) ([]byte, queryLeasePlan, error) { + if strings.TrimSpace(in.IndexName) != "" { + // A GSI query is a multi-shard read, but only when it passes the same + // validation the read path runs: a query against a missing table, + // unknown index, GSI ConsistentRead, or malformed KeyConditionExpression + // touches no data and the read path rejects it with a deterministic 4xx. + // Fencing every group before that validation would mask the 4xx with a + // degraded-lease 500, so classify those as a skip (codex #952 P2-B). + schema, plan, err := d.multiShardReadLeasePlan(ctx, in.TableName, in.IndexName, in.Select, in.ProjectionExpression, in.ExpressionAttributeNames, in.ConsistentRead) + if err != nil { + return nil, queryLeaseAllGroups, errors.WithStack(err) + } + if plan == queryLeaseSkip { + return nil, queryLeaseSkip, nil + } + // Malformed ExclusiveStartKey is a deterministic ValidationException the + // read path rejects before the iterator is constructed (codex #952 P2 + // round-3). Skip leasing on failure so a degraded shard cannot mask + // that 4xx with a 500. + if queryExclusiveStartKeyInvalid(schema, in) { + return nil, queryLeaseSkip, nil + } + // Malformed ProjectionExpression is the same kind of deterministic + // ValidationException newReadPageState raises before the iterator + // touches data (codex #952 P2 round-4 line 2492); skip the lease so a + // degraded shard cannot mask it. + if projectionInvalid(in.ProjectionExpression, in.ExpressionAttributeNames) { + return nil, queryLeaseSkip, nil + } + // Schema + GSI options are valid; the KeyConditionExpression is the last + // deterministic validation the read path runs before touching data. + return nil, gsiQueryLeasePlan(in, schema), nil + } + tentativeTS := snapshotTS(d.coordinator.Clock(), d.store) + schema, exists, err := d.loadTableSchemaAt(ctx, in.TableName, tentativeTS) + if err != nil { + // loadTableSchemaAt maps ErrKeyNotFound to (_, false, nil); any + // error reaching here is a transient store/context/decode failure, + // so fail closed. + return nil, queryLeaseSingleGroup, errors.WithStack(err) + } + if !exists { + // Table not found is a deterministic ResourceNotFoundException the + // read path produces without touching data; skip the lease. + return nil, queryLeaseSkip, nil + } + // Same ExclusiveStartKey pre-check as the GSI branch above (base table). + if queryExclusiveStartKeyInvalid(schema, in) { + return nil, queryLeaseSkip, nil + } + // Same ProjectionExpression pre-check (base-table path; codex #952 P2 round-4 + // line 2492). + if projectionInvalid(in.ProjectionExpression, in.ExpressionAttributeNames) { + return nil, queryLeaseSkip, nil + } + prefix, plan := queryLeasePrefix(in, schema) + return prefix, plan, nil +} + +// queryLeasePrefix resolves the single hash-key prefix a base-table Query +// reads, classifying the read into queryLeaseSingleGroup (resolved prefix), +// queryLeaseAllGroups (whole-table prefix: a valid multi-shard read), or +// queryLeaseSkip (malformed KeyConditionExpression: a validation error the +// read path rejects identically). The validation error is deliberately not +// surfaced — only the routing classification matters here, and the read path +// reports the identical error downstream. +func queryLeasePrefix(in queryInput, schema *dynamoTableSchema) ([]byte, queryLeasePlan) { + keySchema, cond, err := resolveQueryCondition(in, schema) + if err != nil { + // Malformed/unsupported KeyConditionExpression: a deterministic + // ValidationException the read path produces without touching data. + return nil, queryLeaseSkip + } + // A query whose key schema hash key differs from the primary hash + // key reads the whole-table prefix (see queryScanPrefix), which can + // span multiple shards; let the all-groups check cover them. + if keySchema.HashKey != schema.PrimaryKey.HashKey { + return nil, queryLeaseAllGroups + } + prefix, err := queryScanPrefix(schema, in, keySchema, cond.hashValue) + if err != nil { + // queryScanPrefix only fails on an unparseable hash-key value — a + // ValidationException the read path rejects identically. + return nil, queryLeaseSkip + } + return prefix, queryLeaseSingleGroup +} + +// gsiQueryLeasePlan classifies a GSI Query (already known to name a valid index +// on an existing table) as the multi-shard all-groups read it is, unless its +// KeyConditionExpression is malformed — the last deterministic validation the +// read path runs before touching data. resolveQueryCondition does no store +// access and returns only *dynamoAPIError, so a failure is a ValidationException +// the read path rejects identically; classify it as a skip so the lease pre-pass +// cannot mask that 4xx with a degraded-lease 500 (codex #952 P2-B). Like +// queryLeasePrefix, the validation error is deliberately not surfaced — only the +// routing classification matters here. +func gsiQueryLeasePlan(in queryInput, schema *dynamoTableSchema) queryLeasePlan { + if _, _, err := resolveQueryCondition(in, schema); err != nil { + return queryLeaseSkip + } + return queryLeaseAllGroups +} + +// leaseCheckTransactGetItems performs a quorum-freshness lease check on every +// shard the TransactGetItems request will read, with a bounded timeout, BEFORE +// the caller resolves the single snapshot timestamp. Item keys are resolved at a +// tentative timestamp (schemas change rarely, so a slight pre-lease stale schema +// is acceptable) used only to route the lease check; the actual snapshot +// timestamp is sampled by the caller afterwards. Items whose schema or key +// cannot be resolved here are skipped: they never reach a store read, and +// buildTransactGetItemsResponses surfaces the identical validation error +// downstream so error mapping is unchanged. When every item is skipped no +// shard is touched, so the function returns true without a lease read. +// +// Keys are first deduplicated by value, then collapsed to one representative key +// per owning Raft group, so a transaction touching up to transactGetItemsMaxItems +// keys that share a group issues a single lease read instead of one per key. +// Each group maintains its own lease, so checking one key per group still +// establishes freshness for every shard the transaction reads. Returns false +// after writing the same InternalServerError getItem produces on lease failure; +// the caller should simply return. +func (d *DynamoDBServer) leaseCheckTransactGetItems(w http.ResponseWriter, r *http.Request, in transactGetItemsInput) bool { + // leaseCtx bounds the entire pre-pass — both the per-item schema reads + // that resolve keys and the lease reads themselves — so a stalled + // schema read (Pebble backpressure, iterator leak) cannot block the + // handler past dynamoLeaseReadTimeout before the lease phase begins. + leaseCtx, leaseCancel := context.WithTimeout(r.Context(), dynamoLeaseReadTimeout) + defer leaseCancel() + uniqueKeys, skipLease, transientErr := d.resolveTransactGetItemKeys(leaseCtx, in) + if transientErr != nil { + writeDynamoError(w, http.StatusInternalServerError, dynamoErrInternal, transientErr.Error()) + return false + } + if skipLease { + return true + } + groupKeys := kv.LeaseReadGroupKeys(d.coordinator, uniqueKeys) + if leaseErr := d.leaseReadGroupKeys(leaseCtx, groupKeys); leaseErr != nil { + writeDynamoError(w, http.StatusInternalServerError, dynamoErrInternal, leaseErr.Error()) + return false + } + return true +} + +// resolveTransactGetItemKeys runs the per-item schema resolution that the +// lease pre-pass needs. Returns (uniqueKeys, skipLease, transientErr) where +// skipLease is true when the read path will surface a deterministic 4xx via +// buildTransactGetItemsResponses without touching any store — leasing the +// valid items in that case only risks masking that 4xx with a degraded-shard +// 500 (codex P2 #952). skipLease covers three cases: +// - (a) every item was malformed (nothing to fence) +// - (b) at least one item was malformed and at least one was valid +// (buildTransactGetItemsResponses returns a ValidationException for the +// malformed item; the valid items never reach a store read) +// - (c) the request contains a duplicate (table, key) pair — DynamoDB +// rejects this with `Transaction request cannot include multiple +// operations on one item`, a deterministic ValidationException the read +// path produces before touching data, so the lease must be skipped for +// the same reason malformed-mixed-with-valid is skipped. +// +// transientErr is the schema-read failure the caller MUST fail closed on +// (CLAUDE.md: the slow conditions the fence targets are exactly when a +// silently-dropped item would let a stale snapshot through). +func (d *DynamoDBServer) resolveTransactGetItemKeys(ctx context.Context, in transactGetItemsInput) ([][]byte, bool, error) { + tentativeTS := snapshotTS(d.coordinator.Clock(), d.store) + schemaCache := make(map[string]*dynamoTableSchema) + seenKeys := make(map[string]struct{}, len(in.TransactItems)) + uniqueKeys := make([][]byte, 0, len(in.TransactItems)) + hasMalformed := false + hasDuplicate := false + for _, item := range in.TransactItems { + itemKey, ok, err := d.transactGetItemKey(ctx, item, schemaCache, tentativeTS) + if err != nil { + return nil, false, err + } + if !ok { + hasMalformed = true + continue + } + if _, dup := seenKeys[string(itemKey)]; dup { + hasDuplicate = true + continue + } + seenKeys[string(itemKey)] = struct{}{} + uniqueKeys = append(uniqueKeys, itemKey) + } + if hasMalformed || hasDuplicate || len(uniqueKeys) == 0 { + return nil, true, nil + } + return uniqueKeys, false, nil +} + +// leaseReadGroupKeys fences every group whose key appears in groupKeys. The +// single-group case stays on the calling goroutine; multi-group fan-out is +// concurrent so a 100-item TransactGetItems does not serialize into 100 Raft +// round-trips and blow dynamoLeaseReadTimeout (gemini HIGH on PR #952). The +// fan-out is bounded by len(groupKeys) ≤ transactGetItemsMaxItems (100), so a +// per-call goroutine pool is unnecessary. Returns the first error seen across +// all goroutines (the rest are dropped to preserve the single-response +// contract at the HTTP layer). +func (d *DynamoDBServer) leaseReadGroupKeys(ctx context.Context, groupKeys [][]byte) error { + if len(groupKeys) == 0 { + return nil + } + if len(groupKeys) == 1 { + _, err := kv.LeaseReadForKeyThrough(d.coordinator, ctx, groupKeys[0]) + return errors.WithStack(err) + } + // Derive a cancellable child so the first error cancels the sibling lease + // reads instead of letting them run out the full dynamoLeaseReadTimeout + // budget (coderabbit Major on PR #952 round-4). The siblings observe the + // cancellation via the LeaseReadForKeyThrough's own context check. + cancelCtx, cancel := context.WithCancel(ctx) + defer cancel() + errCh := make(chan error, len(groupKeys)) + var wg sync.WaitGroup + for _, itemKey := range groupKeys { + wg.Add(1) + go func(k []byte) { + defer wg.Done() + if _, err := kv.LeaseReadForKeyThrough(d.coordinator, cancelCtx, k); err != nil { + select { + case errCh <- err: + cancel() // unwind the remaining goroutines on the first error. + default: + } + } + }(itemKey) + } + wg.Wait() + close(errCh) + for err := range errCh { + if err != nil { + return err + } + } + return nil +} + +// transactGetItemKey resolves the storage key for one TransactGetItems Get at +// the tentative timestamp. It returns (key, true, nil) on success, +// (nil, false, nil) when the item is MALFORMED (nil Get, empty/unknown table, +// or an invalid key) — the read path rejects those identically, so the lease +// pre-pass may safely skip them — and (nil, false, err) for a TRANSIENT or +// INTERNAL schema-read failure (leaseCtx deadline, Pebble error) that the +// caller MUST fail closed on rather than skip, otherwise the item's shard goes +// unfenced and a stale read can slip through. The malformed/transient split +// keys off dynamoErrIsTransient: validation errors are *dynamoAPIError, +// everything else is treated as transient. It never writes a response: +// validation is left to the read path so error mapping stays identical. +func (d *DynamoDBServer) transactGetItemKey(ctx context.Context, item transactGetItem, schemaCache map[string]*dynamoTableSchema, tentativeTS uint64) ([]byte, bool, error) { + if item.Get == nil || strings.TrimSpace(item.Get.TableName) == "" { + return nil, false, nil + } + schema, err := d.resolveTransactTableSchema(ctx, schemaCache, item.Get.TableName, tentativeTS) + if err != nil { + if dynamoErrIsTransient(err) { + return nil, false, errors.WithStack(err) + } + // Validation error (table not found): the read path rejects it + // identically, so skip rather than fail closed. + return nil, false, nil + } + // itemKeyFromAttributes only fails on malformed key attributes + // (missing/unparseable hash or range key), a pure validation error the + // read path rejects identically; transactGetItemKeyFromSchema swallows + // it to ok=false so the item is skipped, not failed closed. + itemKey, ok := transactGetItemKeyFromSchema(schema, item.Get.Key) + return itemKey, ok, nil +} + +// transactGetItemKeyFromSchema computes the storage key for a TransactGetItems +// Get, returning ok=false when the key attributes are malformed. The error is +// deliberately discarded: it is a validation failure the read path reports +// downstream, and the lease pre-pass only needs the routing key. +func transactGetItemKeyFromSchema(schema *dynamoTableSchema, key map[string]attributeValue) ([]byte, bool) { + itemKey, err := schema.itemKeyFromAttributes(key) + if err != nil { + return nil, false + } + return itemKey, true +} diff --git a/adapter/dynamodb_query_scan.go b/adapter/dynamodb_query_scan.go new file mode 100644 index 00000000..90db633c --- /dev/null +++ b/adapter/dynamodb_query_scan.go @@ -0,0 +1,2281 @@ +package adapter + +import ( + "bytes" + "context" + "encoding/base64" + "io" + "net/http" + "sort" + "strings" + "sync" + + "github.com/bootjp/elastickv/store" + "github.com/cockroachdb/errors" + json "github.com/goccy/go-json" +) + +type queryInput struct { + TableName string `json:"TableName"` + IndexName string `json:"IndexName"` + KeyConditionExpression string `json:"KeyConditionExpression"` + FilterExpression string `json:"FilterExpression"` + ProjectionExpression string `json:"ProjectionExpression"` + ExpressionAttributeNames map[string]string `json:"ExpressionAttributeNames"` + ExpressionAttributeValues map[string]attributeValue `json:"ExpressionAttributeValues"` + ScanIndexForward *bool `json:"ScanIndexForward"` + Limit *int32 `json:"Limit"` + ExclusiveStartKey map[string]attributeValue `json:"ExclusiveStartKey"` + Select string `json:"Select"` + ConsistentRead *bool `json:"ConsistentRead"` +} + +type scanInput struct { + TableName string `json:"TableName"` + IndexName string `json:"IndexName"` + FilterExpression string `json:"FilterExpression"` + ProjectionExpression string `json:"ProjectionExpression"` + ExpressionAttributeNames map[string]string `json:"ExpressionAttributeNames"` + ExpressionAttributeValues map[string]attributeValue `json:"ExpressionAttributeValues"` + ExclusiveStartKey map[string]attributeValue `json:"ExclusiveStartKey"` + Limit *int32 `json:"Limit"` + Select string `json:"Select"` + ConsistentRead *bool `json:"ConsistentRead"` +} + +func (d *DynamoDBServer) query(w http.ResponseWriter, r *http.Request) { + in, err := decodeQueryInput(maxDynamoBodyReader(w, r)) + if err != nil { + writeDynamoErrorFromErr(w, err) + return + } + // Lease-check the shard the Query reads BEFORE queryItems samples + // readTS, so the quorum-freshness bound is established without + // changing read-snapshot semantics (sampling readTS only after the + // confirmation keeps any commit that landed before it visible). A + // base-table Query on a single partition key reads exactly one + // hash-key prefix, which routes to one shard group, so the check is + // routed by that prefix in a multi-group deployment. GSI queries and + // queries whose prefix cannot be resolved fall back to the keyless + // check, which spans every shard the range can touch. + if !d.leaseCheckQuery(w, r, in) { + return + } + out, err := d.queryItems(r.Context(), in) + if err != nil { + writeDynamoErrorFromErr(w, err) + return + } + d.observeReadMetrics(r.Context(), in.TableName, out.count, out.scannedCount) + resp := map[string]any{ + "Items": out.items, + "Count": out.count, + "ScannedCount": out.scannedCount, + } + if len(out.lastEvaluatedKey) > 0 { + resp["LastEvaluatedKey"] = out.lastEvaluatedKey + } + writeDynamoJSON(w, resp) +} + +func (d *DynamoDBServer) scan(w http.ResponseWriter, r *http.Request) { + in, err := decodeScanInput(maxDynamoBodyReader(w, r)) + if err != nil { + writeDynamoErrorFromErr(w, err) + return + } + // A Scan reads the whole table and therefore spans every shard that + // holds any of its items. leaseCheckScan establishes the quorum-freshness + // bound across every group BEFORE scanItems samples readTS — but only for + // a request that passes the cheap table/GSI validation, so a scan that + // will deterministically 4xx is not masked by a degraded-lease 500 + // (codex #952 P2-A). + if !d.leaseCheckScan(w, r, in) { + return + } + out, err := d.scanItems(r.Context(), in) + if err != nil { + writeDynamoErrorFromErr(w, err) + return + } + d.observeReadMetrics(r.Context(), in.TableName, out.count, out.scannedCount) + resp := map[string]any{ + "Items": out.items, + "Count": out.count, + "ScannedCount": out.scannedCount, + } + if len(out.lastEvaluatedKey) > 0 { + resp["LastEvaluatedKey"] = out.lastEvaluatedKey + } + writeDynamoJSON(w, resp) +} + +// projectionInvalid returns true when the ProjectionExpression cannot be +// parsed against the given ExpressionAttributeNames. resolveProjectionAttributes +// is the same validator newReadPageState runs before the iterator reads from +// the store, so a true result means the read path WILL reject the request with +// a deterministic ValidationException without touching data. Pre-pass uses +// this to skip leasing in that case (codex #952 P2 round-4 lines 2346, 2492). +// An empty ProjectionExpression is the common "project everything" case and +// returns false (no validation needed). +func projectionInvalid(projectionExpression string, names map[string]string) bool { + if strings.TrimSpace(projectionExpression) == "" { + return false + } + _, err := resolveProjectionAttributes(projectionExpression, names) + return err != nil +} + +// scanExclusiveStartKeyInvalid returns true when in.ExclusiveStartKey cannot be +// decoded against the table's primary key (Scan with no IndexName) or the named +// GSI (Scan with IndexName). It mirrors the validation resolveTableReadBounds / +// resolveGSIReadBounds run in scanItems so the lease pre-pass can route the +// invalid case to the same skip-lease path as table-not-found etc. A nil schema +// is treated as "not invalid" because multiShardReadLeasePlan already classified +// the request as queryLeaseSkip in that case and we never reach here. +func scanExclusiveStartKeyInvalid(schema *dynamoTableSchema, in scanInput) bool { + if schema == nil || len(in.ExclusiveStartKey) == 0 { + return false + } + if strings.TrimSpace(in.IndexName) == "" { + _, err := schema.itemKeyFromAttributes(in.ExclusiveStartKey) + return err != nil + } + _, _, err := schema.gsiKeyFromAttributes(in.IndexName, in.ExclusiveStartKey) + return err != nil +} + +// queryExclusiveStartKeyInvalid mirrors the validation +// resolveQueryExclusiveStartKey runs inside queryItems' read-bounds resolution +// (`adapter/dynamodb.go` resolveQueryExclusiveStartKey / resolveTableReadBounds / +// resolveGSIReadBounds): a malformed ExclusiveStartKey produces a deterministic +// ValidationException without touching any store. Returning true routes the +// lease pre-pass to queryLeaseSkip so a degraded-shard 500 cannot mask that 4xx +// (codex #952 P2 round-3). Mirrors scanExclusiveStartKeyInvalid for the +// Query path — kept separate because the GSI vs base-table dispatch differs +// from the Scan input. +func queryExclusiveStartKeyInvalid(schema *dynamoTableSchema, in queryInput) bool { + if schema == nil || len(in.ExclusiveStartKey) == 0 { + return false + } + if strings.TrimSpace(in.IndexName) == "" { + _, err := schema.itemKeyFromAttributes(in.ExclusiveStartKey) + return err != nil + } + _, _, err := schema.gsiKeyFromAttributes(in.IndexName, in.ExclusiveStartKey) + return err != nil +} + +func decodeQueryInput(bodyReader io.Reader) (queryInput, error) { + body, err := io.ReadAll(bodyReader) + if err != nil { + return queryInput{}, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) + } + var in queryInput + if err := json.Unmarshal(body, &in); err != nil { + return queryInput{}, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) + } + if strings.TrimSpace(in.TableName) == "" { + return queryInput{}, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, "missing table name") + } + if err := validateReadSelect(in.Select); err != nil { + return queryInput{}, err + } + if _, _, err := resolveReadLimit(in.Limit); err != nil { + return queryInput{}, err + } + return in, nil +} + +func decodeScanInput(bodyReader io.Reader) (scanInput, error) { + body, err := io.ReadAll(bodyReader) + if err != nil { + return scanInput{}, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) + } + var in scanInput + if err := json.Unmarshal(body, &in); err != nil { + return scanInput{}, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) + } + if strings.TrimSpace(in.TableName) == "" { + return scanInput{}, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, "missing table name") + } + if err := validateReadSelect(in.Select); err != nil { + return scanInput{}, err + } + if _, _, err := resolveReadLimit(in.Limit); err != nil { + return scanInput{}, err + } + return in, nil +} + +type queryOutput struct { + items []map[string]attributeValue + count int + scannedCount int + lastEvaluatedKey map[string]attributeValue +} + +type readPageOptions struct { + filterExpression string + projectionExpression string + expressionAttributeNames map[string]string + expressionAttributeValues map[string]attributeValue + exclusiveStartKey map[string]attributeValue + limit *int32 + selectValue string + lastEvaluatedKeyBuilder func(map[string]attributeValue) map[string]attributeValue +} + +type dynamoReadIterator interface { + Next(context.Context) (map[string]attributeValue, bool, error) +} + +type queryRangeOperator string + +const ( + queryRangeOpEqual queryRangeOperator = "=" + queryRangeOpLessThan queryRangeOperator = "<" + queryRangeOpLessOrEq queryRangeOperator = "<=" + queryRangeOpGreater queryRangeOperator = ">" + queryRangeOpGreaterEq queryRangeOperator = ">=" + queryRangeOpBetween queryRangeOperator = "BETWEEN" + queryRangeOpBeginsWith queryRangeOperator = "BEGINS_WITH" +) + +type queryRangeCondition struct { + attr string + op queryRangeOperator + value1 attributeValue + value2 attributeValue +} + +type queryCondition struct { + hashAttr string + hashValue attributeValue + rangeCond *queryRangeCondition +} + +func (d *DynamoDBServer) queryItems(ctx context.Context, in queryInput) (*queryOutput, error) { + schema, readTS, err := d.prepareReadSchema(ctx, in.TableName, in.IndexName, in.Select, in.ProjectionExpression, in.ExpressionAttributeNames, in.ConsistentRead) + if err != nil { + return nil, err + } + readPin := d.pinReadTS(readTS) + defer readPin.Release() + keySchema, cond, err := resolveQueryCondition(in, schema) + if err != nil { + return nil, err + } + opts := readPageOptions{ + filterExpression: in.FilterExpression, + projectionExpression: in.ProjectionExpression, + expressionAttributeNames: in.ExpressionAttributeNames, + expressionAttributeValues: in.ExpressionAttributeValues, + exclusiveStartKey: in.ExclusiveStartKey, + limit: in.Limit, + selectValue: in.Select, + lastEvaluatedKeyBuilder: func(item map[string]attributeValue) map[string]attributeValue { + return makeReadLastEvaluatedKey(schema.PrimaryKey, keySchema, item) + }, + } + if schema.MigratingFromGeneration == 0 { + if out, ok, err := d.streamQueryItems(ctx, in, schema, keySchema, cond, readTS, opts); ok || err != nil { + return out, err + } + } + items, err := d.loadQueryItemsWithMigration(ctx, in, schema, keySchema, cond, readTS) + if err != nil { + return nil, err + } + items, err = projectReadItemsForIndex(schema, in.IndexName, items) + if err != nil { + return nil, err + } + orderQueryItems(items, keySchema.RangeKey, in.ScanIndexForward) + return finalizeReadPage(schema, items, opts) +} + +func (d *DynamoDBServer) scanItems(ctx context.Context, in scanInput) (*queryOutput, error) { + schema, readTS, err := d.prepareReadSchema(ctx, in.TableName, in.IndexName, in.Select, in.ProjectionExpression, in.ExpressionAttributeNames, in.ConsistentRead) + if err != nil { + return nil, err + } + readPin := d.pinReadTS(readTS) + defer readPin.Release() + indexKeySchema := schema.PrimaryKey + if strings.TrimSpace(in.IndexName) != "" { + indexKeySchema, err = schema.keySchemaForQuery(in.IndexName) + if err != nil { + return nil, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) + } + } + opts := readPageOptions{ + filterExpression: in.FilterExpression, + projectionExpression: in.ProjectionExpression, + expressionAttributeNames: in.ExpressionAttributeNames, + expressionAttributeValues: in.ExpressionAttributeValues, + exclusiveStartKey: in.ExclusiveStartKey, + limit: in.Limit, + selectValue: in.Select, + lastEvaluatedKeyBuilder: func(item map[string]attributeValue) map[string]attributeValue { + return makeReadLastEvaluatedKey(schema.PrimaryKey, indexKeySchema, item) + }, + } + if schema.MigratingFromGeneration == 0 { + if out, ok, err := d.streamScanItems(ctx, in, schema, readTS, opts); ok || err != nil { + return out, err + } + } + items, err := d.loadScanItemsWithMigration(ctx, in, schema, indexKeySchema, readTS) + if err != nil { + return nil, err + } + items, err = projectReadItemsForIndex(schema, in.IndexName, items) + if err != nil { + return nil, err + } + return finalizeReadPage(schema, items, opts) +} + +func (d *DynamoDBServer) prepareReadSchema( + ctx context.Context, + tableName string, + indexName string, + selectValue string, + projectionExpression string, + names map[string]string, + consistentRead *bool, +) (*dynamoTableSchema, uint64, error) { + if err := d.ensureLegacyTableMigration(ctx, tableName); err != nil { + return nil, 0, err + } + readTS := d.resolveDynamoReadTS(consistentRead) + schema, exists, err := d.loadTableSchemaAt(ctx, tableName, readTS) + if err != nil { + return nil, 0, errors.WithStack(err) + } + if !exists { + return nil, 0, newDynamoAPIError(http.StatusBadRequest, dynamoErrResourceNotFound, "table not found") + } + if err := validateGSIReadOptions(schema, indexName, selectValue, projectionExpression, names, consistentRead); err != nil { + return nil, 0, err + } + return schema, readTS, nil +} + +func (d *DynamoDBServer) loadQueryItemsWithMigration( + ctx context.Context, + in queryInput, + schema *dynamoTableSchema, + keySchema dynamoKeySchema, + cond queryCondition, + readTS uint64, +) ([]map[string]attributeValue, error) { + items, err := d.queryItemsByKeyCondition(ctx, in, schema, keySchema, cond, readTS) + if err != nil { + return nil, err + } + return d.mergeReadItemsFromSourceSchema(schema, keySchema, items, func(sourceSchema *dynamoTableSchema) ([]map[string]attributeValue, error) { + return d.queryItemsByKeyCondition(ctx, in, sourceSchema, keySchema, cond, readTS) + }) +} + +func (d *DynamoDBServer) loadScanItemsWithMigration( + ctx context.Context, + in scanInput, + schema *dynamoTableSchema, + indexKeySchema dynamoKeySchema, + readTS uint64, +) ([]map[string]attributeValue, error) { + items, err := d.scanItemsBySource(ctx, in, schema, readTS) + if err != nil { + return nil, err + } + return d.mergeReadItemsFromSourceSchema(schema, indexKeySchema, items, func(sourceSchema *dynamoTableSchema) ([]map[string]attributeValue, error) { + return d.scanItemsBySource(ctx, in, sourceSchema, readTS) + }) +} + +func (d *DynamoDBServer) mergeReadItemsFromSourceSchema( + schema *dynamoTableSchema, + orderKey dynamoKeySchema, + items []map[string]attributeValue, + loadSource func(*dynamoTableSchema) ([]map[string]attributeValue, error), +) ([]map[string]attributeValue, error) { + sourceSchema := schema.migrationSourceSchema() + if sourceSchema == nil { + return items, nil + } + sourceItems, err := loadSource(sourceSchema) + if err != nil { + return nil, err + } + return mergeMigratingReadItems(schema.PrimaryKey, orderKey, items, sourceItems) +} + +// consistentReadLatestTS is a read timestamp used for ConsistentRead=true reads. +// The value is far above any realistic HLC timestamp (~Unix-nanosecond range, +// ≪ 10^19), so reading at this TS from the leader's Pebble store returns the +// most recently committed version of any key. It avoids the noStartTS +// sentinel (^uint64(0)) used by the coordinator. +// +// This sentinel is used on BOTH the leader and followers: +// - On a follower, the read is proxied to the leader via proxyRawGet with +// ts=consistentReadLatestTS, so the leader reads the absolute latest version. +// - On the leader, the LeaderRoutedStore performs a linearizable read fence +// (ensuring all committed Raft entries are applied) and then reads locally +// at consistentReadLatestTS, returning the latest committed version. +// +// Using store.LastCommitTS() instead would introduce a TOCTOU race: the +// timestamp is captured before the linearizable fence, so a write committed +// after LastCommitTS() but applied during the fence would be missed. +const consistentReadLatestTS = ^uint64(0) - 1 + +func (d *DynamoDBServer) resolveDynamoReadTS(consistentRead *bool) uint64 { + if consistentRead != nil && *consistentRead { + return consistentReadLatestTS + } + return snapshotTS(d.coordinator.Clock(), d.store) +} + +func validateGSIReadOptions( + schema *dynamoTableSchema, + indexName string, + selectValue string, + projectionExpression string, + names map[string]string, + consistentRead *bool, +) error { + if strings.TrimSpace(indexName) == "" { + return nil + } + attrs, err := resolveProjectionAttributes(projectionExpression, names) + if err != nil { + return err + } + return validateProjectedGSIRead(schema, indexName, selectValue, attrs, consistentRead) +} + +func validateProjectedGSIRead( + schema *dynamoTableSchema, + indexName string, + selectValue string, + attrs []string, + consistentRead *bool, +) error { + if err := validateGSIConsistentRead(consistentRead); err != nil { + return err + } + allProjected, projected, err := schema.gsiProjectedAttributeSet(indexName) + if err != nil { + return newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) + } + if err := validateGSISelectValue(selectValue, allProjected); err != nil { + return err + } + return validateProjectedAttributes(attrs, projected, allProjected) +} + +func validateGSIConsistentRead(consistentRead *bool) error { + if consistentRead == nil || !*consistentRead { + return nil + } + return newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, "ConsistentRead is not supported on global secondary indexes") +} + +func validateGSISelectValue(selectValue string, allProjected bool) error { + if !strings.EqualFold(strings.TrimSpace(selectValue), "ALL_ATTRIBUTES") || allProjected { + return nil + } + return newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, "ALL_ATTRIBUTES is not supported for this index projection") +} + +func validateProjectedAttributes(attrs []string, projected map[string]struct{}, allProjected bool) error { + if allProjected || len(attrs) == 0 { + return nil + } + for _, attr := range attrs { + if _, ok := projected[attr]; ok { + continue + } + return newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, "requested attribute is not projected into index") + } + return nil +} + +func projectReadItemsForIndex(schema *dynamoTableSchema, indexName string, items []map[string]attributeValue) ([]map[string]attributeValue, error) { + if strings.TrimSpace(indexName) == "" || len(items) == 0 { + return items, nil + } + out := make([]map[string]attributeValue, 0, len(items)) + for _, item := range items { + projected, err := schema.projectItemForIndex(indexName, item) + if err != nil { + return nil, err + } + out = append(out, projected) + } + return out, nil +} + +func resolveQueryCondition(in queryInput, schema *dynamoTableSchema) (dynamoKeySchema, queryCondition, error) { + keySchema, err := schema.keySchemaForQuery(in.IndexName) + if err != nil { + return dynamoKeySchema{}, queryCondition{}, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) + } + keyExpr, err := replaceNames(in.KeyConditionExpression, in.ExpressionAttributeNames) + if err != nil { + return dynamoKeySchema{}, queryCondition{}, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) + } + parsed, err := parseKeyConditionExpression(keyExpr) + if err != nil { + return dynamoKeySchema{}, queryCondition{}, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) + } + cond, err := buildQueryCondition(keySchema, parsed, in.ExpressionAttributeValues) + if err != nil { + return dynamoKeySchema{}, queryCondition{}, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) + } + return keySchema, cond, nil +} + +func filterQueryItems(kvs []*store.KVPair, cond queryCondition) ([]map[string]attributeValue, error) { + items := make([]map[string]attributeValue, 0, len(kvs)) + for _, kvp := range kvs { + item, err := decodeStoredDynamoItem(kvp.Value) + if err != nil { + return nil, err + } + if !matchesQueryCondition(item, cond) { + continue + } + items = append(items, item) + } + return items, nil +} + +func orderQueryItems(items []map[string]attributeValue, rangeKey string, scanIndexForward *bool) { + if rangeKey != "" { + sort.SliceStable(items, func(i, j int) bool { + return compareAttributeValueSortKey(items[i][rangeKey], items[j][rangeKey]) < 0 + }) + } + scanForward := true + if scanIndexForward != nil { + scanForward = *scanIndexForward + } + if !scanForward { + reverseItems(items) + } +} + +func mergeMigratingReadItems( + primaryKey dynamoKeySchema, + orderKey dynamoKeySchema, + preferred []map[string]attributeValue, + source []map[string]attributeValue, +) ([]map[string]attributeValue, error) { + if len(source) == 0 { + return preferred, nil + } + out := make([]map[string]attributeValue, 0, len(preferred)+len(source)) + seen := make(map[string]struct{}, len(preferred)+len(source)) + appendItem := func(item map[string]attributeValue) error { + identity, err := itemPrimaryIdentity(primaryKey, item) + if err != nil { + return err + } + if _, ok := seen[identity]; ok { + return nil + } + seen[identity] = struct{}{} + out = append(out, item) + return nil + } + for _, item := range preferred { + if err := appendItem(item); err != nil { + return nil, err + } + } + for _, item := range source { + if err := appendItem(item); err != nil { + return nil, err + } + } + sort.SliceStable(out, func(i, j int) bool { + return compareReadOrder(orderKey, primaryKey, out[i], out[j]) < 0 + }) + return out, nil +} + +func itemPrimaryIdentity(keySchema dynamoKeySchema, item map[string]attributeValue) (string, error) { + var b strings.Builder + if err := appendIdentityPart(&b, item, keySchema.HashKey); err != nil { + return "", err + } + if keySchema.RangeKey != "" { + if err := appendIdentityPart(&b, item, keySchema.RangeKey); err != nil { + return "", err + } + } + return b.String(), nil +} + +func appendIdentityPart(b *strings.Builder, item map[string]attributeValue, attrName string) error { + attr, ok := item[attrName] + if !ok { + return errors.New("missing key attribute") + } + key, err := attributeValueAsKeySegment(attr) + if err != nil { + return err + } + b.WriteString(attrName) + b.WriteByte('=') + b.WriteString(base64.RawURLEncoding.EncodeToString(key)) + b.WriteByte('|') + return nil +} + +func compareReadOrder(orderKey dynamoKeySchema, primaryKey dynamoKeySchema, left map[string]attributeValue, right map[string]attributeValue) int { + if cmp := compareAttributeValueByName(orderKey.HashKey, left, right); cmp != 0 { + return cmp + } + if orderKey.RangeKey != "" { + if cmp := compareAttributeValueByName(orderKey.RangeKey, left, right); cmp != 0 { + return cmp + } + } + if cmp := compareAttributeValueByName(primaryKey.HashKey, left, right); cmp != 0 { + return cmp + } + if primaryKey.RangeKey != "" { + if cmp := compareAttributeValueByName(primaryKey.RangeKey, left, right); cmp != 0 { + return cmp + } + } + return 0 +} + +func compareAttributeValueByName(attrName string, left map[string]attributeValue, right map[string]attributeValue) int { + if attrName == "" { + return 0 + } + leftAttr, leftOK := left[attrName] + rightAttr, rightOK := right[attrName] + switch { + case !leftOK && !rightOK: + return 0 + case !leftOK: + return -1 + case !rightOK: + return 1 + default: + return compareAttributeValueSortKey(leftAttr, rightAttr) + } +} + +func validateReadSelect(selectValue string) error { + switch strings.TrimSpace(selectValue) { + case "", "ALL_ATTRIBUTES", "ALL_PROJECTED_ATTRIBUTES", "SPECIFIC_ATTRIBUTES", "COUNT": + return nil + default: + return newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, "unsupported Select") + } +} + +func resolveReadLimit(limit *int32) (int, bool, error) { + if limit == nil { + return 0, false, nil + } + if *limit <= 0 { + return 0, false, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, "invalid Limit") + } + return int(*limit), true, nil +} + +func readIteratorPageLimit(limit *int32) int { + resolved, hasLimit, err := resolveReadLimit(limit) + if err != nil || !hasLimit { + return dynamoScanPageLimit + } + pageLimit := resolved + 1 + if pageLimit > dynamoScanPageLimit { + return dynamoScanPageLimit + } + if pageLimit < 1 { + return 1 + } + return pageLimit +} + +func (d *DynamoDBServer) streamQueryItems( + ctx context.Context, + in queryInput, + schema *dynamoTableSchema, + keySchema dynamoKeySchema, + cond queryCondition, + readTS uint64, + opts readPageOptions, +) (*queryOutput, bool, error) { + iter, ok, err := d.newQueryReadIterator(in, schema, keySchema, cond, readTS, opts) + if err != nil || !ok { + return nil, ok, err + } + out, err := finalizeReadIterator(ctx, schema, iter, opts) + if err != nil { + return nil, true, err + } + return out, true, nil +} + +func (d *DynamoDBServer) streamScanItems( + ctx context.Context, + in scanInput, + schema *dynamoTableSchema, + readTS uint64, + opts readPageOptions, +) (*queryOutput, bool, error) { + iter, ok, err := d.newScanReadIterator(in, schema, readTS, opts) + if err != nil || !ok { + return nil, ok, err + } + out, err := finalizeReadIterator(ctx, schema, iter, opts) + if err != nil { + return nil, true, err + } + return out, true, nil +} + +func finalizeReadIterator( + ctx context.Context, + schema *dynamoTableSchema, + iter dynamoReadIterator, + opts readPageOptions, +) (*queryOutput, error) { + state, err := newReadPageState(schema, opts) + if err != nil { + return nil, err + } + if err := state.consumeIterator(ctx, iter); err != nil { + return nil, err + } + return state.output(), nil +} + +func (d *DynamoDBServer) newQueryReadIterator( + in queryInput, + schema *dynamoTableSchema, + keySchema dynamoKeySchema, + cond queryCondition, + readTS uint64, + opts readPageOptions, +) (dynamoReadIterator, bool, error) { + projector := d.readItemProjector(schema, in.IndexName) + filter := itemReadFilter(func(item map[string]attributeValue) bool { + return matchesQueryCondition(item, cond) + }) + pageLimit := readIteratorPageLimit(opts.limit) + bounds, ok, err := resolveQueryReadBounds(schema, in, keySchema, cond, opts.exclusiveStartKey) + if err != nil || !ok { + return nil, ok, err + } + if strings.TrimSpace(in.IndexName) == "" { + return newTableReadIterator(d, bounds, readTS, pageLimit, projector, filter), true, nil + } + return newGSIReadIterator(d, bounds, readTS, pageLimit, projector, filter), true, nil +} + +func (d *DynamoDBServer) newScanReadIterator( + in scanInput, + schema *dynamoTableSchema, + readTS uint64, + opts readPageOptions, +) (dynamoReadIterator, bool, error) { + projector := d.readItemProjector(schema, in.IndexName) + pageLimit := readIteratorPageLimit(opts.limit) + if strings.TrimSpace(in.IndexName) == "" { + bounds, err := resolveTableReadBounds(schema, in.TableName, opts.exclusiveStartKey) + if err != nil { + return nil, false, err + } + return newTableReadIterator(d, bounds, readTS, pageLimit, projector, nil), true, nil + } + if _, err := schema.keySchemaForQuery(in.IndexName); err != nil { + return nil, false, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) + } + bounds, ok, err := resolveGSIReadBounds(schema, in.TableName, in.IndexName, opts.exclusiveStartKey) + if err != nil { + return nil, false, err + } + if len(opts.exclusiveStartKey) > 0 && !ok { + return nil, false, nil + } + return newGSIReadIterator(d, bounds, readTS, pageLimit, projector, nil), true, nil +} + +func resolveTableReadBounds( + schema *dynamoTableSchema, + tableName string, + startKey map[string]attributeValue, +) (dynamoReadBounds, error) { + lower := dynamoItemPrefixForTable(tableName, schema.Generation) + upper := prefixScanEnd(lower) + if len(startKey) == 0 { + return dynamoReadBounds{lower: lower, upper: upper}, nil + } + key, err := schema.itemKeyFromAttributes(startKey) + if err != nil { + return dynamoReadBounds{}, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, "invalid ExclusiveStartKey") + } + return dynamoReadBounds{lower: maxBytes(lower, nextScanCursor(key)), upper: upper}, nil +} + +func resolveGSIReadBounds( + schema *dynamoTableSchema, + tableName string, + indexName string, + startKey map[string]attributeValue, +) (dynamoReadBounds, bool, error) { + lower := dynamoGSIIndexPrefixForTable(tableName, schema.Generation, indexName) + upper := prefixScanEnd(lower) + if len(startKey) == 0 { + return dynamoReadBounds{lower: lower, upper: upper}, true, nil + } + key, ok, err := schema.gsiKeyFromAttributes(indexName, startKey) + if err != nil { + return dynamoReadBounds{}, false, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, "invalid ExclusiveStartKey") + } + if !ok { + return dynamoReadBounds{}, false, nil + } + return dynamoReadBounds{lower: maxBytes(lower, nextScanCursor(key)), upper: upper}, true, nil +} + +func resolveQueryReadBounds( + schema *dynamoTableSchema, + in queryInput, + keySchema dynamoKeySchema, + cond queryCondition, + startKey map[string]attributeValue, +) (dynamoReadBounds, bool, error) { + if !schema.usesOrderedKeyEncoding() { + return dynamoReadBounds{}, false, nil + } + basePrefix, err := queryScanPrefix(schema, in, keySchema, cond.hashValue) + if err != nil { + return dynamoReadBounds{}, false, err + } + bounds := dynamoReadBounds{ + lower: basePrefix, + upper: prefixScanEnd(basePrefix), + reverse: queryReadReverse(in.ScanIndexForward), + } + if keySchema.RangeKey != "" && cond.rangeCond != nil { + bounds, err = refineQueryReadBounds(bounds, basePrefix, *cond.rangeCond) + if err != nil { + return dynamoReadBounds{}, false, err + } + } + if len(startKey) == 0 { + return bounds, true, nil + } + startCursor, ok, err := resolveQueryExclusiveStartKey(schema, in, startKey) + if err != nil { + return dynamoReadBounds{}, false, err + } + if !ok { + return dynamoReadBounds{}, false, nil + } + if bounds.reverse { + bounds.upper = minBytes(bounds.upper, startCursor) + } else { + bounds.lower = maxBytes(bounds.lower, nextScanCursor(startCursor)) + } + return bounds, true, nil +} + +func resolveQueryExclusiveStartKey( + schema *dynamoTableSchema, + in queryInput, + startKey map[string]attributeValue, +) ([]byte, bool, error) { + if len(startKey) == 0 { + return nil, true, nil + } + if strings.TrimSpace(in.IndexName) == "" { + key, err := schema.itemKeyFromAttributes(startKey) + if err != nil { + return nil, false, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, "invalid ExclusiveStartKey") + } + return key, true, nil + } + key, ok, err := schema.gsiKeyFromAttributes(in.IndexName, startKey) + if err != nil { + return nil, false, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, "invalid ExclusiveStartKey") + } + return key, ok, nil +} + +func queryReadReverse(scanIndexForward *bool) bool { + return scanIndexForward != nil && !*scanIndexForward +} + +type readItemProjector func(map[string]attributeValue) (map[string]attributeValue, error) + +func (d *DynamoDBServer) readItemProjector(schema *dynamoTableSchema, indexName string) readItemProjector { + if strings.TrimSpace(indexName) == "" { + return identityReadItemProjector + } + return func(item map[string]attributeValue) (map[string]attributeValue, error) { + return schema.projectItemForIndex(indexName, item) + } +} + +func identityReadItemProjector(item map[string]attributeValue) (map[string]attributeValue, error) { + return item, nil +} + +func finalizeReadPage(schema *dynamoTableSchema, items []map[string]attributeValue, opts readPageOptions) (*queryOutput, error) { + items, err := applyQueryExclusiveStartKey(schema, opts.exclusiveStartKey, items) + if err != nil { + return nil, err + } + state, err := newReadPageState(schema, opts) + if err != nil { + return nil, err + } + if err := state.consume(items); err != nil { + return nil, err + } + return state.output(), nil +} + +type readPageState struct { + schema *dynamoTableSchema + opts readPageOptions + projection []string + filterExpr string + includeItems bool + limit int + hasLimit bool + outItems []map[string]attributeValue + outCount int + scannedCount int + lastEvaluatedKey map[string]attributeValue +} + +type dynamoReadBounds struct { + lower []byte + upper []byte + reverse bool +} + +type keyRangeKVIterator struct { + server *DynamoDBServer + lower []byte + upper []byte + cursor []byte + readTS uint64 + pageLimit int + reverse bool + page []*store.KVPair + index int + done bool +} + +type emptyReadIterator struct{} + +type tableReadIterator struct { + kv *keyRangeKVIterator + projector readItemProjector + filter itemReadFilter +} + +type gsiReadIterator struct { + server *DynamoDBServer + kv *keyRangeKVIterator + readTS uint64 + projector readItemProjector + filter itemReadFilter + seen map[string]struct{} +} + +func newReadPageState(schema *dynamoTableSchema, opts readPageOptions) (*readPageState, error) { + limit, hasLimit, err := resolveReadLimit(opts.limit) + if err != nil { + return nil, err + } + projection, err := resolveProjectionAttributes(opts.projectionExpression, opts.expressionAttributeNames) + if err != nil { + return nil, err + } + filterExpr, err := replaceNames(opts.filterExpression, opts.expressionAttributeNames) + if err != nil { + return nil, err + } + return &readPageState{ + schema: schema, + opts: opts, + projection: projection, + filterExpr: strings.TrimSpace(filterExpr), + includeItems: !strings.EqualFold(strings.TrimSpace(opts.selectValue), dynamoSelectCount), + limit: limit, + hasLimit: hasLimit, + outItems: make([]map[string]attributeValue, 0), + }, nil +} + +func (s *readPageState) consume(items []map[string]attributeValue) error { + for i, item := range items { + if s.reachedLimit() { + break + } + if err := s.consumeItem(i, item, len(items)); err != nil { + return err + } + } + return nil +} + +func (s *readPageState) consumeIterator(ctx context.Context, iter dynamoReadIterator) error { + var lastItem map[string]attributeValue + for !s.reachedLimit() { + item, ok, err := iter.Next(ctx) + if err != nil { + return errors.WithStack(err) + } + if !ok { + return nil + } + if err := s.consumeReadItem(item); err != nil { + return err + } + lastItem = item + } + if lastItem == nil { + return nil + } + if nextItem, ok, err := iter.Next(ctx); err != nil { + return errors.WithStack(err) + } else if ok && nextItem != nil { + s.lastEvaluatedKey = s.buildLastEvaluatedKey(lastItem) + } + return nil +} + +func (s *readPageState) reachedLimit() bool { + return s.hasLimit && s.scannedCount == s.limit +} + +func (s *readPageState) consumeReadItem(item map[string]attributeValue) error { + s.scannedCount++ + match, err := matchesReadFilter(s.filterExpr, item, s.opts.expressionAttributeValues) + if err != nil { + return err + } + if match { + s.recordMatch(item) + } + return nil +} + +func (s *readPageState) consumeItem(i int, item map[string]attributeValue, totalItems int) error { + if err := s.consumeReadItem(item); err != nil { + return err + } + if s.shouldSetLastEvaluatedKey(i, totalItems) { + s.lastEvaluatedKey = s.buildLastEvaluatedKey(item) + } + return nil +} + +func (s *readPageState) recordMatch(item map[string]attributeValue) { + s.outCount++ + if !s.includeItems { + return + } + s.outItems = append(s.outItems, projectItemByAttributes(item, s.projection)) +} + +func (s *readPageState) shouldSetLastEvaluatedKey(i int, totalItems int) bool { + return s.hasLimit && s.scannedCount == s.limit && i < totalItems-1 +} + +func (s *readPageState) buildLastEvaluatedKey(item map[string]attributeValue) map[string]attributeValue { + if s.opts.lastEvaluatedKeyBuilder != nil { + return s.opts.lastEvaluatedKeyBuilder(item) + } + return makeLastEvaluatedKey(s.schema.PrimaryKey, item) +} + +func (s *readPageState) output() *queryOutput { + items := s.outItems + if !s.includeItems { + items = nil + } + return &queryOutput{ + items: items, + count: s.outCount, + scannedCount: s.scannedCount, + lastEvaluatedKey: s.lastEvaluatedKey, + } +} + +func (emptyReadIterator) Next(context.Context) (map[string]attributeValue, bool, error) { + return nil, false, nil +} + +func newKeyRangeKVIterator( + server *DynamoDBServer, + bounds dynamoReadBounds, + readTS uint64, + pageLimit int, +) *keyRangeKVIterator { + cursor := bytes.Clone(bounds.lower) + if bounds.reverse { + cursor = bytes.Clone(bounds.upper) + } + return &keyRangeKVIterator{ + server: server, + lower: bytes.Clone(bounds.lower), + upper: bytes.Clone(bounds.upper), + cursor: cursor, + readTS: readTS, + pageLimit: pageLimit, + reverse: bounds.reverse, + } +} + +func (it *keyRangeKVIterator) Next(ctx context.Context) (*store.KVPair, bool, error) { + for { + if it.index < len(it.page) { + kvp := it.page[it.index] + it.index++ + return kvp, true, nil + } + if it.done { + return nil, false, nil + } + if err := it.loadNextPage(ctx); err != nil { + return nil, false, err + } + } +} + +func (it *keyRangeKVIterator) loadNextPage(ctx context.Context) error { + if it.reverse { + return it.loadNextPageReverse(ctx) + } + return it.loadNextPageForward(ctx) +} + +func (it *keyRangeKVIterator) loadNextPageForward(ctx context.Context) error { + kvs, err := it.server.store.ScanAt(ctx, it.cursor, it.upper, it.pageLimit, it.readTS) + if err != nil { + return errors.WithStack(err) + } + if len(kvs) == 0 { + it.done = true + it.page = nil + return nil + } + it.page, it.done = filterBoundedKVPairsForward(kvs, it.lower, it.upper, it.pageLimit) + it.index = 0 + if !it.done { + it.cursor = nextScanCursor(kvs[len(kvs)-1].Key) + if it.upper != nil && bytes.Compare(it.cursor, it.upper) >= 0 { + it.done = true + } + } + return nil +} + +func (it *keyRangeKVIterator) loadNextPageReverse(ctx context.Context) error { + kvs, err := it.server.store.ReverseScanAt(ctx, it.lower, it.cursor, it.pageLimit, it.readTS) + if err != nil { + return errors.WithStack(err) + } + if len(kvs) == 0 { + it.done = true + it.page = nil + return nil + } + it.page, it.done = filterBoundedKVPairsReverse(kvs, it.lower, it.cursor, it.pageLimit) + it.index = 0 + if !it.done { + it.cursor = bytes.Clone(kvs[len(kvs)-1].Key) + } + return nil +} + +func filterBoundedKVPairsForward(kvs []*store.KVPair, lower []byte, upper []byte, pageLimit int) ([]*store.KVPair, bool) { + page := make([]*store.KVPair, 0, minInt(len(kvs), pageLimit)) + done := len(kvs) < pageLimit + for _, kvp := range kvs { + if lower != nil && bytes.Compare(kvp.Key, lower) < 0 { + continue + } + if upper != nil && bytes.Compare(kvp.Key, upper) >= 0 { + done = true + break + } + page = append(page, kvp) + } + if len(page) == 0 { + done = true + } + return page, done +} + +func filterBoundedKVPairsReverse(kvs []*store.KVPair, lower []byte, upper []byte, pageLimit int) ([]*store.KVPair, bool) { + page := make([]*store.KVPair, 0, minInt(len(kvs), pageLimit)) + done := len(kvs) < pageLimit + for _, kvp := range kvs { + if lower != nil && bytes.Compare(kvp.Key, lower) < 0 { + done = true + break + } + if upper != nil && bytes.Compare(kvp.Key, upper) >= 0 { + continue + } + page = append(page, kvp) + } + if len(page) == 0 { + done = true + } + return page, done +} + +func newTableReadIterator( + server *DynamoDBServer, + bounds dynamoReadBounds, + readTS uint64, + pageLimit int, + projector readItemProjector, + filter itemReadFilter, +) dynamoReadIterator { + if bounds.upper != nil && bytes.Compare(bounds.lower, bounds.upper) >= 0 { + return emptyReadIterator{} + } + return &tableReadIterator{ + kv: newKeyRangeKVIterator(server, bounds, readTS, pageLimit), + projector: projector, + filter: filter, + } +} + +func (it *tableReadIterator) Next(ctx context.Context) (map[string]attributeValue, bool, error) { + for { + kvp, ok, err := it.kv.Next(ctx) + if err != nil || !ok { + return nil, ok, err + } + item, err := decodeStoredDynamoItem(kvp.Value) + if err != nil { + return nil, false, err + } + item, err = it.projector(item) + if err != nil { + return nil, false, err + } + if it.filter != nil && !it.filter(item) { + continue + } + return item, true, nil + } +} + +func newGSIReadIterator( + server *DynamoDBServer, + bounds dynamoReadBounds, + readTS uint64, + pageLimit int, + projector readItemProjector, + filter itemReadFilter, +) dynamoReadIterator { + if bounds.upper != nil && bytes.Compare(bounds.lower, bounds.upper) >= 0 { + return emptyReadIterator{} + } + return &gsiReadIterator{ + server: server, + kv: newKeyRangeKVIterator(server, bounds, readTS, pageLimit), + readTS: readTS, + projector: projector, + filter: filter, + seen: map[string]struct{}{}, + } +} + +func (it *gsiReadIterator) Next(ctx context.Context) (map[string]attributeValue, bool, error) { + for { + kvp, ok, err := it.kv.Next(ctx) + if err != nil || !ok { + return nil, ok, err + } + itemKey := string(kvp.Value) + if _, exists := it.seen[itemKey]; exists { + continue + } + it.seen[itemKey] = struct{}{} + item, found, err := it.server.readItemAtKeyAt(ctx, kvp.Value, it.readTS) + if err != nil { + return nil, false, err + } + if !found { + continue + } + item, err = it.projector(item) + if err != nil { + return nil, false, err + } + if it.filter != nil && !it.filter(item) { + continue + } + return item, true, nil + } +} + +func matchesReadFilter(expr string, item map[string]attributeValue, values map[string]attributeValue) (bool, error) { + if strings.TrimSpace(expr) == "" { + return true, nil + } + ok, err := evalConditionExpression(expr, item, values) + if err != nil { + return false, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) + } + return ok, nil +} + +func resolveProjectionAttributes(expr string, names map[string]string) ([]string, error) { + projectionExpr, err := replaceNames(expr, names) + if err != nil { + return nil, err + } + projection := strings.TrimSpace(projectionExpr) + if projection == "" { + return nil, nil + } + parts, err := splitTopLevelByComma(projection) + if err != nil { + return nil, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, "invalid ProjectionExpression") + } + attrs := make([]string, 0, len(parts)) + for _, part := range parts { + attr := strings.TrimSpace(part) + if attr == "" { + return nil, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, "invalid ProjectionExpression") + } + attrs = append(attrs, attr) + } + return attrs, nil +} + +func projectItem(item map[string]attributeValue, expr string, names map[string]string) (map[string]attributeValue, error) { + attrs, err := resolveProjectionAttributes(expr, names) + if err != nil { + return nil, err + } + return projectItemByAttributes(item, attrs), nil +} + +func projectItemByAttributes(item map[string]attributeValue, attrs []string) map[string]attributeValue { + if len(attrs) == 0 { + return cloneAttributeValueMap(item) + } + out := make(map[string]attributeValue, len(attrs)) + for _, attr := range attrs { + if value, ok := item[attr]; ok { + out[attr] = value + } + } + return out +} + +func decodeItemsFromKVPairs(kvs []*store.KVPair) ([]map[string]attributeValue, error) { + items := make([]map[string]attributeValue, 0, len(kvs)) + for _, kvp := range kvs { + item, err := decodeStoredDynamoItem(kvp.Value) + if err != nil { + return nil, err + } + items = append(items, item) + } + return items, nil +} + +func (d *DynamoDBServer) queryItemsByKeyCondition( + ctx context.Context, + in queryInput, + schema *dynamoTableSchema, + keySchema dynamoKeySchema, + cond queryCondition, + readTS uint64, +) ([]map[string]attributeValue, error) { + if strings.TrimSpace(in.IndexName) != "" { + return d.queryItemsByGSI(ctx, in, schema, cond, readTS) + } + scanPrefix, err := queryScanPrefix(schema, in, keySchema, cond.hashValue) + if err != nil { + return nil, err + } + kvs, err := d.scanAllByPrefixAt(ctx, scanPrefix, readTS) + if err != nil { + return nil, errors.WithStack(err) + } + items, err := filterQueryItems(kvs, cond) + if err != nil { + return nil, err + } + return items, nil +} + +func (d *DynamoDBServer) queryItemsByGSI( + ctx context.Context, + in queryInput, + schema *dynamoTableSchema, + cond queryCondition, + readTS uint64, +) ([]map[string]attributeValue, error) { + keySchema, err := schema.keySchemaForQuery(in.IndexName) + if err != nil { + return nil, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) + } + prefix, err := queryScanPrefix(schema, in, keySchema, cond.hashValue) + if err != nil { + return nil, err + } + kvs, err := d.scanAllByPrefixAt(ctx, prefix, readTS) + if err != nil { + return nil, errors.WithStack(err) + } + itemKeys := uniqueGSIItemKeys(kvs) + items, err := d.readItemsForGSIQuery(ctx, itemKeys, readTS, cond) + if err != nil { + return nil, err + } + return items, nil +} + +func (d *DynamoDBServer) scanItemsBySource( + ctx context.Context, + in scanInput, + schema *dynamoTableSchema, + readTS uint64, +) ([]map[string]attributeValue, error) { + if strings.TrimSpace(in.IndexName) == "" { + kvs, err := d.scanAllByPrefixAt(ctx, dynamoItemPrefixForTable(in.TableName, schema.Generation), readTS) + if err != nil { + return nil, errors.WithStack(err) + } + return decodeItemsFromKVPairs(kvs) + } + if _, err := schema.keySchemaForQuery(in.IndexName); err != nil { + return nil, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) + } + kvs, err := d.scanAllByPrefixAt(ctx, dynamoGSIIndexPrefixForTable(in.TableName, schema.Generation, in.IndexName), readTS) + if err != nil { + return nil, errors.WithStack(err) + } + itemKeys := uniqueGSIItemKeys(kvs) + return d.readItemsAtKeys(ctx, itemKeys, readTS) +} + +func uniqueGSIItemKeys(kvs []*store.KVPair) [][]byte { + if len(kvs) == 0 { + return nil + } + out := make([][]byte, 0, len(kvs)) + seen := make(map[string]struct{}, len(kvs)) + for _, kvp := range kvs { + key := string(kvp.Value) + if _, exists := seen[key]; exists { + continue + } + seen[key] = struct{}{} + out = append(out, bytes.Clone(kvp.Value)) + } + return out +} + +type gsiReadJob struct { + index int + key []byte +} + +type gsiReadResult struct { + index int + item map[string]attributeValue + err error +} + +type itemReadFilter func(map[string]attributeValue) bool + +func resolveGSIReadWorkerCount(n int) int { + if n <= 0 { + return 0 + } + if n < gsiQueryReadWorkerCount { + return n + } + return gsiQueryReadWorkerCount +} + +func (d *DynamoDBServer) readItemsForGSIQuery( + ctx context.Context, + itemKeys [][]byte, + readTS uint64, + cond queryCondition, +) ([]map[string]attributeValue, error) { + return d.readItemsAtKeysMatching(ctx, itemKeys, readTS, func(item map[string]attributeValue) bool { + return matchesQueryCondition(item, cond) + }) +} + +func (d *DynamoDBServer) readItemsAtKeys( + ctx context.Context, + itemKeys [][]byte, + readTS uint64, +) ([]map[string]attributeValue, error) { + return d.readItemsAtKeysMatching(ctx, itemKeys, readTS, nil) +} + +func (d *DynamoDBServer) readItemsAtKeysMatching( + ctx context.Context, + itemKeys [][]byte, + readTS uint64, + filter itemReadFilter, +) ([]map[string]attributeValue, error) { + if len(itemKeys) == 0 { + return nil, nil + } + workerCount := resolveGSIReadWorkerCount(len(itemKeys)) + jobs := make(chan gsiReadJob) + results := make(chan gsiReadResult, len(itemKeys)) + workerCtx, cancel := context.WithCancel(ctx) + defer cancel() + + var wg sync.WaitGroup + d.startGSIReadWorkers(&wg, workerCount, workerCtx, readTS, filter, jobs, results, cancel) + enqueueGSIReadJobs(workerCtx, jobs, itemKeys) + close(jobs) + wg.Wait() + close(results) + + return collectOrderedGSIReadResults(itemKeys, results) +} + +func (d *DynamoDBServer) startGSIReadWorkers( + wg *sync.WaitGroup, + workerCount int, + ctx context.Context, + readTS uint64, + filter itemReadFilter, + jobs <-chan gsiReadJob, + results chan<- gsiReadResult, + cancel context.CancelFunc, +) { + for range workerCount { + wg.Go(func() { + d.gsiReadWorker(ctx, readTS, filter, jobs, results, cancel) + }) + } +} + +func enqueueGSIReadJobs(ctx context.Context, jobs chan<- gsiReadJob, itemKeys [][]byte) { +enqueueLoop: + for i, key := range itemKeys { + select { + case <-ctx.Done(): + break enqueueLoop + case jobs <- gsiReadJob{index: i, key: key}: + } + } +} + +func collectOrderedGSIReadResults( + itemKeys [][]byte, + results <-chan gsiReadResult, +) ([]map[string]attributeValue, error) { + indexed := make(map[int]map[string]attributeValue, len(itemKeys)) + for res := range results { + if res.err != nil { + return nil, res.err + } + if res.item != nil { + indexed[res.index] = res.item + } + } + items := make([]map[string]attributeValue, 0, len(indexed)) + for i := range itemKeys { + if item := indexed[i]; item != nil { + items = append(items, item) + } + } + return items, nil +} + +func (d *DynamoDBServer) gsiReadWorker( + ctx context.Context, + readTS uint64, + filter itemReadFilter, + jobs <-chan gsiReadJob, + results chan<- gsiReadResult, + cancel context.CancelFunc, +) { + for { + job, ok := nextGSIReadJob(ctx, jobs) + if !ok { + return + } + item, emit, err := d.executeGSIReadJob(ctx, readTS, filter, job.key) + if err != nil { + sendGSIReadError(results, err) + cancel() + return + } + if !emit { + continue + } + if !sendGSIReadResult(ctx, results, gsiReadResult{index: job.index, item: item}) { + return + } + } +} + +func nextGSIReadJob(ctx context.Context, jobs <-chan gsiReadJob) (gsiReadJob, bool) { + select { + case <-ctx.Done(): + return gsiReadJob{}, false + case job, ok := <-jobs: + if !ok { + return gsiReadJob{}, false + } + return job, true + } +} + +func (d *DynamoDBServer) executeGSIReadJob( + ctx context.Context, + readTS uint64, + filter itemReadFilter, + key []byte, +) (map[string]attributeValue, bool, error) { + item, found, err := d.readItemAtKeyAt(ctx, key, readTS) + if err != nil { + return nil, false, err + } + if !found { + return nil, false, nil + } + if filter != nil && !filter(item) { + return nil, false, nil + } + return item, true, nil +} + +func sendGSIReadError(results chan<- gsiReadResult, err error) { + select { + case results <- gsiReadResult{err: err}: + default: + } +} + +func sendGSIReadResult(ctx context.Context, results chan<- gsiReadResult, result gsiReadResult) bool { + select { + case results <- result: + return true + case <-ctx.Done(): + return false + } +} + +func queryScanPrefix(schema *dynamoTableSchema, in queryInput, keySchema dynamoKeySchema, hashValue attributeValue) ([]byte, error) { + if !schema.usesOrderedKeyEncoding() { + hashKey, err := attributeValueAsKey(hashValue) + if err != nil { + return nil, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) + } + if strings.TrimSpace(in.IndexName) != "" { + return legacyDynamoGSIHashPrefixForTable(in.TableName, schema.Generation, in.IndexName, hashKey), nil + } + if keySchema.HashKey != schema.PrimaryKey.HashKey { + return dynamoItemPrefixForTable(in.TableName, schema.Generation), nil + } + return legacyDynamoItemHashPrefixForTable(in.TableName, schema.Generation, hashKey), nil + } + hashKey, err := attributeValueAsKeySegment(hashValue) + if err != nil { + return nil, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) + } + if strings.TrimSpace(in.IndexName) != "" { + return dynamoGSIHashPrefixForTable(in.TableName, schema.Generation, in.IndexName, hashKey), nil + } + if keySchema.HashKey != schema.PrimaryKey.HashKey { + return dynamoItemPrefixForTable(in.TableName, schema.Generation), nil + } + return dynamoItemHashPrefixForTable(in.TableName, schema.Generation, hashKey), nil +} + +func refineQueryReadBounds( + bounds dynamoReadBounds, + basePrefix []byte, + cond queryRangeCondition, +) (dynamoReadBounds, error) { + switch cond.op { + case queryRangeOpEqual, queryRangeOpLessThan, queryRangeOpLessOrEq, queryRangeOpGreater, queryRangeOpGreaterEq: + return refineQueryComparisonBounds(bounds, basePrefix, cond) + case queryRangeOpBetween: + return refineQueryBetweenBounds(bounds, basePrefix, cond) + case queryRangeOpBeginsWith: + return refineQueryBeginsWithBounds(bounds, basePrefix, cond.value1) + default: + return bounds, nil + } +} + +func refineQueryComparisonBounds( + bounds dynamoReadBounds, + basePrefix []byte, + cond queryRangeCondition, +) (dynamoReadBounds, error) { + prefix, err := appendRangeConditionPrefix(basePrefix, cond.value1) + if err != nil { + return dynamoReadBounds{}, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) + } + if cond.op == queryRangeOpEqual { + bounds.lower = prefix + bounds.upper = prefixScanEnd(prefix) + return bounds, nil + } + if cond.op == queryRangeOpLessThan { + bounds.upper = minBytes(bounds.upper, prefix) + return bounds, nil + } + if cond.op == queryRangeOpLessOrEq { + bounds.upper = minBytes(bounds.upper, prefixScanEnd(prefix)) + return bounds, nil + } + if cond.op == queryRangeOpGreater { + bounds.lower = maxBytes(bounds.lower, prefixScanEnd(prefix)) + return bounds, nil + } + if cond.op == queryRangeOpGreaterEq { + bounds.lower = maxBytes(bounds.lower, prefix) + return bounds, nil + } + return bounds, nil +} + +func refineQueryBetweenBounds( + bounds dynamoReadBounds, + basePrefix []byte, + cond queryRangeCondition, +) (dynamoReadBounds, error) { + lower, err := appendRangeConditionPrefix(basePrefix, cond.value1) + if err != nil { + return dynamoReadBounds{}, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) + } + upper, err := appendRangeConditionPrefix(basePrefix, cond.value2) + if err != nil { + return dynamoReadBounds{}, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) + } + bounds.lower = maxBytes(bounds.lower, lower) + bounds.upper = minBytes(bounds.upper, prefixScanEnd(upper)) + return bounds, nil +} + +func refineQueryBeginsWithBounds( + bounds dynamoReadBounds, + basePrefix []byte, + value attributeValue, +) (dynamoReadBounds, error) { + prefix, err := appendRangeConditionPrefixMatch(basePrefix, value) + if err != nil { + return dynamoReadBounds{}, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) + } + bounds.lower = maxBytes(bounds.lower, prefix) + bounds.upper = minBytes(bounds.upper, prefixScanEnd(prefix)) + return bounds, nil +} + +func appendRangeConditionPrefix(basePrefix []byte, value attributeValue) ([]byte, error) { + segment, err := attributeValueAsKeySegment(value) + if err != nil { + return nil, err + } + return append(bytes.Clone(basePrefix), segment...), nil +} + +func appendRangeConditionPrefixMatch(basePrefix []byte, value attributeValue) ([]byte, error) { + raw, err := attributeValueAsKeyBytes(value) + if err != nil { + return nil, err + } + segment := encodeDynamoKeySegmentPrefix(raw) + return append(bytes.Clone(basePrefix), segment...), nil +} + +func maxBytes(left []byte, right []byte) []byte { + if left == nil { + return bytes.Clone(right) + } + if right == nil { + return bytes.Clone(left) + } + if bytes.Compare(left, right) >= 0 { + return bytes.Clone(left) + } + return bytes.Clone(right) +} + +func minBytes(left []byte, right []byte) []byte { + if left == nil { + return bytes.Clone(right) + } + if right == nil { + return bytes.Clone(left) + } + if bytes.Compare(left, right) <= 0 { + return bytes.Clone(left) + } + return bytes.Clone(right) +} + +func applyQueryExclusiveStartKey(schema *dynamoTableSchema, startKey map[string]attributeValue, items []map[string]attributeValue) ([]map[string]attributeValue, error) { + if len(startKey) == 0 { + return items, nil + } + startItemKey, err := schema.itemKeyFromAttributes(startKey) + if err != nil { + return nil, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, "invalid ExclusiveStartKey") + } + descending, hasDirection := queryItemOrderDirection(schema, items) + for i, item := range items { + if remaining, ok := exclusiveStartRemainingItems(schema, item, items, i, startItemKey, descending, hasDirection); ok { + return remaining, nil + } + } + return nil, nil +} + +func exclusiveStartRemainingItems( + schema *dynamoTableSchema, + item map[string]attributeValue, + items []map[string]attributeValue, + index int, + startItemKey []byte, + descending bool, + hasDirection bool, +) ([]map[string]attributeValue, bool) { + itemKey, err := schema.itemKeyFromAttributes(item) + if err != nil { + return nil, false + } + if bytes.Equal(itemKey, startItemKey) { + return items[index+1:], true + } + if !hasDirection || !exclusiveStartShouldAdvance(descending, itemKey, startItemKey) { + return nil, false + } + return items[index:], true +} + +func exclusiveStartShouldAdvance(descending bool, itemKey []byte, startItemKey []byte) bool { + cmp := bytes.Compare(itemKey, startItemKey) + return (!descending && cmp > 0) || (descending && cmp < 0) +} + +func queryItemOrderDirection(schema *dynamoTableSchema, items []map[string]attributeValue) (bool, bool) { + var previous []byte + for _, item := range items { + itemKey, err := schema.itemKeyFromAttributes(item) + if err != nil { + continue + } + if previous == nil { + previous = itemKey + continue + } + cmp := bytes.Compare(itemKey, previous) + if cmp == 0 { + continue + } + return cmp < 0, true + } + return false, false +} + +func makeLastEvaluatedKey(keySchema dynamoKeySchema, item map[string]attributeValue) map[string]attributeValue { + out := map[string]attributeValue{} + if hash, ok := item[keySchema.HashKey]; ok { + out[keySchema.HashKey] = hash + } + if keySchema.RangeKey != "" { + if rk, ok := item[keySchema.RangeKey]; ok { + out[keySchema.RangeKey] = rk + } + } + if len(out) == 0 { + return nil + } + return out +} + +func makeReadLastEvaluatedKey(primary dynamoKeySchema, index dynamoKeySchema, item map[string]attributeValue) map[string]attributeValue { + out := makeLastEvaluatedKey(primary, item) + if len(out) == 0 { + out = map[string]attributeValue{} + } + if hash, ok := item[index.HashKey]; ok { + out[index.HashKey] = hash + } + if index.RangeKey != "" { + if rk, ok := item[index.RangeKey]; ok { + out[index.RangeKey] = rk + } + } + if len(out) == 0 { + return nil + } + return out +} + +type parsedKeyConditionTerm struct { + attr string + op queryRangeOperator + placeholder1 string + placeholder2 string +} + +func parseKeyConditionExpression(expr string) ([]parsedKeyConditionTerm, error) { + expr = strings.TrimSpace(expr) + if expr == "" { + return nil, errors.New("unsupported key condition expression") + } + parts, err := splitKeyConditionTerms(expr) + if err != nil { + return nil, err + } + if len(parts) > updateSplitCount { + return nil, errors.New("unsupported key condition expression") + } + terms := make([]parsedKeyConditionTerm, 0, len(parts)) + for _, part := range parts { + term, err := parseKeyConditionTerm(part) + if err != nil { + return nil, err + } + terms = append(terms, term) + } + return terms, nil +} + +func splitKeyConditionTerms(expr string) ([]string, error) { + upper := strings.ToUpper(expr) + depth := 0 + last := 0 + betweenPending := false + parts := make([]string, 0, splitPartsInitialCapacity) + for i := 0; i < len(expr); i++ { + depth = nextParenDepth(depth, expr[i]) + if depth != 0 { + continue + } + keyword := keyConditionKeywordAt(expr, upper, i) + if keyword == "" { + continue + } + if keyword == "BETWEEN" { + betweenPending = true + i += len(keyword) - 1 + continue + } + if betweenPending { + betweenPending = false + i += len(keyword) - 1 + continue + } + part, ok := trimmedNonEmpty(expr[last:i]) + if !ok { + return nil, errors.New("unsupported key condition expression") + } + parts = append(parts, part) + i += len(keyword) - 1 + last = i + 1 + } + if betweenPending { + return nil, errors.New("unsupported key condition expression") + } + tail, ok := trimmedNonEmpty(expr[last:]) + if !ok { + return nil, errors.New("unsupported key condition expression") + } + if len(parts) == 0 { + return []string{tail}, nil + } + return append(parts, tail), nil +} + +func keyConditionKeywordAt(expr string, upper string, pos int) string { + if matchesKeywordTokenAt(upper, "BETWEEN", pos) && + isLogicalKeywordBoundary(expr, pos-1) && + isLogicalKeywordBoundary(expr, pos+len("BETWEEN")) { + return "BETWEEN" + } + if matchesKeywordTokenAt(upper, "AND", pos) && + isLogicalKeywordBoundary(expr, pos-1) && + isLogicalKeywordBoundary(expr, pos+len("AND")) { + return "AND" + } + return "" +} + +func parseKeyConditionTerm(term string) (parsedKeyConditionTerm, error) { + term = strings.TrimSpace(term) + if t, ok, err := parseBeginsWithKeyConditionTerm(term); ok || err != nil { + return t, err + } + if t, ok, err := parseBetweenKeyConditionTerm(term); ok || err != nil { + return t, err + } + return parseComparisonKeyConditionTerm(term) +} + +func parseBeginsWithKeyConditionTerm(term string) (parsedKeyConditionTerm, bool, error) { + const prefix = "BEGINS_WITH(" + upper := strings.ToUpper(term) + if !strings.HasPrefix(upper, prefix) { + return parsedKeyConditionTerm{}, false, nil + } + if !strings.HasSuffix(term, ")") { + return parsedKeyConditionTerm{}, true, errors.New("unsupported key condition expression") + } + inner := strings.TrimSpace(term[len(prefix) : len(term)-1]) + parts := strings.SplitN(inner, ",", updateSplitCount) + if len(parts) != updateSplitCount { + return parsedKeyConditionTerm{}, true, errors.New("unsupported key condition expression") + } + attrName := strings.TrimSpace(parts[0]) + placeholder := strings.TrimSpace(parts[1]) + if attrName == "" || !strings.HasPrefix(placeholder, ":") { + return parsedKeyConditionTerm{}, true, errors.New("unsupported key condition expression") + } + return parsedKeyConditionTerm{ + attr: attrName, + op: queryRangeOpBeginsWith, + placeholder1: placeholder, + }, true, nil +} + +func parseBetweenKeyConditionTerm(term string) (parsedKeyConditionTerm, bool, error) { + upper := strings.ToUpper(term) + betweenIdx := strings.Index(upper, " BETWEEN ") + if betweenIdx < 0 { + return parsedKeyConditionTerm{}, false, nil + } + attrName := strings.TrimSpace(term[:betweenIdx]) + rest := strings.TrimSpace(term[betweenIdx+len(" BETWEEN "):]) + andIdx := strings.Index(strings.ToUpper(rest), " AND ") + if andIdx < 0 { + return parsedKeyConditionTerm{}, true, errors.New("unsupported key condition expression") + } + placeholder1 := strings.TrimSpace(rest[:andIdx]) + placeholder2 := strings.TrimSpace(rest[andIdx+len(" AND "):]) + if attrName == "" || !strings.HasPrefix(placeholder1, ":") || !strings.HasPrefix(placeholder2, ":") { + return parsedKeyConditionTerm{}, true, errors.New("unsupported key condition expression") + } + return parsedKeyConditionTerm{ + attr: attrName, + op: queryRangeOpBetween, + placeholder1: placeholder1, + placeholder2: placeholder2, + }, true, nil +} + +func parseComparisonKeyConditionTerm(term string) (parsedKeyConditionTerm, error) { + operators := []queryRangeOperator{ + queryRangeOpLessOrEq, + queryRangeOpGreaterEq, + queryRangeOpLessThan, + queryRangeOpGreater, + queryRangeOpEqual, + } + for _, op := range operators { + if t, ok := splitComparisonTerm(term, op); ok { + return t, nil + } + } + return parsedKeyConditionTerm{}, errors.New("unsupported key condition expression") +} + +func splitComparisonTerm(term string, op queryRangeOperator) (parsedKeyConditionTerm, bool) { + opStr := string(op) + before, after, ok := strings.Cut(term, opStr) + if !ok { + return parsedKeyConditionTerm{}, false + } + left := strings.TrimSpace(before) + right := strings.TrimSpace(after) + if left == "" || !strings.HasPrefix(right, ":") { + return parsedKeyConditionTerm{}, false + } + return parsedKeyConditionTerm{ + attr: left, + op: op, + placeholder1: right, + }, true +} + +func buildQueryCondition(keySchema dynamoKeySchema, terms []parsedKeyConditionTerm, values map[string]attributeValue) (queryCondition, error) { + hashTerm, rangeTerm, err := classifyQueryConditionTerms(keySchema, terms) + if err != nil { + return queryCondition{}, err + } + hashValue, ok := values[hashTerm.placeholder1] + if !ok { + return queryCondition{}, errors.New("missing key condition value") + } + cond := queryCondition{ + hashAttr: keySchema.HashKey, + hashValue: hashValue, + } + if rangeTerm == nil { + return cond, nil + } + value1, ok := values[rangeTerm.placeholder1] + if !ok { + return queryCondition{}, errors.New("missing key condition value") + } + rangeCond := &queryRangeCondition{ + attr: keySchema.RangeKey, + op: rangeTerm.op, + value1: value1, + } + if rangeTerm.op == queryRangeOpBetween { + value2, ok := values[rangeTerm.placeholder2] + if !ok { + return queryCondition{}, errors.New("missing key condition value") + } + rangeCond.value2 = value2 + } + cond.rangeCond = rangeCond + return cond, nil +} + +func classifyQueryConditionTerms( + keySchema dynamoKeySchema, + terms []parsedKeyConditionTerm, +) (parsedKeyConditionTerm, *parsedKeyConditionTerm, error) { + if len(terms) == 0 || len(terms) > updateSplitCount { + return parsedKeyConditionTerm{}, nil, errors.New("unsupported key condition") + } + hashTerm, ok := findHashConditionTerm(keySchema.HashKey, terms) + if !ok { + return parsedKeyConditionTerm{}, nil, errors.New("unsupported key condition") + } + if len(terms) == 1 { + return hashTerm, nil, nil + } + rangeTerm, ok := findRangeConditionTerm(keySchema.RangeKey, terms, hashTerm) + if !ok { + return parsedKeyConditionTerm{}, nil, errors.New("unsupported key condition") + } + return hashTerm, &rangeTerm, nil +} + +func findHashConditionTerm(hashKey string, terms []parsedKeyConditionTerm) (parsedKeyConditionTerm, bool) { + var hashTerm parsedKeyConditionTerm + found := false + for _, term := range terms { + if term.attr != hashKey || term.op != queryRangeOpEqual { + continue + } + if found { + return parsedKeyConditionTerm{}, false + } + hashTerm = term + found = true + } + return hashTerm, found +} + +func findRangeConditionTerm( + rangeKey string, + terms []parsedKeyConditionTerm, + hashTerm parsedKeyConditionTerm, +) (parsedKeyConditionTerm, bool) { + if strings.TrimSpace(rangeKey) == "" { + return parsedKeyConditionTerm{}, false + } + for _, term := range terms { + if term == hashTerm { + continue + } + if term.attr != rangeKey { + return parsedKeyConditionTerm{}, false + } + return term, true + } + return parsedKeyConditionTerm{}, false +} + +func matchesQueryCondition(item map[string]attributeValue, cond queryCondition) bool { + hashAttr, ok := item[cond.hashAttr] + if !ok || !attributeValueEqual(hashAttr, cond.hashValue) { + return false + } + if cond.rangeCond == nil { + return true + } + rangeAttr, ok := item[cond.rangeCond.attr] + if !ok { + return false + } + return matchesQueryRangeCondition(rangeAttr, *cond.rangeCond) +} + +func matchesQueryRangeCondition(attr attributeValue, cond queryRangeCondition) bool { + if cond.op == queryRangeOpBeginsWith { + return matchesQueryRangeBeginsWith(attr, cond.value1) + } + if cond.op == queryRangeOpBetween { + return matchesQueryRangeBetween(attr, cond.value1, cond.value2) + } + return matchesQueryRangeCompare(attr, cond.value1, cond.op) +} + +func matchesQueryRangeCompare(attr attributeValue, right attributeValue, op queryRangeOperator) bool { + switch op { + case queryRangeOpEqual: + return attributeValueEqual(attr, right) + case queryRangeOpLessThan: + return compareAttributeValueSortKey(attr, right) < 0 + case queryRangeOpLessOrEq: + return compareAttributeValueSortKey(attr, right) <= 0 + case queryRangeOpGreater: + return compareAttributeValueSortKey(attr, right) > 0 + case queryRangeOpGreaterEq: + return compareAttributeValueSortKey(attr, right) >= 0 + case queryRangeOpBetween, queryRangeOpBeginsWith: + return false + default: + return false + } +} + +func matchesQueryRangeBetween(attr attributeValue, lower attributeValue, upper attributeValue) bool { + return compareAttributeValueSortKey(attr, lower) >= 0 && + compareAttributeValueSortKey(attr, upper) <= 0 +} + +func matchesQueryRangeBeginsWith(attr attributeValue, prefixValue attributeValue) bool { + attrKey, err := attributeValueAsKey(attr) + if err != nil { + return false + } + prefix, err := attributeValueAsKey(prefixValue) + if err != nil { + return false + } + return strings.HasPrefix(attrKey, prefix) +} diff --git a/adapter/dynamodb_transact.go b/adapter/dynamodb_transact.go new file mode 100644 index 00000000..3f58a688 --- /dev/null +++ b/adapter/dynamodb_transact.go @@ -0,0 +1,1220 @@ +package adapter + +import ( + "context" + "encoding/base64" + "io" + "net/http" + "sort" + "strconv" + "strings" + "time" + + "github.com/bootjp/elastickv/kv" + "github.com/bootjp/elastickv/store" + "github.com/cockroachdb/errors" + json "github.com/goccy/go-json" +) + +type batchWriteItemInput struct { + RequestItems map[string][]batchWriteRequest `json:"RequestItems"` +} + +type batchWriteRequest struct { + PutRequest *batchPutRequest `json:"PutRequest,omitempty"` + DeleteRequest *batchDeleteRequest `json:"DeleteRequest,omitempty"` +} + +type batchPutRequest struct { + Item map[string]attributeValue `json:"Item"` +} + +type batchDeleteRequest struct { + Key map[string]attributeValue `json:"Key"` +} + +type transactWriteItemsInput struct { + TransactItems []transactWriteItem `json:"TransactItems"` +} + +type transactGetItemsInput struct { + TransactItems []transactGetItem `json:"TransactItems"` +} + +type transactGetItem struct { + Get *transactGetItemGet `json:"Get"` +} + +type transactGetItemGet struct { + TableName string `json:"TableName"` + Key map[string]attributeValue `json:"Key"` + ProjectionExpression string `json:"ProjectionExpression,omitempty"` + ExpressionAttributeNames map[string]string `json:"ExpressionAttributeNames,omitempty"` +} + +type transactWriteItem struct { + Put *putItemInput `json:"Put,omitempty"` + Update *transactUpdateInput `json:"Update,omitempty"` + Delete *transactDeleteInput `json:"Delete,omitempty"` + ConditionCheck *transactConditionInput `json:"ConditionCheck,omitempty"` +} + +type transactUpdateInput struct { + TableName string `json:"TableName"` + Key map[string]attributeValue `json:"Key"` + UpdateExpression string `json:"UpdateExpression"` + ConditionExpression string `json:"ConditionExpression"` + ExpressionAttributeNames map[string]string `json:"ExpressionAttributeNames"` + ExpressionAttributeValues map[string]attributeValue `json:"ExpressionAttributeValues"` +} + +type transactDeleteInput struct { + TableName string `json:"TableName"` + Key map[string]attributeValue `json:"Key"` + ConditionExpression string `json:"ConditionExpression"` + ExpressionAttributeNames map[string]string `json:"ExpressionAttributeNames"` + ExpressionAttributeValues map[string]attributeValue `json:"ExpressionAttributeValues"` +} + +type transactConditionInput struct { + TableName string `json:"TableName"` + Key map[string]attributeValue `json:"Key"` + ConditionExpression string `json:"ConditionExpression"` + ExpressionAttributeNames map[string]string `json:"ExpressionAttributeNames"` + ExpressionAttributeValues map[string]attributeValue `json:"ExpressionAttributeValues"` +} + +func (d *DynamoDBServer) batchWriteItem(w http.ResponseWriter, r *http.Request) { + in, err := decodeBatchWriteItemInput(maxDynamoBodyReader(w, r)) + if err != nil { + writeDynamoErrorFromErr(w, err) + return + } + unprocessed, err := d.batchWriteItems(r.Context(), in) + if err != nil { + writeDynamoErrorFromErr(w, err) + return + } + for table, written := range batchWriteCommittedCounts(in, unprocessed) { + d.observeWrittenItems(r.Context(), table, written) + } + writeDynamoJSON(w, map[string]any{"UnprocessedItems": unprocessed}) +} + +func decodeBatchWriteItemInput(bodyReader io.Reader) (batchWriteItemInput, error) { + body, err := io.ReadAll(bodyReader) + if err != nil { + return batchWriteItemInput{}, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) + } + var in batchWriteItemInput + if err := json.Unmarshal(body, &in); err != nil { + return batchWriteItemInput{}, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) + } + if len(in.RequestItems) == 0 { + return batchWriteItemInput{}, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, "missing RequestItems") + } + total := 0 + for tableName, requests := range in.RequestItems { + if strings.TrimSpace(tableName) == "" { + return batchWriteItemInput{}, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, "missing table name") + } + total += len(requests) + } + if total == 0 { + return batchWriteItemInput{}, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, "missing write requests") + } + if total > batchWriteItemMaxItems { + return batchWriteItemInput{}, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, "too many items in BatchWriteItem") + } + return in, nil +} + +func batchWriteCommittedCounts(in batchWriteItemInput, unprocessed map[string][]batchWriteRequest) map[string]int { + out := make(map[string]int, len(in.RequestItems)) + for table, requests := range in.RequestItems { + written := len(requests) - len(unprocessed[table]) + if written > 0 { + out[table] = written + } + } + return out +} + +func (d *DynamoDBServer) batchWriteItems( + ctx context.Context, + in batchWriteItemInput, +) (map[string][]batchWriteRequest, error) { + tableNames := make([]string, 0, len(in.RequestItems)) + for tableName := range in.RequestItems { + tableNames = append(tableNames, tableName) + } + sort.Strings(tableNames) + unlock := d.lockTableOperations(tableNames) + defer unlock() + for _, tableName := range tableNames { + if err := d.ensureLegacyTableMigrationLocked(ctx, tableName); err != nil { + return nil, err + } + } + if err := d.validateBatchWriteRequests(ctx, tableNames, in.RequestItems); err != nil { + return nil, err + } + unprocessed := make(map[string][]batchWriteRequest) + for _, tableName := range tableNames { + requests := in.RequestItems[tableName] + for _, request := range requests { + err := d.executeBatchWriteRequest(ctx, tableName, request) + if err == nil { + continue + } + if ctx.Err() != nil { + return nil, errors.WithStack(ctx.Err()) + } + unprocessed[tableName] = append(unprocessed[tableName], request) + } + } + return unprocessed, nil +} + +func (d *DynamoDBServer) validateBatchWriteRequests( + ctx context.Context, + tableNames []string, + requestItems map[string][]batchWriteRequest, +) error { + for _, tableName := range tableNames { + schema, exists, err := d.loadTableSchema(ctx, tableName) + if err != nil { + return errors.WithStack(err) + } + if !exists { + return newDynamoAPIError(http.StatusBadRequest, dynamoErrResourceNotFound, "table not found") + } + seenKeys := make(map[string]struct{}, len(requestItems[tableName])) + for _, request := range requestItems[tableName] { + key, err := validateBatchWriteRequestForSchema(schema, request) + if err != nil { + return err + } + if _, ok := seenKeys[string(key)]; ok { + return newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, "duplicate item in BatchWriteItem") + } + seenKeys[string(key)] = struct{}{} + } + } + return nil +} + +func validateBatchWriteRequestForSchema(schema *dynamoTableSchema, request batchWriteRequest) ([]byte, error) { + switch countBatchWriteActions(request) { + case 1: + default: + return nil, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, "invalid batch write request") + } + switch { + case request.PutRequest != nil: + key, err := schema.itemKeyFromAttributes(request.PutRequest.Item) + if err != nil { + return nil, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) + } + return key, nil + case request.DeleteRequest != nil: + key, err := schema.itemKeyFromAttributes(request.DeleteRequest.Key) + if err != nil { + return nil, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) + } + return key, nil + default: + return nil, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, "invalid batch write request") + } +} + +func (d *DynamoDBServer) executeBatchWriteRequest( + ctx context.Context, + tableName string, + request batchWriteRequest, +) error { + schema, exists, err := d.loadTableSchema(ctx, tableName) + if err != nil { + return errors.WithStack(err) + } + if !exists { + return newDynamoAPIError(http.StatusBadRequest, dynamoErrResourceNotFound, "table not found") + } + keyAttrs, err := batchWriteRequestKey(schema, request) + if err != nil { + return err + } + lockKey, err := dynamoItemUpdateLockKey(tableName, keyAttrs) + if err != nil { + return newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) + } + unlock := d.lockItemUpdate(lockKey) + defer unlock() + switch countBatchWriteActions(request) { + case 1: + default: + return newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, "invalid batch write request") + } + switch { + case request.PutRequest != nil: + _, err := d.putItemWithRetry(ctx, putItemInput{ + TableName: tableName, + Item: request.PutRequest.Item, + }) + return err + case request.DeleteRequest != nil: + _, err := d.deleteItemWithRetry(ctx, deleteItemInput{ + TableName: tableName, + Key: request.DeleteRequest.Key, + }) + return err + default: + return newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, "invalid batch write request") + } +} + +func batchWriteRequestKey(schema *dynamoTableSchema, request batchWriteRequest) (map[string]attributeValue, error) { + switch { + case request.PutRequest != nil: + return primaryKeyAttributes(schema.PrimaryKey, request.PutRequest.Item) + case request.DeleteRequest != nil: + return primaryKeyAttributes(schema.PrimaryKey, request.DeleteRequest.Key) + default: + return nil, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, "invalid batch write request") + } +} + +func primaryKeyAttributes(keySchema dynamoKeySchema, attrs map[string]attributeValue) (map[string]attributeValue, error) { + out := make(map[string]attributeValue, primaryKeyAttributeCapacity(keySchema)) + hash, ok := attrs[keySchema.HashKey] + if !ok { + return nil, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, "missing hash key attribute") + } + out[keySchema.HashKey] = hash + if keySchema.RangeKey != "" { + rangeValue, ok := attrs[keySchema.RangeKey] + if !ok { + return nil, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, "missing range key attribute") + } + out[keySchema.RangeKey] = rangeValue + } + return out, nil +} + +func primaryKeyAttributeCapacity(keySchema dynamoKeySchema) int { + size := 1 + if keySchema.RangeKey != "" { + size++ + } + return size +} + +func countBatchWriteActions(request batchWriteRequest) int { + count := 0 + if request.PutRequest != nil { + count++ + } + if request.DeleteRequest != nil { + count++ + } + return count +} + +func (d *DynamoDBServer) transactWriteItems(w http.ResponseWriter, r *http.Request) { + in, err := decodeTransactWriteItemsInput(maxDynamoBodyReader(w, r)) + if err != nil { + writeDynamoErrorFromErr(w, err) + return + } + if err := d.transactWriteItemsWithRetry(r.Context(), in); err != nil { + writeDynamoErrorFromErr(w, err) + return + } + for table, written := range transactWriteWrittenCounts(in) { + d.observeWrittenItems(r.Context(), table, written) + } + writeDynamoJSON(w, map[string]any{}) +} + +func decodeTransactWriteItemsInput(bodyReader io.Reader) (transactWriteItemsInput, error) { + body, err := io.ReadAll(bodyReader) + if err != nil { + return transactWriteItemsInput{}, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) + } + var in transactWriteItemsInput + if err := json.Unmarshal(body, &in); err != nil { + return transactWriteItemsInput{}, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) + } + if len(in.TransactItems) == 0 { + return transactWriteItemsInput{}, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, "missing transact items") + } + return in, nil +} + +// transactGetItems implements TransactGetItems: reads multiple items atomically +// at a single snapshot timestamp, guaranteeing a consistent view across all keys. +func (d *DynamoDBServer) transactGetItems(w http.ResponseWriter, r *http.Request) { + in, err := decodeTransactGetItemsInput(maxDynamoBodyReader(w, r)) + if err != nil { + writeDynamoErrorFromErr(w, err) + return + } + + // Lease-check every shard this transaction will read BEFORE the single + // snapshot timestamp is resolved, so the quorum-freshness bound is + // established without changing the single-snapshot-ts semantics. The + // timestamp below is still sampled once and shared by all items. + if !d.leaseCheckTransactGetItems(w, r, in) { + return + } + + // Acquire a single read timestamp for all items to guarantee a consistent snapshot. + readTS := d.nextTxnReadTS() + pin := d.pinReadTS(readTS) + defer pin.Release() + + responses, tableMetrics, err := d.buildTransactGetItemsResponses(r.Context(), in, readTS) + if err != nil { + writeDynamoErrorFromErr(w, err) + return + } + for table, m := range tableMetrics { + d.observeReadMetrics(r.Context(), table, m.found, m.requested) + } + writeDynamoJSON(w, map[string]any{"Responses": responses}) +} + +func decodeTransactGetItemsInput(bodyReader io.Reader) (transactGetItemsInput, error) { + body, err := io.ReadAll(bodyReader) + if err != nil { + return transactGetItemsInput{}, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) + } + var in transactGetItemsInput + if err := json.Unmarshal(body, &in); err != nil { + return transactGetItemsInput{}, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) + } + if len(in.TransactItems) == 0 { + return transactGetItemsInput{}, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, "missing transact items") + } + if len(in.TransactItems) > transactGetItemsMaxItems { + return transactGetItemsInput{}, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, + "Too many items in TransactGetItems: "+strconv.Itoa(len(in.TransactItems))+" (max "+strconv.Itoa(transactGetItemsMaxItems)+")") + } + return in, nil +} + +// collectTransactGetTableNames returns a deduplicated list of table names referenced +// in the TransactGetItems input. Used to run ensureLegacyTableMigration once per table. +func collectTransactGetTableNames(in transactGetItemsInput) []string { + seen := make(map[string]struct{}, len(in.TransactItems)) + names := make([]string, 0, len(in.TransactItems)) + for _, item := range in.TransactItems { + if item.Get == nil { + continue + } + t := item.Get.TableName + if _, exists := seen[t]; exists { + continue + } + seen[t] = struct{}{} + names = append(names, t) + } + return names +} + +// transactGetItemsMetrics holds per-table counts for metrics reporting. +type transactGetItemsMetrics struct { + requested int + found int +} + +// buildTransactGetItemsResponses reads each requested item at the given readTS +// and returns the ordered response list and a per-table metrics map. +// schemaCache avoids redundant storage reads when multiple items share the same table. +// seenItemKeys enforces the DynamoDB rule that a transaction may not reference the +// same item more than once. +// ensureLegacyTableMigration is called once per unique table before any item is read. +func (d *DynamoDBServer) buildTransactGetItemsResponses(ctx context.Context, in transactGetItemsInput, readTS uint64) ([]map[string]any, map[string]*transactGetItemsMetrics, error) { + tableNames := collectTransactGetTableNames(in) + for _, tableName := range tableNames { + if err := d.ensureLegacyTableMigration(ctx, tableName); err != nil { + return nil, nil, err + } + } + schemaCache := make(map[string]*dynamoTableSchema) + seenItemKeys := make(map[transactGetSeenKey]struct{}, len(in.TransactItems)) + tableMetrics := make(map[string]*transactGetItemsMetrics) + responses := make([]map[string]any, 0, len(in.TransactItems)) + for _, item := range in.TransactItems { + entry, itemFound, tableName, err := d.readTransactGetItem(ctx, item, schemaCache, seenItemKeys, readTS) + if err != nil { + return nil, nil, err + } + responses = append(responses, entry) + m := tableMetrics[tableName] + if m == nil { + m = &transactGetItemsMetrics{} + tableMetrics[tableName] = m + } + m.requested++ + if itemFound { + m.found++ + } + } + return responses, tableMetrics, nil +} + +// transactGetSeenKey is the map key used for duplicate-item detection in +// TransactGetItems. Using a struct avoids separator-collision risks from +// string concatenation and is more idiomatic Go. +type transactGetSeenKey struct { + tableName string + keyStr string +} + +// readTransactGetItem validates and reads a single item in a TransactGetItems request. +// ensureLegacyTableMigration must be called for g.TableName before invoking this function. +// Returns the response entry, whether the item was found, the table name, and any error. +// Returning the table name avoids the caller having to re-access item.Get after the call. +func (d *DynamoDBServer) readTransactGetItem(ctx context.Context, item transactGetItem, schemaCache map[string]*dynamoTableSchema, seenItemKeys map[transactGetSeenKey]struct{}, readTS uint64) (map[string]any, bool, string, error) { + if item.Get == nil { + return nil, false, "", newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, "TransactGetItems only supports Get operations") + } + g := item.Get + if strings.TrimSpace(g.TableName) == "" { + return nil, false, "", newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, "missing TableName in Get") + } + schema, err := d.resolveTransactTableSchema(ctx, schemaCache, g.TableName, readTS) + if err != nil { + return nil, false, "", err + } + // Reject duplicate item keys to match real DynamoDB behavior. + // canonicalPrimaryKeyStr reads only hash/range key attributes from g.Key + // by schema name, so extra attributes in the map are safely ignored — + // no separate primaryKeyAttributes extraction is needed. + keyStr, err := canonicalPrimaryKeyStr(schema.PrimaryKey, g.Key) + if err != nil { + return nil, false, "", newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) + } + seenKey := transactGetSeenKey{tableName: g.TableName, keyStr: keyStr} + if _, dup := seenItemKeys[seenKey]; dup { + return nil, false, "", newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, + "Transaction request cannot include multiple operations on one item") + } + seenItemKeys[seenKey] = struct{}{} + loc, found, err := d.readLogicalItemAt(ctx, schema, g.Key, readTS) + if err != nil { + // Return the error as-is: storage errors from readItemAtKeyAt surface as + // InternalServerError (500) via writeDynamoErrorFromErr in the HTTP handler. + return nil, false, "", err + } + entry := map[string]any{} + if found { + projected, err := projectItem(loc.item, g.ProjectionExpression, g.ExpressionAttributeNames) + if err != nil { + return nil, false, "", err + } + entry["Item"] = projected + } + return entry, found, g.TableName, nil +} + +// canonicalPrimaryKeyStr returns a collision-free canonical string of primary +// key attributes for duplicate-item detection in TransactGetItems and +// TransactWriteItems. Shared between both operations to avoid duplicated logic. +// +// Takes the table's keySchema so it can write hash key then range key in a +// fixed schema-defined order, avoiding a slice allocation and sort — DynamoDB +// primary keys have at most two attributes, so direct lookup beats sorting. +// +// Format per attribute: "=::", separated by \x1f. +// The length prefix makes the format collision-free: a string value that +// contains \x1f cannot be confused with the inter-attribute separator because +// the decoder knows exactly how many bytes belong to each value. +// Numeric values are normalised; binary values are base64-encoded. +func canonicalPrimaryKeyStr(keySchema dynamoKeySchema, key map[string]attributeValue) (string, error) { + var buf strings.Builder + hashVal, ok := key[keySchema.HashKey] + if !ok { + return "", newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, "missing hash key attribute") + } + buf.WriteString(keySchema.HashKey) + buf.WriteByte('=') + if err := writeCanonicalAttrValue(&buf, hashVal); err != nil { + return "", err + } + if keySchema.RangeKey != "" { + rangeVal, ok := key[keySchema.RangeKey] + if !ok { + return "", newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, "missing range key attribute") + } + buf.WriteByte('\x1f') + buf.WriteString(keySchema.RangeKey) + buf.WriteByte('=') + if err := writeCanonicalAttrValue(&buf, rangeVal); err != nil { + return "", err + } + } + return buf.String(), nil +} + +// writeCanonicalAttrValue appends a length-prefixed typed value for a single +// primary key attribute to buf. Format: "::". +// Supports S (string), N (normalised number), and B (base64-encoded binary). +// The length prefix prevents collisions when string values contain \x1f. +func writeCanonicalAttrValue(buf *strings.Builder, v attributeValue) error { + switch { + case v.S != nil: + buf.WriteString("S:") + buf.WriteString(strconv.Itoa(len(*v.S))) + buf.WriteByte(':') + buf.WriteString(*v.S) + case v.N != nil: + n := canonicalNumberString(*v.N) + buf.WriteString("N:") + buf.WriteString(strconv.Itoa(len(n))) + buf.WriteByte(':') + buf.WriteString(n) + case v.B != nil: + encoded := base64.StdEncoding.EncodeToString(v.B) + buf.WriteString("B:") + buf.WriteString(strconv.Itoa(len(encoded))) + buf.WriteByte(':') + buf.WriteString(encoded) + default: + return newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, "unsupported key attribute type for duplicate detection") + } + return nil +} + +func collectTransactWriteTableNames(in transactWriteItemsInput) ([]string, error) { + seen := map[string]struct{}{} + names := make([]string, 0, len(in.TransactItems)) + for _, item := range in.TransactItems { + tableName, err := transactWriteItemTableName(item) + if err != nil { + return nil, err + } + if tableName == "" { + return nil, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, "missing table name") + } + if _, exists := seen[tableName]; exists { + continue + } + seen[tableName] = struct{}{} + names = append(names, tableName) + } + return names, nil +} + +func transactWriteWrittenCounts(in transactWriteItemsInput) map[string]int { + out := make(map[string]int) + for _, item := range in.TransactItems { + tableName, err := transactWriteItemTableName(item) + if err != nil || strings.TrimSpace(tableName) == "" { + continue + } + switch { + case item.Put != nil, item.Update != nil, item.Delete != nil: + out[tableName]++ + } + } + return out +} + +func transactWriteItemTableName(item transactWriteItem) (string, error) { + switch countTransactWriteItemActions(item) { + case 0: + return "", newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, "missing transact action") + case 1: + default: + return "", newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, "multiple transact actions are not supported") + } + switch { + case item.Put != nil: + return strings.TrimSpace(item.Put.TableName), nil + case item.Update != nil: + return strings.TrimSpace(item.Update.TableName), nil + case item.Delete != nil: + return strings.TrimSpace(item.Delete.TableName), nil + default: + return strings.TrimSpace(item.ConditionCheck.TableName), nil + } +} + +func countTransactWriteItemActions(item transactWriteItem) int { + count := 0 + if item.Put != nil { + count++ + } + if item.Update != nil { + count++ + } + if item.Delete != nil { + count++ + } + if item.ConditionCheck != nil { + count++ + } + return count +} + +func (d *DynamoDBServer) transactWriteItemsWithRetry(ctx context.Context, in transactWriteItemsInput) error { + backoff := transactRetryInitialBackoff + deadline := time.Now().Add(transactRetryMaxDuration) + var lastErr error + for range transactRetryMaxAttempts { + reqs, generations, cleanupKeys, err := d.buildTransactWriteItemsRequest(ctx, in) + if err != nil { + return err + } + done, retryErr, fatalErr := d.runTransactWriteAttempt(ctx, reqs, generations, cleanupKeys) + if fatalErr != nil { + return fatalErr + } + if done { + return nil + } + if retryErr != nil { + lastErr = retryErr + } + if err := waitRetryWithDeadline(ctx, deadline, backoff); err != nil { + if lastErr != nil { + combined := errors.Join(err, lastErr) + return errors.Wrap(combined, "transact write retry canceled") + } + return errors.WithStack(err) + } + backoff = nextTransactRetryBackoff(backoff) + } + if lastErr != nil { + return errors.Wrapf(lastErr, "transact write retry attempts exhausted after %s", transactRetryMaxDuration) + } + return newDynamoAPIError(http.StatusInternalServerError, dynamoErrInternal, "transact write retry attempts exhausted") +} + +func (d *DynamoDBServer) runTransactWriteAttempt( + ctx context.Context, + reqs *kv.OperationGroup[kv.OP], + generations map[string]uint64, + cleanupKeys [][]byte, +) (bool, error, error) { + if len(reqs.Elems) == 0 { + return true, nil, nil + } + if _, err := d.coordinator.Dispatch(ctx, reqs); err != nil { + wrapped := errors.WithStack(err) + if !isRetryableTransactWriteError(err) { + return false, nil, wrapped + } + return false, wrapped, nil + } + retry, verifyErr := d.handleGenerationFenceResult( + ctx, + d.verifyTableGenerations(ctx, generations), + cleanupKeys, + ) + if verifyErr != nil { + return false, nil, verifyErr + } + if !retry { + return true, nil, nil + } + return false, nil, nil +} + +func (d *DynamoDBServer) buildTransactWriteItemsRequest(ctx context.Context, in transactWriteItemsInput) (*kv.OperationGroup[kv.OP], map[string]uint64, [][]byte, error) { + tableNames, err := collectTransactWriteTableNames(in) + if err != nil { + return nil, nil, nil, err + } + for _, tableName := range tableNames { + if err := d.ensureLegacyTableMigration(ctx, tableName); err != nil { + return nil, nil, nil, err + } + } + readTS := d.nextTxnReadTS() + reqs := &kv.OperationGroup[kv.OP]{ + IsTxn: true, + // Keep transaction start aligned with the snapshot used to evaluate + // ConditionCheck/ConditionExpression so concurrent writes after readTS + // are detected as write conflicts at commit time. + StartTS: readTS, + } + schemaCache := make(map[string]*dynamoTableSchema) + tableGenerations := make(map[string]uint64) + cleanup := make([][]byte, 0, len(in.TransactItems)) + // seenItemKeys tracks (tableName, primaryKey) pairs to detect duplicates. + // Real DynamoDB rejects requests with multiple operations on the same item. + seenItemKeys := make(map[string]struct{}, len(in.TransactItems)) + for _, item := range in.TransactItems { + if err := d.processTransactWriteItem(ctx, item, readTS, reqs, schemaCache, seenItemKeys, tableGenerations, &cleanup); err != nil { + return nil, nil, nil, err + } + } + return reqs, tableGenerations, cleanup, nil +} + +// processTransactWriteItem validates and plans a single item within a +// TransactWriteItems request, appending the resulting ops to reqs and cleanup. +func (d *DynamoDBServer) processTransactWriteItem( + ctx context.Context, + item transactWriteItem, + readTS uint64, + reqs *kv.OperationGroup[kv.OP], + schemaCache map[string]*dynamoTableSchema, + seenItemKeys map[string]struct{}, + tableGenerations map[string]uint64, + cleanup *[][]byte, +) error { + tableName, err := transactWriteItemTableName(item) + if err != nil { + return err + } + schema, err := d.resolveTransactTableSchema(ctx, schemaCache, tableName, readTS) + if err != nil { + return err + } + // Reject duplicate item keys to match real DynamoDB behavior. + itemKeyStr, err := transactWriteItemPrimaryKeyStr(schema, item) + if err != nil { + return err + } + compositeKey := tableName + "\x00" + itemKeyStr + if _, dup := seenItemKeys[compositeKey]; dup { + return newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, + "Transaction request cannot include multiple operations on one item") + } + seenItemKeys[compositeKey] = struct{}{} + plan, err := d.buildTransactWriteItemPlan(ctx, schema, item, readTS) + if err != nil { + // Real DynamoDB wraps per-item condition failures in + // TransactionCanceledException rather than surfacing the raw + // ConditionalCheckFailedException to the caller. + var apiErr *dynamoAPIError + if errors.As(err, &apiErr) && apiErr.errorType == dynamoErrConditionalFailed { + return newDynamoAPIError(http.StatusBadRequest, dynamoErrTransactionCanceled, apiErr.message) + } + return err + } + reqs.Elems = append(reqs.Elems, plan.elems...) + reqs.ReadKeys = append(reqs.ReadKeys, plan.readKeys...) + if !plan.writes { + return nil + } + tableGenerations[tableName] = schema.Generation + *cleanup = append(*cleanup, plan.cleanup...) + return nil +} + +// transactWriteItemPrimaryKeyStr returns a canonical string of the item's +// primary key attributes, used to detect duplicate-item violations in +// TransactWriteItems (real DynamoDB returns ValidationException for these). +// Delegates to canonicalPrimaryKeyStr for the actual serialization. +// primaryKeyAttributes is applied uniformly across all operation types so that +// only hash/range key fields are used for duplicate detection, regardless of +// whether the operation carries a full Item (Put) or a Key-only map (Update/Delete/ConditionCheck). +func transactWriteItemPrimaryKeyStr(schema *dynamoTableSchema, item transactWriteItem) (string, error) { + var rawAttrs map[string]attributeValue + switch { + case item.Update != nil: + rawAttrs = item.Update.Key + case item.Delete != nil: + rawAttrs = item.Delete.Key + case item.ConditionCheck != nil: + rawAttrs = item.ConditionCheck.Key + case item.Put != nil: + rawAttrs = item.Put.Item + default: + return "", newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, "unsupported transact item") + } + keyAttrs, err := primaryKeyAttributes(schema.PrimaryKey, rawAttrs) + if err != nil { + // primaryKeyAttributes already returns a dynamoAPIError; return it directly + // to preserve its status code and error type. + return "", err + } + keyStr, err := canonicalPrimaryKeyStr(schema.PrimaryKey, keyAttrs) + if err != nil { + return "", newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) + } + return keyStr, nil +} + +type transactWriteItemPlan struct { + elems []*kv.Elem[kv.OP] + cleanup [][]byte + writes bool + // readKeys contains the raw storage keys that were read during plan + // construction. They are propagated into OperationGroup.ReadKeys so the + // FSM can validate read-write conflicts atomically at commit time, + // preventing lost-update and G0 anomalies on concurrent transactions that + // read the same item at a stale timestamp. + readKeys [][]byte +} + +func (d *DynamoDBServer) buildTransactWriteItemPlan( + ctx context.Context, + schema *dynamoTableSchema, + item transactWriteItem, + readTS uint64, +) (*transactWriteItemPlan, error) { + switch { + case item.Put != nil: + return d.buildTransactPutPlan(ctx, schema, *item.Put, readTS) + case item.Update != nil: + return d.buildTransactUpdatePlan(ctx, schema, *item.Update, readTS) + case item.Delete != nil: + return d.buildTransactDeletePlan(ctx, schema, *item.Delete, readTS) + case item.ConditionCheck != nil: + return d.buildTransactConditionCheckPlan(ctx, schema, *item.ConditionCheck, readTS) + default: + return nil, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, "unsupported transact item") + } +} + +func (d *DynamoDBServer) buildTransactPutPlan( + ctx context.Context, + schema *dynamoTableSchema, + in putItemInput, + readTS uint64, +) (*transactWriteItemPlan, error) { + itemKey, err := schema.itemKeyFromAttributes(in.Item) + if err != nil { + return nil, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) + } + keyAttrs, err := primaryKeyAttributes(schema.PrimaryKey, in.Item) + if err != nil { + return nil, err + } + currentLocation, found, err := d.readLogicalItemAt(ctx, schema, keyAttrs, readTS) + if err != nil { + return nil, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) + } + var current map[string]attributeValue + if found { + current = currentLocation.item + } + if err := validateConditionOnItem( + in.ConditionExpression, + in.ExpressionAttributeNames, + in.ExpressionAttributeValues, + valueOrEmptyMap(current, found), + ); err != nil { + return nil, newDynamoAPIError(http.StatusBadRequest, dynamoErrConditionalFailed, err.Error()) + } + req, cleanup, err := buildItemWriteRequestWithSource(schema, itemKey, in.Item, currentLocation) + if err != nil { + return nil, err + } + return &transactWriteItemPlan{ + elems: req.Elems, + cleanup: cleanup, + writes: true, + readKeys: [][]byte{itemKey}, + }, nil +} + +func (d *DynamoDBServer) buildTransactUpdatePlan( + ctx context.Context, + schema *dynamoTableSchema, + in transactUpdateInput, + readTS uint64, +) (*transactWriteItemPlan, error) { + itemKey, err := schema.itemKeyFromAttributes(in.Key) + if err != nil { + return nil, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) + } + currentLocation, found, err := d.readLogicalItemAt(ctx, schema, in.Key, readTS) + if err != nil { + return nil, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) + } + var current map[string]attributeValue + if !found { + current = map[string]attributeValue{} + } else { + current = currentLocation.item + } + updateIn := updateItemInput{ + TableName: in.TableName, + Key: in.Key, + UpdateExpression: in.UpdateExpression, + ConditionExpression: in.ConditionExpression, + ExpressionAttributeNames: in.ExpressionAttributeNames, + ExpressionAttributeValues: in.ExpressionAttributeValues, + } + nextItem, err := buildUpdatedItem(schema, updateIn, current) + if err != nil { + return nil, err + } + req, cleanup, err := buildItemWriteRequestWithSource(schema, itemKey, nextItem, currentLocation) + if err != nil { + return nil, err + } + return &transactWriteItemPlan{ + elems: req.Elems, + cleanup: cleanup, + writes: true, + readKeys: [][]byte{itemKey}, + }, nil +} + +func (d *DynamoDBServer) buildTransactDeletePlan( + ctx context.Context, + schema *dynamoTableSchema, + in transactDeleteInput, + readTS uint64, +) (*transactWriteItemPlan, error) { + itemKey, err := schema.itemKeyFromAttributes(in.Key) + if err != nil { + return nil, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) + } + currentLocation, found, err := d.readLogicalItemAt(ctx, schema, in.Key, readTS) + if err != nil { + return nil, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) + } + current := map[string]attributeValue(nil) + if found { + current = currentLocation.item + } + if err := validateConditionOnItem( + in.ConditionExpression, + in.ExpressionAttributeNames, + in.ExpressionAttributeValues, + valueOrEmptyMap(current, found), + ); err != nil { + return nil, newDynamoAPIError(http.StatusBadRequest, dynamoErrConditionalFailed, err.Error()) + } + if !found { + // Item does not exist at readTS. Track the key so a concurrent create + // after our snapshot can be detected if the overall transaction + // includes write elems and therefore reaches FSM validation. In a + // pure no-op transaction, these read keys may not be validated. + return &transactWriteItemPlan{readKeys: [][]byte{itemKey}}, nil + } + req, err := buildItemDeleteRequestWithSource(currentLocation) + if err != nil { + return nil, err + } + return &transactWriteItemPlan{ + elems: req.Elems, + writes: true, + readKeys: [][]byte{itemKey}, + }, nil +} + +func (d *DynamoDBServer) buildTransactConditionCheckPlan( + ctx context.Context, + schema *dynamoTableSchema, + in transactConditionInput, + readTS uint64, +) (*transactWriteItemPlan, error) { + if strings.TrimSpace(in.ConditionExpression) == "" { + return nil, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, "missing condition expression") + } + itemKey, err := schema.itemKeyFromAttributes(in.Key) + if err != nil { + return nil, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) + } + currentLocation, found, err := d.readLogicalItemAt(ctx, schema, in.Key, readTS) + if err != nil { + return nil, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) + } + current := map[string]attributeValue(nil) + if found { + current = currentLocation.item + } + if err := validateConditionOnItem( + in.ConditionExpression, + in.ExpressionAttributeNames, + in.ExpressionAttributeValues, + valueOrEmptyMap(current, found), + ); err != nil { + return nil, newDynamoAPIError(http.StatusBadRequest, dynamoErrConditionalFailed, err.Error()) + } + lockKey := itemKey + if currentLocation != nil { + lockKey = currentLocation.key + } + lockReq, lockCleanup, err := buildConditionCheckLockRequest(lockKey, current, found) + if err != nil { + return nil, err + } + return &transactWriteItemPlan{ + elems: lockReq.Elems, + cleanup: lockCleanup, + writes: true, + readKeys: [][]byte{itemKey}, + }, nil +} + +func valueOrEmptyMap(item map[string]attributeValue, found bool) map[string]attributeValue { + if found { + return item + } + return map[string]attributeValue{} +} + +func buildConditionCheckLockRequest( + itemKey []byte, + current map[string]attributeValue, + found bool, +) (*kv.OperationGroup[kv.OP], [][]byte, error) { + if !found { + // Item does not exist: no write is needed. + // Include itemKey in ReadKeys only so OCC conflict detection fires + // if a concurrent writer commits to this key between our startTS and commitTS. + // Writing a Del tombstone here would shadow any concurrently committed Put + // at a higher timestamp, causing G-single-item-realtime anomalies. + // Return nil cleanup since nothing was written by this condition check. + return &kv.OperationGroup[kv.OP]{ + IsTxn: true, + StartTS: 0, + Elems: nil, + }, + nil, + nil + } + payload, err := encodeStoredDynamoItem(current) + if err != nil { + return nil, nil, errors.WithStack(err) + } + elems := []*kv.Elem[kv.OP]{{Op: kv.Put, Key: itemKey, Value: payload}} + return &kv.OperationGroup[kv.OP]{ + IsTxn: true, + StartTS: 0, + Elems: elems, + }, + [][]byte{itemKey}, + nil +} + +func (d *DynamoDBServer) resolveTransactTableSchema(ctx context.Context, cache map[string]*dynamoTableSchema, tableName string, readTS uint64) (*dynamoTableSchema, error) { + if schema := cache[tableName]; schema != nil { + return schema, nil + } + schema, exists, err := d.loadTableSchemaAt(ctx, tableName, readTS) + if err != nil { + return nil, errors.WithStack(err) + } + if !exists { + return nil, newDynamoAPIError(http.StatusBadRequest, dynamoErrResourceNotFound, "table not found") + } + cache[tableName] = schema + return schema, nil +} + +func isRetryableTransactWriteError(err error) bool { + return errors.Is(err, store.ErrWriteConflict) || errors.Is(err, kv.ErrTxnLocked) +} + +func waitTransactRetryBackoff(ctx context.Context, delay time.Duration) error { + timer := time.NewTimer(delay) + defer timer.Stop() + + select { + case <-ctx.Done(): + return errors.WithStack(ctx.Err()) + case <-timer.C: + return nil + } +} + +func waitRetryWithDeadline(ctx context.Context, deadline time.Time, backoff time.Duration) error { + remaining := time.Until(deadline) + if remaining <= 0 { + return errors.New("retry timeout") + } + delay := min(backoff, remaining) + return waitTransactRetryBackoff(ctx, delay) +} + +func nextTransactRetryBackoff(current time.Duration) time.Duration { + next := current * transactRetryBackoffFactor + if next > transactRetryMaxBackoff { + return transactRetryMaxBackoff + } + return next +} + +var errTableGenerationChanged = errors.New("table generation changed") + +func (d *DynamoDBServer) verifyTableGeneration(ctx context.Context, tableName string, expectedGeneration uint64) error { + // Use consistentReadLatestTS to always read the latest committed schema. + // Using a stale snapshotTS can cause false "table not found" results when + // this node's LastCommitTS is behind the table creation timestamp, which + // would erroneously trigger cleanupCommittedKeys and delete live item data. + schema, exists, err := d.loadTableSchemaAt(ctx, tableName, consistentReadLatestTS) + if err != nil { + return errors.WithStack(err) + } + if !exists { + return newDynamoAPIError(http.StatusBadRequest, dynamoErrResourceNotFound, "table not found") + } + if schema.Generation != expectedGeneration { + return errors.Wrapf(errTableGenerationChanged, + "table generation changed (table=%s expected=%d actual=%d)", + tableName, expectedGeneration, schema.Generation, + ) + } + return nil +} + +func (d *DynamoDBServer) verifyTableGenerations(ctx context.Context, generations map[string]uint64) error { + for tableName, generation := range generations { + if err := d.verifyTableGeneration(ctx, tableName, generation); err != nil { + return err + } + } + return nil +} + +func isGenerationFenceFailure(err error) bool { + return errors.Is(err, errTableGenerationChanged) || isTableNotFoundError(err) +} + +func (d *DynamoDBServer) handleGenerationFenceResult(ctx context.Context, err error, cleanupKeys [][]byte) (bool, error) { + if err == nil { + return false, nil + } + if !isGenerationFenceFailure(err) { + return false, err + } + if cleanupErr := d.cleanupCommittedKeys(ctx, cleanupKeys); cleanupErr != nil { + return false, cleanupErr + } + if errors.Is(err, errTableGenerationChanged) { + return true, nil + } + return false, err +} + +func isTableNotFoundError(err error) bool { + var apiErr *dynamoAPIError + if !errors.As(err, &apiErr) { + return false + } + return apiErr.errorType == dynamoErrResourceNotFound +} + +func (d *DynamoDBServer) cleanupCommittedKeys(ctx context.Context, keys [][]byte) error { + uniq := uniqueKeys(keys) + if len(uniq) == 0 { + return nil + } + return d.dispatchDeleteBatch(ctx, uniq) +} + +func uniqueKeys(keys [][]byte) [][]byte { + seen := make(map[string]struct{}, len(keys)) + out := make([][]byte, 0, len(keys)) + for _, key := range keys { + s := string(key) + if _, ok := seen[s]; ok { + continue + } + seen[s] = struct{}{} + out = append(out, key) + } + return out +} From f3f7fb7f400f4752082cde4c2977f1e11e30f44f Mon Sep 17 00:00:00 2001 From: "Yoshiaki Ueda (bootjp)" Date: Fri, 12 Jun 2026 15:44:32 +0900 Subject: [PATCH 3/4] adapter: split dynamodb.go expression/migration/locks into cohesive files (no behavior change) Pure code movement within package adapter: update-expression parsing, condition-expression evaluation, document-path traversal move into dynamodb_expression.go; legacy table key migration into dynamodb_migration.go; item/table lock striping + txn read-ts helpers into dynamodb_locks.go. No declarations changed. --- adapter/dynamodb_expression.go | 1640 ++++++++++++++++++++++++++++++++ adapter/dynamodb_locks.go | 123 +++ adapter/dynamodb_migration.go | 308 ++++++ 3 files changed, 2071 insertions(+) create mode 100644 adapter/dynamodb_expression.go create mode 100644 adapter/dynamodb_locks.go create mode 100644 adapter/dynamodb_migration.go diff --git a/adapter/dynamodb_expression.go b/adapter/dynamodb_expression.go new file mode 100644 index 00000000..d76c0d43 --- /dev/null +++ b/adapter/dynamodb_expression.go @@ -0,0 +1,1640 @@ +package adapter + +import ( + "bytes" + "maps" + "math/big" + "slices" + "sort" + "strconv" + "strings" + + "github.com/cockroachdb/errors" +) + +func replaceNames(expr string, names map[string]string) (string, error) { + if expr == "" || len(names) == 0 { + return expr, nil + } + if err := validateExpressionAttributeNames(names); err != nil { + return "", err + } + keys := make([]string, 0, len(names)) + for k := range names { + keys = append(keys, k) + } + sort.Slice(keys, func(i, j int) bool { + if len(keys[i]) == len(keys[j]) { + return keys[i] < keys[j] + } + return len(keys[i]) > len(keys[j]) + }) + + // DynamoDB expression attribute names are substituted once. + args := make([]string, 0, len(keys)*replacerArgPairSize) + for _, key := range keys { + args = append(args, key, names[key]) + } + return strings.NewReplacer(args...).Replace(expr), nil +} + +func validateExpressionAttributeNames(names map[string]string) error { + for placeholder, name := range names { + if !isExpressionAttributePlaceholder(placeholder) { + return errors.Errorf("invalid expression attribute placeholder %q", placeholder) + } + if !isExpressionAttributeName(name) { + return errors.Errorf("invalid expression attribute name %q for placeholder %q", name, placeholder) + } + } + return nil +} + +func isExpressionAttributePlaceholder(s string) bool { + if len(s) <= 1 || s[0] != '#' { + return false + } + return isExpressionPlaceholderIdentifier(s[1:]) +} + +func isExpressionPlaceholderIdentifier(s string) bool { + if s == "" { + return false + } + for i := 0; i < len(s); i++ { + if isExpressionPlaceholderIdentByte(s[i]) { + continue + } + return false + } + return true +} + +func isExpressionPlaceholderIdentByte(b byte) bool { + return b == '_' || (b >= 'a' && b <= 'z') || (b >= 'A' && b <= 'Z') || (b >= '0' && b <= '9') +} + +func isExpressionAttributeName(s string) bool { + if s == "" { + return false + } + for i := 0; i < len(s); i++ { + if isExpressionAttributeNameByte(s[i]) { + continue + } + return false + } + return true +} + +func isExpressionAttributeNameByte(b byte) bool { + return b == '_' || b == '.' || b == '-' || (b >= 'a' && b <= 'z') || (b >= 'A' && b <= 'Z') || (b >= '0' && b <= '9') +} + +func applyUpdateExpression(expr string, names map[string]string, values map[string]attributeValue, item map[string]attributeValue) error { + updExpr, err := replaceNames(expr, names) + if err != nil { + return err + } + updExpr = strings.TrimSpace(updExpr) + sections, err := parseUpdateExpressionSections(updExpr) + if err != nil { + return err + } + for _, section := range sections { + if err := applyUpdateExpressionSection(section, values, item); err != nil { + return err + } + } + return nil +} + +type updateExpressionSection struct { + action string + body string +} + +func parseUpdateExpressionSections(expr string) ([]updateExpressionSection, error) { + if strings.TrimSpace(expr) == "" { + return nil, errors.New("unsupported update expression") + } + sections := make([]updateExpressionSection, 0, updateSplitCount) + seen := map[string]struct{}{} + i := skipSpaces(expr, 0) + for i < len(expr) { + action, nextPos, ok := parseUpdateActionToken(expr, i) + if !ok { + return nil, errors.New("unsupported update expression") + } + if _, exists := seen[action]; exists { + return nil, errors.New("duplicate update action") + } + seen[action] = struct{}{} + bodyStart := skipSpaces(expr, nextPos) + bodyEnd := findNextUpdateAction(expr, bodyStart) + if bodyEnd < 0 { + bodyEnd = len(expr) + } + body := strings.TrimSpace(expr[bodyStart:bodyEnd]) + if body == "" { + return nil, errors.New("unsupported update expression") + } + sections = append(sections, updateExpressionSection{action: action, body: body}) + if bodyEnd >= len(expr) { + break + } + i = bodyEnd + } + if len(sections) == 0 { + return nil, errors.New("unsupported update expression") + } + return sections, nil +} + +func applyUpdateExpressionSection(section updateExpressionSection, values map[string]attributeValue, item map[string]attributeValue) error { + switch section.action { + case "SET": + return applySetUpdateAction(section.body, values, item) + case "REMOVE": + return applyRemoveUpdateAction(section.body, item) + case "ADD": + return applyAddUpdateAction(section.body, values, item) + case "DELETE": + return applyDeleteUpdateAction(section.body, values, item) + default: + return errors.New("unsupported update action") + } +} + +func parseUpdateActionToken(expr string, pos int) (string, int, bool) { + actions := []string{"SET", "REMOVE", "ADD", "DELETE"} + for _, action := range actions { + end := pos + len(action) + if end > len(expr) { + continue + } + if !strings.EqualFold(expr[pos:end], action) { + continue + } + if !isLogicalKeywordBoundary(expr, pos-1) || !isLogicalKeywordBoundary(expr, end) { + continue + } + return action, end, true + } + return "", 0, false +} + +func skipSpaces(expr string, pos int) int { + for pos < len(expr) && (expr[pos] == ' ' || expr[pos] == '\t' || expr[pos] == '\n' || expr[pos] == '\r') { + pos++ + } + return pos +} + +func findNextUpdateAction(expr string, start int) int { + depth := 0 + for i := start; i < len(expr); i++ { + depth = nextParenDepth(depth, expr[i]) + if depth != 0 { + continue + } + _, _, ok := parseUpdateActionToken(expr, i) + if ok { + return i + } + } + return -1 +} + +func applySetUpdateAction(body string, values map[string]attributeValue, item map[string]attributeValue) error { + assignments, err := splitTopLevelByComma(body) + if err != nil { + return errors.New("invalid update expression") + } + for _, assignment := range assignments { + if err := applySingleSetAssignment(assignment, values, item); err != nil { + return err + } + } + return nil +} + +func applySingleSetAssignment(assignment string, values map[string]attributeValue, item map[string]attributeValue) error { + parts := strings.SplitN(assignment, "=", updateSplitCount) + if len(parts) != updateSplitCount { + return errors.New("invalid update expression") + } + path := strings.TrimSpace(parts[0]) + if path == "" { + return errors.New("invalid update expression attribute") + } + valueExpr := strings.TrimSpace(parts[1]) + valueAttr, err := evalUpdateValueExpression(valueExpr, values, item) + if err != nil { + return err + } + return setDocumentPath(item, path, valueAttr) +} + +func applyRemoveUpdateAction(body string, item map[string]attributeValue) error { + attrs, err := splitTopLevelByComma(body) + if err != nil { + return errors.New("invalid update expression") + } + for _, attr := range attrs { + path := strings.TrimSpace(attr) + if path == "" { + return errors.New("invalid update expression attribute") + } + if err := removeDocumentPath(item, path); err != nil { + return err + } + } + return nil +} + +func applyAddUpdateAction(body string, values map[string]attributeValue, item map[string]attributeValue) error { + terms, err := splitTopLevelByComma(body) + if err != nil { + return errors.New("invalid update expression") + } + for _, term := range terms { + if err := applySingleAddTerm(term, values, item); err != nil { + return err + } + } + return nil +} + +func applySingleAddTerm(term string, values map[string]attributeValue, item map[string]attributeValue) error { + parts := strings.Fields(term) + if len(parts) != updateSplitCount { + return errors.New("invalid update expression") + } + path := strings.TrimSpace(parts[0]) + placeholder := strings.TrimSpace(parts[1]) + if path == "" || !strings.HasPrefix(placeholder, ":") { + return errors.New("invalid update expression") + } + addValue, ok := values[placeholder] + if !ok { + return errors.New("missing value attribute") + } + current, exists, err := resolveDocumentPath(item, path) + if err != nil { + return err + } + next, err := addAttributeValue(current, exists, addValue) + if err != nil { + return err + } + return setDocumentPath(item, path, next) +} + +func addNumericAttributeValues(left string, right string) (string, error) { + leftRat, rightRat := &big.Rat{}, &big.Rat{} + if _, ok := leftRat.SetString(strings.TrimSpace(left)); !ok { + return "", errors.New("invalid number attribute") + } + if _, ok := rightRat.SetString(strings.TrimSpace(right)); !ok { + return "", errors.New("invalid number attribute") + } + sum := &big.Rat{} + sum.Add(leftRat, rightRat) + if sum.IsInt() { + return sum.Num().String(), nil + } + out := strings.TrimRight(sum.FloatString(numericUpdateScaleDigits), "0") + out = strings.TrimRight(out, ".") + if out == "" { + return "0", nil + } + return out, nil +} + +func applyDeleteUpdateAction(body string, values map[string]attributeValue, item map[string]attributeValue) error { + terms, err := splitTopLevelByComma(body) + if err != nil { + return errors.New("invalid update expression") + } + for _, term := range terms { + if err := applySingleDeleteTerm(term, values, item); err != nil { + return err + } + } + return nil +} + +func applySingleDeleteTerm(term string, values map[string]attributeValue, item map[string]attributeValue) error { + fields := strings.Fields(strings.TrimSpace(term)) + switch len(fields) { + case 0: + return errors.New("invalid update expression") + case 1: + return removeDocumentPath(item, fields[0]) + case updateSplitCount: + return applyDeleteSetTerm(fields[0], fields[1], values, item) + default: + return errors.New("invalid update expression") + } +} + +func applyDeleteSetTerm(pathExpr string, placeholderExpr string, values map[string]attributeValue, item map[string]attributeValue) error { + path := strings.TrimSpace(pathExpr) + placeholder := strings.TrimSpace(placeholderExpr) + if path == "" || !strings.HasPrefix(placeholder, ":") { + return errors.New("invalid update expression") + } + deleteValue, ok := values[placeholder] + if !ok { + return errors.New("missing value attribute") + } + current, found, err := resolveDocumentPath(item, path) + if err != nil || !found { + return err + } + next, removeAttr, err := deleteAttributeValueElements(current, deleteValue) + if err != nil { + return err + } + if removeAttr { + return removeDocumentPath(item, path) + } + return setDocumentPath(item, path, next) +} + +func splitTopLevelByComma(expr string) ([]string, error) { + depth := 0 + last := 0 + parts := make([]string, 0, splitPartsInitialCapacity) + for i := 0; i < len(expr); i++ { + depth = nextParenDepth(depth, expr[i]) + if depth < 0 { + return nil, errors.New("invalid expression") + } + if depth != 0 || expr[i] != ',' { + continue + } + part := strings.TrimSpace(expr[last:i]) + if part == "" { + return nil, errors.New("invalid expression") + } + parts = append(parts, part) + last = i + 1 + } + if depth != 0 { + return nil, errors.New("invalid expression") + } + tail := strings.TrimSpace(expr[last:]) + if tail == "" { + return nil, errors.New("invalid expression") + } + return append(parts, tail), nil +} + +type documentPathToken struct { + attr string + index int + isIndex bool +} + +func parseDocumentPath(path string) ([]documentPathToken, error) { + path = strings.TrimSpace(path) + if path == "" { + return nil, errors.New("invalid document path") + } + tokens := make([]documentPathToken, 0, updateSplitCount) + for pos := 0; pos < len(path); { + nextPos, token, err := consumeDocumentPathToken(path, pos) + if err != nil { + return nil, err + } + pos = nextPos + if token.attr != "" || token.isIndex { + tokens = append(tokens, token) + } + } + if len(tokens) == 0 { + return nil, errors.New("invalid document path") + } + return tokens, nil +} + +func consumeDocumentPathToken(path string, pos int) (int, documentPathToken, error) { + switch path[pos] { + case '.': + return pos + 1, documentPathToken{}, nil + case '[': + return consumeDocumentPathIndex(path, pos) + default: + return consumeDocumentPathAttr(path, pos) + } +} + +func consumeDocumentPathIndex(path string, pos int) (int, documentPathToken, error) { + end := strings.IndexByte(path[pos:], ']') + if end <= 1 { + return 0, documentPathToken{}, errors.New("invalid document path") + } + indexValue, err := strconv.Atoi(path[pos+1 : pos+end]) + if err != nil || indexValue < 0 { + return 0, documentPathToken{}, errors.New("invalid document path") + } + return pos + end + 1, documentPathToken{index: indexValue, isIndex: true}, nil +} + +func consumeDocumentPathAttr(path string, pos int) (int, documentPathToken, error) { + start := pos + for pos < len(path) && path[pos] != '.' && path[pos] != '[' { + pos++ + } + attr := strings.TrimSpace(path[start:pos]) + if attr == "" { + return 0, documentPathToken{}, errors.New("invalid document path") + } + return pos, documentPathToken{attr: attr}, nil +} + +func resolveDocumentPath(item map[string]attributeValue, path string) (attributeValue, bool, error) { + tokens, err := parseDocumentPath(path) + if err != nil { + return attributeValue{}, false, err + } + current := attributeValue{M: item} + found := true + for _, token := range tokens { + current, found = nextDocumentPathValue(current, found, token) + if !found { + return attributeValue{}, false, nil + } + } + return cloneAttributeValue(current), true, nil +} + +func nextDocumentPathValue(current attributeValue, found bool, token documentPathToken) (attributeValue, bool) { + if !found { + return attributeValue{}, false + } + if token.isIndex { + if !current.hasListType() || token.index >= len(current.L) { + return attributeValue{}, false + } + return current.L[token.index], true + } + if !current.hasMapType() { + return attributeValue{}, false + } + value, ok := current.M[token.attr] + if !ok { + return attributeValue{}, false + } + return value, true +} + +func setDocumentPath(item map[string]attributeValue, path string, value attributeValue) error { + tokens, err := parseDocumentPath(path) + if err != nil { + return err + } + root, err := setDocumentPathValue(attributeValue{M: cloneAttributeValueMap(item)}, true, tokens, value) + if err != nil { + return err + } + replaceAttributeValueMap(item, root.M) + return nil +} + +func setDocumentPathValue(current attributeValue, exists bool, tokens []documentPathToken, value attributeValue) (attributeValue, error) { + if len(tokens) == 0 { + return cloneAttributeValue(value), nil + } + token := tokens[0] + if token.isIndex { + return setDocumentPathIndex(current, exists, token, tokens[1:], value) + } + return setDocumentPathAttribute(current, exists, token, tokens[1:], value) +} + +func setDocumentPathIndex(current attributeValue, exists bool, token documentPathToken, rest []documentPathToken, value attributeValue) (attributeValue, error) { + if !exists || !current.hasListType() { + return attributeValue{}, errors.New("invalid document path") + } + list := cloneAttributeValueList(current.L) + if token.index > len(list) { + return attributeValue{}, errors.New("invalid document path") + } + if token.index == len(list) { + return appendDocumentPathIndex(list, rest, value) + } + nextValue, err := setDocumentPathValue(list[token.index], true, rest, value) + if err != nil { + return attributeValue{}, err + } + list[token.index] = nextValue + return attributeValue{L: list}, nil +} + +func appendDocumentPathIndex(list []attributeValue, rest []documentPathToken, value attributeValue) (attributeValue, error) { + child := value + if len(rest) > 0 { + var err error + child, err = setDocumentPathValue(newDocumentContainer(rest[0]), true, rest, value) + if err != nil { + return attributeValue{}, err + } + } + list = append(list, cloneAttributeValue(child)) + return attributeValue{L: list}, nil +} + +func setDocumentPathAttribute(current attributeValue, exists bool, token documentPathToken, rest []documentPathToken, value attributeValue) (attributeValue, error) { + var object map[string]attributeValue + if exists { + if !current.hasMapType() { + return attributeValue{}, errors.New("invalid document path") + } + object = cloneAttributeValueMap(current.M) + } else { + object = map[string]attributeValue{} + } + child, childExists := object[token.attr] + if !childExists && len(rest) > 0 { + child = newDocumentContainer(rest[0]) + childExists = true + } + nextValue, err := setDocumentPathValue(child, childExists, rest, value) + if err != nil { + return attributeValue{}, err + } + object[token.attr] = nextValue + return attributeValue{M: object}, nil +} + +func newDocumentContainer(next documentPathToken) attributeValue { + if next.isIndex { + return attributeValue{L: []attributeValue{}} + } + return attributeValue{M: map[string]attributeValue{}} +} + +func removeDocumentPath(item map[string]attributeValue, path string) error { + tokens, err := parseDocumentPath(path) + if err != nil { + return err + } + root, err := removeDocumentPathValue(attributeValue{M: cloneAttributeValueMap(item)}, true, tokens) + if err != nil { + return err + } + replaceAttributeValueMap(item, root.M) + return nil +} + +func removeDocumentPathValue(current attributeValue, exists bool, tokens []documentPathToken) (attributeValue, error) { + if !exists || len(tokens) == 0 { + return current, nil + } + token := tokens[0] + if token.isIndex { + return removeDocumentPathIndex(current, token, tokens[1:]) + } + return removeDocumentPathAttribute(current, token, tokens[1:]) +} + +func removeDocumentPathIndex(current attributeValue, token documentPathToken, rest []documentPathToken) (attributeValue, error) { + if !current.hasListType() || token.index >= len(current.L) { + return current, nil + } + list := cloneAttributeValueList(current.L) + if len(rest) == 0 { + list = append(list[:token.index], list[token.index+1:]...) + return attributeValue{L: list}, nil + } + nextValue, err := removeDocumentPathValue(list[token.index], true, rest) + if err != nil { + return attributeValue{}, err + } + list[token.index] = nextValue + return attributeValue{L: list}, nil +} + +func removeDocumentPathAttribute(current attributeValue, token documentPathToken, rest []documentPathToken) (attributeValue, error) { + if !current.hasMapType() { + return current, nil + } + object := cloneAttributeValueMap(current.M) + child, ok := object[token.attr] + if !ok { + return current, nil + } + if len(rest) == 0 { + delete(object, token.attr) + return attributeValue{M: object}, nil + } + nextValue, err := removeDocumentPathValue(child, true, rest) + if err != nil { + return attributeValue{}, err + } + object[token.attr] = nextValue + return attributeValue{M: object}, nil +} + +func replaceAttributeValueMap(dst map[string]attributeValue, src map[string]attributeValue) { + clear(dst) + maps.Copy(dst, src) +} + +func deleteAttributeValueElements(current attributeValue, deleteValue attributeValue) (attributeValue, bool, error) { + currentKind, _ := detectAttributeValueKind(current) + deleteKind, _ := detectAttributeValueKind(deleteValue) + if currentKind != deleteKind { + return attributeValue{}, false, errors.New("DELETE supports only matching set attribute types") + } + switch currentKind { + case attributeValueKindStringSet: + return buildDeleteSetResult(attributeValue{SS: subtractStringSet(current.SS, deleteValue.SS)}) + case attributeValueKindNumberSet: + return buildDeleteSetResult(attributeValue{NS: subtractNumberSet(current.NS, deleteValue.NS)}) + case attributeValueKindBinarySet: + return buildDeleteSetResult(attributeValue{BS: subtractBinarySet(current.BS, deleteValue.BS)}) + case attributeValueKindInvalid, + attributeValueKindString, + attributeValueKindNumber, + attributeValueKindBinary, + attributeValueKindBool, + attributeValueKindNull, + attributeValueKindList, + attributeValueKindMap: + return attributeValue{}, false, errors.New("DELETE supports only matching set attribute types") + } + return attributeValue{}, false, errors.New("DELETE supports only matching set attribute types") +} + +func buildDeleteSetResult(next attributeValue) (attributeValue, bool, error) { + if next.hasStringSetType() && len(next.SS) == 0 { + return attributeValue{}, true, nil + } + if next.hasNumberSetType() && len(next.NS) == 0 { + return attributeValue{}, true, nil + } + if next.hasBinarySetType() && len(next.BS) == 0 { + return attributeValue{}, true, nil + } + return next, false, nil +} + +func subtractStringSet(current []string, remove []string) []string { + removeSet := make(map[string]struct{}, len(remove)) + for _, value := range remove { + removeSet[value] = struct{}{} + } + out := make([]string, 0, len(current)) + for _, value := range current { + if _, ok := removeSet[value]; ok { + continue + } + out = append(out, value) + } + return out +} + +func subtractNumberSet(current []string, remove []string) []string { + removeSet := make(map[string]struct{}, len(remove)) + for _, value := range remove { + removeSet[canonicalNumberString(value)] = struct{}{} + } + out := make([]string, 0, len(current)) + for _, value := range current { + if _, ok := removeSet[canonicalNumberString(value)]; ok { + continue + } + out = append(out, value) + } + return out +} + +func subtractBinarySet(current [][]byte, remove [][]byte) [][]byte { + removeSet := make(map[string]struct{}, len(remove)) + for _, value := range remove { + removeSet[string(value)] = struct{}{} + } + out := make([][]byte, 0, len(current)) + for _, value := range current { + if _, ok := removeSet[string(value)]; ok { + continue + } + out = append(out, bytes.Clone(value)) + } + return out +} + +func evalUpdateValueExpression(expr string, values map[string]attributeValue, item map[string]attributeValue) (attributeValue, error) { + expr = strings.TrimSpace(expr) + if expr == "" { + return attributeValue{}, errors.New("invalid update expression") + } + if value, handled, err := evalArithmeticUpdateOperand(expr, values, item); handled { + return value, err + } + if value, handled, err := evalNamedUpdateFunction(expr, values, item, "if_not_exists", evalIfNotExistsUpdateValue); handled { + return value, err + } + if value, handled, err := evalNamedUpdateFunction(expr, values, item, "list_append", evalListAppendUpdateValue); handled { + return value, err + } + return evalUpdateTerminalValue(expr, values, item) +} + +func evalArithmeticUpdateOperand(expr string, values map[string]attributeValue, item map[string]attributeValue) (attributeValue, bool, error) { + index, op, ok := findTopLevelArithmeticOperator(expr) + if !ok { + return attributeValue{}, false, nil + } + left, err := evalUpdateValueExpression(expr[:index], values, item) + if err != nil { + return attributeValue{}, true, err + } + right, err := evalUpdateValueExpression(expr[index+1:], values, item) + if err != nil { + return attributeValue{}, true, err + } + value, err := applyArithmeticUpdateValue(left, right, op) + return value, true, err +} + +func evalNamedUpdateFunction( + expr string, + values map[string]attributeValue, + item map[string]attributeValue, + name string, + eval func([]string, map[string]attributeValue, map[string]attributeValue) (attributeValue, error), +) (attributeValue, bool, error) { + args, ok, err := parseExpressionFunctionArgs(expr, name) + if err != nil { + return attributeValue{}, true, err + } + if !ok { + return attributeValue{}, false, nil + } + value, err := eval(args, values, item) + return value, true, err +} + +func evalUpdateTerminalValue(expr string, values map[string]attributeValue, item map[string]attributeValue) (attributeValue, error) { + if strings.HasPrefix(expr, ":") { + value, ok := values[expr] + if !ok { + return attributeValue{}, errors.New("missing value attribute") + } + return cloneAttributeValue(value), nil + } + value, found, err := resolveDocumentPath(item, expr) + if err != nil { + return attributeValue{}, err + } + if !found { + return attributeValue{}, errors.New("missing value attribute") + } + return value, nil +} + +func evalIfNotExistsUpdateValue(args []string, values map[string]attributeValue, item map[string]attributeValue) (attributeValue, error) { + if len(args) != updateSplitCount { + return attributeValue{}, errors.New("invalid update expression") + } + current, found, err := resolveDocumentPath(item, strings.TrimSpace(args[0])) + if err != nil { + return attributeValue{}, err + } + if found { + return current, nil + } + return evalUpdateValueExpression(args[1], values, item) +} + +func evalListAppendUpdateValue(args []string, values map[string]attributeValue, item map[string]attributeValue) (attributeValue, error) { + if len(args) != updateSplitCount { + return attributeValue{}, errors.New("invalid update expression") + } + left, err := evalUpdateValueExpression(args[0], values, item) + if err != nil { + return attributeValue{}, err + } + right, err := evalUpdateValueExpression(args[1], values, item) + if err != nil { + return attributeValue{}, err + } + if !left.hasListType() || !right.hasListType() { + return attributeValue{}, errors.New("list_append supports only list attributes") + } + out := make([]attributeValue, 0, len(left.L)+len(right.L)) + for _, value := range left.L { + out = append(out, cloneAttributeValue(value)) + } + for _, value := range right.L { + out = append(out, cloneAttributeValue(value)) + } + return attributeValue{L: out}, nil +} + +func applyArithmeticUpdateValue(left attributeValue, right attributeValue, op byte) (attributeValue, error) { + if !left.hasNumberType() || !right.hasNumberType() { + return attributeValue{}, errors.New("arithmetic update supports only number attributes") + } + rightValue := right.numberValue() + if op == '-' { + rightValue = "-" + rightValue + } + sum, err := addNumericAttributeValues(left.numberValue(), rightValue) + if err != nil { + return attributeValue{}, err + } + return attributeValue{N: &sum}, nil +} + +func findTopLevelArithmeticOperator(expr string) (int, byte, bool) { + depth := 0 + for i := 0; i < len(expr); i++ { + depth = nextParenDepth(depth, expr[i]) + if depth != 0 { + continue + } + switch expr[i] { + case '+', '-': + if i == 0 { + continue + } + return i, expr[i], true + } + } + return 0, 0, false +} + +func parseExpressionFunctionArgs(expr string, funcName string) ([]string, bool, error) { + prefix := funcName + "(" + if !strings.HasPrefix(strings.ToLower(expr), strings.ToLower(prefix)) || !strings.HasSuffix(expr, ")") { + return nil, false, nil + } + inner := strings.TrimSpace(expr[len(prefix) : len(expr)-1]) + parts, err := splitTopLevelByComma(inner) + if err != nil { + return nil, true, errors.New("invalid expression") + } + for i := range parts { + parts[i] = strings.TrimSpace(parts[i]) + } + return parts, true, nil +} + +func addAttributeValue(current attributeValue, exists bool, addValue attributeValue) (attributeValue, error) { + if addValue.hasNumberType() { + return addNumericUpdateValue(current, exists, addValue) + } + if addValue.hasStringSetType() || addValue.hasNumberSetType() || addValue.hasBinarySetType() { + return addSetUpdateValue(current, exists, addValue) + } + return attributeValue{}, errors.New("ADD supports only number or set attributes") +} + +func addNumericUpdateValue(current attributeValue, exists bool, addValue attributeValue) (attributeValue, error) { + if !exists { + return cloneAttributeValue(addValue), nil + } + if !current.hasNumberType() { + return attributeValue{}, errors.New("ADD supports only number attributes") + } + sum, err := addNumericAttributeValues(current.numberValue(), addValue.numberValue()) + if err != nil { + return attributeValue{}, err + } + return attributeValue{N: &sum}, nil +} + +func addSetUpdateValue(current attributeValue, exists bool, addValue attributeValue) (attributeValue, error) { + if !exists { + return cloneAttributeValue(addValue), nil + } + switch { + case current.hasStringSetType() && addValue.hasStringSetType(): + return attributeValue{SS: mergeStringSet(current.SS, addValue.SS)}, nil + case current.hasNumberSetType() && addValue.hasNumberSetType(): + return attributeValue{NS: mergeNumberSet(current.NS, addValue.NS)}, nil + case current.hasBinarySetType() && addValue.hasBinarySetType(): + return attributeValue{BS: mergeBinarySet(current.BS, addValue.BS)}, nil + default: + return attributeValue{}, errors.New("ADD supports only matching set attribute types") + } +} + +func mergeStringSet(current []string, add []string) []string { + out := make([]string, 0, len(current)+len(add)) + seen := make(map[string]struct{}, len(current)+len(add)) + for _, value := range append(append([]string(nil), current...), add...) { + if _, ok := seen[value]; ok { + continue + } + seen[value] = struct{}{} + out = append(out, value) + } + return out +} + +func mergeNumberSet(current []string, add []string) []string { + out := make([]string, 0, len(current)+len(add)) + seen := make(map[string]struct{}, len(current)+len(add)) + for _, value := range append(append([]string(nil), current...), add...) { + key := canonicalNumberString(value) + if _, ok := seen[key]; ok { + continue + } + seen[key] = struct{}{} + out = append(out, value) + } + return out +} + +func mergeBinarySet(current [][]byte, add [][]byte) [][]byte { + out := make([][]byte, 0, len(current)+len(add)) + seen := make(map[string]struct{}, len(current)+len(add)) + for _, value := range append(cloneBinarySet(current), add...) { + key := string(value) + if _, ok := seen[key]; ok { + continue + } + seen[key] = struct{}{} + out = append(out, bytes.Clone(value)) + } + return out +} + +func validateConditionOnItem(expr string, names map[string]string, values map[string]attributeValue, item map[string]attributeValue) error { + cond, err := replaceNames(expr, names) + if err != nil { + return err + } + cond = strings.TrimSpace(cond) + if cond == "" { + return nil + } + ok, err := evalConditionExpression(cond, item, values) + if err != nil { + return err + } + if !ok { + return errors.New("conditional check failed") + } + return nil +} + +func evalConditionExpression(expr string, item map[string]attributeValue, values map[string]attributeValue) (bool, error) { + expr = trimOuterParens(strings.TrimSpace(expr)) + if expr == "" { + return true, nil + } + if ok, handled, err := evalLogicalCondition(expr, "OR", item, values); handled { + return ok, err + } + if ok, handled, err := evalLogicalCondition(expr, "AND", item, values); handled { + return ok, err + } + if rest, ok := trimLeadingKeyword(expr, "NOT"); ok { + ok, err := evalConditionExpression(rest, item, values) + if err != nil { + return false, err + } + return !ok, nil + } + return evalAtomicCondition(expr, item, values) +} + +func trimOuterParens(expr string) string { + for { + expr = strings.TrimSpace(expr) + if !hasOuterParens(expr) { + return expr + } + expr = expr[1 : len(expr)-1] + } +} + +func splitTopLevelByKeyword(expr string, keyword string) []string { + if expr == "" { + return nil + } + upper := strings.ToUpper(expr) + target := strings.ToUpper(keyword) + targetLen := len(target) + if targetLen == 0 { + return nil + } + depth := 0 + last := 0 + betweenPending := false + parts := make([]string, 0, splitPartsInitialCapacity) + for i := 0; i < len(expr); i++ { + depth = nextParenDepth(depth, expr[i]) + if depth != 0 { + continue + } + if nextIndex, handled, nextPending := consumeBetweenSplitState(expr, upper, keyword, i, betweenPending); handled { + betweenPending = nextPending + i = nextIndex + continue + } + if !shouldSplitKeywordAt(expr, upper, target, targetLen, i) { + continue + } + part, ok := trimmedNonEmpty(expr[last:i]) + if !ok { + return nil + } + parts = append(parts, part) + i += targetLen - 1 + last = i + 1 + } + return finalizeKeywordSplit(expr[last:], parts) +} + +func consumeBetweenSplitState(expr string, upper string, keyword string, index int, betweenPending bool) (int, bool, bool) { + if !strings.EqualFold(keyword, "AND") { + return index, false, betweenPending + } + if matchesLogicalKeyword(expr, upper, "BETWEEN", index) { + return index + len("BETWEEN") - 1, true, true + } + if betweenPending && matchesLogicalKeyword(expr, upper, "AND", index) { + return index + len("AND") - 1, true, false + } + return index, false, betweenPending +} + +func shouldSplitKeywordAt(expr string, upper string, target string, targetLen int, index int) bool { + return matchesKeywordTokenAt(upper, target, index) && + isLogicalKeywordBoundary(expr, index-1) && + isLogicalKeywordBoundary(expr, index+targetLen) +} + +func matchesLogicalKeyword(expr string, upper string, keyword string, index int) bool { + return matchesKeywordTokenAt(upper, keyword, index) && + isLogicalKeywordBoundary(expr, index-1) && + isLogicalKeywordBoundary(expr, index+len(keyword)) +} + +func evalLogicalCondition(expr string, keyword string, item map[string]attributeValue, values map[string]attributeValue) (bool, bool, error) { + parts := splitTopLevelByKeyword(expr, keyword) + if len(parts) == 0 { + return false, false, nil + } + if strings.EqualFold(keyword, "OR") { + ok, err := evalConditionAny(parts, item, values) + return ok, true, err + } + ok, err := evalConditionAll(parts, item, values) + return ok, true, err +} + +func evalConditionAny(parts []string, item map[string]attributeValue, values map[string]attributeValue) (bool, error) { + for _, part := range parts { + ok, err := evalConditionExpression(part, item, values) + if err != nil { + return false, err + } + if ok { + return true, nil + } + } + return false, nil +} + +func evalConditionAll(parts []string, item map[string]attributeValue, values map[string]attributeValue) (bool, error) { + for _, part := range parts { + ok, err := evalConditionExpression(part, item, values) + if err != nil { + return false, err + } + if !ok { + return false, nil + } + } + return true, nil +} + +func evalAtomicCondition(expr string, item map[string]attributeValue, values map[string]attributeValue) (bool, error) { + for _, handler := range conditionFunctionHandlers { + if ok, handled, err := evalNamedConditionFunction(expr, item, values, handler); handled { + return ok, err + } + } + if ok, handled, err := evalConditionBetween(expr, item, values); handled { + return ok, err + } + if ok, handled, err := evalConditionIn(expr, item, values); handled { + return ok, err + } + return evalConditionComparison(expr, item, values) +} + +type conditionFunctionHandler struct { + name string + eval func([]string, map[string]attributeValue, map[string]attributeValue) (bool, error) +} + +var conditionFunctionHandlers = []conditionFunctionHandler{ + { + name: "attribute_exists", + eval: func(args []string, item map[string]attributeValue, _ map[string]attributeValue) (bool, error) { + return evalAttributeExistsCondition(args, item) + }, + }, + { + name: "attribute_not_exists", + eval: func(args []string, item map[string]attributeValue, _ map[string]attributeValue) (bool, error) { + return evalAttributeNotExistsCondition(args, item) + }, + }, + {name: "attribute_type", eval: evalAttributeTypeCondition}, + {name: "begins_with", eval: evalBeginsWithCondition}, + {name: "contains", eval: evalContainsCondition}, +} + +func evalNamedConditionFunction( + expr string, + item map[string]attributeValue, + values map[string]attributeValue, + handler conditionFunctionHandler, +) (bool, bool, error) { + args, ok, err := parseExpressionFunctionArgs(expr, handler.name) + if err != nil { + return false, true, err + } + if !ok { + return false, false, nil + } + value, err := handler.eval(args, item, values) + return value, true, err +} + +func evalAttributeExistsCondition(args []string, item map[string]attributeValue) (bool, error) { + if len(args) != 1 { + return false, errors.New("unsupported condition expression") + } + _, found, err := resolveDocumentPath(item, args[0]) + if err != nil { + return false, err + } + return found, nil +} + +func evalAttributeNotExistsCondition(args []string, item map[string]attributeValue) (bool, error) { + ok, err := evalAttributeExistsCondition(args, item) + if err != nil { + return false, err + } + return !ok, nil +} + +func evalAttributeTypeCondition(args []string, item map[string]attributeValue, values map[string]attributeValue) (bool, error) { + if len(args) != updateSplitCount { + return false, errors.New("unsupported condition expression") + } + value, found, err := resolveDocumentPath(item, args[0]) + if err != nil || !found { + return false, err + } + typeValue, ok := values[strings.TrimSpace(args[1])] + if !ok || !typeValue.hasStringType() { + return false, errors.New("unsupported condition expression") + } + return dynamoAttributeType(value) == typeValue.stringValue(), nil +} + +func evalBeginsWithCondition(args []string, item map[string]attributeValue, values map[string]attributeValue) (bool, error) { + if len(args) != updateSplitCount { + return false, errors.New("unsupported condition expression") + } + left, found, err := resolveDocumentPath(item, args[0]) + if err != nil || !found { + return false, err + } + right, ok := values[strings.TrimSpace(args[1])] + if !ok { + return false, errors.New("missing condition value") + } + switch { + case left.hasStringType() && right.hasStringType(): + return strings.HasPrefix(left.stringValue(), right.stringValue()), nil + case left.hasBinaryType() && right.hasBinaryType(): + return bytes.HasPrefix(left.B, right.B), nil + default: + return false, nil + } +} + +func evalContainsCondition(args []string, item map[string]attributeValue, values map[string]attributeValue) (bool, error) { + if len(args) != updateSplitCount { + return false, errors.New("unsupported condition expression") + } + left, found, err := resolveDocumentPath(item, args[0]) + if err != nil || !found { + return false, err + } + right, ok := values[strings.TrimSpace(args[1])] + if !ok { + return false, errors.New("missing condition value") + } + return attributeValueContains(left, right), nil +} + +func attributeValueContains(left attributeValue, right attributeValue) bool { + for _, eval := range attributeValueContainsEvaluators { + if handled, ok := eval(left, right); handled { + return ok + } + } + return false +} + +type attributeValueContainsEvaluator func(attributeValue, attributeValue) (bool, bool) + +var attributeValueContainsEvaluators = []attributeValueContainsEvaluator{ + containsStringAttributeValue, + containsBinaryAttributeValue, + containsListAttributeValue, + containsStringSetAttributeValue, + containsNumberSetAttributeValue, + containsBinarySetAttributeValue, +} + +func containsStringAttributeValue(left attributeValue, right attributeValue) (bool, bool) { + if !left.hasStringType() || !right.hasStringType() { + return false, false + } + return true, strings.Contains(left.stringValue(), right.stringValue()) +} + +func containsBinaryAttributeValue(left attributeValue, right attributeValue) (bool, bool) { + if !left.hasBinaryType() || !right.hasBinaryType() { + return false, false + } + return true, bytes.Contains(left.B, right.B) +} + +func containsListAttributeValue(left attributeValue, right attributeValue) (bool, bool) { + if !left.hasListType() { + return false, false + } + return true, listContainsAttributeValue(left.L, right) +} + +func containsStringSetAttributeValue(left attributeValue, right attributeValue) (bool, bool) { + if !left.hasStringSetType() || !right.hasStringType() { + return false, false + } + return true, stringSetContains(left.SS, right.stringValue()) +} + +func containsNumberSetAttributeValue(left attributeValue, right attributeValue) (bool, bool) { + if !left.hasNumberSetType() || !right.hasNumberType() { + return false, false + } + return true, numberSetContains(left.NS, right.numberValue()) +} + +func containsBinarySetAttributeValue(left attributeValue, right attributeValue) (bool, bool) { + if !left.hasBinarySetType() || !right.hasBinaryType() { + return false, false + } + return true, binarySetContains(left.BS, right.B) +} + +func listContainsAttributeValue(values []attributeValue, needle attributeValue) bool { + for _, value := range values { + if attributeValueEqual(value, needle) { + return true + } + } + return false +} + +func stringSetContains(values []string, needle string) bool { + return slices.Contains(values, needle) +} + +func numberSetContains(values []string, needle string) bool { + for _, value := range values { + if cmp, ok := compareNumericAttributeString(value, needle); ok && cmp == 0 { + return true + } + } + return false +} + +func binarySetContains(values [][]byte, needle []byte) bool { + for _, value := range values { + if bytes.Equal(value, needle) { + return true + } + } + return false +} + +func evalConditionBetween(expr string, item map[string]attributeValue, values map[string]attributeValue) (bool, bool, error) { + betweenIndex := findTopLevelKeywordIndex(expr, "BETWEEN") + if betweenIndex < 0 { + return false, false, nil + } + leftExpr := strings.TrimSpace(expr[:betweenIndex]) + rest := strings.TrimSpace(expr[betweenIndex+len("BETWEEN"):]) + andIndex := findTopLevelKeywordIndex(rest, "AND") + if andIndex < 0 { + return false, true, errors.New("unsupported condition expression") + } + lowerExpr := strings.TrimSpace(rest[:andIndex]) + upperExpr := strings.TrimSpace(rest[andIndex+len("AND"):]) + left, found, err := resolveConditionOperand(leftExpr, item, values) + if err != nil || !found { + return false, true, err + } + lower, found, err := resolveConditionOperand(lowerExpr, item, values) + if err != nil || !found { + return false, true, err + } + upper, found, err := resolveConditionOperand(upperExpr, item, values) + if err != nil || !found { + return false, true, err + } + return compareAttributeValueSortKey(left, lower) >= 0 && compareAttributeValueSortKey(left, upper) <= 0, true, nil +} + +func evalConditionIn(expr string, item map[string]attributeValue, values map[string]attributeValue) (bool, bool, error) { + inIndex := findTopLevelKeywordIndex(expr, "IN") + if inIndex < 0 { + return false, false, nil + } + left, parts, err := parseConditionInOperands(expr, inIndex, item, values) + if err != nil { + return false, true, err + } + ok, err := conditionInListContains(left, parts, item, values) + return ok, true, err +} + +func parseConditionInOperands(expr string, inIndex int, item map[string]attributeValue, values map[string]attributeValue) (attributeValue, []string, error) { + leftExpr := strings.TrimSpace(expr[:inIndex]) + rest := strings.TrimSpace(expr[inIndex+len("IN"):]) + if !strings.HasPrefix(rest, "(") || !strings.HasSuffix(rest, ")") { + return attributeValue{}, nil, errors.New("unsupported condition expression") + } + left, found, err := resolveConditionOperand(leftExpr, item, values) + if err != nil || !found { + return attributeValue{}, nil, err + } + parts, err := splitTopLevelByComma(rest[1 : len(rest)-1]) + if err != nil { + return attributeValue{}, nil, errors.New("unsupported condition expression") + } + return left, parts, nil +} + +func conditionInListContains(left attributeValue, parts []string, item map[string]attributeValue, values map[string]attributeValue) (bool, error) { + for _, part := range parts { + candidate, found, err := resolveConditionOperand(part, item, values) + if err != nil { + return false, err + } + if found && attributeValueEqual(left, candidate) { + return true, nil + } + } + return false, nil +} + +func evalConditionComparison(expr string, item map[string]attributeValue, values map[string]attributeValue) (bool, error) { + index, operator, ok := findTopLevelConditionComparator(expr) + if !ok { + return false, errors.New("unsupported condition expression") + } + left, right, err := resolveConditionComparisonOperands(expr, index, operator, item, values) + if err != nil { + return false, err + } + return compareConditionValues(operator, left, right) +} + +func resolveConditionComparisonOperands( + expr string, + index int, + operator string, + item map[string]attributeValue, + values map[string]attributeValue, +) (attributeValue, attributeValue, error) { + leftExpr := strings.TrimSpace(expr[:index]) + rightExpr := strings.TrimSpace(expr[index+len(operator):]) + left, found, err := resolveConditionOperand(leftExpr, item, values) + if err != nil || !found { + return attributeValue{}, attributeValue{}, err + } + right, found, err := resolveConditionOperand(rightExpr, item, values) + if err != nil || !found { + return attributeValue{}, attributeValue{}, err + } + return left, right, nil +} + +func compareConditionValues(operator string, left attributeValue, right attributeValue) (bool, error) { + switch operator { + case "=": + return attributeValueEqual(left, right), nil + case "<>": + return !attributeValueEqual(left, right), nil + case "<": + return compareAttributeValueSortKey(left, right) < 0, nil + case "<=": + return compareAttributeValueSortKey(left, right) <= 0, nil + case ">": + return compareAttributeValueSortKey(left, right) > 0, nil + case ">=": + return compareAttributeValueSortKey(left, right) >= 0, nil + default: + return false, errors.New("unsupported condition expression") + } +} + +func resolveConditionOperand(expr string, item map[string]attributeValue, values map[string]attributeValue) (attributeValue, bool, error) { + expr = strings.TrimSpace(expr) + if expr == "" { + return attributeValue{}, false, errors.New("unsupported condition expression") + } + if args, ok, err := parseExpressionFunctionArgs(expr, "size"); ok || err != nil { + if err != nil { + return attributeValue{}, false, err + } + return resolveConditionSizeOperand(args, item) + } + if strings.HasPrefix(expr, ":") { + value, ok := values[expr] + if !ok { + return attributeValue{}, false, errors.New("missing condition value") + } + return cloneAttributeValue(value), true, nil + } + value, found, err := resolveDocumentPath(item, expr) + if err != nil { + return attributeValue{}, false, err + } + return value, found, nil +} + +func resolveConditionSizeOperand(args []string, item map[string]attributeValue) (attributeValue, bool, error) { + if len(args) != 1 { + return attributeValue{}, false, errors.New("unsupported condition expression") + } + value, found, err := resolveDocumentPath(item, args[0]) + if err != nil || !found { + return attributeValue{}, false, err + } + size := attributeValueSize(value) + sizeString := strconv.Itoa(size) + return attributeValue{N: &sizeString}, true, nil +} + +func attributeValueSize(value attributeValue) int { + switch { + case value.hasStringType(): + return len(value.stringValue()) + case value.hasBinaryType(): + return len(value.B) + case value.hasStringSetType(): + return len(value.SS) + case value.hasNumberSetType(): + return len(value.NS) + case value.hasBinarySetType(): + return len(value.BS) + case value.hasListType(): + return len(value.L) + case value.hasMapType(): + return len(value.M) + default: + return 0 + } +} + +func findTopLevelKeywordIndex(expr string, keyword string) int { + upper := strings.ToUpper(expr) + target := strings.ToUpper(keyword) + depth := 0 + for i := 0; i < len(expr); i++ { + depth = nextParenDepth(depth, expr[i]) + if depth != 0 || !matchesKeywordTokenAt(upper, target, i) { + continue + } + if !isLogicalKeywordBoundary(expr, i-1) || !isLogicalKeywordBoundary(expr, i+len(target)) { + continue + } + return i + } + return -1 +} + +func trimLeadingKeyword(expr string, keyword string) (string, bool) { + upper := strings.ToUpper(strings.TrimSpace(expr)) + keyword = strings.ToUpper(keyword) + if !strings.HasPrefix(upper, keyword) { + return "", false + } + trimmed := strings.TrimSpace(expr) + if !isLogicalKeywordBoundary(trimmed, len(keyword)) { + return "", false + } + return strings.TrimSpace(trimmed[len(keyword):]), true +} + +func findTopLevelConditionComparator(expr string) (int, string, bool) { + operators := []string{"<>", "<=", ">=", "=", "<", ">"} + depth := 0 + for i := 0; i < len(expr); i++ { + depth = nextParenDepth(depth, expr[i]) + if depth != 0 { + continue + } + for _, operator := range operators { + if strings.HasPrefix(expr[i:], operator) { + return i, operator, true + } + } + } + return 0, "", false +} + +func dynamoAttributeType(value attributeValue) string { + kind, count := detectAttributeValueKind(value) + if count != 1 { + return "" + } + return string(kind) +} + +func hasOuterParens(expr string) bool { + if len(expr) < 2 || expr[0] != '(' || expr[len(expr)-1] != ')' { + return false + } + depth := 0 + for i := 0; i < len(expr); i++ { + depth = nextParenDepth(depth, expr[i]) + if depth == 0 && i != len(expr)-1 { + return false + } + if depth < 0 { + return false + } + } + return depth == 0 +} + +func nextParenDepth(depth int, ch byte) int { + switch ch { + case '(': + return depth + 1 + case ')': + return depth - 1 + default: + return depth + } +} + +func matchesKeywordTokenAt(upperExpr string, target string, pos int) bool { + end := pos + len(target) + if end > len(upperExpr) { + return false + } + return upperExpr[pos:end] == target +} + +func isLogicalKeywordBoundary(s string, pos int) bool { + if pos < 0 || pos >= len(s) { + return true + } + ch := s[pos] + // Keep identifier-style characters as token characters so expressions like + // "MY_AND_VAR" or "a-OR-b" are not split at logical keyword substrings. + if isExpressionAttributeNameByte(ch) { + return false + } + return true +} + +func trimmedNonEmpty(s string) (string, bool) { + trimmed := strings.TrimSpace(s) + return trimmed, trimmed != "" +} + +func finalizeKeywordSplit(tailExpr string, parts []string) []string { + if len(parts) == 0 { + return nil + } + tail, ok := trimmedNonEmpty(tailExpr) + if !ok { + return nil + } + return append(parts, tail) +} diff --git a/adapter/dynamodb_locks.go b/adapter/dynamodb_locks.go new file mode 100644 index 00000000..8bbb3698 --- /dev/null +++ b/adapter/dynamodb_locks.go @@ -0,0 +1,123 @@ +package adapter + +import ( + "context" + "hash/fnv" + "sort" + "strings" + + "github.com/bootjp/elastickv/kv" + "github.com/cockroachdb/errors" +) + +func (d *DynamoDBServer) lockItemUpdate(lockKey string) func() { + idx := stripeIndex(lockKey, itemUpdateLockStripeCount) + d.itemUpdateLocks[idx].Lock() + return d.itemUpdateLocks[idx].Unlock +} + +func (d *DynamoDBServer) lockTableOperations(tableNames []string) func() { + if len(tableNames) == 0 { + return func() {} + } + idxs := make([]int, 0, len(tableNames)) + seen := map[int]struct{}{} + for _, tableName := range tableNames { + idx := stripeIndex(tableName, tableLockStripeCount) + if _, ok := seen[idx]; ok { + continue + } + seen[idx] = struct{}{} + idxs = append(idxs, idx) + } + sort.Ints(idxs) + for _, idx := range idxs { + d.tableLocks[idx].Lock() + } + return func() { + for i := len(idxs) - 1; i >= 0; i-- { + d.tableLocks[idxs[i]].Unlock() + } + } +} + +func stripeIndex(key string, stripeCount uint32) int { + if stripeCount == 0 { + return 0 + } + h := fnv.New32a() + _, _ = h.Write([]byte(key)) + return int(h.Sum32() % stripeCount) +} + +func dynamoItemUpdateLockKey(tableName string, key map[string]attributeValue) (string, error) { + parts := make([]string, 0, len(key)) + for name := range key { + parts = append(parts, name) + } + sort.Strings(parts) + var b strings.Builder + b.WriteString(tableName) + b.WriteByte('|') + for _, name := range parts { + val, err := attributeValueAsKey(key[name]) + if err != nil { + return "", errors.WithStack(err) + } + b.WriteString(name) + b.WriteByte('=') + b.WriteString(val) + b.WriteByte('|') + } + return b.String(), nil +} + +// globalLastCommitTSProvider is the interface satisfied by stores (e.g. +// LeaderRoutedStore) that can proxy the leader's LastCommitTS. Defined as a +// local interface so DynamoDBServer avoids a hard dependency on the concrete +// kv type. +type globalLastCommitTSProvider interface { + GlobalLastCommitTS(ctx context.Context) uint64 +} + +func (d *DynamoDBServer) nextTxnReadTS() uint64 { + // On a follower the local store.LastCommitTS() may lag behind the leader. + // Use GlobalLastCommitTS so ConsistentRead snapshots and transaction + // start timestamps are aligned with the leader's committed watermark, + // preventing stale pre-reads that cause false Jepsen anomalies and + // unnecessary WriteConflict retries on every follower request. + maxTS := uint64(0) + if p, ok := d.store.(globalLastCommitTSProvider); ok { + maxTS = p.GlobalLastCommitTS(context.Background()) + } else if d.store != nil { + maxTS = d.store.LastCommitTS() + } + + // Advance the HLC so subsequent commitTS calls produce values > maxTS, + // but return maxTS directly as the snapshot — NOT clock.Next(). + // + // clock.Next() can be ahead of store.LastCommitTS() because concurrent + // dispatchTxn calls advance the HLC before their Raft entry is applied. + // If readTS = clock.Next() = T and a concurrent write obtained + // commitTS = T-1 (still in the Raft pipeline), the version at T-1 is + // not yet in Pebble. Reads would see stale data and the FSM conflict + // check (latestTS > startTS: T-1 > T → false) would silently pass, + // allowing corrupted writes. Returning maxTS closes this gap: every + // version at ≤ maxTS is guaranteed visible, and any concurrent write at + // > maxTS triggers a WriteConflict and a retry. + clock := d.coordinator.Clock() + if clock != nil && maxTS > 0 { + clock.Observe(maxTS) + } + if maxTS == 0 { + return 1 + } + return maxTS +} + +func (d *DynamoDBServer) pinReadTS(ts uint64) *kv.ActiveTimestampToken { + if d == nil || d.readTracker == nil { + return &kv.ActiveTimestampToken{} + } + return d.readTracker.Pin(ts) +} diff --git a/adapter/dynamodb_migration.go b/adapter/dynamodb_migration.go new file mode 100644 index 00000000..6fcaff1d --- /dev/null +++ b/adapter/dynamodb_migration.go @@ -0,0 +1,308 @@ +package adapter + +import ( + "bytes" + "context" + "net/http" + "time" + + "github.com/bootjp/elastickv/kv" + "github.com/bootjp/elastickv/store" + "github.com/cockroachdb/errors" +) + +func (d *DynamoDBServer) ensureLegacyTableMigration(ctx context.Context, tableName string) error { + unlock := d.lockTableOperations([]string{tableName}) + defer unlock() + return d.ensureLegacyTableMigrationLocked(ctx, tableName) +} + +func (d *DynamoDBServer) ensureLegacyTableMigrationLocked(ctx context.Context, tableName string) error { + backoff := transactRetryInitialBackoff + deadline := time.Now().Add(transactRetryMaxDuration) + for range transactRetryMaxAttempts { + readTS := d.nextTxnReadTS() + schema, exists, err := d.loadTableSchemaAt(ctx, tableName, readTS) + if err != nil { + return errors.WithStack(err) + } + if !exists || !schema.needsLegacyKeyMigration() { + return nil + } + // Admin read-only callers (AdminScanTable) must not trigger + // migration writes. Their own pre-check at the admin readTS + // already rejects needs-migration tables, but the schema can + // transition between that check and this one (Codex r8 P2 on + // PR #805) — refuse rather than racing into write-path code. + if isAdminReadOnlyContext(ctx) { + return errors.Wrap(ErrAdminDynamoValidation, + "table requires a one-time legacy-key migration before admin read endpoints are available; migrate via the SigV4 surface first") + } + if !schema.usesOrderedKeyEncoding() { + err = d.startLegacyTableKeyMigration(ctx, schema, readTS) + } else { + err = d.migrateLegacyTableGeneration(ctx, schema) + } + if err == nil { + continue + } + if !isRetryableTransactWriteError(err) { + return err + } + if err := waitRetryWithDeadline(ctx, deadline, backoff); err != nil { + return errors.WithStack(err) + } + backoff = nextTransactRetryBackoff(backoff) + } + return newDynamoAPIError(http.StatusInternalServerError, dynamoErrInternal, "legacy table migration retry attempts exhausted") +} + +func (d *DynamoDBServer) startLegacyTableKeyMigration( + ctx context.Context, + schema *dynamoTableSchema, + readTS uint64, +) error { + if schema == nil || schema.usesOrderedKeyEncoding() { + return nil + } + nextGeneration, err := d.nextTableGenerationAt(ctx, schema.TableName, readTS) + if err != nil { + return err + } + req, err := makeCreateTableRequest(&dynamoTableSchema{ + TableName: schema.TableName, + AttributeDefinitions: schema.AttributeDefinitions, + PrimaryKey: schema.PrimaryKey, + GlobalSecondaryIndexes: schema.GlobalSecondaryIndexes, + KeyEncodingVersion: dynamoOrderedKeyEncodingV2, + MigratingFromGeneration: schema.Generation, + }, nextGeneration) + if err != nil { + return err + } + req.StartTS = readTS + if _, err := d.coordinator.Dispatch(ctx, req); err != nil { + return errors.WithStack(err) + } + return nil +} + +func (d *DynamoDBServer) migrateLegacyTableGeneration(ctx context.Context, schema *dynamoTableSchema) error { + sourceSchema := schema.migrationSourceSchema() + if sourceSchema == nil { + return nil + } + sourceReadTS := snapshotTS(d.coordinator.Clock(), d.store) + if err := d.migrateLegacySourceItems(ctx, schema, sourceSchema, sourceReadTS); err != nil { + return err + } + empty, err := d.isTableGenerationEmpty(ctx, schema.TableName, sourceSchema.Generation) + if err != nil { + return err + } + if !empty { + return nil + } + return d.finalizeLegacyTableMigration(ctx, schema) +} + +func (d *DynamoDBServer) migrateLegacySourceItems( + ctx context.Context, + targetSchema *dynamoTableSchema, + sourceSchema *dynamoTableSchema, + readTS uint64, +) error { + readPin := d.pinReadTS(readTS) + defer readPin.Release() + + prefix := dynamoItemPrefixForTable(targetSchema.TableName, sourceSchema.Generation) + upper := prefixScanEnd(prefix) + cursor := bytes.Clone(prefix) + for { + kvs, err := d.scanLegacyMigrationPage(ctx, cursor, upper, readTS) + if err != nil { + return err + } + nextCursor, done, err := d.migrateLegacySourcePage(ctx, targetSchema, sourceSchema, prefix, upper, kvs) + if err != nil { + return err + } + if done { + return nil + } + cursor = nextCursor + } +} + +func (d *DynamoDBServer) migrateLegacyItem( + ctx context.Context, + targetSchema *dynamoTableSchema, + sourceSchema *dynamoTableSchema, + sourceKey []byte, + sourceItem map[string]attributeValue, +) error { + lockKey, targetKey, err := resolveLegacyMigrationTarget(targetSchema, sourceItem) + if err != nil { + return err + } + unlock := d.lockItemUpdate(lockKey) + defer unlock() + + backoff := transactRetryInitialBackoff + deadline := time.Now().Add(transactRetryMaxDuration) + for range transactRetryMaxAttempts { + readTS := d.nextTxnReadTS() + req, done, err := d.buildLegacyMigrationRequest(ctx, targetSchema, sourceSchema, targetKey, sourceKey, readTS) + if err != nil { + return err + } + if done { + return nil + } + req.StartTS = readTS + if _, err := d.coordinator.Dispatch(ctx, req); err == nil { + return nil + } else if !isRetryableTransactWriteError(err) { + return errors.WithStack(err) + } + if err := waitRetryWithDeadline(ctx, deadline, backoff); err != nil { + return errors.WithStack(err) + } + backoff = nextTransactRetryBackoff(backoff) + } + return newDynamoAPIError(http.StatusInternalServerError, dynamoErrInternal, "legacy item migration retry attempts exhausted") +} + +func (d *DynamoDBServer) scanLegacyMigrationPage( + ctx context.Context, + cursor []byte, + upper []byte, + readTS uint64, +) ([]*store.KVPair, error) { + kvs, err := d.store.ScanAt(ctx, cursor, upper, dynamoScanPageLimit, readTS) + if err != nil { + return nil, errors.WithStack(err) + } + return kvs, nil +} + +func (d *DynamoDBServer) migrateLegacySourcePage( + ctx context.Context, + targetSchema *dynamoTableSchema, + sourceSchema *dynamoTableSchema, + prefix []byte, + upper []byte, + kvs []*store.KVPair, +) ([]byte, bool, error) { + if len(kvs) == 0 { + return nil, true, nil + } + for _, kvp := range kvs { + if !bytes.HasPrefix(kvp.Key, prefix) { + return nil, true, nil + } + item, err := decodeStoredDynamoItem(kvp.Value) + if err != nil { + return nil, false, err + } + if err := d.migrateLegacyItem(ctx, targetSchema, sourceSchema, kvp.Key, item); err != nil { + return nil, false, err + } + } + if len(kvs) < dynamoScanPageLimit { + return nil, true, nil + } + cursor := nextScanCursor(kvs[len(kvs)-1].Key) + if upper != nil && bytes.Compare(cursor, upper) >= 0 { + return nil, true, nil + } + return cursor, false, nil +} + +func resolveLegacyMigrationTarget(targetSchema *dynamoTableSchema, sourceItem map[string]attributeValue) (string, []byte, error) { + keyAttrs, err := primaryKeyAttributes(targetSchema.PrimaryKey, sourceItem) + if err != nil { + return "", nil, err + } + lockKey, err := dynamoItemUpdateLockKey(targetSchema.TableName, keyAttrs) + if err != nil { + return "", nil, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) + } + targetKey, err := targetSchema.itemKeyFromAttributes(keyAttrs) + if err != nil { + return "", nil, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) + } + return lockKey, targetKey, nil +} + +func (d *DynamoDBServer) buildLegacyMigrationRequest( + ctx context.Context, + targetSchema *dynamoTableSchema, + sourceSchema *dynamoTableSchema, + targetKey []byte, + sourceKey []byte, + readTS uint64, +) (*kv.OperationGroup[kv.OP], bool, error) { + _, targetFound, err := d.readItemAtKeyAt(ctx, targetKey, readTS) + if err != nil { + return nil, false, err + } + currentSource, sourceFound, err := d.readItemAtKeyAt(ctx, sourceKey, readTS) + if err != nil { + return nil, false, err + } + if !sourceFound { + return nil, true, nil + } + currentLocation := &dynamoItemLocation{ + schema: sourceSchema, + key: sourceKey, + item: currentSource, + } + if targetFound { + req, err := buildItemDeleteRequestWithSource(currentLocation) + return req, false, err + } + req, _, err := buildItemWriteRequestWithSource(targetSchema, targetKey, currentSource, currentLocation) + return req, false, err +} + +func (d *DynamoDBServer) isTableGenerationEmpty(ctx context.Context, tableName string, generation uint64) (bool, error) { + prefix := dynamoItemPrefixForTable(tableName, generation) + kvs, err := d.store.ScanAt(ctx, prefix, prefixScanEnd(prefix), 1, snapshotTS(d.coordinator.Clock(), d.store)) + if err != nil { + return false, errors.WithStack(err) + } + for _, kvp := range kvs { + if bytes.HasPrefix(kvp.Key, prefix) { + return false, nil + } + } + return true, nil +} + +func (d *DynamoDBServer) finalizeLegacyTableMigration(ctx context.Context, schema *dynamoTableSchema) error { + if schema == nil || schema.MigratingFromGeneration == 0 { + return nil + } + oldGeneration := schema.MigratingFromGeneration + finalized := *schema + finalized.MigratingFromGeneration = 0 + body, err := encodeStoredDynamoTableSchema(&finalized) + if err != nil { + return errors.WithStack(err) + } + readTS := d.nextTxnReadTS() + req := &kv.OperationGroup[kv.OP]{ + IsTxn: true, + StartTS: readTS, + Elems: []*kv.Elem[kv.OP]{ + {Op: kv.Put, Key: dynamoTableMetaKey(schema.TableName), Value: body}, + }, + } + if _, err := d.coordinator.Dispatch(ctx, req); err != nil { + return errors.WithStack(err) + } + d.launchDeletedTableCleanup(schema.TableName, oldGeneration) + return nil +} From cc09b2d6d1abe301ee890fbcca2adf8a47440f56 Mon Sep 17 00:00:00 2001 From: "Yoshiaki Ueda (bootjp)" Date: Fri, 12 Jun 2026 15:44:32 +0900 Subject: [PATCH 4/4] adapter: reduce dynamodb.go to server core after split (no behavior change) dynamodb.go now holds only the server core/registration: DynamoDBServer struct + options + constructor, Run/Stop, HTTP handler/dispatch/health, body-size readers, and the request-metrics recorder. Imports trimmed to what the core uses. No declarations changed. --- adapter/dynamodb.go | 8788 ------------------------------------------- 1 file changed, 8788 deletions(-) diff --git a/adapter/dynamodb.go b/adapter/dynamodb.go index 47794859..b0ca73c4 100644 --- a/adapter/dynamodb.go +++ b/adapter/dynamodb.go @@ -1,21 +1,13 @@ package adapter import ( - "bytes" "context" - "encoding/base64" - "encoding/binary" - "hash/fnv" "io" - "log/slog" "maps" - "math/big" "net" "net/http" "os" - "slices" "sort" - "strconv" "strings" "sync" "time" @@ -24,7 +16,6 @@ import ( "github.com/bootjp/elastickv/monitoring" "github.com/bootjp/elastickv/store" "github.com/cockroachdb/errors" - json "github.com/goccy/go-json" ) const ( @@ -547,8782 +538,3 @@ func (d *DynamoDBServer) observeWrittenItems(ctx context.Context, table string, } state.addTableMetrics(table, 0, 0, writtenItems) } - -type createTableAttributeDefinition struct { - AttributeName string `json:"AttributeName"` - AttributeType string `json:"AttributeType"` -} - -type createTableKeySchemaElement struct { - AttributeName string `json:"AttributeName"` - KeyType string `json:"KeyType"` -} - -type createTableGSI struct { - IndexName string `json:"IndexName"` - KeySchema []createTableKeySchemaElement `json:"KeySchema"` - Projection createTableProjection `json:"Projection"` -} - -type createTableProjection struct { - ProjectionType string `json:"ProjectionType"` - NonKeyAttributes []string `json:"NonKeyAttributes"` -} - -type createTableInput struct { - TableName string `json:"TableName"` - AttributeDefinitions []createTableAttributeDefinition `json:"AttributeDefinitions"` - KeySchema []createTableKeySchemaElement `json:"KeySchema"` - GlobalSecondaryIndexes []createTableGSI `json:"GlobalSecondaryIndexes"` -} - -type deleteTableInput struct { - TableName string `json:"TableName"` -} - -type describeTableInput struct { - TableName string `json:"TableName"` -} - -type listTablesInput struct { - ExclusiveStartTableName string `json:"ExclusiveStartTableName"` - Limit int32 `json:"Limit"` -} - -type queryInput struct { - TableName string `json:"TableName"` - IndexName string `json:"IndexName"` - KeyConditionExpression string `json:"KeyConditionExpression"` - FilterExpression string `json:"FilterExpression"` - ProjectionExpression string `json:"ProjectionExpression"` - ExpressionAttributeNames map[string]string `json:"ExpressionAttributeNames"` - ExpressionAttributeValues map[string]attributeValue `json:"ExpressionAttributeValues"` - ScanIndexForward *bool `json:"ScanIndexForward"` - Limit *int32 `json:"Limit"` - ExclusiveStartKey map[string]attributeValue `json:"ExclusiveStartKey"` - Select string `json:"Select"` - ConsistentRead *bool `json:"ConsistentRead"` -} - -type getItemInput struct { - TableName string `json:"TableName"` - Key map[string]attributeValue `json:"Key"` - ProjectionExpression string `json:"ProjectionExpression"` - ExpressionAttributeNames map[string]string `json:"ExpressionAttributeNames"` - ConsistentRead *bool `json:"ConsistentRead"` -} - -type scanInput struct { - TableName string `json:"TableName"` - IndexName string `json:"IndexName"` - FilterExpression string `json:"FilterExpression"` - ProjectionExpression string `json:"ProjectionExpression"` - ExpressionAttributeNames map[string]string `json:"ExpressionAttributeNames"` - ExpressionAttributeValues map[string]attributeValue `json:"ExpressionAttributeValues"` - ExclusiveStartKey map[string]attributeValue `json:"ExclusiveStartKey"` - Limit *int32 `json:"Limit"` - Select string `json:"Select"` - ConsistentRead *bool `json:"ConsistentRead"` -} - -type updateItemInput struct { - TableName string `json:"TableName"` - Key map[string]attributeValue `json:"Key"` - UpdateExpression string `json:"UpdateExpression"` - ConditionExpression string `json:"ConditionExpression"` - ExpressionAttributeNames map[string]string `json:"ExpressionAttributeNames"` - ExpressionAttributeValues map[string]attributeValue `json:"ExpressionAttributeValues"` - ReturnValues string `json:"ReturnValues"` -} - -type putItemInput struct { - TableName string `json:"TableName"` - Item map[string]attributeValue `json:"Item"` - ConditionExpression string `json:"ConditionExpression"` - ExpressionAttributeNames map[string]string `json:"ExpressionAttributeNames"` - ExpressionAttributeValues map[string]attributeValue `json:"ExpressionAttributeValues"` - ReturnValues string `json:"ReturnValues"` -} - -type deleteItemInput struct { - TableName string `json:"TableName"` - Key map[string]attributeValue `json:"Key"` - ConditionExpression string `json:"ConditionExpression"` - ExpressionAttributeNames map[string]string `json:"ExpressionAttributeNames"` - ExpressionAttributeValues map[string]attributeValue `json:"ExpressionAttributeValues"` - ReturnValues string `json:"ReturnValues"` -} - -type batchWriteItemInput struct { - RequestItems map[string][]batchWriteRequest `json:"RequestItems"` -} - -type batchWriteRequest struct { - PutRequest *batchPutRequest `json:"PutRequest,omitempty"` - DeleteRequest *batchDeleteRequest `json:"DeleteRequest,omitempty"` -} - -type batchPutRequest struct { - Item map[string]attributeValue `json:"Item"` -} - -type batchDeleteRequest struct { - Key map[string]attributeValue `json:"Key"` -} - -type transactWriteItemsInput struct { - TransactItems []transactWriteItem `json:"TransactItems"` -} - -type transactGetItemsInput struct { - TransactItems []transactGetItem `json:"TransactItems"` -} - -type transactGetItem struct { - Get *transactGetItemGet `json:"Get"` -} - -type transactGetItemGet struct { - TableName string `json:"TableName"` - Key map[string]attributeValue `json:"Key"` - ProjectionExpression string `json:"ProjectionExpression,omitempty"` - ExpressionAttributeNames map[string]string `json:"ExpressionAttributeNames,omitempty"` -} - -type transactWriteItem struct { - Put *putItemInput `json:"Put,omitempty"` - Update *transactUpdateInput `json:"Update,omitempty"` - Delete *transactDeleteInput `json:"Delete,omitempty"` - ConditionCheck *transactConditionInput `json:"ConditionCheck,omitempty"` -} - -type transactUpdateInput struct { - TableName string `json:"TableName"` - Key map[string]attributeValue `json:"Key"` - UpdateExpression string `json:"UpdateExpression"` - ConditionExpression string `json:"ConditionExpression"` - ExpressionAttributeNames map[string]string `json:"ExpressionAttributeNames"` - ExpressionAttributeValues map[string]attributeValue `json:"ExpressionAttributeValues"` -} - -type transactDeleteInput struct { - TableName string `json:"TableName"` - Key map[string]attributeValue `json:"Key"` - ConditionExpression string `json:"ConditionExpression"` - ExpressionAttributeNames map[string]string `json:"ExpressionAttributeNames"` - ExpressionAttributeValues map[string]attributeValue `json:"ExpressionAttributeValues"` -} - -type transactConditionInput struct { - TableName string `json:"TableName"` - Key map[string]attributeValue `json:"Key"` - ConditionExpression string `json:"ConditionExpression"` - ExpressionAttributeNames map[string]string `json:"ExpressionAttributeNames"` - ExpressionAttributeValues map[string]attributeValue `json:"ExpressionAttributeValues"` -} - -type dynamoKeySchema struct { - HashKey string `json:"hash_key"` - RangeKey string `json:"range_key,omitempty"` -} - -type dynamoGSIProjection struct { - ProjectionType string `json:"projection_type"` - NonKeyAttributes []string `json:"non_key_attributes,omitempty"` -} - -type dynamoGlobalSecondaryIndex struct { - KeySchema dynamoKeySchema `json:"key_schema"` - Projection dynamoGSIProjection `json:"projection"` -} - -func (g *dynamoGlobalSecondaryIndex) UnmarshalJSON(b []byte) error { - type rawGSI struct { - KeySchema *dynamoKeySchema `json:"key_schema"` - Projection *dynamoGSIProjection `json:"projection"` - HashKey string `json:"hash_key"` - RangeKey string `json:"range_key"` - } - - var raw rawGSI - if err := json.Unmarshal(b, &raw); err != nil { - return errors.WithStack(err) - } - - if raw.KeySchema != nil { - g.KeySchema = *raw.KeySchema - } else { - g.KeySchema = dynamoKeySchema{ - HashKey: raw.HashKey, - RangeKey: raw.RangeKey, - } - } - - if raw.Projection != nil && strings.TrimSpace(raw.Projection.ProjectionType) != "" { - g.Projection = *raw.Projection - } else { - // Older schema snapshots stored only the key schema. Those GSIs behaved - // like ALL projections, so preserve that behavior when normalizing. - g.Projection = dynamoGSIProjection{ProjectionType: "ALL"} - } - - return nil -} - -type dynamoTableSchema struct { - TableName string `json:"table_name"` - AttributeDefinitions map[string]string `json:"attribute_definitions,omitempty"` - PrimaryKey dynamoKeySchema `json:"primary_key"` - GlobalSecondaryIndexes map[string]dynamoGlobalSecondaryIndex `json:"global_secondary_indexes,omitempty"` - KeyEncodingVersion int `json:"key_encoding_version,omitempty"` - MigratingFromGeneration uint64 `json:"migrating_from_generation,omitempty"` - Generation uint64 `json:"generation"` -} - -func (d *DynamoDBServer) createTable(w http.ResponseWriter, r *http.Request) { - in, err := decodeCreateTableInput(maxDynamoBodyReader(w, r)) - if err != nil { - writeDynamoErrorFromErr(w, err) - return - } - unlock := d.lockTableOperations([]string{in.TableName}) - defer unlock() - schema, err := buildCreateTableSchema(in) - if err != nil { - writeDynamoError(w, http.StatusBadRequest, dynamoErrValidation, err.Error()) - return - } - if err := d.createTableWithRetry(r.Context(), in.TableName, schema); err != nil { - writeDynamoErrorFromErr(w, err) - return - } - d.observeTables(r.Context(), schema.TableName) - writeDynamoJSON(w, map[string]any{ - "TableDescription": map[string]any{ - "TableName": in.TableName, - "TableStatus": "ACTIVE", - }, - }) -} - -func decodeCreateTableInput(bodyReader io.Reader) (createTableInput, error) { - body, err := io.ReadAll(bodyReader) - if err != nil { - return createTableInput{}, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) - } - var in createTableInput - if err := json.Unmarshal(body, &in); err != nil { - return createTableInput{}, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) - } - if strings.TrimSpace(in.TableName) == "" { - return createTableInput{}, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, "missing table name") - } - return in, nil -} - -func buildCreateTableSchema(in createTableInput) (*dynamoTableSchema, error) { - primary, err := parseCreateTableKeySchema(in.KeySchema) - if err != nil { - return nil, err - } - attrDefs := make(map[string]string, len(in.AttributeDefinitions)) - for _, def := range in.AttributeDefinitions { - if strings.TrimSpace(def.AttributeName) == "" { - return nil, errors.New("invalid attribute definition") - } - attrDefs[def.AttributeName] = def.AttributeType - } - gsis := make(map[string]dynamoGlobalSecondaryIndex, len(in.GlobalSecondaryIndexes)) - for _, gsi := range in.GlobalSecondaryIndexes { - if strings.TrimSpace(gsi.IndexName) == "" { - return nil, errors.New("invalid global secondary index") - } - ks, err := parseCreateTableKeySchema(gsi.KeySchema) - if err != nil { - return nil, err - } - projection, err := buildCreateTableProjection(gsi.Projection) - if err != nil { - return nil, err - } - gsis[gsi.IndexName] = dynamoGlobalSecondaryIndex{ - KeySchema: ks, - Projection: projection, - } - } - return &dynamoTableSchema{ - TableName: in.TableName, - AttributeDefinitions: attrDefs, - PrimaryKey: primary, - GlobalSecondaryIndexes: gsis, - KeyEncodingVersion: dynamoOrderedKeyEncodingV2, - }, nil -} - -func buildCreateTableProjection(in createTableProjection) (dynamoGSIProjection, error) { - switch strings.TrimSpace(in.ProjectionType) { - case "", "ALL": - return dynamoGSIProjection{ProjectionType: "ALL"}, nil - case "KEYS_ONLY": - return dynamoGSIProjection{ProjectionType: "KEYS_ONLY"}, nil - case "INCLUDE": - return dynamoGSIProjection{ - ProjectionType: "INCLUDE", - NonKeyAttributes: append([]string(nil), in.NonKeyAttributes...), - }, nil - default: - return dynamoGSIProjection{}, errors.New("invalid projection") - } -} - -func (d *DynamoDBServer) createTableWithRetry(ctx context.Context, tableName string, baseSchema *dynamoTableSchema) error { - backoff := transactRetryInitialBackoff - deadline := time.Now().Add(transactRetryMaxDuration) - for range transactRetryMaxAttempts { - readTS := d.nextTxnReadTS() - exists, err := d.tableExistsAt(ctx, tableName, readTS) - if err != nil { - return err - } - if exists { - return newDynamoAPIError(http.StatusBadRequest, dynamoErrResourceInUse, "table already exists") - } - nextGeneration, err := d.nextTableGenerationAt(ctx, tableName, readTS) - if err != nil { - return err - } - req, err := makeCreateTableRequest(baseSchema, nextGeneration) - if err != nil { - return err - } - if _, err := d.coordinator.Dispatch(ctx, req); err == nil { - return nil - } - if !isRetryableTransactWriteError(err) { - return errors.WithStack(err) - } - if err := waitRetryWithDeadline(ctx, deadline, backoff); err != nil { - return errors.WithStack(err) - } - backoff = nextTransactRetryBackoff(backoff) - } - return newDynamoAPIError(http.StatusInternalServerError, dynamoErrInternal, "create table retry attempts exhausted") -} - -func (d *DynamoDBServer) tableExistsAt(ctx context.Context, tableName string, readTS uint64) (bool, error) { - _, exists, err := d.loadTableSchemaAt(ctx, tableName, readTS) - if err != nil { - return false, errors.WithStack(err) - } - return exists, nil -} - -func (d *DynamoDBServer) nextTableGenerationAt(ctx context.Context, tableName string, readTS uint64) (uint64, error) { - lastGeneration, err := d.loadTableGenerationAt(ctx, tableName, readTS) - if err != nil { - return 0, errors.WithStack(err) - } - return lastGeneration + 1, nil -} - -func makeCreateTableRequest(baseSchema *dynamoTableSchema, nextGeneration uint64) (*kv.OperationGroup[kv.OP], error) { - schema := &dynamoTableSchema{ - TableName: baseSchema.TableName, - AttributeDefinitions: baseSchema.AttributeDefinitions, - PrimaryKey: baseSchema.PrimaryKey, - GlobalSecondaryIndexes: baseSchema.GlobalSecondaryIndexes, - KeyEncodingVersion: baseSchema.KeyEncodingVersion, - MigratingFromGeneration: baseSchema.MigratingFromGeneration, - Generation: nextGeneration, - } - schemaBytes, err := encodeStoredDynamoTableSchema(schema) - if err != nil { - return nil, errors.WithStack(err) - } - return &kv.OperationGroup[kv.OP]{ - IsTxn: true, - StartTS: 0, - Elems: []*kv.Elem[kv.OP]{ - {Op: kv.Put, Key: dynamoTableMetaKey(baseSchema.TableName), Value: schemaBytes}, - {Op: kv.Put, Key: dynamoTableGenerationKey(baseSchema.TableName), Value: []byte(strconv.FormatUint(nextGeneration, 10))}, - }, - }, nil -} - -func (d *DynamoDBServer) deleteTable(w http.ResponseWriter, r *http.Request) { - in, err := decodeDeleteTableInput(maxDynamoBodyReader(w, r)) - if err != nil { - writeDynamoErrorFromErr(w, err) - return - } - unlock := d.lockTableOperations([]string{in.TableName}) - defer unlock() - if err := d.deleteTableWithRetry(r.Context(), in.TableName); err != nil { - writeDynamoErrorFromErr(w, err) - return - } - resp := map[string]any{ - "TableDescription": map[string]any{ - "TableName": in.TableName, - "TableStatus": "DELETING", - }, - } - writeDynamoJSON(w, resp) -} - -func decodeDeleteTableInput(bodyReader io.Reader) (deleteTableInput, error) { - body, err := io.ReadAll(bodyReader) - if err != nil { - return deleteTableInput{}, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) - } - var in deleteTableInput - if err := json.Unmarshal(body, &in); err != nil { - return deleteTableInput{}, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) - } - if strings.TrimSpace(in.TableName) == "" { - return deleteTableInput{}, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, "missing table name") - } - return in, nil -} - -func (d *DynamoDBServer) deleteTableWithRetry(ctx context.Context, tableName string) error { - backoff := transactRetryInitialBackoff - deadline := time.Now().Add(transactRetryMaxDuration) - for range transactRetryMaxAttempts { - readTS := d.nextTxnReadTS() - schema, exists, err := d.loadTableSchemaAt(ctx, tableName, readTS) - if err != nil { - return errors.WithStack(err) - } - if !exists { - return newDynamoAPIError(http.StatusBadRequest, dynamoErrResourceNotFound, "table not found") - } - - req := &kv.OperationGroup[kv.OP]{ - IsTxn: true, - StartTS: 0, - Elems: []*kv.Elem[kv.OP]{ - {Op: kv.Del, Key: dynamoTableMetaKey(tableName)}, - }, - } - if _, err := d.coordinator.Dispatch(ctx, req); err != nil { - if !isRetryableTransactWriteError(err) { - return errors.WithStack(err) - } - } else { - d.launchDeletedTableCleanup(tableName, schema.Generation) - return nil - } - - if err := waitRetryWithDeadline(ctx, deadline, backoff); err != nil { - return errors.WithStack(err) - } - backoff = nextTransactRetryBackoff(backoff) - } - return newDynamoAPIError(http.StatusInternalServerError, dynamoErrInternal, "delete table retry attempts exhausted") -} - -func (d *DynamoDBServer) launchDeletedTableCleanup(tableName string, generation uint64) { - go func() { - ctx, cancel := context.WithTimeout(context.Background(), tableCleanupAsyncTimeout) - defer cancel() - if err := d.cleanupDeletedTableGeneration(ctx, tableName, generation); err != nil { - slog.Error("dynamodb delete table cleanup failed", - "table", tableName, - "generation", generation, - "error", err, - ) - } - }() -} - -func (d *DynamoDBServer) cleanupDeletedTableGeneration(ctx context.Context, tableName string, generation uint64) error { - prefixes := [][]byte{ - dynamoItemPrefixForTable(tableName, generation), - dynamoGSIPrefixForTable(tableName, generation), - } - // Dispatch a single DEL_PREFIX operation per prefix. The FSM on each node - // scans and writes tombstones locally, avoiding the enumerate-then-batch- - // delete loop that previously required many Raft proposals. - for _, prefix := range prefixes { - _, err := d.coordinator.Dispatch(ctx, &kv.OperationGroup[kv.OP]{ - Elems: []*kv.Elem[kv.OP]{ - {Op: kv.DelPrefix, Key: prefix}, - }, - }) - if err != nil { - return errors.WithStack(err) - } - } - return nil -} - -func (d *DynamoDBServer) dispatchDeleteBatch(ctx context.Context, keys [][]byte) error { - elems := make([]*kv.Elem[kv.OP], 0, len(keys)) - for _, key := range keys { - elems = append(elems, &kv.Elem[kv.OP]{Op: kv.Del, Key: key}) - } - req := &kv.OperationGroup[kv.OP]{ - IsTxn: false, - Elems: elems, - } - _, err := d.coordinator.Dispatch(ctx, req) - if err != nil { - return errors.WithStack(err) - } - return nil -} - -func (d *DynamoDBServer) describeTable(w http.ResponseWriter, r *http.Request) { - body, err := io.ReadAll(maxDynamoBodyReader(w, r)) - if err != nil { - writeDynamoError(w, http.StatusBadRequest, dynamoErrValidation, err.Error()) - return - } - var in describeTableInput - if err := json.Unmarshal(body, &in); err != nil { - writeDynamoError(w, http.StatusBadRequest, dynamoErrValidation, err.Error()) - return - } - if strings.TrimSpace(in.TableName) == "" { - writeDynamoError(w, http.StatusBadRequest, dynamoErrValidation, "missing table name") - return - } - if err := d.ensureLegacyTableMigration(r.Context(), in.TableName); err != nil { - writeDynamoErrorFromErr(w, err) - return - } - schema, exists, err := d.loadTableSchema(r.Context(), in.TableName) - if err != nil { - writeDynamoError(w, http.StatusInternalServerError, dynamoErrInternal, err.Error()) - return - } - if !exists { - writeDynamoError(w, http.StatusBadRequest, dynamoErrResourceNotFound, "table not found") - return - } - writeDynamoJSON(w, map[string]any{"Table": describeTableShape(schema)}) -} - -func (d *DynamoDBServer) listTables(w http.ResponseWriter, r *http.Request) { - in, err := decodeListTablesInput(maxDynamoBodyReader(w, r)) - if err != nil { - writeDynamoErrorFromErr(w, err) - return - } - names, err := d.listTableNames(r.Context()) - if err != nil { - writeDynamoError(w, http.StatusInternalServerError, dynamoErrInternal, err.Error()) - return - } - outNames, hasNext := paginateTableNames(names, in) - - resp := map[string]any{"TableNames": outNames} - if hasNext && len(outNames) > 0 { - resp["LastEvaluatedTableName"] = outNames[len(outNames)-1] - } - writeDynamoJSON(w, resp) -} - -func decodeListTablesInput(bodyReader io.Reader) (listTablesInput, error) { - body, err := io.ReadAll(bodyReader) - if err != nil { - return listTablesInput{}, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) - } - var in listTablesInput - if len(bytes.TrimSpace(body)) == 0 { - return in, nil - } - if err := json.Unmarshal(body, &in); err != nil { - return listTablesInput{}, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) - } - return in, nil -} - -func paginateTableNames(names []string, in listTablesInput) ([]string, bool) { - start := findExclusiveStartIndex(names, in.ExclusiveStartTableName) - limit := resolveTableListLimit(in.Limit, len(names)) - end := min(start+limit, len(names)) - return names[start:end], end < len(names) -} - -func findExclusiveStartIndex(names []string, startName string) int { - if startName == "" { - return 0 - } - for i, name := range names { - if name == startName { - return i + 1 - } - } - return 0 -} - -func resolveTableListLimit(limit int32, tableCount int) int { - if limit <= 0 || int(limit) >= tableCount { - return tableCount - } - return int(limit) -} - -func (d *DynamoDBServer) listTableNames(ctx context.Context) ([]string, error) { - kvs, err := d.scanAllByPrefix(ctx, []byte(dynamoTableMetaPrefix)) - if err != nil { - return nil, err - } - names := make([]string, 0, len(kvs)) - for _, kvp := range kvs { - name, ok := tableNameFromMetaKey(kvp.Key) - if !ok { - continue - } - names = append(names, name) - } - sort.Strings(names) - return names, nil -} - -func (d *DynamoDBServer) putItem(w http.ResponseWriter, r *http.Request) { - in, err := decodePutItemInput(maxDynamoBodyReader(w, r)) - if err != nil { - writeDynamoErrorFromErr(w, err) - return - } - if err := d.ensureLegacyTableMigration(r.Context(), in.TableName); err != nil { - writeDynamoErrorFromErr(w, err) - return - } - plan, err := d.putItemWithRetry(r.Context(), in) - if err != nil { - writeDynamoErrorFromErr(w, err) - return - } - d.observeWrittenItems(r.Context(), in.TableName, 1) - resp := map[string]any{} - if attrs := putItemReturnAttributes(in.ReturnValues, plan.current); len(attrs) > 0 { - resp["Attributes"] = attrs - } - writeDynamoJSON(w, resp) -} - -func decodePutItemInput(bodyReader io.Reader) (putItemInput, error) { - body, err := io.ReadAll(bodyReader) - if err != nil { - return putItemInput{}, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) - } - var in putItemInput - if err := json.Unmarshal(body, &in); err != nil { - return putItemInput{}, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) - } - if strings.TrimSpace(in.TableName) == "" { - return putItemInput{}, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, "missing table name") - } - if err := validatePutItemReturnValues(in.ReturnValues); err != nil { - return putItemInput{}, err - } - return in, nil -} - -func validatePutItemReturnValues(returnValues string) error { - switch strings.TrimSpace(returnValues) { - case "", dynamoReturnValueNone, dynamoReturnValueAllOld: - return nil - default: - return newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, "unsupported ReturnValues") - } -} - -func (d *DynamoDBServer) putItemWithRetry(ctx context.Context, in putItemInput) (*itemWritePlan, error) { - return d.retryItemWriteWithGeneration( - ctx, - in.TableName, - "put item retry attempts exhausted", - func(readTS uint64) (*itemWritePlan, error) { - return d.preparePutItemWrite(ctx, in, readTS) - }, - ) -} - -type itemWritePlan struct { - req *kv.OperationGroup[kv.OP] - generation uint64 - cleanup [][]byte - current map[string]attributeValue - next map[string]attributeValue -} - -func (d *DynamoDBServer) retryItemWriteWithGeneration( - ctx context.Context, - tableName string, - exhaustedMessage string, - prepare func(readTS uint64) (*itemWritePlan, error), -) (*itemWritePlan, error) { - // Option-2 one-phase dedup (gated, default off): on a retryable write error, - // reuse the failed attempt's write set under a fresh commit_ts + prev_commit_ts - // so the FSM no-ops a commit that already landed under leadership churn, - // instead of re-reading and re-appending (the :duplicate-elements anomaly). - // See docs/design/2026_06_03_partial_dynamodb_onephase_dedup.md. - // - // Leader-only (codex P1, PR #920): the dedup path allocates commit_ts from - // the LOCAL HLC and carries it as prev_commit_ts, so that timestamp MUST be - // leader-issued to stay globally unique — otherwise two frontends could mint - // the same commit_ts in one millisecond and the exact-ts probe would dedup - // against the wrong writer's version, losing an update. On the leader the - // single HLC issues monotonic unique values, and NextFenced's physical-ceiling - // fence keeps a deposed leader's window disjoint from its successor's. A - // non-leader (reachable only when no leaderMap HTTP proxy forwards follower - // ingress) falls back to the legacy path, where Coordinator.Dispatch redirects - // to the leader and the LEADER allocates commit_ts — never this follower's HLC. - if d.onePhaseTxnDedup && d.coordinator.IsLeader() { - return d.retryItemWriteWithGenerationDedup(ctx, tableName, exhaustedMessage, prepare) - } - return d.retryItemWriteWithGenerationLegacy(ctx, tableName, exhaustedMessage, prepare) -} - -// retryItemWriteWithGenerationLegacy is the pre-dedup retry loop: it recomputes -// the write set from a fresh read on every retryable error. It is the active -// path whenever the dedup gate is off or this node is not the leader, so it -// stays byte-identical to the pre-feature behavior. -func (d *DynamoDBServer) retryItemWriteWithGenerationLegacy( - ctx context.Context, - tableName string, - exhaustedMessage string, - prepare func(readTS uint64) (*itemWritePlan, error), -) (*itemWritePlan, error) { - backoff := transactRetryInitialBackoff - deadline := time.Now().Add(transactRetryMaxDuration) - for range transactRetryMaxAttempts { - readTS := d.nextTxnReadTS() - plan, err := prepare(readTS) - if err != nil { - return nil, err - } - if plan.req == nil { - return plan, nil - } - plan.req.StartTS = readTS - if err = d.commitItemWrite(ctx, plan.req); err != nil { - if !isRetryableTransactWriteError(err) { - return nil, errors.WithStack(err) - } - } else { - retry, verifyErr := d.handleGenerationFenceResult( - ctx, - d.verifyTableGeneration(ctx, tableName, plan.generation), - plan.cleanup, - ) - if verifyErr != nil { - return nil, verifyErr - } - if !retry { - return plan, nil - } - } - if err := waitRetryWithDeadline(ctx, deadline, backoff); err != nil { - return nil, errors.WithStack(err) - } - backoff = nextTransactRetryBackoff(backoff) - } - return nil, newDynamoAPIError(http.StatusInternalServerError, dynamoErrInternal, exhaustedMessage) -} - -// reusableItemWrite captures a dispatched single-item write attempt so a -// subsequent retry can REUSE its exact write set (the same Put/Del elems) under -// a fresh commit_ts and probe whether it already landed, instead of re-reading -// and recomputing the item. Recomputing is what duplicates a list_append under -// leadership churn: attempt 1 commits at C1 but returns a WriteConflict, the -// retry re-reads the now-larger list and appends again. Reuse + the FSM's -// exact-ts dedup probe close that. See option 2 in -// docs/design/2026_06_03_partial_dynamodb_onephase_dedup.md. -type reusableItemWrite struct { - // plan holds the reused OperationGroup (plan.req: Elems + fixed StartTS) and - // the captured current/next item. The client-visible result - // (updateItemReturnAttributes over current/next) is invariant across reuse - // — the write set was built once from attempt 1's read — so plan is also the - // correct value to return when the FSM dedup no-ops the apply (R1). - plan *itemWritePlan - // commitTS is the most recent dispatched commit_ts for this write set; the - // next retry passes it as PrevCommitTS so the FSM probes exactly the attempt - // that might have landed. - commitTS uint64 - // probeKey is kv.PrimaryKeyForElems(plan.req.Elems) — the same key the FSM - // uses as meta.PrimaryKey — so the adapter-side self-inflicted-conflict guard - // and the FSM dedup probe agree on the point they query (R4). - probeKey []byte -} - -// retryItemWriteWithGenerationDedup is the option-2 retry loop. The first -// attempt computes the write set from a fresh read; any retryable failure makes -// the next iteration REUSE that write set under a fresh commit_ts carrying -// prev_commit_ts, so the FSM no-ops if the prior attempt already landed. A -// genuine WriteConflict on a reuse (the self-conflict probe missed) drops the -// pending attempt and recomputes from a fresh read. -func (d *DynamoDBServer) retryItemWriteWithGenerationDedup( - ctx context.Context, - tableName string, - exhaustedMessage string, - prepare func(readTS uint64) (*itemWritePlan, error), -) (*itemWritePlan, error) { - backoff := transactRetryInitialBackoff - deadline := time.Now().Add(transactRetryMaxDuration) - var pending *reusableItemWrite - for range transactRetryMaxAttempts { - var ( - plan *itemWritePlan - err error - ) - if pending != nil { - plan, pending, err = d.itemWriteReuseAttempt(ctx, tableName, pending) - } else { - plan, pending, err = d.itemWriteFirstAttempt(ctx, tableName, prepare) - } - if err != nil { - // commitItemWrite already wraps dispatch errors; the attempt helpers - // return them raw, so return raw here too (no double WithStack). - if !isRetryableTransactWriteError(err) { - return nil, err - } - } else if plan != nil { - return plan, nil - } - if waitErr := waitRetryWithDeadline(ctx, deadline, backoff); waitErr != nil { - return nil, errors.WithStack(waitErr) - } - backoff = nextTransactRetryBackoff(backoff) - } - return nil, newDynamoAPIError(http.StatusInternalServerError, dynamoErrInternal, exhaustedMessage) -} - -// itemWriteFirstAttempt runs the recompute branch of the dedup loop: a fresh -// read snapshot, a locally-allocated commit_ts, and a dispatch. On a retryable -// write error it returns a reusableItemWrite so the next iteration reuses this -// write set. Return shapes match itemWriteReuseAttempt (see -// retryItemWriteWithGenerationDedup). -func (d *DynamoDBServer) itemWriteFirstAttempt( - ctx context.Context, - tableName string, - prepare func(readTS uint64) (*itemWritePlan, error), -) (*itemWritePlan, *reusableItemWrite, error) { - readTS := d.nextTxnReadTS() - plan, err := prepare(readTS) - if err != nil { - return nil, nil, err - } - if plan.req == nil { - return plan, nil, nil - } - // NextFenced (not Next) honors the HLC physical-ceiling fence so a - // stale-leader window cannot mint a colliding commit_ts (HLC-4); - // ErrCeilingExpired is non-retryable and surfaces to the client. - commitTS, err := d.coordinator.Clock().NextFenced() - if err != nil { - return nil, nil, errors.Wrap(err, "dynamodb item-write first attempt: allocate commitTS") - } - plan.req.StartTS = readTS - plan.req.CommitTS = commitTS - if dispErr := d.commitItemWrite(ctx, plan.req); dispErr != nil { - // dispErr is already wrapped by commitItemWrite; return it raw. - if isRetryableTransactWriteError(dispErr) { - return nil, &reusableItemWrite{ - plan: plan, - commitTS: commitTS, - probeKey: kv.PrimaryKeyForElems(plan.req.Elems), - }, dispErr - } - return nil, nil, dispErr - } - return d.finishItemWriteAttempt(ctx, tableName, plan) -} - -// itemWriteReuseAttempt runs one reuse iteration: re-dispatch the captured write -// set under a fresh commit_ts carrying pending.commitTS as PrevCommitTS, so the -// FSM probes whether the prior attempt landed. -func (d *DynamoDBServer) itemWriteReuseAttempt( - ctx context.Context, - tableName string, - pending *reusableItemWrite, -) (*itemWritePlan, *reusableItemWrite, error) { - commitTS, err := d.coordinator.Clock().NextFenced() - if err != nil { - return nil, pending, errors.Wrap(err, "dynamodb item-write reuse: allocate commitTS") - } - pending.plan.req.CommitTS = commitTS - pending.plan.req.PrevCommitTS = pending.commitTS - dispErr := d.commitItemWrite(ctx, pending.plan.req) - if dispErr == nil { - return d.finishItemWriteAttempt(ctx, tableName, pending.plan) - } - if errors.Is(dispErr, store.ErrWriteConflict) { - return d.resolveReuseWriteConflict(ctx, tableName, pending, commitTS, dispErr) - } - if isRetryableTransactWriteError(dispErr) { - // Still ambiguous (e.g. TxnLocked): this reuse may itself have landed, - // so the next retry must probe THIS commit_ts. dispErr is already - // wrapped by commitItemWrite; return it raw. - pending.commitTS = commitTS - return nil, pending, dispErr - } - return nil, nil, dispErr -} - -// resolveReuseWriteConflict handles a WriteConflict from a reuse dispatch via -// the self-inflicted-conflict guard: probe whether THIS reuse's commit_ts -// actually landed (the apply may have committed but surfaced WriteConflict under -// churn). On a hit the conflict is against our own commit — return the cached -// plan, no double-apply. On a miss the write key is genuinely held by another -// txn — drop pending so the next iteration recomputes from a fresh read. -func (d *DynamoDBServer) resolveReuseWriteConflict( - ctx context.Context, - tableName string, - pending *reusableItemWrite, - commitTS uint64, - dispErr error, -) (*itemWritePlan, *reusableItemWrite, error) { - if len(pending.probeKey) > 0 { - landed, perr := d.store.CommittedVersionAt(ctx, pending.probeKey, commitTS) - if perr != nil { - // Fail closed: a probe read error makes "did our reuse land?" - // unknowable, and a blind recompute would double-append if it HAD - // landed. Surface the probe error instead of silently recomputing, - // matching the FSM-side dedupProbeOnePhase (kv/fsm.go) which also - // propagates probe errors. The wrapped error is non-retryable, so - // the loop returns it to the client rather than re-applying. - return nil, nil, errors.Wrap(perr, "dynamodb item-write: self-conflict probe") - } - if landed { - // The reuse landed at commitTS. Run the SAME generation fence + - // cleanup the normal success path runs (finishItemWriteAttempt), so - // a table dropped/recreated under us cleans up the landed write and - // recomputes instead of returning a stale plan (coderabbit major). - return d.finishItemWriteAttempt(ctx, tableName, pending.plan) - } - } - // Probe missed (or no probe key): a genuine cross-writer conflict. dispErr - // is already wrapped by commitItemWrite; return it raw so the loop recomputes. - return nil, nil, dispErr -} - -// finishItemWriteAttempt runs the table-generation fence after a successful -// commit. Returns (plan, nil, nil) when the write is durable; (nil, nil, nil) -// when the generation changed and the caller must recompute from a fresh read; -// (nil, nil, err) on a fence error. -func (d *DynamoDBServer) finishItemWriteAttempt( - ctx context.Context, - tableName string, - plan *itemWritePlan, -) (*itemWritePlan, *reusableItemWrite, error) { - retry, verifyErr := d.handleGenerationFenceResult( - ctx, - d.verifyTableGeneration(ctx, tableName, plan.generation), - plan.cleanup, - ) - if verifyErr != nil { - return nil, nil, verifyErr - } - if retry { - return nil, nil, nil - } - return plan, nil, nil -} - -func (d *DynamoDBServer) preparePutItemWrite(ctx context.Context, in putItemInput, readTS uint64) (*itemWritePlan, error) { - schema, exists, err := d.loadTableSchemaAt(ctx, in.TableName, readTS) - if err != nil { - return nil, errors.WithStack(err) - } - if !exists { - return nil, newDynamoAPIError(http.StatusBadRequest, dynamoErrResourceNotFound, "table not found") - } - itemKey, err := schema.itemKeyFromAttributes(in.Item) - if err != nil { - return nil, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) - } - keyAttrs, err := primaryKeyAttributes(schema.PrimaryKey, in.Item) - if err != nil { - return nil, err - } - currentLocation, found, err := d.readLogicalItemAt(ctx, schema, keyAttrs, readTS) - if err != nil { - return nil, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) - } - var current map[string]attributeValue - if found { - current = currentLocation.item - } - if err := validateConditionOnItem( - in.ConditionExpression, - in.ExpressionAttributeNames, - in.ExpressionAttributeValues, - valueOrEmptyMap(current, found), - ); err != nil { - return nil, newDynamoAPIError(http.StatusBadRequest, dynamoErrConditionalFailed, err.Error()) - } - req, cleanup, err := buildItemWriteRequestWithSource(schema, itemKey, in.Item, currentLocation) - if err != nil { - return nil, err - } - return &itemWritePlan{ - req: req, - generation: schema.Generation, - cleanup: cleanup, - current: cloneAttributeValueMap(current), - next: cloneAttributeValueMap(in.Item), - }, nil -} - -func (d *DynamoDBServer) commitItemWrite(ctx context.Context, req *kv.OperationGroup[kv.OP]) error { - _, err := d.coordinator.Dispatch(ctx, req) - if err != nil { - return errors.WithStack(err) - } - return nil -} - -func (d *DynamoDBServer) parseGetItemInput(w http.ResponseWriter, r *http.Request) (getItemInput, bool) { - body, err := io.ReadAll(maxDynamoBodyReader(w, r)) - if err != nil { - writeDynamoError(w, http.StatusBadRequest, dynamoErrValidation, err.Error()) - return getItemInput{}, false - } - var in getItemInput - if err := json.Unmarshal(body, &in); err != nil { - writeDynamoError(w, http.StatusBadRequest, dynamoErrValidation, err.Error()) - return getItemInput{}, false - } - if strings.TrimSpace(in.TableName) == "" { - writeDynamoError(w, http.StatusBadRequest, dynamoErrValidation, "missing table name") - return getItemInput{}, false - } - if err := d.ensureLegacyTableMigration(r.Context(), in.TableName); err != nil { - writeDynamoErrorFromErr(w, err) - return getItemInput{}, false - } - return in, true -} - -func (d *DynamoDBServer) getItem(w http.ResponseWriter, r *http.Request) { - in, ok := d.parseGetItemInput(w, r) - if !ok { - return - } - // Tentative TS for schema resolution only; schemas change rarely - // so a slight pre-lease stale is acceptable. The item read below - // is sampled AFTER the lease check. - tentativeTS := d.resolveDynamoReadTS(in.ConsistentRead) - _, itemKey, ok := d.resolveGetItemTarget(w, r, in, tentativeTS) - if !ok { - return - } - // Lease-check the shard that actually owns the ITEM key with a - // bounded timeout so a stalled Raft cannot hang this handler - // indefinitely if the client never cancels. Use defer so the - // cancel runs even if LeaseReadForKey panics or a future - // refactor inserts an early return; the cost of keeping ctx - // alive until handler exit is negligible because the next - // in-handler calls are local store reads. - leaseCtx, leaseCancel := context.WithTimeout(r.Context(), dynamoLeaseReadTimeout) - defer leaseCancel() - if _, err := kv.LeaseReadForKeyThrough(d.coordinator, leaseCtx, itemKey); err != nil { - writeDynamoError(w, http.StatusInternalServerError, dynamoErrInternal, err.Error()) - return - } - // Re-sample readTS AFTER the lease confirmation so that any write - // that completed on the same shard BEFORE the confirmation is - // visible. Sampling earlier would violate linearizability for - // ConsistentRead=false reads by returning a snapshot from before - // the most recent confirmed commit. - readTS := d.resolveDynamoReadTS(in.ConsistentRead) - // Pin readTS so concurrent MVCC GC cannot reclaim versions - // between the schema revalidation and the item read below; - // matches the pattern already used by queryItems / scanItems / - // transactGetItems. - readPin := d.pinReadTS(readTS) - defer readPin.Release() - - // Re-resolve schema + itemKey at readTS and verify that the key - // we lease-checked is STILL the key that will be read. A table - // migration that commits between the tentative schema load and - // the lease confirmation may shift the item to a different shard - // even if the request parameters are unchanged, so comparing the - // computed item keys (not just generation) catches any future - // schema change that alters item routing. - finalSchema, freshItemKey, ok := d.resolveGetItemTarget(w, r, in, readTS) - if !ok { - return - } - if !bytes.Equal(freshItemKey, itemKey) { - writeDynamoError(w, http.StatusServiceUnavailable, dynamoErrInternal, - "table routing changed during read; please retry") - return - } - - current, found, err := d.readLogicalItemAt(r.Context(), finalSchema, in.Key, readTS) - if err != nil { - writeDynamoError(w, http.StatusBadRequest, dynamoErrValidation, err.Error()) - return - } - if !found { - writeDynamoJSON(w, map[string]any{}) - return - } - d.observeReadMetrics(r.Context(), in.TableName, 1, 1) - projected, err := projectItem(current.item, in.ProjectionExpression, in.ExpressionAttributeNames) - if err != nil { - writeDynamoErrorFromErr(w, err) - return - } - writeDynamoJSON(w, map[string]any{"Item": projected}) -} - -// resolveGetItemTarget loads the schema and computes the item key whose -// shard must be lease-checked before the read. Returns false after -// writing an error response; the caller should simply return. -func (d *DynamoDBServer) resolveGetItemTarget(w http.ResponseWriter, r *http.Request, in getItemInput, readTS uint64) (*dynamoTableSchema, []byte, bool) { - schema, exists, err := d.loadTableSchemaAt(r.Context(), in.TableName, readTS) - if err != nil { - writeDynamoError(w, http.StatusInternalServerError, dynamoErrInternal, err.Error()) - return nil, nil, false - } - if !exists { - writeDynamoError(w, http.StatusBadRequest, dynamoErrResourceNotFound, "table not found") - return nil, nil, false - } - itemKey, err := schema.itemKeyFromAttributes(in.Key) - if err != nil { - writeDynamoError(w, http.StatusBadRequest, dynamoErrValidation, err.Error()) - return nil, nil, false - } - return schema, itemKey, true -} - -func (d *DynamoDBServer) deleteItem(w http.ResponseWriter, r *http.Request) { - in, shouldReturnOld, err := decodeDeleteItemInput(maxDynamoBodyReader(w, r)) - if err != nil { - writeDynamoErrorFromErr(w, err) - return - } - if err := d.ensureLegacyTableMigration(r.Context(), in.TableName); err != nil { - writeDynamoErrorFromErr(w, err) - return - } - lockKey, err := dynamoItemUpdateLockKey(in.TableName, in.Key) - if err != nil { - writeDynamoError(w, http.StatusBadRequest, dynamoErrValidation, err.Error()) - return - } - unlock := d.lockItemUpdate(lockKey) - defer unlock() - plan, err := d.deleteItemWithRetry(r.Context(), in) - if err != nil { - writeDynamoErrorFromErr(w, err) - return - } - if len(plan.current) > 0 { - d.observeWrittenItems(r.Context(), in.TableName, 1) - } - resp := map[string]any{} - if shouldReturnOld && len(plan.current) > 0 { - resp["Attributes"] = plan.current - } - writeDynamoJSON(w, resp) -} - -func decodeDeleteItemInput(bodyReader io.Reader) (deleteItemInput, bool, error) { - body, err := io.ReadAll(bodyReader) - if err != nil { - return deleteItemInput{}, false, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) - } - var in deleteItemInput - if err := json.Unmarshal(body, &in); err != nil { - return deleteItemInput{}, false, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) - } - if strings.TrimSpace(in.TableName) == "" { - return deleteItemInput{}, false, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, "missing table name") - } - shouldReturnOld, err := parseDeleteItemReturnValues(in.ReturnValues) - if err != nil { - return deleteItemInput{}, false, err - } - return in, shouldReturnOld, nil -} - -func parseDeleteItemReturnValues(returnValues string) (bool, error) { - switch strings.TrimSpace(returnValues) { - case "", dynamoReturnValueNone: - return false, nil - case dynamoReturnValueAllOld: - return true, nil - default: - return false, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, "unsupported ReturnValues") - } -} - -type deleteItemPlan struct { - req *kv.OperationGroup[kv.OP] - generation uint64 - current map[string]attributeValue -} - -func (d *DynamoDBServer) deleteItemWithRetry(ctx context.Context, in deleteItemInput) (*deleteItemPlan, error) { - var deletePlan *deleteItemPlan - _, err := d.retryItemWriteWithGeneration( - ctx, - in.TableName, - "delete retry attempts exhausted", - func(readTS uint64) (*itemWritePlan, error) { - var err error - deletePlan, err = d.prepareDeleteItemWrite(ctx, in, readTS) - if err != nil { - return nil, err - } - return &itemWritePlan{ - req: deletePlan.req, - generation: deletePlan.generation, - }, nil - }, - ) - if err != nil { - return nil, err - } - return deletePlan, nil -} - -func (d *DynamoDBServer) prepareDeleteItemWrite(ctx context.Context, in deleteItemInput, readTS uint64) (*deleteItemPlan, error) { - schema, exists, err := d.loadTableSchemaAt(ctx, in.TableName, readTS) - if err != nil { - return nil, errors.WithStack(err) - } - if !exists { - return nil, newDynamoAPIError(http.StatusBadRequest, dynamoErrResourceNotFound, "table not found") - } - currentLocation, found, err := d.readLogicalItemAt(ctx, schema, in.Key, readTS) - if err != nil { - return nil, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) - } - current := map[string]attributeValue(nil) - if found { - current = currentLocation.item - } - if err := validateConditionOnItem( - in.ConditionExpression, - in.ExpressionAttributeNames, - in.ExpressionAttributeValues, - valueOrEmptyMap(current, found), - ); err != nil { - return nil, newDynamoAPIError(http.StatusBadRequest, dynamoErrConditionalFailed, err.Error()) - } - if !found { - return &deleteItemPlan{current: nil}, nil - } - req, err := buildItemDeleteRequestWithSource(currentLocation) - if err != nil { - return nil, err - } - return &deleteItemPlan{ - req: req, - generation: schema.Generation, - current: cloneAttributeValueMap(current), - }, nil -} - -func (d *DynamoDBServer) updateItem(w http.ResponseWriter, r *http.Request) { - in, err := decodeUpdateItemInput(maxDynamoBodyReader(w, r)) - if err != nil { - writeDynamoErrorFromErr(w, err) - return - } - if err := d.ensureLegacyTableMigration(r.Context(), in.TableName); err != nil { - writeDynamoErrorFromErr(w, err) - return - } - lockKey, err := dynamoItemUpdateLockKey(in.TableName, in.Key) - if err != nil { - writeDynamoError(w, http.StatusBadRequest, dynamoErrValidation, err.Error()) - return - } - unlock := d.lockItemUpdate(lockKey) - defer unlock() - plan, err := d.updateItemWithRetry(r.Context(), in) - if err != nil { - writeDynamoErrorFromErr(w, err) - return - } - d.observeWrittenItems(r.Context(), in.TableName, 1) - resp := map[string]any{} - if attrs := updateItemReturnAttributes(in.ReturnValues, plan.current, plan.next); len(attrs) > 0 { - resp["Attributes"] = attrs - } - writeDynamoJSON(w, resp) -} - -func decodeUpdateItemInput(bodyReader io.Reader) (updateItemInput, error) { - body, err := io.ReadAll(bodyReader) - if err != nil { - return updateItemInput{}, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) - } - var in updateItemInput - if err := json.Unmarshal(body, &in); err != nil { - return updateItemInput{}, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) - } - if strings.TrimSpace(in.TableName) == "" { - return updateItemInput{}, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, "missing table name") - } - if err := validateUpdateItemReturnValues(in.ReturnValues); err != nil { - return updateItemInput{}, err - } - return in, nil -} - -func validateUpdateItemReturnValues(returnValues string) error { - switch strings.TrimSpace(returnValues) { - case "", - dynamoReturnValueNone, - dynamoReturnValueAllOld, - dynamoReturnValueUpdatedOld, - dynamoReturnValueAllNew, - dynamoReturnValueUpdatedNew: - return nil - default: - return newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, "unsupported ReturnValues") - } -} - -func (d *DynamoDBServer) updateItemWithRetry(ctx context.Context, in updateItemInput) (*itemWritePlan, error) { - return d.retryItemWriteWithGeneration( - ctx, - in.TableName, - "update retry attempts exhausted", - func(readTS uint64) (*itemWritePlan, error) { - return d.prepareUpdateItemWrite(ctx, in, readTS) - }, - ) -} - -func (d *DynamoDBServer) prepareUpdateItemWrite(ctx context.Context, in updateItemInput, readTS uint64) (*itemWritePlan, error) { - schema, exists, err := d.loadTableSchemaAt(ctx, in.TableName, readTS) - if err != nil { - return nil, errors.WithStack(err) - } - if !exists { - return nil, newDynamoAPIError(http.StatusBadRequest, dynamoErrResourceNotFound, "table not found") - } - itemKey, err := schema.itemKeyFromAttributes(in.Key) - if err != nil { - return nil, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) - } - currentLocation, found, err := d.readLogicalItemAt(ctx, schema, in.Key, readTS) - if err != nil { - return nil, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) - } - var current map[string]attributeValue - if !found { - current = map[string]attributeValue{} - } else { - current = currentLocation.item - } - nextItem, err := buildUpdatedItem(schema, in, current) - if err != nil { - return nil, err - } - req, cleanup, err := buildItemWriteRequestWithSource(schema, itemKey, nextItem, currentLocation) - if err != nil { - return nil, err - } - return &itemWritePlan{ - req: req, - generation: schema.Generation, - cleanup: cleanup, - current: cloneAttributeValueMap(current), - next: cloneAttributeValueMap(nextItem), - }, nil -} - -func buildUpdatedItem(schema *dynamoTableSchema, in updateItemInput, current map[string]attributeValue) (map[string]attributeValue, error) { - if err := validateConditionOnItem(in.ConditionExpression, in.ExpressionAttributeNames, in.ExpressionAttributeValues, current); err != nil { - return nil, newDynamoAPIError(http.StatusBadRequest, dynamoErrConditionalFailed, err.Error()) - } - nextItem := cloneAttributeValueMap(current) - maps.Copy(nextItem, in.Key) - if err := applyUpdateExpression(in.UpdateExpression, in.ExpressionAttributeNames, in.ExpressionAttributeValues, nextItem); err != nil { - return nil, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) - } - if err := ensurePrimaryKeyUnchanged(schema.PrimaryKey, in.Key, nextItem); err != nil { - return nil, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) - } - return nextItem, nil -} - -func ensurePrimaryKeyUnchanged(keySchema dynamoKeySchema, originalKey map[string]attributeValue, nextItem map[string]attributeValue) error { - if err := ensureSinglePrimaryKeyUnchanged(keySchema.HashKey, originalKey, nextItem); err != nil { - return err - } - if keySchema.RangeKey != "" { - if err := ensureSinglePrimaryKeyUnchanged(keySchema.RangeKey, originalKey, nextItem); err != nil { - return err - } - } - return nil -} - -func ensureSinglePrimaryKeyUnchanged(attrName string, originalKey map[string]attributeValue, nextItem map[string]attributeValue) error { - keyVal, ok := originalKey[attrName] - if !ok { - return errors.New("missing key attribute") - } - nextVal, ok := nextItem[attrName] - if !ok { - return errors.New("cannot remove key attribute") - } - if !attributeValueEqual(keyVal, nextVal) { - return errors.New("cannot update primary key attribute") - } - return nil -} - -type dynamoItemLocation struct { - schema *dynamoTableSchema - key []byte - item map[string]attributeValue -} - -func buildItemWriteRequestWithSource( - targetSchema *dynamoTableSchema, - targetKey []byte, - nextItem map[string]attributeValue, - current *dynamoItemLocation, -) (*kv.OperationGroup[kv.OP], [][]byte, error) { - payload, err := encodeStoredDynamoItem(nextItem) - if err != nil { - return nil, nil, errors.WithStack(err) - } - elems := []*kv.Elem[kv.OP]{{Op: kv.Put, Key: targetKey, Value: payload}} - cleanup := [][]byte{targetKey} - delKeys, putKeys, err := itemStorageDelta(targetSchema, targetKey, nextItem, current) - if err != nil { - return nil, nil, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) - } - for _, key := range delKeys { - elems = append(elems, &kv.Elem[kv.OP]{Op: kv.Del, Key: key}) - } - for _, key := range putKeys { - elems = append(elems, &kv.Elem[kv.OP]{Op: kv.Put, Key: key, Value: targetKey}) - cleanup = append(cleanup, key) - } - return &kv.OperationGroup[kv.OP]{ - IsTxn: true, - StartTS: 0, - Elems: elems, - }, cleanup, nil -} - -func itemStorageDelta( - targetSchema *dynamoTableSchema, - targetKey []byte, - nextItem map[string]attributeValue, - current *dynamoItemLocation, -) ([][]byte, [][]byte, error) { - oldKeys, err := itemStorageKeys(current) - if err != nil { - return nil, nil, err - } - newKeys, err := targetSchema.gsiEntryKeysForItem(nextItem) - if err != nil { - return nil, nil, err - } - newSet := bytesToSet(newKeys) - oldSet := bytesToSet(oldKeys) - delete(oldSet, string(targetKey)) - delKeys := make([][]byte, 0, len(oldKeys)) - for key, raw := range oldSet { - if _, ok := newSet[key]; ok { - continue - } - delKeys = append(delKeys, raw) - } - putKeys := make([][]byte, 0, len(newKeys)) - for key, raw := range newSet { - if _, ok := oldSet[key]; ok { - continue - } - putKeys = append(putKeys, raw) - } - return delKeys, putKeys, nil -} - -func itemStorageKeys(current *dynamoItemLocation) ([][]byte, error) { - if current == nil || len(current.item) == 0 { - return nil, nil - } - gsiKeys, err := current.schema.gsiEntryKeysForItem(current.item) - if err != nil { - return nil, err - } - out := make([][]byte, 0, len(gsiKeys)+1) - out = append(out, bytes.Clone(current.key)) - out = append(out, gsiKeys...) - return out, nil -} - -func bytesToSet(keys [][]byte) map[string][]byte { - out := make(map[string][]byte, len(keys)) - for _, key := range keys { - out[string(key)] = key - } - return out -} - -func putItemReturnAttributes(returnValues string, current map[string]attributeValue) map[string]attributeValue { - if !strings.EqualFold(strings.TrimSpace(returnValues), dynamoReturnValueAllOld) || len(current) == 0 { - return nil - } - return cloneAttributeValueMap(current) -} - -func updateItemReturnAttributes(returnValues string, current map[string]attributeValue, next map[string]attributeValue) map[string]attributeValue { - switch strings.TrimSpace(returnValues) { - case "", dynamoReturnValueNone: - return nil - case dynamoReturnValueAllOld: - if len(current) == 0 { - return nil - } - return cloneAttributeValueMap(current) - case dynamoReturnValueAllNew: - return cloneAttributeValueMap(next) - case dynamoReturnValueUpdatedOld: - return selectUpdatedAttributes(current, next, true) - case dynamoReturnValueUpdatedNew: - return selectUpdatedAttributes(current, next, false) - default: - return nil - } -} - -func selectUpdatedAttributes(current map[string]attributeValue, next map[string]attributeValue, oldValues bool) map[string]attributeValue { - keys := updatedAttributeNames(current, next) - if len(keys) == 0 { - return nil - } - out := make(map[string]attributeValue, len(keys)) - for _, key := range keys { - if oldValues { - if value, ok := current[key]; ok { - out[key] = value - } - continue - } - if value, ok := next[key]; ok { - out[key] = value - } - } - if len(out) == 0 { - return nil - } - return out -} - -func updatedAttributeNames(current map[string]attributeValue, next map[string]attributeValue) []string { - seen := make(map[string]struct{}, len(current)+len(next)) - for name := range current { - seen[name] = struct{}{} - } - for name := range next { - seen[name] = struct{}{} - } - names := make([]string, 0, len(seen)) - for name := range seen { - names = append(names, name) - } - sort.Strings(names) - out := make([]string, 0, len(names)) - for _, name := range names { - oldVal, oldOK := current[name] - newVal, newOK := next[name] - if !oldOK && !newOK { - continue - } - if oldOK && newOK && attributeValueEqual(oldVal, newVal) { - continue - } - out = append(out, name) - } - return out -} - -func (d *DynamoDBServer) query(w http.ResponseWriter, r *http.Request) { - in, err := decodeQueryInput(maxDynamoBodyReader(w, r)) - if err != nil { - writeDynamoErrorFromErr(w, err) - return - } - // Lease-check the shard the Query reads BEFORE queryItems samples - // readTS, so the quorum-freshness bound is established without - // changing read-snapshot semantics (sampling readTS only after the - // confirmation keeps any commit that landed before it visible). A - // base-table Query on a single partition key reads exactly one - // hash-key prefix, which routes to one shard group, so the check is - // routed by that prefix in a multi-group deployment. GSI queries and - // queries whose prefix cannot be resolved fall back to the keyless - // check, which spans every shard the range can touch. - if !d.leaseCheckQuery(w, r, in) { - return - } - out, err := d.queryItems(r.Context(), in) - if err != nil { - writeDynamoErrorFromErr(w, err) - return - } - d.observeReadMetrics(r.Context(), in.TableName, out.count, out.scannedCount) - resp := map[string]any{ - "Items": out.items, - "Count": out.count, - "ScannedCount": out.scannedCount, - } - if len(out.lastEvaluatedKey) > 0 { - resp["LastEvaluatedKey"] = out.lastEvaluatedKey - } - writeDynamoJSON(w, resp) -} - -func (d *DynamoDBServer) scan(w http.ResponseWriter, r *http.Request) { - in, err := decodeScanInput(maxDynamoBodyReader(w, r)) - if err != nil { - writeDynamoErrorFromErr(w, err) - return - } - // A Scan reads the whole table and therefore spans every shard that - // holds any of its items. leaseCheckScan establishes the quorum-freshness - // bound across every group BEFORE scanItems samples readTS — but only for - // a request that passes the cheap table/GSI validation, so a scan that - // will deterministically 4xx is not masked by a degraded-lease 500 - // (codex #952 P2-A). - if !d.leaseCheckScan(w, r, in) { - return - } - out, err := d.scanItems(r.Context(), in) - if err != nil { - writeDynamoErrorFromErr(w, err) - return - } - d.observeReadMetrics(r.Context(), in.TableName, out.count, out.scannedCount) - resp := map[string]any{ - "Items": out.items, - "Count": out.count, - "ScannedCount": out.scannedCount, - } - if len(out.lastEvaluatedKey) > 0 { - resp["LastEvaluatedKey"] = out.lastEvaluatedKey - } - writeDynamoJSON(w, resp) -} - -// leaseReadKeyless performs a keyless quorum-freshness lease check for -// multi-shard read handlers (Scan, GSI/whole-table Query fallback). These -// reads visit every shard the range intersects, so the check fences EVERY -// group the coordinator owns via LeaseReadAllGroupsThrough — a default-group- -// only lease would let a non-default group serve a stale snapshot. A -// single-group coordinator falls back to one LeaseRead, so single-group -// deployments still issue exactly one lease read. It bounds the wait with -// dynamoLeaseReadTimeout so a stalled Raft cannot hang the handler when the -// client never cancels, and writes the same InternalServerError that getItem -// produces on lease failure. Returns false after writing an error response; -// the caller should simply return. -// leaseReadKeyless fences every group via the keyless all-groups lease check. -// `leaseCtx` MUST be the SAME context the pre-pass armed (it bounds the entire -// pre-pass — schema read + the lease that lands here — by dynamoLeaseReadTimeout -// total; coderabbit Major on PR #952 round-4). Creating a fresh -// context.WithTimeout(r.Context(), dynamoLeaseReadTimeout) here would re-arm -// the 5s budget per call, so a slow schema read followed by the keyless -// fallback could consume close to 10s end-to-end. Callers that do NOT have a -// pre-pass context must pass their own bounded ctx; r.Context() with the -// handler's own timeout-on-the-roundabout is the conservative choice. -func (d *DynamoDBServer) leaseReadKeyless(w http.ResponseWriter, leaseCtx context.Context) bool { - if err := kv.LeaseReadAllGroupsThrough(d.coordinator, leaseCtx); err != nil { - writeDynamoError(w, http.StatusInternalServerError, dynamoErrInternal, err.Error()) - return false - } - return true -} - -// leaseCheckScan runs the Scan lease pre-pass. A Scan reads the whole table, so -// a VALID scan must fence EVERY group via the keyless all-groups check. But a -// scan against a missing table, an unknown index, or a GSI with -// ConsistentRead=true never touches data: the read path rejects it with a -// deterministic 4xx, so establishing freshness is unnecessary and a failed -// all-groups fence on a degraded deployment would mask that 4xx with a 500 -// (codex #952 P2-A). leaseCheckScan therefore cheaply pre-validates the request -// (schema load + the same GSI read-option checks scanItems re-runs) at a -// tentative timestamp and skips the lease on a client-side validation error, -// while still failing closed (fencing every group) on a transient schema-read -// failure. Returns false after writing an error response; the caller returns. -func (d *DynamoDBServer) leaseCheckScan(w http.ResponseWriter, r *http.Request, in scanInput) bool { - // leaseCtx bounds the pre-validation schema read AND the lease read so a - // stalled schema read cannot block the handler past dynamoLeaseReadTimeout - // before the lease phase begins. leaseReadKeyless creates its own bounded - // context for the actual lease read. - leaseCtx, leaseCancel := context.WithTimeout(r.Context(), dynamoLeaseReadTimeout) - defer leaseCancel() - schema, plan, err := d.multiShardReadLeasePlan(leaseCtx, in.TableName, in.IndexName, in.Select, in.ProjectionExpression, in.ExpressionAttributeNames, in.ConsistentRead) - if err != nil { - // Transient/internal schema-read failure: fail closed by fencing every - // group. leaseReadKeyless writes the same InternalServerError on its - // own failure. - return d.leaseReadKeyless(w, leaseCtx) - } - if plan == queryLeaseSkip { - // Client-side validation problem (table not found, unknown index, - // unsupported ConsistentRead): the read path re-runs the identical - // validation and surfaces the deterministic 4xx, so skip the lease so a - // degraded-lease failure cannot mask it with a 500 (codex #952 P2-A). - return true - } - // Malformed ExclusiveStartKey is a deterministic ValidationException the - // read path rejects in resolveTableReadBounds / resolveGSIReadBounds — - // before the iterator is constructed and before any store read. If we let - // the lease run first, a degraded shard's 500 would mask that 4xx - // (codex #952 P2 round-3). Pre-validate against the loaded schema and skip - // leasing on failure; the read path will surface the identical error. - if scanExclusiveStartKeyInvalid(schema, in) { - return true - } - // Same logic for a malformed ProjectionExpression: newReadPageState runs - // resolveProjectionAttributes before the iterator reads from the store, so a - // parse failure is a deterministic ValidationException the lease pre-pass - // must not mask (codex #952 P2 round-4 line 2346). validateGSIReadOptions - // already covers the GSI case; this catches the base-table path that the - // earlier ExclusiveStartKey check left exposed. - if projectionInvalid(in.ProjectionExpression, in.ExpressionAttributeNames) { - return true - } - // Valid whole-table read: fence every group (fail closed). - return d.leaseReadKeyless(w, leaseCtx) -} - -// projectionInvalid returns true when the ProjectionExpression cannot be -// parsed against the given ExpressionAttributeNames. resolveProjectionAttributes -// is the same validator newReadPageState runs before the iterator reads from -// the store, so a true result means the read path WILL reject the request with -// a deterministic ValidationException without touching data. Pre-pass uses -// this to skip leasing in that case (codex #952 P2 round-4 lines 2346, 2492). -// An empty ProjectionExpression is the common "project everything" case and -// returns false (no validation needed). -func projectionInvalid(projectionExpression string, names map[string]string) bool { - if strings.TrimSpace(projectionExpression) == "" { - return false - } - _, err := resolveProjectionAttributes(projectionExpression, names) - return err != nil -} - -// scanExclusiveStartKeyInvalid returns true when in.ExclusiveStartKey cannot be -// decoded against the table's primary key (Scan with no IndexName) or the named -// GSI (Scan with IndexName). It mirrors the validation resolveTableReadBounds / -// resolveGSIReadBounds run in scanItems so the lease pre-pass can route the -// invalid case to the same skip-lease path as table-not-found etc. A nil schema -// is treated as "not invalid" because multiShardReadLeasePlan already classified -// the request as queryLeaseSkip in that case and we never reach here. -func scanExclusiveStartKeyInvalid(schema *dynamoTableSchema, in scanInput) bool { - if schema == nil || len(in.ExclusiveStartKey) == 0 { - return false - } - if strings.TrimSpace(in.IndexName) == "" { - _, err := schema.itemKeyFromAttributes(in.ExclusiveStartKey) - return err != nil - } - _, _, err := schema.gsiKeyFromAttributes(in.IndexName, in.ExclusiveStartKey) - return err != nil -} - -// multiShardReadLeasePlan cheaply classifies whether a multi-shard read (Scan or -// a GSI/whole-table Query) is a VALID data read that must fence every group -// (queryLeaseAllGroups) or a CLIENT-side validation problem the read path -// rejects identically without touching data (queryLeaseSkip). It performs the -// same table-existence and GSI read-option checks prepareReadSchema runs, at a -// tentative timestamp (schema only, no readTS sampling), so the lease pre-pass -// never masks a deterministic 4xx with a degraded-lease 500. A transient/internal -// schema-read failure is returned as an error so the caller fails closed. -// -// The loaded schema is returned (nil when the table is missing or on error) so -// callers that need a further deterministic validation (the GSI Query -// KeyConditionExpression check) can reuse it without a second schema load. -// -// Validation failures are reported via queryLeaseSkip rather than an error: the -// read path re-runs the same resolution and reports the identical validation -// error, so error mapping is unchanged. -func (d *DynamoDBServer) multiShardReadLeasePlan( - ctx context.Context, - tableName string, - indexName string, - selectValue string, - projectionExpression string, - names map[string]string, - consistentRead *bool, -) (*dynamoTableSchema, queryLeasePlan, error) { - tentativeTS := snapshotTS(d.coordinator.Clock(), d.store) - schema, exists, err := d.loadTableSchemaAt(ctx, tableName, tentativeTS) - if err != nil { - // loadTableSchemaAt maps ErrKeyNotFound to (_, false, nil); any error - // reaching here is a transient store/context/decode failure, so fail - // closed. - return nil, queryLeaseAllGroups, errors.WithStack(err) - } - if !exists { - // Table not found is a deterministic ResourceNotFoundException the read - // path produces without touching data; skip the lease. - return nil, queryLeaseSkip, nil - } - // validateGSIReadOptions runs the identical unknown-index / GSI - // ConsistentRead / projection checks prepareReadSchema performs. Any failure - // is a *dynamoAPIError (a deterministic ValidationException), so classify it - // as a skip; a transient failure is impossible here (no store access). - if err := validateGSIReadOptions(schema, indexName, selectValue, projectionExpression, names, consistentRead); err != nil { - if dynamoErrIsTransient(err) { - // Defensive: validateGSIReadOptions returns only *dynamoAPIError, so - // this is unreachable. Fail closed if a future change adds a - // transient path. - return nil, queryLeaseAllGroups, errors.WithStack(err) - } - return schema, queryLeaseSkip, nil - } - return schema, queryLeaseAllGroups, nil -} - -// leaseCheckQuery lease-checks the shard a Query reads with a bounded -// timeout, writing the same InternalServerError getItem produces on -// failure. When the request resolves to a single base-table hash-key -// prefix (the common case), the check is routed to that prefix's owning -// group via LeaseReadForKey so a multi-group deployment confirms the -// shard that actually holds the data — not the default group. GSI -// queries and any request whose prefix cannot be resolved here fall back -// to the keyless check, which establishes freshness across every shard -// the range can touch. Returns false after writing an error response; -// the caller should simply return. -func (d *DynamoDBServer) leaseCheckQuery(w http.ResponseWriter, r *http.Request, in queryInput) bool { - // leaseCtx bounds the entire pre-pass — the schema read that resolves - // the lease key and the lease read itself — so a stalled schema read - // cannot block the handler past dynamoLeaseReadTimeout before the lease - // phase begins. The keyless fallback creates its own bounded context. - leaseCtx, leaseCancel := context.WithTimeout(r.Context(), dynamoLeaseReadTimeout) - defer leaseCancel() - leaseKey, plan, err := d.queryLeaseKey(leaseCtx, in) - if err != nil { - // Transient/internal schema read failure: the routing key could - // not be resolved, so fail closed by fencing EVERY group via the - // keyless check (a strict superset of the single group this query - // would have routed to). leaseReadKeyless writes the same - // InternalServerError on its own failure. - return d.leaseReadKeyless(w, leaseCtx) - } - switch plan { - case queryLeaseSkip: - // Client-side validation problem (table not found, malformed/ - // unsupported KeyConditionExpression): the request touches no data, - // so establishing freshness is unnecessary. Skip the lease entirely - // and let queryItems re-run the identical resolution and surface the - // deterministic ResourceNotFoundException/ValidationException — a - // lease failure on the fallback must not mask that 4xx with a 500 in - // a degraded deployment (codex #952 P2). This matches getItem, which - // writes the 4xx before any lease read. - return true - case queryLeaseAllGroups: - // GSI / whole-table query: a VALID read that spans multiple shards, - // so the keyless all-groups check is the correct fence (fail closed). - return d.leaseReadKeyless(w, leaseCtx) - case queryLeaseSingleGroup: - if _, err := kv.LeaseReadForKeyThrough(d.coordinator, leaseCtx, leaseKey); err != nil { - writeDynamoError(w, http.StatusInternalServerError, dynamoErrInternal, err.Error()) - return false - } - return true - default: - // Unreachable: queryLeaseKey only returns the three plans above. Fail - // closed via the all-groups fence rather than silently proceeding. - return d.leaseReadKeyless(w, leaseCtx) - } -} - -// queryLeasePlan classifies how a Query lease pre-pass must fence the read. -type queryLeasePlan int - -const ( - // queryLeaseSingleGroup: the query routes to exactly one shard group; - // fence that group via the resolved leaseKey. - queryLeaseSingleGroup queryLeasePlan = iota - // queryLeaseAllGroups: a VALID multi-shard read (GSI query or whole-table - // prefix); fence every group via the keyless all-groups check. - queryLeaseAllGroups - // queryLeaseSkip: a CLIENT-side validation problem (table not found, - // malformed/unsupported KeyConditionExpression) that the read path rejects - // deterministically without touching data; skip the lease so the handler's - // 4xx is never masked by a transient lease failure. - queryLeaseSkip -) - -// queryLeaseKey resolves the single hash-key prefix a base-table Query reads, -// at a tentative timestamp (schema only, no readTS sampling), so the lease -// check can be routed to the owning shard group. It returns: -// - (prefix, queryLeaseSingleGroup, nil) when the query routes to exactly -// one shard group; -// - (nil, queryLeaseAllGroups, nil) for a VALID multi-shard read (GSI query -// or whole-table prefix) the caller must fence across every group; -// - (nil, queryLeaseSkip, nil) for a CLIENT-side validation problem (table -// not found, unknown index, GSI ConsistentRead, malformed/unsupported -// KeyConditionExpression) the read path rejects identically — the caller -// skips the lease so the deterministic 4xx is not masked by a transient -// lease failure (codex #952 P2). GSI queries are validated against the -// table schema here before being classified as a multi-shard read so an -// invalid index can never trigger the all-groups fence (codex #952 P2-B); -// - (nil, _, err) for a TRANSIENT/INTERNAL schema-read failure (leaseCtx -// deadline, Pebble error) so the caller fails closed. -// -// Validation failures are reported via queryLeaseSkip rather than an error: the -// read path re-runs the same resolution and reports the identical validation -// error, so error mapping is unchanged. -func (d *DynamoDBServer) queryLeaseKey(ctx context.Context, in queryInput) ([]byte, queryLeasePlan, error) { - if strings.TrimSpace(in.IndexName) != "" { - // A GSI query is a multi-shard read, but only when it passes the same - // validation the read path runs: a query against a missing table, - // unknown index, GSI ConsistentRead, or malformed KeyConditionExpression - // touches no data and the read path rejects it with a deterministic 4xx. - // Fencing every group before that validation would mask the 4xx with a - // degraded-lease 500, so classify those as a skip (codex #952 P2-B). - schema, plan, err := d.multiShardReadLeasePlan(ctx, in.TableName, in.IndexName, in.Select, in.ProjectionExpression, in.ExpressionAttributeNames, in.ConsistentRead) - if err != nil { - return nil, queryLeaseAllGroups, errors.WithStack(err) - } - if plan == queryLeaseSkip { - return nil, queryLeaseSkip, nil - } - // Malformed ExclusiveStartKey is a deterministic ValidationException the - // read path rejects before the iterator is constructed (codex #952 P2 - // round-3). Skip leasing on failure so a degraded shard cannot mask - // that 4xx with a 500. - if queryExclusiveStartKeyInvalid(schema, in) { - return nil, queryLeaseSkip, nil - } - // Malformed ProjectionExpression is the same kind of deterministic - // ValidationException newReadPageState raises before the iterator - // touches data (codex #952 P2 round-4 line 2492); skip the lease so a - // degraded shard cannot mask it. - if projectionInvalid(in.ProjectionExpression, in.ExpressionAttributeNames) { - return nil, queryLeaseSkip, nil - } - // Schema + GSI options are valid; the KeyConditionExpression is the last - // deterministic validation the read path runs before touching data. - return nil, gsiQueryLeasePlan(in, schema), nil - } - tentativeTS := snapshotTS(d.coordinator.Clock(), d.store) - schema, exists, err := d.loadTableSchemaAt(ctx, in.TableName, tentativeTS) - if err != nil { - // loadTableSchemaAt maps ErrKeyNotFound to (_, false, nil); any - // error reaching here is a transient store/context/decode failure, - // so fail closed. - return nil, queryLeaseSingleGroup, errors.WithStack(err) - } - if !exists { - // Table not found is a deterministic ResourceNotFoundException the - // read path produces without touching data; skip the lease. - return nil, queryLeaseSkip, nil - } - // Same ExclusiveStartKey pre-check as the GSI branch above (base table). - if queryExclusiveStartKeyInvalid(schema, in) { - return nil, queryLeaseSkip, nil - } - // Same ProjectionExpression pre-check (base-table path; codex #952 P2 round-4 - // line 2492). - if projectionInvalid(in.ProjectionExpression, in.ExpressionAttributeNames) { - return nil, queryLeaseSkip, nil - } - prefix, plan := queryLeasePrefix(in, schema) - return prefix, plan, nil -} - -// queryExclusiveStartKeyInvalid mirrors the validation -// resolveQueryExclusiveStartKey runs inside queryItems' read-bounds resolution -// (`adapter/dynamodb.go` resolveQueryExclusiveStartKey / resolveTableReadBounds / -// resolveGSIReadBounds): a malformed ExclusiveStartKey produces a deterministic -// ValidationException without touching any store. Returning true routes the -// lease pre-pass to queryLeaseSkip so a degraded-shard 500 cannot mask that 4xx -// (codex #952 P2 round-3). Mirrors scanExclusiveStartKeyInvalid for the -// Query path — kept separate because the GSI vs base-table dispatch differs -// from the Scan input. -func queryExclusiveStartKeyInvalid(schema *dynamoTableSchema, in queryInput) bool { - if schema == nil || len(in.ExclusiveStartKey) == 0 { - return false - } - if strings.TrimSpace(in.IndexName) == "" { - _, err := schema.itemKeyFromAttributes(in.ExclusiveStartKey) - return err != nil - } - _, _, err := schema.gsiKeyFromAttributes(in.IndexName, in.ExclusiveStartKey) - return err != nil -} - -// queryLeasePrefix resolves the single hash-key prefix a base-table Query -// reads, classifying the read into queryLeaseSingleGroup (resolved prefix), -// queryLeaseAllGroups (whole-table prefix: a valid multi-shard read), or -// queryLeaseSkip (malformed KeyConditionExpression: a validation error the -// read path rejects identically). The validation error is deliberately not -// surfaced — only the routing classification matters here, and the read path -// reports the identical error downstream. -func queryLeasePrefix(in queryInput, schema *dynamoTableSchema) ([]byte, queryLeasePlan) { - keySchema, cond, err := resolveQueryCondition(in, schema) - if err != nil { - // Malformed/unsupported KeyConditionExpression: a deterministic - // ValidationException the read path produces without touching data. - return nil, queryLeaseSkip - } - // A query whose key schema hash key differs from the primary hash - // key reads the whole-table prefix (see queryScanPrefix), which can - // span multiple shards; let the all-groups check cover them. - if keySchema.HashKey != schema.PrimaryKey.HashKey { - return nil, queryLeaseAllGroups - } - prefix, err := queryScanPrefix(schema, in, keySchema, cond.hashValue) - if err != nil { - // queryScanPrefix only fails on an unparseable hash-key value — a - // ValidationException the read path rejects identically. - return nil, queryLeaseSkip - } - return prefix, queryLeaseSingleGroup -} - -// gsiQueryLeasePlan classifies a GSI Query (already known to name a valid index -// on an existing table) as the multi-shard all-groups read it is, unless its -// KeyConditionExpression is malformed — the last deterministic validation the -// read path runs before touching data. resolveQueryCondition does no store -// access and returns only *dynamoAPIError, so a failure is a ValidationException -// the read path rejects identically; classify it as a skip so the lease pre-pass -// cannot mask that 4xx with a degraded-lease 500 (codex #952 P2-B). Like -// queryLeasePrefix, the validation error is deliberately not surfaced — only the -// routing classification matters here. -func gsiQueryLeasePlan(in queryInput, schema *dynamoTableSchema) queryLeasePlan { - if _, _, err := resolveQueryCondition(in, schema); err != nil { - return queryLeaseSkip - } - return queryLeaseAllGroups -} - -func decodeQueryInput(bodyReader io.Reader) (queryInput, error) { - body, err := io.ReadAll(bodyReader) - if err != nil { - return queryInput{}, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) - } - var in queryInput - if err := json.Unmarshal(body, &in); err != nil { - return queryInput{}, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) - } - if strings.TrimSpace(in.TableName) == "" { - return queryInput{}, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, "missing table name") - } - if err := validateReadSelect(in.Select); err != nil { - return queryInput{}, err - } - if _, _, err := resolveReadLimit(in.Limit); err != nil { - return queryInput{}, err - } - return in, nil -} - -func decodeScanInput(bodyReader io.Reader) (scanInput, error) { - body, err := io.ReadAll(bodyReader) - if err != nil { - return scanInput{}, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) - } - var in scanInput - if err := json.Unmarshal(body, &in); err != nil { - return scanInput{}, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) - } - if strings.TrimSpace(in.TableName) == "" { - return scanInput{}, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, "missing table name") - } - if err := validateReadSelect(in.Select); err != nil { - return scanInput{}, err - } - if _, _, err := resolveReadLimit(in.Limit); err != nil { - return scanInput{}, err - } - return in, nil -} - -type queryOutput struct { - items []map[string]attributeValue - count int - scannedCount int - lastEvaluatedKey map[string]attributeValue -} - -type readPageOptions struct { - filterExpression string - projectionExpression string - expressionAttributeNames map[string]string - expressionAttributeValues map[string]attributeValue - exclusiveStartKey map[string]attributeValue - limit *int32 - selectValue string - lastEvaluatedKeyBuilder func(map[string]attributeValue) map[string]attributeValue -} - -type dynamoReadIterator interface { - Next(context.Context) (map[string]attributeValue, bool, error) -} - -type queryRangeOperator string - -const ( - queryRangeOpEqual queryRangeOperator = "=" - queryRangeOpLessThan queryRangeOperator = "<" - queryRangeOpLessOrEq queryRangeOperator = "<=" - queryRangeOpGreater queryRangeOperator = ">" - queryRangeOpGreaterEq queryRangeOperator = ">=" - queryRangeOpBetween queryRangeOperator = "BETWEEN" - queryRangeOpBeginsWith queryRangeOperator = "BEGINS_WITH" -) - -type queryRangeCondition struct { - attr string - op queryRangeOperator - value1 attributeValue - value2 attributeValue -} - -type queryCondition struct { - hashAttr string - hashValue attributeValue - rangeCond *queryRangeCondition -} - -func (d *DynamoDBServer) queryItems(ctx context.Context, in queryInput) (*queryOutput, error) { - schema, readTS, err := d.prepareReadSchema(ctx, in.TableName, in.IndexName, in.Select, in.ProjectionExpression, in.ExpressionAttributeNames, in.ConsistentRead) - if err != nil { - return nil, err - } - readPin := d.pinReadTS(readTS) - defer readPin.Release() - keySchema, cond, err := resolveQueryCondition(in, schema) - if err != nil { - return nil, err - } - opts := readPageOptions{ - filterExpression: in.FilterExpression, - projectionExpression: in.ProjectionExpression, - expressionAttributeNames: in.ExpressionAttributeNames, - expressionAttributeValues: in.ExpressionAttributeValues, - exclusiveStartKey: in.ExclusiveStartKey, - limit: in.Limit, - selectValue: in.Select, - lastEvaluatedKeyBuilder: func(item map[string]attributeValue) map[string]attributeValue { - return makeReadLastEvaluatedKey(schema.PrimaryKey, keySchema, item) - }, - } - if schema.MigratingFromGeneration == 0 { - if out, ok, err := d.streamQueryItems(ctx, in, schema, keySchema, cond, readTS, opts); ok || err != nil { - return out, err - } - } - items, err := d.loadQueryItemsWithMigration(ctx, in, schema, keySchema, cond, readTS) - if err != nil { - return nil, err - } - items, err = projectReadItemsForIndex(schema, in.IndexName, items) - if err != nil { - return nil, err - } - orderQueryItems(items, keySchema.RangeKey, in.ScanIndexForward) - return finalizeReadPage(schema, items, opts) -} - -func (d *DynamoDBServer) scanItems(ctx context.Context, in scanInput) (*queryOutput, error) { - schema, readTS, err := d.prepareReadSchema(ctx, in.TableName, in.IndexName, in.Select, in.ProjectionExpression, in.ExpressionAttributeNames, in.ConsistentRead) - if err != nil { - return nil, err - } - readPin := d.pinReadTS(readTS) - defer readPin.Release() - indexKeySchema := schema.PrimaryKey - if strings.TrimSpace(in.IndexName) != "" { - indexKeySchema, err = schema.keySchemaForQuery(in.IndexName) - if err != nil { - return nil, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) - } - } - opts := readPageOptions{ - filterExpression: in.FilterExpression, - projectionExpression: in.ProjectionExpression, - expressionAttributeNames: in.ExpressionAttributeNames, - expressionAttributeValues: in.ExpressionAttributeValues, - exclusiveStartKey: in.ExclusiveStartKey, - limit: in.Limit, - selectValue: in.Select, - lastEvaluatedKeyBuilder: func(item map[string]attributeValue) map[string]attributeValue { - return makeReadLastEvaluatedKey(schema.PrimaryKey, indexKeySchema, item) - }, - } - if schema.MigratingFromGeneration == 0 { - if out, ok, err := d.streamScanItems(ctx, in, schema, readTS, opts); ok || err != nil { - return out, err - } - } - items, err := d.loadScanItemsWithMigration(ctx, in, schema, indexKeySchema, readTS) - if err != nil { - return nil, err - } - items, err = projectReadItemsForIndex(schema, in.IndexName, items) - if err != nil { - return nil, err - } - return finalizeReadPage(schema, items, opts) -} - -func (d *DynamoDBServer) prepareReadSchema( - ctx context.Context, - tableName string, - indexName string, - selectValue string, - projectionExpression string, - names map[string]string, - consistentRead *bool, -) (*dynamoTableSchema, uint64, error) { - if err := d.ensureLegacyTableMigration(ctx, tableName); err != nil { - return nil, 0, err - } - readTS := d.resolveDynamoReadTS(consistentRead) - schema, exists, err := d.loadTableSchemaAt(ctx, tableName, readTS) - if err != nil { - return nil, 0, errors.WithStack(err) - } - if !exists { - return nil, 0, newDynamoAPIError(http.StatusBadRequest, dynamoErrResourceNotFound, "table not found") - } - if err := validateGSIReadOptions(schema, indexName, selectValue, projectionExpression, names, consistentRead); err != nil { - return nil, 0, err - } - return schema, readTS, nil -} - -func (d *DynamoDBServer) loadQueryItemsWithMigration( - ctx context.Context, - in queryInput, - schema *dynamoTableSchema, - keySchema dynamoKeySchema, - cond queryCondition, - readTS uint64, -) ([]map[string]attributeValue, error) { - items, err := d.queryItemsByKeyCondition(ctx, in, schema, keySchema, cond, readTS) - if err != nil { - return nil, err - } - return d.mergeReadItemsFromSourceSchema(schema, keySchema, items, func(sourceSchema *dynamoTableSchema) ([]map[string]attributeValue, error) { - return d.queryItemsByKeyCondition(ctx, in, sourceSchema, keySchema, cond, readTS) - }) -} - -func (d *DynamoDBServer) loadScanItemsWithMigration( - ctx context.Context, - in scanInput, - schema *dynamoTableSchema, - indexKeySchema dynamoKeySchema, - readTS uint64, -) ([]map[string]attributeValue, error) { - items, err := d.scanItemsBySource(ctx, in, schema, readTS) - if err != nil { - return nil, err - } - return d.mergeReadItemsFromSourceSchema(schema, indexKeySchema, items, func(sourceSchema *dynamoTableSchema) ([]map[string]attributeValue, error) { - return d.scanItemsBySource(ctx, in, sourceSchema, readTS) - }) -} - -func (d *DynamoDBServer) mergeReadItemsFromSourceSchema( - schema *dynamoTableSchema, - orderKey dynamoKeySchema, - items []map[string]attributeValue, - loadSource func(*dynamoTableSchema) ([]map[string]attributeValue, error), -) ([]map[string]attributeValue, error) { - sourceSchema := schema.migrationSourceSchema() - if sourceSchema == nil { - return items, nil - } - sourceItems, err := loadSource(sourceSchema) - if err != nil { - return nil, err - } - return mergeMigratingReadItems(schema.PrimaryKey, orderKey, items, sourceItems) -} - -// consistentReadLatestTS is a read timestamp used for ConsistentRead=true reads. -// The value is far above any realistic HLC timestamp (~Unix-nanosecond range, -// ≪ 10^19), so reading at this TS from the leader's Pebble store returns the -// most recently committed version of any key. It avoids the noStartTS -// sentinel (^uint64(0)) used by the coordinator. -// -// This sentinel is used on BOTH the leader and followers: -// - On a follower, the read is proxied to the leader via proxyRawGet with -// ts=consistentReadLatestTS, so the leader reads the absolute latest version. -// - On the leader, the LeaderRoutedStore performs a linearizable read fence -// (ensuring all committed Raft entries are applied) and then reads locally -// at consistentReadLatestTS, returning the latest committed version. -// -// Using store.LastCommitTS() instead would introduce a TOCTOU race: the -// timestamp is captured before the linearizable fence, so a write committed -// after LastCommitTS() but applied during the fence would be missed. -const consistentReadLatestTS = ^uint64(0) - 1 - -func (d *DynamoDBServer) resolveDynamoReadTS(consistentRead *bool) uint64 { - if consistentRead != nil && *consistentRead { - return consistentReadLatestTS - } - return snapshotTS(d.coordinator.Clock(), d.store) -} - -func validateGSIReadOptions( - schema *dynamoTableSchema, - indexName string, - selectValue string, - projectionExpression string, - names map[string]string, - consistentRead *bool, -) error { - if strings.TrimSpace(indexName) == "" { - return nil - } - attrs, err := resolveProjectionAttributes(projectionExpression, names) - if err != nil { - return err - } - return validateProjectedGSIRead(schema, indexName, selectValue, attrs, consistentRead) -} - -func validateProjectedGSIRead( - schema *dynamoTableSchema, - indexName string, - selectValue string, - attrs []string, - consistentRead *bool, -) error { - if err := validateGSIConsistentRead(consistentRead); err != nil { - return err - } - allProjected, projected, err := schema.gsiProjectedAttributeSet(indexName) - if err != nil { - return newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) - } - if err := validateGSISelectValue(selectValue, allProjected); err != nil { - return err - } - return validateProjectedAttributes(attrs, projected, allProjected) -} - -func validateGSIConsistentRead(consistentRead *bool) error { - if consistentRead == nil || !*consistentRead { - return nil - } - return newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, "ConsistentRead is not supported on global secondary indexes") -} - -func validateGSISelectValue(selectValue string, allProjected bool) error { - if !strings.EqualFold(strings.TrimSpace(selectValue), "ALL_ATTRIBUTES") || allProjected { - return nil - } - return newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, "ALL_ATTRIBUTES is not supported for this index projection") -} - -func validateProjectedAttributes(attrs []string, projected map[string]struct{}, allProjected bool) error { - if allProjected || len(attrs) == 0 { - return nil - } - for _, attr := range attrs { - if _, ok := projected[attr]; ok { - continue - } - return newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, "requested attribute is not projected into index") - } - return nil -} - -func projectReadItemsForIndex(schema *dynamoTableSchema, indexName string, items []map[string]attributeValue) ([]map[string]attributeValue, error) { - if strings.TrimSpace(indexName) == "" || len(items) == 0 { - return items, nil - } - out := make([]map[string]attributeValue, 0, len(items)) - for _, item := range items { - projected, err := schema.projectItemForIndex(indexName, item) - if err != nil { - return nil, err - } - out = append(out, projected) - } - return out, nil -} - -func resolveQueryCondition(in queryInput, schema *dynamoTableSchema) (dynamoKeySchema, queryCondition, error) { - keySchema, err := schema.keySchemaForQuery(in.IndexName) - if err != nil { - return dynamoKeySchema{}, queryCondition{}, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) - } - keyExpr, err := replaceNames(in.KeyConditionExpression, in.ExpressionAttributeNames) - if err != nil { - return dynamoKeySchema{}, queryCondition{}, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) - } - parsed, err := parseKeyConditionExpression(keyExpr) - if err != nil { - return dynamoKeySchema{}, queryCondition{}, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) - } - cond, err := buildQueryCondition(keySchema, parsed, in.ExpressionAttributeValues) - if err != nil { - return dynamoKeySchema{}, queryCondition{}, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) - } - return keySchema, cond, nil -} - -func filterQueryItems(kvs []*store.KVPair, cond queryCondition) ([]map[string]attributeValue, error) { - items := make([]map[string]attributeValue, 0, len(kvs)) - for _, kvp := range kvs { - item, err := decodeStoredDynamoItem(kvp.Value) - if err != nil { - return nil, err - } - if !matchesQueryCondition(item, cond) { - continue - } - items = append(items, item) - } - return items, nil -} - -func orderQueryItems(items []map[string]attributeValue, rangeKey string, scanIndexForward *bool) { - if rangeKey != "" { - sort.SliceStable(items, func(i, j int) bool { - return compareAttributeValueSortKey(items[i][rangeKey], items[j][rangeKey]) < 0 - }) - } - scanForward := true - if scanIndexForward != nil { - scanForward = *scanIndexForward - } - if !scanForward { - reverseItems(items) - } -} - -func mergeMigratingReadItems( - primaryKey dynamoKeySchema, - orderKey dynamoKeySchema, - preferred []map[string]attributeValue, - source []map[string]attributeValue, -) ([]map[string]attributeValue, error) { - if len(source) == 0 { - return preferred, nil - } - out := make([]map[string]attributeValue, 0, len(preferred)+len(source)) - seen := make(map[string]struct{}, len(preferred)+len(source)) - appendItem := func(item map[string]attributeValue) error { - identity, err := itemPrimaryIdentity(primaryKey, item) - if err != nil { - return err - } - if _, ok := seen[identity]; ok { - return nil - } - seen[identity] = struct{}{} - out = append(out, item) - return nil - } - for _, item := range preferred { - if err := appendItem(item); err != nil { - return nil, err - } - } - for _, item := range source { - if err := appendItem(item); err != nil { - return nil, err - } - } - sort.SliceStable(out, func(i, j int) bool { - return compareReadOrder(orderKey, primaryKey, out[i], out[j]) < 0 - }) - return out, nil -} - -func itemPrimaryIdentity(keySchema dynamoKeySchema, item map[string]attributeValue) (string, error) { - var b strings.Builder - if err := appendIdentityPart(&b, item, keySchema.HashKey); err != nil { - return "", err - } - if keySchema.RangeKey != "" { - if err := appendIdentityPart(&b, item, keySchema.RangeKey); err != nil { - return "", err - } - } - return b.String(), nil -} - -func appendIdentityPart(b *strings.Builder, item map[string]attributeValue, attrName string) error { - attr, ok := item[attrName] - if !ok { - return errors.New("missing key attribute") - } - key, err := attributeValueAsKeySegment(attr) - if err != nil { - return err - } - b.WriteString(attrName) - b.WriteByte('=') - b.WriteString(base64.RawURLEncoding.EncodeToString(key)) - b.WriteByte('|') - return nil -} - -func compareReadOrder(orderKey dynamoKeySchema, primaryKey dynamoKeySchema, left map[string]attributeValue, right map[string]attributeValue) int { - if cmp := compareAttributeValueByName(orderKey.HashKey, left, right); cmp != 0 { - return cmp - } - if orderKey.RangeKey != "" { - if cmp := compareAttributeValueByName(orderKey.RangeKey, left, right); cmp != 0 { - return cmp - } - } - if cmp := compareAttributeValueByName(primaryKey.HashKey, left, right); cmp != 0 { - return cmp - } - if primaryKey.RangeKey != "" { - if cmp := compareAttributeValueByName(primaryKey.RangeKey, left, right); cmp != 0 { - return cmp - } - } - return 0 -} - -func compareAttributeValueByName(attrName string, left map[string]attributeValue, right map[string]attributeValue) int { - if attrName == "" { - return 0 - } - leftAttr, leftOK := left[attrName] - rightAttr, rightOK := right[attrName] - switch { - case !leftOK && !rightOK: - return 0 - case !leftOK: - return -1 - case !rightOK: - return 1 - default: - return compareAttributeValueSortKey(leftAttr, rightAttr) - } -} - -func validateReadSelect(selectValue string) error { - switch strings.TrimSpace(selectValue) { - case "", "ALL_ATTRIBUTES", "ALL_PROJECTED_ATTRIBUTES", "SPECIFIC_ATTRIBUTES", "COUNT": - return nil - default: - return newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, "unsupported Select") - } -} - -func resolveReadLimit(limit *int32) (int, bool, error) { - if limit == nil { - return 0, false, nil - } - if *limit <= 0 { - return 0, false, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, "invalid Limit") - } - return int(*limit), true, nil -} - -func readIteratorPageLimit(limit *int32) int { - resolved, hasLimit, err := resolveReadLimit(limit) - if err != nil || !hasLimit { - return dynamoScanPageLimit - } - pageLimit := resolved + 1 - if pageLimit > dynamoScanPageLimit { - return dynamoScanPageLimit - } - if pageLimit < 1 { - return 1 - } - return pageLimit -} - -func (d *DynamoDBServer) streamQueryItems( - ctx context.Context, - in queryInput, - schema *dynamoTableSchema, - keySchema dynamoKeySchema, - cond queryCondition, - readTS uint64, - opts readPageOptions, -) (*queryOutput, bool, error) { - iter, ok, err := d.newQueryReadIterator(in, schema, keySchema, cond, readTS, opts) - if err != nil || !ok { - return nil, ok, err - } - out, err := finalizeReadIterator(ctx, schema, iter, opts) - if err != nil { - return nil, true, err - } - return out, true, nil -} - -func (d *DynamoDBServer) streamScanItems( - ctx context.Context, - in scanInput, - schema *dynamoTableSchema, - readTS uint64, - opts readPageOptions, -) (*queryOutput, bool, error) { - iter, ok, err := d.newScanReadIterator(in, schema, readTS, opts) - if err != nil || !ok { - return nil, ok, err - } - out, err := finalizeReadIterator(ctx, schema, iter, opts) - if err != nil { - return nil, true, err - } - return out, true, nil -} - -func finalizeReadIterator( - ctx context.Context, - schema *dynamoTableSchema, - iter dynamoReadIterator, - opts readPageOptions, -) (*queryOutput, error) { - state, err := newReadPageState(schema, opts) - if err != nil { - return nil, err - } - if err := state.consumeIterator(ctx, iter); err != nil { - return nil, err - } - return state.output(), nil -} - -func (d *DynamoDBServer) newQueryReadIterator( - in queryInput, - schema *dynamoTableSchema, - keySchema dynamoKeySchema, - cond queryCondition, - readTS uint64, - opts readPageOptions, -) (dynamoReadIterator, bool, error) { - projector := d.readItemProjector(schema, in.IndexName) - filter := itemReadFilter(func(item map[string]attributeValue) bool { - return matchesQueryCondition(item, cond) - }) - pageLimit := readIteratorPageLimit(opts.limit) - bounds, ok, err := resolveQueryReadBounds(schema, in, keySchema, cond, opts.exclusiveStartKey) - if err != nil || !ok { - return nil, ok, err - } - if strings.TrimSpace(in.IndexName) == "" { - return newTableReadIterator(d, bounds, readTS, pageLimit, projector, filter), true, nil - } - return newGSIReadIterator(d, bounds, readTS, pageLimit, projector, filter), true, nil -} - -func (d *DynamoDBServer) newScanReadIterator( - in scanInput, - schema *dynamoTableSchema, - readTS uint64, - opts readPageOptions, -) (dynamoReadIterator, bool, error) { - projector := d.readItemProjector(schema, in.IndexName) - pageLimit := readIteratorPageLimit(opts.limit) - if strings.TrimSpace(in.IndexName) == "" { - bounds, err := resolveTableReadBounds(schema, in.TableName, opts.exclusiveStartKey) - if err != nil { - return nil, false, err - } - return newTableReadIterator(d, bounds, readTS, pageLimit, projector, nil), true, nil - } - if _, err := schema.keySchemaForQuery(in.IndexName); err != nil { - return nil, false, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) - } - bounds, ok, err := resolveGSIReadBounds(schema, in.TableName, in.IndexName, opts.exclusiveStartKey) - if err != nil { - return nil, false, err - } - if len(opts.exclusiveStartKey) > 0 && !ok { - return nil, false, nil - } - return newGSIReadIterator(d, bounds, readTS, pageLimit, projector, nil), true, nil -} - -func resolveTableReadBounds( - schema *dynamoTableSchema, - tableName string, - startKey map[string]attributeValue, -) (dynamoReadBounds, error) { - lower := dynamoItemPrefixForTable(tableName, schema.Generation) - upper := prefixScanEnd(lower) - if len(startKey) == 0 { - return dynamoReadBounds{lower: lower, upper: upper}, nil - } - key, err := schema.itemKeyFromAttributes(startKey) - if err != nil { - return dynamoReadBounds{}, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, "invalid ExclusiveStartKey") - } - return dynamoReadBounds{lower: maxBytes(lower, nextScanCursor(key)), upper: upper}, nil -} - -func resolveGSIReadBounds( - schema *dynamoTableSchema, - tableName string, - indexName string, - startKey map[string]attributeValue, -) (dynamoReadBounds, bool, error) { - lower := dynamoGSIIndexPrefixForTable(tableName, schema.Generation, indexName) - upper := prefixScanEnd(lower) - if len(startKey) == 0 { - return dynamoReadBounds{lower: lower, upper: upper}, true, nil - } - key, ok, err := schema.gsiKeyFromAttributes(indexName, startKey) - if err != nil { - return dynamoReadBounds{}, false, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, "invalid ExclusiveStartKey") - } - if !ok { - return dynamoReadBounds{}, false, nil - } - return dynamoReadBounds{lower: maxBytes(lower, nextScanCursor(key)), upper: upper}, true, nil -} - -func resolveQueryReadBounds( - schema *dynamoTableSchema, - in queryInput, - keySchema dynamoKeySchema, - cond queryCondition, - startKey map[string]attributeValue, -) (dynamoReadBounds, bool, error) { - if !schema.usesOrderedKeyEncoding() { - return dynamoReadBounds{}, false, nil - } - basePrefix, err := queryScanPrefix(schema, in, keySchema, cond.hashValue) - if err != nil { - return dynamoReadBounds{}, false, err - } - bounds := dynamoReadBounds{ - lower: basePrefix, - upper: prefixScanEnd(basePrefix), - reverse: queryReadReverse(in.ScanIndexForward), - } - if keySchema.RangeKey != "" && cond.rangeCond != nil { - bounds, err = refineQueryReadBounds(bounds, basePrefix, *cond.rangeCond) - if err != nil { - return dynamoReadBounds{}, false, err - } - } - if len(startKey) == 0 { - return bounds, true, nil - } - startCursor, ok, err := resolveQueryExclusiveStartKey(schema, in, startKey) - if err != nil { - return dynamoReadBounds{}, false, err - } - if !ok { - return dynamoReadBounds{}, false, nil - } - if bounds.reverse { - bounds.upper = minBytes(bounds.upper, startCursor) - } else { - bounds.lower = maxBytes(bounds.lower, nextScanCursor(startCursor)) - } - return bounds, true, nil -} - -func resolveQueryExclusiveStartKey( - schema *dynamoTableSchema, - in queryInput, - startKey map[string]attributeValue, -) ([]byte, bool, error) { - if len(startKey) == 0 { - return nil, true, nil - } - if strings.TrimSpace(in.IndexName) == "" { - key, err := schema.itemKeyFromAttributes(startKey) - if err != nil { - return nil, false, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, "invalid ExclusiveStartKey") - } - return key, true, nil - } - key, ok, err := schema.gsiKeyFromAttributes(in.IndexName, startKey) - if err != nil { - return nil, false, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, "invalid ExclusiveStartKey") - } - return key, ok, nil -} - -func queryReadReverse(scanIndexForward *bool) bool { - return scanIndexForward != nil && !*scanIndexForward -} - -type readItemProjector func(map[string]attributeValue) (map[string]attributeValue, error) - -func (d *DynamoDBServer) readItemProjector(schema *dynamoTableSchema, indexName string) readItemProjector { - if strings.TrimSpace(indexName) == "" { - return identityReadItemProjector - } - return func(item map[string]attributeValue) (map[string]attributeValue, error) { - return schema.projectItemForIndex(indexName, item) - } -} - -func identityReadItemProjector(item map[string]attributeValue) (map[string]attributeValue, error) { - return item, nil -} - -func finalizeReadPage(schema *dynamoTableSchema, items []map[string]attributeValue, opts readPageOptions) (*queryOutput, error) { - items, err := applyQueryExclusiveStartKey(schema, opts.exclusiveStartKey, items) - if err != nil { - return nil, err - } - state, err := newReadPageState(schema, opts) - if err != nil { - return nil, err - } - if err := state.consume(items); err != nil { - return nil, err - } - return state.output(), nil -} - -type readPageState struct { - schema *dynamoTableSchema - opts readPageOptions - projection []string - filterExpr string - includeItems bool - limit int - hasLimit bool - outItems []map[string]attributeValue - outCount int - scannedCount int - lastEvaluatedKey map[string]attributeValue -} - -type dynamoReadBounds struct { - lower []byte - upper []byte - reverse bool -} - -type keyRangeKVIterator struct { - server *DynamoDBServer - lower []byte - upper []byte - cursor []byte - readTS uint64 - pageLimit int - reverse bool - page []*store.KVPair - index int - done bool -} - -type emptyReadIterator struct{} - -type tableReadIterator struct { - kv *keyRangeKVIterator - projector readItemProjector - filter itemReadFilter -} - -type gsiReadIterator struct { - server *DynamoDBServer - kv *keyRangeKVIterator - readTS uint64 - projector readItemProjector - filter itemReadFilter - seen map[string]struct{} -} - -func newReadPageState(schema *dynamoTableSchema, opts readPageOptions) (*readPageState, error) { - limit, hasLimit, err := resolveReadLimit(opts.limit) - if err != nil { - return nil, err - } - projection, err := resolveProjectionAttributes(opts.projectionExpression, opts.expressionAttributeNames) - if err != nil { - return nil, err - } - filterExpr, err := replaceNames(opts.filterExpression, opts.expressionAttributeNames) - if err != nil { - return nil, err - } - return &readPageState{ - schema: schema, - opts: opts, - projection: projection, - filterExpr: strings.TrimSpace(filterExpr), - includeItems: !strings.EqualFold(strings.TrimSpace(opts.selectValue), dynamoSelectCount), - limit: limit, - hasLimit: hasLimit, - outItems: make([]map[string]attributeValue, 0), - }, nil -} - -func (s *readPageState) consume(items []map[string]attributeValue) error { - for i, item := range items { - if s.reachedLimit() { - break - } - if err := s.consumeItem(i, item, len(items)); err != nil { - return err - } - } - return nil -} - -func (s *readPageState) consumeIterator(ctx context.Context, iter dynamoReadIterator) error { - var lastItem map[string]attributeValue - for !s.reachedLimit() { - item, ok, err := iter.Next(ctx) - if err != nil { - return errors.WithStack(err) - } - if !ok { - return nil - } - if err := s.consumeReadItem(item); err != nil { - return err - } - lastItem = item - } - if lastItem == nil { - return nil - } - if nextItem, ok, err := iter.Next(ctx); err != nil { - return errors.WithStack(err) - } else if ok && nextItem != nil { - s.lastEvaluatedKey = s.buildLastEvaluatedKey(lastItem) - } - return nil -} - -func (s *readPageState) reachedLimit() bool { - return s.hasLimit && s.scannedCount == s.limit -} - -func (s *readPageState) consumeReadItem(item map[string]attributeValue) error { - s.scannedCount++ - match, err := matchesReadFilter(s.filterExpr, item, s.opts.expressionAttributeValues) - if err != nil { - return err - } - if match { - s.recordMatch(item) - } - return nil -} - -func (s *readPageState) consumeItem(i int, item map[string]attributeValue, totalItems int) error { - if err := s.consumeReadItem(item); err != nil { - return err - } - if s.shouldSetLastEvaluatedKey(i, totalItems) { - s.lastEvaluatedKey = s.buildLastEvaluatedKey(item) - } - return nil -} - -func (s *readPageState) recordMatch(item map[string]attributeValue) { - s.outCount++ - if !s.includeItems { - return - } - s.outItems = append(s.outItems, projectItemByAttributes(item, s.projection)) -} - -func (s *readPageState) shouldSetLastEvaluatedKey(i int, totalItems int) bool { - return s.hasLimit && s.scannedCount == s.limit && i < totalItems-1 -} - -func (s *readPageState) buildLastEvaluatedKey(item map[string]attributeValue) map[string]attributeValue { - if s.opts.lastEvaluatedKeyBuilder != nil { - return s.opts.lastEvaluatedKeyBuilder(item) - } - return makeLastEvaluatedKey(s.schema.PrimaryKey, item) -} - -func (s *readPageState) output() *queryOutput { - items := s.outItems - if !s.includeItems { - items = nil - } - return &queryOutput{ - items: items, - count: s.outCount, - scannedCount: s.scannedCount, - lastEvaluatedKey: s.lastEvaluatedKey, - } -} - -func (emptyReadIterator) Next(context.Context) (map[string]attributeValue, bool, error) { - return nil, false, nil -} - -func newKeyRangeKVIterator( - server *DynamoDBServer, - bounds dynamoReadBounds, - readTS uint64, - pageLimit int, -) *keyRangeKVIterator { - cursor := bytes.Clone(bounds.lower) - if bounds.reverse { - cursor = bytes.Clone(bounds.upper) - } - return &keyRangeKVIterator{ - server: server, - lower: bytes.Clone(bounds.lower), - upper: bytes.Clone(bounds.upper), - cursor: cursor, - readTS: readTS, - pageLimit: pageLimit, - reverse: bounds.reverse, - } -} - -func (it *keyRangeKVIterator) Next(ctx context.Context) (*store.KVPair, bool, error) { - for { - if it.index < len(it.page) { - kvp := it.page[it.index] - it.index++ - return kvp, true, nil - } - if it.done { - return nil, false, nil - } - if err := it.loadNextPage(ctx); err != nil { - return nil, false, err - } - } -} - -func (it *keyRangeKVIterator) loadNextPage(ctx context.Context) error { - if it.reverse { - return it.loadNextPageReverse(ctx) - } - return it.loadNextPageForward(ctx) -} - -func (it *keyRangeKVIterator) loadNextPageForward(ctx context.Context) error { - kvs, err := it.server.store.ScanAt(ctx, it.cursor, it.upper, it.pageLimit, it.readTS) - if err != nil { - return errors.WithStack(err) - } - if len(kvs) == 0 { - it.done = true - it.page = nil - return nil - } - it.page, it.done = filterBoundedKVPairsForward(kvs, it.lower, it.upper, it.pageLimit) - it.index = 0 - if !it.done { - it.cursor = nextScanCursor(kvs[len(kvs)-1].Key) - if it.upper != nil && bytes.Compare(it.cursor, it.upper) >= 0 { - it.done = true - } - } - return nil -} - -func (it *keyRangeKVIterator) loadNextPageReverse(ctx context.Context) error { - kvs, err := it.server.store.ReverseScanAt(ctx, it.lower, it.cursor, it.pageLimit, it.readTS) - if err != nil { - return errors.WithStack(err) - } - if len(kvs) == 0 { - it.done = true - it.page = nil - return nil - } - it.page, it.done = filterBoundedKVPairsReverse(kvs, it.lower, it.cursor, it.pageLimit) - it.index = 0 - if !it.done { - it.cursor = bytes.Clone(kvs[len(kvs)-1].Key) - } - return nil -} - -func filterBoundedKVPairsForward(kvs []*store.KVPair, lower []byte, upper []byte, pageLimit int) ([]*store.KVPair, bool) { - page := make([]*store.KVPair, 0, minInt(len(kvs), pageLimit)) - done := len(kvs) < pageLimit - for _, kvp := range kvs { - if lower != nil && bytes.Compare(kvp.Key, lower) < 0 { - continue - } - if upper != nil && bytes.Compare(kvp.Key, upper) >= 0 { - done = true - break - } - page = append(page, kvp) - } - if len(page) == 0 { - done = true - } - return page, done -} - -func filterBoundedKVPairsReverse(kvs []*store.KVPair, lower []byte, upper []byte, pageLimit int) ([]*store.KVPair, bool) { - page := make([]*store.KVPair, 0, minInt(len(kvs), pageLimit)) - done := len(kvs) < pageLimit - for _, kvp := range kvs { - if lower != nil && bytes.Compare(kvp.Key, lower) < 0 { - done = true - break - } - if upper != nil && bytes.Compare(kvp.Key, upper) >= 0 { - continue - } - page = append(page, kvp) - } - if len(page) == 0 { - done = true - } - return page, done -} - -func newTableReadIterator( - server *DynamoDBServer, - bounds dynamoReadBounds, - readTS uint64, - pageLimit int, - projector readItemProjector, - filter itemReadFilter, -) dynamoReadIterator { - if bounds.upper != nil && bytes.Compare(bounds.lower, bounds.upper) >= 0 { - return emptyReadIterator{} - } - return &tableReadIterator{ - kv: newKeyRangeKVIterator(server, bounds, readTS, pageLimit), - projector: projector, - filter: filter, - } -} - -func (it *tableReadIterator) Next(ctx context.Context) (map[string]attributeValue, bool, error) { - for { - kvp, ok, err := it.kv.Next(ctx) - if err != nil || !ok { - return nil, ok, err - } - item, err := decodeStoredDynamoItem(kvp.Value) - if err != nil { - return nil, false, err - } - item, err = it.projector(item) - if err != nil { - return nil, false, err - } - if it.filter != nil && !it.filter(item) { - continue - } - return item, true, nil - } -} - -func newGSIReadIterator( - server *DynamoDBServer, - bounds dynamoReadBounds, - readTS uint64, - pageLimit int, - projector readItemProjector, - filter itemReadFilter, -) dynamoReadIterator { - if bounds.upper != nil && bytes.Compare(bounds.lower, bounds.upper) >= 0 { - return emptyReadIterator{} - } - return &gsiReadIterator{ - server: server, - kv: newKeyRangeKVIterator(server, bounds, readTS, pageLimit), - readTS: readTS, - projector: projector, - filter: filter, - seen: map[string]struct{}{}, - } -} - -func (it *gsiReadIterator) Next(ctx context.Context) (map[string]attributeValue, bool, error) { - for { - kvp, ok, err := it.kv.Next(ctx) - if err != nil || !ok { - return nil, ok, err - } - itemKey := string(kvp.Value) - if _, exists := it.seen[itemKey]; exists { - continue - } - it.seen[itemKey] = struct{}{} - item, found, err := it.server.readItemAtKeyAt(ctx, kvp.Value, it.readTS) - if err != nil { - return nil, false, err - } - if !found { - continue - } - item, err = it.projector(item) - if err != nil { - return nil, false, err - } - if it.filter != nil && !it.filter(item) { - continue - } - return item, true, nil - } -} - -func matchesReadFilter(expr string, item map[string]attributeValue, values map[string]attributeValue) (bool, error) { - if strings.TrimSpace(expr) == "" { - return true, nil - } - ok, err := evalConditionExpression(expr, item, values) - if err != nil { - return false, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) - } - return ok, nil -} - -func resolveProjectionAttributes(expr string, names map[string]string) ([]string, error) { - projectionExpr, err := replaceNames(expr, names) - if err != nil { - return nil, err - } - projection := strings.TrimSpace(projectionExpr) - if projection == "" { - return nil, nil - } - parts, err := splitTopLevelByComma(projection) - if err != nil { - return nil, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, "invalid ProjectionExpression") - } - attrs := make([]string, 0, len(parts)) - for _, part := range parts { - attr := strings.TrimSpace(part) - if attr == "" { - return nil, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, "invalid ProjectionExpression") - } - attrs = append(attrs, attr) - } - return attrs, nil -} - -func projectItem(item map[string]attributeValue, expr string, names map[string]string) (map[string]attributeValue, error) { - attrs, err := resolveProjectionAttributes(expr, names) - if err != nil { - return nil, err - } - return projectItemByAttributes(item, attrs), nil -} - -func projectItemByAttributes(item map[string]attributeValue, attrs []string) map[string]attributeValue { - if len(attrs) == 0 { - return cloneAttributeValueMap(item) - } - out := make(map[string]attributeValue, len(attrs)) - for _, attr := range attrs { - if value, ok := item[attr]; ok { - out[attr] = value - } - } - return out -} - -func decodeItemsFromKVPairs(kvs []*store.KVPair) ([]map[string]attributeValue, error) { - items := make([]map[string]attributeValue, 0, len(kvs)) - for _, kvp := range kvs { - item, err := decodeStoredDynamoItem(kvp.Value) - if err != nil { - return nil, err - } - items = append(items, item) - } - return items, nil -} - -func (d *DynamoDBServer) queryItemsByKeyCondition( - ctx context.Context, - in queryInput, - schema *dynamoTableSchema, - keySchema dynamoKeySchema, - cond queryCondition, - readTS uint64, -) ([]map[string]attributeValue, error) { - if strings.TrimSpace(in.IndexName) != "" { - return d.queryItemsByGSI(ctx, in, schema, cond, readTS) - } - scanPrefix, err := queryScanPrefix(schema, in, keySchema, cond.hashValue) - if err != nil { - return nil, err - } - kvs, err := d.scanAllByPrefixAt(ctx, scanPrefix, readTS) - if err != nil { - return nil, errors.WithStack(err) - } - items, err := filterQueryItems(kvs, cond) - if err != nil { - return nil, err - } - return items, nil -} - -func (d *DynamoDBServer) queryItemsByGSI( - ctx context.Context, - in queryInput, - schema *dynamoTableSchema, - cond queryCondition, - readTS uint64, -) ([]map[string]attributeValue, error) { - keySchema, err := schema.keySchemaForQuery(in.IndexName) - if err != nil { - return nil, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) - } - prefix, err := queryScanPrefix(schema, in, keySchema, cond.hashValue) - if err != nil { - return nil, err - } - kvs, err := d.scanAllByPrefixAt(ctx, prefix, readTS) - if err != nil { - return nil, errors.WithStack(err) - } - itemKeys := uniqueGSIItemKeys(kvs) - items, err := d.readItemsForGSIQuery(ctx, itemKeys, readTS, cond) - if err != nil { - return nil, err - } - return items, nil -} - -func (d *DynamoDBServer) scanItemsBySource( - ctx context.Context, - in scanInput, - schema *dynamoTableSchema, - readTS uint64, -) ([]map[string]attributeValue, error) { - if strings.TrimSpace(in.IndexName) == "" { - kvs, err := d.scanAllByPrefixAt(ctx, dynamoItemPrefixForTable(in.TableName, schema.Generation), readTS) - if err != nil { - return nil, errors.WithStack(err) - } - return decodeItemsFromKVPairs(kvs) - } - if _, err := schema.keySchemaForQuery(in.IndexName); err != nil { - return nil, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) - } - kvs, err := d.scanAllByPrefixAt(ctx, dynamoGSIIndexPrefixForTable(in.TableName, schema.Generation, in.IndexName), readTS) - if err != nil { - return nil, errors.WithStack(err) - } - itemKeys := uniqueGSIItemKeys(kvs) - return d.readItemsAtKeys(ctx, itemKeys, readTS) -} - -func uniqueGSIItemKeys(kvs []*store.KVPair) [][]byte { - if len(kvs) == 0 { - return nil - } - out := make([][]byte, 0, len(kvs)) - seen := make(map[string]struct{}, len(kvs)) - for _, kvp := range kvs { - key := string(kvp.Value) - if _, exists := seen[key]; exists { - continue - } - seen[key] = struct{}{} - out = append(out, bytes.Clone(kvp.Value)) - } - return out -} - -type gsiReadJob struct { - index int - key []byte -} - -type gsiReadResult struct { - index int - item map[string]attributeValue - err error -} - -type itemReadFilter func(map[string]attributeValue) bool - -func resolveGSIReadWorkerCount(n int) int { - if n <= 0 { - return 0 - } - if n < gsiQueryReadWorkerCount { - return n - } - return gsiQueryReadWorkerCount -} - -func (d *DynamoDBServer) readItemsForGSIQuery( - ctx context.Context, - itemKeys [][]byte, - readTS uint64, - cond queryCondition, -) ([]map[string]attributeValue, error) { - return d.readItemsAtKeysMatching(ctx, itemKeys, readTS, func(item map[string]attributeValue) bool { - return matchesQueryCondition(item, cond) - }) -} - -func (d *DynamoDBServer) readItemsAtKeys( - ctx context.Context, - itemKeys [][]byte, - readTS uint64, -) ([]map[string]attributeValue, error) { - return d.readItemsAtKeysMatching(ctx, itemKeys, readTS, nil) -} - -func (d *DynamoDBServer) readItemsAtKeysMatching( - ctx context.Context, - itemKeys [][]byte, - readTS uint64, - filter itemReadFilter, -) ([]map[string]attributeValue, error) { - if len(itemKeys) == 0 { - return nil, nil - } - workerCount := resolveGSIReadWorkerCount(len(itemKeys)) - jobs := make(chan gsiReadJob) - results := make(chan gsiReadResult, len(itemKeys)) - workerCtx, cancel := context.WithCancel(ctx) - defer cancel() - - var wg sync.WaitGroup - d.startGSIReadWorkers(&wg, workerCount, workerCtx, readTS, filter, jobs, results, cancel) - enqueueGSIReadJobs(workerCtx, jobs, itemKeys) - close(jobs) - wg.Wait() - close(results) - - return collectOrderedGSIReadResults(itemKeys, results) -} - -func (d *DynamoDBServer) startGSIReadWorkers( - wg *sync.WaitGroup, - workerCount int, - ctx context.Context, - readTS uint64, - filter itemReadFilter, - jobs <-chan gsiReadJob, - results chan<- gsiReadResult, - cancel context.CancelFunc, -) { - for range workerCount { - wg.Go(func() { - d.gsiReadWorker(ctx, readTS, filter, jobs, results, cancel) - }) - } -} - -func enqueueGSIReadJobs(ctx context.Context, jobs chan<- gsiReadJob, itemKeys [][]byte) { -enqueueLoop: - for i, key := range itemKeys { - select { - case <-ctx.Done(): - break enqueueLoop - case jobs <- gsiReadJob{index: i, key: key}: - } - } -} - -func collectOrderedGSIReadResults( - itemKeys [][]byte, - results <-chan gsiReadResult, -) ([]map[string]attributeValue, error) { - indexed := make(map[int]map[string]attributeValue, len(itemKeys)) - for res := range results { - if res.err != nil { - return nil, res.err - } - if res.item != nil { - indexed[res.index] = res.item - } - } - items := make([]map[string]attributeValue, 0, len(indexed)) - for i := range itemKeys { - if item := indexed[i]; item != nil { - items = append(items, item) - } - } - return items, nil -} - -func (d *DynamoDBServer) gsiReadWorker( - ctx context.Context, - readTS uint64, - filter itemReadFilter, - jobs <-chan gsiReadJob, - results chan<- gsiReadResult, - cancel context.CancelFunc, -) { - for { - job, ok := nextGSIReadJob(ctx, jobs) - if !ok { - return - } - item, emit, err := d.executeGSIReadJob(ctx, readTS, filter, job.key) - if err != nil { - sendGSIReadError(results, err) - cancel() - return - } - if !emit { - continue - } - if !sendGSIReadResult(ctx, results, gsiReadResult{index: job.index, item: item}) { - return - } - } -} - -func nextGSIReadJob(ctx context.Context, jobs <-chan gsiReadJob) (gsiReadJob, bool) { - select { - case <-ctx.Done(): - return gsiReadJob{}, false - case job, ok := <-jobs: - if !ok { - return gsiReadJob{}, false - } - return job, true - } -} - -func (d *DynamoDBServer) executeGSIReadJob( - ctx context.Context, - readTS uint64, - filter itemReadFilter, - key []byte, -) (map[string]attributeValue, bool, error) { - item, found, err := d.readItemAtKeyAt(ctx, key, readTS) - if err != nil { - return nil, false, err - } - if !found { - return nil, false, nil - } - if filter != nil && !filter(item) { - return nil, false, nil - } - return item, true, nil -} - -func sendGSIReadError(results chan<- gsiReadResult, err error) { - select { - case results <- gsiReadResult{err: err}: - default: - } -} - -func sendGSIReadResult(ctx context.Context, results chan<- gsiReadResult, result gsiReadResult) bool { - select { - case results <- result: - return true - case <-ctx.Done(): - return false - } -} - -func queryScanPrefix(schema *dynamoTableSchema, in queryInput, keySchema dynamoKeySchema, hashValue attributeValue) ([]byte, error) { - if !schema.usesOrderedKeyEncoding() { - hashKey, err := attributeValueAsKey(hashValue) - if err != nil { - return nil, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) - } - if strings.TrimSpace(in.IndexName) != "" { - return legacyDynamoGSIHashPrefixForTable(in.TableName, schema.Generation, in.IndexName, hashKey), nil - } - if keySchema.HashKey != schema.PrimaryKey.HashKey { - return dynamoItemPrefixForTable(in.TableName, schema.Generation), nil - } - return legacyDynamoItemHashPrefixForTable(in.TableName, schema.Generation, hashKey), nil - } - hashKey, err := attributeValueAsKeySegment(hashValue) - if err != nil { - return nil, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) - } - if strings.TrimSpace(in.IndexName) != "" { - return dynamoGSIHashPrefixForTable(in.TableName, schema.Generation, in.IndexName, hashKey), nil - } - if keySchema.HashKey != schema.PrimaryKey.HashKey { - return dynamoItemPrefixForTable(in.TableName, schema.Generation), nil - } - return dynamoItemHashPrefixForTable(in.TableName, schema.Generation, hashKey), nil -} - -func refineQueryReadBounds( - bounds dynamoReadBounds, - basePrefix []byte, - cond queryRangeCondition, -) (dynamoReadBounds, error) { - switch cond.op { - case queryRangeOpEqual, queryRangeOpLessThan, queryRangeOpLessOrEq, queryRangeOpGreater, queryRangeOpGreaterEq: - return refineQueryComparisonBounds(bounds, basePrefix, cond) - case queryRangeOpBetween: - return refineQueryBetweenBounds(bounds, basePrefix, cond) - case queryRangeOpBeginsWith: - return refineQueryBeginsWithBounds(bounds, basePrefix, cond.value1) - default: - return bounds, nil - } -} - -func refineQueryComparisonBounds( - bounds dynamoReadBounds, - basePrefix []byte, - cond queryRangeCondition, -) (dynamoReadBounds, error) { - prefix, err := appendRangeConditionPrefix(basePrefix, cond.value1) - if err != nil { - return dynamoReadBounds{}, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) - } - if cond.op == queryRangeOpEqual { - bounds.lower = prefix - bounds.upper = prefixScanEnd(prefix) - return bounds, nil - } - if cond.op == queryRangeOpLessThan { - bounds.upper = minBytes(bounds.upper, prefix) - return bounds, nil - } - if cond.op == queryRangeOpLessOrEq { - bounds.upper = minBytes(bounds.upper, prefixScanEnd(prefix)) - return bounds, nil - } - if cond.op == queryRangeOpGreater { - bounds.lower = maxBytes(bounds.lower, prefixScanEnd(prefix)) - return bounds, nil - } - if cond.op == queryRangeOpGreaterEq { - bounds.lower = maxBytes(bounds.lower, prefix) - return bounds, nil - } - return bounds, nil -} - -func refineQueryBetweenBounds( - bounds dynamoReadBounds, - basePrefix []byte, - cond queryRangeCondition, -) (dynamoReadBounds, error) { - lower, err := appendRangeConditionPrefix(basePrefix, cond.value1) - if err != nil { - return dynamoReadBounds{}, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) - } - upper, err := appendRangeConditionPrefix(basePrefix, cond.value2) - if err != nil { - return dynamoReadBounds{}, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) - } - bounds.lower = maxBytes(bounds.lower, lower) - bounds.upper = minBytes(bounds.upper, prefixScanEnd(upper)) - return bounds, nil -} - -func refineQueryBeginsWithBounds( - bounds dynamoReadBounds, - basePrefix []byte, - value attributeValue, -) (dynamoReadBounds, error) { - prefix, err := appendRangeConditionPrefixMatch(basePrefix, value) - if err != nil { - return dynamoReadBounds{}, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) - } - bounds.lower = maxBytes(bounds.lower, prefix) - bounds.upper = minBytes(bounds.upper, prefixScanEnd(prefix)) - return bounds, nil -} - -func appendRangeConditionPrefix(basePrefix []byte, value attributeValue) ([]byte, error) { - segment, err := attributeValueAsKeySegment(value) - if err != nil { - return nil, err - } - return append(bytes.Clone(basePrefix), segment...), nil -} - -func appendRangeConditionPrefixMatch(basePrefix []byte, value attributeValue) ([]byte, error) { - raw, err := attributeValueAsKeyBytes(value) - if err != nil { - return nil, err - } - segment := encodeDynamoKeySegmentPrefix(raw) - return append(bytes.Clone(basePrefix), segment...), nil -} - -func maxBytes(left []byte, right []byte) []byte { - if left == nil { - return bytes.Clone(right) - } - if right == nil { - return bytes.Clone(left) - } - if bytes.Compare(left, right) >= 0 { - return bytes.Clone(left) - } - return bytes.Clone(right) -} - -func minBytes(left []byte, right []byte) []byte { - if left == nil { - return bytes.Clone(right) - } - if right == nil { - return bytes.Clone(left) - } - if bytes.Compare(left, right) <= 0 { - return bytes.Clone(left) - } - return bytes.Clone(right) -} - -func applyQueryExclusiveStartKey(schema *dynamoTableSchema, startKey map[string]attributeValue, items []map[string]attributeValue) ([]map[string]attributeValue, error) { - if len(startKey) == 0 { - return items, nil - } - startItemKey, err := schema.itemKeyFromAttributes(startKey) - if err != nil { - return nil, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, "invalid ExclusiveStartKey") - } - descending, hasDirection := queryItemOrderDirection(schema, items) - for i, item := range items { - if remaining, ok := exclusiveStartRemainingItems(schema, item, items, i, startItemKey, descending, hasDirection); ok { - return remaining, nil - } - } - return nil, nil -} - -func exclusiveStartRemainingItems( - schema *dynamoTableSchema, - item map[string]attributeValue, - items []map[string]attributeValue, - index int, - startItemKey []byte, - descending bool, - hasDirection bool, -) ([]map[string]attributeValue, bool) { - itemKey, err := schema.itemKeyFromAttributes(item) - if err != nil { - return nil, false - } - if bytes.Equal(itemKey, startItemKey) { - return items[index+1:], true - } - if !hasDirection || !exclusiveStartShouldAdvance(descending, itemKey, startItemKey) { - return nil, false - } - return items[index:], true -} - -func exclusiveStartShouldAdvance(descending bool, itemKey []byte, startItemKey []byte) bool { - cmp := bytes.Compare(itemKey, startItemKey) - return (!descending && cmp > 0) || (descending && cmp < 0) -} - -func queryItemOrderDirection(schema *dynamoTableSchema, items []map[string]attributeValue) (bool, bool) { - var previous []byte - for _, item := range items { - itemKey, err := schema.itemKeyFromAttributes(item) - if err != nil { - continue - } - if previous == nil { - previous = itemKey - continue - } - cmp := bytes.Compare(itemKey, previous) - if cmp == 0 { - continue - } - return cmp < 0, true - } - return false, false -} - -func makeLastEvaluatedKey(keySchema dynamoKeySchema, item map[string]attributeValue) map[string]attributeValue { - out := map[string]attributeValue{} - if hash, ok := item[keySchema.HashKey]; ok { - out[keySchema.HashKey] = hash - } - if keySchema.RangeKey != "" { - if rk, ok := item[keySchema.RangeKey]; ok { - out[keySchema.RangeKey] = rk - } - } - if len(out) == 0 { - return nil - } - return out -} - -func makeReadLastEvaluatedKey(primary dynamoKeySchema, index dynamoKeySchema, item map[string]attributeValue) map[string]attributeValue { - out := makeLastEvaluatedKey(primary, item) - if len(out) == 0 { - out = map[string]attributeValue{} - } - if hash, ok := item[index.HashKey]; ok { - out[index.HashKey] = hash - } - if index.RangeKey != "" { - if rk, ok := item[index.RangeKey]; ok { - out[index.RangeKey] = rk - } - } - if len(out) == 0 { - return nil - } - return out -} - -func (d *DynamoDBServer) batchWriteItem(w http.ResponseWriter, r *http.Request) { - in, err := decodeBatchWriteItemInput(maxDynamoBodyReader(w, r)) - if err != nil { - writeDynamoErrorFromErr(w, err) - return - } - unprocessed, err := d.batchWriteItems(r.Context(), in) - if err != nil { - writeDynamoErrorFromErr(w, err) - return - } - for table, written := range batchWriteCommittedCounts(in, unprocessed) { - d.observeWrittenItems(r.Context(), table, written) - } - writeDynamoJSON(w, map[string]any{"UnprocessedItems": unprocessed}) -} - -func decodeBatchWriteItemInput(bodyReader io.Reader) (batchWriteItemInput, error) { - body, err := io.ReadAll(bodyReader) - if err != nil { - return batchWriteItemInput{}, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) - } - var in batchWriteItemInput - if err := json.Unmarshal(body, &in); err != nil { - return batchWriteItemInput{}, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) - } - if len(in.RequestItems) == 0 { - return batchWriteItemInput{}, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, "missing RequestItems") - } - total := 0 - for tableName, requests := range in.RequestItems { - if strings.TrimSpace(tableName) == "" { - return batchWriteItemInput{}, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, "missing table name") - } - total += len(requests) - } - if total == 0 { - return batchWriteItemInput{}, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, "missing write requests") - } - if total > batchWriteItemMaxItems { - return batchWriteItemInput{}, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, "too many items in BatchWriteItem") - } - return in, nil -} - -func batchWriteCommittedCounts(in batchWriteItemInput, unprocessed map[string][]batchWriteRequest) map[string]int { - out := make(map[string]int, len(in.RequestItems)) - for table, requests := range in.RequestItems { - written := len(requests) - len(unprocessed[table]) - if written > 0 { - out[table] = written - } - } - return out -} - -func (d *DynamoDBServer) batchWriteItems( - ctx context.Context, - in batchWriteItemInput, -) (map[string][]batchWriteRequest, error) { - tableNames := make([]string, 0, len(in.RequestItems)) - for tableName := range in.RequestItems { - tableNames = append(tableNames, tableName) - } - sort.Strings(tableNames) - unlock := d.lockTableOperations(tableNames) - defer unlock() - for _, tableName := range tableNames { - if err := d.ensureLegacyTableMigrationLocked(ctx, tableName); err != nil { - return nil, err - } - } - if err := d.validateBatchWriteRequests(ctx, tableNames, in.RequestItems); err != nil { - return nil, err - } - unprocessed := make(map[string][]batchWriteRequest) - for _, tableName := range tableNames { - requests := in.RequestItems[tableName] - for _, request := range requests { - err := d.executeBatchWriteRequest(ctx, tableName, request) - if err == nil { - continue - } - if ctx.Err() != nil { - return nil, errors.WithStack(ctx.Err()) - } - unprocessed[tableName] = append(unprocessed[tableName], request) - } - } - return unprocessed, nil -} - -func (d *DynamoDBServer) validateBatchWriteRequests( - ctx context.Context, - tableNames []string, - requestItems map[string][]batchWriteRequest, -) error { - for _, tableName := range tableNames { - schema, exists, err := d.loadTableSchema(ctx, tableName) - if err != nil { - return errors.WithStack(err) - } - if !exists { - return newDynamoAPIError(http.StatusBadRequest, dynamoErrResourceNotFound, "table not found") - } - seenKeys := make(map[string]struct{}, len(requestItems[tableName])) - for _, request := range requestItems[tableName] { - key, err := validateBatchWriteRequestForSchema(schema, request) - if err != nil { - return err - } - if _, ok := seenKeys[string(key)]; ok { - return newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, "duplicate item in BatchWriteItem") - } - seenKeys[string(key)] = struct{}{} - } - } - return nil -} - -func validateBatchWriteRequestForSchema(schema *dynamoTableSchema, request batchWriteRequest) ([]byte, error) { - switch countBatchWriteActions(request) { - case 1: - default: - return nil, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, "invalid batch write request") - } - switch { - case request.PutRequest != nil: - key, err := schema.itemKeyFromAttributes(request.PutRequest.Item) - if err != nil { - return nil, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) - } - return key, nil - case request.DeleteRequest != nil: - key, err := schema.itemKeyFromAttributes(request.DeleteRequest.Key) - if err != nil { - return nil, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) - } - return key, nil - default: - return nil, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, "invalid batch write request") - } -} - -func (d *DynamoDBServer) executeBatchWriteRequest( - ctx context.Context, - tableName string, - request batchWriteRequest, -) error { - schema, exists, err := d.loadTableSchema(ctx, tableName) - if err != nil { - return errors.WithStack(err) - } - if !exists { - return newDynamoAPIError(http.StatusBadRequest, dynamoErrResourceNotFound, "table not found") - } - keyAttrs, err := batchWriteRequestKey(schema, request) - if err != nil { - return err - } - lockKey, err := dynamoItemUpdateLockKey(tableName, keyAttrs) - if err != nil { - return newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) - } - unlock := d.lockItemUpdate(lockKey) - defer unlock() - switch countBatchWriteActions(request) { - case 1: - default: - return newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, "invalid batch write request") - } - switch { - case request.PutRequest != nil: - _, err := d.putItemWithRetry(ctx, putItemInput{ - TableName: tableName, - Item: request.PutRequest.Item, - }) - return err - case request.DeleteRequest != nil: - _, err := d.deleteItemWithRetry(ctx, deleteItemInput{ - TableName: tableName, - Key: request.DeleteRequest.Key, - }) - return err - default: - return newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, "invalid batch write request") - } -} - -func batchWriteRequestKey(schema *dynamoTableSchema, request batchWriteRequest) (map[string]attributeValue, error) { - switch { - case request.PutRequest != nil: - return primaryKeyAttributes(schema.PrimaryKey, request.PutRequest.Item) - case request.DeleteRequest != nil: - return primaryKeyAttributes(schema.PrimaryKey, request.DeleteRequest.Key) - default: - return nil, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, "invalid batch write request") - } -} - -func primaryKeyAttributes(keySchema dynamoKeySchema, attrs map[string]attributeValue) (map[string]attributeValue, error) { - out := make(map[string]attributeValue, primaryKeyAttributeCapacity(keySchema)) - hash, ok := attrs[keySchema.HashKey] - if !ok { - return nil, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, "missing hash key attribute") - } - out[keySchema.HashKey] = hash - if keySchema.RangeKey != "" { - rangeValue, ok := attrs[keySchema.RangeKey] - if !ok { - return nil, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, "missing range key attribute") - } - out[keySchema.RangeKey] = rangeValue - } - return out, nil -} - -func primaryKeyAttributeCapacity(keySchema dynamoKeySchema) int { - size := 1 - if keySchema.RangeKey != "" { - size++ - } - return size -} - -func countBatchWriteActions(request batchWriteRequest) int { - count := 0 - if request.PutRequest != nil { - count++ - } - if request.DeleteRequest != nil { - count++ - } - return count -} - -func (d *DynamoDBServer) transactWriteItems(w http.ResponseWriter, r *http.Request) { - in, err := decodeTransactWriteItemsInput(maxDynamoBodyReader(w, r)) - if err != nil { - writeDynamoErrorFromErr(w, err) - return - } - if err := d.transactWriteItemsWithRetry(r.Context(), in); err != nil { - writeDynamoErrorFromErr(w, err) - return - } - for table, written := range transactWriteWrittenCounts(in) { - d.observeWrittenItems(r.Context(), table, written) - } - writeDynamoJSON(w, map[string]any{}) -} - -func decodeTransactWriteItemsInput(bodyReader io.Reader) (transactWriteItemsInput, error) { - body, err := io.ReadAll(bodyReader) - if err != nil { - return transactWriteItemsInput{}, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) - } - var in transactWriteItemsInput - if err := json.Unmarshal(body, &in); err != nil { - return transactWriteItemsInput{}, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) - } - if len(in.TransactItems) == 0 { - return transactWriteItemsInput{}, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, "missing transact items") - } - return in, nil -} - -// transactGetItems implements TransactGetItems: reads multiple items atomically -// at a single snapshot timestamp, guaranteeing a consistent view across all keys. -func (d *DynamoDBServer) transactGetItems(w http.ResponseWriter, r *http.Request) { - in, err := decodeTransactGetItemsInput(maxDynamoBodyReader(w, r)) - if err != nil { - writeDynamoErrorFromErr(w, err) - return - } - - // Lease-check every shard this transaction will read BEFORE the single - // snapshot timestamp is resolved, so the quorum-freshness bound is - // established without changing the single-snapshot-ts semantics. The - // timestamp below is still sampled once and shared by all items. - if !d.leaseCheckTransactGetItems(w, r, in) { - return - } - - // Acquire a single read timestamp for all items to guarantee a consistent snapshot. - readTS := d.nextTxnReadTS() - pin := d.pinReadTS(readTS) - defer pin.Release() - - responses, tableMetrics, err := d.buildTransactGetItemsResponses(r.Context(), in, readTS) - if err != nil { - writeDynamoErrorFromErr(w, err) - return - } - for table, m := range tableMetrics { - d.observeReadMetrics(r.Context(), table, m.found, m.requested) - } - writeDynamoJSON(w, map[string]any{"Responses": responses}) -} - -// leaseCheckTransactGetItems performs a quorum-freshness lease check on every -// shard the TransactGetItems request will read, with a bounded timeout, BEFORE -// the caller resolves the single snapshot timestamp. Item keys are resolved at a -// tentative timestamp (schemas change rarely, so a slight pre-lease stale schema -// is acceptable) used only to route the lease check; the actual snapshot -// timestamp is sampled by the caller afterwards. Items whose schema or key -// cannot be resolved here are skipped: they never reach a store read, and -// buildTransactGetItemsResponses surfaces the identical validation error -// downstream so error mapping is unchanged. When every item is skipped no -// shard is touched, so the function returns true without a lease read. -// -// Keys are first deduplicated by value, then collapsed to one representative key -// per owning Raft group, so a transaction touching up to transactGetItemsMaxItems -// keys that share a group issues a single lease read instead of one per key. -// Each group maintains its own lease, so checking one key per group still -// establishes freshness for every shard the transaction reads. Returns false -// after writing the same InternalServerError getItem produces on lease failure; -// the caller should simply return. -func (d *DynamoDBServer) leaseCheckTransactGetItems(w http.ResponseWriter, r *http.Request, in transactGetItemsInput) bool { - // leaseCtx bounds the entire pre-pass — both the per-item schema reads - // that resolve keys and the lease reads themselves — so a stalled - // schema read (Pebble backpressure, iterator leak) cannot block the - // handler past dynamoLeaseReadTimeout before the lease phase begins. - leaseCtx, leaseCancel := context.WithTimeout(r.Context(), dynamoLeaseReadTimeout) - defer leaseCancel() - uniqueKeys, skipLease, transientErr := d.resolveTransactGetItemKeys(leaseCtx, in) - if transientErr != nil { - writeDynamoError(w, http.StatusInternalServerError, dynamoErrInternal, transientErr.Error()) - return false - } - if skipLease { - return true - } - groupKeys := kv.LeaseReadGroupKeys(d.coordinator, uniqueKeys) - if leaseErr := d.leaseReadGroupKeys(leaseCtx, groupKeys); leaseErr != nil { - writeDynamoError(w, http.StatusInternalServerError, dynamoErrInternal, leaseErr.Error()) - return false - } - return true -} - -// resolveTransactGetItemKeys runs the per-item schema resolution that the -// lease pre-pass needs. Returns (uniqueKeys, skipLease, transientErr) where -// skipLease is true when the read path will surface a deterministic 4xx via -// buildTransactGetItemsResponses without touching any store — leasing the -// valid items in that case only risks masking that 4xx with a degraded-shard -// 500 (codex P2 #952). skipLease covers three cases: -// - (a) every item was malformed (nothing to fence) -// - (b) at least one item was malformed and at least one was valid -// (buildTransactGetItemsResponses returns a ValidationException for the -// malformed item; the valid items never reach a store read) -// - (c) the request contains a duplicate (table, key) pair — DynamoDB -// rejects this with `Transaction request cannot include multiple -// operations on one item`, a deterministic ValidationException the read -// path produces before touching data, so the lease must be skipped for -// the same reason malformed-mixed-with-valid is skipped. -// -// transientErr is the schema-read failure the caller MUST fail closed on -// (CLAUDE.md: the slow conditions the fence targets are exactly when a -// silently-dropped item would let a stale snapshot through). -func (d *DynamoDBServer) resolveTransactGetItemKeys(ctx context.Context, in transactGetItemsInput) ([][]byte, bool, error) { - tentativeTS := snapshotTS(d.coordinator.Clock(), d.store) - schemaCache := make(map[string]*dynamoTableSchema) - seenKeys := make(map[string]struct{}, len(in.TransactItems)) - uniqueKeys := make([][]byte, 0, len(in.TransactItems)) - hasMalformed := false - hasDuplicate := false - for _, item := range in.TransactItems { - itemKey, ok, err := d.transactGetItemKey(ctx, item, schemaCache, tentativeTS) - if err != nil { - return nil, false, err - } - if !ok { - hasMalformed = true - continue - } - if _, dup := seenKeys[string(itemKey)]; dup { - hasDuplicate = true - continue - } - seenKeys[string(itemKey)] = struct{}{} - uniqueKeys = append(uniqueKeys, itemKey) - } - if hasMalformed || hasDuplicate || len(uniqueKeys) == 0 { - return nil, true, nil - } - return uniqueKeys, false, nil -} - -// leaseReadGroupKeys fences every group whose key appears in groupKeys. The -// single-group case stays on the calling goroutine; multi-group fan-out is -// concurrent so a 100-item TransactGetItems does not serialize into 100 Raft -// round-trips and blow dynamoLeaseReadTimeout (gemini HIGH on PR #952). The -// fan-out is bounded by len(groupKeys) ≤ transactGetItemsMaxItems (100), so a -// per-call goroutine pool is unnecessary. Returns the first error seen across -// all goroutines (the rest are dropped to preserve the single-response -// contract at the HTTP layer). -func (d *DynamoDBServer) leaseReadGroupKeys(ctx context.Context, groupKeys [][]byte) error { - if len(groupKeys) == 0 { - return nil - } - if len(groupKeys) == 1 { - _, err := kv.LeaseReadForKeyThrough(d.coordinator, ctx, groupKeys[0]) - return errors.WithStack(err) - } - // Derive a cancellable child so the first error cancels the sibling lease - // reads instead of letting them run out the full dynamoLeaseReadTimeout - // budget (coderabbit Major on PR #952 round-4). The siblings observe the - // cancellation via the LeaseReadForKeyThrough's own context check. - cancelCtx, cancel := context.WithCancel(ctx) - defer cancel() - errCh := make(chan error, len(groupKeys)) - var wg sync.WaitGroup - for _, itemKey := range groupKeys { - wg.Add(1) - go func(k []byte) { - defer wg.Done() - if _, err := kv.LeaseReadForKeyThrough(d.coordinator, cancelCtx, k); err != nil { - select { - case errCh <- err: - cancel() // unwind the remaining goroutines on the first error. - default: - } - } - }(itemKey) - } - wg.Wait() - close(errCh) - for err := range errCh { - if err != nil { - return err - } - } - return nil -} - -// transactGetItemKey resolves the storage key for one TransactGetItems Get at -// the tentative timestamp. It returns (key, true, nil) on success, -// (nil, false, nil) when the item is MALFORMED (nil Get, empty/unknown table, -// or an invalid key) — the read path rejects those identically, so the lease -// pre-pass may safely skip them — and (nil, false, err) for a TRANSIENT or -// INTERNAL schema-read failure (leaseCtx deadline, Pebble error) that the -// caller MUST fail closed on rather than skip, otherwise the item's shard goes -// unfenced and a stale read can slip through. The malformed/transient split -// keys off dynamoErrIsTransient: validation errors are *dynamoAPIError, -// everything else is treated as transient. It never writes a response: -// validation is left to the read path so error mapping stays identical. -func (d *DynamoDBServer) transactGetItemKey(ctx context.Context, item transactGetItem, schemaCache map[string]*dynamoTableSchema, tentativeTS uint64) ([]byte, bool, error) { - if item.Get == nil || strings.TrimSpace(item.Get.TableName) == "" { - return nil, false, nil - } - schema, err := d.resolveTransactTableSchema(ctx, schemaCache, item.Get.TableName, tentativeTS) - if err != nil { - if dynamoErrIsTransient(err) { - return nil, false, errors.WithStack(err) - } - // Validation error (table not found): the read path rejects it - // identically, so skip rather than fail closed. - return nil, false, nil - } - // itemKeyFromAttributes only fails on malformed key attributes - // (missing/unparseable hash or range key), a pure validation error the - // read path rejects identically; transactGetItemKeyFromSchema swallows - // it to ok=false so the item is skipped, not failed closed. - itemKey, ok := transactGetItemKeyFromSchema(schema, item.Get.Key) - return itemKey, ok, nil -} - -// transactGetItemKeyFromSchema computes the storage key for a TransactGetItems -// Get, returning ok=false when the key attributes are malformed. The error is -// deliberately discarded: it is a validation failure the read path reports -// downstream, and the lease pre-pass only needs the routing key. -func transactGetItemKeyFromSchema(schema *dynamoTableSchema, key map[string]attributeValue) ([]byte, bool) { - itemKey, err := schema.itemKeyFromAttributes(key) - if err != nil { - return nil, false - } - return itemKey, true -} - -func decodeTransactGetItemsInput(bodyReader io.Reader) (transactGetItemsInput, error) { - body, err := io.ReadAll(bodyReader) - if err != nil { - return transactGetItemsInput{}, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) - } - var in transactGetItemsInput - if err := json.Unmarshal(body, &in); err != nil { - return transactGetItemsInput{}, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) - } - if len(in.TransactItems) == 0 { - return transactGetItemsInput{}, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, "missing transact items") - } - if len(in.TransactItems) > transactGetItemsMaxItems { - return transactGetItemsInput{}, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, - "Too many items in TransactGetItems: "+strconv.Itoa(len(in.TransactItems))+" (max "+strconv.Itoa(transactGetItemsMaxItems)+")") - } - return in, nil -} - -// collectTransactGetTableNames returns a deduplicated list of table names referenced -// in the TransactGetItems input. Used to run ensureLegacyTableMigration once per table. -func collectTransactGetTableNames(in transactGetItemsInput) []string { - seen := make(map[string]struct{}, len(in.TransactItems)) - names := make([]string, 0, len(in.TransactItems)) - for _, item := range in.TransactItems { - if item.Get == nil { - continue - } - t := item.Get.TableName - if _, exists := seen[t]; exists { - continue - } - seen[t] = struct{}{} - names = append(names, t) - } - return names -} - -// transactGetItemsMetrics holds per-table counts for metrics reporting. -type transactGetItemsMetrics struct { - requested int - found int -} - -// buildTransactGetItemsResponses reads each requested item at the given readTS -// and returns the ordered response list and a per-table metrics map. -// schemaCache avoids redundant storage reads when multiple items share the same table. -// seenItemKeys enforces the DynamoDB rule that a transaction may not reference the -// same item more than once. -// ensureLegacyTableMigration is called once per unique table before any item is read. -func (d *DynamoDBServer) buildTransactGetItemsResponses(ctx context.Context, in transactGetItemsInput, readTS uint64) ([]map[string]any, map[string]*transactGetItemsMetrics, error) { - tableNames := collectTransactGetTableNames(in) - for _, tableName := range tableNames { - if err := d.ensureLegacyTableMigration(ctx, tableName); err != nil { - return nil, nil, err - } - } - schemaCache := make(map[string]*dynamoTableSchema) - seenItemKeys := make(map[transactGetSeenKey]struct{}, len(in.TransactItems)) - tableMetrics := make(map[string]*transactGetItemsMetrics) - responses := make([]map[string]any, 0, len(in.TransactItems)) - for _, item := range in.TransactItems { - entry, itemFound, tableName, err := d.readTransactGetItem(ctx, item, schemaCache, seenItemKeys, readTS) - if err != nil { - return nil, nil, err - } - responses = append(responses, entry) - m := tableMetrics[tableName] - if m == nil { - m = &transactGetItemsMetrics{} - tableMetrics[tableName] = m - } - m.requested++ - if itemFound { - m.found++ - } - } - return responses, tableMetrics, nil -} - -// transactGetSeenKey is the map key used for duplicate-item detection in -// TransactGetItems. Using a struct avoids separator-collision risks from -// string concatenation and is more idiomatic Go. -type transactGetSeenKey struct { - tableName string - keyStr string -} - -// readTransactGetItem validates and reads a single item in a TransactGetItems request. -// ensureLegacyTableMigration must be called for g.TableName before invoking this function. -// Returns the response entry, whether the item was found, the table name, and any error. -// Returning the table name avoids the caller having to re-access item.Get after the call. -func (d *DynamoDBServer) readTransactGetItem(ctx context.Context, item transactGetItem, schemaCache map[string]*dynamoTableSchema, seenItemKeys map[transactGetSeenKey]struct{}, readTS uint64) (map[string]any, bool, string, error) { - if item.Get == nil { - return nil, false, "", newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, "TransactGetItems only supports Get operations") - } - g := item.Get - if strings.TrimSpace(g.TableName) == "" { - return nil, false, "", newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, "missing TableName in Get") - } - schema, err := d.resolveTransactTableSchema(ctx, schemaCache, g.TableName, readTS) - if err != nil { - return nil, false, "", err - } - // Reject duplicate item keys to match real DynamoDB behavior. - // canonicalPrimaryKeyStr reads only hash/range key attributes from g.Key - // by schema name, so extra attributes in the map are safely ignored — - // no separate primaryKeyAttributes extraction is needed. - keyStr, err := canonicalPrimaryKeyStr(schema.PrimaryKey, g.Key) - if err != nil { - return nil, false, "", newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) - } - seenKey := transactGetSeenKey{tableName: g.TableName, keyStr: keyStr} - if _, dup := seenItemKeys[seenKey]; dup { - return nil, false, "", newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, - "Transaction request cannot include multiple operations on one item") - } - seenItemKeys[seenKey] = struct{}{} - loc, found, err := d.readLogicalItemAt(ctx, schema, g.Key, readTS) - if err != nil { - // Return the error as-is: storage errors from readItemAtKeyAt surface as - // InternalServerError (500) via writeDynamoErrorFromErr in the HTTP handler. - return nil, false, "", err - } - entry := map[string]any{} - if found { - projected, err := projectItem(loc.item, g.ProjectionExpression, g.ExpressionAttributeNames) - if err != nil { - return nil, false, "", err - } - entry["Item"] = projected - } - return entry, found, g.TableName, nil -} - -// canonicalPrimaryKeyStr returns a collision-free canonical string of primary -// key attributes for duplicate-item detection in TransactGetItems and -// TransactWriteItems. Shared between both operations to avoid duplicated logic. -// -// Takes the table's keySchema so it can write hash key then range key in a -// fixed schema-defined order, avoiding a slice allocation and sort — DynamoDB -// primary keys have at most two attributes, so direct lookup beats sorting. -// -// Format per attribute: "=::", separated by \x1f. -// The length prefix makes the format collision-free: a string value that -// contains \x1f cannot be confused with the inter-attribute separator because -// the decoder knows exactly how many bytes belong to each value. -// Numeric values are normalised; binary values are base64-encoded. -func canonicalPrimaryKeyStr(keySchema dynamoKeySchema, key map[string]attributeValue) (string, error) { - var buf strings.Builder - hashVal, ok := key[keySchema.HashKey] - if !ok { - return "", newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, "missing hash key attribute") - } - buf.WriteString(keySchema.HashKey) - buf.WriteByte('=') - if err := writeCanonicalAttrValue(&buf, hashVal); err != nil { - return "", err - } - if keySchema.RangeKey != "" { - rangeVal, ok := key[keySchema.RangeKey] - if !ok { - return "", newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, "missing range key attribute") - } - buf.WriteByte('\x1f') - buf.WriteString(keySchema.RangeKey) - buf.WriteByte('=') - if err := writeCanonicalAttrValue(&buf, rangeVal); err != nil { - return "", err - } - } - return buf.String(), nil -} - -// writeCanonicalAttrValue appends a length-prefixed typed value for a single -// primary key attribute to buf. Format: "::". -// Supports S (string), N (normalised number), and B (base64-encoded binary). -// The length prefix prevents collisions when string values contain \x1f. -func writeCanonicalAttrValue(buf *strings.Builder, v attributeValue) error { - switch { - case v.S != nil: - buf.WriteString("S:") - buf.WriteString(strconv.Itoa(len(*v.S))) - buf.WriteByte(':') - buf.WriteString(*v.S) - case v.N != nil: - n := canonicalNumberString(*v.N) - buf.WriteString("N:") - buf.WriteString(strconv.Itoa(len(n))) - buf.WriteByte(':') - buf.WriteString(n) - case v.B != nil: - encoded := base64.StdEncoding.EncodeToString(v.B) - buf.WriteString("B:") - buf.WriteString(strconv.Itoa(len(encoded))) - buf.WriteByte(':') - buf.WriteString(encoded) - default: - return newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, "unsupported key attribute type for duplicate detection") - } - return nil -} - -func collectTransactWriteTableNames(in transactWriteItemsInput) ([]string, error) { - seen := map[string]struct{}{} - names := make([]string, 0, len(in.TransactItems)) - for _, item := range in.TransactItems { - tableName, err := transactWriteItemTableName(item) - if err != nil { - return nil, err - } - if tableName == "" { - return nil, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, "missing table name") - } - if _, exists := seen[tableName]; exists { - continue - } - seen[tableName] = struct{}{} - names = append(names, tableName) - } - return names, nil -} - -func transactWriteWrittenCounts(in transactWriteItemsInput) map[string]int { - out := make(map[string]int) - for _, item := range in.TransactItems { - tableName, err := transactWriteItemTableName(item) - if err != nil || strings.TrimSpace(tableName) == "" { - continue - } - switch { - case item.Put != nil, item.Update != nil, item.Delete != nil: - out[tableName]++ - } - } - return out -} - -func transactWriteItemTableName(item transactWriteItem) (string, error) { - switch countTransactWriteItemActions(item) { - case 0: - return "", newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, "missing transact action") - case 1: - default: - return "", newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, "multiple transact actions are not supported") - } - switch { - case item.Put != nil: - return strings.TrimSpace(item.Put.TableName), nil - case item.Update != nil: - return strings.TrimSpace(item.Update.TableName), nil - case item.Delete != nil: - return strings.TrimSpace(item.Delete.TableName), nil - default: - return strings.TrimSpace(item.ConditionCheck.TableName), nil - } -} - -func countTransactWriteItemActions(item transactWriteItem) int { - count := 0 - if item.Put != nil { - count++ - } - if item.Update != nil { - count++ - } - if item.Delete != nil { - count++ - } - if item.ConditionCheck != nil { - count++ - } - return count -} - -func (d *DynamoDBServer) transactWriteItemsWithRetry(ctx context.Context, in transactWriteItemsInput) error { - backoff := transactRetryInitialBackoff - deadline := time.Now().Add(transactRetryMaxDuration) - var lastErr error - for range transactRetryMaxAttempts { - reqs, generations, cleanupKeys, err := d.buildTransactWriteItemsRequest(ctx, in) - if err != nil { - return err - } - done, retryErr, fatalErr := d.runTransactWriteAttempt(ctx, reqs, generations, cleanupKeys) - if fatalErr != nil { - return fatalErr - } - if done { - return nil - } - if retryErr != nil { - lastErr = retryErr - } - if err := waitRetryWithDeadline(ctx, deadline, backoff); err != nil { - if lastErr != nil { - combined := errors.Join(err, lastErr) - return errors.Wrap(combined, "transact write retry canceled") - } - return errors.WithStack(err) - } - backoff = nextTransactRetryBackoff(backoff) - } - if lastErr != nil { - return errors.Wrapf(lastErr, "transact write retry attempts exhausted after %s", transactRetryMaxDuration) - } - return newDynamoAPIError(http.StatusInternalServerError, dynamoErrInternal, "transact write retry attempts exhausted") -} - -func (d *DynamoDBServer) runTransactWriteAttempt( - ctx context.Context, - reqs *kv.OperationGroup[kv.OP], - generations map[string]uint64, - cleanupKeys [][]byte, -) (bool, error, error) { - if len(reqs.Elems) == 0 { - return true, nil, nil - } - if _, err := d.coordinator.Dispatch(ctx, reqs); err != nil { - wrapped := errors.WithStack(err) - if !isRetryableTransactWriteError(err) { - return false, nil, wrapped - } - return false, wrapped, nil - } - retry, verifyErr := d.handleGenerationFenceResult( - ctx, - d.verifyTableGenerations(ctx, generations), - cleanupKeys, - ) - if verifyErr != nil { - return false, nil, verifyErr - } - if !retry { - return true, nil, nil - } - return false, nil, nil -} - -func (d *DynamoDBServer) buildTransactWriteItemsRequest(ctx context.Context, in transactWriteItemsInput) (*kv.OperationGroup[kv.OP], map[string]uint64, [][]byte, error) { - tableNames, err := collectTransactWriteTableNames(in) - if err != nil { - return nil, nil, nil, err - } - for _, tableName := range tableNames { - if err := d.ensureLegacyTableMigration(ctx, tableName); err != nil { - return nil, nil, nil, err - } - } - readTS := d.nextTxnReadTS() - reqs := &kv.OperationGroup[kv.OP]{ - IsTxn: true, - // Keep transaction start aligned with the snapshot used to evaluate - // ConditionCheck/ConditionExpression so concurrent writes after readTS - // are detected as write conflicts at commit time. - StartTS: readTS, - } - schemaCache := make(map[string]*dynamoTableSchema) - tableGenerations := make(map[string]uint64) - cleanup := make([][]byte, 0, len(in.TransactItems)) - // seenItemKeys tracks (tableName, primaryKey) pairs to detect duplicates. - // Real DynamoDB rejects requests with multiple operations on the same item. - seenItemKeys := make(map[string]struct{}, len(in.TransactItems)) - for _, item := range in.TransactItems { - if err := d.processTransactWriteItem(ctx, item, readTS, reqs, schemaCache, seenItemKeys, tableGenerations, &cleanup); err != nil { - return nil, nil, nil, err - } - } - return reqs, tableGenerations, cleanup, nil -} - -// processTransactWriteItem validates and plans a single item within a -// TransactWriteItems request, appending the resulting ops to reqs and cleanup. -func (d *DynamoDBServer) processTransactWriteItem( - ctx context.Context, - item transactWriteItem, - readTS uint64, - reqs *kv.OperationGroup[kv.OP], - schemaCache map[string]*dynamoTableSchema, - seenItemKeys map[string]struct{}, - tableGenerations map[string]uint64, - cleanup *[][]byte, -) error { - tableName, err := transactWriteItemTableName(item) - if err != nil { - return err - } - schema, err := d.resolveTransactTableSchema(ctx, schemaCache, tableName, readTS) - if err != nil { - return err - } - // Reject duplicate item keys to match real DynamoDB behavior. - itemKeyStr, err := transactWriteItemPrimaryKeyStr(schema, item) - if err != nil { - return err - } - compositeKey := tableName + "\x00" + itemKeyStr - if _, dup := seenItemKeys[compositeKey]; dup { - return newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, - "Transaction request cannot include multiple operations on one item") - } - seenItemKeys[compositeKey] = struct{}{} - plan, err := d.buildTransactWriteItemPlan(ctx, schema, item, readTS) - if err != nil { - // Real DynamoDB wraps per-item condition failures in - // TransactionCanceledException rather than surfacing the raw - // ConditionalCheckFailedException to the caller. - var apiErr *dynamoAPIError - if errors.As(err, &apiErr) && apiErr.errorType == dynamoErrConditionalFailed { - return newDynamoAPIError(http.StatusBadRequest, dynamoErrTransactionCanceled, apiErr.message) - } - return err - } - reqs.Elems = append(reqs.Elems, plan.elems...) - reqs.ReadKeys = append(reqs.ReadKeys, plan.readKeys...) - if !plan.writes { - return nil - } - tableGenerations[tableName] = schema.Generation - *cleanup = append(*cleanup, plan.cleanup...) - return nil -} - -// transactWriteItemPrimaryKeyStr returns a canonical string of the item's -// primary key attributes, used to detect duplicate-item violations in -// TransactWriteItems (real DynamoDB returns ValidationException for these). -// Delegates to canonicalPrimaryKeyStr for the actual serialization. -// primaryKeyAttributes is applied uniformly across all operation types so that -// only hash/range key fields are used for duplicate detection, regardless of -// whether the operation carries a full Item (Put) or a Key-only map (Update/Delete/ConditionCheck). -func transactWriteItemPrimaryKeyStr(schema *dynamoTableSchema, item transactWriteItem) (string, error) { - var rawAttrs map[string]attributeValue - switch { - case item.Update != nil: - rawAttrs = item.Update.Key - case item.Delete != nil: - rawAttrs = item.Delete.Key - case item.ConditionCheck != nil: - rawAttrs = item.ConditionCheck.Key - case item.Put != nil: - rawAttrs = item.Put.Item - default: - return "", newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, "unsupported transact item") - } - keyAttrs, err := primaryKeyAttributes(schema.PrimaryKey, rawAttrs) - if err != nil { - // primaryKeyAttributes already returns a dynamoAPIError; return it directly - // to preserve its status code and error type. - return "", err - } - keyStr, err := canonicalPrimaryKeyStr(schema.PrimaryKey, keyAttrs) - if err != nil { - return "", newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) - } - return keyStr, nil -} - -type transactWriteItemPlan struct { - elems []*kv.Elem[kv.OP] - cleanup [][]byte - writes bool - // readKeys contains the raw storage keys that were read during plan - // construction. They are propagated into OperationGroup.ReadKeys so the - // FSM can validate read-write conflicts atomically at commit time, - // preventing lost-update and G0 anomalies on concurrent transactions that - // read the same item at a stale timestamp. - readKeys [][]byte -} - -func (d *DynamoDBServer) buildTransactWriteItemPlan( - ctx context.Context, - schema *dynamoTableSchema, - item transactWriteItem, - readTS uint64, -) (*transactWriteItemPlan, error) { - switch { - case item.Put != nil: - return d.buildTransactPutPlan(ctx, schema, *item.Put, readTS) - case item.Update != nil: - return d.buildTransactUpdatePlan(ctx, schema, *item.Update, readTS) - case item.Delete != nil: - return d.buildTransactDeletePlan(ctx, schema, *item.Delete, readTS) - case item.ConditionCheck != nil: - return d.buildTransactConditionCheckPlan(ctx, schema, *item.ConditionCheck, readTS) - default: - return nil, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, "unsupported transact item") - } -} - -func (d *DynamoDBServer) buildTransactPutPlan( - ctx context.Context, - schema *dynamoTableSchema, - in putItemInput, - readTS uint64, -) (*transactWriteItemPlan, error) { - itemKey, err := schema.itemKeyFromAttributes(in.Item) - if err != nil { - return nil, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) - } - keyAttrs, err := primaryKeyAttributes(schema.PrimaryKey, in.Item) - if err != nil { - return nil, err - } - currentLocation, found, err := d.readLogicalItemAt(ctx, schema, keyAttrs, readTS) - if err != nil { - return nil, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) - } - var current map[string]attributeValue - if found { - current = currentLocation.item - } - if err := validateConditionOnItem( - in.ConditionExpression, - in.ExpressionAttributeNames, - in.ExpressionAttributeValues, - valueOrEmptyMap(current, found), - ); err != nil { - return nil, newDynamoAPIError(http.StatusBadRequest, dynamoErrConditionalFailed, err.Error()) - } - req, cleanup, err := buildItemWriteRequestWithSource(schema, itemKey, in.Item, currentLocation) - if err != nil { - return nil, err - } - return &transactWriteItemPlan{ - elems: req.Elems, - cleanup: cleanup, - writes: true, - readKeys: [][]byte{itemKey}, - }, nil -} - -func (d *DynamoDBServer) buildTransactUpdatePlan( - ctx context.Context, - schema *dynamoTableSchema, - in transactUpdateInput, - readTS uint64, -) (*transactWriteItemPlan, error) { - itemKey, err := schema.itemKeyFromAttributes(in.Key) - if err != nil { - return nil, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) - } - currentLocation, found, err := d.readLogicalItemAt(ctx, schema, in.Key, readTS) - if err != nil { - return nil, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) - } - var current map[string]attributeValue - if !found { - current = map[string]attributeValue{} - } else { - current = currentLocation.item - } - updateIn := updateItemInput{ - TableName: in.TableName, - Key: in.Key, - UpdateExpression: in.UpdateExpression, - ConditionExpression: in.ConditionExpression, - ExpressionAttributeNames: in.ExpressionAttributeNames, - ExpressionAttributeValues: in.ExpressionAttributeValues, - } - nextItem, err := buildUpdatedItem(schema, updateIn, current) - if err != nil { - return nil, err - } - req, cleanup, err := buildItemWriteRequestWithSource(schema, itemKey, nextItem, currentLocation) - if err != nil { - return nil, err - } - return &transactWriteItemPlan{ - elems: req.Elems, - cleanup: cleanup, - writes: true, - readKeys: [][]byte{itemKey}, - }, nil -} - -func (d *DynamoDBServer) buildTransactDeletePlan( - ctx context.Context, - schema *dynamoTableSchema, - in transactDeleteInput, - readTS uint64, -) (*transactWriteItemPlan, error) { - itemKey, err := schema.itemKeyFromAttributes(in.Key) - if err != nil { - return nil, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) - } - currentLocation, found, err := d.readLogicalItemAt(ctx, schema, in.Key, readTS) - if err != nil { - return nil, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) - } - current := map[string]attributeValue(nil) - if found { - current = currentLocation.item - } - if err := validateConditionOnItem( - in.ConditionExpression, - in.ExpressionAttributeNames, - in.ExpressionAttributeValues, - valueOrEmptyMap(current, found), - ); err != nil { - return nil, newDynamoAPIError(http.StatusBadRequest, dynamoErrConditionalFailed, err.Error()) - } - if !found { - // Item does not exist at readTS. Track the key so a concurrent create - // after our snapshot can be detected if the overall transaction - // includes write elems and therefore reaches FSM validation. In a - // pure no-op transaction, these read keys may not be validated. - return &transactWriteItemPlan{readKeys: [][]byte{itemKey}}, nil - } - req, err := buildItemDeleteRequestWithSource(currentLocation) - if err != nil { - return nil, err - } - return &transactWriteItemPlan{ - elems: req.Elems, - writes: true, - readKeys: [][]byte{itemKey}, - }, nil -} - -func (d *DynamoDBServer) buildTransactConditionCheckPlan( - ctx context.Context, - schema *dynamoTableSchema, - in transactConditionInput, - readTS uint64, -) (*transactWriteItemPlan, error) { - if strings.TrimSpace(in.ConditionExpression) == "" { - return nil, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, "missing condition expression") - } - itemKey, err := schema.itemKeyFromAttributes(in.Key) - if err != nil { - return nil, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) - } - currentLocation, found, err := d.readLogicalItemAt(ctx, schema, in.Key, readTS) - if err != nil { - return nil, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) - } - current := map[string]attributeValue(nil) - if found { - current = currentLocation.item - } - if err := validateConditionOnItem( - in.ConditionExpression, - in.ExpressionAttributeNames, - in.ExpressionAttributeValues, - valueOrEmptyMap(current, found), - ); err != nil { - return nil, newDynamoAPIError(http.StatusBadRequest, dynamoErrConditionalFailed, err.Error()) - } - lockKey := itemKey - if currentLocation != nil { - lockKey = currentLocation.key - } - lockReq, lockCleanup, err := buildConditionCheckLockRequest(lockKey, current, found) - if err != nil { - return nil, err - } - return &transactWriteItemPlan{ - elems: lockReq.Elems, - cleanup: lockCleanup, - writes: true, - readKeys: [][]byte{itemKey}, - }, nil -} - -func valueOrEmptyMap(item map[string]attributeValue, found bool) map[string]attributeValue { - if found { - return item - } - return map[string]attributeValue{} -} - -func buildItemDeleteRequestWithSource(current *dynamoItemLocation) (*kv.OperationGroup[kv.OP], error) { - if current == nil { - return &kv.OperationGroup[kv.OP]{ - IsTxn: true, - StartTS: 0, - Elems: nil, - }, nil - } - elems := []*kv.Elem[kv.OP]{{Op: kv.Del, Key: current.key}} - delKeys, err := itemStorageKeys(current) - if err != nil { - return nil, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) - } - for _, key := range delKeys { - if bytes.Equal(key, current.key) { - continue - } - elems = append(elems, &kv.Elem[kv.OP]{Op: kv.Del, Key: key}) - } - return &kv.OperationGroup[kv.OP]{ - IsTxn: true, - StartTS: 0, - Elems: elems, - }, nil -} - -func buildConditionCheckLockRequest( - itemKey []byte, - current map[string]attributeValue, - found bool, -) (*kv.OperationGroup[kv.OP], [][]byte, error) { - if !found { - // Item does not exist: no write is needed. - // Include itemKey in ReadKeys only so OCC conflict detection fires - // if a concurrent writer commits to this key between our startTS and commitTS. - // Writing a Del tombstone here would shadow any concurrently committed Put - // at a higher timestamp, causing G-single-item-realtime anomalies. - // Return nil cleanup since nothing was written by this condition check. - return &kv.OperationGroup[kv.OP]{ - IsTxn: true, - StartTS: 0, - Elems: nil, - }, - nil, - nil - } - payload, err := encodeStoredDynamoItem(current) - if err != nil { - return nil, nil, errors.WithStack(err) - } - elems := []*kv.Elem[kv.OP]{{Op: kv.Put, Key: itemKey, Value: payload}} - return &kv.OperationGroup[kv.OP]{ - IsTxn: true, - StartTS: 0, - Elems: elems, - }, - [][]byte{itemKey}, - nil -} - -func (d *DynamoDBServer) resolveTransactTableSchema(ctx context.Context, cache map[string]*dynamoTableSchema, tableName string, readTS uint64) (*dynamoTableSchema, error) { - if schema := cache[tableName]; schema != nil { - return schema, nil - } - schema, exists, err := d.loadTableSchemaAt(ctx, tableName, readTS) - if err != nil { - return nil, errors.WithStack(err) - } - if !exists { - return nil, newDynamoAPIError(http.StatusBadRequest, dynamoErrResourceNotFound, "table not found") - } - cache[tableName] = schema - return schema, nil -} - -func isRetryableTransactWriteError(err error) bool { - return errors.Is(err, store.ErrWriteConflict) || errors.Is(err, kv.ErrTxnLocked) -} - -func waitTransactRetryBackoff(ctx context.Context, delay time.Duration) error { - timer := time.NewTimer(delay) - defer timer.Stop() - - select { - case <-ctx.Done(): - return errors.WithStack(ctx.Err()) - case <-timer.C: - return nil - } -} - -func waitRetryWithDeadline(ctx context.Context, deadline time.Time, backoff time.Duration) error { - remaining := time.Until(deadline) - if remaining <= 0 { - return errors.New("retry timeout") - } - delay := min(backoff, remaining) - return waitTransactRetryBackoff(ctx, delay) -} - -func nextTransactRetryBackoff(current time.Duration) time.Duration { - next := current * transactRetryBackoffFactor - if next > transactRetryMaxBackoff { - return transactRetryMaxBackoff - } - return next -} - -var errTableGenerationChanged = errors.New("table generation changed") - -func (d *DynamoDBServer) verifyTableGeneration(ctx context.Context, tableName string, expectedGeneration uint64) error { - // Use consistentReadLatestTS to always read the latest committed schema. - // Using a stale snapshotTS can cause false "table not found" results when - // this node's LastCommitTS is behind the table creation timestamp, which - // would erroneously trigger cleanupCommittedKeys and delete live item data. - schema, exists, err := d.loadTableSchemaAt(ctx, tableName, consistentReadLatestTS) - if err != nil { - return errors.WithStack(err) - } - if !exists { - return newDynamoAPIError(http.StatusBadRequest, dynamoErrResourceNotFound, "table not found") - } - if schema.Generation != expectedGeneration { - return errors.Wrapf(errTableGenerationChanged, - "table generation changed (table=%s expected=%d actual=%d)", - tableName, expectedGeneration, schema.Generation, - ) - } - return nil -} - -func (d *DynamoDBServer) verifyTableGenerations(ctx context.Context, generations map[string]uint64) error { - for tableName, generation := range generations { - if err := d.verifyTableGeneration(ctx, tableName, generation); err != nil { - return err - } - } - return nil -} - -func isGenerationFenceFailure(err error) bool { - return errors.Is(err, errTableGenerationChanged) || isTableNotFoundError(err) -} - -func (d *DynamoDBServer) handleGenerationFenceResult(ctx context.Context, err error, cleanupKeys [][]byte) (bool, error) { - if err == nil { - return false, nil - } - if !isGenerationFenceFailure(err) { - return false, err - } - if cleanupErr := d.cleanupCommittedKeys(ctx, cleanupKeys); cleanupErr != nil { - return false, cleanupErr - } - if errors.Is(err, errTableGenerationChanged) { - return true, nil - } - return false, err -} - -func isTableNotFoundError(err error) bool { - var apiErr *dynamoAPIError - if !errors.As(err, &apiErr) { - return false - } - return apiErr.errorType == dynamoErrResourceNotFound -} - -func (d *DynamoDBServer) cleanupCommittedKeys(ctx context.Context, keys [][]byte) error { - uniq := uniqueKeys(keys) - if len(uniq) == 0 { - return nil - } - return d.dispatchDeleteBatch(ctx, uniq) -} - -func uniqueKeys(keys [][]byte) [][]byte { - seen := make(map[string]struct{}, len(keys)) - out := make([][]byte, 0, len(keys)) - for _, key := range keys { - s := string(key) - if _, ok := seen[s]; ok { - continue - } - seen[s] = struct{}{} - out = append(out, key) - } - return out -} - -type dynamoAPIError struct { - status int - errorType string - message string -} - -func (e *dynamoAPIError) Error() string { - if e == nil { - return "" - } - if e.message != "" { - return e.message - } - return http.StatusText(e.status) -} - -func newDynamoAPIError(status int, errorType string, message string) error { - return &dynamoAPIError{ - status: status, - errorType: errorType, - message: message, - } -} - -func writeDynamoErrorFromErr(w http.ResponseWriter, err error) { - var apiErr *dynamoAPIError - if errors.As(err, &apiErr) { - writeDynamoError(w, apiErr.status, apiErr.errorType, apiErr.message) - return - } - writeDynamoError(w, http.StatusInternalServerError, dynamoErrInternal, err.Error()) -} - -// dynamoErrIsTransient reports whether err is a transient/internal failure -// (Pebble error, context deadline, decode failure) as opposed to a structured -// validation/malformed-input error. A *dynamoAPIError is always a deliberate -// validation result (it carries an HTTP status + error type), so it is NOT -// transient; everything else — a raw wrapped store/context error — is. The -// lease pre-pass uses this to decide whether an unresolvable item must fail -// closed (transient) or may be skipped (validation, rejected identically by -// the read path). -func dynamoErrIsTransient(err error) bool { - if err == nil { - return false - } - var apiErr *dynamoAPIError - return !errors.As(err, &apiErr) -} - -func writeDynamoError(w http.ResponseWriter, status int, errorType string, message string) { - if message == "" { - message = http.StatusText(status) - } - - resp := map[string]string{"message": message} - if errorType != "" { - resp["__type"] = errorType - w.Header().Set("x-amzn-ErrorType", errorType) - } - w.Header().Set("Content-Type", "application/x-amz-json-1.0") - w.WriteHeader(status) - _ = json.NewEncoder(w).Encode(resp) -} - -func writeDynamoJSON(w http.ResponseWriter, payload any) { - w.Header().Set("Content-Type", "application/x-amz-json-1.0") - w.WriteHeader(http.StatusOK) - _ = json.NewEncoder(w).Encode(payload) -} - -func replaceNames(expr string, names map[string]string) (string, error) { - if expr == "" || len(names) == 0 { - return expr, nil - } - if err := validateExpressionAttributeNames(names); err != nil { - return "", err - } - keys := make([]string, 0, len(names)) - for k := range names { - keys = append(keys, k) - } - sort.Slice(keys, func(i, j int) bool { - if len(keys[i]) == len(keys[j]) { - return keys[i] < keys[j] - } - return len(keys[i]) > len(keys[j]) - }) - - // DynamoDB expression attribute names are substituted once. - args := make([]string, 0, len(keys)*replacerArgPairSize) - for _, key := range keys { - args = append(args, key, names[key]) - } - return strings.NewReplacer(args...).Replace(expr), nil -} - -func validateExpressionAttributeNames(names map[string]string) error { - for placeholder, name := range names { - if !isExpressionAttributePlaceholder(placeholder) { - return errors.Errorf("invalid expression attribute placeholder %q", placeholder) - } - if !isExpressionAttributeName(name) { - return errors.Errorf("invalid expression attribute name %q for placeholder %q", name, placeholder) - } - } - return nil -} - -func isExpressionAttributePlaceholder(s string) bool { - if len(s) <= 1 || s[0] != '#' { - return false - } - return isExpressionPlaceholderIdentifier(s[1:]) -} - -func isExpressionPlaceholderIdentifier(s string) bool { - if s == "" { - return false - } - for i := 0; i < len(s); i++ { - if isExpressionPlaceholderIdentByte(s[i]) { - continue - } - return false - } - return true -} - -func isExpressionPlaceholderIdentByte(b byte) bool { - return b == '_' || (b >= 'a' && b <= 'z') || (b >= 'A' && b <= 'Z') || (b >= '0' && b <= '9') -} - -func isExpressionAttributeName(s string) bool { - if s == "" { - return false - } - for i := 0; i < len(s); i++ { - if isExpressionAttributeNameByte(s[i]) { - continue - } - return false - } - return true -} - -func isExpressionAttributeNameByte(b byte) bool { - return b == '_' || b == '.' || b == '-' || (b >= 'a' && b <= 'z') || (b >= 'A' && b <= 'Z') || (b >= '0' && b <= '9') -} - -func applyUpdateExpression(expr string, names map[string]string, values map[string]attributeValue, item map[string]attributeValue) error { - updExpr, err := replaceNames(expr, names) - if err != nil { - return err - } - updExpr = strings.TrimSpace(updExpr) - sections, err := parseUpdateExpressionSections(updExpr) - if err != nil { - return err - } - for _, section := range sections { - if err := applyUpdateExpressionSection(section, values, item); err != nil { - return err - } - } - return nil -} - -type updateExpressionSection struct { - action string - body string -} - -func parseUpdateExpressionSections(expr string) ([]updateExpressionSection, error) { - if strings.TrimSpace(expr) == "" { - return nil, errors.New("unsupported update expression") - } - sections := make([]updateExpressionSection, 0, updateSplitCount) - seen := map[string]struct{}{} - i := skipSpaces(expr, 0) - for i < len(expr) { - action, nextPos, ok := parseUpdateActionToken(expr, i) - if !ok { - return nil, errors.New("unsupported update expression") - } - if _, exists := seen[action]; exists { - return nil, errors.New("duplicate update action") - } - seen[action] = struct{}{} - bodyStart := skipSpaces(expr, nextPos) - bodyEnd := findNextUpdateAction(expr, bodyStart) - if bodyEnd < 0 { - bodyEnd = len(expr) - } - body := strings.TrimSpace(expr[bodyStart:bodyEnd]) - if body == "" { - return nil, errors.New("unsupported update expression") - } - sections = append(sections, updateExpressionSection{action: action, body: body}) - if bodyEnd >= len(expr) { - break - } - i = bodyEnd - } - if len(sections) == 0 { - return nil, errors.New("unsupported update expression") - } - return sections, nil -} - -func applyUpdateExpressionSection(section updateExpressionSection, values map[string]attributeValue, item map[string]attributeValue) error { - switch section.action { - case "SET": - return applySetUpdateAction(section.body, values, item) - case "REMOVE": - return applyRemoveUpdateAction(section.body, item) - case "ADD": - return applyAddUpdateAction(section.body, values, item) - case "DELETE": - return applyDeleteUpdateAction(section.body, values, item) - default: - return errors.New("unsupported update action") - } -} - -func parseUpdateActionToken(expr string, pos int) (string, int, bool) { - actions := []string{"SET", "REMOVE", "ADD", "DELETE"} - for _, action := range actions { - end := pos + len(action) - if end > len(expr) { - continue - } - if !strings.EqualFold(expr[pos:end], action) { - continue - } - if !isLogicalKeywordBoundary(expr, pos-1) || !isLogicalKeywordBoundary(expr, end) { - continue - } - return action, end, true - } - return "", 0, false -} - -func skipSpaces(expr string, pos int) int { - for pos < len(expr) && (expr[pos] == ' ' || expr[pos] == '\t' || expr[pos] == '\n' || expr[pos] == '\r') { - pos++ - } - return pos -} - -func findNextUpdateAction(expr string, start int) int { - depth := 0 - for i := start; i < len(expr); i++ { - depth = nextParenDepth(depth, expr[i]) - if depth != 0 { - continue - } - _, _, ok := parseUpdateActionToken(expr, i) - if ok { - return i - } - } - return -1 -} - -func applySetUpdateAction(body string, values map[string]attributeValue, item map[string]attributeValue) error { - assignments, err := splitTopLevelByComma(body) - if err != nil { - return errors.New("invalid update expression") - } - for _, assignment := range assignments { - if err := applySingleSetAssignment(assignment, values, item); err != nil { - return err - } - } - return nil -} - -func applySingleSetAssignment(assignment string, values map[string]attributeValue, item map[string]attributeValue) error { - parts := strings.SplitN(assignment, "=", updateSplitCount) - if len(parts) != updateSplitCount { - return errors.New("invalid update expression") - } - path := strings.TrimSpace(parts[0]) - if path == "" { - return errors.New("invalid update expression attribute") - } - valueExpr := strings.TrimSpace(parts[1]) - valueAttr, err := evalUpdateValueExpression(valueExpr, values, item) - if err != nil { - return err - } - return setDocumentPath(item, path, valueAttr) -} - -func applyRemoveUpdateAction(body string, item map[string]attributeValue) error { - attrs, err := splitTopLevelByComma(body) - if err != nil { - return errors.New("invalid update expression") - } - for _, attr := range attrs { - path := strings.TrimSpace(attr) - if path == "" { - return errors.New("invalid update expression attribute") - } - if err := removeDocumentPath(item, path); err != nil { - return err - } - } - return nil -} - -func applyAddUpdateAction(body string, values map[string]attributeValue, item map[string]attributeValue) error { - terms, err := splitTopLevelByComma(body) - if err != nil { - return errors.New("invalid update expression") - } - for _, term := range terms { - if err := applySingleAddTerm(term, values, item); err != nil { - return err - } - } - return nil -} - -func applySingleAddTerm(term string, values map[string]attributeValue, item map[string]attributeValue) error { - parts := strings.Fields(term) - if len(parts) != updateSplitCount { - return errors.New("invalid update expression") - } - path := strings.TrimSpace(parts[0]) - placeholder := strings.TrimSpace(parts[1]) - if path == "" || !strings.HasPrefix(placeholder, ":") { - return errors.New("invalid update expression") - } - addValue, ok := values[placeholder] - if !ok { - return errors.New("missing value attribute") - } - current, exists, err := resolveDocumentPath(item, path) - if err != nil { - return err - } - next, err := addAttributeValue(current, exists, addValue) - if err != nil { - return err - } - return setDocumentPath(item, path, next) -} - -func addNumericAttributeValues(left string, right string) (string, error) { - leftRat, rightRat := &big.Rat{}, &big.Rat{} - if _, ok := leftRat.SetString(strings.TrimSpace(left)); !ok { - return "", errors.New("invalid number attribute") - } - if _, ok := rightRat.SetString(strings.TrimSpace(right)); !ok { - return "", errors.New("invalid number attribute") - } - sum := &big.Rat{} - sum.Add(leftRat, rightRat) - if sum.IsInt() { - return sum.Num().String(), nil - } - out := strings.TrimRight(sum.FloatString(numericUpdateScaleDigits), "0") - out = strings.TrimRight(out, ".") - if out == "" { - return "0", nil - } - return out, nil -} - -func applyDeleteUpdateAction(body string, values map[string]attributeValue, item map[string]attributeValue) error { - terms, err := splitTopLevelByComma(body) - if err != nil { - return errors.New("invalid update expression") - } - for _, term := range terms { - if err := applySingleDeleteTerm(term, values, item); err != nil { - return err - } - } - return nil -} - -func applySingleDeleteTerm(term string, values map[string]attributeValue, item map[string]attributeValue) error { - fields := strings.Fields(strings.TrimSpace(term)) - switch len(fields) { - case 0: - return errors.New("invalid update expression") - case 1: - return removeDocumentPath(item, fields[0]) - case updateSplitCount: - return applyDeleteSetTerm(fields[0], fields[1], values, item) - default: - return errors.New("invalid update expression") - } -} - -func applyDeleteSetTerm(pathExpr string, placeholderExpr string, values map[string]attributeValue, item map[string]attributeValue) error { - path := strings.TrimSpace(pathExpr) - placeholder := strings.TrimSpace(placeholderExpr) - if path == "" || !strings.HasPrefix(placeholder, ":") { - return errors.New("invalid update expression") - } - deleteValue, ok := values[placeholder] - if !ok { - return errors.New("missing value attribute") - } - current, found, err := resolveDocumentPath(item, path) - if err != nil || !found { - return err - } - next, removeAttr, err := deleteAttributeValueElements(current, deleteValue) - if err != nil { - return err - } - if removeAttr { - return removeDocumentPath(item, path) - } - return setDocumentPath(item, path, next) -} - -func splitTopLevelByComma(expr string) ([]string, error) { - depth := 0 - last := 0 - parts := make([]string, 0, splitPartsInitialCapacity) - for i := 0; i < len(expr); i++ { - depth = nextParenDepth(depth, expr[i]) - if depth < 0 { - return nil, errors.New("invalid expression") - } - if depth != 0 || expr[i] != ',' { - continue - } - part := strings.TrimSpace(expr[last:i]) - if part == "" { - return nil, errors.New("invalid expression") - } - parts = append(parts, part) - last = i + 1 - } - if depth != 0 { - return nil, errors.New("invalid expression") - } - tail := strings.TrimSpace(expr[last:]) - if tail == "" { - return nil, errors.New("invalid expression") - } - return append(parts, tail), nil -} - -type documentPathToken struct { - attr string - index int - isIndex bool -} - -func parseDocumentPath(path string) ([]documentPathToken, error) { - path = strings.TrimSpace(path) - if path == "" { - return nil, errors.New("invalid document path") - } - tokens := make([]documentPathToken, 0, updateSplitCount) - for pos := 0; pos < len(path); { - nextPos, token, err := consumeDocumentPathToken(path, pos) - if err != nil { - return nil, err - } - pos = nextPos - if token.attr != "" || token.isIndex { - tokens = append(tokens, token) - } - } - if len(tokens) == 0 { - return nil, errors.New("invalid document path") - } - return tokens, nil -} - -func consumeDocumentPathToken(path string, pos int) (int, documentPathToken, error) { - switch path[pos] { - case '.': - return pos + 1, documentPathToken{}, nil - case '[': - return consumeDocumentPathIndex(path, pos) - default: - return consumeDocumentPathAttr(path, pos) - } -} - -func consumeDocumentPathIndex(path string, pos int) (int, documentPathToken, error) { - end := strings.IndexByte(path[pos:], ']') - if end <= 1 { - return 0, documentPathToken{}, errors.New("invalid document path") - } - indexValue, err := strconv.Atoi(path[pos+1 : pos+end]) - if err != nil || indexValue < 0 { - return 0, documentPathToken{}, errors.New("invalid document path") - } - return pos + end + 1, documentPathToken{index: indexValue, isIndex: true}, nil -} - -func consumeDocumentPathAttr(path string, pos int) (int, documentPathToken, error) { - start := pos - for pos < len(path) && path[pos] != '.' && path[pos] != '[' { - pos++ - } - attr := strings.TrimSpace(path[start:pos]) - if attr == "" { - return 0, documentPathToken{}, errors.New("invalid document path") - } - return pos, documentPathToken{attr: attr}, nil -} - -func resolveDocumentPath(item map[string]attributeValue, path string) (attributeValue, bool, error) { - tokens, err := parseDocumentPath(path) - if err != nil { - return attributeValue{}, false, err - } - current := attributeValue{M: item} - found := true - for _, token := range tokens { - current, found = nextDocumentPathValue(current, found, token) - if !found { - return attributeValue{}, false, nil - } - } - return cloneAttributeValue(current), true, nil -} - -func nextDocumentPathValue(current attributeValue, found bool, token documentPathToken) (attributeValue, bool) { - if !found { - return attributeValue{}, false - } - if token.isIndex { - if !current.hasListType() || token.index >= len(current.L) { - return attributeValue{}, false - } - return current.L[token.index], true - } - if !current.hasMapType() { - return attributeValue{}, false - } - value, ok := current.M[token.attr] - if !ok { - return attributeValue{}, false - } - return value, true -} - -func setDocumentPath(item map[string]attributeValue, path string, value attributeValue) error { - tokens, err := parseDocumentPath(path) - if err != nil { - return err - } - root, err := setDocumentPathValue(attributeValue{M: cloneAttributeValueMap(item)}, true, tokens, value) - if err != nil { - return err - } - replaceAttributeValueMap(item, root.M) - return nil -} - -func setDocumentPathValue(current attributeValue, exists bool, tokens []documentPathToken, value attributeValue) (attributeValue, error) { - if len(tokens) == 0 { - return cloneAttributeValue(value), nil - } - token := tokens[0] - if token.isIndex { - return setDocumentPathIndex(current, exists, token, tokens[1:], value) - } - return setDocumentPathAttribute(current, exists, token, tokens[1:], value) -} - -func setDocumentPathIndex(current attributeValue, exists bool, token documentPathToken, rest []documentPathToken, value attributeValue) (attributeValue, error) { - if !exists || !current.hasListType() { - return attributeValue{}, errors.New("invalid document path") - } - list := cloneAttributeValueList(current.L) - if token.index > len(list) { - return attributeValue{}, errors.New("invalid document path") - } - if token.index == len(list) { - return appendDocumentPathIndex(list, rest, value) - } - nextValue, err := setDocumentPathValue(list[token.index], true, rest, value) - if err != nil { - return attributeValue{}, err - } - list[token.index] = nextValue - return attributeValue{L: list}, nil -} - -func appendDocumentPathIndex(list []attributeValue, rest []documentPathToken, value attributeValue) (attributeValue, error) { - child := value - if len(rest) > 0 { - var err error - child, err = setDocumentPathValue(newDocumentContainer(rest[0]), true, rest, value) - if err != nil { - return attributeValue{}, err - } - } - list = append(list, cloneAttributeValue(child)) - return attributeValue{L: list}, nil -} - -func setDocumentPathAttribute(current attributeValue, exists bool, token documentPathToken, rest []documentPathToken, value attributeValue) (attributeValue, error) { - var object map[string]attributeValue - if exists { - if !current.hasMapType() { - return attributeValue{}, errors.New("invalid document path") - } - object = cloneAttributeValueMap(current.M) - } else { - object = map[string]attributeValue{} - } - child, childExists := object[token.attr] - if !childExists && len(rest) > 0 { - child = newDocumentContainer(rest[0]) - childExists = true - } - nextValue, err := setDocumentPathValue(child, childExists, rest, value) - if err != nil { - return attributeValue{}, err - } - object[token.attr] = nextValue - return attributeValue{M: object}, nil -} - -func newDocumentContainer(next documentPathToken) attributeValue { - if next.isIndex { - return attributeValue{L: []attributeValue{}} - } - return attributeValue{M: map[string]attributeValue{}} -} - -func removeDocumentPath(item map[string]attributeValue, path string) error { - tokens, err := parseDocumentPath(path) - if err != nil { - return err - } - root, err := removeDocumentPathValue(attributeValue{M: cloneAttributeValueMap(item)}, true, tokens) - if err != nil { - return err - } - replaceAttributeValueMap(item, root.M) - return nil -} - -func removeDocumentPathValue(current attributeValue, exists bool, tokens []documentPathToken) (attributeValue, error) { - if !exists || len(tokens) == 0 { - return current, nil - } - token := tokens[0] - if token.isIndex { - return removeDocumentPathIndex(current, token, tokens[1:]) - } - return removeDocumentPathAttribute(current, token, tokens[1:]) -} - -func removeDocumentPathIndex(current attributeValue, token documentPathToken, rest []documentPathToken) (attributeValue, error) { - if !current.hasListType() || token.index >= len(current.L) { - return current, nil - } - list := cloneAttributeValueList(current.L) - if len(rest) == 0 { - list = append(list[:token.index], list[token.index+1:]...) - return attributeValue{L: list}, nil - } - nextValue, err := removeDocumentPathValue(list[token.index], true, rest) - if err != nil { - return attributeValue{}, err - } - list[token.index] = nextValue - return attributeValue{L: list}, nil -} - -func removeDocumentPathAttribute(current attributeValue, token documentPathToken, rest []documentPathToken) (attributeValue, error) { - if !current.hasMapType() { - return current, nil - } - object := cloneAttributeValueMap(current.M) - child, ok := object[token.attr] - if !ok { - return current, nil - } - if len(rest) == 0 { - delete(object, token.attr) - return attributeValue{M: object}, nil - } - nextValue, err := removeDocumentPathValue(child, true, rest) - if err != nil { - return attributeValue{}, err - } - object[token.attr] = nextValue - return attributeValue{M: object}, nil -} - -func replaceAttributeValueMap(dst map[string]attributeValue, src map[string]attributeValue) { - clear(dst) - maps.Copy(dst, src) -} - -func deleteAttributeValueElements(current attributeValue, deleteValue attributeValue) (attributeValue, bool, error) { - currentKind, _ := detectAttributeValueKind(current) - deleteKind, _ := detectAttributeValueKind(deleteValue) - if currentKind != deleteKind { - return attributeValue{}, false, errors.New("DELETE supports only matching set attribute types") - } - switch currentKind { - case attributeValueKindStringSet: - return buildDeleteSetResult(attributeValue{SS: subtractStringSet(current.SS, deleteValue.SS)}) - case attributeValueKindNumberSet: - return buildDeleteSetResult(attributeValue{NS: subtractNumberSet(current.NS, deleteValue.NS)}) - case attributeValueKindBinarySet: - return buildDeleteSetResult(attributeValue{BS: subtractBinarySet(current.BS, deleteValue.BS)}) - case attributeValueKindInvalid, - attributeValueKindString, - attributeValueKindNumber, - attributeValueKindBinary, - attributeValueKindBool, - attributeValueKindNull, - attributeValueKindList, - attributeValueKindMap: - return attributeValue{}, false, errors.New("DELETE supports only matching set attribute types") - } - return attributeValue{}, false, errors.New("DELETE supports only matching set attribute types") -} - -func buildDeleteSetResult(next attributeValue) (attributeValue, bool, error) { - if next.hasStringSetType() && len(next.SS) == 0 { - return attributeValue{}, true, nil - } - if next.hasNumberSetType() && len(next.NS) == 0 { - return attributeValue{}, true, nil - } - if next.hasBinarySetType() && len(next.BS) == 0 { - return attributeValue{}, true, nil - } - return next, false, nil -} - -func subtractStringSet(current []string, remove []string) []string { - removeSet := make(map[string]struct{}, len(remove)) - for _, value := range remove { - removeSet[value] = struct{}{} - } - out := make([]string, 0, len(current)) - for _, value := range current { - if _, ok := removeSet[value]; ok { - continue - } - out = append(out, value) - } - return out -} - -func subtractNumberSet(current []string, remove []string) []string { - removeSet := make(map[string]struct{}, len(remove)) - for _, value := range remove { - removeSet[canonicalNumberString(value)] = struct{}{} - } - out := make([]string, 0, len(current)) - for _, value := range current { - if _, ok := removeSet[canonicalNumberString(value)]; ok { - continue - } - out = append(out, value) - } - return out -} - -func subtractBinarySet(current [][]byte, remove [][]byte) [][]byte { - removeSet := make(map[string]struct{}, len(remove)) - for _, value := range remove { - removeSet[string(value)] = struct{}{} - } - out := make([][]byte, 0, len(current)) - for _, value := range current { - if _, ok := removeSet[string(value)]; ok { - continue - } - out = append(out, bytes.Clone(value)) - } - return out -} - -func evalUpdateValueExpression(expr string, values map[string]attributeValue, item map[string]attributeValue) (attributeValue, error) { - expr = strings.TrimSpace(expr) - if expr == "" { - return attributeValue{}, errors.New("invalid update expression") - } - if value, handled, err := evalArithmeticUpdateOperand(expr, values, item); handled { - return value, err - } - if value, handled, err := evalNamedUpdateFunction(expr, values, item, "if_not_exists", evalIfNotExistsUpdateValue); handled { - return value, err - } - if value, handled, err := evalNamedUpdateFunction(expr, values, item, "list_append", evalListAppendUpdateValue); handled { - return value, err - } - return evalUpdateTerminalValue(expr, values, item) -} - -func evalArithmeticUpdateOperand(expr string, values map[string]attributeValue, item map[string]attributeValue) (attributeValue, bool, error) { - index, op, ok := findTopLevelArithmeticOperator(expr) - if !ok { - return attributeValue{}, false, nil - } - left, err := evalUpdateValueExpression(expr[:index], values, item) - if err != nil { - return attributeValue{}, true, err - } - right, err := evalUpdateValueExpression(expr[index+1:], values, item) - if err != nil { - return attributeValue{}, true, err - } - value, err := applyArithmeticUpdateValue(left, right, op) - return value, true, err -} - -func evalNamedUpdateFunction( - expr string, - values map[string]attributeValue, - item map[string]attributeValue, - name string, - eval func([]string, map[string]attributeValue, map[string]attributeValue) (attributeValue, error), -) (attributeValue, bool, error) { - args, ok, err := parseExpressionFunctionArgs(expr, name) - if err != nil { - return attributeValue{}, true, err - } - if !ok { - return attributeValue{}, false, nil - } - value, err := eval(args, values, item) - return value, true, err -} - -func evalUpdateTerminalValue(expr string, values map[string]attributeValue, item map[string]attributeValue) (attributeValue, error) { - if strings.HasPrefix(expr, ":") { - value, ok := values[expr] - if !ok { - return attributeValue{}, errors.New("missing value attribute") - } - return cloneAttributeValue(value), nil - } - value, found, err := resolveDocumentPath(item, expr) - if err != nil { - return attributeValue{}, err - } - if !found { - return attributeValue{}, errors.New("missing value attribute") - } - return value, nil -} - -func evalIfNotExistsUpdateValue(args []string, values map[string]attributeValue, item map[string]attributeValue) (attributeValue, error) { - if len(args) != updateSplitCount { - return attributeValue{}, errors.New("invalid update expression") - } - current, found, err := resolveDocumentPath(item, strings.TrimSpace(args[0])) - if err != nil { - return attributeValue{}, err - } - if found { - return current, nil - } - return evalUpdateValueExpression(args[1], values, item) -} - -func evalListAppendUpdateValue(args []string, values map[string]attributeValue, item map[string]attributeValue) (attributeValue, error) { - if len(args) != updateSplitCount { - return attributeValue{}, errors.New("invalid update expression") - } - left, err := evalUpdateValueExpression(args[0], values, item) - if err != nil { - return attributeValue{}, err - } - right, err := evalUpdateValueExpression(args[1], values, item) - if err != nil { - return attributeValue{}, err - } - if !left.hasListType() || !right.hasListType() { - return attributeValue{}, errors.New("list_append supports only list attributes") - } - out := make([]attributeValue, 0, len(left.L)+len(right.L)) - for _, value := range left.L { - out = append(out, cloneAttributeValue(value)) - } - for _, value := range right.L { - out = append(out, cloneAttributeValue(value)) - } - return attributeValue{L: out}, nil -} - -func applyArithmeticUpdateValue(left attributeValue, right attributeValue, op byte) (attributeValue, error) { - if !left.hasNumberType() || !right.hasNumberType() { - return attributeValue{}, errors.New("arithmetic update supports only number attributes") - } - rightValue := right.numberValue() - if op == '-' { - rightValue = "-" + rightValue - } - sum, err := addNumericAttributeValues(left.numberValue(), rightValue) - if err != nil { - return attributeValue{}, err - } - return attributeValue{N: &sum}, nil -} - -func findTopLevelArithmeticOperator(expr string) (int, byte, bool) { - depth := 0 - for i := 0; i < len(expr); i++ { - depth = nextParenDepth(depth, expr[i]) - if depth != 0 { - continue - } - switch expr[i] { - case '+', '-': - if i == 0 { - continue - } - return i, expr[i], true - } - } - return 0, 0, false -} - -func parseExpressionFunctionArgs(expr string, funcName string) ([]string, bool, error) { - prefix := funcName + "(" - if !strings.HasPrefix(strings.ToLower(expr), strings.ToLower(prefix)) || !strings.HasSuffix(expr, ")") { - return nil, false, nil - } - inner := strings.TrimSpace(expr[len(prefix) : len(expr)-1]) - parts, err := splitTopLevelByComma(inner) - if err != nil { - return nil, true, errors.New("invalid expression") - } - for i := range parts { - parts[i] = strings.TrimSpace(parts[i]) - } - return parts, true, nil -} - -func addAttributeValue(current attributeValue, exists bool, addValue attributeValue) (attributeValue, error) { - if addValue.hasNumberType() { - return addNumericUpdateValue(current, exists, addValue) - } - if addValue.hasStringSetType() || addValue.hasNumberSetType() || addValue.hasBinarySetType() { - return addSetUpdateValue(current, exists, addValue) - } - return attributeValue{}, errors.New("ADD supports only number or set attributes") -} - -func addNumericUpdateValue(current attributeValue, exists bool, addValue attributeValue) (attributeValue, error) { - if !exists { - return cloneAttributeValue(addValue), nil - } - if !current.hasNumberType() { - return attributeValue{}, errors.New("ADD supports only number attributes") - } - sum, err := addNumericAttributeValues(current.numberValue(), addValue.numberValue()) - if err != nil { - return attributeValue{}, err - } - return attributeValue{N: &sum}, nil -} - -func addSetUpdateValue(current attributeValue, exists bool, addValue attributeValue) (attributeValue, error) { - if !exists { - return cloneAttributeValue(addValue), nil - } - switch { - case current.hasStringSetType() && addValue.hasStringSetType(): - return attributeValue{SS: mergeStringSet(current.SS, addValue.SS)}, nil - case current.hasNumberSetType() && addValue.hasNumberSetType(): - return attributeValue{NS: mergeNumberSet(current.NS, addValue.NS)}, nil - case current.hasBinarySetType() && addValue.hasBinarySetType(): - return attributeValue{BS: mergeBinarySet(current.BS, addValue.BS)}, nil - default: - return attributeValue{}, errors.New("ADD supports only matching set attribute types") - } -} - -func mergeStringSet(current []string, add []string) []string { - out := make([]string, 0, len(current)+len(add)) - seen := make(map[string]struct{}, len(current)+len(add)) - for _, value := range append(append([]string(nil), current...), add...) { - if _, ok := seen[value]; ok { - continue - } - seen[value] = struct{}{} - out = append(out, value) - } - return out -} - -func mergeNumberSet(current []string, add []string) []string { - out := make([]string, 0, len(current)+len(add)) - seen := make(map[string]struct{}, len(current)+len(add)) - for _, value := range append(append([]string(nil), current...), add...) { - key := canonicalNumberString(value) - if _, ok := seen[key]; ok { - continue - } - seen[key] = struct{}{} - out = append(out, value) - } - return out -} - -func mergeBinarySet(current [][]byte, add [][]byte) [][]byte { - out := make([][]byte, 0, len(current)+len(add)) - seen := make(map[string]struct{}, len(current)+len(add)) - for _, value := range append(cloneBinarySet(current), add...) { - key := string(value) - if _, ok := seen[key]; ok { - continue - } - seen[key] = struct{}{} - out = append(out, bytes.Clone(value)) - } - return out -} - -func validateConditionOnItem(expr string, names map[string]string, values map[string]attributeValue, item map[string]attributeValue) error { - cond, err := replaceNames(expr, names) - if err != nil { - return err - } - cond = strings.TrimSpace(cond) - if cond == "" { - return nil - } - ok, err := evalConditionExpression(cond, item, values) - if err != nil { - return err - } - if !ok { - return errors.New("conditional check failed") - } - return nil -} - -func evalConditionExpression(expr string, item map[string]attributeValue, values map[string]attributeValue) (bool, error) { - expr = trimOuterParens(strings.TrimSpace(expr)) - if expr == "" { - return true, nil - } - if ok, handled, err := evalLogicalCondition(expr, "OR", item, values); handled { - return ok, err - } - if ok, handled, err := evalLogicalCondition(expr, "AND", item, values); handled { - return ok, err - } - if rest, ok := trimLeadingKeyword(expr, "NOT"); ok { - ok, err := evalConditionExpression(rest, item, values) - if err != nil { - return false, err - } - return !ok, nil - } - return evalAtomicCondition(expr, item, values) -} - -func trimOuterParens(expr string) string { - for { - expr = strings.TrimSpace(expr) - if !hasOuterParens(expr) { - return expr - } - expr = expr[1 : len(expr)-1] - } -} - -func splitTopLevelByKeyword(expr string, keyword string) []string { - if expr == "" { - return nil - } - upper := strings.ToUpper(expr) - target := strings.ToUpper(keyword) - targetLen := len(target) - if targetLen == 0 { - return nil - } - depth := 0 - last := 0 - betweenPending := false - parts := make([]string, 0, splitPartsInitialCapacity) - for i := 0; i < len(expr); i++ { - depth = nextParenDepth(depth, expr[i]) - if depth != 0 { - continue - } - if nextIndex, handled, nextPending := consumeBetweenSplitState(expr, upper, keyword, i, betweenPending); handled { - betweenPending = nextPending - i = nextIndex - continue - } - if !shouldSplitKeywordAt(expr, upper, target, targetLen, i) { - continue - } - part, ok := trimmedNonEmpty(expr[last:i]) - if !ok { - return nil - } - parts = append(parts, part) - i += targetLen - 1 - last = i + 1 - } - return finalizeKeywordSplit(expr[last:], parts) -} - -func consumeBetweenSplitState(expr string, upper string, keyword string, index int, betweenPending bool) (int, bool, bool) { - if !strings.EqualFold(keyword, "AND") { - return index, false, betweenPending - } - if matchesLogicalKeyword(expr, upper, "BETWEEN", index) { - return index + len("BETWEEN") - 1, true, true - } - if betweenPending && matchesLogicalKeyword(expr, upper, "AND", index) { - return index + len("AND") - 1, true, false - } - return index, false, betweenPending -} - -func shouldSplitKeywordAt(expr string, upper string, target string, targetLen int, index int) bool { - return matchesKeywordTokenAt(upper, target, index) && - isLogicalKeywordBoundary(expr, index-1) && - isLogicalKeywordBoundary(expr, index+targetLen) -} - -func matchesLogicalKeyword(expr string, upper string, keyword string, index int) bool { - return matchesKeywordTokenAt(upper, keyword, index) && - isLogicalKeywordBoundary(expr, index-1) && - isLogicalKeywordBoundary(expr, index+len(keyword)) -} - -func evalLogicalCondition(expr string, keyword string, item map[string]attributeValue, values map[string]attributeValue) (bool, bool, error) { - parts := splitTopLevelByKeyword(expr, keyword) - if len(parts) == 0 { - return false, false, nil - } - if strings.EqualFold(keyword, "OR") { - ok, err := evalConditionAny(parts, item, values) - return ok, true, err - } - ok, err := evalConditionAll(parts, item, values) - return ok, true, err -} - -func evalConditionAny(parts []string, item map[string]attributeValue, values map[string]attributeValue) (bool, error) { - for _, part := range parts { - ok, err := evalConditionExpression(part, item, values) - if err != nil { - return false, err - } - if ok { - return true, nil - } - } - return false, nil -} - -func evalConditionAll(parts []string, item map[string]attributeValue, values map[string]attributeValue) (bool, error) { - for _, part := range parts { - ok, err := evalConditionExpression(part, item, values) - if err != nil { - return false, err - } - if !ok { - return false, nil - } - } - return true, nil -} - -func evalAtomicCondition(expr string, item map[string]attributeValue, values map[string]attributeValue) (bool, error) { - for _, handler := range conditionFunctionHandlers { - if ok, handled, err := evalNamedConditionFunction(expr, item, values, handler); handled { - return ok, err - } - } - if ok, handled, err := evalConditionBetween(expr, item, values); handled { - return ok, err - } - if ok, handled, err := evalConditionIn(expr, item, values); handled { - return ok, err - } - return evalConditionComparison(expr, item, values) -} - -type conditionFunctionHandler struct { - name string - eval func([]string, map[string]attributeValue, map[string]attributeValue) (bool, error) -} - -var conditionFunctionHandlers = []conditionFunctionHandler{ - { - name: "attribute_exists", - eval: func(args []string, item map[string]attributeValue, _ map[string]attributeValue) (bool, error) { - return evalAttributeExistsCondition(args, item) - }, - }, - { - name: "attribute_not_exists", - eval: func(args []string, item map[string]attributeValue, _ map[string]attributeValue) (bool, error) { - return evalAttributeNotExistsCondition(args, item) - }, - }, - {name: "attribute_type", eval: evalAttributeTypeCondition}, - {name: "begins_with", eval: evalBeginsWithCondition}, - {name: "contains", eval: evalContainsCondition}, -} - -func evalNamedConditionFunction( - expr string, - item map[string]attributeValue, - values map[string]attributeValue, - handler conditionFunctionHandler, -) (bool, bool, error) { - args, ok, err := parseExpressionFunctionArgs(expr, handler.name) - if err != nil { - return false, true, err - } - if !ok { - return false, false, nil - } - value, err := handler.eval(args, item, values) - return value, true, err -} - -func evalAttributeExistsCondition(args []string, item map[string]attributeValue) (bool, error) { - if len(args) != 1 { - return false, errors.New("unsupported condition expression") - } - _, found, err := resolveDocumentPath(item, args[0]) - if err != nil { - return false, err - } - return found, nil -} - -func evalAttributeNotExistsCondition(args []string, item map[string]attributeValue) (bool, error) { - ok, err := evalAttributeExistsCondition(args, item) - if err != nil { - return false, err - } - return !ok, nil -} - -func evalAttributeTypeCondition(args []string, item map[string]attributeValue, values map[string]attributeValue) (bool, error) { - if len(args) != updateSplitCount { - return false, errors.New("unsupported condition expression") - } - value, found, err := resolveDocumentPath(item, args[0]) - if err != nil || !found { - return false, err - } - typeValue, ok := values[strings.TrimSpace(args[1])] - if !ok || !typeValue.hasStringType() { - return false, errors.New("unsupported condition expression") - } - return dynamoAttributeType(value) == typeValue.stringValue(), nil -} - -func evalBeginsWithCondition(args []string, item map[string]attributeValue, values map[string]attributeValue) (bool, error) { - if len(args) != updateSplitCount { - return false, errors.New("unsupported condition expression") - } - left, found, err := resolveDocumentPath(item, args[0]) - if err != nil || !found { - return false, err - } - right, ok := values[strings.TrimSpace(args[1])] - if !ok { - return false, errors.New("missing condition value") - } - switch { - case left.hasStringType() && right.hasStringType(): - return strings.HasPrefix(left.stringValue(), right.stringValue()), nil - case left.hasBinaryType() && right.hasBinaryType(): - return bytes.HasPrefix(left.B, right.B), nil - default: - return false, nil - } -} - -func evalContainsCondition(args []string, item map[string]attributeValue, values map[string]attributeValue) (bool, error) { - if len(args) != updateSplitCount { - return false, errors.New("unsupported condition expression") - } - left, found, err := resolveDocumentPath(item, args[0]) - if err != nil || !found { - return false, err - } - right, ok := values[strings.TrimSpace(args[1])] - if !ok { - return false, errors.New("missing condition value") - } - return attributeValueContains(left, right), nil -} - -func attributeValueContains(left attributeValue, right attributeValue) bool { - for _, eval := range attributeValueContainsEvaluators { - if handled, ok := eval(left, right); handled { - return ok - } - } - return false -} - -type attributeValueContainsEvaluator func(attributeValue, attributeValue) (bool, bool) - -var attributeValueContainsEvaluators = []attributeValueContainsEvaluator{ - containsStringAttributeValue, - containsBinaryAttributeValue, - containsListAttributeValue, - containsStringSetAttributeValue, - containsNumberSetAttributeValue, - containsBinarySetAttributeValue, -} - -func containsStringAttributeValue(left attributeValue, right attributeValue) (bool, bool) { - if !left.hasStringType() || !right.hasStringType() { - return false, false - } - return true, strings.Contains(left.stringValue(), right.stringValue()) -} - -func containsBinaryAttributeValue(left attributeValue, right attributeValue) (bool, bool) { - if !left.hasBinaryType() || !right.hasBinaryType() { - return false, false - } - return true, bytes.Contains(left.B, right.B) -} - -func containsListAttributeValue(left attributeValue, right attributeValue) (bool, bool) { - if !left.hasListType() { - return false, false - } - return true, listContainsAttributeValue(left.L, right) -} - -func containsStringSetAttributeValue(left attributeValue, right attributeValue) (bool, bool) { - if !left.hasStringSetType() || !right.hasStringType() { - return false, false - } - return true, stringSetContains(left.SS, right.stringValue()) -} - -func containsNumberSetAttributeValue(left attributeValue, right attributeValue) (bool, bool) { - if !left.hasNumberSetType() || !right.hasNumberType() { - return false, false - } - return true, numberSetContains(left.NS, right.numberValue()) -} - -func containsBinarySetAttributeValue(left attributeValue, right attributeValue) (bool, bool) { - if !left.hasBinarySetType() || !right.hasBinaryType() { - return false, false - } - return true, binarySetContains(left.BS, right.B) -} - -func listContainsAttributeValue(values []attributeValue, needle attributeValue) bool { - for _, value := range values { - if attributeValueEqual(value, needle) { - return true - } - } - return false -} - -func stringSetContains(values []string, needle string) bool { - return slices.Contains(values, needle) -} - -func numberSetContains(values []string, needle string) bool { - for _, value := range values { - if cmp, ok := compareNumericAttributeString(value, needle); ok && cmp == 0 { - return true - } - } - return false -} - -func binarySetContains(values [][]byte, needle []byte) bool { - for _, value := range values { - if bytes.Equal(value, needle) { - return true - } - } - return false -} - -func evalConditionBetween(expr string, item map[string]attributeValue, values map[string]attributeValue) (bool, bool, error) { - betweenIndex := findTopLevelKeywordIndex(expr, "BETWEEN") - if betweenIndex < 0 { - return false, false, nil - } - leftExpr := strings.TrimSpace(expr[:betweenIndex]) - rest := strings.TrimSpace(expr[betweenIndex+len("BETWEEN"):]) - andIndex := findTopLevelKeywordIndex(rest, "AND") - if andIndex < 0 { - return false, true, errors.New("unsupported condition expression") - } - lowerExpr := strings.TrimSpace(rest[:andIndex]) - upperExpr := strings.TrimSpace(rest[andIndex+len("AND"):]) - left, found, err := resolveConditionOperand(leftExpr, item, values) - if err != nil || !found { - return false, true, err - } - lower, found, err := resolveConditionOperand(lowerExpr, item, values) - if err != nil || !found { - return false, true, err - } - upper, found, err := resolveConditionOperand(upperExpr, item, values) - if err != nil || !found { - return false, true, err - } - return compareAttributeValueSortKey(left, lower) >= 0 && compareAttributeValueSortKey(left, upper) <= 0, true, nil -} - -func evalConditionIn(expr string, item map[string]attributeValue, values map[string]attributeValue) (bool, bool, error) { - inIndex := findTopLevelKeywordIndex(expr, "IN") - if inIndex < 0 { - return false, false, nil - } - left, parts, err := parseConditionInOperands(expr, inIndex, item, values) - if err != nil { - return false, true, err - } - ok, err := conditionInListContains(left, parts, item, values) - return ok, true, err -} - -func parseConditionInOperands(expr string, inIndex int, item map[string]attributeValue, values map[string]attributeValue) (attributeValue, []string, error) { - leftExpr := strings.TrimSpace(expr[:inIndex]) - rest := strings.TrimSpace(expr[inIndex+len("IN"):]) - if !strings.HasPrefix(rest, "(") || !strings.HasSuffix(rest, ")") { - return attributeValue{}, nil, errors.New("unsupported condition expression") - } - left, found, err := resolveConditionOperand(leftExpr, item, values) - if err != nil || !found { - return attributeValue{}, nil, err - } - parts, err := splitTopLevelByComma(rest[1 : len(rest)-1]) - if err != nil { - return attributeValue{}, nil, errors.New("unsupported condition expression") - } - return left, parts, nil -} - -func conditionInListContains(left attributeValue, parts []string, item map[string]attributeValue, values map[string]attributeValue) (bool, error) { - for _, part := range parts { - candidate, found, err := resolveConditionOperand(part, item, values) - if err != nil { - return false, err - } - if found && attributeValueEqual(left, candidate) { - return true, nil - } - } - return false, nil -} - -func evalConditionComparison(expr string, item map[string]attributeValue, values map[string]attributeValue) (bool, error) { - index, operator, ok := findTopLevelConditionComparator(expr) - if !ok { - return false, errors.New("unsupported condition expression") - } - left, right, err := resolveConditionComparisonOperands(expr, index, operator, item, values) - if err != nil { - return false, err - } - return compareConditionValues(operator, left, right) -} - -func resolveConditionComparisonOperands( - expr string, - index int, - operator string, - item map[string]attributeValue, - values map[string]attributeValue, -) (attributeValue, attributeValue, error) { - leftExpr := strings.TrimSpace(expr[:index]) - rightExpr := strings.TrimSpace(expr[index+len(operator):]) - left, found, err := resolveConditionOperand(leftExpr, item, values) - if err != nil || !found { - return attributeValue{}, attributeValue{}, err - } - right, found, err := resolveConditionOperand(rightExpr, item, values) - if err != nil || !found { - return attributeValue{}, attributeValue{}, err - } - return left, right, nil -} - -func compareConditionValues(operator string, left attributeValue, right attributeValue) (bool, error) { - switch operator { - case "=": - return attributeValueEqual(left, right), nil - case "<>": - return !attributeValueEqual(left, right), nil - case "<": - return compareAttributeValueSortKey(left, right) < 0, nil - case "<=": - return compareAttributeValueSortKey(left, right) <= 0, nil - case ">": - return compareAttributeValueSortKey(left, right) > 0, nil - case ">=": - return compareAttributeValueSortKey(left, right) >= 0, nil - default: - return false, errors.New("unsupported condition expression") - } -} - -func resolveConditionOperand(expr string, item map[string]attributeValue, values map[string]attributeValue) (attributeValue, bool, error) { - expr = strings.TrimSpace(expr) - if expr == "" { - return attributeValue{}, false, errors.New("unsupported condition expression") - } - if args, ok, err := parseExpressionFunctionArgs(expr, "size"); ok || err != nil { - if err != nil { - return attributeValue{}, false, err - } - return resolveConditionSizeOperand(args, item) - } - if strings.HasPrefix(expr, ":") { - value, ok := values[expr] - if !ok { - return attributeValue{}, false, errors.New("missing condition value") - } - return cloneAttributeValue(value), true, nil - } - value, found, err := resolveDocumentPath(item, expr) - if err != nil { - return attributeValue{}, false, err - } - return value, found, nil -} - -func resolveConditionSizeOperand(args []string, item map[string]attributeValue) (attributeValue, bool, error) { - if len(args) != 1 { - return attributeValue{}, false, errors.New("unsupported condition expression") - } - value, found, err := resolveDocumentPath(item, args[0]) - if err != nil || !found { - return attributeValue{}, false, err - } - size := attributeValueSize(value) - sizeString := strconv.Itoa(size) - return attributeValue{N: &sizeString}, true, nil -} - -func attributeValueSize(value attributeValue) int { - switch { - case value.hasStringType(): - return len(value.stringValue()) - case value.hasBinaryType(): - return len(value.B) - case value.hasStringSetType(): - return len(value.SS) - case value.hasNumberSetType(): - return len(value.NS) - case value.hasBinarySetType(): - return len(value.BS) - case value.hasListType(): - return len(value.L) - case value.hasMapType(): - return len(value.M) - default: - return 0 - } -} - -func findTopLevelKeywordIndex(expr string, keyword string) int { - upper := strings.ToUpper(expr) - target := strings.ToUpper(keyword) - depth := 0 - for i := 0; i < len(expr); i++ { - depth = nextParenDepth(depth, expr[i]) - if depth != 0 || !matchesKeywordTokenAt(upper, target, i) { - continue - } - if !isLogicalKeywordBoundary(expr, i-1) || !isLogicalKeywordBoundary(expr, i+len(target)) { - continue - } - return i - } - return -1 -} - -func trimLeadingKeyword(expr string, keyword string) (string, bool) { - upper := strings.ToUpper(strings.TrimSpace(expr)) - keyword = strings.ToUpper(keyword) - if !strings.HasPrefix(upper, keyword) { - return "", false - } - trimmed := strings.TrimSpace(expr) - if !isLogicalKeywordBoundary(trimmed, len(keyword)) { - return "", false - } - return strings.TrimSpace(trimmed[len(keyword):]), true -} - -func findTopLevelConditionComparator(expr string) (int, string, bool) { - operators := []string{"<>", "<=", ">=", "=", "<", ">"} - depth := 0 - for i := 0; i < len(expr); i++ { - depth = nextParenDepth(depth, expr[i]) - if depth != 0 { - continue - } - for _, operator := range operators { - if strings.HasPrefix(expr[i:], operator) { - return i, operator, true - } - } - } - return 0, "", false -} - -func dynamoAttributeType(value attributeValue) string { - kind, count := detectAttributeValueKind(value) - if count != 1 { - return "" - } - return string(kind) -} - -func hasOuterParens(expr string) bool { - if len(expr) < 2 || expr[0] != '(' || expr[len(expr)-1] != ')' { - return false - } - depth := 0 - for i := 0; i < len(expr); i++ { - depth = nextParenDepth(depth, expr[i]) - if depth == 0 && i != len(expr)-1 { - return false - } - if depth < 0 { - return false - } - } - return depth == 0 -} - -func nextParenDepth(depth int, ch byte) int { - switch ch { - case '(': - return depth + 1 - case ')': - return depth - 1 - default: - return depth - } -} - -func matchesKeywordTokenAt(upperExpr string, target string, pos int) bool { - end := pos + len(target) - if end > len(upperExpr) { - return false - } - return upperExpr[pos:end] == target -} - -func isLogicalKeywordBoundary(s string, pos int) bool { - if pos < 0 || pos >= len(s) { - return true - } - ch := s[pos] - // Keep identifier-style characters as token characters so expressions like - // "MY_AND_VAR" or "a-OR-b" are not split at logical keyword substrings. - if isExpressionAttributeNameByte(ch) { - return false - } - return true -} - -func trimmedNonEmpty(s string) (string, bool) { - trimmed := strings.TrimSpace(s) - return trimmed, trimmed != "" -} - -func finalizeKeywordSplit(tailExpr string, parts []string) []string { - if len(parts) == 0 { - return nil - } - tail, ok := trimmedNonEmpty(tailExpr) - if !ok { - return nil - } - return append(parts, tail) -} - -type parsedKeyConditionTerm struct { - attr string - op queryRangeOperator - placeholder1 string - placeholder2 string -} - -func parseKeyConditionExpression(expr string) ([]parsedKeyConditionTerm, error) { - expr = strings.TrimSpace(expr) - if expr == "" { - return nil, errors.New("unsupported key condition expression") - } - parts, err := splitKeyConditionTerms(expr) - if err != nil { - return nil, err - } - if len(parts) > updateSplitCount { - return nil, errors.New("unsupported key condition expression") - } - terms := make([]parsedKeyConditionTerm, 0, len(parts)) - for _, part := range parts { - term, err := parseKeyConditionTerm(part) - if err != nil { - return nil, err - } - terms = append(terms, term) - } - return terms, nil -} - -func splitKeyConditionTerms(expr string) ([]string, error) { - upper := strings.ToUpper(expr) - depth := 0 - last := 0 - betweenPending := false - parts := make([]string, 0, splitPartsInitialCapacity) - for i := 0; i < len(expr); i++ { - depth = nextParenDepth(depth, expr[i]) - if depth != 0 { - continue - } - keyword := keyConditionKeywordAt(expr, upper, i) - if keyword == "" { - continue - } - if keyword == "BETWEEN" { - betweenPending = true - i += len(keyword) - 1 - continue - } - if betweenPending { - betweenPending = false - i += len(keyword) - 1 - continue - } - part, ok := trimmedNonEmpty(expr[last:i]) - if !ok { - return nil, errors.New("unsupported key condition expression") - } - parts = append(parts, part) - i += len(keyword) - 1 - last = i + 1 - } - if betweenPending { - return nil, errors.New("unsupported key condition expression") - } - tail, ok := trimmedNonEmpty(expr[last:]) - if !ok { - return nil, errors.New("unsupported key condition expression") - } - if len(parts) == 0 { - return []string{tail}, nil - } - return append(parts, tail), nil -} - -func keyConditionKeywordAt(expr string, upper string, pos int) string { - if matchesKeywordTokenAt(upper, "BETWEEN", pos) && - isLogicalKeywordBoundary(expr, pos-1) && - isLogicalKeywordBoundary(expr, pos+len("BETWEEN")) { - return "BETWEEN" - } - if matchesKeywordTokenAt(upper, "AND", pos) && - isLogicalKeywordBoundary(expr, pos-1) && - isLogicalKeywordBoundary(expr, pos+len("AND")) { - return "AND" - } - return "" -} - -func parseKeyConditionTerm(term string) (parsedKeyConditionTerm, error) { - term = strings.TrimSpace(term) - if t, ok, err := parseBeginsWithKeyConditionTerm(term); ok || err != nil { - return t, err - } - if t, ok, err := parseBetweenKeyConditionTerm(term); ok || err != nil { - return t, err - } - return parseComparisonKeyConditionTerm(term) -} - -func parseBeginsWithKeyConditionTerm(term string) (parsedKeyConditionTerm, bool, error) { - const prefix = "BEGINS_WITH(" - upper := strings.ToUpper(term) - if !strings.HasPrefix(upper, prefix) { - return parsedKeyConditionTerm{}, false, nil - } - if !strings.HasSuffix(term, ")") { - return parsedKeyConditionTerm{}, true, errors.New("unsupported key condition expression") - } - inner := strings.TrimSpace(term[len(prefix) : len(term)-1]) - parts := strings.SplitN(inner, ",", updateSplitCount) - if len(parts) != updateSplitCount { - return parsedKeyConditionTerm{}, true, errors.New("unsupported key condition expression") - } - attrName := strings.TrimSpace(parts[0]) - placeholder := strings.TrimSpace(parts[1]) - if attrName == "" || !strings.HasPrefix(placeholder, ":") { - return parsedKeyConditionTerm{}, true, errors.New("unsupported key condition expression") - } - return parsedKeyConditionTerm{ - attr: attrName, - op: queryRangeOpBeginsWith, - placeholder1: placeholder, - }, true, nil -} - -func parseBetweenKeyConditionTerm(term string) (parsedKeyConditionTerm, bool, error) { - upper := strings.ToUpper(term) - betweenIdx := strings.Index(upper, " BETWEEN ") - if betweenIdx < 0 { - return parsedKeyConditionTerm{}, false, nil - } - attrName := strings.TrimSpace(term[:betweenIdx]) - rest := strings.TrimSpace(term[betweenIdx+len(" BETWEEN "):]) - andIdx := strings.Index(strings.ToUpper(rest), " AND ") - if andIdx < 0 { - return parsedKeyConditionTerm{}, true, errors.New("unsupported key condition expression") - } - placeholder1 := strings.TrimSpace(rest[:andIdx]) - placeholder2 := strings.TrimSpace(rest[andIdx+len(" AND "):]) - if attrName == "" || !strings.HasPrefix(placeholder1, ":") || !strings.HasPrefix(placeholder2, ":") { - return parsedKeyConditionTerm{}, true, errors.New("unsupported key condition expression") - } - return parsedKeyConditionTerm{ - attr: attrName, - op: queryRangeOpBetween, - placeholder1: placeholder1, - placeholder2: placeholder2, - }, true, nil -} - -func parseComparisonKeyConditionTerm(term string) (parsedKeyConditionTerm, error) { - operators := []queryRangeOperator{ - queryRangeOpLessOrEq, - queryRangeOpGreaterEq, - queryRangeOpLessThan, - queryRangeOpGreater, - queryRangeOpEqual, - } - for _, op := range operators { - if t, ok := splitComparisonTerm(term, op); ok { - return t, nil - } - } - return parsedKeyConditionTerm{}, errors.New("unsupported key condition expression") -} - -func splitComparisonTerm(term string, op queryRangeOperator) (parsedKeyConditionTerm, bool) { - opStr := string(op) - before, after, ok := strings.Cut(term, opStr) - if !ok { - return parsedKeyConditionTerm{}, false - } - left := strings.TrimSpace(before) - right := strings.TrimSpace(after) - if left == "" || !strings.HasPrefix(right, ":") { - return parsedKeyConditionTerm{}, false - } - return parsedKeyConditionTerm{ - attr: left, - op: op, - placeholder1: right, - }, true -} - -func buildQueryCondition(keySchema dynamoKeySchema, terms []parsedKeyConditionTerm, values map[string]attributeValue) (queryCondition, error) { - hashTerm, rangeTerm, err := classifyQueryConditionTerms(keySchema, terms) - if err != nil { - return queryCondition{}, err - } - hashValue, ok := values[hashTerm.placeholder1] - if !ok { - return queryCondition{}, errors.New("missing key condition value") - } - cond := queryCondition{ - hashAttr: keySchema.HashKey, - hashValue: hashValue, - } - if rangeTerm == nil { - return cond, nil - } - value1, ok := values[rangeTerm.placeholder1] - if !ok { - return queryCondition{}, errors.New("missing key condition value") - } - rangeCond := &queryRangeCondition{ - attr: keySchema.RangeKey, - op: rangeTerm.op, - value1: value1, - } - if rangeTerm.op == queryRangeOpBetween { - value2, ok := values[rangeTerm.placeholder2] - if !ok { - return queryCondition{}, errors.New("missing key condition value") - } - rangeCond.value2 = value2 - } - cond.rangeCond = rangeCond - return cond, nil -} - -func classifyQueryConditionTerms( - keySchema dynamoKeySchema, - terms []parsedKeyConditionTerm, -) (parsedKeyConditionTerm, *parsedKeyConditionTerm, error) { - if len(terms) == 0 || len(terms) > updateSplitCount { - return parsedKeyConditionTerm{}, nil, errors.New("unsupported key condition") - } - hashTerm, ok := findHashConditionTerm(keySchema.HashKey, terms) - if !ok { - return parsedKeyConditionTerm{}, nil, errors.New("unsupported key condition") - } - if len(terms) == 1 { - return hashTerm, nil, nil - } - rangeTerm, ok := findRangeConditionTerm(keySchema.RangeKey, terms, hashTerm) - if !ok { - return parsedKeyConditionTerm{}, nil, errors.New("unsupported key condition") - } - return hashTerm, &rangeTerm, nil -} - -func findHashConditionTerm(hashKey string, terms []parsedKeyConditionTerm) (parsedKeyConditionTerm, bool) { - var hashTerm parsedKeyConditionTerm - found := false - for _, term := range terms { - if term.attr != hashKey || term.op != queryRangeOpEqual { - continue - } - if found { - return parsedKeyConditionTerm{}, false - } - hashTerm = term - found = true - } - return hashTerm, found -} - -func findRangeConditionTerm( - rangeKey string, - terms []parsedKeyConditionTerm, - hashTerm parsedKeyConditionTerm, -) (parsedKeyConditionTerm, bool) { - if strings.TrimSpace(rangeKey) == "" { - return parsedKeyConditionTerm{}, false - } - for _, term := range terms { - if term == hashTerm { - continue - } - if term.attr != rangeKey { - return parsedKeyConditionTerm{}, false - } - return term, true - } - return parsedKeyConditionTerm{}, false -} - -func matchesQueryCondition(item map[string]attributeValue, cond queryCondition) bool { - hashAttr, ok := item[cond.hashAttr] - if !ok || !attributeValueEqual(hashAttr, cond.hashValue) { - return false - } - if cond.rangeCond == nil { - return true - } - rangeAttr, ok := item[cond.rangeCond.attr] - if !ok { - return false - } - return matchesQueryRangeCondition(rangeAttr, *cond.rangeCond) -} - -func matchesQueryRangeCondition(attr attributeValue, cond queryRangeCondition) bool { - if cond.op == queryRangeOpBeginsWith { - return matchesQueryRangeBeginsWith(attr, cond.value1) - } - if cond.op == queryRangeOpBetween { - return matchesQueryRangeBetween(attr, cond.value1, cond.value2) - } - return matchesQueryRangeCompare(attr, cond.value1, cond.op) -} - -func matchesQueryRangeCompare(attr attributeValue, right attributeValue, op queryRangeOperator) bool { - switch op { - case queryRangeOpEqual: - return attributeValueEqual(attr, right) - case queryRangeOpLessThan: - return compareAttributeValueSortKey(attr, right) < 0 - case queryRangeOpLessOrEq: - return compareAttributeValueSortKey(attr, right) <= 0 - case queryRangeOpGreater: - return compareAttributeValueSortKey(attr, right) > 0 - case queryRangeOpGreaterEq: - return compareAttributeValueSortKey(attr, right) >= 0 - case queryRangeOpBetween, queryRangeOpBeginsWith: - return false - default: - return false - } -} - -func matchesQueryRangeBetween(attr attributeValue, lower attributeValue, upper attributeValue) bool { - return compareAttributeValueSortKey(attr, lower) >= 0 && - compareAttributeValueSortKey(attr, upper) <= 0 -} - -func matchesQueryRangeBeginsWith(attr attributeValue, prefixValue attributeValue) bool { - attrKey, err := attributeValueAsKey(attr) - if err != nil { - return false - } - prefix, err := attributeValueAsKey(prefixValue) - if err != nil { - return false - } - return strings.HasPrefix(attrKey, prefix) -} - -func parseCreateTableKeySchema(elems []createTableKeySchemaElement) (dynamoKeySchema, error) { - var ks dynamoKeySchema - for _, e := range elems { - switch strings.ToUpper(strings.TrimSpace(e.KeyType)) { - case "HASH": - ks.HashKey = e.AttributeName - case "RANGE": - ks.RangeKey = e.AttributeName - } - } - if strings.TrimSpace(ks.HashKey) == "" { - return dynamoKeySchema{}, errors.New("missing HASH key schema") - } - return ks, nil -} - -func (t *dynamoTableSchema) keySchemaForQuery(indexName string) (dynamoKeySchema, error) { - if strings.TrimSpace(indexName) == "" { - return t.PrimaryKey, nil - } - gsi, ok := t.GlobalSecondaryIndexes[indexName] - if !ok { - return dynamoKeySchema{}, errors.New("unknown index") - } - return gsi.KeySchema, nil -} - -func (t *dynamoTableSchema) gsiProjectedAttributeSet(indexName string) (bool, map[string]struct{}, error) { - gsi, ok := t.GlobalSecondaryIndexes[indexName] - if !ok { - return false, nil, errors.New("unknown index") - } - if strings.EqualFold(gsi.Projection.ProjectionType, "ALL") { - return true, nil, nil - } - out := map[string]struct{}{ - t.PrimaryKey.HashKey: {}, - gsi.KeySchema.HashKey: {}, - } - if t.PrimaryKey.RangeKey != "" { - out[t.PrimaryKey.RangeKey] = struct{}{} - } - if gsi.KeySchema.RangeKey != "" { - out[gsi.KeySchema.RangeKey] = struct{}{} - } - for _, attr := range gsi.Projection.NonKeyAttributes { - out[attr] = struct{}{} - } - return false, out, nil -} - -func (t *dynamoTableSchema) projectItemForIndex(indexName string, item map[string]attributeValue) (map[string]attributeValue, error) { - allProjected, projected, err := t.gsiProjectedAttributeSet(indexName) - if err != nil { - return nil, err - } - if allProjected { - return cloneAttributeValueMap(item), nil - } - out := make(map[string]attributeValue, len(projected)) - for attr := range projected { - if value, ok := item[attr]; ok { - out[attr] = cloneAttributeValue(value) - } - } - return out, nil -} - -func (t *dynamoTableSchema) usesOrderedKeyEncoding() bool { - return t != nil && t.KeyEncodingVersion >= dynamoOrderedKeyEncodingV2 -} - -func (t *dynamoTableSchema) needsLegacyKeyMigration() bool { - return t != nil && (!t.usesOrderedKeyEncoding() || t.MigratingFromGeneration != 0) -} - -func (t *dynamoTableSchema) migrationSourceSchema() *dynamoTableSchema { - if t == nil || t.MigratingFromGeneration == 0 { - return nil - } - return &dynamoTableSchema{ - TableName: t.TableName, - AttributeDefinitions: t.AttributeDefinitions, - PrimaryKey: t.PrimaryKey, - GlobalSecondaryIndexes: t.GlobalSecondaryIndexes, - KeyEncodingVersion: 0, - Generation: t.MigratingFromGeneration, - } -} - -func (t *dynamoTableSchema) itemKeyFromAttributes(attrs map[string]attributeValue) ([]byte, error) { - if !t.usesOrderedKeyEncoding() { - return t.legacyItemKeyFromAttributes(attrs) - } - primary, err := t.primaryKeyValues(attrs) - if err != nil { - return nil, err - } - return dynamoItemKey(t.TableName, t.Generation, primary.hash, primary.rangeKey), nil -} - -func (t *dynamoTableSchema) legacyItemKeyFromAttributes(attrs map[string]attributeValue) ([]byte, error) { - hashAttr, ok := attrs[t.PrimaryKey.HashKey] - if !ok { - return nil, errors.New("missing hash key attribute") - } - hashKey, err := attributeValueAsKey(hashAttr) - if err != nil { - return nil, err - } - rangeKey := "" - if t.PrimaryKey.RangeKey != "" { - rangeAttr, ok := attrs[t.PrimaryKey.RangeKey] - if !ok { - return nil, errors.New("missing range key attribute") - } - rangeKey, err = attributeValueAsKey(rangeAttr) - if err != nil { - return nil, err - } - } - return legacyDynamoItemKey(t.TableName, t.Generation, hashKey, rangeKey), nil -} - -func (t *dynamoTableSchema) gsiKeyFromAttributes(indexName string, attrs map[string]attributeValue) ([]byte, bool, error) { - if !t.usesOrderedKeyEncoding() { - return t.legacyGSIKeyFromAttributes(indexName, attrs) - } - gsi, ok := t.GlobalSecondaryIndexes[indexName] - if !ok { - return nil, false, errors.New("global secondary index not found") - } - primary, err := t.primaryKeyValues(attrs) - if err != nil { - return nil, false, err - } - index, include, err := gsiKeyValues(attrs, gsi.KeySchema) - if err != nil || !include { - return nil, include, err - } - return dynamoGSIKey(t.TableName, t.Generation, indexName, index.hash, index.rangeKey, primary.hash, primary.rangeKey), true, nil -} - -func (t *dynamoTableSchema) legacyGSIKeyFromAttributes(indexName string, attrs map[string]attributeValue) ([]byte, bool, error) { - gsi, ok := t.GlobalSecondaryIndexes[indexName] - if !ok { - return nil, false, errors.New("global secondary index not found") - } - pkHash, pkRange, err := t.legacyPrimaryKeyValues(attrs) - if err != nil { - return nil, false, err - } - indexHash, indexRange, include, err := legacyGSIKeyValues(attrs, gsi.KeySchema) - if err != nil || !include { - return nil, include, err - } - return legacyDynamoGSIKey(t.TableName, t.Generation, indexName, indexHash, indexRange, pkHash, pkRange), true, nil -} - -func (t *dynamoTableSchema) gsiEntryKeysForItem(attrs map[string]attributeValue) ([][]byte, error) { - if len(t.GlobalSecondaryIndexes) == 0 || len(attrs) == 0 { - return nil, nil - } - if !t.usesOrderedKeyEncoding() { - return t.legacyGSIEntryKeysForItem(attrs) - } - primary, err := t.primaryKeyValues(attrs) - if err != nil { - return nil, err - } - indexNames := sortedGSIIndexNames(t.GlobalSecondaryIndexes) - keys := make([][]byte, 0, len(indexNames)) - for _, indexName := range indexNames { - gsi := t.GlobalSecondaryIndexes[indexName] - index, include, err := gsiKeyValues(attrs, gsi.KeySchema) - if err != nil { - return nil, err - } - if !include { - continue - } - keys = append(keys, dynamoGSIKey(t.TableName, t.Generation, indexName, index.hash, index.rangeKey, primary.hash, primary.rangeKey)) - } - return keys, nil -} - -func (t *dynamoTableSchema) legacyGSIEntryKeysForItem(attrs map[string]attributeValue) ([][]byte, error) { - primaryHash, primaryRange, err := t.legacyPrimaryKeyValues(attrs) - if err != nil { - return nil, err - } - indexNames := sortedGSIIndexNames(t.GlobalSecondaryIndexes) - keys := make([][]byte, 0, len(indexNames)) - for _, indexName := range indexNames { - gsi := t.GlobalSecondaryIndexes[indexName] - indexHash, indexRange, include, err := legacyGSIKeyValues(attrs, gsi.KeySchema) - if err != nil { - return nil, err - } - if !include { - continue - } - keys = append(keys, legacyDynamoGSIKey(t.TableName, t.Generation, indexName, indexHash, indexRange, primaryHash, primaryRange)) - } - return keys, nil -} - -func (t *dynamoTableSchema) legacyPrimaryKeyValues(attrs map[string]attributeValue) (string, string, error) { - hashAttr, ok := attrs[t.PrimaryKey.HashKey] - if !ok { - return "", "", errors.New("missing hash key attribute") - } - hash, err := attributeValueAsKey(hashAttr) - if err != nil { - return "", "", err - } - rangeKey := "" - if t.PrimaryKey.RangeKey != "" { - rangeAttr, ok := attrs[t.PrimaryKey.RangeKey] - if !ok { - return "", "", errors.New("missing range key attribute") - } - rangeKey, err = attributeValueAsKey(rangeAttr) - if err != nil { - return "", "", err - } - } - return hash, rangeKey, nil -} - -type dynamoEncodedKeyValues struct { - hash []byte - rangeKey []byte -} - -func (t *dynamoTableSchema) primaryKeyValues(attrs map[string]attributeValue) (dynamoEncodedKeyValues, error) { - hashAttr, ok := attrs[t.PrimaryKey.HashKey] - if !ok { - return dynamoEncodedKeyValues{}, errors.New("missing hash key attribute") - } - hash, err := attributeValueAsKeySegment(hashAttr) - if err != nil { - return dynamoEncodedKeyValues{}, err - } - var rangeKey []byte - if t.PrimaryKey.RangeKey != "" { - rangeAttr, ok := attrs[t.PrimaryKey.RangeKey] - if !ok { - return dynamoEncodedKeyValues{}, errors.New("missing range key attribute") - } - rangeKey, err = attributeValueAsKeySegment(rangeAttr) - if err != nil { - return dynamoEncodedKeyValues{}, err - } - } - return dynamoEncodedKeyValues{hash: hash, rangeKey: rangeKey}, nil -} - -func gsiKeyValues(attrs map[string]attributeValue, ks dynamoKeySchema) (dynamoEncodedKeyValues, bool, error) { - hashAttr, ok := attrs[ks.HashKey] - if !ok { - return dynamoEncodedKeyValues{}, false, nil - } - hash, err := attributeValueAsKeySegment(hashAttr) - if err != nil { - return dynamoEncodedKeyValues{}, false, err - } - var rangeKey []byte - if ks.RangeKey != "" { - rangeAttr, ok := attrs[ks.RangeKey] - if !ok { - return dynamoEncodedKeyValues{}, false, nil - } - rangeKey, err = attributeValueAsKeySegment(rangeAttr) - if err != nil { - return dynamoEncodedKeyValues{}, false, err - } - } - return dynamoEncodedKeyValues{hash: hash, rangeKey: rangeKey}, true, nil -} - -func legacyGSIKeyValues(attrs map[string]attributeValue, ks dynamoKeySchema) (string, string, bool, error) { - hashAttr, ok := attrs[ks.HashKey] - if !ok { - return "", "", false, nil - } - hash, err := attributeValueAsKey(hashAttr) - if err != nil { - return "", "", false, err - } - rangeKey := "" - if ks.RangeKey != "" { - rangeAttr, ok := attrs[ks.RangeKey] - if !ok { - return "", "", false, nil - } - rangeKey, err = attributeValueAsKey(rangeAttr) - if err != nil { - return "", "", false, err - } - } - return hash, rangeKey, true, nil -} - -func sortedGSIIndexNames(indexes map[string]dynamoGlobalSecondaryIndex) []string { - names := make([]string, 0, len(indexes)) - for name := range indexes { - names = append(names, name) - } - sort.Strings(names) - return names -} - -var attributeValueKeyExtractors = map[attributeValueKind]func(attributeValue) string{ - attributeValueKindString: func(attr attributeValue) string { return attr.stringValue() }, - attributeValueKindNumber: func(attr attributeValue) string { return attr.numberValue() }, - attributeValueKindBinary: func(attr attributeValue) string { return string(attr.B) }, -} - -var attributeValueKeyByteExtractors = map[attributeValueKind]func(attributeValue) []byte{ - attributeValueKindString: func(attr attributeValue) []byte { - return []byte(attr.stringValue()) - }, - attributeValueKindBinary: func(attr attributeValue) []byte { - return bytes.Clone(attr.B) - }, -} - -var attributeValueScalarEqualityComparators = map[attributeValueKind]func(attributeValue, attributeValue) bool{ - attributeValueKindString: func(left attributeValue, right attributeValue) bool { return left.stringValue() == right.stringValue() }, - attributeValueKindNumber: numberAttributeValueEqual, - attributeValueKindBinary: func(left attributeValue, right attributeValue) bool { return bytes.Equal(left.B, right.B) }, - attributeValueKindBool: func(left attributeValue, right attributeValue) bool { return *left.BOOL == *right.BOOL }, - attributeValueKindNull: func(left attributeValue, right attributeValue) bool { return *left.NULL == *right.NULL }, - attributeValueKindStringSet: func(left attributeValue, right attributeValue) bool { - return unorderedStringSlicesEqual(left.SS, right.SS) - }, - attributeValueKindNumberSet: func(left attributeValue, right attributeValue) bool { - return unorderedNumberSlicesEqual(left.NS, right.NS) - }, - attributeValueKindBinarySet: func(left attributeValue, right attributeValue) bool { - return unorderedBinarySlicesEqual(left.BS, right.BS) - }, -} - -var attributeValueSortFormatters = map[attributeValueKind]func(attributeValue) string{ - attributeValueKindString: func(attr attributeValue) string { return attr.stringValue() }, - attributeValueKindNumber: func(attr attributeValue) string { return attr.numberValue() }, - attributeValueKindBinary: func(attr attributeValue) string { return base64.RawURLEncoding.EncodeToString(attr.B) }, - attributeValueKindBool: formatBoolAttributeValue, - attributeValueKindNull: func(attributeValue) string { return "" }, - attributeValueKindStringSet: func(attr attributeValue) string { return strings.Join(sortedStringSlice(attr.SS), "\x00") }, - attributeValueKindNumberSet: func(attr attributeValue) string { return strings.Join(sortedNumberStrings(attr.NS), "\x00") }, - attributeValueKindBinarySet: func(attr attributeValue) string { return strings.Join(sortedBinaryStrings(attr.BS), "\x00") }, -} - -func attributeValueAsKey(attr attributeValue) (string, error) { - kind, count := detectAttributeValueKind(attr) - if count != 1 { - return "", errors.New("unsupported key attribute type") - } - extract, ok := attributeValueKeyExtractors[kind] - if !ok { - return "", errors.New("unsupported key attribute type") - } - return extract(attr), nil -} - -func attributeValueAsKeyBytes(attr attributeValue) ([]byte, error) { - kind, count := detectAttributeValueKind(attr) - if count != 1 { - return nil, errors.New("unsupported key attribute type") - } - if kind == attributeValueKindNumber { - return encodeNumericKeyBytes(attr.numberValue()) - } - extract, ok := attributeValueKeyByteExtractors[kind] - if !ok { - return nil, errors.New("unsupported key attribute type") - } - return extract(attr), nil -} - -func attributeValueAsKeySegment(attr attributeValue) ([]byte, error) { - raw, err := attributeValueAsKeyBytes(attr) - if err != nil { - return nil, err - } - return encodeDynamoKeySegment(raw), nil -} - -type numericKeyParts struct { - negative bool - exponent int64 - digits []byte -} - -func encodeNumericKeyBytes(v string) ([]byte, error) { - parts, err := parseNumericKeyParts(v) - if err != nil { - return nil, err - } - if len(parts.digits) == 0 { - return []byte{0x01}, nil - } - body := encodeOrderedSignedInt64(parts.exponent) - body = append(body, dynamoKeyEscapeByte) - body = append(body, parts.digits...) - if !parts.negative { - return append([]byte{0x02}, body...), nil - } - return append([]byte{0x00}, invertBytes(body)...), nil -} - -func parseNumericKeyParts(v string) (numericKeyParts, error) { - trimmed, negative, exp10, err := parseNumericKeyLiteral(v) - if err != nil { - return numericKeyParts{}, err - } - digits, exponent, zero, err := normalizeNumericKeyParts(trimmed, exp10) - if err != nil { - return numericKeyParts{}, err - } - if zero { - return numericKeyParts{}, nil - } - return numericKeyParts{ - negative: negative, - exponent: exponent, - digits: digits, - }, nil -} - -func parseNumericKeyLiteral(v string) (string, bool, int64, error) { - trimmed := strings.TrimSpace(v) - if trimmed == "" { - return "", false, 0, errors.New("unsupported key attribute type") - } - - negative := false - switch trimmed[0] { - case '+': - trimmed = trimmed[1:] - case '-': - negative = true - trimmed = trimmed[1:] - } - if trimmed == "" { - return "", false, 0, errors.New("unsupported key attribute type") - } - - exp10 := int64(0) - if idx := strings.IndexAny(trimmed, "eE"); idx >= 0 { - expPart := strings.TrimSpace(trimmed[idx+1:]) - trimmed = trimmed[:idx] - parsedExp, err := parseNumericExponent(expPart) - if err != nil { - return "", false, 0, err - } - exp10 = parsedExp - } - return trimmed, negative, exp10, nil -} - -func parseNumericExponent(expPart string) (int64, error) { - if expPart == "" { - return 0, errors.New("unsupported key attribute type") - } - parsedExp, err := strconv.ParseInt(expPart, 10, 64) - if err != nil { - return 0, errors.New("unsupported key attribute type") - } - return parsedExp, nil -} - -func normalizeNumericKeyParts(trimmed string, exp10 int64) ([]byte, int64, bool, error) { - intPart, fracPart, err := splitNumericMantissa(trimmed) - if err != nil { - return nil, 0, false, err - } - combined := intPart + fracPart - leadingZeros := leadingZeroCount(combined) - if leadingZeros == len(combined) { - return nil, 0, true, nil - } - digits := []byte(strings.TrimRight(combined[leadingZeros:], "0")) - if len(digits) == 0 { - return nil, 0, true, nil - } - exponent := int64(len(intPart)) + exp10 - int64(leadingZeros) - return digits, exponent, false, nil -} - -func splitNumericMantissa(trimmed string) (string, string, error) { - if strings.Count(trimmed, ".") > 1 { - return "", "", errors.New("unsupported key attribute type") - } - intPart := trimmed - fracPart := "" - if before, after, ok := strings.Cut(trimmed, "."); ok { - intPart = before - fracPart = after - } - if intPart == "" && fracPart == "" { - return "", "", errors.New("unsupported key attribute type") - } - if !decimalDigitsOnly(intPart) || !decimalDigitsOnly(fracPart) { - return "", "", errors.New("unsupported key attribute type") - } - return intPart, fracPart, nil -} - -func leadingZeroCount(v string) int { - count := 0 - for count < len(v) && v[count] == '0' { - count++ - } - return count -} - -func decimalDigitsOnly(v string) bool { - for i := range v { - if v[i] < '0' || v[i] > '9' { - return false - } - } - return true -} - -func encodeOrderedSignedInt64(v int64) []byte { - switch { - case v < 0: - return append([]byte{0x00}, invertBytes(encodeOrderedUint64(signedMagnitude(v)))...) - case v == 0: - return []byte{0x01} - default: - return append([]byte{0x02}, encodeOrderedUint64(uint64(v))...) - } -} - -func signedMagnitude(v int64) uint64 { - if v >= 0 { - return uint64(v) - } - abs := big.NewInt(v) - abs.Abs(abs) - return abs.Uint64() -} - -var orderedUint64LengthPrefix = [...]byte{0, 1, 2, 3, 4, 5, 6, 7, 8} - -func encodeOrderedUint64(v uint64) []byte { - var buf [8]byte - binary.BigEndian.PutUint64(buf[:], v) - start := 0 - for start < len(buf)-1 && buf[start] == 0 { - start++ - } - width := len(buf) - start - out := make([]byte, 0, width+1) - out = append(out, orderedUint64LengthPrefix[width]) - out = append(out, buf[start:]...) - return out -} - -func invertBytes(in []byte) []byte { - out := make([]byte, len(in)) - for i := range in { - out[i] = ^in[i] - } - return out -} - -func encodeDynamoKeySegment(raw []byte) []byte { - out := encodeDynamoKeySegmentPrefix(raw) - out = append(out, dynamoKeyEscapeByte, dynamoKeyTerminatorByte) - return out -} - -func encodeDynamoKeySegmentPrefix(raw []byte) []byte { - return appendEscapedDynamoKeyBytes(make([]byte, 0, len(raw)+dynamoKeySegmentOverhead), raw) -} - -func appendEscapedDynamoKeyBytes(dst []byte, raw []byte) []byte { - for _, b := range raw { - if b == dynamoKeyEscapeByte { - dst = append(dst, dynamoKeyEscapeByte, dynamoKeyEscapedZeroByte) - continue - } - dst = append(dst, b) - } - return dst -} - -func attributeValueEqual(left attributeValue, right attributeValue) bool { - leftKind, leftCount := detectAttributeValueKind(left) - rightKind, rightCount := detectAttributeValueKind(right) - if leftCount == 0 && rightCount == 0 { - return true - } - if leftCount != 1 || rightCount != 1 || leftKind != rightKind { - return false - } - if leftKind == attributeValueKindMap { - return mapAttributeValueEqual(left, right) - } - if leftKind == attributeValueKindList { - return listAttributeValueEqual(left, right) - } - compare, ok := attributeValueScalarEqualityComparators[leftKind] - if !ok { - return false - } - return compare(left, right) -} - -func numberAttributeValueEqual(left attributeValue, right attributeValue) bool { - cmp, ok := compareNumericAttributeString(left.numberValue(), right.numberValue()) - if !ok { - return left.numberValue() == right.numberValue() - } - return cmp == 0 -} - -func mapAttributeValueEqual(left attributeValue, right attributeValue) bool { - if len(left.M) != len(right.M) { - return false - } - for key, leftValue := range left.M { - rightValue, ok := right.M[key] - if !ok || !attributeValueEqual(leftValue, rightValue) { - return false - } - } - return true -} - -func listAttributeValueEqual(left attributeValue, right attributeValue) bool { - if len(left.L) != len(right.L) { - return false - } - for i := range left.L { - if !attributeValueEqual(left.L[i], right.L[i]) { - return false - } - } - return true -} - -func compareAttributeValueSortKey(left attributeValue, right attributeValue) int { - if left.hasNumberType() && right.hasNumberType() { - if cmp, ok := compareNumericAttributeString(left.numberValue(), right.numberValue()); ok { - return cmp - } - } - if left.hasBinaryType() && right.hasBinaryType() { - return bytes.Compare(left.B, right.B) - } - return strings.Compare(attributeValueSortFallback(left), attributeValueSortFallback(right)) -} - -func compareNumericAttributeString(left string, right string) (int, bool) { - leftRat := &big.Rat{} - rightRat := &big.Rat{} - if _, ok := leftRat.SetString(strings.TrimSpace(left)); !ok { - return 0, false - } - if _, ok := rightRat.SetString(strings.TrimSpace(right)); !ok { - return 0, false - } - return leftRat.Cmp(rightRat), true -} - -func attributeValueSortFallback(attr attributeValue) string { - kind, count := detectAttributeValueKind(attr) - if count != 1 { - return "" - } - format, ok := attributeValueSortFormatters[kind] - if !ok { - return "" - } - return format(attr) -} - -func formatBoolAttributeValue(attr attributeValue) string { - if *attr.BOOL { - return "1" - } - return "0" -} - -func unorderedStringSlicesEqual(left []string, right []string) bool { - if len(left) != len(right) { - return false - } - lv := sortedStringSlice(left) - rv := sortedStringSlice(right) - for i := range lv { - if lv[i] != rv[i] { - return false - } - } - return true -} - -func unorderedNumberSlicesEqual(left []string, right []string) bool { - if len(left) != len(right) { - return false - } - lv := sortedNumberStrings(left) - rv := sortedNumberStrings(right) - for i := range lv { - if lv[i] != rv[i] { - return false - } - } - return true -} - -func unorderedBinarySlicesEqual(left [][]byte, right [][]byte) bool { - if len(left) != len(right) { - return false - } - lv := sortedBinaryStrings(left) - rv := sortedBinaryStrings(right) - for i := range lv { - if lv[i] != rv[i] { - return false - } - } - return true -} - -func sortedStringSlice(in []string) []string { - out := append([]string(nil), in...) - sort.Strings(out) - return out -} - -func sortedNumberStrings(in []string) []string { - out := make([]string, len(in)) - for i := range in { - out[i] = canonicalNumberString(in[i]) - } - sort.Strings(out) - return out -} - -func sortedBinaryStrings(in [][]byte) []string { - out := make([]string, len(in)) - for i := range in { - out[i] = base64.RawURLEncoding.EncodeToString(in[i]) - } - sort.Strings(out) - return out -} - -func canonicalNumberString(v string) string { - rat := &big.Rat{} - if _, ok := rat.SetString(strings.TrimSpace(v)); !ok { - return strings.TrimSpace(v) - } - return rat.RatString() -} - -func reverseItems(items []map[string]attributeValue) { - for i, j := 0, len(items)-1; i < j; i, j = i+1, j-1 { - items[i], items[j] = items[j], items[i] - } -} - -func (d *DynamoDBServer) lockItemUpdate(lockKey string) func() { - idx := stripeIndex(lockKey, itemUpdateLockStripeCount) - d.itemUpdateLocks[idx].Lock() - return d.itemUpdateLocks[idx].Unlock -} - -func (d *DynamoDBServer) lockTableOperations(tableNames []string) func() { - if len(tableNames) == 0 { - return func() {} - } - idxs := make([]int, 0, len(tableNames)) - seen := map[int]struct{}{} - for _, tableName := range tableNames { - idx := stripeIndex(tableName, tableLockStripeCount) - if _, ok := seen[idx]; ok { - continue - } - seen[idx] = struct{}{} - idxs = append(idxs, idx) - } - sort.Ints(idxs) - for _, idx := range idxs { - d.tableLocks[idx].Lock() - } - return func() { - for i := len(idxs) - 1; i >= 0; i-- { - d.tableLocks[idxs[i]].Unlock() - } - } -} - -func stripeIndex(key string, stripeCount uint32) int { - if stripeCount == 0 { - return 0 - } - h := fnv.New32a() - _, _ = h.Write([]byte(key)) - return int(h.Sum32() % stripeCount) -} - -func dynamoItemUpdateLockKey(tableName string, key map[string]attributeValue) (string, error) { - parts := make([]string, 0, len(key)) - for name := range key { - parts = append(parts, name) - } - sort.Strings(parts) - var b strings.Builder - b.WriteString(tableName) - b.WriteByte('|') - for _, name := range parts { - val, err := attributeValueAsKey(key[name]) - if err != nil { - return "", errors.WithStack(err) - } - b.WriteString(name) - b.WriteByte('=') - b.WriteString(val) - b.WriteByte('|') - } - return b.String(), nil -} - -func describeTableShape(t *dynamoTableSchema) map[string]any { - attrDefs := make([]map[string]string, 0, len(t.AttributeDefinitions)) - for name, typ := range t.AttributeDefinitions { - attrDefs = append(attrDefs, map[string]string{ - "AttributeName": name, - "AttributeType": typ, - }) - } - sort.Slice(attrDefs, func(i, j int) bool { - return attrDefs[i]["AttributeName"] < attrDefs[j]["AttributeName"] - }) - - keySchema := []map[string]string{{ - "AttributeName": t.PrimaryKey.HashKey, - "KeyType": "HASH", - }} - if t.PrimaryKey.RangeKey != "" { - keySchema = append(keySchema, map[string]string{ - "AttributeName": t.PrimaryKey.RangeKey, - "KeyType": "RANGE", - }) - } - - resp := map[string]any{ - "TableName": t.TableName, - "TableStatus": "ACTIVE", - "KeySchema": keySchema, - "AttributeDefinitions": attrDefs, - } - - if len(t.GlobalSecondaryIndexes) > 0 { - gsis := make([]map[string]any, 0, len(t.GlobalSecondaryIndexes)) - indexNames := make([]string, 0, len(t.GlobalSecondaryIndexes)) - for name := range t.GlobalSecondaryIndexes { - indexNames = append(indexNames, name) - } - sort.Strings(indexNames) - for _, name := range indexNames { - gsi := t.GlobalSecondaryIndexes[name] - ks := gsi.KeySchema - projection := map[string]any{ - "ProjectionType": gsi.Projection.ProjectionType, - } - if len(gsi.Projection.NonKeyAttributes) > 0 { - projection["NonKeyAttributes"] = append([]string(nil), gsi.Projection.NonKeyAttributes...) - } - idxKeySchema := []map[string]string{{ - "AttributeName": ks.HashKey, - "KeyType": "HASH", - }} - if ks.RangeKey != "" { - idxKeySchema = append(idxKeySchema, map[string]string{ - "AttributeName": ks.RangeKey, - "KeyType": "RANGE", - }) - } - indexDesc := map[string]any{ - "IndexName": name, - "IndexStatus": "ACTIVE", - "KeySchema": idxKeySchema, - "Projection": projection, - } - gsis = append(gsis, indexDesc) - } - resp["GlobalSecondaryIndexes"] = gsis - } - - return resp -} - -func cloneAttributeValueMap(in map[string]attributeValue) map[string]attributeValue { - if in == nil { - return nil - } - out := make(map[string]attributeValue, len(in)) - for k, v := range in { - out[k] = cloneAttributeValue(v) - } - return out -} - -func cloneAttributeValueList(in []attributeValue) []attributeValue { - if in == nil { - return nil - } - out := make([]attributeValue, 0, len(in)) - for _, value := range in { - out = append(out, cloneAttributeValue(value)) - } - return out -} - -func cloneAttributeValue(in attributeValue) attributeValue { - out := attributeValue{} - if in.S != nil { - s := *in.S - out.S = &s - } - if in.N != nil { - n := *in.N - out.N = &n - } - if in.B != nil { - out.B = bytes.Clone(in.B) - } - if in.BOOL != nil { - b := *in.BOOL - out.BOOL = &b - } - if in.NULL != nil { - n := *in.NULL - out.NULL = &n - } - out.SS = cloneStringSlice(in.SS) - out.NS = cloneStringSlice(in.NS) - out.BS = cloneBinarySet(in.BS) - if in.L != nil { - out.L = make([]attributeValue, len(in.L)) - for i := range in.L { - out.L[i] = cloneAttributeValue(in.L[i]) - } - } - if in.M != nil { - out.M = cloneAttributeValueMap(in.M) - } - return out -} - -// globalLastCommitTSProvider is the interface satisfied by stores (e.g. -// LeaderRoutedStore) that can proxy the leader's LastCommitTS. Defined as a -// local interface so DynamoDBServer avoids a hard dependency on the concrete -// kv type. -type globalLastCommitTSProvider interface { - GlobalLastCommitTS(ctx context.Context) uint64 -} - -func (d *DynamoDBServer) nextTxnReadTS() uint64 { - // On a follower the local store.LastCommitTS() may lag behind the leader. - // Use GlobalLastCommitTS so ConsistentRead snapshots and transaction - // start timestamps are aligned with the leader's committed watermark, - // preventing stale pre-reads that cause false Jepsen anomalies and - // unnecessary WriteConflict retries on every follower request. - maxTS := uint64(0) - if p, ok := d.store.(globalLastCommitTSProvider); ok { - maxTS = p.GlobalLastCommitTS(context.Background()) - } else if d.store != nil { - maxTS = d.store.LastCommitTS() - } - - // Advance the HLC so subsequent commitTS calls produce values > maxTS, - // but return maxTS directly as the snapshot — NOT clock.Next(). - // - // clock.Next() can be ahead of store.LastCommitTS() because concurrent - // dispatchTxn calls advance the HLC before their Raft entry is applied. - // If readTS = clock.Next() = T and a concurrent write obtained - // commitTS = T-1 (still in the Raft pipeline), the version at T-1 is - // not yet in Pebble. Reads would see stale data and the FSM conflict - // check (latestTS > startTS: T-1 > T → false) would silently pass, - // allowing corrupted writes. Returning maxTS closes this gap: every - // version at ≤ maxTS is guaranteed visible, and any concurrent write at - // > maxTS triggers a WriteConflict and a retry. - clock := d.coordinator.Clock() - if clock != nil && maxTS > 0 { - clock.Observe(maxTS) - } - if maxTS == 0 { - return 1 - } - return maxTS -} - -func (d *DynamoDBServer) pinReadTS(ts uint64) *kv.ActiveTimestampToken { - if d == nil || d.readTracker == nil { - return &kv.ActiveTimestampToken{} - } - return d.readTracker.Pin(ts) -} - -func (d *DynamoDBServer) loadTableSchema(ctx context.Context, tableName string) (*dynamoTableSchema, bool, error) { - return d.loadTableSchemaAt(ctx, tableName, snapshotTS(d.coordinator.Clock(), d.store)) -} - -func (d *DynamoDBServer) loadTableSchemaAt(ctx context.Context, tableName string, ts uint64) (*dynamoTableSchema, bool, error) { - b, err := d.store.GetAt(ctx, dynamoTableMetaKey(tableName), ts) - if err != nil { - if errors.Is(err, store.ErrKeyNotFound) { - return nil, false, nil - } - return nil, false, errors.WithStack(err) - } - schema, err := decodeStoredDynamoTableSchema(b) - if err != nil { - return nil, false, err - } - d.observeTables(ctx, schema.TableName) - return schema, true, nil -} - -func (d *DynamoDBServer) loadTableGenerationAt(ctx context.Context, tableName string, ts uint64) (uint64, error) { - b, err := d.store.GetAt(ctx, dynamoTableGenerationKey(tableName), ts) - if err != nil { - if errors.Is(err, store.ErrKeyNotFound) { - return 0, nil - } - return 0, errors.WithStack(err) - } - gen, err := strconv.ParseUint(strings.TrimSpace(string(b)), 10, 64) - if err != nil { - return 0, errors.WithStack(err) - } - return gen, nil -} - -func (d *DynamoDBServer) readItemAtKeyAt(ctx context.Context, key []byte, ts uint64) (map[string]attributeValue, bool, error) { - b, err := d.store.GetAt(ctx, key, ts) - if err != nil { - if errors.Is(err, store.ErrKeyNotFound) { - return nil, false, nil - } - return nil, false, errors.WithStack(err) - } - item, err := decodeStoredDynamoItem(b) - if err != nil { - return nil, false, err - } - return item, true, nil -} - -func (d *DynamoDBServer) readLogicalItemAt( - ctx context.Context, - schema *dynamoTableSchema, - key map[string]attributeValue, - ts uint64, -) (*dynamoItemLocation, bool, error) { - itemKey, err := schema.itemKeyFromAttributes(key) - if err != nil { - return nil, false, err - } - item, found, err := d.readItemAtKeyAt(ctx, itemKey, ts) - if err != nil { - return nil, false, err - } - if found { - return &dynamoItemLocation{schema: schema, key: itemKey, item: item}, true, nil - } - sourceSchema := schema.migrationSourceSchema() - if sourceSchema == nil { - return nil, false, nil - } - sourceKey, err := sourceSchema.itemKeyFromAttributes(key) - if err != nil { - return nil, false, err - } - item, found, err = d.readItemAtKeyAt(ctx, sourceKey, ts) - if err != nil { - return nil, false, err - } - if !found { - return nil, false, nil - } - return &dynamoItemLocation{schema: sourceSchema, key: sourceKey, item: item}, true, nil -} - -func (d *DynamoDBServer) ensureLegacyTableMigration(ctx context.Context, tableName string) error { - unlock := d.lockTableOperations([]string{tableName}) - defer unlock() - return d.ensureLegacyTableMigrationLocked(ctx, tableName) -} - -func (d *DynamoDBServer) ensureLegacyTableMigrationLocked(ctx context.Context, tableName string) error { - backoff := transactRetryInitialBackoff - deadline := time.Now().Add(transactRetryMaxDuration) - for range transactRetryMaxAttempts { - readTS := d.nextTxnReadTS() - schema, exists, err := d.loadTableSchemaAt(ctx, tableName, readTS) - if err != nil { - return errors.WithStack(err) - } - if !exists || !schema.needsLegacyKeyMigration() { - return nil - } - // Admin read-only callers (AdminScanTable) must not trigger - // migration writes. Their own pre-check at the admin readTS - // already rejects needs-migration tables, but the schema can - // transition between that check and this one (Codex r8 P2 on - // PR #805) — refuse rather than racing into write-path code. - if isAdminReadOnlyContext(ctx) { - return errors.Wrap(ErrAdminDynamoValidation, - "table requires a one-time legacy-key migration before admin read endpoints are available; migrate via the SigV4 surface first") - } - if !schema.usesOrderedKeyEncoding() { - err = d.startLegacyTableKeyMigration(ctx, schema, readTS) - } else { - err = d.migrateLegacyTableGeneration(ctx, schema) - } - if err == nil { - continue - } - if !isRetryableTransactWriteError(err) { - return err - } - if err := waitRetryWithDeadline(ctx, deadline, backoff); err != nil { - return errors.WithStack(err) - } - backoff = nextTransactRetryBackoff(backoff) - } - return newDynamoAPIError(http.StatusInternalServerError, dynamoErrInternal, "legacy table migration retry attempts exhausted") -} - -func (d *DynamoDBServer) startLegacyTableKeyMigration( - ctx context.Context, - schema *dynamoTableSchema, - readTS uint64, -) error { - if schema == nil || schema.usesOrderedKeyEncoding() { - return nil - } - nextGeneration, err := d.nextTableGenerationAt(ctx, schema.TableName, readTS) - if err != nil { - return err - } - req, err := makeCreateTableRequest(&dynamoTableSchema{ - TableName: schema.TableName, - AttributeDefinitions: schema.AttributeDefinitions, - PrimaryKey: schema.PrimaryKey, - GlobalSecondaryIndexes: schema.GlobalSecondaryIndexes, - KeyEncodingVersion: dynamoOrderedKeyEncodingV2, - MigratingFromGeneration: schema.Generation, - }, nextGeneration) - if err != nil { - return err - } - req.StartTS = readTS - if _, err := d.coordinator.Dispatch(ctx, req); err != nil { - return errors.WithStack(err) - } - return nil -} - -func (d *DynamoDBServer) migrateLegacyTableGeneration(ctx context.Context, schema *dynamoTableSchema) error { - sourceSchema := schema.migrationSourceSchema() - if sourceSchema == nil { - return nil - } - sourceReadTS := snapshotTS(d.coordinator.Clock(), d.store) - if err := d.migrateLegacySourceItems(ctx, schema, sourceSchema, sourceReadTS); err != nil { - return err - } - empty, err := d.isTableGenerationEmpty(ctx, schema.TableName, sourceSchema.Generation) - if err != nil { - return err - } - if !empty { - return nil - } - return d.finalizeLegacyTableMigration(ctx, schema) -} - -func (d *DynamoDBServer) migrateLegacySourceItems( - ctx context.Context, - targetSchema *dynamoTableSchema, - sourceSchema *dynamoTableSchema, - readTS uint64, -) error { - readPin := d.pinReadTS(readTS) - defer readPin.Release() - - prefix := dynamoItemPrefixForTable(targetSchema.TableName, sourceSchema.Generation) - upper := prefixScanEnd(prefix) - cursor := bytes.Clone(prefix) - for { - kvs, err := d.scanLegacyMigrationPage(ctx, cursor, upper, readTS) - if err != nil { - return err - } - nextCursor, done, err := d.migrateLegacySourcePage(ctx, targetSchema, sourceSchema, prefix, upper, kvs) - if err != nil { - return err - } - if done { - return nil - } - cursor = nextCursor - } -} - -func (d *DynamoDBServer) migrateLegacyItem( - ctx context.Context, - targetSchema *dynamoTableSchema, - sourceSchema *dynamoTableSchema, - sourceKey []byte, - sourceItem map[string]attributeValue, -) error { - lockKey, targetKey, err := resolveLegacyMigrationTarget(targetSchema, sourceItem) - if err != nil { - return err - } - unlock := d.lockItemUpdate(lockKey) - defer unlock() - - backoff := transactRetryInitialBackoff - deadline := time.Now().Add(transactRetryMaxDuration) - for range transactRetryMaxAttempts { - readTS := d.nextTxnReadTS() - req, done, err := d.buildLegacyMigrationRequest(ctx, targetSchema, sourceSchema, targetKey, sourceKey, readTS) - if err != nil { - return err - } - if done { - return nil - } - req.StartTS = readTS - if _, err := d.coordinator.Dispatch(ctx, req); err == nil { - return nil - } else if !isRetryableTransactWriteError(err) { - return errors.WithStack(err) - } - if err := waitRetryWithDeadline(ctx, deadline, backoff); err != nil { - return errors.WithStack(err) - } - backoff = nextTransactRetryBackoff(backoff) - } - return newDynamoAPIError(http.StatusInternalServerError, dynamoErrInternal, "legacy item migration retry attempts exhausted") -} - -func (d *DynamoDBServer) scanLegacyMigrationPage( - ctx context.Context, - cursor []byte, - upper []byte, - readTS uint64, -) ([]*store.KVPair, error) { - kvs, err := d.store.ScanAt(ctx, cursor, upper, dynamoScanPageLimit, readTS) - if err != nil { - return nil, errors.WithStack(err) - } - return kvs, nil -} - -func (d *DynamoDBServer) migrateLegacySourcePage( - ctx context.Context, - targetSchema *dynamoTableSchema, - sourceSchema *dynamoTableSchema, - prefix []byte, - upper []byte, - kvs []*store.KVPair, -) ([]byte, bool, error) { - if len(kvs) == 0 { - return nil, true, nil - } - for _, kvp := range kvs { - if !bytes.HasPrefix(kvp.Key, prefix) { - return nil, true, nil - } - item, err := decodeStoredDynamoItem(kvp.Value) - if err != nil { - return nil, false, err - } - if err := d.migrateLegacyItem(ctx, targetSchema, sourceSchema, kvp.Key, item); err != nil { - return nil, false, err - } - } - if len(kvs) < dynamoScanPageLimit { - return nil, true, nil - } - cursor := nextScanCursor(kvs[len(kvs)-1].Key) - if upper != nil && bytes.Compare(cursor, upper) >= 0 { - return nil, true, nil - } - return cursor, false, nil -} - -func resolveLegacyMigrationTarget(targetSchema *dynamoTableSchema, sourceItem map[string]attributeValue) (string, []byte, error) { - keyAttrs, err := primaryKeyAttributes(targetSchema.PrimaryKey, sourceItem) - if err != nil { - return "", nil, err - } - lockKey, err := dynamoItemUpdateLockKey(targetSchema.TableName, keyAttrs) - if err != nil { - return "", nil, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) - } - targetKey, err := targetSchema.itemKeyFromAttributes(keyAttrs) - if err != nil { - return "", nil, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) - } - return lockKey, targetKey, nil -} - -func (d *DynamoDBServer) buildLegacyMigrationRequest( - ctx context.Context, - targetSchema *dynamoTableSchema, - sourceSchema *dynamoTableSchema, - targetKey []byte, - sourceKey []byte, - readTS uint64, -) (*kv.OperationGroup[kv.OP], bool, error) { - _, targetFound, err := d.readItemAtKeyAt(ctx, targetKey, readTS) - if err != nil { - return nil, false, err - } - currentSource, sourceFound, err := d.readItemAtKeyAt(ctx, sourceKey, readTS) - if err != nil { - return nil, false, err - } - if !sourceFound { - return nil, true, nil - } - currentLocation := &dynamoItemLocation{ - schema: sourceSchema, - key: sourceKey, - item: currentSource, - } - if targetFound { - req, err := buildItemDeleteRequestWithSource(currentLocation) - return req, false, err - } - req, _, err := buildItemWriteRequestWithSource(targetSchema, targetKey, currentSource, currentLocation) - return req, false, err -} - -func (d *DynamoDBServer) isTableGenerationEmpty(ctx context.Context, tableName string, generation uint64) (bool, error) { - prefix := dynamoItemPrefixForTable(tableName, generation) - kvs, err := d.store.ScanAt(ctx, prefix, prefixScanEnd(prefix), 1, snapshotTS(d.coordinator.Clock(), d.store)) - if err != nil { - return false, errors.WithStack(err) - } - for _, kvp := range kvs { - if bytes.HasPrefix(kvp.Key, prefix) { - return false, nil - } - } - return true, nil -} - -func (d *DynamoDBServer) finalizeLegacyTableMigration(ctx context.Context, schema *dynamoTableSchema) error { - if schema == nil || schema.MigratingFromGeneration == 0 { - return nil - } - oldGeneration := schema.MigratingFromGeneration - finalized := *schema - finalized.MigratingFromGeneration = 0 - body, err := encodeStoredDynamoTableSchema(&finalized) - if err != nil { - return errors.WithStack(err) - } - readTS := d.nextTxnReadTS() - req := &kv.OperationGroup[kv.OP]{ - IsTxn: true, - StartTS: readTS, - Elems: []*kv.Elem[kv.OP]{ - {Op: kv.Put, Key: dynamoTableMetaKey(schema.TableName), Value: body}, - }, - } - if _, err := d.coordinator.Dispatch(ctx, req); err != nil { - return errors.WithStack(err) - } - d.launchDeletedTableCleanup(schema.TableName, oldGeneration) - return nil -} - -func (d *DynamoDBServer) scanAllByPrefix(ctx context.Context, prefix []byte) ([]*store.KVPair, error) { - return d.scanAllByPrefixAt(ctx, prefix, snapshotTS(d.coordinator.Clock(), d.store)) -} - -func (d *DynamoDBServer) scanAllByPrefixAt(ctx context.Context, prefix []byte, readTS uint64) ([]*store.KVPair, error) { - readPin := d.pinReadTS(readTS) - defer readPin.Release() - - end := prefixScanEnd(prefix) - start := bytes.Clone(prefix) - - out := make([]*store.KVPair, 0, dynamoScanPageLimit) - for { - kvs, err := d.store.ScanAt(ctx, start, end, dynamoScanPageLimit, readTS) - if err != nil { - return nil, errors.WithStack(err) - } - if len(kvs) == 0 { - break - } - for _, kvp := range kvs { - if !bytes.HasPrefix(kvp.Key, prefix) { - return out, nil - } - out = append(out, kvp) - } - if len(kvs) < dynamoScanPageLimit { - break - } - start = nextScanCursor(kvs[len(kvs)-1].Key) - if end != nil && bytes.Compare(start, end) > 0 { - break - } - } - return out, nil -} - -func nextScanCursor(lastKey []byte) []byte { - next := make([]byte, len(lastKey)+1) - copy(next, lastKey) - return next -} - -func minInt(a int, b int) int { - if a < b { - return a - } - return b -} - -func dynamoTableMetaKey(tableName string) []byte { - return []byte(dynamoTableMetaPrefix + encodeDynamoSegment(tableName)) -} - -func dynamoTableGenerationKey(tableName string) []byte { - return []byte(dynamoTableGenerationPrefix + encodeDynamoSegment(tableName)) -} - -func dynamoItemPrefixForTable(tableName string, generation uint64) []byte { - return []byte(dynamoItemPrefix + encodeDynamoSegment(tableName) + "|" + strconv.FormatUint(generation, 10) + "|") -} - -func dynamoItemHashPrefixForTable(tableName string, generation uint64, hashKey []byte) []byte { - base := dynamoItemPrefixForTable(tableName, generation) - return append(base, hashKey...) -} - -func legacyDynamoItemHashPrefixForTable(tableName string, generation uint64, hashKey string) []byte { - return []byte( - dynamoItemPrefix + - encodeDynamoSegment(tableName) + "|" + - strconv.FormatUint(generation, 10) + "|" + - encodeDynamoSegment(hashKey) + "|", - ) -} - -func dynamoGSIPrefixForTable(tableName string, generation uint64) []byte { - return []byte( - dynamoGSIPrefix + - encodeDynamoSegment(tableName) + "|" + - strconv.FormatUint(generation, 10) + "|", - ) -} - -func dynamoGSIIndexPrefixForTable(tableName string, generation uint64, indexName string) []byte { - return []byte( - dynamoGSIPrefix + - encodeDynamoSegment(tableName) + "|" + - strconv.FormatUint(generation, 10) + "|" + - encodeDynamoSegment(indexName) + "|", - ) -} - -func dynamoGSIHashPrefixForTable(tableName string, generation uint64, indexName string, hashKey []byte) []byte { - base := dynamoGSIIndexPrefixForTable(tableName, generation, indexName) - return append(base, hashKey...) -} - -func legacyDynamoGSIHashPrefixForTable(tableName string, generation uint64, indexName string, hashKey string) []byte { - return []byte( - dynamoGSIPrefix + - encodeDynamoSegment(tableName) + "|" + - strconv.FormatUint(generation, 10) + "|" + - encodeDynamoSegment(indexName) + "|" + - encodeDynamoSegment(hashKey) + "|", - ) -} - -func dynamoGSIKey( - tableName string, - generation uint64, - indexName string, - indexHash []byte, - indexRange []byte, - pkHash []byte, - pkRange []byte, -) []byte { - key := dynamoGSIIndexPrefixForTable(tableName, generation, indexName) - key = append(key, indexHash...) - key = append(key, indexRange...) - key = append(key, pkHash...) - key = append(key, pkRange...) - return key -} - -func legacyDynamoGSIKey(tableName string, generation uint64, indexName string, indexHash string, indexRange string, pkHash string, pkRange string) []byte { - return []byte( - dynamoGSIPrefix + - encodeDynamoSegment(tableName) + "|" + - strconv.FormatUint(generation, 10) + "|" + - encodeDynamoSegment(indexName) + "|" + - encodeDynamoSegment(indexHash) + "|" + - encodeDynamoSegment(indexRange) + "|" + - encodeDynamoSegment(pkHash) + "|" + - encodeDynamoSegment(pkRange), - ) -} - -func dynamoItemKey(tableName string, generation uint64, hashKey []byte, rangeKey []byte) []byte { - key := dynamoItemPrefixForTable(tableName, generation) - key = append(key, hashKey...) - key = append(key, rangeKey...) - return key -} - -func legacyDynamoItemKey(tableName string, generation uint64, hashKey, rangeKey string) []byte { - return []byte( - dynamoItemPrefix + - encodeDynamoSegment(tableName) + "|" + - strconv.FormatUint(generation, 10) + "|" + - encodeDynamoSegment(hashKey) + "|" + - encodeDynamoSegment(rangeKey), - ) -} - -func encodeDynamoSegment(v string) string { - return base64.RawURLEncoding.EncodeToString([]byte(v)) -} - -func decodeDynamoSegment(v string) (string, error) { - b, err := base64.RawURLEncoding.DecodeString(v) - if err != nil { - return "", errors.WithStack(err) - } - return string(b), nil -} - -func tableNameFromMetaKey(key []byte) (string, bool) { - enc, ok := strings.CutPrefix(string(key), dynamoTableMetaPrefix) - if !ok || enc == "" { - return "", false - } - name, err := decodeDynamoSegment(enc) - if err != nil { - return "", false - } - return name, true -}