diff --git a/.github/workflows/dependabot-auto-merge.yml b/.github/workflows/dependabot-auto-merge.yml deleted file mode 100644 index a048a3ec6..000000000 --- a/.github/workflows/dependabot-auto-merge.yml +++ /dev/null @@ -1,25 +0,0 @@ -name: Dependabot auto-merge -on: pull_request - -permissions: - contents: write - -jobs: - dependabot-merge: - runs-on: cx-public-ubuntu-x64 - if: ${{ github.actor == 'dependabot[bot]' }} - steps: - - name: Dependabot metadata - id: metadata - uses: step-security/dependabot-fetch-metadata@bf8fb6e0be0a711c669dc236de6e7f7374ba626e # v3.1.0 - with: - github-token: "${{ secrets.GH_TOKEN }}" - - name: Enable auto-merge for Dependabot PRs - env: - PR_URL: ${{github.event.pull_request.html_url}} - GITHUB_TOKEN: ${{secrets.GH_TOKEN}} - run: gh pr merge --auto --merge "$PR_URL" - - name: Auto approve dependabot PRs - uses: step-security/auto-approve-action@0c28339628c8e79ab2f6813291e7e6cd584b4d30 # v4.0.0 - with: - github-token: ${{ secrets.PERSONAL_ACCESS_TOKEN }} diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 91b41c260..f384b7b70 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -4,21 +4,62 @@ on: workflow_call: inputs: tag: + description: "Next release tag" description: "Next release tag" required: true type: string dev: + description: "Is dev build" description: "Is dev build" required: false default: true type: boolean + secrets: + AC_PASSWORD: + required: true + AC_USER: + required: true + APPLE_DEVELOPER_CERTIFICATE_P12_BASE64: + required: true + APPLE_DEVELOPER_CERTIFICATE_PASSWORD: + required: true + AWS_ASSUME_ROLE_ARN: + required: true + AWS_ASSUME_ROLE_REGION: + required: true + COSIGN_PASSWORD: + required: true + COSIGN_PRIVATE_KEY: + required: true + COSIGN_PUBLIC_KEY: + required: true + DOCKER_PASSWORD: + required: true + DOCKER_USERNAME: + required: true + PERSONAL_ACCESS_TOKEN: + required: true + S3_BUCKET_NAME: + required: true + S3_BUCKET_REGION: + required: true + SIGNING_HSM_CREDS: + required: true + SIGNING_REMOTE_SSH_HOST: + required: true + SIGNING_REMOTE_SSH_PRIVATE_KEY: + required: true + SIGNING_REMOTE_SSH_USER: + required: true workflow_dispatch: inputs: tag: + description: "Next release tag" description: "Next release tag" required: true type: string dev: + description: "Is dev build" description: "Is dev build" required: false default: true @@ -40,6 +81,11 @@ jobs: COSIGN_PASSWORD: ${{ secrets.COSIGN_PASSWORD }} COSIGN_PUBLIC_KEY: ${{ secrets.COSIGN_PUBLIC_KEY }} steps: + - name: Install Harden Runner + uses: checkmarx/harden-runner-action@9af89fc71515a100421586dfdb3dc9c984fbf411 #v2.19.4 + with: + use-policy-store: true + api-key: ${{ secrets.STEP_SECURITY_API_KEY }} - name: Install Harden Runner uses: checkmarx/harden-runner-action@9af89fc71515a100421586dfdb3dc9c984fbf411 #v2.19.4 with: diff --git a/.gitignore b/.gitignore index afb99c7b8..f77686f22 100644 --- a/.gitignore +++ b/.gitignore @@ -11,6 +11,7 @@ *.exe *.exe~ *.dll +/cx *.so *.dylib diff --git a/cmd/main.go b/cmd/main.go index ae5d46ceb..4352c6149 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -30,6 +30,7 @@ func main() { bindKeysToEnvAndDefault() err = configuration.LoadConfiguration() exitIfError(err) + wrappers.LoadActiveCredential() scans := viper.GetString(params.ScansPathKey) groups := viper.GetString(params.GroupsPathKey) logs := viper.GetString(params.LogsPathKey) diff --git a/go.mod b/go.mod index 2b69ab38d..b38020ab4 100644 --- a/go.mod +++ b/go.mod @@ -9,17 +9,19 @@ require ( github.com/Checkmarx/gen-ai-wrapper v1.0.3 github.com/Checkmarx/manifest-parser v0.1.2 github.com/Checkmarx/secret-detection v1.2.1 + github.com/CheckmarxDev/ast-cx-hooks v1.0.3 github.com/MakeNowJust/heredoc v1.0.0 github.com/alexbrainman/sspi v0.0.0-20210105120005-909beea2cc74 github.com/bouk/monkey v1.0.0 github.com/checkmarx/2ms/v3 v3.21.0 github.com/gofrs/flock v0.13.0 - github.com/golang-jwt/jwt/v5 v5.3.0 + github.com/golang-jwt/jwt/v5 v5.3.1 github.com/gomarkdown/markdown v0.0.0-20260417124207-7d523f7318df github.com/google/uuid v1.6.0 github.com/gookit/color v1.6.0 github.com/jcmturner/gokrb5/v8 v8.4.4 github.com/jsumners/go-getport v1.0.0 + github.com/modelcontextprotocol/go-sdk v1.6.1 github.com/mssola/user_agent v0.6.0 github.com/pkg/errors v0.9.1 github.com/spf13/cobra v1.10.2 @@ -44,6 +46,53 @@ require ( github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/clipperhouse/displaywidth v0.10.0 // indirect github.com/clipperhouse/uax29/v2 v2.6.0 // indirect + github.com/cncf/xds/go v0.0.0-20251210132809-ee656c7534f5 // indirect + github.com/containerd/containerd/v2 v2.2.3 // indirect + github.com/containerd/plugin v1.0.0 // indirect + github.com/diskfs/go-diskfs v1.7.0 // indirect + github.com/envoyproxy/go-control-plane/envoy v1.36.0 // indirect + github.com/envoyproxy/protoc-gen-validate v1.3.0 // indirect + github.com/go-jose/go-jose/v4 v4.1.4 // indirect + github.com/goccy/go-yaml v1.19.2 // indirect + github.com/gohugoio/hashstructure v0.6.0 // indirect + github.com/google/jsonschema-go v0.4.3 // indirect + github.com/google/s2a-go v0.1.9 // indirect + github.com/googleapis/enterprise-certificate-proxy v0.3.14 // indirect + github.com/googleapis/gax-go/v2 v2.17.0 // indirect + github.com/gpustack/gguf-parser-go v0.24.0 // indirect + github.com/hashicorp/aws-sdk-go-base/v2 v2.0.0-beta.72 // indirect + github.com/hashicorp/go-cleanhttp v0.5.2 // indirect + github.com/hashicorp/go-getter v1.8.6 // indirect + github.com/hashicorp/go-version v1.8.0 // indirect + github.com/hashicorp/golang-lru/v2 v2.0.7 // indirect + github.com/henvic/httpretty v0.1.4 // indirect + github.com/mholt/archives v0.1.5 // indirect + github.com/mikelolasagasti/xz v1.0.1 // indirect + github.com/minio/minlz v1.0.1 // indirect + github.com/moby/moby/api v1.54.1 // indirect + github.com/moby/moby/client v0.4.0 // indirect + github.com/nix-community/go-nix v0.0.0-20250101154619-4bdde671e0a1 // indirect + github.com/nwaples/rardecode/v2 v2.2.0 // indirect + github.com/olekukonko/cat v0.0.0-20250911104152-50322a0618f6 // indirect + github.com/olekukonko/errors v1.2.0 // indirect + github.com/olekukonko/ll v0.1.6 // indirect + github.com/pkg/xattr v0.4.9 // indirect + github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 // indirect + github.com/segmentio/asm v1.1.3 // indirect + github.com/segmentio/encoding v0.5.4 // indirect + github.com/smallnest/ringbuffer v0.0.0-20241116012123-461381446e3d // indirect + github.com/sorairolake/lzip-go v0.3.8 // indirect + github.com/spiffe/go-spiffe/v2 v2.6.0 // indirect + github.com/yosida95/uritemplate/v3 v3.0.2 // indirect + go.opentelemetry.io/contrib/detectors/gcp v1.39.0 // indirect + go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.63.0 // indirect + go.opentelemetry.io/otel/sdk v1.43.0 // indirect + go.opentelemetry.io/otel/sdk/metric v1.43.0 // indirect + go4.org v0.0.0-20230225012048-214862532bf5 // indirect + gonum.org/v1/gonum v0.16.0 // indirect + google.golang.org/api v0.271.0 // indirect + google.golang.org/genproto/googleapis/api v0.0.0-20260203192932-546029d2fa20 // indirect + sigs.k8s.io/structured-merge-diff/v6 v6.3.0 // indirect github.com/coreos/go-systemd/v22 v22.7.0 // indirect github.com/distribution/distribution/v3 v3.1.1 // indirect github.com/docker/docker v28.0.3+incompatible // indirect @@ -103,7 +152,7 @@ require ( github.com/becheran/wildmatch-go v1.0.0 // indirect github.com/bitnami/go-version v0.0.0-20250324202741-04b9d491e744 // indirect github.com/blang/semver/v4 v4.0.0 // indirect - github.com/bmatcuk/doublestar/v4 v4.10.0 // indirect + github.com/bmatcuk/doublestar/v4 v4.10.0 github.com/bwmarrin/discordgo v0.27.1 // indirect github.com/chai2010/gettext-go v1.0.3 // indirect github.com/charmbracelet/colorprofile v0.4.1 // indirect diff --git a/go.sum b/go.sum index 134ad80ff..fee9ca81c 100644 --- a/go.sum +++ b/go.sum @@ -81,6 +81,8 @@ github.com/Checkmarx/manifest-parser v0.1.2 h1:Sh2xkpeOWKu56Y7wo+ljckNGHAQX1uITE github.com/Checkmarx/manifest-parser v0.1.2/go.mod h1:hh5FX5FdDieU8CKQEkged4hfOaSylpJzub8PRFXa4kA= github.com/Checkmarx/secret-detection v1.2.1 h1:Hzpz74dcN/L14Q86ARvPOZpKBnERzGTpy6sl1RXKOTo= github.com/Checkmarx/secret-detection v1.2.1/go.mod h1:kbXbtIQisDdB/TNuV7r9HPclEznUyBHLQ5yr7IX7vBQ= +github.com/CheckmarxDev/ast-cx-hooks v1.0.3 h1:zMz6Ony8iWgKqjgUFvYhhqm5dr29sEO6r2pBl7fi/OM= +github.com/CheckmarxDev/ast-cx-hooks v1.0.3/go.mod h1:BNFcjgHhjxiPnKGHqiaWQycMMrkeT+DqokG/l7d9gs8= github.com/CycloneDX/cyclonedx-go v0.10.0 h1:7xyklU7YD+CUyGzSFIARG18NYLsKVn4QFg04qSsu+7Y= github.com/CycloneDX/cyclonedx-go v0.10.0/go.mod h1:vUvbCXQsEm48OI6oOlanxstwNByXjCZ2wuleUlwGEO8= github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU= @@ -438,6 +440,10 @@ github.com/gofrs/flock v0.13.0/go.mod h1:jxeyy9R1auM5S6JYDBhDt+E2TCo7DkratH4Pgi8 github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= +github.com/gohugoio/hashstructure v0.6.0 h1:7wMB/2CfXoThFYhdWRGv3u3rUM761Cq29CxUW+NltUg= +github.com/gohugoio/hashstructure v0.6.0/go.mod h1:lapVLk9XidheHG1IQ4ZSbyYrXcaILU1ZEP/+vno5rBQ= +github.com/golang-jwt/jwt/v5 v5.3.1 h1:kYf81DTWFe7t+1VvL7eS+jKFVWaUnK9cB1qbwn63YCY= +github.com/golang-jwt/jwt/v5 v5.3.1/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo= github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= @@ -503,6 +509,8 @@ github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX github.com/google/go-containerregistry v0.21.5 h1:KTJG9Pn/jC0VdZR6ctV3/jcN+q6/Iqlx0sTVz3ywZlM= github.com/google/go-containerregistry v0.21.5/go.mod h1:ySvMuiWg+dOsRW0Hw8GYwfMwBlNRTmpYBFJPlkco5zU= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/google/jsonschema-go v0.4.3 h1:/DBOLZTfDow7pe2GmaJNhltueGTtDKICi8V8p+DQPd0= +github.com/google/jsonschema-go v0.4.3/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE= github.com/google/licensecheck v0.3.1 h1:QoxgoDkaeC4nFrtGN1jV7IPmDCHFNIVh54e5hSt6sPs= github.com/google/licensecheck v0.3.1/go.mod h1:ORkR35t/JjW+emNKtfJDII0zlciG9JgbT7SmsohlHmY= github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= @@ -760,6 +768,8 @@ github.com/moby/sys/userns v0.1.0 h1:tVLXkFOxVu9A64/yh59slHVv9ahO9UIev4JZusOLG/g github.com/moby/sys/userns v0.1.0/go.mod h1:IHUYgu/kao6N8YZlp9Cf444ySSvCmDlmzUcYfDHOl28= github.com/moby/term v0.5.2 h1:6qk3FJAFDs6i/q3W/pQ97SX192qKfZgGjCQqfCJkgzQ= github.com/moby/term v0.5.2/go.mod h1:d3djjFCrjnB+fl8NJux+EJzu0msscUP+f8it8hPkFLc= +github.com/modelcontextprotocol/go-sdk v1.6.1 h1:0zOSupjKUxPKSocPT1Wtago+mUHU2/uZ4xSOY0FGReU= +github.com/modelcontextprotocol/go-sdk v1.6.1/go.mod h1:kzm3kzFL1/+AziGOE0nUs3gvPoNxMCvkxokMkuFapXQ= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= @@ -899,6 +909,10 @@ github.com/scylladb/go-set v1.0.3-0.20200225121959-cc7b2070d91e/go.mod h1:DkpGd7 github.com/sean-/seed v0.0.0-20170313163322-e2103e2c3529/go.mod h1:DxrIzT+xaE7yg65j358z/aeFdxmN0P9QXhEzd20vsDc= github.com/sebdah/goldie/v2 v2.8.0 h1:dZb9wR8q5++oplmEiJT+U/5KyotVD+HNGCAc5gNr8rc= github.com/sebdah/goldie/v2 v2.8.0/go.mod h1:oZ9fp0+se1eapSRjfYbsV/0Hqhbuu3bJVvKI/NNtssI= +github.com/segmentio/asm v1.1.3 h1:WM03sfUOENvvKexOLp+pCqgb/WDjsi7EK8gIsICtzhc= +github.com/segmentio/asm v1.1.3/go.mod h1:Ld3L4ZXGNcSLRg4JBsZ3//1+f/TjYl0Mzen/DQy1EJg= +github.com/segmentio/encoding v0.5.4 h1:OW1VRern8Nw6ITAtwSZ7Idrl3MXCFwXHPgqESYfvNt0= +github.com/segmentio/encoding v0.5.4/go.mod h1:HS1ZKa3kSN32ZHVZ7ZLPLXWvOVIiZtyJnO1gPH1sKt0= github.com/secDre4mer/pkcs7 v0.0.0-20240322103146-665324a4461d h1:RQqyEogx5J6wPdoxqL132b100j8KjcVHO1c0KLRoIhc= github.com/secDre4mer/pkcs7 v0.0.0-20240322103146-665324a4461d/go.mod h1:PegD7EVqlN88z7TpCqH92hHP+GBpfomGCCnw1PFtNOA= github.com/sergi/go-diff v1.1.0/go.mod h1:STckp+ISIX8hZLjrqAeVduY0gWCT9IjLuqbuNXdaHfM= @@ -1006,6 +1020,8 @@ github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavM github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM= github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU= github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= +github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= +github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= diff --git a/internal/commands/agenthooks.go b/internal/commands/agenthooks.go new file mode 100644 index 000000000..f16d69ac7 --- /dev/null +++ b/internal/commands/agenthooks.go @@ -0,0 +1,217 @@ +package commands + +import ( + "fmt" + "os" + "strings" + + "github.com/MakeNowJust/heredoc" + cxhooks "github.com/checkmarx/ast-cli/internal/commands/agenthooks/cx" + "github.com/checkmarx/ast-cli/internal/logger" + "github.com/checkmarx/ast-cli/internal/params" + "github.com/checkmarx/ast-cli/internal/wrappers" + "github.com/checkmarx/ast-cli/internal/wrappers/configuration" + "github.com/pkg/errors" + "github.com/spf13/cobra" +) + +// ============================================================================= +// License check +// ============================================================================= + +// isLicensed loads CLI config silently and checks whether the token carries +// a CxOne Assist, AI Protection, or Developer Assist license. +func isLicensed(jwt wrappers.JWTWrapper) bool { + if err := configuration.LoadConfiguration(); err != nil { + logger.PrintIfVerbose(fmt.Sprintf("hooks: config load warning - %v", err)) + } + authErrorLogged := false + for _, engine := range []string{ + params.CheckmarxOneAssistType, + params.AIProtectionType, + params.CheckmarxDevAssistType, + } { + ok, err := jwt.IsAllowedEngine(engine) + if err != nil { + if !authErrorLogged { + logger.PrintIfVerbose(fmt.Sprintf("hooks: authentication failed - %v", err)) + authErrorLogged = true + } + continue + } + if ok { + logger.PrintIfVerbose(fmt.Sprintf("hooks: AI feature license found (%s)", engine)) + return true + } + } + if authErrorLogged { + logger.PrintIfVerbose("hooks: running in pass-through mode (not authenticated)") + } else { + logger.PrintIfVerbose("hooks: running in pass-through mode (no AI feature license)") + } + return false +} + +// ============================================================================= +// Hook dispatch commands — hidden subcommands for all AI agent hook events. +// +// Agents invoke: cx hooks +// Each route reads JSON from stdin and writes the verdict as JSON to stdout. +// Routes are declared per-agent in cxhooks.Agents (cx package). +// ============================================================================= + +func HookDispatchCommands(jwt wrappers.JWTWrapper, featureFlags wrappers.FeatureFlagsWrapper, realtimeScanner wrappers.RealtimeScannerWrapper, telemetryWrapper wrappers.TelemetryWrapper) []*cobra.Command { + var cmds []*cobra.Command + for _, agent := range cxhooks.Agents { + for _, r := range agent.Routes { + r := r + cmds = append(cmds, &cobra.Command{ + Use: r.Use, + Short: r.Short, + Hidden: true, + // Override root PersistentPreRunE — any stdout from config loading + // would corrupt the JSON response the agent expects. + PersistentPreRunE: func(*cobra.Command, []string) error { return nil }, + Run: func(cmd *cobra.Command, _ []string) { + if isLicensed(jwt) { + logger.PrintIfVerbose(fmt.Sprintf("hooks: registering security guardrails for %s", cmd.Use)) + cxhooks.RegisterGuardrails(jwt, featureFlags, realtimeScanner, telemetryWrapper) + } else { + logger.PrintIfVerbose(fmt.Sprintf("hooks: registering pass-through for %s", cmd.Use)) + cxhooks.RegisterPassThrough() + } + cxhooks.DispatchRoute(cmd.Use) + }, + }) + } + } + return cmds +} + +// ============================================================================= +// Management command — cx hooks agenthooks install [agent] +// ============================================================================= + +func NewAgentHooksCommand() *cobra.Command { + cmd := &cobra.Command{ + Use: "agenthooks", + Short: "Manage AI coding agent hook configuration", + Long: "Configure AI coding agent hooks to invoke cx directly. Supports " + agentDisplayList() + ".", + Example: heredoc.Doc(` + $ cx hooks agenthooks install # install for all agents + $ cx hooks agenthooks install cursor # install for Cursor only + $ cx hooks agenthooks install claude # install for Claude Code only + `), + } + cmd.AddCommand(agentHooksInstallCmd()) + return cmd +} + +func agentHooksInstallCmd() *cobra.Command { + installCmd := &cobra.Command{ + Use: "install", + Short: "Write hook config for all AI coding agents", + Long: heredoc.Doc(` + Patches the hook configuration for all supported AI coding agents + so they invoke "cx hooks " on hook events. + + Supported agents and their config files: + `) + agentConfigTable(), + Example: " $ cx hooks agenthooks install", + RunE: func(*cobra.Command, []string) error { + return runInstall(allAgentIDs()...) + }, + } + + for _, agent := range cxhooks.Agents { + agent := agent + installCmd.AddCommand(&cobra.Command{ + Use: agent.ID, + Short: "Write hook config for " + agent.DisplayName, + Example: " $ cx hooks agenthooks install " + agent.ID, + RunE: func(*cobra.Command, []string) error { + return runInstall(agent.ID) + }, + }) + } + + return installCmd +} + +// runInstall installs hooks for the given agent IDs. With a single ID it +// returns errors directly (so cobra surfaces them as the command failure); +// with multiple IDs it attempts all agents and aggregates failures. +func runInstall(ids ...string) error { + cxPath, err := os.Executable() + if err != nil { + return errors.Wrap(err, "resolving cx binary path") + } + home, err := os.UserHomeDir() + if err != nil { + return errors.Wrap(err, "finding home directory") + } + cmdFor := cxhooks.CxCmdFor(cxPath) + + single := len(ids) == 1 + var failed int + for _, id := range ids { + agent := cxhooks.FindAgent(id) + if agent == nil { + if single { + return fmt.Errorf("unknown agent %q", id) + } + fmt.Fprintf(os.Stderr, "✗ %s: no installer registered\n", id) + failed++ + continue + } + if installErr := agent.Install(home, cmdFor); installErr != nil { + if single { + return fmt.Errorf("%s: %w", agent.DisplayName, installErr) + } + fmt.Fprintf(os.Stderr, "✗ %s: %v\n", agent.DisplayName, installErr) + failed++ + continue + } + fmt.Fprintf(os.Stdout, "✓ %s configured\n", agent.DisplayName) + } + if failed > 0 { + return fmt.Errorf("%d agent(s) failed to configure", failed) + } + return nil +} + +func allAgentIDs() []string { + ids := make([]string, len(cxhooks.Agents)) + for i, a := range cxhooks.Agents { + ids[i] = a.ID + } + return ids +} + +// agentDisplayList formats agent display names as a natural-language list: +// "A, B, C, and D" (Oxford comma, "and" before the last entry). +func agentDisplayList() string { + names := make([]string, len(cxhooks.Agents)) + for i, a := range cxhooks.Agents { + names[i] = a.DisplayName + } + switch len(names) { + case 0: + return "" + case 1: + return names[0] + case 2: + return names[0] + " and " + names[1] + } + return strings.Join(names[:len(names)-1], ", ") + ", and " + names[len(names)-1] +} + +// agentConfigTable formats agent ID + config path as an indented two-column +// table for inclusion in command help text. +func agentConfigTable() string { + var b strings.Builder + for _, a := range cxhooks.Agents { + fmt.Fprintf(&b, " %-8s %s\n", a.ID, a.ConfigPath) + } + return b.String() +} diff --git a/internal/commands/agenthooks/cx/dispatch.go b/internal/commands/agenthooks/cx/dispatch.go new file mode 100644 index 000000000..05ab3aacc --- /dev/null +++ b/internal/commands/agenthooks/cx/dispatch.go @@ -0,0 +1,17 @@ +package cx + +import ( + "os" + + agenthooks "github.com/CheckmarxDev/ast-cx-hooks" +) + +// DispatchRoute calls the handler registered under route without requiring the +// caller to mutate os.Args. Use this when the route name is already known (e.g. +// from cmd.Use after Cobra has consumed it from os.Args). +func DispatchRoute(route string) { + saved := os.Args + os.Args = []string{saved[0], route} + agenthooks.Dispatch() + os.Args = saved +} diff --git a/internal/commands/agenthooks/cx/hooks.go b/internal/commands/agenthooks/cx/hooks.go new file mode 100644 index 000000000..91651ecf9 --- /dev/null +++ b/internal/commands/agenthooks/cx/hooks.go @@ -0,0 +1,258 @@ +package cx + +import ( + "log" + "os" + "strings" + + agenthooks "github.com/CheckmarxDev/ast-cx-hooks" + "github.com/CheckmarxDev/ast-cx-hooks/cursor" + "github.com/checkmarx/ast-cli/internal/commands/agenthooks/guardrails" + "github.com/checkmarx/ast-cli/internal/commands/agenthooks/guardrails/kics" + "github.com/checkmarx/ast-cli/internal/commands/agenthooks/sca" + "github.com/checkmarx/ast-cli/internal/wrappers" +) + +// scaScanner is the package-level SCA scanner used by the guardrail handlers. +// It is set by RegisterGuardrails so the handlers (free functions registered +// with the agenthooks library) can reach it without an injection mechanism. +var scaScanner *sca.Scanner + +// kicsScanner is the package-level KICS scanner used by the file-edit guardrail. +// It is set by RegisterGuardrails and cleared by RegisterPassThrough. +var kicsScanner *kics.Scanner + +var telemetryWrapper wrappers.TelemetryWrapper + +// cxWhenAgentIdle: agent finished its turn. Nothing to enforce yet. +func cxWhenAgentIdle(_ agenthooks.AgentIdleEvent) agenthooks.IdleVerdict { + return agenthooks.Resume() +} + +// cxBeforeToolCall gates shell execution against the organization's blacklist, +// tool rules, and the SCA guardrail (malicious / vulnerable package installs). +func cxBeforeToolCall(ev agenthooks.ToolCallEvent) agenthooks.ToolVerdict { + if !ev.IsShell() { + return agenthooks.Allow() + } + blocked, needsConfirm, reason := guardrails.CheckShellCommand(ev.Command, ev.WorkDir) + if blocked { + if needsConfirm { + return agenthooks.AskUser(reason) + } + return agenthooks.Deny(reason) + } + if scaScanner != nil { + if finding, remediation := scaScanner.CheckBashInstall(ev.Command, ev.WorkDir); finding != "" { + agent := agentToString(ev.Agent) + logRemediationTelemetry(agent, "SCA", finding, remediation) + return agenthooks.DenyWithContext(finding, remediation) + } + } + return agenthooks.Allow() +} + +// cxBeforeFileEdit gates two distinct events the library multiplexes through +// the same handler signature: +// +// 1. File EDITS (Claude / Windsurf / Droid / Gemini) — ev.Changes is populated. +// Enforce blast_radius_limit, files_limits.max_total_file_size_kb, the ASCA +// guardrail (AI-introduced code vulnerabilities), the KICS guardrail +// (IaC security vulnerabilities), and the SCA guardrail +// (malicious / vulnerable manifest additions) before any bytes are written +// to disk. MultiEdit and multi-file edits are handled uniformly by iterating +// ev.Changes. +// +// 2. Cursor file READS (beforeReadFile) — ev.Changes is empty and ev.FilePath +// points to a file the agent is about to ingest into the LLM context. +// Cursor's hook payload carries only the path, so we open the file and run +// the 2ms scanner over its contents. Blocks the read if secrets are found +// or if the file exceeds the policy size cap. Reads do NOT count toward +// the blast-radius budget (that limit is about writes). +func cxBeforeFileEdit(ev agenthooks.FileEditEvent) agenthooks.FileEditVerdict { + if ev.Agent == agenthooks.AgentCursor && len(ev.Changes) == 0 { + if reason := guardrails.ScanFileForSecrets(ev.FilePath); reason != "" { + return agenthooks.RejectEdit(reason) + } + return agenthooks.AcceptEdit() + } + + if blocked, reason := guardrails.CheckAndIncrementBlastRadius(); blocked { + return agenthooks.RejectEdit(reason) + } + var totalBytes int64 + for _, diff := range ev.Changes { + totalBytes += int64(len(diff.After)) + } + if blocked, reason := guardrails.CheckAndIncrementTotalFileSize(totalBytes); blocked { + return agenthooks.RejectEdit(reason) + } + agent := agentToString(ev.Agent) + if blocked, reason, context := asca.ScanFileEdit(ev, telemetryWrapper, agent); blocked { + logRemediationTelemetry(agent, "Asca", reason, context) + return agenthooks.RejectEditWithContext(reason, context) + } + if kicsScanner != nil { + if blocked, reason, context := kics.ScanFileEdit(ev, kicsScanner); blocked { + return agenthooks.RejectEditWithContext(reason, context) + } + } + if scaScanner != nil { + for _, diff := range ev.Changes { + if finding, remediation := scaScanner.CheckManifestEdit(ev.FilePath, fullAfterContent(ev.FilePath, diff), ev.WorkDir, ev.WorkDir); finding != "" { + logRemediationTelemetry(agent, "Oss", finding, remediation) + return agenthooks.RejectEditWithContext(finding, remediation) + } + } + } + return agenthooks.AcceptEdit() +} + +// fullAfterContent returns the complete new file content for a diff. +// Write ops set diff.Before to "" and diff.After to the full new content. +// Edit ops set diff.After only to the replacement snippet, so we +// reconstruct by applying the replacement to the current file on disk. +// +// Reconstruction must not depend on the checkout's line-ending style. An +// agent/editor may send the replaced region with LF endings while the file on +// disk uses CRLF (Windows / git core.autocrlf=true), or vice versa. A +// byte-exact strings.Replace then finds no match, returns the file unchanged, +// and any newly added (possibly vulnerable) dependency slips past the scanner +// silently. We therefore try an exact match first, then fall back to a +// line-ending–normalized match. Line endings are irrelevant to manifest +// dependency parsing, so scanning the normalized content is safe. +func fullAfterContent(filePath string, diff agenthooks.FileDiff) []byte { + if diff.Before == "" { + return []byte(diff.After) + } + current, err := os.ReadFile(filePath) + if err != nil { + return []byte(diff.After) + } + cur := string(current) + + // 1) Exact replacement (fast path; preserves original bytes). + if out := strings.Replace(cur, diff.Before, diff.After, 1); out != cur { + return []byte(out) + } + + // 2) Line-ending–agnostic replacement. Normalize both the file and the + // diff region to LF, then replace. This makes reconstruction independent + // of CRLF vs LF differences between machines and checkouts. + curN := normalizeNewlines(cur) + if out := strings.Replace(curN, normalizeNewlines(diff.Before), normalizeNewlines(diff.After), 1); out != curN { + return []byte(out) + } + + // 3) Fail-safe: the replaced region could not be located even after + // normalization. Do not silently accept by returning the unchanged file. + // Surface the anomaly and fall back to scanning the proposed snippet so a + // newly added dependency is still given a chance to be detected. + log.Printf("sca guardrail: could not locate edited region in %q (line-ending or whitespace mismatch); scanning proposed snippet as fallback", filePath) + return []byte(normalizeNewlines(diff.After)) +} + +// normalizeNewlines converts CRLF and lone CR line endings to LF. +func normalizeNewlines(s string) string { + return strings.ReplaceAll(strings.ReplaceAll(s, "\r\n", "\n"), "\r", "\n") +} + +// cxBeforePrompt runs all prompt guardrails before the prompt reaches the AI agent. +func cxBeforePrompt(ev agenthooks.PromptEvent) agenthooks.PromptVerdict { + if reason := guardrails.ScanPrompt(ev.Text); reason != "" { + return agenthooks.RejectPrompt(reason) + } + roots := promptWorkspaceRoots(ev.Raw) + if reason := guardrails.ScanReferencedFiles(ev.Text, roots); reason != "" { + return agenthooks.RejectPrompt(reason) + } + if reason := guardrails.ScanWorkspaceFilesByPromptName(ev.Text, roots); reason != "" { + return agenthooks.RejectPrompt(reason) + } + return agenthooks.AcceptPrompt() +} + +// promptWorkspaceRoots returns the anchor(s) for resolving relative file paths +// in the prompt. Cursor sends workspace_roots in its hook payload; when present +// we use them directly. Otherwise (other agents, or missing field) fall back to +// the hook process's CWD. +func promptWorkspaceRoots(raw any) []string { + if cev, ok := raw.(*cursor.PromptPreEvent); ok && len(cev.WorkspaceRoots) > 0 { + return cev.WorkspaceRoots + } + cwd, err := os.Getwd() + if err != nil { + return nil + } + return []string{cwd} +} + +// RegisterGuardrails wires the four guardrail handlers and instantiates the +// SCA and KICS scanners used by the Bash and FileEdit handlers. +func RegisterGuardrails(jwt wrappers.JWTWrapper, ff wrappers.FeatureFlagsWrapper, rt wrappers.RealtimeScannerWrapper, tel wrappers.TelemetryWrapper) { + scaScanner = sca.NewScanner(jwt, ff, rt) + kicsScanner = kics.NewScanner(jwt, ff) + telemetryWrapper = tel + agenthooks.WhenAgentIdle(cxWhenAgentIdle) + agenthooks.BeforeToolCall(cxBeforeToolCall) + agenthooks.BeforeFileEdit(cxBeforeFileEdit) + agenthooks.BeforePrompt(cxBeforePrompt) +} + +// RegisterPassThrough wires no-op handlers that always allow the action. +// Used when the license check fails so we still emit valid JSON (fail-open). +func RegisterPassThrough() { + scaScanner = nil + kicsScanner = nil + agenthooks.WhenAgentIdle(func(_ agenthooks.AgentIdleEvent) agenthooks.IdleVerdict { return agenthooks.Resume() }) + agenthooks.BeforeToolCall(func(_ agenthooks.ToolCallEvent) agenthooks.ToolVerdict { return agenthooks.Allow() }) + agenthooks.BeforeFileEdit(func(_ agenthooks.FileEditEvent) agenthooks.FileEditVerdict { return agenthooks.AcceptEdit() }) + agenthooks.BeforePrompt(func(_ agenthooks.PromptEvent) agenthooks.PromptVerdict { return agenthooks.AcceptPrompt() }) +} + +// logRemediationTelemetry sends telemetry when remediation context is delivered to the agent. +func logRemediationTelemetry(agent, engine, finding, remediationContext string) { + if telemetryWrapper == nil { + return + } + + telemetryData := &wrappers.DataForAITelemetry{ + //agent = aiProvider + //hooks-detect for detection + //subtype = scan + // hooks-remeditae + //subType = fixWithAIchet + + AIProvider: agent, + Agent: agent + "-cli", + Engine: engine, + ScanType: strings.ToLower(engine), + UniqueID: wrappers.GetUniqueID(), + Type: "hooks-remediate", + SubType: "fixWithAIAssist", + } + + if err := telemetryWrapper.SendAIDataToLog(telemetryData); err != nil { + // fail-open + } +} + +// agentToString converts agenthooks.AgentID enum to string representation for telemetry. +func agentToString(agent agenthooks.AgentID) string { + switch agent { + case agenthooks.AgentClaude: + return "Claude" + case agenthooks.AgentCopilot: + return "Copilot" + case agenthooks.AgentCursor: + return "Cursor" + case agenthooks.AgentGemini: + return "Gemini" + case agenthooks.AgentDroid: + return "Droid" + case agenthooks.AgentWindsurf: + return "Windsurf" + default: + return "Unknown" + } +} diff --git a/internal/commands/agenthooks/cx/install.go b/internal/commands/agenthooks/cx/install.go new file mode 100644 index 000000000..d11296642 --- /dev/null +++ b/internal/commands/agenthooks/cx/install.go @@ -0,0 +1,125 @@ +package cx + +import ( + "github.com/CheckmarxDev/ast-cx-hooks/install" +) + +// Route is a single hook dispatch route exposed as a hidden `cx hooks ` +// subcommand. The Use string is what the agent invokes; Short is the cobra +// short description. +type Route struct { + Use string + Short string +} + +// Agent describes one supported AI coding agent: how to install its hook +// config and which dispatch routes its hooks invoke. The Routes list must +// match the route names the corresponding Install function writes into the +// agent's config file — the canonical installers live in ast-cx-hooks's +// install package; here we just thread the cx-CLI route prefix through. +type Agent struct { + ID string + DisplayName string + ConfigPath string + Install func(home string, cmdFor install.CmdForFunc) error + Routes []Route +} + +// Agents is the single source of truth for supported AI coding agents. +// Adding a new agent is one entry here plus one installer in ast-cx-hooks. +var Agents = []Agent{ + { + ID: "claude", + DisplayName: "Claude Code", + ConfigPath: "~/.claude/settings.json", + Install: install.InstallClaude, + Routes: []Route{ + {"claude-stop", "Claude Code agent finished"}, + {"claude-pre-tool-use", "Gate Claude Code tool use"}, + {"claude-pre-file-write", "Gate Claude Code file write"}, + {"claude-user-prompt-submit", "Gate Claude Code prompt"}, + }, + }, + { + ID: "cursor", + DisplayName: "Cursor", + ConfigPath: "~/.cursor/hooks.json", + Install: install.InstallCursor, + Routes: []Route{ + {"cursor-stop", "Cursor agent finished"}, + {"cursor-before-shell", "Gate Cursor shell execution"}, + {"cursor-before-mcp", "Gate Cursor MCP execution"}, + {"cursor-before-file-read", "Gate Cursor file read"}, + {"cursor-after-file-edit", "React to Cursor file edit"}, + {"cursor-before-submit-prompt", "Gate Cursor prompt"}, + }, + }, + { + ID: "windsurf", + DisplayName: "Windsurf", + ConfigPath: "~/.codeium/windsurf/hooks.json", + Install: install.InstallWindsurf, + Routes: []Route{ + {"windsurf-pre-run-command", "Gate Windsurf shell execution"}, + {"windsurf-pre-mcp-tool-use", "Gate Windsurf MCP execution"}, + {"windsurf-pre-user-prompt", "Gate Windsurf prompt"}, + {"windsurf-pre-write-code", "Gate Windsurf file write"}, + {"windsurf-post-cascade-response", "Windsurf agent finished"}, + }, + }, + { + ID: "droid", + DisplayName: "Factory Droid", + ConfigPath: "~/.factory/settings.json", + Install: install.InstallDroid, + Routes: []Route{ + {"droid-stop", "Factory Droid agent finished"}, + {"droid-pre-tool-use", "Gate Factory Droid tool use"}, + {"droid-pre-file-write", "Gate Factory Droid file write"}, + {"droid-user-prompt-submit", "Gate Factory Droid prompt"}, + }, + }, + { + ID: "gemini", + DisplayName: "Gemini CLI", + ConfigPath: "~/.gemini/settings.json", + Install: install.InstallGemini, + Routes: []Route{ + {"gemini-before-agent", "Gemini CLI agent starting"}, + {"gemini-before-tool", "Gate Gemini CLI tool execution"}, + {"gemini-before-file-tool", "Gate Gemini CLI file write"}, + {"gemini-after-agent", "Gemini CLI agent finished"}, + }, + }, + { + ID: "copilot", + DisplayName: "GitHub Copilot CLI", + ConfigPath: "~/.copilot/hooks/agenthooks.json", + Install: install.InstallCopilotCLI, + Routes: []Route{ + {"copilot-cli-stop", "GitHub Copilot CLI agent finished"}, + {"copilot-cli-pre-tool-use", "Gate GitHub Copilot CLI tool use"}, + {"copilot-cli-pre-file-write", "Gate GitHub Copilot CLI file write"}, + {"copilot-cli-user-prompt-submit", "Gate GitHub Copilot CLI prompt"}, + }, + }, +} + +// FindAgent returns the Agent with the given ID, or nil if not found. +func FindAgent(id string) *Agent { + for i := range Agents { + if Agents[i].ID == id { + return &Agents[i] + } + } + return nil +} + +// CxCmdFor returns a CmdForFunc that maps each route to the shell command +// " hooks " — the form ast-cli uses when wiring its own routes +// into agent config files. +func CxCmdFor(cxPath string) install.CmdForFunc { + return func(route string) string { + return install.FormatCommand(cxPath, "hooks", route) + } +} diff --git a/internal/commands/agenthooks/cx/install_test.go b/internal/commands/agenthooks/cx/install_test.go new file mode 100644 index 000000000..1692bd601 --- /dev/null +++ b/internal/commands/agenthooks/cx/install_test.go @@ -0,0 +1,50 @@ +//go:build !integration + +package cx + +import "testing" + +// TestFindAgentCopilot pins the GitHub Copilot CLI agent entry: its config path +// and the curated route set the installer mirrors. The route Use names must match +// the copilot-cli-* routes ast-cx-hooks registers, or `cx hooks agenthooks install +// copilot` would write commands that don't resolve. +func TestFindAgentCopilot(t *testing.T) { + agent := FindAgent("copilot") + if agent == nil { + t.Fatal("FindAgent(\"copilot\") returned nil; Copilot agent not registered") + } + if agent.DisplayName != "GitHub Copilot CLI" { + t.Errorf("DisplayName = %q, want %q", agent.DisplayName, "GitHub Copilot CLI") + } + if agent.ConfigPath != "~/.copilot/hooks/agenthooks.json" { + t.Errorf("ConfigPath = %q, want %q", agent.ConfigPath, "~/.copilot/hooks/agenthooks.json") + } + if agent.Install == nil { + t.Error("Install func is nil") + } + + wantRoutes := []string{ + "copilot-cli-stop", + "copilot-cli-pre-tool-use", + "copilot-cli-pre-file-write", + "copilot-cli-user-prompt-submit", + } + if len(agent.Routes) != len(wantRoutes) { + t.Fatalf("got %d routes, want %d: %+v", len(agent.Routes), len(wantRoutes), agent.Routes) + } + for i, want := range wantRoutes { + if agent.Routes[i].Use != want { + t.Errorf("Routes[%d].Use = %q, want %q", i, agent.Routes[i].Use, want) + } + if agent.Routes[i].Short == "" { + t.Errorf("Routes[%d] (%q) has empty Short description", i, want) + } + } +} + +// TestFindAgentUnknown verifies FindAgent returns nil for an unregistered id. +func TestFindAgentUnknown(t *testing.T) { + if a := FindAgent("not-a-real-agent"); a != nil { + t.Errorf("FindAgent of unknown id = %+v, want nil", a) + } +} diff --git a/internal/commands/agenthooks/guardrails/asca/asca.go b/internal/commands/agenthooks/guardrails/asca/asca.go new file mode 100644 index 000000000..16f1f8d19 --- /dev/null +++ b/internal/commands/agenthooks/guardrails/asca/asca.go @@ -0,0 +1,137 @@ +package asca + +import ( + "os" + "path/filepath" + "strings" + + agenthooks "github.com/CheckmarxDev/ast-cx-hooks" + "github.com/checkmarx/ast-cli/internal/params" + "github.com/checkmarx/ast-cli/internal/services" + "github.com/checkmarx/ast-cli/internal/services/realtimeengine/ignore" + "github.com/checkmarx/ast-cli/internal/wrappers" + "github.com/checkmarx/ast-cli/internal/wrappers/grpcs" + "github.com/spf13/viper" +) + +// ascaSupportedExtensions lists file extensions for languages ASCA can scan: +// Java, JavaScript (Node.js), C#, Go, and Python. +var ascaSupportedExtensions = map[string]struct{}{ + ".java": {}, ".js": {}, ".jsx": {}, ".ts": {}, ".tsx": {}, ".mjs": {}, ".cjs": {}, + ".cs": {}, ".go": {}, ".py": {}, ".pyw": {}, +} + +// isSupportedByASCA returns true when the file's extension is one ASCA can scan. +func isSupportedByASCA(filePath string) bool { + ext := strings.ToLower(filepath.Ext(filePath)) + _, ok := ascaSupportedExtensions[ext] + return ok +} + +// ScanFileEdit runs ASCA on the proposed post-edit content. +// Returns blocked=true with a formatted reason and remediation context when ASCA +// finds *new* vulnerabilities introduced by ev.Changes (delta-detection for edits; +// any-vuln for new writes). Findings the user already suppressed via +// `cx ignore-vulnerability` (the realtime ignore file) are filtered out before the +// verdict. Fail-open on infrastructure errors (ASCA install fail, engine unavailable, panic). +func ScanFileEdit(ev agenthooks.FileEditEvent) (blocked bool, reason, context string) { + defer func() { + if r := recover(); r != nil { + blocked = false + reason = "" + context = "" + } + }() + + if !isSupportedByASCA(ev.FilePath) { + return false, "", "" + } + + newContent, originalContent, err := ProposedContent(ev.FilePath, ev.Changes) + if err != nil || newContent == "" { + return false, "", "" + } + + wrapperParams := services.AscaWrappersParam{ + JwtWrapper: wrappers.NewJwtWrapper(), + ASCAWrapper: grpcs.NewASCAGrpcWrapper(viper.GetInt(params.ASCAPortKey)), + } + + ascaParams := services.AscaScanParams{ + ASCAUpdateVersion: shouldUpdateVersion(), + IsDefaultAgent: true, // license already verified upstream in agenthooks.go + // Honor findings the user suppressed via `cx ignore-vulnerability` so the hook + // stops blocking them. Only set when the file exists: the ASCA service treats a + // configured-but-missing ignore path as a scan error, which would fail-open the + // guardrail entirely. + IgnoredFilePath: existingIgnoreFilePath(ev.WorkDir), + } + + // Stage and scan the proposed (new) content + stagedNew, cleanupNew, err := stageForScan(ev.FilePath, newContent, ev.SessionID) + if err != nil { + return false, "", "" + } + defer cleanupNew() + + ascaParams.FilePath = stagedNew + newResult, err := services.CreateASCAScanRequest(ascaParams, wrapperParams) + if err != nil || newResult == nil { + return false, "", "" + } + if newResult.Error != nil { + return false, "", "" + } + if len(newResult.ScanDetails) == 0 { + return false, "", "" + } + + // For new files (no original content), every finding is new + if originalContent == "" { + r, c := formatFindings(ev.FilePath, newResult.ScanDetails, ev.WorkDir) + return true, r, c + } + + // Delta: scan original content and find only newly introduced findings + stagedOrig, cleanupOrig, err := stageForScan(ev.FilePath, originalContent, ev.SessionID) + if err != nil { + return false, "", "" + } + defer cleanupOrig() + + ascaParams.FilePath = stagedOrig + origResult, err := services.CreateASCAScanRequest(ascaParams, wrapperParams) + if err != nil || origResult == nil { + return false, "", "" + } + var origDetails []grpcs.ScanDetail + if origResult.Error == nil { + origDetails = origResult.ScanDetails + } + + newFindings := NewFindings(origDetails, newResult.ScanDetails) + if len(newFindings) == 0 { + return false, "", "" + } + + r, c := formatFindings(ev.FilePath, newFindings, ev.WorkDir) + return true, r, c +} + +// existingIgnoreFilePath returns the default realtime ignore-file path only when it +// exists on disk. The ASCA service short-circuits the scan with an error when a +// configured ignore path is missing, so we pass it only once the user has created it +// via `cx ignore-vulnerability`; otherwise the scan runs without ignore filtering. +func existingIgnoreFilePath(workDir string) string { + p := ignore.PathFor(workDir) + if _, err := os.Stat(p); err == nil { + return p + } + return "" +} + +// shouldUpdateVersion returns whether ASCA should check for a newer version. +func shouldUpdateVersion() bool { + v := viper.GetString(params.DisableASCALatestVersionKey) + return v != "true" +} diff --git a/internal/commands/agenthooks/guardrails/asca/asca_test.go b/internal/commands/agenthooks/guardrails/asca/asca_test.go new file mode 100644 index 000000000..ae3ef4897 --- /dev/null +++ b/internal/commands/agenthooks/guardrails/asca/asca_test.go @@ -0,0 +1,327 @@ +//go:build !integration + +package asca + +import ( + "os" + "path/filepath" + "runtime" + "strings" + "testing" + + agenthooks "github.com/CheckmarxDev/ast-cx-hooks" + "github.com/checkmarx/ast-cli/internal/services/realtimeengine/ignore" + "github.com/checkmarx/ast-cli/internal/wrappers/grpcs" +) + +// ── ProposedContent ───────────────────────────────────────────────────────── + +func TestProposedContent_FullFileWrite(t *testing.T) { + newContent, _, err := ProposedContent("/nonexistent/auth.py", []agenthooks.FileDiff{ + {Before: "", After: "print('hello')"}, + }) + if err != nil { + t.Fatal(err) + } + if newContent != "print('hello')" { + t.Fatalf("want %q, got %q", "print('hello')", newContent) + } +} + +func TestProposedContent_FullFileWrite_OriginalEmpty_WhenFileAbsent(t *testing.T) { + _, orig, err := ProposedContent("/nonexistent/auth.py", []agenthooks.FileDiff{ + {Before: "", After: "new content"}, + }) + if err != nil { + t.Fatal(err) + } + if orig != "" { + t.Fatalf("expected empty originalContent for absent file, got %q", orig) + } +} + +func TestProposedContent_StringReplaceEdit(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "app.py") + if err := os.WriteFile(path, []byte("x = 1\ny = 2\n"), 0o600); err != nil { + t.Fatal(err) + } + + newContent, origContent, err := ProposedContent(path, []agenthooks.FileDiff{ + {Before: "y = 2", After: "y = 99"}, + }) + if err != nil { + t.Fatal(err) + } + if origContent != "x = 1\ny = 2\n" { + t.Fatalf("unexpected orig: %q", origContent) + } + if newContent != "x = 1\ny = 99\n" { + t.Fatalf("unexpected new: %q", newContent) + } +} + +func TestProposedContent_MissingBeforeFailsOpen(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "app.py") + if err := os.WriteFile(path, []byte("a = 1\n"), 0o600); err != nil { + t.Fatal(err) + } + + // Before string not present → returns original unchanged + newContent, origContent, err := ProposedContent(path, []agenthooks.FileDiff{ + {Before: "NOTHERE", After: "replacement"}, + }) + if err != nil { + t.Fatal(err) + } + if newContent != origContent { + t.Fatalf("expected content unchanged, got %q", newContent) + } +} + +func TestProposedContent_MultiEdit(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "app.py") + if err := os.WriteFile(path, []byte("a\nb\nc\n"), 0o600); err != nil { + t.Fatal(err) + } + + newContent, _, err := ProposedContent(path, []agenthooks.FileDiff{ + {Before: "a", After: "A"}, + {Before: "b", After: "B"}, + }) + if err != nil { + t.Fatal(err) + } + if newContent != "A\nB\nc\n" { + t.Fatalf("unexpected multi-edit result: %q", newContent) + } +} + +// ── stageForScan / safeSessionTag ─────────────────────────────────────────── + +func TestSafeSessionTag_Empty(t *testing.T) { + if got := safeSessionTag(""); got != "anon" { + t.Fatalf("want anon, got %q", got) + } +} + +func TestSafeSessionTag_AllSpecialChars(t *testing.T) { + if got := safeSessionTag("!!!???"); got != "anon" { + t.Fatalf("want anon, got %q", got) + } +} + +func TestSafeSessionTag_UUID(t *testing.T) { + got := safeSessionTag("550e8400-e29b-41d4-a716-446655440000") + if len(got) > 8 { + t.Fatalf("expected ≤8 chars, got %q (len %d)", got, len(got)) + } + for _, r := range got { + if !((r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || + (r >= '0' && r <= '9') || r == '-' || r == '_') { + t.Fatalf("unexpected char %q in tag %q", r, got) + } + } +} + +func TestStageForScan_CreatesFileWithOriginalBasename(t *testing.T) { + staged, cleanup, err := stageForScan("/some/path/auth.py", "content", "sess123") + if err != nil { + t.Fatal(err) + } + defer cleanup() + + if filepath.Base(staged) != "auth.py" { + t.Fatalf("expected basename auth.py, got %q", filepath.Base(staged)) + } + data, err := os.ReadFile(staged) + if err != nil { + t.Fatal(err) + } + if string(data) != "content" { + t.Fatalf("file content mismatch: %q", string(data)) + } +} + +func TestStageForScan_DirNameContainsSessionTag(t *testing.T) { + staged, cleanup, err := stageForScan("/tmp/foo.py", "x", "abc123") + if err != nil { + t.Fatal(err) + } + defer cleanup() + + dir := filepath.Dir(staged) + base := filepath.Base(dir) + if !strings.Contains(base, "asca-hook-") { + t.Fatalf("expected dir name to contain asca-hook-, got %q", base) + } + if !strings.Contains(base, "abc123") { + t.Fatalf("expected dir name to contain session tag, got %q", base) + } +} + +func TestStageForScan_CleanupRemovesDir(t *testing.T) { + staged, cleanup, err := stageForScan("/tmp/foo.py", "x", "sess") + if err != nil { + t.Fatal(err) + } + dir := filepath.Dir(staged) + cleanup() + if _, err := os.Stat(dir); !os.IsNotExist(err) { + t.Fatal("expected temp dir to be removed after cleanup") + } +} + +func TestStageForScan_FileMode(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("Unix permission bits (0600) are not enforced on Windows; validated on Linux/macOS CI") + } + staged, cleanup, err := stageForScan("/tmp/secret.py", "secret", "s1") + if err != nil { + t.Fatal(err) + } + defer cleanup() + + info, err := os.Stat(staged) + if err != nil { + t.Fatal(err) + } + if perm := info.Mode().Perm(); perm != 0o600 { + t.Fatalf("expected mode 0600, got %04o", perm) + } +} + +// ── NewFindings ────────────────────────────────────────────────────────────── + +func scanDetail(ruleID uint32, line string) grpcs.ScanDetail { + return grpcs.ScanDetail{ + RuleID: ruleID, + ProblematicLine: line, + Severity: "HIGH", + RuleName: "test-rule", + } +} + +func TestNewFindings_NilOriginalReturnsAll(t *testing.T) { + newScan := []grpcs.ScanDetail{scanDetail(1, "bad code")} + got := NewFindings(nil, newScan) + if len(got) != 1 { + t.Fatalf("expected 1 finding, got %d", len(got)) + } +} + +func TestNewFindings_IdenticalScansReturnsEmpty(t *testing.T) { + scan := []grpcs.ScanDetail{scanDetail(42, "subprocess.run(cmd, shell=True)")} + got := NewFindings(scan, scan) + if len(got) != 0 { + t.Fatalf("expected 0 new findings, got %d", len(got)) + } +} + +func TestNewFindings_NewVulnReturned(t *testing.T) { + orig := []grpcs.ScanDetail{scanDetail(1, "line A")} + newScan := []grpcs.ScanDetail{ + scanDetail(1, "line A"), + scanDetail(2, "line B"), + } + got := NewFindings(orig, newScan) + if len(got) != 1 || got[0].RuleID != 2 { + t.Fatalf("expected finding for ruleID 2, got %v", got) + } +} + +func TestNewFindings_OldVulnNotInNewIsIgnored(t *testing.T) { + orig := []grpcs.ScanDetail{scanDetail(99, "old line")} + newScan := []grpcs.ScanDetail{scanDetail(1, "new line")} + got := NewFindings(orig, newScan) + if len(got) != 1 || got[0].RuleID != 1 { + t.Fatalf("unexpected findings: %v", got) + } +} + +func TestNewFindings_TrimSpaceDeduplication(t *testing.T) { + // Same rule + same line but with different surrounding whitespace → treated as same + orig := []grpcs.ScanDetail{scanDetail(5, " shell=True ")} + newScan := []grpcs.ScanDetail{scanDetail(5, "shell=True")} + got := NewFindings(orig, newScan) + if len(got) != 0 { + t.Fatalf("expected trimspace dedup, got %d findings", len(got)) + } +} + +func TestNewFindings_EmptyNewScanReturnsEmpty(t *testing.T) { + orig := []grpcs.ScanDetail{scanDetail(1, "x")} + got := NewFindings(orig, nil) + if len(got) != 0 { + t.Fatalf("expected 0 findings, got %d", len(got)) + } +} + +// ── additionalContext ──────────────────────────────────────────────────────── + +func TestAdditionalContext_SingleFinding_PreFilledCommand(t *testing.T) { + findings := []grpcs.ScanDetail{ + {FileName: "billing.py", Line: 5, RuleID: 4059}, + } + ctx := additionalContext("billing.py", "cx", findings, "") + if !strings.Contains(ctx, "ignore-vulnerability") { + t.Errorf("expected ignore-vulnerability command, got %q", ctx) + } + if !strings.Contains(ctx, `"FileName":"billing.py"`) { + t.Errorf("expected FileName in command, got %q", ctx) + } + if !strings.Contains(ctx, `"Line":5`) { + t.Errorf("expected Line in command, got %q", ctx) + } + if !strings.Contains(ctx, `"RuleID":4059`) { + t.Errorf("expected RuleID in command, got %q", ctx) + } +} + +func TestAdditionalContext_MultipleFindings_EachGetsCommand(t *testing.T) { + findings := []grpcs.ScanDetail{ + {FileName: "billing.py", Line: 5, RuleID: 4059}, + {FileName: "billing.py", Line: 12, RuleID: 4027}, + } + ctx := additionalContext("billing.py", "cx", findings, "") + if strings.Count(ctx, "ignore-vulnerability") != 2 { + t.Errorf("expected 2 ignore commands for 2 findings, got: %q", ctx) + } + if !strings.Contains(ctx, `"RuleID":4059`) { + t.Errorf("expected RuleID 4059, got %q", ctx) + } + if !strings.Contains(ctx, `"RuleID":4027`) { + t.Errorf("expected RuleID 4027, got %q", ctx) + } +} + +func TestAdditionalContext_EmptyFindings_StillContainsRemediationInstruction(t *testing.T) { + ctx := additionalContext("main.py", "cx", nil, "") + if !strings.Contains(ctx, "mcp__Checkmarx__codeRemediation") { + t.Errorf("expected codeRemediation instruction even with no findings, got %q", ctx) + } +} + +func TestAdditionalContext_PinsIgnoredFilePathToWorkDir(t *testing.T) { + findings := []grpcs.ScanDetail{ + {FileName: "billing.py", Line: 5, RuleID: 4059}, + } + workDir := filepath.Join("repo", "ws") + ctx := additionalContext("billing.py", "cx", findings, workDir) + want := "--ignored-file-path '" + ignore.PathFor(workDir) + "'" + if !strings.Contains(ctx, want) { + t.Errorf("expected context to pin %q, got %q", want, ctx) + } +} + +func TestAdditionalContext_EmptyWorkDirOmitsIgnoredFilePath(t *testing.T) { + findings := []grpcs.ScanDetail{ + {FileName: "billing.py", Line: 5, RuleID: 4059}, + } + ctx := additionalContext("billing.py", "cx", findings, "") + if strings.Contains(ctx, "--ignored-file-path") { + t.Errorf("expected no ignored-file-path flag for empty workDir, got %q", ctx) + } +} diff --git a/internal/commands/agenthooks/guardrails/asca/delta.go b/internal/commands/agenthooks/guardrails/asca/delta.go new file mode 100644 index 000000000..03c917f65 --- /dev/null +++ b/internal/commands/agenthooks/guardrails/asca/delta.go @@ -0,0 +1,136 @@ +package asca + +import ( + "encoding/json" + "fmt" + "os" + "strings" + + "github.com/checkmarx/ast-cli/internal/services/realtimeengine/ignore" + "github.com/checkmarx/ast-cli/internal/wrappers/grpcs" +) + +// findingKey is the deduplication tuple used for delta detection. +// Mirrors the cx-security plugin's matching logic. +type findingKey struct { + ruleID uint32 + problematicLine string // TrimSpace applied +} + +func keyOf(d grpcs.ScanDetail) findingKey { + return findingKey{ + ruleID: d.RuleID, + problematicLine: strings.TrimSpace(d.ProblematicLine), + } +} + +// NewFindings returns scan details present in newScan that have no matching key in originalScan. +// A new file (originalScan == nil) returns newScan unchanged — any vuln is "new". +func NewFindings(originalScan, newScan []grpcs.ScanDetail) []grpcs.ScanDetail { + if originalScan == nil { + return newScan + } + baseline := make(map[findingKey]struct{}, len(originalScan)) + for _, d := range originalScan { + baseline[keyOf(d)] = struct{}{} + } + var out []grpcs.ScanDetail + for _, d := range newScan { + if _, exists := baseline[keyOf(d)]; !exists { + out = append(out, d) + } + } + return out +} + +// findingsSummary returns the bullet list shared by both message fields. Each line +// carries the file_name (basename), line, rule_id, severity and remediation so the +// agent has everything needed to suppress a confirmed false positive via +// `cx ignore-vulnerability` without having to re-scan to recover the rule_id. +func findingsSummary(findings []grpcs.ScanDetail) string { + var sb strings.Builder + for _, f := range findings { + remediation := f.Remediation + if remediation == "" { + remediation = "No remediation provided" + } + fmt.Fprintf(&sb, " - %s line %d [%s] %s (rule_id %d) — %s\n", + f.FileName, f.Line, f.Severity, f.RuleName, f.RuleID, remediation) + } + return sb.String() +} + +// formatFindings builds the two verdict fields delivered to the agent: the +// human-readable deny reason (rendered as permissionDecisionReason) and the +// remediation guidance injected into the agent's context (additionalContext). +// ast-cx-hooks v1.0.3 carries these as distinct fields via RejectEditWithContext. +func formatFindings(filePath string, findings []grpcs.ScanDetail, workDir string) (reason, context string) { + summary := findingsSummary(findings) + cxExe, err := os.Executable() + cxBinary := "cx" + if err == nil { + cxBinary = cxExe + } + return permissionDecisionReason(filePath, summary), additionalContext(filePath, cxBinary, findings, workDir) +} + +// ignoredFilePathFlag returns the " --ignored-file-path ''" fragment that pins +// the suppression command to the workspace ignore file, anchored at the hook event's +// workDir. This keeps the write (cx ignore-vulnerability) and the later read (the hook) +// on the same absolute file regardless of either process's CWD — without it, a host CLI +// that runs the agent's shell from a different directory than the hook (e.g. Copilot CLI) +// would write and read different files. Returns "" when workDir is unknown so the command +// falls back to its CWD-relative default. +func ignoredFilePathFlag(workDir string) string { + if workDir == "" { + return "" + } + return fmt.Sprintf(" --ignored-file-path '%s'", ignore.PathFor(workDir)) +} + +// permissionDecisionReason is the human-readable deny message shown to the user. +// Contains only the findings — no agent instructions. +func permissionDecisionReason(filePath, summary string) string { + return fmt.Sprintf( + "ASCA security scan detected vulnerabilities in %s.\nFindings:\n%s", + filePath, summary, + ) +} + +// additionalContext is injected into the agent's context window to drive remediation. +// Contains all action instructions — not shown directly to the user. +func additionalContext(filePath, cxBinary string, findings []grpcs.ScanDetail, workDir string) string { + ignoreFlag := ignoredFilePathFlag(workDir) + var suppressCmds strings.Builder + for _, f := range findings { + data, _ := json.Marshal(grpcs.AscaIgnoreFinding{ + FileName: f.FileName, + Line: f.Line, + RuleID: f.RuleID, + }) + fmt.Fprintf(&suppressCmds, " %s ignore-vulnerability --scan-type asca --data '%s'%s\n", cxBinary, string(data), ignoreFlag) + } + return fmt.Sprintf( + "ASCA detected vulnerabilities in %s. "+ + "Do not bypass the scan by writing the same content through another tool or shell command. "+ + "ANALYZE each finding to determine if it is a real vulnerability or a false positive "+ + "caused by ASCA's single-file scope (it cannot see imported modules or helper files). "+ + "For each real finding, invoke the cx-security:cx-security-asca skill — "+ + "the findings are already in context so it will skip the scan and go directly to "+ + "MCP-driven remediation; the skill also handles MCP unavailability and self-recovery. "+ + "If that skill is not available in this session, call mcp__Checkmarx__codeRemediation directly:\n"+ + " {\n"+ + " \"language\": \"[auto-detected programming language]\",\n"+ + " \"metadata\": {\n"+ + " \"ruleId\": \"[rule_name from scan]\",\n"+ + " \"description\": \"[description from scan]\",\n"+ + " \"remediationAdvice\": \"[remediationAdvise from scan]\"\n"+ + " },\n"+ + " \"type\": \"sast\"\n"+ + " }\n"+ + "Use the remediation guidance returned by the tool to fix the vulnerability, then retry the write. "+ + "If a finding is a confirmed false positive, suppress it by running the corresponding command below, then retry the write:\n"+ + suppressCmds.String(), + filePath, + ) +} diff --git a/internal/commands/agenthooks/guardrails/kics/content.go b/internal/commands/agenthooks/guardrails/kics/content.go new file mode 100644 index 000000000..9eed7e369 --- /dev/null +++ b/internal/commands/agenthooks/guardrails/kics/content.go @@ -0,0 +1,38 @@ +package kics + +import ( + "os" + "strings" + + agenthooks "github.com/CheckmarxDev/ast-cx-hooks" +) + +// proposedContent returns the file content that would exist after ev.Changes are applied. +// Returns (newContent, originalContent, err). +// - Full-file write: Changes = [{Before:"", After:X}] → newContent=X, originalContent= +// - String-replace edit: read disk, apply each FileDiff.Before→After in order +// - File doesn't exist on disk: originalContent="" +func proposedContent(filePath string, changes []agenthooks.FileDiff) (newContent, originalContent string, err error) { + diskBytes, readErr := os.ReadFile(filePath) + if readErr == nil { + originalContent = string(diskBytes) + } + // readErr means file doesn't exist yet — originalContent stays "" + + // Full-file write: single diff with empty Before + if len(changes) == 1 && changes[0].Before == "" { + return changes[0].After, originalContent, nil + } + + // String-replace: apply each diff in order against current content + current := originalContent + for _, diff := range changes { + idx := strings.Index(current, diff.Before) + if idx < 0 { + // Before not found — malformed edit; fail-open, let agent's tool surface it + continue + } + current = current[:idx] + diff.After + current[idx+len(diff.Before):] + } + return current, originalContent, nil +} diff --git a/internal/commands/agenthooks/guardrails/kics/delta.go b/internal/commands/agenthooks/guardrails/kics/delta.go new file mode 100644 index 000000000..9503b9498 --- /dev/null +++ b/internal/commands/agenthooks/guardrails/kics/delta.go @@ -0,0 +1,112 @@ +package kics + +import ( + "fmt" + "strings" + + "github.com/checkmarx/ast-cli/internal/services/realtimeengine/iacrealtime" +) + +// findingKey is the deduplication tuple used for delta detection. +// Mirrors the ignore-file key used by RunIacRealtimeScan: Title + "_" + SimilarityID. +type findingKey struct { + title string + similarityID string +} + +func keyOf(r iacrealtime.IacRealtimeResult) findingKey { + return findingKey{ + title: r.Title, + similarityID: r.SimilarityID, + } +} + +// NewFindings returns results present in newScan that have no matching key in originalScan. +// A new file (originalScan == nil) returns newScan unchanged — any finding is "new". +func NewFindings(originalScan, newScan []iacrealtime.IacRealtimeResult) []iacrealtime.IacRealtimeResult { + if originalScan == nil { + return newScan + } + baseline := make(map[findingKey]struct{}, len(originalScan)) + for _, r := range originalScan { + baseline[keyOf(r)] = struct{}{} + } + var out []iacrealtime.IacRealtimeResult + for _, r := range newScan { + if _, exists := baseline[keyOf(r)]; !exists { + out = append(out, r) + } + } + return out +} + +// findingsSummary returns the bullet list of findings for human display. +func findingsSummary(filePath string, findings []iacrealtime.IacRealtimeResult) string { + var sb strings.Builder + for _, f := range findings { + line := 0 + if len(f.Locations) > 0 { + line = f.Locations[0].Line + } + description := f.Description + if description == "" { + description = "No description provided" + } + fmt.Fprintf(&sb, " - %s line %d [%s] %s — %s\n", + filePath, line, f.Severity, f.Title, description) + } + return sb.String() +} + +// formatFindings builds the two verdict fields delivered to the agent. +func formatFindings(filePath string, findings []iacrealtime.IacRealtimeResult) (reason, context string) { + summary := findingsSummary(filePath, findings) + return permissionDecisionReason(filePath, summary), additionalContext(filePath, findings) +} + +// permissionDecisionReason is the human-readable deny message shown to the user. +func permissionDecisionReason(filePath, summary string) string { + return fmt.Sprintf( + "KICS security scan detected IaC vulnerabilities in %s.\nFindings:\n%s", + filePath, summary, + ) +} + +// additionalContext is injected into the agent's context window to drive remediation. +// KICS is a deterministic IaC rule engine: unlike ASCA, its findings are not caused by +// missing cross-file context, so the agent is NOT given discretion to treat findings as +// false positives. Every new finding must be fixed. +func additionalContext(filePath string, findings []iacrealtime.IacRealtimeResult) string { + var findingList strings.Builder + for _, f := range findings { + line := 0 + if len(f.Locations) > 0 { + line = f.Locations[0].Line + } + fmt.Fprintf(&findingList, " - line %d [%s] %s: %s\n", + line, f.Severity, f.Title, f.Description) + } + return fmt.Sprintf( + "KICS detected IaC misconfigurations in %s. These are deterministic rule matches "+ + "against the configuration itself — they are NOT false positives caused by code "+ + "the scanner cannot see. Do not skip, suppress, or dismiss any finding as a false "+ + "positive, and do not bypass the scan by writing the same content through another "+ + "tool or shell command.\n"+ + "Fix every finding below, then retry the write:\n"+ + "%s"+ + "For each finding, call the mcp__Checkmarx__codeRemediation tool with:\n"+ + " {\n"+ + " \"type\": \"iac\",\n"+ + " \"metadata\": {\n"+ + " \"title\": \"[Title from finding]\",\n"+ + " \"description\": \"[Description from finding]\",\n"+ + " \"remediationAdvice\": \"[how to harden this configuration]\"\n"+ + " }\n"+ + " }\n"+ + "Apply the remediation guidance the tool returns, then retry the write. If a fix "+ + "genuinely requires resources outside this file (for example a separate KMS key or "+ + "a centrally-managed policy), add them as part of your change rather than skipping "+ + "the finding.", + filePath, findingList.String(), + ) +} diff --git a/internal/commands/agenthooks/guardrails/kics/delta_test.go b/internal/commands/agenthooks/guardrails/kics/delta_test.go new file mode 100644 index 000000000..09f6c476f --- /dev/null +++ b/internal/commands/agenthooks/guardrails/kics/delta_test.go @@ -0,0 +1,125 @@ +//go:build !integration + +package kics + +import ( + "strings" + "testing" + + "github.com/checkmarx/ast-cli/internal/services/realtimeengine" + "github.com/checkmarx/ast-cli/internal/services/realtimeengine/iacrealtime" +) + +func iacResult(title, similarityID, severity string, line int) iacrealtime.IacRealtimeResult { + return iacrealtime.IacRealtimeResult{ + Title: title, + SimilarityID: similarityID, + Severity: severity, + Description: "test description", + Locations: []realtimeengine.Location{{Line: line}}, + } +} + +// ── NewFindings ─────────────────────────────────────────────────────────────── + +func TestNewFindings_NilOriginalReturnsAll(t *testing.T) { + newScan := []iacrealtime.IacRealtimeResult{iacResult("PrivilegedContainer", "sim1", "HIGH", 5)} + got := NewFindings(nil, newScan) + if len(got) != 1 { + t.Fatalf("expected 1 finding, got %d", len(got)) + } +} + +func TestNewFindings_IdenticalScansReturnsEmpty(t *testing.T) { + scan := []iacrealtime.IacRealtimeResult{iacResult("PrivilegedContainer", "sim1", "HIGH", 5)} + got := NewFindings(scan, scan) + if len(got) != 0 { + t.Fatalf("expected 0 new findings, got %d", len(got)) + } +} + +func TestNewFindings_NewVulnReturned(t *testing.T) { + orig := []iacrealtime.IacRealtimeResult{iacResult("PrivilegedContainer", "sim1", "HIGH", 5)} + newScan := []iacrealtime.IacRealtimeResult{ + iacResult("PrivilegedContainer", "sim1", "HIGH", 5), + iacResult("OpenSecurityGroup", "sim2", "CRITICAL", 10), + } + got := NewFindings(orig, newScan) + if len(got) != 1 || got[0].Title != "OpenSecurityGroup" { + t.Fatalf("expected finding for OpenSecurityGroup, got %v", got) + } +} + +func TestNewFindings_PreExistingFindingNotReturned(t *testing.T) { + orig := []iacrealtime.IacRealtimeResult{iacResult("PrivilegedContainer", "sim1", "HIGH", 5)} + newScan := []iacrealtime.IacRealtimeResult{iacResult("PrivilegedContainer", "sim1", "HIGH", 5)} + got := NewFindings(orig, newScan) + if len(got) != 0 { + t.Fatalf("expected 0 findings (pre-existing), got %d", len(got)) + } +} + +func TestNewFindings_EmptyNewScanReturnsEmpty(t *testing.T) { + orig := []iacrealtime.IacRealtimeResult{iacResult("PrivilegedContainer", "sim1", "HIGH", 5)} + got := NewFindings(orig, nil) + if len(got) != 0 { + t.Fatalf("expected 0 findings, got %d", len(got)) + } +} + +func TestNewFindings_DeltaDedup_SameKeyNotDoubled(t *testing.T) { + orig := []iacrealtime.IacRealtimeResult{iacResult("RuleA", "simA", "HIGH", 1)} + newScan := []iacrealtime.IacRealtimeResult{ + iacResult("RuleA", "simA", "HIGH", 1), // pre-existing + iacResult("RuleB", "simB", "MEDIUM", 2), // new + } + got := NewFindings(orig, newScan) + if len(got) != 1 || got[0].Title != "RuleB" { + t.Fatalf("expected only RuleB as new finding, got %v", got) + } +} + +// ── formatFindings ──────────────────────────────────────────────────────────── + +func TestFormatFindings_ReasonContainsKICS(t *testing.T) { + findings := []iacrealtime.IacRealtimeResult{iacResult("PrivilegedContainer", "sim1", "HIGH", 5)} + reason, _ := formatFindings("/project/Dockerfile", findings) + if !strings.Contains(reason, "KICS") { + t.Errorf("reason should contain KICS, got: %q", reason) + } +} + +func TestFormatFindings_ReasonContainsFilePath(t *testing.T) { + findings := []iacrealtime.IacRealtimeResult{iacResult("PrivilegedContainer", "sim1", "HIGH", 5)} + reason, _ := formatFindings("/project/Dockerfile", findings) + if !strings.Contains(reason, "/project/Dockerfile") { + t.Errorf("reason should contain file path, got: %q", reason) + } +} + +func TestFormatFindings_ReasonContainsSeverityAndTitle(t *testing.T) { + findings := []iacrealtime.IacRealtimeResult{iacResult("PrivilegedContainer", "sim1", "HIGH", 5)} + reason, _ := formatFindings("/project/Dockerfile", findings) + if !strings.Contains(reason, "HIGH") { + t.Errorf("reason should contain severity, got: %q", reason) + } + if !strings.Contains(reason, "PrivilegedContainer") { + t.Errorf("reason should contain finding title, got: %q", reason) + } +} + +func TestFormatFindings_ContextContainsFixInstruction(t *testing.T) { + findings := []iacrealtime.IacRealtimeResult{iacResult("PrivilegedContainer", "sim1", "HIGH", 5)} + _, ctx := formatFindings("/project/Dockerfile", findings) + if !strings.Contains(ctx, "fix") && !strings.Contains(ctx, "Fix") { + t.Errorf("context should contain fix instruction, got: %q", ctx) + } +} + +func TestFormatFindings_ContextContainsDoNotBypass(t *testing.T) { + findings := []iacrealtime.IacRealtimeResult{iacResult("PrivilegedContainer", "sim1", "HIGH", 5)} + _, ctx := formatFindings("/project/Dockerfile", findings) + if !strings.Contains(ctx, "bypass") { + t.Errorf("context should warn against bypass, got: %q", ctx) + } +} diff --git a/internal/commands/agenthooks/guardrails/kics/kics.go b/internal/commands/agenthooks/guardrails/kics/kics.go new file mode 100644 index 000000000..5ff568f04 --- /dev/null +++ b/internal/commands/agenthooks/guardrails/kics/kics.go @@ -0,0 +1,116 @@ +package kics + +import ( + "os" + "path/filepath" + "strings" + + agenthooks "github.com/CheckmarxDev/ast-cx-hooks" + "github.com/checkmarx/ast-cli/internal/params" + "github.com/checkmarx/ast-cli/internal/services/realtimeengine/ignore" +) + +// isSupportedByKICS returns true when the file matches a KICS-supported extension or basename. +// Mirrors params.KicsBaseFilters: basename match for Dockerfile/.dockerfile, +// extension match for .tf/.yaml/.yml/.json/.auto.tfvars/.terraform.tfvars/.proto. +func isSupportedByKICS(filePath string) bool { + base := filepath.Base(filePath) + baseLower := strings.ToLower(base) + + // Check basenames (Dockerfile and .dockerfile are basenames, not extensions) + for _, filter := range params.KicsBaseFilters { + filterLower := strings.ToLower(filter) + // Basename matches (no dot prefix means it's a filename, not an extension) + if !strings.HasPrefix(filterLower, ".") { + if baseLower == filterLower { + return true + } + continue + } + // Extension matches — check if the file path ends with the filter string + // (handles compound extensions like .auto.tfvars and .terraform.tfvars) + if strings.HasSuffix(strings.ToLower(filePath), filterLower) { + return true + } + } + return false +} + +// ScanFileEdit runs KICS on the proposed post-edit content. +// Returns blocked=true with a formatted reason and remediation context when KICS +// finds *new* vulnerabilities introduced by ev.Changes (delta-detection for edits; +// any-vuln for new writes). Findings the user already suppressed via +// `cx ignore-vulnerability` (the realtime ignore file) are filtered out before the +// verdict. Fail-open on infrastructure errors (Docker unavailable, image pull fail, panic). +func ScanFileEdit(ev agenthooks.FileEditEvent, svc *Scanner) (blocked bool, reason, context string) { + defer func() { + if r := recover(); r != nil { + blocked = false + reason = "" + context = "" + } + }() + + if !isSupportedByKICS(ev.FilePath) { + return false, "", "" + } + + newContent, originalContent, err := proposedContent(ev.FilePath, ev.Changes) + if err != nil || newContent == "" { + return false, "", "" + } + + // Stage and scan the proposed (new) content + stagedNew, cleanupNew, err := stageForScan(ev.FilePath, newContent, ev.SessionID) + if err != nil { + return false, "", "" + } + defer cleanupNew() + + newResults, err := svc.scan(stagedNew) + if err != nil { + // Fail open: Docker unavailable, image pull failure, feature flag disabled, etc. + return false, "", "" + } + if len(newResults) == 0 { + return false, "", "" + } + + // For new files (no original content), every finding is new + if originalContent == "" { + r, c := formatFindings(ev.FilePath, newResults) + return true, r, c + } + + // Delta: scan original content and find only newly introduced findings + stagedOrig, cleanupOrig, err := stageForScan(ev.FilePath, originalContent, ev.SessionID) + if err != nil { + return false, "", "" + } + defer cleanupOrig() + + origResults, err := svc.scan(stagedOrig) + if err != nil { + // Fail open on original scan error + return false, "", "" + } + + newFindings := NewFindings(origResults, newResults) + if len(newFindings) == 0 { + return false, "", "" + } + + r, c := formatFindings(ev.FilePath, newFindings) + return true, r, c +} + +// existingIgnoreFilePath returns the default realtime ignore-file path only when it +// exists on disk. The IaC realtime service logs a warning and skips ignore filtering +// when a missing path is passed, but we keep the pattern consistent with ASCA. +func existingIgnoreFilePath() string { + p := ignore.DefaultPath() + if _, err := os.Stat(p); err == nil { + return p + } + return "" +} diff --git a/internal/commands/agenthooks/guardrails/kics/kics_test.go b/internal/commands/agenthooks/guardrails/kics/kics_test.go new file mode 100644 index 000000000..4fe57e72e --- /dev/null +++ b/internal/commands/agenthooks/guardrails/kics/kics_test.go @@ -0,0 +1,199 @@ +//go:build !integration + +package kics + +import ( + "fmt" + "os" + "path/filepath" + "strings" + "testing" + + agenthooks "github.com/CheckmarxDev/ast-cx-hooks" + "github.com/checkmarx/ast-cli/internal/services/realtimeengine" + "github.com/checkmarx/ast-cli/internal/services/realtimeengine/iacrealtime" +) + +// ── isSupportedByKICS ──────────────────────────────────────────────────────── + +func TestIsSupportedByKICS_TerraformFile(t *testing.T) { + if !isSupportedByKICS("/project/main.tf") { + t.Error("expected .tf to be supported") + } +} + +func TestIsSupportedByKICS_YamlFile(t *testing.T) { + if !isSupportedByKICS("/project/k8s/deployment.yaml") { + t.Error("expected .yaml to be supported") + } +} + +func TestIsSupportedByKICS_YmlFile(t *testing.T) { + if !isSupportedByKICS("/project/compose.yml") { + t.Error("expected .yml to be supported") + } +} + +func TestIsSupportedByKICS_JsonFile(t *testing.T) { + if !isSupportedByKICS("/project/policy.json") { + t.Error("expected .json to be supported") + } +} + +func TestIsSupportedByKICS_Dockerfile(t *testing.T) { + if !isSupportedByKICS("/project/Dockerfile") { + t.Error("expected Dockerfile to be supported") + } +} + +func TestIsSupportedByKICS_DockerfileUppercase(t *testing.T) { + if !isSupportedByKICS("/project/DOCKERFILE") { + t.Error("expected DOCKERFILE (case-insensitive) to be supported") + } +} + +func TestIsSupportedByKICS_AutoTfvars(t *testing.T) { + if !isSupportedByKICS("/project/prod.auto.tfvars") { + t.Error("expected .auto.tfvars to be supported") + } +} + +func TestIsSupportedByKICS_TerraformTfvars(t *testing.T) { + if !isSupportedByKICS("/project/env.terraform.tfvars") { + t.Error("expected .terraform.tfvars to be supported") + } +} + +func TestIsSupportedByKICS_GoFileNotSupported(t *testing.T) { + if isSupportedByKICS("/project/main.go") { + t.Error("expected .go to NOT be supported") + } +} + +func TestIsSupportedByKICS_TxtFileNotSupported(t *testing.T) { + if isSupportedByKICS("/project/README.txt") { + t.Error("expected .txt to NOT be supported") + } +} + +func TestIsSupportedByKICS_PyFileNotSupported(t *testing.T) { + if isSupportedByKICS("/project/app.py") { + t.Error("expected .py to NOT be supported") + } +} + +// ── ScanFileEdit ───────────────────────────────────────────────────────────── + +func makeResult(title, similarityID, severity, description string, line int) iacrealtime.IacRealtimeResult { + return iacrealtime.IacRealtimeResult{ + Title: title, + SimilarityID: similarityID, + Severity: severity, + Description: description, + Locations: []realtimeengine.Location{{Line: line}}, + } +} + +func TestScanFileEdit_NewFileWithFinding_Blocked(t *testing.T) { + finding := makeResult("Privileged Container", "sim123", "HIGH", "Container runs as privileged", 5) + svc := NewScannerWithFunc(func(_ string) ([]iacrealtime.IacRealtimeResult, error) { + return []iacrealtime.IacRealtimeResult{finding}, nil + }) + + ev := agenthooks.FileEditEvent{ + FilePath: "/project/Dockerfile", + SessionID: "test-sess", + Changes: []agenthooks.FileDiff{{Before: "", After: "FROM ubuntu\nUSER root\n"}}, + } + + blocked, reason, ctx := ScanFileEdit(ev, svc) + if !blocked { + t.Fatal("expected edit to be blocked") + } + if reason == "" { + t.Error("expected non-empty reason") + } + if ctx == "" { + t.Error("expected non-empty context") + } + if !strings.Contains(reason, "KICS") { + t.Errorf("reason should mention KICS, got: %q", reason) + } +} + +func TestScanFileEdit_EditWithNoNewFindings_NotBlocked(t *testing.T) { + existingFinding := makeResult("Privileged Container", "sim123", "HIGH", "Container runs as privileged", 5) + svc := NewScannerWithFunc(func(_ string) ([]iacrealtime.IacRealtimeResult, error) { + // Both original and new have the same finding — delta is empty + return []iacrealtime.IacRealtimeResult{existingFinding}, nil + }) + + dir := t.TempDir() + filePath := filepath.Join(dir, "Dockerfile") + if err := os.WriteFile(filePath, []byte("FROM ubuntu\nUSER root\n"), 0o600); err != nil { + t.Fatal(err) + } + + ev := agenthooks.FileEditEvent{ + FilePath: filePath, + SessionID: "test-sess", + Changes: []agenthooks.FileDiff{{Before: "FROM ubuntu", After: "FROM ubuntu:22.04"}}, + } + + blocked, _, _ := ScanFileEdit(ev, svc) + if blocked { + t.Fatal("expected edit to NOT be blocked when no new findings") + } +} + +func TestScanFileEdit_ScanError_FailOpen(t *testing.T) { + svc := NewScannerWithFunc(func(_ string) ([]iacrealtime.IacRealtimeResult, error) { + return nil, fmt.Errorf("docker daemon not running") + }) + + ev := agenthooks.FileEditEvent{ + FilePath: "/project/main.tf", + SessionID: "test-sess", + Changes: []agenthooks.FileDiff{{Before: "", After: "resource \"aws_s3_bucket\" \"bad\" {}"}}, + } + + blocked, _, _ := ScanFileEdit(ev, svc) + if blocked { + t.Fatal("expected fail-open (not blocked) on scan error") + } +} + +func TestScanFileEdit_UnsupportedFile_NotBlocked(t *testing.T) { + svc := NewScannerWithFunc(func(_ string) ([]iacrealtime.IacRealtimeResult, error) { + t.Error("scan should not be called for unsupported file types") + return nil, nil + }) + + ev := agenthooks.FileEditEvent{ + FilePath: "/project/main.go", + SessionID: "test-sess", + Changes: []agenthooks.FileDiff{{Before: "", After: "package main"}}, + } + + blocked, _, _ := ScanFileEdit(ev, svc) + if blocked { + t.Fatal("expected NOT blocked for unsupported file") + } +} + +func TestScanFileEdit_EmptyNewContent_NotBlocked(t *testing.T) { + svc := NewScannerWithFunc(func(_ string) ([]iacrealtime.IacRealtimeResult, error) { + return nil, nil + }) + + ev := agenthooks.FileEditEvent{ + FilePath: "/project/main.tf", + SessionID: "test-sess", + Changes: []agenthooks.FileDiff{{Before: "", After: ""}}, + } + + blocked, _, _ := ScanFileEdit(ev, svc) + if blocked { + t.Fatal("expected NOT blocked for empty content") + } +} diff --git a/internal/commands/agenthooks/guardrails/kics/scanner.go b/internal/commands/agenthooks/guardrails/kics/scanner.go new file mode 100644 index 000000000..e1e99a12d --- /dev/null +++ b/internal/commands/agenthooks/guardrails/kics/scanner.go @@ -0,0 +1,33 @@ +package kics + +import ( + "github.com/checkmarx/ast-cli/internal/services/realtimeengine/iacrealtime" + "github.com/checkmarx/ast-cli/internal/wrappers" +) + +// Scanner runs IaC realtime scans on behalf of the KICS guardrail. It holds +// the wrappers needed to construct an IacRealtimeService per call. Tests +// substitute scan via NewScannerWithFunc. +type Scanner struct { + jwt wrappers.JWTWrapper + ff wrappers.FeatureFlagsWrapper + scan func(path string) ([]iacrealtime.IacRealtimeResult, error) +} + +// NewScanner returns a Scanner backed by the given wrappers. +func NewScanner(jwt wrappers.JWTWrapper, ff wrappers.FeatureFlagsWrapper) *Scanner { + s := &Scanner{jwt: jwt, ff: ff} + s.scan = s.runRealScan + return s +} + +// NewScannerWithFunc returns a Scanner whose scan call is replaced with f. +// For unit tests only. +func NewScannerWithFunc(f func(path string) ([]iacrealtime.IacRealtimeResult, error)) *Scanner { + return &Scanner{scan: f} +} + +func (s *Scanner) runRealScan(path string) ([]iacrealtime.IacRealtimeResult, error) { + svc := iacrealtime.NewIacRealtimeService(s.jwt, s.ff, iacrealtime.NewContainerManager()) + return svc.RunIacRealtimeScan(path, "", existingIgnoreFilePath()) +} diff --git a/internal/commands/agenthooks/guardrails/kics/stage.go b/internal/commands/agenthooks/guardrails/kics/stage.go new file mode 100644 index 000000000..538c1f357 --- /dev/null +++ b/internal/commands/agenthooks/guardrails/kics/stage.go @@ -0,0 +1,67 @@ +package kics + +import ( + "fmt" + "os" + "path/filepath" + "strings" +) + +// noop is a no-op cleanup func returned on error paths so callers can always defer cleanup(). +var noop = func() {} + +// stageForScan writes content to a fresh temp directory under os.TempDir(), +// preserving the original basename so KICS's file-type detection works correctly +// and findings report a sensible file path. The dir name includes a short, +// sanitized prefix of sessionID so concurrent agent sessions are visibly +// separated and orphaned dirs can be traced back to the session that created them. +// +// Returns the staged path and a cleanup func. Caller must defer cleanup(). +func stageForScan(originalPath, content, sessionID string) (stagedPath string, cleanup func(), err error) { + pattern := fmt.Sprintf("kics-hook-%s-*", safeSessionTag(sessionID)) + tempDir, err := os.MkdirTemp("", pattern) + if err != nil { + return "", noop, err + } + + base := filepath.Base(originalPath) + if base == "." || base == ".." || base == "" || base == string(filepath.Separator) { + _ = os.RemoveAll(tempDir) + return "", noop, fmt.Errorf("invalid basename %q", base) + } + + staged := filepath.Join(tempDir, base) + // Path-traversal guard, copied from iacrealtime/file_handler.go:62 + if !strings.HasPrefix(filepath.Clean(staged), filepath.Clean(tempDir)) { + _ = os.RemoveAll(tempDir) + return "", noop, fmt.Errorf("path traversal in %q", base) + } + + if err := os.WriteFile(staged, []byte(content), 0o600); err != nil { + _ = os.RemoveAll(tempDir) + return "", noop, err + } + return staged, func() { _ = os.RemoveAll(tempDir) }, nil +} + +// safeSessionTag returns up to 8 chars of [a-zA-Z0-9_-] from sid, or "anon" if +// sid is empty or has no usable characters. Keeps the dir name short and shell-safe. +func safeSessionTag(sid string) string { + if sid == "" { + return "anon" + } + var b strings.Builder + for _, r := range sid { + if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || + (r >= '0' && r <= '9') || r == '-' || r == '_' { + b.WriteRune(r) + if b.Len() >= 8 { + break + } + } + } + if b.Len() == 0 { + return "anon" + } + return b.String() +} diff --git a/internal/commands/agenthooks/guardrails/match.go b/internal/commands/agenthooks/guardrails/match.go new file mode 100644 index 000000000..7c95d75f9 --- /dev/null +++ b/internal/commands/agenthooks/guardrails/match.go @@ -0,0 +1,87 @@ +package guardrails + +import ( + "path/filepath" + "strings" + + "github.com/bmatcuk/doublestar/v4" +) + +// hasGlobMeta reports whether s contains glob metacharacters (*, ?, [). +func hasGlobMeta(s string) bool { + return strings.ContainsAny(s, "*?[") +} + +// normalizeForMatch lower-cases path and converts separators to forward slashes, +// mirroring how the rest of the guardrails package normalises paths. +func normalizeForMatch(s string) string { + return strings.ToLower(filepath.ToSlash(s)) +} + +// matchFilePattern reports whether target matches pattern. +// +// A match is any of: +// - literal equality of the normalized forms +// - basename equality (so a policy entry like "kubeconfig" matches any path whose leaf is kubeconfig) +// - suffix equality after "/" (so ".env" matches "/app/.env") — preserved for back-compat +// - doublestar glob match against the full normalized target +// - doublestar glob match against just the basename (so "*.pem" matches "foo.pem") +// +// Normalization is lowercase + forward-slash; patterns authored with backslashes +// on Windows still work. +func matchFilePattern(pattern, target string) bool { + p := normalizeForMatch(pattern) + t := normalizeForMatch(target) + base := filepath.Base(t) + + if p == t || p == base { + return true + } + if strings.HasSuffix(t, "/"+p) { + return true + } + if hasGlobMeta(p) { + if ok, _ := doublestar.Match(p, t); ok { + return true + } + if ok, _ := doublestar.Match(p, base); ok { + return true + } + } + return false +} + +// anyPatternMatchesFile returns true when target matches at least one entry in patterns +// via matchFilePattern. Convenience wrapper for call sites that just need a boolean. +func anyPatternMatchesFile(patterns []string, target string) bool { + for _, p := range patterns { + if matchFilePattern(p, target) { + return true + } + } + return false +} + +// matchDirContains reports whether target is pattern itself or sits under it. +// +// For literal patterns this is the classic "target == dir OR target starts with dir/". +// For glob patterns it additionally accepts target when doublestar matches +// either the pattern directly (target is the dir) or "pattern/**" +// (target is a file inside the glob-matched dir). +func matchDirContains(pattern, target string) bool { + p := strings.TrimSuffix(normalizeForMatch(pattern), "/") + t := normalizeForMatch(target) + + if t == p || strings.HasPrefix(t, p+"/") { + return true + } + if hasGlobMeta(p) { + if ok, _ := doublestar.Match(p, t); ok { + return true + } + if ok, _ := doublestar.Match(p+"/**", t); ok { + return true + } + } + return false +} diff --git a/internal/commands/agenthooks/guardrails/policy.go b/internal/commands/agenthooks/guardrails/policy.go new file mode 100644 index 000000000..420827d6b --- /dev/null +++ b/internal/commands/agenthooks/guardrails/policy.go @@ -0,0 +1,500 @@ +package guardrails + +import ( + "bytes" + "encoding/json" + "fmt" + "os" + "path/filepath" + "runtime" + "strings" + "sync/atomic" +) + +// utf8BOM is the byte-order mark some editors (notably PowerShell's +// Out-File/Set-Content defaults) prepend to UTF-8 files. encoding/json +// rejects it, so we strip it before unmarshalling. +var utf8BOM = []byte{0xEF, 0xBB, 0xBF} + +// DenyMessage is the firm instruction appended to every denial. It tells the +// agent to stop — no retries, no workarounds, no alternative approaches. +const DenyMessage = "\n\n" + + "IMPORTANT: This action was blocked by your organization's security policy. " + + "Do NOT attempt alternative commands, workarounds, or equivalent approaches to achieve the same result. " + + "Any alternative method to accomplish this blocked action may be equally harmful to the system and the developer. " + + "Instead, inform the user that this operation is restricted by organizational policy and suggest they contact their administrator if they need this action performed." + +// HooksPolicy is the root shape of ~/.checkmarx/policyhooks.json +type HooksPolicy struct { + DefaultPolicy DefaultPolicy `json:"default_policy"` + Tools ToolsPolicy `json:"tools"` +} + +// DefaultPolicy holds all policy sections that apply at the global scope. +type DefaultPolicy struct { + BlacklistTools struct { + Enabled bool `json:"enabled"` + Tools []BlacklistedTool `json:"tools"` + } `json:"blacklist_tools"` + RestrictedDirectories PathPolicy `json:"restricted_directories"` + RestrictedFiles PathPolicy `json:"restricted_files"` + AllowedDirectories PathPolicy `json:"allowed_directories"` + AllowedFiles PathPolicy `json:"allowed_files"` + ContextPolicy ContextPolicy `json:"context_policy"` + BlastRadiusLimit BlastRadiusLimit `json:"blast_radius_limit"` +} + +// ContextPolicy controls what data may enter the AI's context window. +type ContextPolicy struct { + Enabled bool `json:"enabled"` + FilesLimits FilesLimits `json:"files_limits"` + ContentScanning ContentScanning `json:"content_scanning"` + BlockedExtensions BlockedExtensions `json:"blocked_extensions"` +} + +// FilesLimits restricts how many (and how large) files may be referenced in an AI context. +type FilesLimits struct { + Enabled bool `json:"enabled"` + MaxFileCount int `json:"max_file_count"` + MaxFileSizeKB int `json:"max_file_size_kb"` + MaxTotalFileSizeKB int `json:"max_total_file_size_kb"` +} + +// BlockedExtensions lists file extensions that must never enter the AI context. +type BlockedExtensions struct { + Enabled bool `json:"enabled"` + Extensions []string `json:"extensions"` +} + +// BlastRadiusLimit caps how many files the AI may write during a single session. +type BlastRadiusLimit struct { + Enabled bool `json:"enabled"` + Threshold int `json:"threshold"` +} + +// ContentScanning holds the scanning configuration. +type ContentScanning struct { + Enabled bool `json:"enabled"` + Patterns []ContentScanPattern `json:"patterns"` +} + +// ContentScanPattern is a single regex rule that blocks sensitive content in prompts. +type ContentScanPattern struct { + ID string `json:"id"` + Pattern string `json:"pattern"` + Description string `json:"description"` +} + +// PathPolicy holds OS-specific path lists for restricted or allowed files/directories. +type PathPolicy struct { + Enabled bool `json:"enabled"` + Linux []string `json:"linux"` + Windows []string `json:"windows"` + Mac []string `json:"mac"` +} + +// BlacklistedTool is a single entry in the shell command blacklist. +type BlacklistedTool struct { + Name string `json:"name"` + OS []string `json:"os"` + Category string `json:"category"` + Risk string `json:"risk"` +} + +// ToolsPolicy is the root of the per-tool rule section. +type ToolsPolicy struct { + Enabled bool `json:"enabled"` + DefaultAuditLog bool `json:"default_audit_log"` + Rules []ToolRule `json:"rules"` +} + +// ToolRule defines restrictions and permissions for a specific shell tool. +// Enabled uses *bool so a missing field (nil) is treated as active, preserving +// backward-compatibility with older policy files that don't set the flag. +type ToolRule struct { + Enabled *bool `json:"enabled,omitempty"` + ID string `json:"id"` + Tool []string `json:"tool"` + OS []string `json:"os"` + ArgsInclude []string `json:"args_include"` + ArgsExclude []string `json:"args_exclude"` + RestrictedDirectories PathPolicy `json:"restricted_directories"` + RestrictedFiles PathPolicy `json:"restricted_files"` + AllowedDirectories PathPolicy `json:"allowed_directories"` + AllowedFiles PathPolicy `json:"allowed_files"` + MergeStrategy MergeStrategy `json:"merge_strategy"` + AuditLog bool `json:"audit_log"` +} + +// MergeStrategy controls how a tool rule's path lists are combined with the +// global default_policy values. Applied independently per field. +// +// merge = global ∪ rule list +// override = rule list only (replaces global) +// default = global list only (rule's values ignored) +type MergeStrategy struct { + RestrictedDirectories string `json:"restricted_directories"` + RestrictedFiles string `json:"restricted_files"` + AllowedDirectories string `json:"allowed_directories"` + AllowedFiles string `json:"allowed_files"` +} + +// blastRadiusCount tracks how many files have been written during this session. +// Kept as a package-level atomic so concurrent BeforeFileEdit handlers are safe. +var blastRadiusCount int32 + +// totalFileSizeBytes accumulates the byte length of all proposed file edits this session. +// Incremented in BeforeFileEdit before any bytes are written to disk. +var totalFileSizeBytes int64 + +// ResetBlastRadiusCount resets the session-level file write counter. Exposed for tests. +func ResetBlastRadiusCount() { + atomic.StoreInt32(&blastRadiusCount, 0) +} + +// ResetTotalFileSizeCount resets the session-level total bytes counter. Exposed for tests. +func ResetTotalFileSizeCount() { + atomic.StoreInt64(&totalFileSizeBytes, 0) +} + +// CheckAndIncrementBlastRadius increments the file-write counter and returns +// blocked=true with a reason if the configured threshold has been exceeded. +func CheckAndIncrementBlastRadius() (blocked bool, reason string) { + limit := LoadBlastRadiusLimit() + if limit == nil || !limit.Enabled || limit.Threshold <= 0 { + return false, "" + } + count := int(atomic.AddInt32(&blastRadiusCount, 1)) + if count > limit.Threshold { + return true, fmt.Sprintf( + "Blocked by Checkmarx: blast radius limit exceeded. "+ + "This session has written %d files, exceeding the policy threshold of %d.%s", + count, limit.Threshold, DenyMessage, + ) + } + return false, "" +} + +// CheckAndIncrementTotalFileSize adds sizeBytes to the running total and returns +// blocked=true if the configured max_total_file_size_kb would be exceeded. +func CheckAndIncrementTotalFileSize(sizeBytes int64) (blocked bool, reason string) { + limits := LoadFilesLimits() + if limits == nil || !limits.Enabled || limits.MaxTotalFileSizeKB <= 0 { + return false, "" + } + limitBytes := int64(limits.MaxTotalFileSizeKB) * 1024 + newTotal := atomic.AddInt64(&totalFileSizeBytes, sizeBytes) + if newTotal > limitBytes { + return true, fmt.Sprintf( + "Blocked by Checkmarx: total file size limit exceeded. "+ + "This session has written %d KB, exceeding the policy threshold of %d KB.%s", + newTotal/1024, limits.MaxTotalFileSizeKB, DenyMessage, + ) + } + return false, "" +} + +// ShellPolicyPath returns the path to the policy file: ~/.checkmarx/policyhooks.json +func ShellPolicyPath() string { + home, err := os.UserHomeDir() + if err != nil { + return "" + } + return filepath.Join(home, ".checkmarx", "policyhooks.json") +} + +// LoadPolicy reads and parses ~/.checkmarx/policyhooks.json. +// Returns nil on any error (fail-open: a missing or malformed policy should never block the developer). +func LoadPolicy() *HooksPolicy { + data, err := os.ReadFile(ShellPolicyPath()) + if err != nil { + return nil + } + data = bytes.TrimPrefix(data, utf8BOM) + var p HooksPolicy + if err := json.Unmarshal(data, &p); err != nil { + return nil + } + return &p +} + +// GetOSPaths returns the path list for the current OS from a PathPolicy entry. +func GetOSPaths(pp PathPolicy) []string { + if !pp.Enabled { + return nil + } + switch runtime.GOOS { + case "linux": + return pp.Linux + case "darwin": + return pp.Mac + case "windows": + return pp.Windows + default: + return nil + } +} + +// MatchesOS returns true when any of the tool's OS labels match the current OS. +func MatchesOS(toolOS []string, currentOS string) bool { + for _, o := range toolOS { + mapped := o + if o == "mac" { + mapped = "darwin" + } + if mapped == currentOS { + return true + } + } + return false +} + +// LoadBlacklistedCommands reads the policy file and returns all command names +// (lowercased) that are blacklisted on the current OS, together with their metadata. +func LoadBlacklistedCommands() map[string]BlacklistedTool { + blacklisted := map[string]BlacklistedTool{} + policy := LoadPolicy() + if policy == nil { + return blacklisted // fail-open + } + if !policy.DefaultPolicy.BlacklistTools.Enabled { + return blacklisted + } + for _, t := range policy.DefaultPolicy.BlacklistTools.Tools { + if !MatchesOS(t.OS, runtime.GOOS) { + continue + } + blacklisted[strings.ToLower(t.Name)] = t + } + return blacklisted +} + +// LoadRestrictedPaths returns the OS-specific restricted file and directory +// lists from the policy file. +func LoadRestrictedPaths() (files []string, dirs []string) { + policy := LoadPolicy() + if policy == nil { + return nil, nil + } + return GetOSPaths(policy.DefaultPolicy.RestrictedFiles), + GetOSPaths(policy.DefaultPolicy.RestrictedDirectories) +} + +// LoadEffectiveRestrictedPaths returns the union of the global default +// restricted_files / restricted_directories and each enabled tool rule's +// effective restricted lists, combined per that rule's merge_strategy. +// +// Used by prompt-side checks where no specific tool is matched but any tool +// rule's restriction may still be relevant. Rules disabled, scoped to other +// OSes, or with merge_strategy == "default" contribute nothing beyond the +// global lists; "merge" rules contribute their entries; "override" rules +// contribute their entries (without re-adding global, since global is already +// included once). +func LoadEffectiveRestrictedPaths() (files []string, dirs []string) { + globalFiles, globalDirs := LoadRestrictedPaths() + + seenF := map[string]struct{}{} + seenD := map[string]struct{}{} + add := func(seen map[string]struct{}, dst *[]string, src []string) { + for _, s := range src { + if _, ok := seen[s]; ok { + continue + } + seen[s] = struct{}{} + *dst = append(*dst, s) + } + } + add(seenF, &files, globalFiles) + add(seenD, &dirs, globalDirs) + + policy := LoadPolicy() + if policy == nil || !policy.Tools.Enabled { + return files, dirs + } + for i := range policy.Tools.Rules { + rule := &policy.Tools.Rules[i] + if rule.Enabled != nil && !*rule.Enabled { + continue + } + if len(rule.OS) > 0 && !MatchesOS(rule.OS, runtime.GOOS) { + continue + } + ef := ResolveRestrictedPaths(globalFiles, GetOSPaths(rule.RestrictedFiles), rule.MergeStrategy.RestrictedFiles) + ed := ResolveRestrictedPaths(globalDirs, GetOSPaths(rule.RestrictedDirectories), rule.MergeStrategy.RestrictedDirectories) + add(seenF, &files, ef) + add(seenD, &dirs, ed) + } + return files, dirs +} + +// LoadAllowedPaths returns the OS-specific allowed file and directory +// lists from the policy file. +func LoadAllowedPaths() (files []string, dirs []string) { + policy := LoadPolicy() + if policy == nil { + return nil, nil + } + return GetOSPaths(policy.DefaultPolicy.AllowedFiles), + GetOSPaths(policy.DefaultPolicy.AllowedDirectories) +} + +// LoadBlastRadiusLimit returns the blast-radius limit config, or nil if disabled / absent. +func LoadBlastRadiusLimit() *BlastRadiusLimit { + policy := LoadPolicy() + if policy == nil { + return nil + } + limit := policy.DefaultPolicy.BlastRadiusLimit + if !limit.Enabled { + return nil + } + return &limit +} + +// LoadBlockedExtensions returns the list of file extensions blocked from AI context. +// Returns nil when the feature is disabled or the policy is absent. +func LoadBlockedExtensions() []string { + policy := LoadPolicy() + if policy == nil { + return nil + } + cp := policy.DefaultPolicy.ContextPolicy + if !cp.Enabled || !cp.BlockedExtensions.Enabled { + return nil + } + return cp.BlockedExtensions.Extensions +} + +// LoadFilesLimits returns the files-limits config, or nil if disabled / absent. +func LoadFilesLimits() *FilesLimits { + policy := LoadPolicy() + if policy == nil { + return nil + } + cp := policy.DefaultPolicy.ContextPolicy + if !cp.Enabled || !cp.FilesLimits.Enabled { + return nil + } + fl := cp.FilesLimits + return &fl +} + +// FindMatchingToolRule returns the first tool rule whose tool list contains a +// command name appearing anywhere in the command as a whole token, and whose +// OS list matches the current OS. Returns nil if no rule matches, the tools +// section is disabled, or the rule is explicitly disabled. +// +// Scanning the whole command (not just fields[0]) is what lets compound +// invocations like `cd /foo && mvn deploy` match the `mvn` rule so its +// args_exclude can be enforced. +func FindMatchingToolRule(command string) *ToolRule { + policy := LoadPolicy() + if policy == nil || !policy.Tools.Enabled { + return nil + } + if strings.TrimSpace(command) == "" { + return nil + } + cmdLower := strings.ToLower(command) + for i := range policy.Tools.Rules { + rule := &policy.Tools.Rules[i] + // Explicit `enabled: false` disables a rule; nil (absent) keeps it active. + if rule.Enabled != nil && !*rule.Enabled { + continue + } + if len(rule.OS) > 0 && !MatchesOS(rule.OS, runtime.GOOS) { + continue + } + for _, name := range rule.Tool { + if containsAsToken(cmdLower, strings.ToLower(name)) { + return rule + } + } + } + return nil +} + +// containsAsToken reports whether needle appears in haystack as a whole token — +// flanked on both sides by start-of-string, end-of-string, or any non-word byte +// (space, ;, &, |, /, \, ", ', etc.). Prevents false positives such as matching +// the tool name "mvn" inside paths like "/opt/foo-mvn-bar/". +// Both inputs MUST already be lowercased. +func containsAsToken(haystack, needle string) bool { + if needle == "" { + return false + } + start := 0 + for start <= len(haystack)-len(needle) { + idx := strings.Index(haystack[start:], needle) + if idx < 0 { + return false + } + absolute := start + idx + if isTokenBoundary(haystack, absolute, absolute+len(needle)) { + return true + } + start = absolute + 1 + } + return false +} + +// isTokenBoundary reports whether s[lo:hi] is bounded on both sides by a +// non-word byte (or start/end of string). Word bytes are a-z, 0-9, '_', '-'. +// The dash is treated as part of a word so "mvn" does NOT match inside +// "foo-mvn" or "mvn-helper". +func isTokenBoundary(s string, lo, hi int) bool { + isWordByte := func(b byte) bool { + return (b >= 'a' && b <= 'z') || (b >= '0' && b <= '9') || b == '_' || b == '-' + } + if lo > 0 && isWordByte(s[lo-1]) { + return false + } + if hi < len(s) && isWordByte(s[hi]) { + return false + } + return true +} + +// ResolveAllowedPaths combines globalPaths and rulePaths according to strategy. +// Valid strategies: "merge", "override", "default" (anything else acts as "default"). +func ResolveAllowedPaths(globalPaths, rulePaths []string, strategy string) []string { + switch strategy { + case "merge": + seen := map[string]struct{}{} + result := make([]string, 0, len(globalPaths)+len(rulePaths)) + for _, p := range append(globalPaths, rulePaths...) { + if _, ok := seen[p]; !ok { + seen[p] = struct{}{} + result = append(result, p) + } + } + return result + case "override": + return rulePaths + default: // "default" or anything unrecognised + return globalPaths + } +} + +// ResolveRestrictedPaths combines global and tool-level restricted paths per strategy. +// Semantics are identical to ResolveAllowedPaths; the two names exist for readability +// at call sites that work with different path categories. +func ResolveRestrictedPaths(globalPaths, rulePaths []string, strategy string) []string { + return ResolveAllowedPaths(globalPaths, rulePaths, strategy) +} + +// NormalizeWorkspaceRoot canonicalises a workspace root so it can be compared +// against policy path entries. Cursor reports Windows roots as "/c:/foo/bar"; +// strip the leading slash before a drive letter so PathUnderAny's prefix match +// lines up with policy entries like "C:\\foo\\bar\\". +func NormalizeWorkspaceRoot(root string) string { + r := filepath.ToSlash(root) + if len(r) >= 3 && r[0] == '/' && isASCIILetter(r[1]) && r[2] == ':' { + r = r[1:] + } + return r +} + +func isASCIILetter(b byte) bool { + return (b >= 'a' && b <= 'z') || (b >= 'A' && b <= 'Z') +} diff --git a/internal/commands/agenthooks/guardrails/policy_test.go b/internal/commands/agenthooks/guardrails/policy_test.go new file mode 100644 index 000000000..24b207ec3 --- /dev/null +++ b/internal/commands/agenthooks/guardrails/policy_test.go @@ -0,0 +1,1620 @@ +//go:build !integration + +package guardrails_test + +import ( + "encoding/json" + "os" + "path/filepath" + "runtime" + "strings" + "testing" + + "github.com/checkmarx/ast-cli/internal/commands/agenthooks/guardrails" +) + +// setHomeDir redirects os.UserHomeDir() to dir and returns a cleanup function. +func setHomeDir(dir string) func() { + if runtime.GOOS == "windows" { + orig, had := os.LookupEnv("USERPROFILE") + os.Setenv("USERPROFILE", dir) + return func() { + if had { + os.Setenv("USERPROFILE", orig) + } else { + os.Unsetenv("USERPROFILE") + } + } + } + orig, had := os.LookupEnv("HOME") + os.Setenv("HOME", dir) + return func() { + if had { + os.Setenv("HOME", orig) + } else { + os.Unsetenv("HOME") + } + } +} + +// writePolicy writes a HooksPolicy to a temp file and sets the home dir so +// LoadPolicy picks it up. Returns a cleanup function. +func writePolicy(t *testing.T, policy guardrails.HooksPolicy) func() { + t.Helper() + data, err := json.Marshal(policy) + if err != nil { + t.Fatalf("marshal policy: %v", err) + } + dir := t.TempDir() + cxDir := filepath.Join(dir, ".checkmarx") + if err := os.MkdirAll(cxDir, 0o755); err != nil { + t.Fatalf("mkdir: %v", err) + } + if err := os.WriteFile(filepath.Join(cxDir, "policyhooks.json"), data, 0o644); err != nil { + t.Fatalf("write policy: %v", err) + } + return setHomeDir(dir) +} + +// currentOS returns the policy OS label for the current platform. +func currentOS() string { + switch runtime.GOOS { + case "darwin": + return "mac" + case "windows": + return "windows" + default: + return "linux" + } +} + +// -------------------------------------------------------------------------- +// LoadPolicy +// -------------------------------------------------------------------------- + +func TestLoadPolicy_MissingFile(t *testing.T) { + dir := t.TempDir() + defer setHomeDir(dir)() + + if got := guardrails.LoadPolicy(); got != nil { + t.Fatal("expected nil for missing file") + } +} + +func TestLoadPolicy_MalformedJSON(t *testing.T) { + dir := t.TempDir() + cxDir := filepath.Join(dir, ".checkmarx") + os.MkdirAll(cxDir, 0o755) + os.WriteFile(filepath.Join(cxDir, "policyhooks.json"), []byte("not-json{{{"), 0o644) + defer setHomeDir(dir)() + + if got := guardrails.LoadPolicy(); got != nil { + t.Fatal("expected nil for malformed JSON") + } +} + +func TestLoadPolicy_UTF8BOMPrefix(t *testing.T) { + policy := guardrails.HooksPolicy{} + policy.DefaultPolicy.BlacklistTools.Enabled = true + body, err := json.Marshal(policy) + if err != nil { + t.Fatalf("marshal policy: %v", err) + } + + dir := t.TempDir() + cxDir := filepath.Join(dir, ".checkmarx") + if err := os.MkdirAll(cxDir, 0o755); err != nil { + t.Fatalf("mkdir: %v", err) + } + withBOM := append([]byte{0xEF, 0xBB, 0xBF}, body...) + if err := os.WriteFile(filepath.Join(cxDir, "policyhooks.json"), withBOM, 0o644); err != nil { + t.Fatalf("write policy: %v", err) + } + defer setHomeDir(dir)() + + got := guardrails.LoadPolicy() + if got == nil { + t.Fatal("expected non-nil policy when file is BOM-prefixed; LoadPolicy should strip the UTF-8 BOM") + } + if !got.DefaultPolicy.BlacklistTools.Enabled { + t.Fatal("BlacklistTools.Enabled should be true after BOM strip") + } +} + +func TestLoadPolicy_ValidJSON(t *testing.T) { + policy := guardrails.HooksPolicy{} + policy.DefaultPolicy.BlacklistTools.Enabled = true + cleanup := writePolicy(t, policy) + defer cleanup() + + got := guardrails.LoadPolicy() + if got == nil { + t.Fatal("expected non-nil policy") + } + if !got.DefaultPolicy.BlacklistTools.Enabled { + t.Fatal("BlacklistTools.Enabled should be true") + } +} + +// -------------------------------------------------------------------------- +// LoadBlacklistedCommands +// -------------------------------------------------------------------------- + +func TestLoadBlacklistedCommands_Disabled(t *testing.T) { + policy := guardrails.HooksPolicy{} + policy.DefaultPolicy.BlacklistTools.Enabled = false + policy.DefaultPolicy.BlacklistTools.Tools = []guardrails.BlacklistedTool{ + {Name: "rm -rf", OS: []string{currentOS()}, Category: "destructive", Risk: "bad"}, + } + cleanup := writePolicy(t, policy) + defer cleanup() + + got := guardrails.LoadBlacklistedCommands() + if len(got) != 0 { + t.Fatalf("expected empty map when disabled, got %d entries", len(got)) + } +} + +func TestLoadBlacklistedCommands_OSMatch(t *testing.T) { + policy := guardrails.HooksPolicy{} + policy.DefaultPolicy.BlacklistTools.Enabled = true + policy.DefaultPolicy.BlacklistTools.Tools = []guardrails.BlacklistedTool{ + {Name: "danger-cmd", OS: []string{currentOS()}, Category: "test", Risk: "none"}, + } + cleanup := writePolicy(t, policy) + defer cleanup() + + got := guardrails.LoadBlacklistedCommands() + if _, ok := got["danger-cmd"]; !ok { + t.Fatal("expected danger-cmd in blacklist for current OS") + } +} + +func TestLoadBlacklistedCommands_OSNoMatch(t *testing.T) { + wrongOS := "linux" + if runtime.GOOS == "linux" { + wrongOS = "windows" + } + policy := guardrails.HooksPolicy{} + policy.DefaultPolicy.BlacklistTools.Enabled = true + policy.DefaultPolicy.BlacklistTools.Tools = []guardrails.BlacklistedTool{ + {Name: "other-os-cmd", OS: []string{wrongOS}, Category: "test", Risk: "none"}, + } + cleanup := writePolicy(t, policy) + defer cleanup() + + got := guardrails.LoadBlacklistedCommands() + if _, ok := got["other-os-cmd"]; ok { + t.Fatal("should not include tool for wrong OS") + } +} + +func TestLoadBlacklistedCommands_CaseInsensitive(t *testing.T) { + policy := guardrails.HooksPolicy{} + policy.DefaultPolicy.BlacklistTools.Enabled = true + policy.DefaultPolicy.BlacklistTools.Tools = []guardrails.BlacklistedTool{ + {Name: "RM -RF", OS: []string{currentOS()}, Category: "destructive", Risk: "bad"}, + } + cleanup := writePolicy(t, policy) + defer cleanup() + + got := guardrails.LoadBlacklistedCommands() + if _, ok := got["rm -rf"]; !ok { + t.Fatal("expected lowercased key in blacklist map") + } +} + +// -------------------------------------------------------------------------- +// CheckShellCommand — blacklist +// -------------------------------------------------------------------------- + +func TestCheckShellCommand_Blacklisted(t *testing.T) { + policy := guardrails.HooksPolicy{} + policy.DefaultPolicy.BlacklistTools.Enabled = true + policy.DefaultPolicy.BlacklistTools.Tools = []guardrails.BlacklistedTool{ + {Name: "rm -rf", OS: []string{currentOS()}, Category: "destructive", Risk: "wipes files"}, + } + cleanup := writePolicy(t, policy) + defer cleanup() + + blocked, needsConfirm, reason := guardrails.CheckShellCommand("rm -rf /tmp/foo", "") + if !blocked { + t.Fatal("expected blocked=true for blacklisted command") + } + if needsConfirm { + t.Fatal("expected needsConfirm=false for blacklisted command") + } + if reason == "" { + t.Fatal("expected non-empty reason") + } +} + +func TestCheckShellCommand_Clean(t *testing.T) { + policy := guardrails.HooksPolicy{} + policy.DefaultPolicy.BlacklistTools.Enabled = true + policy.DefaultPolicy.BlacklistTools.Tools = []guardrails.BlacklistedTool{ + {Name: "rm -rf", OS: []string{currentOS()}, Category: "destructive", Risk: "bad"}, + } + cleanup := writePolicy(t, policy) + defer cleanup() + + blocked, _, _ := guardrails.CheckShellCommand("ls -la", "") + if blocked { + t.Fatal("expected clean command to pass") + } +} + +// -------------------------------------------------------------------------- +// CheckShellCommand — tool rules +// -------------------------------------------------------------------------- + +func makeToolRulePolicy(rule guardrails.ToolRule) guardrails.HooksPolicy { + policy := guardrails.HooksPolicy{} + policy.Tools.Enabled = true + policy.Tools.Rules = []guardrails.ToolRule{rule} + return policy +} + +func TestCheckShellCommand_ToolRule_ExcludedArg(t *testing.T) { + rule := guardrails.ToolRule{ + ID: "t1", + Tool: []string{"mvn"}, + OS: []string{currentOS()}, + ArgsExclude: []string{"deploy"}, + } + cleanup := writePolicy(t, makeToolRulePolicy(rule)) + defer cleanup() + + blocked, needsConfirm, reason := guardrails.CheckShellCommand("mvn deploy", "/project") + if !blocked { + t.Fatal("expected blocked for excluded arg") + } + if needsConfirm { + t.Fatal("excluded arg should hard-deny, not ask") + } + if reason == "" { + t.Fatal("expected reason") + } +} + +func TestCheckShellCommand_ToolRule_UnknownArg_AlwaysAsks(t *testing.T) { + // Unmatched args always produce needsConfirm=true regardless of rule.Action. + for _, action := range []string{"ask", "block", "allow", ""} { + rule := guardrails.ToolRule{ + ID: "t2", + Tool: []string{"mvn"}, + OS: []string{currentOS()}, + ArgsInclude: []string{"compile", "test"}, + } + cleanup := writePolicy(t, makeToolRulePolicy(rule)) + + blocked, needsConfirm, _ := guardrails.CheckShellCommand("mvn unknown-goal", "") + if !blocked { + t.Fatalf("action=%q: expected blocked for arg not in whitelist", action) + } + if !needsConfirm { + t.Fatalf("action=%q: unmatched arg must always ask (needsConfirm=true), not hard-block", action) + } + cleanup() + } +} + +func TestCheckShellCommand_ToolRule_AllowedArg(t *testing.T) { + rule := guardrails.ToolRule{ + ID: "t4", + Tool: []string{"mvn"}, + OS: []string{currentOS()}, + ArgsInclude: []string{"compile", "test"}, + } + cleanup := writePolicy(t, makeToolRulePolicy(rule)) + defer cleanup() + + blocked, _, _ := guardrails.CheckShellCommand("mvn compile", "") + if blocked { + t.Fatal("expected allowed for whitelisted arg") + } +} + +func TestCheckShellCommand_ToolRule_GlobMatch_Allowed(t *testing.T) { + rule := guardrails.ToolRule{ + ID: "tg1", + Tool: []string{"mvn"}, + OS: []string{currentOS()}, + ArgsInclude: []string{"compile", "-D*", "--*"}, + } + cleanup := writePolicy(t, makeToolRulePolicy(rule)) + defer cleanup() + + // "-Dmaven.test.skip=true" matches glob "-D*" → allowed + if blocked, _, _ := guardrails.CheckShellCommand("mvn compile -Dmaven.test.skip=true", ""); blocked { + t.Fatal("expected -D* glob to allow -Dmaven.test.skip=true") + } + // "--offline" matches glob "--*" → allowed + if blocked, _, _ := guardrails.CheckShellCommand("mvn compile --offline", ""); blocked { + t.Fatal("expected --* glob to allow --offline") + } +} + +func TestCheckShellCommand_ToolRule_GlobMatch_MissAsks(t *testing.T) { + rule := guardrails.ToolRule{ + ID: "tg2", + Tool: []string{"mvn"}, + OS: []string{currentOS()}, + ArgsInclude: []string{"compile", "-D*"}, + } + cleanup := writePolicy(t, makeToolRulePolicy(rule)) + defer cleanup() + + // "-Pfoo" does not match any pattern → ask (not hard block) + blocked, needsConfirm, _ := guardrails.CheckShellCommand("mvn compile -Pfoo", "") + if !blocked { + t.Fatal("expected blocked for arg not matching any glob") + } + if !needsConfirm { + t.Fatal("unmatched glob should ask, not hard-block (even when action=block)") + } +} + +func TestCheckShellCommand_ToolRule_ExcludeBeatsInclude(t *testing.T) { + // Even if an arg is in args_include, args_exclude takes precedence. + rule := guardrails.ToolRule{ + ID: "t5", + Tool: []string{"mvn"}, + OS: []string{currentOS()}, + ArgsInclude: []string{"deploy"}, + ArgsExclude: []string{"deploy"}, + } + cleanup := writePolicy(t, makeToolRulePolicy(rule)) + defer cleanup() + + blocked, needsConfirm, _ := guardrails.CheckShellCommand("mvn deploy", "") + if !blocked || needsConfirm { + t.Fatal("args_exclude must hard-deny before args_include whitelist is checked") + } +} + +// -------------------------------------------------------------------------- +// CheckShellCommand — allowed_directories with merge strategy +// -------------------------------------------------------------------------- + +func TestCheckShellCommand_AllowedDirs_Merge(t *testing.T) { + globalDir := "/global/allowed" + ruleDir := "/rule/allowed" + + policy := guardrails.HooksPolicy{} + policy.DefaultPolicy.AllowedDirectories.Enabled = true + policy.DefaultPolicy.AllowedDirectories.Linux = []string{globalDir} + policy.DefaultPolicy.AllowedDirectories.Mac = []string{globalDir} + policy.DefaultPolicy.AllowedDirectories.Windows = []string{globalDir} + + rule := guardrails.ToolRule{ + ID: "t6", + Tool: []string{"mvn"}, + OS: []string{currentOS()}, + AllowedDirectories: guardrails.PathPolicy{ + Enabled: true, + Linux: []string{ruleDir}, + Mac: []string{ruleDir}, + Windows: []string{ruleDir}, + }, + + MergeStrategy: guardrails.MergeStrategy{AllowedDirectories: "merge"}, + } + policy.Tools.Enabled = true + policy.Tools.Rules = []guardrails.ToolRule{rule} + + cleanup := writePolicy(t, policy) + defer cleanup() + + // Both global and rule dirs should be allowed. + if blocked, _, _ := guardrails.CheckShellCommand("mvn compile", globalDir); blocked { + t.Fatal("global dir should be allowed (merge)") + } + if blocked, _, _ := guardrails.CheckShellCommand("mvn compile", ruleDir); blocked { + t.Fatal("rule dir should be allowed (merge)") + } + // A dir outside both should trigger ask. + blocked, needsConfirm, _ := guardrails.CheckShellCommand("mvn compile", "/other") + if !blocked || !needsConfirm { + t.Fatal("dir outside merge set should trigger ask") + } +} + +func TestCheckShellCommand_AllowedDirs_Override(t *testing.T) { + globalDir := "/global/allowed" + ruleDir := "/rule/only" + + policy := guardrails.HooksPolicy{} + policy.DefaultPolicy.AllowedDirectories.Enabled = true + policy.DefaultPolicy.AllowedDirectories.Linux = []string{globalDir} + policy.DefaultPolicy.AllowedDirectories.Mac = []string{globalDir} + policy.DefaultPolicy.AllowedDirectories.Windows = []string{globalDir} + + rule := guardrails.ToolRule{ + ID: "t7", + Tool: []string{"mvn"}, + OS: []string{currentOS()}, + AllowedDirectories: guardrails.PathPolicy{ + Enabled: true, + Linux: []string{ruleDir}, + Mac: []string{ruleDir}, + Windows: []string{ruleDir}, + }, + + MergeStrategy: guardrails.MergeStrategy{AllowedDirectories: "override"}, + } + policy.Tools.Enabled = true + policy.Tools.Rules = []guardrails.ToolRule{rule} + + cleanup := writePolicy(t, policy) + defer cleanup() + + // Global dir is no longer in the effective set (override). + if blocked, _, _ := guardrails.CheckShellCommand("mvn compile", ruleDir); blocked { + t.Fatal("rule dir should be allowed (override)") + } + blocked, _, _ := guardrails.CheckShellCommand("mvn compile", globalDir) + if !blocked { + t.Fatal("global dir should be blocked (override replaces global list)") + } +} + +func TestCheckShellCommand_AllowedDirs_Default(t *testing.T) { + globalDir := "/global/allowed" + ruleDir := "/rule/ignored" + + policy := guardrails.HooksPolicy{} + policy.DefaultPolicy.AllowedDirectories.Enabled = true + policy.DefaultPolicy.AllowedDirectories.Linux = []string{globalDir} + policy.DefaultPolicy.AllowedDirectories.Mac = []string{globalDir} + policy.DefaultPolicy.AllowedDirectories.Windows = []string{globalDir} + + rule := guardrails.ToolRule{ + ID: "t8", + Tool: []string{"mvn"}, + OS: []string{currentOS()}, + AllowedDirectories: guardrails.PathPolicy{ + Enabled: true, + Linux: []string{ruleDir}, + Mac: []string{ruleDir}, + Windows: []string{ruleDir}, + }, + + MergeStrategy: guardrails.MergeStrategy{AllowedDirectories: "default"}, + } + policy.Tools.Enabled = true + policy.Tools.Rules = []guardrails.ToolRule{rule} + + cleanup := writePolicy(t, policy) + defer cleanup() + + // Only global dir allowed. + if blocked, _, _ := guardrails.CheckShellCommand("mvn compile", globalDir); blocked { + t.Fatal("global dir should be allowed (default strategy)") + } + // Rule dir is ignored. + if blocked, _, _ := guardrails.CheckShellCommand("mvn compile", ruleDir); !blocked { + t.Fatal("rule dir should not be allowed when strategy=default") + } +} + +// -------------------------------------------------------------------------- +// LoadRestrictedPaths +// -------------------------------------------------------------------------- + +func TestLoadRestrictedPaths_Nil(t *testing.T) { + dir := t.TempDir() + defer setHomeDir(dir)() + + files, dirs := guardrails.LoadRestrictedPaths() + if files != nil || dirs != nil { + t.Fatal("expected nil for missing policy") + } +} + +func TestLoadRestrictedPaths_Disabled(t *testing.T) { + policy := guardrails.HooksPolicy{} + policy.DefaultPolicy.RestrictedFiles.Enabled = false + policy.DefaultPolicy.RestrictedDirectories.Enabled = false + cleanup := writePolicy(t, policy) + defer cleanup() + + files, dirs := guardrails.LoadRestrictedPaths() + if len(files) != 0 || len(dirs) != 0 { + t.Fatal("expected empty lists when disabled") + } +} + +func TestLoadRestrictedPaths_Enabled(t *testing.T) { + policy := guardrails.HooksPolicy{} + policy.DefaultPolicy.RestrictedFiles.Enabled = true + policy.DefaultPolicy.RestrictedFiles.Linux = []string{".env"} + policy.DefaultPolicy.RestrictedFiles.Mac = []string{".env"} + policy.DefaultPolicy.RestrictedFiles.Windows = []string{".env"} + policy.DefaultPolicy.RestrictedDirectories.Enabled = true + policy.DefaultPolicy.RestrictedDirectories.Linux = []string{"/etc/"} + policy.DefaultPolicy.RestrictedDirectories.Mac = []string{"/etc/"} + policy.DefaultPolicy.RestrictedDirectories.Windows = []string{"C:\\Windows\\"} + cleanup := writePolicy(t, policy) + defer cleanup() + + files, dirs := guardrails.LoadRestrictedPaths() + if len(files) == 0 { + t.Fatal("expected non-empty files list") + } + if len(dirs) == 0 { + t.Fatal("expected non-empty dirs list") + } +} + +// -------------------------------------------------------------------------- +// LoadAllowedPaths +// -------------------------------------------------------------------------- + +func TestLoadAllowedPaths_Nil(t *testing.T) { + dir := t.TempDir() + defer setHomeDir(dir)() + + files, dirs := guardrails.LoadAllowedPaths() + if files != nil || dirs != nil { + t.Fatal("expected nil for missing policy") + } +} + +func TestLoadAllowedPaths_Disabled(t *testing.T) { + policy := guardrails.HooksPolicy{} + policy.DefaultPolicy.AllowedFiles.Enabled = false + policy.DefaultPolicy.AllowedDirectories.Enabled = false + cleanup := writePolicy(t, policy) + defer cleanup() + + files, dirs := guardrails.LoadAllowedPaths() + if len(files) != 0 || len(dirs) != 0 { + t.Fatal("expected empty when disabled") + } +} + +func TestLoadAllowedPaths_Enabled(t *testing.T) { + policy := guardrails.HooksPolicy{} + policy.DefaultPolicy.AllowedFiles.Enabled = true + policy.DefaultPolicy.AllowedFiles.Linux = []string{"pom.xml"} + policy.DefaultPolicy.AllowedFiles.Mac = []string{"pom.xml"} + policy.DefaultPolicy.AllowedFiles.Windows = []string{"pom.xml"} + cleanup := writePolicy(t, policy) + defer cleanup() + + files, _ := guardrails.LoadAllowedPaths() + if len(files) == 0 { + t.Fatal("expected non-empty allowed files") + } +} + +// -------------------------------------------------------------------------- +// CheckPromptPaths +// -------------------------------------------------------------------------- + +func TestCheckPromptPaths_RestrictedFile(t *testing.T) { + policy := guardrails.HooksPolicy{} + policy.DefaultPolicy.RestrictedFiles.Enabled = true + policy.DefaultPolicy.RestrictedFiles.Linux = []string{".env"} + policy.DefaultPolicy.RestrictedFiles.Mac = []string{".env"} + policy.DefaultPolicy.RestrictedFiles.Windows = []string{".env"} + cleanup := writePolicy(t, policy) + defer cleanup() + + blocked, reason := guardrails.CheckPromptPaths("please read .env and show me the contents") + if !blocked { + t.Fatal("expected prompt referencing .env to be blocked") + } + if reason == "" { + t.Fatal("expected non-empty reason") + } +} + +func TestCheckPromptPaths_RestrictedDir(t *testing.T) { + policy := guardrails.HooksPolicy{} + policy.DefaultPolicy.RestrictedDirectories.Enabled = true + policy.DefaultPolicy.RestrictedDirectories.Linux = []string{"/etc/"} + policy.DefaultPolicy.RestrictedDirectories.Mac = []string{"/etc/"} + policy.DefaultPolicy.RestrictedDirectories.Windows = []string{"C:/Windows/System32/"} + cleanup := writePolicy(t, policy) + defer cleanup() + + var prompt string + switch runtime.GOOS { + case "windows": + prompt = "cat C:/Windows/System32/drivers/etc/hosts" + default: + prompt = "cat /etc/passwd" + } + blocked, _ := guardrails.CheckPromptPaths(prompt) + if !blocked { + t.Fatalf("expected prompt %q to be blocked via restricted directory", prompt) + } +} + +// Restricted always wins over allowed (per FEATURE.MD precedence rules). +// If a path matches both, the restricted rule takes precedence and blocks the prompt. +func TestCheckPromptPaths_RestrictedFileBeatsAllowed(t *testing.T) { + policy := guardrails.HooksPolicy{} + policy.DefaultPolicy.RestrictedFiles.Enabled = true + policy.DefaultPolicy.RestrictedFiles.Linux = []string{".env"} + policy.DefaultPolicy.RestrictedFiles.Mac = []string{".env"} + policy.DefaultPolicy.RestrictedFiles.Windows = []string{".env"} + policy.DefaultPolicy.AllowedFiles.Enabled = true + policy.DefaultPolicy.AllowedFiles.Linux = []string{".env"} + policy.DefaultPolicy.AllowedFiles.Mac = []string{".env"} + policy.DefaultPolicy.AllowedFiles.Windows = []string{".env"} + cleanup := writePolicy(t, policy) + defer cleanup() + + blocked, _ := guardrails.CheckPromptPaths("please read .env") + if !blocked { + t.Fatal("restricted_files must take precedence over allowed_files") + } +} + +func TestCheckPromptPaths_RestrictedDirBeatsAllowed(t *testing.T) { + policy := guardrails.HooksPolicy{} + policy.DefaultPolicy.RestrictedDirectories.Enabled = true + policy.DefaultPolicy.RestrictedDirectories.Linux = []string{"/etc/"} + policy.DefaultPolicy.RestrictedDirectories.Mac = []string{"/etc/"} + policy.DefaultPolicy.RestrictedDirectories.Windows = []string{"C:/Windows/System32/"} + policy.DefaultPolicy.AllowedDirectories.Enabled = true + policy.DefaultPolicy.AllowedDirectories.Linux = []string{"/etc/"} + policy.DefaultPolicy.AllowedDirectories.Mac = []string{"/etc/"} + policy.DefaultPolicy.AllowedDirectories.Windows = []string{"C:/Windows/System32/"} + cleanup := writePolicy(t, policy) + defer cleanup() + + var prompt string + switch runtime.GOOS { + case "windows": + prompt = "cat C:/Windows/System32/drivers/etc/hosts" + default: + prompt = "cat /etc/passwd" + } + blocked, _ := guardrails.CheckPromptPaths(prompt) + if !blocked { + t.Fatal("restricted_directories must take precedence over allowed_directories") + } +} + +func TestCheckPromptPaths_CleanPrompt(t *testing.T) { + policy := guardrails.HooksPolicy{} + policy.DefaultPolicy.RestrictedFiles.Enabled = true + policy.DefaultPolicy.RestrictedFiles.Linux = []string{".env"} + policy.DefaultPolicy.RestrictedFiles.Mac = []string{".env"} + policy.DefaultPolicy.RestrictedFiles.Windows = []string{".env"} + cleanup := writePolicy(t, policy) + defer cleanup() + + blocked, _ := guardrails.CheckPromptPaths("show me the main.go file") + if blocked { + t.Fatal("clean prompt should not be blocked") + } +} + +// -------------------------------------------------------------------------- +// Glob pattern support (doublestar) +// -------------------------------------------------------------------------- + +// restrictedFilesPolicy is a helper building a HooksPolicy whose restricted_files list +// applies on every OS — avoids repeating the boilerplate across glob tests. +func restrictedFilesPolicy(patterns []string) guardrails.HooksPolicy { + p := guardrails.HooksPolicy{} + p.DefaultPolicy.RestrictedFiles.Enabled = true + p.DefaultPolicy.RestrictedFiles.Linux = patterns + p.DefaultPolicy.RestrictedFiles.Mac = patterns + p.DefaultPolicy.RestrictedFiles.Windows = patterns + return p +} + +// restrictedDirsPolicy mirrors restrictedFilesPolicy for restricted_directories. +func restrictedDirsPolicy(patterns []string) guardrails.HooksPolicy { + p := guardrails.HooksPolicy{} + p.DefaultPolicy.RestrictedDirectories.Enabled = true + p.DefaultPolicy.RestrictedDirectories.Linux = patterns + p.DefaultPolicy.RestrictedDirectories.Mac = patterns + p.DefaultPolicy.RestrictedDirectories.Windows = patterns + return p +} + +func TestCheckPromptPaths_GlobBasename_StarDotPem(t *testing.T) { + cleanup := writePolicy(t, restrictedFilesPolicy([]string{"*.pem"})) + defer cleanup() + + blocked, _ := guardrails.CheckPromptPaths("please read cert.pem for the deploy") + if !blocked { + t.Fatal("expected *.pem to match cert.pem via basename glob") + } +} + +func TestCheckPromptPaths_DoubleStar_AnywherePem(t *testing.T) { + cleanup := writePolicy(t, restrictedFilesPolicy([]string{"**/*.pem"})) + defer cleanup() + + blocked, _ := guardrails.CheckPromptPaths("please inspect /srv/keys/cert.pem now") + if !blocked { + t.Fatal("expected **/*.pem to match /srv/keys/cert.pem") + } +} + +func TestCheckPromptPaths_GlobDir_PerUserSSH(t *testing.T) { + var patterns []string + var prompt string + switch runtime.GOOS { + case "windows": + patterns = []string{"C:/Users/*/.ssh"} + prompt = "grab C:/Users/alice/.ssh/id_rsa" + default: + patterns = []string{"/home/*/.ssh"} + prompt = "grab /home/alice/.ssh/id_rsa" + } + cleanup := writePolicy(t, restrictedDirsPolicy(patterns)) + defer cleanup() + + blocked, _ := guardrails.CheckPromptPaths(prompt) + if !blocked { + t.Fatalf("expected per-user .ssh glob to block %q", prompt) + } +} + +func TestCheckPromptPaths_DoubleStar_SecretsAnywhere(t *testing.T) { + cleanup := writePolicy(t, restrictedDirsPolicy([]string{"**/secrets/**"})) + defer cleanup() + + blocked, _ := guardrails.CheckPromptPaths("look at /repo/services/secrets/db.yaml please") + if !blocked { + t.Fatal("expected **/secrets/** to match a file inside any secrets dir") + } +} + +func TestCheckPromptPaths_LiteralBasename_StillWorks(t *testing.T) { + cleanup := writePolicy(t, restrictedFilesPolicy([]string{"kubeconfig", "terraform.tfstate"})) + defer cleanup() + + if blocked, _ := guardrails.CheckPromptPaths("merge /etc/kubeconfig please"); !blocked { + t.Fatal("literal basename kubeconfig should still block") + } + if blocked, _ := guardrails.CheckPromptPaths("read terraform.tfstate for audit"); !blocked { + t.Fatal("literal basename terraform.tfstate should still block") + } +} + +func TestPathUnderAny_GlobDir(t *testing.T) { + switch runtime.GOOS { + case "windows": + if !guardrails.PathUnderAny("C:/Users/alice/.ssh/id_rsa", []string{"C:/Users/*/.ssh"}) { + t.Fatal("expected glob dir to match nested path") + } + if guardrails.PathUnderAny("C:/Users/alice/Documents/report.txt", []string{"C:/Users/*/.ssh"}) { + t.Fatal("glob dir must not match unrelated path") + } + default: + if !guardrails.PathUnderAny("/home/alice/.ssh/id_rsa", []string{"/home/*/.ssh"}) { + t.Fatal("expected glob dir to match nested path") + } + if guardrails.PathUnderAny("/home/alice/Documents/report.txt", []string{"/home/*/.ssh"}) { + t.Fatal("glob dir must not match unrelated path") + } + } +} + +func TestCheckShellCommand_RestrictedFilesGlob(t *testing.T) { + policy := guardrails.HooksPolicy{} + policy.DefaultPolicy.RestrictedFiles.Enabled = true + policy.DefaultPolicy.RestrictedFiles.Linux = []string{"**/*.pem"} + policy.DefaultPolicy.RestrictedFiles.Mac = []string{"**/*.pem"} + policy.DefaultPolicy.RestrictedFiles.Windows = []string{"**/*.pem"} + cleanup := writePolicy(t, policy) + defer cleanup() + + blocked, needsConfirm, _ := guardrails.CheckShellCommand("cat /tmp/secrets/foo.pem", "") + if !blocked { + t.Fatal("expected **/*.pem to block a command referencing the file") + } + if needsConfirm { + t.Fatal("restricted-files match should be a hard block, not an ask") + } +} + +func TestCheckShellCommand_AllowedFilesGlob(t *testing.T) { + rule := guardrails.ToolRule{ + ID: "glob-allow", + Tool: []string{"mvn"}, + OS: []string{currentOS()}, + AllowedFiles: guardrails.PathPolicy{ + Enabled: true, + Linux: []string{"**/pom.xml", "*.java"}, + Mac: []string{"**/pom.xml", "*.java"}, + Windows: []string{"**/pom.xml", "*.java"}, + }, + MergeStrategy: guardrails.MergeStrategy{AllowedFiles: "override"}, + } + cleanup := writePolicy(t, makeToolRulePolicy(rule)) + defer cleanup() + + // Foo.java matches "*.java" via basename glob — should be allowed. + if blocked, _, _ := guardrails.CheckShellCommand("mvn compile Foo.java", ""); blocked { + t.Fatal("expected *.java glob to allow Foo.java") + } + // ./sub/pom.xml matches "**/pom.xml" via full-path glob — should be allowed. + if blocked, _, _ := guardrails.CheckShellCommand("mvn -f ./sub/pom.xml compile", ""); blocked { + t.Fatal("expected **/pom.xml glob to allow ./sub/pom.xml") + } + // script.sh matches nothing — should ask. + blocked, needsConfirm, _ := guardrails.CheckShellCommand("mvn compile script.sh", "") + if !blocked || !needsConfirm { + t.Fatal("unknown file must trigger ask, not allow or hard-block") + } +} + +func TestResolveRestrictedPaths_MergeWithEmptyRule(t *testing.T) { + // Empty rule list + merge strategy must safely fall back to the global list. + got := guardrails.ResolveRestrictedPaths([]string{"/a", "/b"}, nil, "merge") + if len(got) != 2 || got[0] != "/a" || got[1] != "/b" { + t.Fatalf("empty rule + merge should return global list verbatim, got %v", got) + } +} + +// -------------------------------------------------------------------------- +// CheckWorkspaceRoots +// -------------------------------------------------------------------------- + +func TestCheckWorkspaceRoots_Blocked(t *testing.T) { + policy := guardrails.HooksPolicy{} + policy.DefaultPolicy.RestrictedDirectories.Enabled = true + policy.DefaultPolicy.RestrictedDirectories.Linux = []string{"/restricted/"} + policy.DefaultPolicy.RestrictedDirectories.Mac = []string{"/restricted/"} + policy.DefaultPolicy.RestrictedDirectories.Windows = []string{"C:\\Cx-Flow\\"} + cleanup := writePolicy(t, policy) + defer cleanup() + + var roots []string + switch runtime.GOOS { + case "windows": + // Cursor reports Windows roots with a leading slash before the drive letter. + roots = []string{"/c:/Cx-Flow/Test/JavaVulnerabilityLabE"} + default: + roots = []string{"/restricted/project"} + } + + blocked, reason := guardrails.CheckWorkspaceRoots(roots) + if !blocked { + t.Fatalf("expected workspace %v to be blocked", roots) + } + if reason == "" { + t.Fatal("expected non-empty reason") + } +} + +func TestCheckWorkspaceRoots_Allowed(t *testing.T) { + policy := guardrails.HooksPolicy{} + policy.DefaultPolicy.RestrictedDirectories.Enabled = true + policy.DefaultPolicy.RestrictedDirectories.Linux = []string{"/restricted/"} + policy.DefaultPolicy.RestrictedDirectories.Mac = []string{"/restricted/"} + policy.DefaultPolicy.RestrictedDirectories.Windows = []string{"C:\\Cx-Flow\\"} + cleanup := writePolicy(t, policy) + defer cleanup() + + var roots []string + switch runtime.GOOS { + case "windows": + roots = []string{"/d:/Projects/safe"} + default: + roots = []string{"/home/user/safe"} + } + + blocked, _ := guardrails.CheckWorkspaceRoots(roots) + if blocked { + t.Fatalf("expected workspace %v to be allowed", roots) + } +} + +func TestCheckWorkspaceRoots_EmptyList(t *testing.T) { + policy := guardrails.HooksPolicy{} + policy.DefaultPolicy.RestrictedDirectories.Enabled = true + policy.DefaultPolicy.RestrictedDirectories.Linux = []string{"/restricted/"} + policy.DefaultPolicy.RestrictedDirectories.Mac = []string{"/restricted/"} + policy.DefaultPolicy.RestrictedDirectories.Windows = []string{"C:\\Cx-Flow\\"} + cleanup := writePolicy(t, policy) + defer cleanup() + + blocked, _ := guardrails.CheckWorkspaceRoots(nil) + if blocked { + t.Fatal("empty workspace root list must not block") + } +} + +// -------------------------------------------------------------------------- +// NormalizeWorkspaceRoot +// -------------------------------------------------------------------------- + +func TestNormalizeWorkspaceRoot(t *testing.T) { + tests := []struct { + name, in, want string + }{ + {"cursor-windows-leading-slash", "/c:/Cx-Flow/Test", "c:/Cx-Flow/Test"}, + {"already-normalized-windows", "C:/Cx-Flow/Test", "C:/Cx-Flow/Test"}, + {"windows-backslashes", "C:\\Cx-Flow\\Test", "C:/Cx-Flow/Test"}, + {"unix-absolute", "/etc/secrets", "/etc/secrets"}, + {"empty", "", ""}, + {"slash-only", "/", "/"}, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + if got := guardrails.NormalizeWorkspaceRoot(tc.in); got != tc.want { + t.Fatalf("NormalizeWorkspaceRoot(%q) = %q, want %q", tc.in, got, tc.want) + } + }) + } +} + +// -------------------------------------------------------------------------- +// FindMatchingToolRule +// -------------------------------------------------------------------------- + +func TestFindMatchingToolRule_Match(t *testing.T) { + policy := guardrails.HooksPolicy{} + policy.Tools.Enabled = true + policy.Tools.Rules = []guardrails.ToolRule{ + {ID: "r1", Tool: []string{"mvn", "mvnw"}, OS: []string{currentOS()}}, + } + cleanup := writePolicy(t, policy) + defer cleanup() + + rule := guardrails.FindMatchingToolRule("mvn clean") + if rule == nil { + t.Fatal("expected matching rule for mvn") + } + if rule.ID != "r1" { + t.Fatalf("expected rule r1, got %q", rule.ID) + } +} + +func TestFindMatchingToolRule_AltName(t *testing.T) { + policy := guardrails.HooksPolicy{} + policy.Tools.Enabled = true + policy.Tools.Rules = []guardrails.ToolRule{ + {ID: "r2", Tool: []string{"mvn", "mvnw"}, OS: []string{currentOS()}}, + } + cleanup := writePolicy(t, policy) + defer cleanup() + + rule := guardrails.FindMatchingToolRule("mvnw compile") + if rule == nil || rule.ID != "r2" { + t.Fatal("expected match on alternate tool name mvnw") + } +} + +func TestFindMatchingToolRule_OSMismatch(t *testing.T) { + wrongOS := "linux" + if runtime.GOOS == "linux" { + wrongOS = "windows" + } + policy := guardrails.HooksPolicy{} + policy.Tools.Enabled = true + policy.Tools.Rules = []guardrails.ToolRule{ + {ID: "r3", Tool: []string{"mvn"}, OS: []string{wrongOS}}, + } + cleanup := writePolicy(t, policy) + defer cleanup() + + if rule := guardrails.FindMatchingToolRule("mvn compile"); rule != nil { + t.Fatal("should not match rule for wrong OS") + } +} + +func TestFindMatchingToolRule_Disabled(t *testing.T) { + policy := guardrails.HooksPolicy{} + policy.Tools.Enabled = false + policy.Tools.Rules = []guardrails.ToolRule{ + {ID: "r4", Tool: []string{"mvn"}, OS: []string{currentOS()}}, + } + cleanup := writePolicy(t, policy) + defer cleanup() + + if rule := guardrails.FindMatchingToolRule("mvn compile"); rule != nil { + t.Fatal("should not match rule when tools section is disabled") + } +} + +func TestFindMatchingToolRule_NoPolicy(t *testing.T) { + dir := t.TempDir() + defer setHomeDir(dir)() + + if rule := guardrails.FindMatchingToolRule("mvn compile"); rule != nil { + t.Fatal("should return nil when no policy file exists") + } +} + +// -------------------------------------------------------------------------- +// ResolveAllowedPaths +// -------------------------------------------------------------------------- + +func TestResolveAllowedPaths_Merge(t *testing.T) { + global := []string{"/a", "/b"} + rule := []string{"/c"} + got := guardrails.ResolveAllowedPaths(global, rule, "merge") + want := map[string]bool{"/a": true, "/b": true, "/c": true} + if len(got) != 3 { + t.Fatalf("expected 3 paths, got %d: %v", len(got), got) + } + for _, p := range got { + if !want[p] { + t.Fatalf("unexpected path %q in merge result", p) + } + } +} + +func TestResolveAllowedPaths_Merge_Deduplicates(t *testing.T) { + global := []string{"/a", "/b"} + rule := []string{"/b", "/c"} + got := guardrails.ResolveAllowedPaths(global, rule, "merge") + if len(got) != 3 { + t.Fatalf("expected 3 deduplicated paths, got %d: %v", len(got), got) + } +} + +func TestResolveAllowedPaths_Override(t *testing.T) { + global := []string{"/a", "/b"} + rule := []string{"/c"} + got := guardrails.ResolveAllowedPaths(global, rule, "override") + if len(got) != 1 || got[0] != "/c" { + t.Fatalf("expected only rule paths on override, got %v", got) + } +} + +func TestResolveAllowedPaths_Default(t *testing.T) { + global := []string{"/a", "/b"} + rule := []string{"/c"} + got := guardrails.ResolveAllowedPaths(global, rule, "default") + if len(got) != 2 || got[0] != "/a" || got[1] != "/b" { + t.Fatalf("expected only global paths on default, got %v", got) + } +} + +func TestResolveAllowedPaths_UnknownStrategyActsAsDefault(t *testing.T) { + global := []string{"/a"} + rule := []string{"/b"} + got := guardrails.ResolveAllowedPaths(global, rule, "unknown-strategy") + if len(got) != 1 || got[0] != "/a" { + t.Fatalf("unknown strategy should act as default, got %v", got) + } +} + +// -------------------------------------------------------------------------- +// ResolveRestrictedPaths (delegates to same logic as allowed) +// -------------------------------------------------------------------------- + +func TestResolveRestrictedPaths_Merge(t *testing.T) { + got := guardrails.ResolveRestrictedPaths([]string{"/a"}, []string{"/b"}, "merge") + if len(got) != 2 { + t.Fatalf("expected 2 merged paths, got %v", got) + } +} + +func TestResolveRestrictedPaths_Override(t *testing.T) { + got := guardrails.ResolveRestrictedPaths([]string{"/a"}, []string{"/b"}, "override") + if len(got) != 1 || got[0] != "/b" { + t.Fatalf("expected override to yield rule paths only, got %v", got) + } +} + +func TestResolveRestrictedPaths_Default(t *testing.T) { + got := guardrails.ResolveRestrictedPaths([]string{"/a"}, []string{"/b"}, "default") + if len(got) != 1 || got[0] != "/a" { + t.Fatalf("expected default to yield global paths only, got %v", got) + } +} + +// -------------------------------------------------------------------------- +// Tool-level restricted paths (shell.go) +// -------------------------------------------------------------------------- + +// When ask_on_restricted=false (default), tool-level restricted paths hard-block. +func TestCheckShellCommand_ToolRestrictedDir_HardBlock(t *testing.T) { + restricted := "/prod" + rule := guardrails.ToolRule{ + ID: "trd1", + Tool: []string{"mvn"}, + OS: []string{currentOS()}, + RestrictedDirectories: guardrails.PathPolicy{ + Enabled: true, + Linux: []string{restricted}, + Mac: []string{restricted}, + Windows: []string{restricted}, + }, + MergeStrategy: guardrails.MergeStrategy{RestrictedDirectories: "override"}, + } + policy := makeToolRulePolicy(rule) + cleanup := writePolicy(t, policy) + defer cleanup() + + blocked, needsConfirm, reason := guardrails.CheckShellCommand("mvn compile", restricted) + if !blocked { + t.Fatal("expected blocked for restricted workDir") + } + if needsConfirm { + t.Fatal("expected hard block when ask_on_restricted=false") + } + if reason == "" { + t.Fatal("expected non-empty reason") + } +} + +// Restricted paths always hard-block regardless of any flag — needsConfirm is never true. +func TestCheckShellCommand_ToolRestrictedDir_AlwaysHardBlock(t *testing.T) { + restricted := "/prod" + rule := guardrails.ToolRule{ + ID: "trd2", + Tool: []string{"mvn"}, + OS: []string{currentOS()}, + RestrictedDirectories: guardrails.PathPolicy{ + Enabled: true, + Linux: []string{restricted}, + Mac: []string{restricted}, + Windows: []string{restricted}, + }, + MergeStrategy: guardrails.MergeStrategy{RestrictedDirectories: "override"}, + } + cleanup := writePolicy(t, makeToolRulePolicy(rule)) + defer cleanup() + + blocked, needsConfirm, _ := guardrails.CheckShellCommand("mvn compile", restricted) + if !blocked { + t.Fatal("expected blocked=true for restricted workDir") + } + if needsConfirm { + t.Fatal("restricted paths must always hard-block (needsConfirm=false)") + } +} + +func TestCheckShellCommand_ToolRestrictedFile_HardBlock(t *testing.T) { + rule := guardrails.ToolRule{ + ID: "trf1", + Tool: []string{"cat"}, + OS: []string{currentOS()}, + RestrictedFiles: guardrails.PathPolicy{ + Enabled: true, + Linux: []string{"secret.key"}, + Mac: []string{"secret.key"}, + Windows: []string{"secret.key"}, + }, + MergeStrategy: guardrails.MergeStrategy{RestrictedFiles: "override"}, + } + cleanup := writePolicy(t, makeToolRulePolicy(rule)) + defer cleanup() + + blocked, needsConfirm, _ := guardrails.CheckShellCommand("cat ./secret.key", "") + if !blocked || needsConfirm { + t.Fatal("expected hard block for restricted file arg") + } +} + +// Tool-level restricted paths merge with the global list when strategy=merge. +func TestCheckShellCommand_ToolRestrictedDir_MergeStrategy(t *testing.T) { + globalRestricted := "/global-prod" + ruleRestricted := "/rule-prod" + + policy := guardrails.HooksPolicy{} + policy.DefaultPolicy.RestrictedDirectories.Enabled = true + policy.DefaultPolicy.RestrictedDirectories.Linux = []string{globalRestricted} + policy.DefaultPolicy.RestrictedDirectories.Mac = []string{globalRestricted} + policy.DefaultPolicy.RestrictedDirectories.Windows = []string{globalRestricted} + + policy.Tools.Enabled = true + policy.Tools.Rules = []guardrails.ToolRule{{ + ID: "merge-r", + Tool: []string{"mvn"}, + OS: []string{currentOS()}, + RestrictedDirectories: guardrails.PathPolicy{ + Enabled: true, + Linux: []string{ruleRestricted}, + Mac: []string{ruleRestricted}, + Windows: []string{ruleRestricted}, + }, + MergeStrategy: guardrails.MergeStrategy{RestrictedDirectories: "merge"}, + }} + cleanup := writePolicy(t, policy) + defer cleanup() + + // Both global and rule restricted dirs should block. + if blocked, _, _ := guardrails.CheckShellCommand("mvn compile", globalRestricted); !blocked { + t.Fatal("global restricted dir should block under merge strategy") + } + if blocked, _, _ := guardrails.CheckShellCommand("mvn compile", ruleRestricted); !blocked { + t.Fatal("rule restricted dir should block under merge strategy") + } +} + +// -------------------------------------------------------------------------- +// ToolRule.Enabled (pointer-based opt-out) +// -------------------------------------------------------------------------- + +func TestFindMatchingToolRule_ExplicitlyDisabled(t *testing.T) { + disabled := false + policy := guardrails.HooksPolicy{} + policy.Tools.Enabled = true + policy.Tools.Rules = []guardrails.ToolRule{ + {ID: "off", Enabled: &disabled, Tool: []string{"mvn"}, OS: []string{currentOS()}}, + } + cleanup := writePolicy(t, policy) + defer cleanup() + + if rule := guardrails.FindMatchingToolRule("mvn compile"); rule != nil { + t.Fatal("rule with enabled=false should be skipped") + } +} + +func TestFindMatchingToolRule_ExplicitlyEnabled(t *testing.T) { + enabled := true + policy := guardrails.HooksPolicy{} + policy.Tools.Enabled = true + policy.Tools.Rules = []guardrails.ToolRule{ + {ID: "on", Enabled: &enabled, Tool: []string{"mvn"}, OS: []string{currentOS()}}, + } + cleanup := writePolicy(t, policy) + defer cleanup() + + if rule := guardrails.FindMatchingToolRule("mvn compile"); rule == nil || rule.ID != "on" { + t.Fatal("rule with enabled=true should match") + } +} + +// -------------------------------------------------------------------------- +// BlastRadiusLimit +// -------------------------------------------------------------------------- + +func TestBlastRadiusLimit_BlocksAfterThreshold(t *testing.T) { + policy := guardrails.HooksPolicy{} + policy.DefaultPolicy.BlastRadiusLimit = guardrails.BlastRadiusLimit{ + Enabled: true, + Threshold: 2, + } + cleanup := writePolicy(t, policy) + defer cleanup() + + guardrails.ResetBlastRadiusCount() + + // First two writes are allowed. + if blocked, _ := guardrails.CheckAndIncrementBlastRadius(); blocked { + t.Fatal("first write should be allowed") + } + if blocked, _ := guardrails.CheckAndIncrementBlastRadius(); blocked { + t.Fatal("second write should be allowed") + } + // Third write exceeds the threshold. + blocked, reason := guardrails.CheckAndIncrementBlastRadius() + if !blocked { + t.Fatal("third write should be blocked by blast radius limit") + } + if reason == "" { + t.Fatal("expected non-empty reason") + } +} + +func TestBlastRadiusLimit_Disabled(t *testing.T) { + policy := guardrails.HooksPolicy{} + policy.DefaultPolicy.BlastRadiusLimit = guardrails.BlastRadiusLimit{ + Enabled: false, + Threshold: 1, + } + cleanup := writePolicy(t, policy) + defer cleanup() + + guardrails.ResetBlastRadiusCount() + for i := 0; i < 10; i++ { + if blocked, _ := guardrails.CheckAndIncrementBlastRadius(); blocked { + t.Fatalf("write %d should be allowed when disabled", i+1) + } + } +} + +// -------------------------------------------------------------------------- +// MaxTotalFileSizeKB +// -------------------------------------------------------------------------- + +func TestMaxTotalFileSizeKB_BlocksAfterThreshold(t *testing.T) { + policy := guardrails.HooksPolicy{} + policy.DefaultPolicy.ContextPolicy.Enabled = true + policy.DefaultPolicy.ContextPolicy.FilesLimits = guardrails.FilesLimits{ + Enabled: true, + MaxTotalFileSizeKB: 1, // 1 KB = 1024 bytes + } + cleanup := writePolicy(t, policy) + defer cleanup() + + guardrails.ResetTotalFileSizeCount() + + // First 512 bytes allowed. + if blocked, _ := guardrails.CheckAndIncrementTotalFileSize(512); blocked { + t.Fatal("first write (512 B) should be allowed") + } + // Second 512 bytes brings total to exactly 1024 — still within limit. + if blocked, _ := guardrails.CheckAndIncrementTotalFileSize(512); blocked { + t.Fatal("second write (512 B, total 1024 B) should be allowed") + } + // One more byte pushes total over 1024. + blocked, reason := guardrails.CheckAndIncrementTotalFileSize(1) + if !blocked { + t.Fatal("write exceeding max_total_file_size_kb should be blocked") + } + if reason == "" { + t.Fatal("expected non-empty reason") + } +} + +func TestMaxTotalFileSizeKB_Disabled(t *testing.T) { + policy := guardrails.HooksPolicy{} + policy.DefaultPolicy.ContextPolicy.Enabled = true + policy.DefaultPolicy.ContextPolicy.FilesLimits = guardrails.FilesLimits{ + Enabled: false, + MaxTotalFileSizeKB: 1, + } + cleanup := writePolicy(t, policy) + defer cleanup() + + guardrails.ResetTotalFileSizeCount() + for i := 0; i < 10; i++ { + if blocked, _ := guardrails.CheckAndIncrementTotalFileSize(1024 * 1024); blocked { + t.Fatalf("write %d should be allowed when disabled", i+1) + } + } +} + +// -------------------------------------------------------------------------- +// BlockedExtensions +// -------------------------------------------------------------------------- + +func TestCheckBlockedExtensions_Blocks(t *testing.T) { + policy := guardrails.HooksPolicy{} + policy.DefaultPolicy.ContextPolicy.Enabled = true + policy.DefaultPolicy.ContextPolicy.BlockedExtensions = guardrails.BlockedExtensions{ + Enabled: true, + Extensions: []string{".env", ".pem"}, + } + cleanup := writePolicy(t, policy) + defer cleanup() + + if reason := guardrails.CheckBlockedExtensions("please read secrets.pem"); reason == "" { + t.Fatal("expected .pem reference to be blocked") + } +} + +func TestCheckBlockedExtensions_Clean(t *testing.T) { + policy := guardrails.HooksPolicy{} + policy.DefaultPolicy.ContextPolicy.Enabled = true + policy.DefaultPolicy.ContextPolicy.BlockedExtensions = guardrails.BlockedExtensions{ + Enabled: true, + Extensions: []string{".env"}, + } + cleanup := writePolicy(t, policy) + defer cleanup() + + if reason := guardrails.CheckBlockedExtensions("please read main.go"); reason != "" { + t.Fatal("clean prompt should not be blocked") + } +} + +func TestCheckBlockedExtensions_DisabledNoPolicy(t *testing.T) { + dir := t.TempDir() + defer setHomeDir(dir)() + + if reason := guardrails.CheckBlockedExtensions("please read secrets.pem"); reason != "" { + t.Fatal("should not block when no policy is configured") + } +} + +// -------------------------------------------------------------------------- +// FilesLimits +// -------------------------------------------------------------------------- + +func TestCheckFilesLimits_Blocks(t *testing.T) { + policy := guardrails.HooksPolicy{} + policy.DefaultPolicy.ContextPolicy.Enabled = true + policy.DefaultPolicy.ContextPolicy.FilesLimits = guardrails.FilesLimits{ + Enabled: true, + MaxFileCount: 2, + } + cleanup := writePolicy(t, policy) + defer cleanup() + + prompt := "read a.go, b.go, and c.go" + if reason := guardrails.CheckFilesLimits(prompt); reason == "" { + t.Fatal("expected prompt referencing 3 files to exceed max_file_count=2") + } +} + +func TestCheckFilesLimits_UnderLimit(t *testing.T) { + policy := guardrails.HooksPolicy{} + policy.DefaultPolicy.ContextPolicy.Enabled = true + policy.DefaultPolicy.ContextPolicy.FilesLimits = guardrails.FilesLimits{ + Enabled: true, + MaxFileCount: 5, + } + cleanup := writePolicy(t, policy) + defer cleanup() + + if reason := guardrails.CheckFilesLimits("read a.go and b.go"); reason != "" { + t.Fatal("prompt under the limit should not be blocked") + } +} + +// -------------------------------------------------------------------------- +// ScanForPolicyPatterns +// -------------------------------------------------------------------------- + +func TestScanForPolicyPatterns_Matches(t *testing.T) { + policy := guardrails.HooksPolicy{} + policy.DefaultPolicy.ContextPolicy.Enabled = true + policy.DefaultPolicy.ContextPolicy.ContentScanning = guardrails.ContentScanning{ + Enabled: true, + Patterns: []guardrails.ContentScanPattern{ + {ID: "no-prod", Pattern: `prod\.example\.com`, Description: "Production URLs"}, + }, + } + cleanup := writePolicy(t, policy) + defer cleanup() + + if reason := guardrails.ScanForPolicyPatterns("see https://prod.example.com/api"); reason == "" { + t.Fatal("expected prompt matching policy pattern to be rejected") + } +} + +func TestScanForPolicyPatterns_NoMatch(t *testing.T) { + policy := guardrails.HooksPolicy{} + policy.DefaultPolicy.ContextPolicy.Enabled = true + policy.DefaultPolicy.ContextPolicy.ContentScanning = guardrails.ContentScanning{ + Enabled: true, + Patterns: []guardrails.ContentScanPattern{ + {ID: "no-prod", Pattern: `prod\.example\.com`, Description: "Production URLs"}, + }, + } + cleanup := writePolicy(t, policy) + defer cleanup() + + if reason := guardrails.ScanForPolicyPatterns("see https://staging.example.com"); reason != "" { + t.Fatal("clean prompt should not match") + } +} + +func TestScanForPolicyPatterns_Disabled(t *testing.T) { + policy := guardrails.HooksPolicy{} + policy.DefaultPolicy.ContextPolicy.Enabled = true + policy.DefaultPolicy.ContextPolicy.ContentScanning = guardrails.ContentScanning{ + Enabled: false, + Patterns: []guardrails.ContentScanPattern{ + {ID: "no-prod", Pattern: `prod`, Description: "Production"}, + }, + } + cleanup := writePolicy(t, policy) + defer cleanup() + + if reason := guardrails.ScanForPolicyPatterns("prod is a keyword"); reason != "" { + t.Fatal("disabled scanner should not produce findings") + } +} + +// -------------------------------------------------------------------------- +// FindMatchingToolRule — compound command / token-boundary behavior (Fix B) +// -------------------------------------------------------------------------- + +func TestFindMatchingToolRule_TokenInChain(t *testing.T) { + policy := guardrails.HooksPolicy{} + policy.Tools.Enabled = true + policy.Tools.Rules = []guardrails.ToolRule{ + {ID: "r-chain", Tool: []string{"mvn"}, OS: []string{currentOS()}}, + } + cleanup := writePolicy(t, policy) + defer cleanup() + + rule := guardrails.FindMatchingToolRule(`cd "c:\foo" && mvn deploy`) + if rule == nil || rule.ID != "r-chain" { + t.Fatalf("expected r-chain match for chained mvn, got %+v", rule) + } +} + +func TestFindMatchingToolRule_TokenAtStart(t *testing.T) { + policy := guardrails.HooksPolicy{} + policy.Tools.Enabled = true + policy.Tools.Rules = []guardrails.ToolRule{ + {ID: "r-start", Tool: []string{"mvn"}, OS: []string{currentOS()}}, + } + cleanup := writePolicy(t, policy) + defer cleanup() + + if rule := guardrails.FindMatchingToolRule("mvn package"); rule == nil || rule.ID != "r-start" { + t.Fatalf("regression: expected r-start match for plain `mvn package`, got %+v", rule) + } +} + +func TestFindMatchingToolRule_TokenNotSubstringMatch(t *testing.T) { + policy := guardrails.HooksPolicy{} + policy.Tools.Enabled = true + policy.Tools.Rules = []guardrails.ToolRule{ + {ID: "r-nosub", Tool: []string{"mvn"}, OS: []string{currentOS()}}, + } + cleanup := writePolicy(t, policy) + defer cleanup() + + if rule := guardrails.FindMatchingToolRule("cd /opt/foo-mvn-bar/ls"); rule != nil { + t.Fatalf("expected nil for `mvn` substring inside path token, got %+v", rule) + } +} + +func TestFindMatchingToolRule_TokenBoundaryUnderscore(t *testing.T) { + policy := guardrails.HooksPolicy{} + policy.Tools.Enabled = true + policy.Tools.Rules = []guardrails.ToolRule{ + {ID: "r-underscore", Tool: []string{"mvn"}, OS: []string{currentOS()}}, + } + cleanup := writePolicy(t, policy) + defer cleanup() + + if rule := guardrails.FindMatchingToolRule("mvn_helper run"); rule != nil { + t.Fatalf("expected nil for `mvn_helper` (underscore is a word byte), got %+v", rule) + } +} + +// -------------------------------------------------------------------------- +// CheckShellCommand — compound + parent-command args_exclude (Fix B) +// -------------------------------------------------------------------------- + +func mvnDeployExcludeFixture(t *testing.T) func() { + t.Helper() + policy := guardrails.HooksPolicy{} + policy.Tools.Enabled = true + policy.Tools.Rules = []guardrails.ToolRule{ + { + ID: "mvn-deploy-excluded", + Tool: []string{"mvn"}, + OS: []string{currentOS()}, + ArgsExclude: []string{"deploy"}, + }, + } + return writePolicy(t, policy) +} + +func TestCheckShellCommand_Compound_ArgsExcludeAfterCd(t *testing.T) { + defer mvnDeployExcludeFixture(t)() + + blocked, needsConfirm, reason := guardrails.CheckShellCommand(`cd "c:\foo" && mvn deploy`, "") + if !blocked || needsConfirm { + t.Fatalf("expected hard deny on chained `mvn deploy`, got blocked=%v needsConfirm=%v reason=%q", + blocked, needsConfirm, reason) + } + if !strings.Contains(reason, "deploy") { + t.Fatalf("expected reason to cite `deploy`, got %q", reason) + } +} + +func TestCheckShellCommand_Compound_ArgsExcludeWithExtraFlags(t *testing.T) { + defer mvnDeployExcludeFixture(t)() + + blocked, needsConfirm, _ := guardrails.CheckShellCommand(`cd "c:\foo" && mvn deploy -Djava=11`, "") + if !blocked || needsConfirm { + t.Fatalf("expected hard deny on chained `mvn deploy -Djava=11`, got blocked=%v needsConfirm=%v", + blocked, needsConfirm) + } +} + +func TestCheckShellCommand_SingleCmd_ArgsExcludeWithExtraFlags(t *testing.T) { + defer mvnDeployExcludeFixture(t)() + + blocked, needsConfirm, _ := guardrails.CheckShellCommand("mvn deploy -Djava=11", "") + if !blocked || needsConfirm { + t.Fatalf("expected hard deny on `mvn deploy -Djava=11`, got blocked=%v needsConfirm=%v", + blocked, needsConfirm) + } +} + +func TestCheckShellCommand_DenyMessageAppended(t *testing.T) { + defer mvnDeployExcludeFixture(t)() + + _, _, reason := guardrails.CheckShellCommand("mvn deploy", "") + if !strings.Contains(reason, "Do NOT attempt alternative commands") { + t.Fatalf("expected DenyMessage no-workaround text in reason, got %q", reason) + } +} diff --git a/internal/commands/agenthooks/guardrails/prompt.go b/internal/commands/agenthooks/guardrails/prompt.go new file mode 100644 index 000000000..5512de321 --- /dev/null +++ b/internal/commands/agenthooks/guardrails/prompt.go @@ -0,0 +1,858 @@ +package guardrails + +import ( + "fmt" + "io/fs" + "os" + "path/filepath" + "regexp" + "strings" + + scanner "github.com/checkmarx/2ms/v3/pkg" +) + +// defaultReferencedFileMaxBytes caps per-file reads when the policy does not +// configure files_limits.max_file_size_kb. 1 MB is well above any realistic +// config/source file while preventing accidental reads of large binaries. +const defaultReferencedFileMaxBytes = 1 << 20 + +// filePathRegexps extracts file/directory paths from free-form text such as user prompts. +// Patterns are tried in order; each must produce the path in the last capture group (or m[0]). +var filePathRegexps = []*regexp.Regexp{ + // @-mention (Cursor/IDE file reference): @.env @src/config.js @/absolute/path + regexp.MustCompile(`@([^\s"'` + "`" + `<>|*?,;:]+)`), + // Unix absolute: /path/to/file + regexp.MustCompile(`(?:^|[\s"'` + "`" + `])(/[^\s"'` + "`" + `<>|*?]+)`), + // Windows absolute: C:\path or C:/path + regexp.MustCompile(`[A-Za-z]:[\\\/][^\s"'` + "`" + `<>|*?]+`), + // Explicit relative: ./foo or ../foo + regexp.MustCompile(`(?:^|[\s"'` + "`" + `])(\.{1,2}/[^\s"'` + "`" + `<>|*?]+)`), + // Bare dotfile or named file with extension, preceded by space/@/quote/backtick + // Matches: .env .env.local credentials.json secrets.yaml id_rsa + regexp.MustCompile(`(?:^|[\s"'` + "`" + `@])(\.[a-zA-Z0-9][a-zA-Z0-9_.-]*|[a-zA-Z0-9_-]+\.[a-zA-Z0-9][a-zA-Z0-9_.-]*)(?:[\s"'` + "`" + `,;:!?]|$)`), +} + +// globMetaStripper replaces wildcard metacharacters with a space so that +// glob-shaped references like "*.env", ".env*", "**/secrets/**", or "id_rsa*" +// degrade to plain path tokens the regexes below already understand. +// +// Spaces (rather than empty strings) preserve word/path boundaries: "file*name" +// becomes "file name" — two separate tokens — instead of merging into a single +// false token "filename". The character class regex anchors then continue to +// fire correctly on the cleaned text. +var globMetaStripper = strings.NewReplacer("*", " ", "?", " ") + +// stripGlobMeta returns text with glob metacharacters replaced by spaces. +func stripGlobMeta(text string) string { + return globMetaStripper.Replace(text) +} + +// extractFilePaths returns all file/directory paths found in text, deduplicated. +// Glob metacharacters are stripped first so that wildcarded references in user +// prompts (e.g. "modify *.env") still surface the underlying file/extension. +func extractFilePaths(text string) []string { + cleaned := stripGlobMeta(text) + seen := map[string]struct{}{} + var paths []string + for _, re := range filePathRegexps { + for _, m := range re.FindAllStringSubmatch(cleaned, -1) { + p := strings.TrimSpace(m[len(m)-1]) + if _, ok := seen[p]; !ok { + seen[p] = struct{}{} + paths = append(paths, p) + } + } + } + return paths +} + +// extractLiteralAnchors derives bare-name anchors from policy entries by +// stripping glob metacharacters and reducing each entry to its final path +// component. The resulting anchors are bare filenames (e.g. "kubeconfig", +// "id_rsa") that the path-extraction regex cannot detect on its own — they +// have no extension, no leading dot, and no path separator — so a separate +// word-boundary scan of the prompt is needed to surface them. +// +// Glob entries like "*.pem" or "**/secrets/**" reduce to ".pem" / "secrets" +// — the path regexes already handle those, so duplicates here are harmless +// (the caller deduplicates against extracted paths). +func extractLiteralAnchors(entries []string) []string { + cleaner := strings.NewReplacer("*", "", "?", "") + seen := map[string]struct{}{} + var anchors []string + for _, e := range entries { + c := cleaner.Replace(e) + c = strings.Trim(c, "/\\") + if c == "" { + continue + } + if i := strings.LastIndexAny(c, "/\\"); i >= 0 { + c = c[i+1:] + } + if c == "" { + continue + } + if _, ok := seen[c]; ok { + continue + } + seen[c] = struct{}{} + anchors = append(anchors, c) + } + return anchors +} + +// findLiteralAnchorsInText returns the subset of anchors that appear in text +// at a word boundary (case-insensitive). Used to surface bare-name policy +// entries the path regexes miss. +func findLiteralAnchorsInText(text string, anchors []string) []string { + if len(anchors) == 0 { + return nil + } + cleaned := stripGlobMeta(text) + seen := map[string]struct{}{} + var hits []string + for _, a := range anchors { + re := regexp.MustCompile(`(?i)\b` + regexp.QuoteMeta(a) + `\b`) + if re.MatchString(cleaned) { + if _, ok := seen[a]; !ok { + seen[a] = struct{}{} + hits = append(hits, a) + } + } + } + return hits +} + +// severityFromValidation maps 2ms validation status to a severity label. +func severityFromValidation(status string) string { + switch status { + case "Valid": + return "Critical" + case "Invalid": + return "Medium" + default: // "Unknown" or anything else + return "High" + } +} + +// ScanForSecrets runs the 2ms secret scanner on arbitrary text (e.g. a prompt). +// Returns a human-readable rejection reason, or "" when the text is clean. +func ScanForSecrets(text string) string { + content := text + report, err := scanner.NewScanner().Scan( + []scanner.ScanItem{{Content: &content, Source: "prompt"}}, + scanner.ScanConfig{WithValidation: true}, + ) + if err != nil { + return "" // fail-open: scanner error should not block the developer + } + + var findings []string + for _, group := range report.Results { + for _, secret := range group { + severity := severityFromValidation(string(secret.ValidationStatus)) + findings = append(findings, fmt.Sprintf(" - %s (severity: %s)", secret.RuleID, severity)) + } + } + if len(findings) == 0 { + return "" + } + return fmt.Sprintf( + "Blocked by Checkmarx: prompt contains %d secret(s):\n%s\nRemove the secrets and try again.", + len(findings), strings.Join(findings, "\n"), + ) +} + +// maxGlobFallbackMatches caps the number of files a single ambiguous prompt +// reference may expand to via the glob fallback. Beyond this we drop the +// fallback entirely rather than scan a directory's worth of unrelated files. +const maxGlobFallbackMatches = 20 + +// resolveReferencedFile returns absolute paths to readable regular files for +// the given prompt-referenced path. A literal match wins; if the path doesn't +// exist on disk, a one-level glob in the parent directory (`*`) is tried +// so that references like `application-jira` still resolve to `application-jira.yml`. +// +// Absolute paths are used as-is; relative paths are tried against each +// workspace root in order, returning the first root that yields any matches. +// Directories, symlinks to directories, and missing entries return nil. +func resolveReferencedFile(p string, workspaceRoots []string) []string { + if filepath.IsAbs(p) { + return resolveOne(p) + } + // Cursor sometimes reports Windows roots as "/c:/foo"; normalise before joining. + for _, root := range workspaceRoots { + normalized := NormalizeWorkspaceRoot(root) + if normalized == "" { + continue + } + if resolved := resolveOne(filepath.Join(normalized, p)); len(resolved) > 0 { + return resolved + } + } + return nil +} + +// resolveOne returns the regular file at absPath if it exists, otherwise the +// glob fallback `*` capped at maxGlobFallbackMatches regular files. +// A typed path that is itself a directory returns nil (we never expand a +// directory reference into its contents). +func resolveOne(absPath string) []string { + if info, err := os.Stat(absPath); err == nil { + if info.Mode().IsRegular() { + return []string{absPath} + } + return nil // directory or other non-regular entry + } + return resolveByGlob(absPath) +} + +// resolveByGlob expands `absPath*` to sibling regular files. Returns nil when +// the parent directory doesn't exist or the match count would exceed +// maxGlobFallbackMatches — refusing to scan is safer than scanning the wrong +// thing on a broad prefix. +func resolveByGlob(absPath string) []string { + parent := filepath.Dir(absPath) + parentInfo, err := os.Stat(parent) + if err != nil || !parentInfo.IsDir() { + return nil + } + matches, err := filepath.Glob(absPath + "*") + if err != nil || len(matches) == 0 { + return nil + } + var regular []string + for _, m := range matches { + info, err := os.Lstat(m) + if err != nil { + continue + } + if !info.Mode().IsRegular() { + continue + } + regular = append(regular, m) + if len(regular) > maxGlobFallbackMatches { + return nil + } + } + return regular +} + +// ScanReferencedFiles resolves file paths mentioned in text against the given +// workspace roots, reads each one, and runs the 2ms secret scanner over its +// contents. Returns a human-readable rejection reason that lists findings per +// file, or "" when no referenced file contains secrets. +// +// Missing files, directories, and files that exceed the configured size cap +// are silently skipped — this is a best-effort guardrail, not a filesystem +// audit, and must not block the developer on unrelated I/O errors. +func ScanReferencedFiles(text string, workspaceRoots []string) string { + paths := extractFilePaths(text) + if len(paths) == 0 { + return "" + } + + // policyCapBytes: explicit policy size limit; >0 means files larger than this + // are blocked outright (size violation) without inspecting contents. + // scanBudgetBytes: memory ceiling for the scanner. If the policy cap is set, + // it doubles as the budget; otherwise fall back to defaultReferencedFileMaxBytes. + var policyCapBytes int64 + scanBudgetBytes := int64(defaultReferencedFileMaxBytes) + if limits := LoadFilesLimits(); limits != nil && limits.MaxFileSizeKB > 0 { + policyCapBytes = int64(limits.MaxFileSizeKB) * 1024 + scanBudgetBytes = policyCapBytes + } + + seen := map[string]struct{}{} + var perFile []string + var oversize []string + sc := scanner.NewScanner() + + for _, p := range paths { + for _, resolved := range resolveReferencedFile(p, workspaceRoots) { + if _, dup := seen[resolved]; dup { + continue + } + seen[resolved] = struct{}{} + + info, err := os.Stat(resolved) + if err != nil { + continue + } + if policyCapBytes > 0 && info.Size() > policyCapBytes { + oversize = append(oversize, fmt.Sprintf( + " %s (%d KB exceeds policy limit of %d KB)", + resolved, info.Size()/1024, policyCapBytes/1024, + )) + continue + } + if info.Size() > scanBudgetBytes { + continue + } + + data, err := os.ReadFile(resolved) + if err != nil { + continue + } + content := string(data) + + report, err := sc.Scan( + []scanner.ScanItem{{Content: &content, Source: resolved}}, + scanner.ScanConfig{WithValidation: true}, + ) + if err != nil { + continue // fail-open per scanner + } + + var findings []string + for _, group := range report.Results { + for _, secret := range group { + severity := severityFromValidation(string(secret.ValidationStatus)) + findings = append(findings, fmt.Sprintf(" - %s (severity: %s)", secret.RuleID, severity)) + } + } + if len(findings) == 0 { + continue + } + perFile = append(perFile, + fmt.Sprintf(" %s (%d secret(s)):\n%s", resolved, len(findings), strings.Join(findings, "\n"))) + } + } + + if len(perFile) == 0 && len(oversize) == 0 { + return "" + } + var sections []string + if len(perFile) > 0 { + sections = append(sections, + "file(s) containing secret(s):\n"+strings.Join(perFile, "\n")) + } + if len(oversize) > 0 { + sections = append(sections, + "file(s) exceeding the configured size limit:\n"+strings.Join(oversize, "\n")) + } + return fmt.Sprintf( + "Blocked by Checkmarx: referenced %s\nRemove the references from your prompt, reduce file size, or remove the secrets.%s", + strings.Join(sections, "\nand "), DenyMessage, + ) +} + +// Tunable bounds for ScanWorkspaceFilesByPromptName. Generous enough for a +// typical project, tight enough that a misconfigured workspace root pointing +// at $HOME does not stall the prompt submit. +const ( + maxWorkspaceWalkFiles = 5000 + maxWorkspaceWalkDepth = 8 +) + +// skipWorkspaceWalkDirs is the set of directory names that +// ScanWorkspaceFilesByPromptName never descends into. These are package +// manager caches, build outputs, and VCS metadata — none of which the user +// would name in a prompt, and all of which can hold millions of files. +var skipWorkspaceWalkDirs = map[string]struct{}{ + ".git": {}, "node_modules": {}, "target": {}, "build": {}, + "dist": {}, "out": {}, "vendor": {}, ".gradle": {}, ".idea": {}, + ".vscode": {}, "bin": {}, "obj": {}, "__pycache__": {}, + ".next": {}, ".nuxt": {}, ".cache": {}, ".pytest_cache": {}, + ".venv": {}, "venv": {}, ".tox": {}, +} + +// extractPromptTokens splits text into the set of distinct lowercase word +// tokens. Word bytes are a-z, 0-9, '_', '-'; any other byte is a separator. +// The dot is a separator so that "kedar.json" yields the tokens {"kedar","json"} +// — the same shape produced when splitting a filename for matching. +func extractPromptTokens(text string) map[string]struct{} { + tokens := map[string]struct{}{} + var b strings.Builder + flush := func() { + if b.Len() > 0 { + tokens[strings.ToLower(b.String())] = struct{}{} + b.Reset() + } + } + for i := 0; i < len(text); i++ { + c := text[i] + if (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || + (c >= '0' && c <= '9') || c == '_' || c == '-' { + b.WriteByte(c) + } else { + flush() + } + } + flush() + return tokens +} + +// filenameNameParts returns the lowercase "name" segments of a basename that +// are eligible for prompt-token matching. The trailing extension and any +// leading-dot prefix are dropped, so that: +// +// "Kedar" → ["kedar"] +// "kedar.json" → ["kedar"] +// ".env" → ["env"] +// ".env.local" → ["env"] +// "config.local.json" → ["config", "local"] +// "Makefile" → ["makefile"] +// "id_rsa" → ["id_rsa"] +// +// Multi-dot filenames contribute every non-extension segment so that prompts +// can reference any of them. Empty segments are skipped. +func filenameNameParts(basename string) []string { + parts := strings.Split(strings.ToLower(basename), ".") + if len(parts) > 0 && parts[0] == "" { + parts = parts[1:] // dotfile: drop the empty piece before the leading dot + } + if len(parts) > 1 { + parts = parts[:len(parts)-1] // drop the trailing extension + } + seen := map[string]struct{}{} + out := make([]string, 0, len(parts)) + for _, p := range parts { + if p == "" { + continue + } + if _, ok := seen[p]; ok { + continue + } + seen[p] = struct{}{} + out = append(out, p) + } + return out +} + +// ScanWorkspaceFilesByPromptName walks each workspace root and scans any +// regular file whose name (basename or stem) appears in the prompt as a +// whole, case-insensitive token. Returns a rejection reason listing files +// that contain secrets or exceed the size policy, or "" when clean. +// +// Why this guardrail exists: prompts like "check kedar file" do not contain +// an @-mention, a path separator, or a file extension, so none of the path +// regexes fire and ScanReferencedFiles never opens the workspace file named +// "Kedar". If that file holds a JWT, sending the prompt would still leak the +// secret because the model resolves the reference on the fly. This catches +// the case at prompt-submit time. Explicit path references (absolute paths +// or @-mentions) are handled separately by ScanReferencedFiles regardless of +// whether the path is inside the workspace. +// +// Algorithm: tokenize the prompt into a set, then walk the workspace and +// match each file's name parts against that set. Cost is O(words + files) +// rather than O(words × files). +// +// The walk is bounded by maxWorkspaceWalkFiles, maxWorkspaceWalkDepth, and +// the skipWorkspaceWalkDirs prune list. File reads are gated by the policy +// size cap (block on violation) and a memory-budget fallback (skip silently). +// Filesystem errors fail-open — a guardrail must not block the developer on +// I/O noise. +func ScanWorkspaceFilesByPromptName(text string, workspaceRoots []string) string { + if strings.TrimSpace(text) == "" || len(workspaceRoots) == 0 { + return "" + } + promptTokens := extractPromptTokens(text) + if len(promptTokens) == 0 { + return "" + } + + // policyCapBytes: explicit policy size limit; >0 means files larger than this + // are blocked outright (size violation) without inspecting contents. + // scanBudgetBytes: memory ceiling for the scanner. If the policy cap is set, + // it doubles as the budget; otherwise fall back to defaultReferencedFileMaxBytes. + var policyCapBytes int64 + scanBudgetBytes := int64(defaultReferencedFileMaxBytes) + if limits := LoadFilesLimits(); limits != nil && limits.MaxFileSizeKB > 0 { + policyCapBytes = int64(limits.MaxFileSizeKB) * 1024 + scanBudgetBytes = policyCapBytes + } + + seen := map[string]struct{}{} + var perFile []string + var oversize []string + sc := scanner.NewScanner() + walked := 0 + + for _, root := range workspaceRoots { + normalized := NormalizeWorkspaceRoot(root) + if normalized == "" { + continue + } + // NormalizeWorkspaceRoot converts to forward slashes, but WalkDir + // reports `path` in the native form (backslashes on Windows). Count + // against forward-slash projections so the depth check is correct on + // both platforms. + rootSlashCount := strings.Count(filepath.ToSlash(normalized), "/") + + _ = filepath.WalkDir(normalized, func(path string, d fs.DirEntry, err error) error { + if err != nil { + if d != nil && d.IsDir() { + return fs.SkipDir + } + return nil + } + if d.IsDir() { + if path != normalized { + if _, skip := skipWorkspaceWalkDirs[strings.ToLower(d.Name())]; skip { + return fs.SkipDir + } + if strings.Count(filepath.ToSlash(path), "/")-rootSlashCount > maxWorkspaceWalkDepth { + return fs.SkipDir + } + } + return nil + } + if !d.Type().IsRegular() { + return nil + } + walked++ + if walked > maxWorkspaceWalkFiles { + return fs.SkipAll + } + + matched := false + for _, part := range filenameNameParts(d.Name()) { + if _, ok := promptTokens[part]; ok { + matched = true + break + } + } + if !matched { + return nil + } + if _, dup := seen[path]; dup { + return nil + } + seen[path] = struct{}{} + + info, err := d.Info() + if err != nil { + return nil + } + + // Size-policy violation wins regardless of file contents: the policy + // says this file may not enter the AI context at all. + if policyCapBytes > 0 && info.Size() > policyCapBytes { + oversize = append(oversize, fmt.Sprintf( + " %s (%d KB exceeds policy limit of %d KB)", + path, info.Size()/1024, policyCapBytes/1024, + )) + return nil + } + + // Otherwise enforce the scan budget purely as a memory ceiling. + if info.Size() > scanBudgetBytes { + return nil + } + + data, err := os.ReadFile(path) + if err != nil { + return nil + } + content := string(data) + + report, scanErr := sc.Scan( + []scanner.ScanItem{{Content: &content, Source: path}}, + scanner.ScanConfig{WithValidation: true}, + ) + if scanErr != nil { + return nil + } + var findings []string + for _, group := range report.Results { + for _, secret := range group { + severity := severityFromValidation(string(secret.ValidationStatus)) + findings = append(findings, fmt.Sprintf(" - %s (severity: %s)", secret.RuleID, severity)) + } + } + if len(findings) > 0 { + perFile = append(perFile, + fmt.Sprintf(" %s (%d secret(s)):\n%s", path, len(findings), strings.Join(findings, "\n"))) + } + return nil + }) + } + + if len(perFile) == 0 && len(oversize) == 0 { + return "" + } + var sections []string + if len(perFile) > 0 { + sections = append(sections, + "file(s) containing secret(s):\n"+strings.Join(perFile, "\n")) + } + if len(oversize) > 0 { + sections = append(sections, + "file(s) exceeding the configured size limit:\n"+strings.Join(oversize, "\n")) + } + return fmt.Sprintf( + "Blocked by Checkmarx: prompt names workspace %s\nRemove the references from your prompt, reduce file size, or remove the secrets.%s", + strings.Join(sections, "\nand "), DenyMessage, + ) +} + +// ScanFileForSecrets reads the file at path and runs the 2ms secret scanner +// over its contents. Returns a human-readable rejection reason, or "" when the +// file is clean, missing, or unreadable (fail-open). When the policy's +// files_limits.max_file_size_kb is set and the file exceeds it, the file is +// blocked on the size violation alone without inspecting contents — the +// policy already says it must not enter AI context. +// +// This is the content-bearing companion to Cursor's beforeReadFile hook: +// Cursor sends only the path, so we open the file ourselves before the agent +// ingests it into the LLM context. +func ScanFileForSecrets(path string) string { + if path == "" { + return "" + } + info, err := os.Stat(path) + if err != nil || !info.Mode().IsRegular() { + return "" + } + + var policyCapBytes int64 + scanBudgetBytes := int64(defaultReferencedFileMaxBytes) + if limits := LoadFilesLimits(); limits != nil && limits.MaxFileSizeKB > 0 { + policyCapBytes = int64(limits.MaxFileSizeKB) * 1024 + scanBudgetBytes = policyCapBytes + } + + if policyCapBytes > 0 && info.Size() > policyCapBytes { + return fmt.Sprintf( + "Blocked by Checkmarx: file %q (%d KB) exceeds the policy size limit of %d KB and may not enter the AI context.%s", + path, info.Size()/1024, policyCapBytes/1024, DenyMessage, + ) + } + if info.Size() > scanBudgetBytes { + return "" // too big to scan but within policy — fail-open on memory ceiling + } + + data, err := os.ReadFile(path) + if err != nil { + return "" + } + content := string(data) + + report, err := scanner.NewScanner().Scan( + []scanner.ScanItem{{Content: &content, Source: path}}, + scanner.ScanConfig{WithValidation: true}, + ) + if err != nil { + return "" // fail-open per scanner contract + } + var findings []string + for _, group := range report.Results { + for _, secret := range group { + severity := severityFromValidation(string(secret.ValidationStatus)) + findings = append(findings, fmt.Sprintf(" - %s (severity: %s)", secret.RuleID, severity)) + } + } + if len(findings) == 0 { + return "" + } + return fmt.Sprintf( + "Blocked by Checkmarx: file %q contains %d secret(s) and must not enter the AI context:\n%s\nRemove the secrets from the file before letting the agent read it.%s", + path, len(findings), strings.Join(findings, "\n"), DenyMessage, + ) +} + +// ScanForPolicyPatterns runs the custom regex patterns defined in the policy's +// context_policy.content_scanning section against the prompt text. +// Returns a human-readable rejection reason, or "" when the text is clean. +func ScanForPolicyPatterns(text string) string { + policy := LoadPolicy() + if policy == nil { + return "" + } + cp := policy.DefaultPolicy.ContextPolicy + if !cp.Enabled || !cp.ContentScanning.Enabled { + return "" + } + + var findings []string + for _, p := range cp.ContentScanning.Patterns { + re, err := regexp.Compile(p.Pattern) + if err != nil { + continue // skip malformed patterns — fail-open + } + if re.MatchString(text) { + findings = append(findings, fmt.Sprintf(" - %s: %s", p.ID, p.Description)) + } + } + if len(findings) == 0 { + return "" + } + return fmt.Sprintf( + "Blocked by Checkmarx: prompt contains sensitive content detected by policy:\n%s\nRemove the sensitive content and try again.%s", + strings.Join(findings, "\n"), DenyMessage, + ) +} + +// CheckPromptPaths checks file/directory paths mentioned in a prompt against +// the organization's restricted_files and restricted_directories policy. +// +// The effective restricted lists union the global default_policy entries with +// each enabled tool rule's restricted_files / restricted_directories combined +// per the rule's merge_strategy ("merge" / "override" / "default"). This means +// a prompt that references a path restricted by ANY tool rule is blocked, +// regardless of which tool the agent might eventually invoke. +// +// Detection sources: +// - Path-shaped tokens extracted from the prompt (after stripping glob meta). +// - Bare-name word-boundary hits derived from non-glob policy entries +// (e.g. "kubeconfig", "id_rsa") that the path-extraction regexes miss. +// +// Precedence: restricted always wins over allowed. A path that matches both a +// restricted and an allowed list is still blocked. +// +// Returns (true, reason) if blocked, (false, "") if allowed. +func CheckPromptPaths(text string) (bool, string) { + restrictedFiles, restrictedDirs := LoadEffectiveRestrictedPaths() + if len(restrictedFiles) == 0 && len(restrictedDirs) == 0 { + return false, "" + } + + files := extractFilePaths(text) + // Bare-name policy entries (e.g. "kubeconfig") aren't surfaced by the path + // regex, so word-boundary scan for them and feed any hits through the same + // matchFilePattern path below. + for _, hit := range findLiteralAnchorsInText(text, extractLiteralAnchors(restrictedFiles)) { + files = append(files, hit) + } + seen := map[string]struct{}{} + var violations []string + + for _, file := range files { + // restricted_files: literal, basename, suffix, or doublestar glob. + for _, rf := range restrictedFiles { + if matchFilePattern(rf, file) { + if _, ok := seen[file]; !ok { + seen[file] = struct{}{} + violations = append(violations, fmt.Sprintf(" - %s (restricted file)", file)) + } + break + } + } + + // restricted_directories: containment match, with glob support. + if _, already := seen[file]; already { + continue + } + for _, rd := range restrictedDirs { + if matchDirContains(rd, file) { + seen[file] = struct{}{} + violations = append(violations, fmt.Sprintf(" - %s (restricted directory)", file)) + break + } + } + } + + if len(violations) == 0 { + return false, "" + } + return true, fmt.Sprintf( + "Blocked by Checkmarx: the following files or folders are restricted by policy:\n%s\nContact your administrator if you need access to these resources.%s", + strings.Join(violations, "\n"), DenyMessage, + ) +} + +// CheckWorkspaceRoots rejects a prompt whose workspace is within a restricted directory. +// Policy entries are interpreted per-OS via LoadRestrictedPaths; the prefix match +// makes a workspace at C:\foo\bar illegal when C:\foo\ is restricted. +// Returns (true, reason) if any root violates policy, (false, "") otherwise. +func CheckWorkspaceRoots(roots []string) (bool, string) { + if len(roots) == 0 { + return false, "" + } + _, restrictedDirs := LoadRestrictedPaths() + if len(restrictedDirs) == 0 { + return false, "" + } + for _, root := range roots { + normalized := NormalizeWorkspaceRoot(root) + if PathUnderAny(normalized, restrictedDirs) { + return true, fmt.Sprintf( + "Blocked by Checkmarx: workspace %q is restricted by policy.%s", + root, DenyMessage, + ) + } + } + return false, "" +} + +// CheckBlockedExtensions rejects prompts that reference files with a blocked extension +// (e.g. .env, .pem, .key). Returns a rejection reason, or "" when the prompt is clean. +func CheckBlockedExtensions(text string) string { + extensions := LoadBlockedExtensions() + if len(extensions) == 0 { + return "" + } + extSet := make(map[string]struct{}, len(extensions)) + for _, e := range extensions { + extSet[strings.ToLower(e)] = struct{}{} + } + + seen := map[string]struct{}{} + var hits []string + for _, p := range extractFilePaths(text) { + ext := strings.ToLower(filepath.Ext(p)) + if ext == "" { + continue + } + if _, ok := extSet[ext]; !ok { + continue + } + if _, already := seen[p]; already { + continue + } + seen[p] = struct{}{} + hits = append(hits, fmt.Sprintf(" - %s (extension %s)", p, ext)) + } + if len(hits) == 0 { + return "" + } + return fmt.Sprintf( + "Blocked by Checkmarx: prompt references files with blocked extensions:\n%s\nThese file types must not enter the AI context.%s", + strings.Join(hits, "\n"), DenyMessage, + ) +} + +// CheckFilesLimits rejects prompts that reference more files than the policy allows. +// Returns a rejection reason, or "" when the prompt is within the limit. +func CheckFilesLimits(text string) string { + limits := LoadFilesLimits() + if limits == nil || limits.MaxFileCount <= 0 { + return "" + } + paths := extractFilePaths(text) + if len(paths) <= limits.MaxFileCount { + return "" + } + return fmt.Sprintf( + "Blocked by Checkmarx: prompt references %d files, exceeding the policy limit of %d.%s", + len(paths), limits.MaxFileCount, DenyMessage, + ) +} + +// ScanPrompt runs all prompt guardrails in order: +// 1. 2ms secret scanner — detects structured secrets (API keys, tokens, PEM blocks) +// 2. Policy content scanner — detects sensitive content via custom regex patterns +// 3. Path guardrail — blocks prompts referencing restricted files/directories +// 4. Blocked extensions — blocks prompts referencing files with blocked extensions +// 5. Files-limits guardrail — rejects prompts that reference too many files +// +// Returns a human-readable rejection reason, or "" when the text is clean. +func ScanPrompt(text string) string { + if reason := ScanForSecrets(text); reason != "" { + return reason + } + if reason := ScanForPolicyPatterns(text); reason != "" { + return reason + } + if blocked, reason := CheckPromptPaths(text); blocked { + return reason + } + if reason := CheckBlockedExtensions(text); reason != "" { + return reason + } + if reason := CheckFilesLimits(text); reason != "" { + return reason + } + return "" +} diff --git a/internal/commands/agenthooks/guardrails/prompt_test.go b/internal/commands/agenthooks/guardrails/prompt_test.go new file mode 100644 index 000000000..eddb3181e --- /dev/null +++ b/internal/commands/agenthooks/guardrails/prompt_test.go @@ -0,0 +1,561 @@ +//go:build !integration + +package guardrails + +import ( + "encoding/json" + "os" + "path/filepath" + "runtime" + "sort" + "strings" + "testing" +) + +// sampleJWT is a well-known test JWT (no real value) used to give 2ms a +// concrete secret to detect when we want to assert a block. +const sampleJWT = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9." + + "eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ." + + "SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c" + +// resolveReferencedFile is the resolver behind ScanReferencedFiles. We exercise +// it directly because the scanner integration is unchanged — only the resolver +// logic shifted from "literal stat" to "literal stat + glob fallback". + +func TestResolveReferencedFile_LiteralAbsoluteHit(t *testing.T) { + dir := t.TempDir() + target := filepath.Join(dir, "config.yml") + mustWrite(t, target, "k: v") + + got := resolveReferencedFile(target, nil) + if len(got) != 1 || got[0] != target { + t.Fatalf("expected [%q], got %v", target, got) + } +} + +func TestResolveReferencedFile_GlobFallbackFindsSibling(t *testing.T) { + dir := t.TempDir() + mustWrite(t, filepath.Join(dir, "application-jira.yml"), "k: v") + + typed := filepath.Join(dir, "application-jira") // no extension + got := resolveReferencedFile(typed, nil) + + if len(got) != 1 || filepath.Base(got[0]) != "application-jira.yml" { + t.Fatalf("expected glob fallback to find application-jira.yml, got %v", got) + } +} + +func TestResolveReferencedFile_GlobFallbackNoSibling(t *testing.T) { + dir := t.TempDir() + // parent exists but nothing matches the prefix + typed := filepath.Join(dir, "application-jira") + if got := resolveReferencedFile(typed, nil); got != nil { + t.Fatalf("expected nil when nothing matches, got %v", got) + } +} + +func TestResolveReferencedFile_GlobFallbackBailsOnTooManyMatches(t *testing.T) { + dir := t.TempDir() + for i := 0; i <= maxGlobFallbackMatches; i++ { + mustWrite(t, filepath.Join(dir, "common-prefix-"+itoa(i)+".log"), "x") + } + + typed := filepath.Join(dir, "common-prefix") + if got := resolveReferencedFile(typed, nil); got != nil { + t.Fatalf("expected nil when match count exceeds cap, got %d entries", len(got)) + } +} + +func TestResolveReferencedFile_TypedPathIsDirectory(t *testing.T) { + dir := t.TempDir() + subdir := filepath.Join(dir, "secrets") + if err := os.Mkdir(subdir, 0o755); err != nil { + t.Fatalf("mkdir: %v", err) + } + mustWrite(t, filepath.Join(subdir, "creds.yml"), "k: v") + + if got := resolveReferencedFile(subdir, nil); got != nil { + t.Fatalf("expected nil for directory reference, got %v", got) + } +} + +func TestResolveReferencedFile_RelativePathResolvesAgainstWorkspaceRoot(t *testing.T) { + dir := t.TempDir() + mustWrite(t, filepath.Join(dir, "application-jira.yml"), "k: v") + + got := resolveReferencedFile("application-jira", []string{dir}) + if len(got) != 1 || filepath.Base(got[0]) != "application-jira.yml" { + t.Fatalf("expected glob fallback under workspace root to find application-jira.yml, got %v", got) + } +} + +func TestResolveReferencedFile_RelativeStopsAtFirstMatchingRoot(t *testing.T) { + rootA := t.TempDir() + rootB := t.TempDir() + mustWrite(t, filepath.Join(rootA, "config.yml"), "a") + mustWrite(t, filepath.Join(rootB, "config.yml"), "b") + + got := resolveReferencedFile("config.yml", []string{rootA, rootB}) + if len(got) != 1 || filepath.Dir(got[0]) != rootA { + t.Fatalf("expected resolution to stop at rootA, got %v", got) + } +} + +func TestResolveReferencedFile_CursorStyleWindowsRootNormalised(t *testing.T) { + if runtime.GOOS != "windows" { + t.Skip("Cursor /c:/ root form is Windows-specific") + } + dir := t.TempDir() + mustWrite(t, filepath.Join(dir, "application-jira.yml"), "k: v") + + // Cursor reports Windows roots as "/c:/foo"; NormalizeWorkspaceRoot strips + // the leading slash. Confirm the resolver still finds the file via glob. + cursorRoot := "/" + filepath.ToSlash(dir) + got := resolveReferencedFile("application-jira", []string{cursorRoot}) + if len(got) != 1 || filepath.Base(got[0]) != "application-jira.yml" { + t.Fatalf("expected glob fallback under Cursor-style root, got %v", got) + } +} + +func TestResolveReferencedFile_GlobMatchesMixedRegularAndDir(t *testing.T) { + // A directory whose name shares the prefix must not be returned as a file. + dir := t.TempDir() + mustWrite(t, filepath.Join(dir, "app.yml"), "k: v") + if err := os.Mkdir(filepath.Join(dir, "app-data"), 0o755); err != nil { + t.Fatalf("mkdir: %v", err) + } + + typed := filepath.Join(dir, "app") + got := resolveReferencedFile(typed, nil) + sort.Strings(got) + if len(got) != 1 || filepath.Base(got[0]) != "app.yml" { + t.Fatalf("expected only the regular file, got %v", got) + } +} + +func mustWrite(t *testing.T, path, content string) { + t.Helper() + if err := os.WriteFile(path, []byte(content), 0o600); err != nil { + t.Fatalf("write %s: %v", path, err) + } +} + +func itoa(i int) string { + if i == 0 { + return "0" + } + var b [20]byte + pos := len(b) + for i > 0 { + pos-- + b[pos] = byte('0' + i%10) + i /= 10 + } + return string(b[pos:]) +} + +// -------------------------------------------------------------------------- +// ScanWorkspaceFilesByPromptName — bare-word filename matching (Fix C) +// -------------------------------------------------------------------------- + +// writePolicyHelper writes a HooksPolicy to a temp ~/.checkmarx/policyhooks.json +// and redirects the home dir so LoadPolicy() picks it up. Returns a cleanup +// function that must be invoked (typically via defer) to restore the env. +func writePolicyHelper(t *testing.T, policy HooksPolicy) func() { + t.Helper() + data, err := json.Marshal(policy) + if err != nil { + t.Fatalf("marshal policy: %v", err) + } + dir := t.TempDir() + cxDir := filepath.Join(dir, ".checkmarx") + if err := os.MkdirAll(cxDir, 0o755); err != nil { + t.Fatalf("mkdir: %v", err) + } + if err := os.WriteFile(filepath.Join(cxDir, "policyhooks.json"), data, 0o644); err != nil { + t.Fatalf("write policy: %v", err) + } + if runtime.GOOS == "windows" { + orig, had := os.LookupEnv("USERPROFILE") + os.Setenv("USERPROFILE", dir) + return func() { + if had { + os.Setenv("USERPROFILE", orig) + } else { + os.Unsetenv("USERPROFILE") + } + } + } + orig, had := os.LookupEnv("HOME") + os.Setenv("HOME", dir) + return func() { + if had { + os.Setenv("HOME", orig) + } else { + os.Unsetenv("HOME") + } + } +} + +// makeWorkspace writes a workspace directory containing the given files +// (path → contents) and returns the workspace root. Parent directories are +// created automatically. Use to set up ScanWorkspaceFilesByPromptName tests. +func makeWorkspace(t *testing.T, files map[string]string) string { + t.Helper() + ws := filepath.Join(t.TempDir(), "workspace") + for rel, content := range files { + full := filepath.Join(ws, rel) + if err := os.MkdirAll(filepath.Dir(full), 0o755); err != nil { + t.Fatalf("mkdir %s: %v", full, err) + } + mustWrite(t, full, content) + } + return ws +} + +func TestScanWorkspaceFilesByPromptName_BasenameMatch_BlocksOnJWT(t *testing.T) { + ws := makeWorkspace(t, map[string]string{ + "Kedar": "token = " + sampleJWT, + }) + reason := ScanWorkspaceFilesByPromptName("check kedar file", []string{ws}) + if reason == "" { + t.Fatal("expected block: workspace file Kedar contains a JWT and the prompt names it") + } + if !strings.Contains(strings.ToLower(reason), "kedar") { + t.Fatalf("reason should cite the offending file path, got %q", reason) + } +} + +func TestScanWorkspaceFilesByPromptName_CaseInsensitive(t *testing.T) { + ws := makeWorkspace(t, map[string]string{ + "Kedar": "secret = " + sampleJWT, + }) + for _, prompt := range []string{ + "check kedar file", + "Check Kedar File", + "please review the KEDAR doc", + } { + if reason := ScanWorkspaceFilesByPromptName(prompt, []string{ws}); reason == "" { + t.Fatalf("expected block for prompt %q (case-insensitive match)", prompt) + } + } +} + +func TestScanWorkspaceFilesByPromptName_NoAtSymbolRequired(t *testing.T) { + ws := makeWorkspace(t, map[string]string{ + "kedar.json": `{"jwt":"` + sampleJWT + `"}`, + }) + if reason := ScanWorkspaceFilesByPromptName("explain kedar to me", []string{ws}); reason == "" { + t.Fatal("expected block on a plain word `kedar` matching kedar.json by stem") + } +} + +func TestScanWorkspaceFilesByPromptName_StemMatchWithExtension(t *testing.T) { + ws := makeWorkspace(t, map[string]string{ + "kedar.yaml": "token: " + sampleJWT, + }) + if reason := ScanWorkspaceFilesByPromptName("check kedar configs", []string{ws}); reason == "" { + t.Fatal("expected block: prompt `kedar` should match `kedar.yaml` via stem") + } +} + +func TestScanWorkspaceFilesByPromptName_CleanFile_DoesNotBlock(t *testing.T) { + ws := makeWorkspace(t, map[string]string{ + "Kedar": "just notes, nothing sensitive here", + }) + if reason := ScanWorkspaceFilesByPromptName("check kedar file", []string{ws}); reason != "" { + t.Fatalf("expected no block when matched file has no secrets, got %q", reason) + } +} + +func TestScanWorkspaceFilesByPromptName_NoMatch_DoesNotBlock(t *testing.T) { + ws := makeWorkspace(t, map[string]string{ + "Kedar": "token = " + sampleJWT, + }) + if reason := ScanWorkspaceFilesByPromptName("show me the latest tests", []string{ws}); reason != "" { + t.Fatalf("expected no block when prompt does not name any workspace file, got %q", reason) + } +} + +func TestScanWorkspaceFilesByPromptName_SubstringInsideWord_NotMatched(t *testing.T) { + // File "mvn" inside a basename like "foo-mvn-bar" must not be matched + // when the prompt contains those joined word characters. + ws := makeWorkspace(t, map[string]string{ + "foo-mvn-bar/secret.txt": "token = " + sampleJWT, + }) + // The prompt mentions "mvn" but the only file with secrets is named + // "secret.txt"; "mvn" appears only inside a parent dir name and is not a + // basename token, so no scan should match. + if reason := ScanWorkspaceFilesByPromptName("run mvn build", []string{ws}); reason != "" { + t.Fatalf("expected no block: `mvn` is inside a directory name, not a basename, got %q", reason) + } +} + +func TestScanWorkspaceFilesByPromptName_ShortFilenameStillScanned(t *testing.T) { + // A 1-char filename should still be detected when it appears as a standalone + // token in the prompt — the token-boundary check is what prevents over-block. + ws := makeWorkspace(t, map[string]string{ + "a": "token = " + sampleJWT, + }) + if reason := ScanWorkspaceFilesByPromptName("review file a please", []string{ws}); reason == "" { + t.Fatal("expected block: 1-char filename `a` appears as a standalone token in the prompt") + } +} + +func TestScanWorkspaceFilesByPromptName_ShortFilenameInsideWord_NotMatched(t *testing.T) { + // File "a" must NOT match `a` inside "apple", "have", etc. — token boundary protects. + ws := makeWorkspace(t, map[string]string{ + "a": "token = " + sampleJWT, + }) + if reason := ScanWorkspaceFilesByPromptName("there are apples here, have one", []string{ws}); reason != "" { + t.Fatalf("expected no block: `a` only appears inside word characters, got %q", reason) + } +} + +func TestScanWorkspaceFilesByPromptName_BothBasenameAndStem_BothBlock(t *testing.T) { + // Workspace has BOTH `kedar` (no extension) and `Kedar.json`. The prompt + // names `kedar`; both files match (one by basename, one by stem) and both + // contain secrets — the rejection must cite both. + ws := makeWorkspace(t, map[string]string{ + "kedar": "token1 = " + sampleJWT, + "Kedar.json": `{"jwt":"` + sampleJWT + `"}`, + }) + reason := ScanWorkspaceFilesByPromptName("check kedar file", []string{ws}) + if reason == "" { + t.Fatal("expected block: both `kedar` and `Kedar.json` should be detected") + } + if !strings.Contains(reason, "kedar") || !strings.Contains(reason, "Kedar.json") { + t.Fatalf("rejection should cite BOTH files, got %q", reason) + } +} + +func TestScanWorkspaceFilesByPromptName_SizePolicyViolation_BlocksWithoutSecrets(t *testing.T) { + // Policy max_file_size_kb=3. A 5 KB file with no secrets should be blocked + // purely on the size violation: policy says it cannot enter AI context. + policy := HooksPolicy{} + policy.DefaultPolicy.ContextPolicy.Enabled = true + policy.DefaultPolicy.ContextPolicy.FilesLimits = FilesLimits{Enabled: true, MaxFileSizeKB: 3} + defer writePolicyHelper(t, policy)() + + ws := makeWorkspace(t, map[string]string{ + "Kedar.txt": strings.Repeat("a", 5*1024), // 5 KB, no secrets + }) + reason := ScanWorkspaceFilesByPromptName("check kedar file", []string{ws}) + if reason == "" { + t.Fatal("expected block: 5 KB file exceeds 3 KB policy cap") + } + if !strings.Contains(reason, "exceeds policy limit") { + t.Fatalf("reason should cite size policy violation, got %q", reason) + } +} + +func TestScanWorkspaceFilesByPromptName_SizePolicyAtCap_NotBlocked(t *testing.T) { + policy := HooksPolicy{} + policy.DefaultPolicy.ContextPolicy.Enabled = true + policy.DefaultPolicy.ContextPolicy.FilesLimits = FilesLimits{Enabled: true, MaxFileSizeKB: 3} + defer writePolicyHelper(t, policy)() + + ws := makeWorkspace(t, map[string]string{ + "Kedar.txt": strings.Repeat("a", 3*1024), // exactly at cap + }) + if reason := ScanWorkspaceFilesByPromptName("check kedar file", []string{ws}); reason != "" { + t.Fatalf("expected no block at exactly the policy cap, got %q", reason) + } +} + +func TestScanWorkspaceFilesByPromptName_SkipsIgnoredDirs(t *testing.T) { + ws := makeWorkspace(t, map[string]string{ + "node_modules/kedar.json": `{"jwt":"` + sampleJWT + `"}`, + ".git/kedar": "token = " + sampleJWT, + }) + if reason := ScanWorkspaceFilesByPromptName("look at kedar", []string{ws}); reason != "" { + t.Fatalf("expected no block: files only inside node_modules/.git should be pruned, got %q", reason) + } +} + +func TestScanWorkspaceFilesByPromptName_NoWorkspaceRoots_NoOp(t *testing.T) { + if reason := ScanWorkspaceFilesByPromptName("check kedar file", nil); reason != "" { + t.Fatalf("expected no-op with empty workspace roots, got %q", reason) + } +} + +func TestScanWorkspaceFilesByPromptName_CursorStyleWindowsRoot(t *testing.T) { + if runtime.GOOS != "windows" { + t.Skip("Cursor /c:/foo root form is Windows-specific") + } + ws := makeWorkspace(t, map[string]string{ + "Kedar": "token = " + sampleJWT, + }) + // Convert "C:\path\workspace" -> "/c:/path/workspace" (Cursor's form). + slashy := filepath.ToSlash(ws) + cursorRoot := "/" + strings.ToLower(slashy[:2]) + slashy[2:] + if reason := ScanWorkspaceFilesByPromptName("check kedar", []string{cursorRoot}); reason == "" { + t.Fatalf("expected block: Cursor-style root %q should normalize", cursorRoot) + } +} + +func TestScanWorkspaceFilesByPromptName_RecursiveSubdirMatch(t *testing.T) { + // File is nested several levels deep under the workspace root and not in + // any skipped directory. The recursive walk should still find it. + ws := makeWorkspace(t, map[string]string{ + "src/auth/internal/Kedar.txt": "token = " + sampleJWT, + }) + if reason := ScanWorkspaceFilesByPromptName("check kedar file", []string{ws}); reason == "" { + t.Fatal("expected block: nested file should be found by recursive walk") + } +} + +func TestScanWorkspaceFilesByPromptName_MultiDotFilename(t *testing.T) { + // "config.local.json": name parts ["config","local"], either token + // in the prompt is enough to flag the file. + ws := makeWorkspace(t, map[string]string{ + "config.local.json": `{"jwt":"` + sampleJWT + `"}`, + }) + if reason := ScanWorkspaceFilesByPromptName("show me the local override", []string{ws}); reason == "" { + t.Fatal("expected block: `local` is a name part of config.local.json") + } +} + +func TestScanWorkspaceFilesByPromptName_DotfileLeadingDotStripped(t *testing.T) { + // ".env" → name part "env"; prompt mentioning "env" as a token matches. + ws := makeWorkspace(t, map[string]string{ + ".env": "TOKEN=" + sampleJWT, + }) + if reason := ScanWorkspaceFilesByPromptName("review env settings", []string{ws}); reason == "" { + t.Fatal("expected block: leading dot in .env should not prevent matching `env`") + } +} + +func TestScanWorkspaceFilesByPromptName_ExtensionAloneNotMatched(t *testing.T) { + // Generic extensions like "json" must not flag every json file in the repo + // — the trailing extension piece is dropped from filenameNameParts. + ws := makeWorkspace(t, map[string]string{ + "kedar.json": `{"jwt":"` + sampleJWT + `"}`, + }) + if reason := ScanWorkspaceFilesByPromptName("what is a json document", []string{ws}); reason != "" { + t.Fatalf("expected no block: extension `json` should not match by itself, got %q", reason) + } +} + +func TestExtractPromptTokens(t *testing.T) { + got := extractPromptTokens("check Kedar.json and id_rsa, also @secret-config!") + want := []string{"check", "kedar", "json", "and", "id_rsa", "also", "secret-config"} + for _, w := range want { + if _, ok := got[w]; !ok { + t.Errorf("missing token %q in %v", w, got) + } + } +} + +func TestFilenameNameParts(t *testing.T) { + cases := map[string][]string{ + "Kedar": {"kedar"}, + "kedar.json": {"kedar"}, + ".env": {"env"}, + ".env.local": {"env"}, + "config.local.json": {"config", "local"}, + "Makefile": {"makefile"}, + "id_rsa": {"id_rsa"}, + "": nil, + ".": nil, + } + for in, want := range cases { + got := filenameNameParts(in) + if len(got) != len(want) { + t.Errorf("filenameNameParts(%q) = %v; want %v", in, got, want) + continue + } + for i := range want { + if got[i] != want[i] { + t.Errorf("filenameNameParts(%q)[%d] = %q; want %q", in, i, got[i], want[i]) + } + } + } +} + +// -------------------------------------------------------------------------- +// ScanFileForSecrets — Cursor beforeReadFile content gate (Fix D) +// -------------------------------------------------------------------------- + +func TestScanFileForSecrets_BlocksOnJWT(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "Kedar.txt") + mustWrite(t, path, "token = "+sampleJWT) + + reason := ScanFileForSecrets(path) + if reason == "" { + t.Fatal("expected block: file contains a JWT") + } + if !strings.Contains(reason, "Kedar.txt") { + t.Fatalf("reason should cite the file path, got %q", reason) + } + if !strings.Contains(reason, "Do NOT attempt alternative commands") { + t.Fatalf("reason should include DenyMessage, got %q", reason) + } +} + +func TestScanFileForSecrets_CleanFile_NoBlock(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "notes.txt") + mustWrite(t, path, "no secrets here, just notes") + + if reason := ScanFileForSecrets(path); reason != "" { + t.Fatalf("expected no block for clean file, got %q", reason) + } +} + +func TestScanFileForSecrets_MissingFile_FailOpen(t *testing.T) { + if reason := ScanFileForSecrets(filepath.Join(t.TempDir(), "does-not-exist")); reason != "" { + t.Fatalf("expected fail-open for missing file, got %q", reason) + } +} + +func TestScanFileForSecrets_EmptyPath_NoOp(t *testing.T) { + if reason := ScanFileForSecrets(""); reason != "" { + t.Fatalf("expected no-op on empty path, got %q", reason) + } +} + +func TestScanFileForSecrets_OverPolicyCap_BlocksOnSize(t *testing.T) { + policy := HooksPolicy{} + policy.DefaultPolicy.ContextPolicy.Enabled = true + policy.DefaultPolicy.ContextPolicy.FilesLimits = FilesLimits{Enabled: true, MaxFileSizeKB: 3} + defer writePolicyHelper(t, policy)() + + dir := t.TempDir() + path := filepath.Join(dir, "big.txt") + mustWrite(t, path, strings.Repeat("a", 5*1024)) // 5 KB, no secrets + + reason := ScanFileForSecrets(path) + if reason == "" { + t.Fatal("expected block: 5 KB file exceeds 3 KB policy cap") + } + if !strings.Contains(reason, "exceeds the policy size limit") { + t.Fatalf("reason should cite size violation, got %q", reason) + } +} + +func TestScanFileForSecrets_AtPolicyCap_Allowed(t *testing.T) { + policy := HooksPolicy{} + policy.DefaultPolicy.ContextPolicy.Enabled = true + policy.DefaultPolicy.ContextPolicy.FilesLimits = FilesLimits{Enabled: true, MaxFileSizeKB: 3} + defer writePolicyHelper(t, policy)() + + dir := t.TempDir() + path := filepath.Join(dir, "exact.txt") + mustWrite(t, path, strings.Repeat("a", 3*1024)) // exactly at cap, no secrets + + if reason := ScanFileForSecrets(path); reason != "" { + t.Fatalf("expected no block at exactly the policy cap, got %q", reason) + } +} + +func TestScanWorkspaceFilesByPromptName_DenyMessageAppended(t *testing.T) { + ws := makeWorkspace(t, map[string]string{ + "Kedar": "token = " + sampleJWT, + }) + reason := ScanWorkspaceFilesByPromptName("check kedar file", []string{ws}) + if !strings.Contains(reason, "Do NOT attempt alternative commands") { + t.Fatalf("expected DenyMessage no-workaround text in reason, got %q", reason) + } +} diff --git a/internal/commands/agenthooks/guardrails/shell.go b/internal/commands/agenthooks/guardrails/shell.go new file mode 100644 index 000000000..6aede4950 --- /dev/null +++ b/internal/commands/agenthooks/guardrails/shell.go @@ -0,0 +1,238 @@ +package guardrails + +import ( + "fmt" + "path" + "strings" +) + +// CheckShellCommand checks a shell command against the blacklist and per-tool rules. +// +// Returns: +// - blocked=true, needsConfirm=false → deny the command +// - blocked=true, needsConfirm=true → ask the user for confirmation +// - blocked=false → permit the command +// +// Enforcement order (first hit wins): +// 1. Global blacklist_tools → hard block +// 2. args_exclude → hard block +// 3. Tool-level restricted_dirs / files → hard block +// 4. Global restricted_dirs / files → hard block +// 5. args_include whitelist → ask on first unmatched token +// 6. allowed_directories → ask if workDir not in effective list +// 7. allowed_files → ask if file arg not in effective list +func CheckShellCommand(command, workDir string) (blocked bool, needsConfirm bool, reason string) { + // 1. Global blacklist check. + blacklisted := LoadBlacklistedCommands() + cmdLower := strings.ToLower(command) + for name, tool := range blacklisted { + if strings.Contains(cmdLower, name) { + return true, false, fmt.Sprintf( + "Blocked by Checkmarx: command %q is not allowed.\nCategory: %s\nReason: %s%s", + name, tool.Category, tool.Risk, DenyMessage, + ) + } + } + + // 2. Per-tool rule enforcement. + rule := FindMatchingToolRule(command) + if rule == nil { + // No tool rule matched — still enforce global restricted paths. + return checkGlobalRestrictedPaths(command, workDir) + } + + // 2a. args_exclude — hard deny, takes precedence over everything else. + for _, excluded := range rule.ArgsExclude { + if strings.Contains(cmdLower, strings.ToLower(excluded)) { + return true, false, fmt.Sprintf( + "Blocked by Checkmarx: argument %q is not permitted for this tool.\nContact your administrator if you need this operation.%s", + excluded, DenyMessage, + ) + } + } + + // 2b. Tool-level restricted_directories / restricted_files (merged per strategy). + if blocked, needsConfirm, reason := checkToolRestrictedPaths(command, workDir, rule); blocked { + return blocked, needsConfirm, reason + } + + // 2c. args_include whitelist — any token not matched by an entry (exact or glob) + // triggers a confirmation request regardless of rule.Action. This is intentionally + // "ask" rather than "block" because an unknown arg is not necessarily dangerous; + // it is simply unknown to the policy author. + if len(rule.ArgsInclude) > 0 { + tokens := strings.Fields(command) + if len(tokens) > 1 { // skip the command name itself (tokens[0]) + for _, tok := range tokens[1:] { + if !argMatchesAny(tok, rule.ArgsInclude) { + msg := fmt.Sprintf( + "Argument %q is not in the approved list for this tool.%s", + tok, DenyMessage, + ) + return true, true, msg // always ask on first unmatched token + } + } + } + } + + // 2d. Allowed-directories check — if a list is defined, workDir must be inside it. + // Unknown (not in list) → ask. + if workDir != "" { + _, globalDirs := LoadAllowedPaths() + effectiveDirs := ResolveAllowedPaths(globalDirs, GetOSPaths(rule.AllowedDirectories), rule.MergeStrategy.AllowedDirectories) + if len(effectiveDirs) > 0 && !PathUnderAny(workDir, effectiveDirs) { + return true, true, fmt.Sprintf( + "Working directory %q is not in the allowed list for this tool.%s", + workDir, DenyMessage, + ) + } + } + + // 2e. Allowed-files check — file-like tokens in the command must match at least + // one entry in the effective list. Matching supports literal paths, basenames, + // and doublestar globs (e.g. "**/pom.xml", "*.java"). Unknown → ask. + globalFiles, _ := LoadAllowedPaths() + effectiveFiles := ResolveAllowedPaths(globalFiles, GetOSPaths(rule.AllowedFiles), rule.MergeStrategy.AllowedFiles) + if len(effectiveFiles) > 0 { + tokens := strings.Fields(command) + if len(tokens) > 1 { + for _, token := range tokens[1:] { + // Only check tokens that look like file names (contain a path separator or a dot). + if !strings.ContainsAny(token, "./\\") { + continue + } + if !anyPatternMatchesFile(effectiveFiles, token) { + return true, true, fmt.Sprintf( + "File %q is not in the allowed list for this tool.%s", + token, DenyMessage, + ) + } + } + } + } + + return false, false, "" +} + +// checkToolRestrictedPaths enforces restricted_directories and restricted_files for a +// matched tool rule, merging tool-level paths with the global lists per the rule's +// merge_strategy. Violations are always a hard block (needsConfirm=false). +func checkToolRestrictedPaths(command, workDir string, rule *ToolRule) (bool, bool, string) { + globalFiles, globalDirs := LoadRestrictedPaths() + + // Restricted directories: workDir must not fall under any effective restricted dir. + effectiveDirs := ResolveRestrictedPaths(globalDirs, GetOSPaths(rule.RestrictedDirectories), rule.MergeStrategy.RestrictedDirectories) + if workDir != "" && len(effectiveDirs) > 0 && PathUnderAny(workDir, effectiveDirs) { + return true, false, fmt.Sprintf( + "Blocked by Checkmarx: working directory %q is restricted by policy and not permitted for this tool.%s", + workDir, DenyMessage, + ) + } + + // Restricted files: no file arg may match an effective restricted file. + effectiveFiles := ResolveRestrictedPaths(globalFiles, GetOSPaths(rule.RestrictedFiles), rule.MergeStrategy.RestrictedFiles) + if len(effectiveFiles) > 0 { + if hit := findRestrictedFileInCommand(command, effectiveFiles); hit != "" { + return true, false, fmt.Sprintf( + "Blocked by Checkmarx: file %q is restricted by policy and not permitted for this tool.%s", + hit, DenyMessage, + ) + } + } + return false, false, "" +} + +// checkGlobalRestrictedPaths enforces the global restricted_directories and +// restricted_files when no tool rule matches the command. Always a hard block. +func checkGlobalRestrictedPaths(command, workDir string) (bool, bool, string) { + globalFiles, globalDirs := LoadRestrictedPaths() + if len(globalFiles) == 0 && len(globalDirs) == 0 { + return false, false, "" + } + + if workDir != "" && len(globalDirs) > 0 && PathUnderAny(workDir, globalDirs) { + return true, false, fmt.Sprintf( + "Blocked by Checkmarx: working directory %q is restricted by policy.%s", + workDir, DenyMessage, + ) + } + if hit := findRestrictedFileInCommand(command, globalFiles); hit != "" { + return true, false, fmt.Sprintf( + "Blocked by Checkmarx: file %q is restricted by policy.%s", + hit, DenyMessage, + ) + } + return false, false, "" +} + +// findRestrictedFileInCommand returns the first token in the command (skipping +// the command name itself) that matches any entry in restrictedFiles. +// Patterns may be literal paths, basenames, or doublestar globs (e.g. "**/*.pem"). +// Returns "" when no token matches. +// +// Two passes: +// 1. Path-shaped tokens (containing ./\) match against any policy entry via +// matchFilePattern (literal, basename, suffix, or doublestar glob). +// 2. Bare-word tokens match against non-glob (literal) policy entries by +// case-insensitive equality. This catches cases like `cat kubeconfig` +// where the file argument has no path separator or extension. +func findRestrictedFileInCommand(command string, restrictedFiles []string) string { + tokens := strings.Fields(command) + if len(tokens) <= 1 { + return "" + } + for _, token := range tokens[1:] { + if !strings.ContainsAny(token, "./\\") { + continue + } + for _, rf := range restrictedFiles { + if matchFilePattern(rf, token) { + return token + } + } + } + literalAnchors := extractLiteralAnchors(restrictedFiles) + if len(literalAnchors) == 0 { + return "" + } + for _, token := range tokens[1:] { + if strings.ContainsAny(token, "./\\") { + continue + } + for _, a := range literalAnchors { + if strings.EqualFold(token, a) { + return token + } + } + } + return "" +} + +// argMatchesAny returns true when arg matches at least one entry in patterns +// using case-insensitive exact match or path.Match glob syntax (e.g. "-D*", "--*"). +func argMatchesAny(arg string, patterns []string) bool { + argLower := strings.ToLower(arg) + for _, p := range patterns { + pLower := strings.ToLower(p) + if argLower == pLower { + return true + } + if matched, _ := path.Match(pLower, argLower); matched { + return true + } + } + return false +} + +// PathUnderAny returns true when path falls within at least one of the candidate +// directories. Directory patterns may be literal paths or doublestar globs +// (e.g. "/home/*/.ssh", "**/secrets/**"); in the glob case a target matches +// when it equals the glob-matched directory or is nested under it. +func PathUnderAny(path string, dirs []string) bool { + for _, d := range dirs { + if matchDirContains(d, path) { + return true + } + } + return false +} diff --git a/internal/commands/agenthooks/mcp/bridge.go b/internal/commands/agenthooks/mcp/bridge.go new file mode 100644 index 000000000..b6ebe39e8 --- /dev/null +++ b/internal/commands/agenthooks/mcp/bridge.go @@ -0,0 +1,750 @@ +package mcp + +import ( + "bufio" + "bytes" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "os" + "strings" + "sync" + "time" + + "github.com/checkmarx/ast-cli/internal/logger" + commonParams "github.com/checkmarx/ast-cli/internal/params" + "github.com/checkmarx/ast-cli/internal/wrappers" + "github.com/checkmarx/ast-cli/internal/wrappers/configuration" + "github.com/spf13/cobra" + "github.com/spf13/viper" +) + +// The bridge is a transparent stdio<->HTTP proxy to the remote Checkmarx Security +// MCP. It exists so the Claude Code plugin's .mcp.json can launch the remediation +// MCP with `command: "cx", args: ["mcp", "bridge"]` — using cx itself, the one +// binary guaranteed present, instead of bash/node/python (none of which are +// guaranteed across Windows/macOS/Linux or on a native, Bun-based Claude install). +// +// It reads the credential cx already resolved (env CX_APIKEY / cx config, loaded +// at startup), derives the realm-scoped URL from the credential's JWT `iss` +// claim, and forwards newline-delimited JSON-RPC between stdin/stdout and the +// remote MCP's Streamable HTTP endpoint (application/json + text/event-stream, +// Mcp-Session-Id). The credential is sent ONLY in the Authorization header (the +// server exchanges it; no client-side OAuth flow runs here) and is never written +// to stdout/stderr. +// +// Resilient/self-healing: if no usable credential exists at startup (e.g. the +// developer hasn't logged in yet, or the stored token is expired/malformed in a way +// that prevents URL derivation), the bridge does NOT exit. It answers the MCP +// handshake LOCALLY so the client shows the server "Connected" with an empty tool +// list (capabilities.tools.listChanged=true), and a background watcher polls cx +// config on disk; the moment `cx auth login` writes a usable credential, the bridge +// establishes the remote session and pushes notifications/tools/list_changed so the +// client auto-fetches the real tools — no /reload-plugins, no /mcp reconnect. +const ( + securityMCPPathPrefix = "/api/security-mcp/mcp/" + bridgeRequestTimeout = 120 * time.Second + jsonrpcInternalError = -32000 // JSON-RPC reserved server-error code + httpAccepted = 202 // POST of a notification/response: no body to relay +) + +// bridgeState is the bridge's connection lifecycle: stateUnauth answers the MCP +// handshake locally and runs a credential watcher; stateConnected proxies to the +// remote (the only state the authed-at-startup path is ever in). +type bridgeState int + +const ( + stateUnauth bridgeState = iota + stateConnected +) + +// supportedProtocolVersions are the MCP protocol versions the local handshake can +// advertise. The first is the preferred default when a client omits one. +var supportedProtocolVersions = []string{"2025-06-18", "2024-11-05"} + +func defaultProtocolVersion() string { return supportedProtocolVersions[0] } + +// bridgeSession holds the MCP session state. Fields touched by both the read loop +// and the watcher goroutine (state/apiKey/mcpURL/clientProto/remoteReady) are +// guarded by mu; id/proto are mutated only after promotion (single-threaded). +type bridgeSession struct { + mu sync.Mutex + state bridgeState + apiKey string // raw credential forwarded to the remote (Authorization) + mcpURL string // realm-scoped Security MCP URL + clientProto string // protocolVersion the client requested at initialize + remoteReady bool // the remote initialize handshake has completed + version string // cx binary version, for the synthetic serverInfo + urlOverride string // --mcp-url / CX_MCP_URL override, immutable after construction + writer *syncWriter + + // resolveKey returns the credential cx resolved. Injected (not a package + // global) so the concurrent self-heal test can stub it without racing viper. + resolveKey func() string + + id string // Mcp-Session-Id, echoed back on every subsequent request + proto string // negotiated protocolVersion, sent as MCP-Protocol-Version +} + +// syncWriter serializes all stdout writes so a full JSON-RPC line is written+flushed +// atomically — the read loop and the watcher goroutine share it, and the MCP stdio +// transport forbids interleaved/partial lines. +type syncWriter struct { + mu sync.Mutex + w *bufio.Writer +} + +func newSyncWriter(w io.Writer) *syncWriter { return &syncWriter{w: bufio.NewWriter(w)} } + +func (sw *syncWriter) emitLine(raw []byte) { + sw.mu.Lock() + defer sw.mu.Unlock() + _, _ = sw.w.Write(raw) + _ = sw.w.WriteByte('\n') + _ = sw.w.Flush() +} + +const mcpURLFlag = "mcp-url" + +// Test seams. getAccessToken stubs the (network) refresh_token->access_token +// exchange. reloadConfig re-reads cx config FROM DISK into viper — essential for +// self-heal, since viper is a one-shot startup snapshot and would otherwise never +// see a credential written by a later `cx auth login`. invalidateTokenCache forces +// a fresh access-token exchange. credentialPollInterval paces the watcher. +var ( + getAccessToken = wrappers.GetAccessToken + + reloadConfig = func() { + // viper is not concurrency-safe; serialize disk re-reads (which mutate + // viper) against the credential resolver's viper reads. Both the read + // loop (on 401/403) and the watcher goroutine call this. + configMu.Lock() + defer configMu.Unlock() + _ = configuration.LoadConfiguration() + wrappers.LoadActiveCredential() + } + invalidateTokenCache = wrappers.InvalidateAccessTokenCache + credentialPollInterval = 3 * time.Second +) + +// configMu serializes all viper access performed by the bridge's two goroutines +// (the stdin read loop and the credential watcher): reloadConfig mutates viper, +// productionResolveAPIKey reads it. viper itself has no internal locking. +var configMu sync.Mutex + +// productionResolveAPIKey returns the credential cx resolved (CX_APIKEY env / cx +// config / active session), falling back to CHECKMARX_API_KEY for parity with the +// previous Python bridge. Callers that need a credential written AFTER startup must +// call reloadConfig() first (viper is a one-shot startup snapshot). The viper read +// is guarded by configMu so it never races a concurrent reloadConfig. +func productionResolveAPIKey() string { + configMu.Lock() + k := strings.TrimSpace(viper.GetString(commonParams.AstAPIKey)) + configMu.Unlock() + if k != "" { + return k + } + if k := strings.TrimSpace(os.Getenv("CHECKMARX_API_KEY")); k != "" { + return k + } + return "" +} + +// NewBridgeCommand creates the hidden "cx mcp bridge" subcommand. version is the cx +// binary version, surfaced in the synthetic serverInfo during the degraded handshake. +func NewBridgeCommand(version string) *cobra.Command { + cmd := &cobra.Command{ + Use: "bridge", + Short: "Proxy stdio MCP traffic to the Checkmarx Security remediation MCP", + Long: `Run a stdio<->HTTP bridge to the remote Checkmarx Security MCP. + +Intended to be launched by an AI coding assistant as an MCP server: + + { + "mcpServers": { + "Checkmarx": { "command": "cx", "args": ["mcp", "bridge"] } + } + } + +The credential is read from cx config (or CX_APIKEY). The realm-scoped URL is +resolved by, in order: the --mcp-url flag, the CX_MCP_URL env var, the +authoritative "ast-base-url" claim from the exchanged access token (works for +any region/on-prem), then an offline IAM->AST host swap. Override with --mcp-url +(preferred for MCP clients that pass args but not env) or CX_MCP_URL only for +air-gapped setups where the token endpoint is unreachable at startup. + +If no usable credential exists at startup the bridge stays up in a degraded state +and connects automatically once you run 'cx auth login' — no restart needed.`, + Hidden: true, // internal plumbing for the Claude Code / IDE plugins + RunE: func(cmd *cobra.Command, _ []string) error { + urlOverride, _ := cmd.Flags().GetString(mcpURLFlag) + return runBridge(version, urlOverride) + }, + } + cmd.Flags().String(mcpURLFlag, "", "Override the Security MCP URL (dev/on-prem where it can't be derived from the credential)") + return cmd +} + +// newBridgeClient builds the HTTP client used to reach the remote Security MCP. When a cx proxy is +// configured — the config-file cx_http_proxy or the HTTP_PROXY/CX_HTTP_PROXY env, including NTLM / +// Kerberos auth types — it reuses the CLI's proxy-aware client (wrappers.GetClient), so a single +// `cx configure`-based proxy setup covers both the CLI and the MCP. Otherwise it keeps the default +// transport, which already honors HTTPS_PROXY / HTTP_PROXY / NO_PROXY — so no existing env-var proxy +// behavior is lost. +func newBridgeClient() *http.Client { + if strings.TrimSpace(viper.GetString(commonParams.ProxyKey)) != "" { + c := wrappers.GetClient(uint(bridgeRequestTimeout / time.Second)) + c.Timeout = bridgeRequestTimeout // preserve the exact bridge timeout + return c + } + return &http.Client{Timeout: bridgeRequestTimeout} +} + +func runBridge(version, urlOverride string) error { + return runBridgeIO(os.Stdin, os.Stdout, newBridgeClient(), version, urlOverride, productionResolveAPIKey) +} + +// runBridgeIO is the testable core: it wires the session to the given streams, +// decides the startup state, runs the watcher when degraded, and pumps the stdin +// read loop. It never exits the process on a missing credential. resolveKey is the +// credential resolver, injected so tests can stub it without racing viper. +func runBridgeIO(in io.Reader, out io.Writer, client *http.Client, version, urlOverride string, resolveKey func() string) error { + sess := &bridgeSession{writer: newSyncWriter(out), version: version, resolveKey: resolveKey, urlOverride: urlOverride} + + apiKey := sess.resolveKey() + mcpURL, err := deriveMCPURL(apiKey, urlOverride) + if apiKey != "" && err == nil { + sess.state = stateConnected + sess.apiKey = apiKey + sess.mcpURL = mcpURL + } else { + sess.state = stateUnauth + // Degraded notice goes to STDERR only (stdout is the protocol channel). + fmt.Fprintln(os.Stderr, "cx mcp bridge: no usable Checkmarx credential yet — serving in a degraded state. "+ + "Log in with 'cx auth login' (or /cx-cli-setup) using the default (yaml) or '--session global' mode "+ + "(NOT '--session local', which this process can't see); Checkmarx tools appear automatically once authenticated. "+ + "For on-prem/custom domains, set CX_MCP_URL or pass --mcp-url.") + stop := make(chan struct{}) + var wg sync.WaitGroup + wg.Add(1) + go func() { defer wg.Done(); sess.watchForCredential(client, urlOverride, stop) }() + // Clean shutdown on EOF: stop the watcher and wait for it to fully exit before + // returning, so no goroutine outlives the bridge or touches config concurrently. + defer func() { close(stop); wg.Wait() }() + } + + reader := bufio.NewReader(in) + for { + line, readErr := reader.ReadString('\n') + if body := bytes.TrimSpace([]byte(line)); len(body) > 0 { + sess.dispatch(client, body) + } + if readErr != nil { + break // EOF — the client closed the connection + } + } + return nil +} + +// deriveMCPURL builds the realm-scoped Security MCP URL, region/tenant/on-prem +// agnostic. Resolution order (top wins): +// 1. the --mcp-url flag (explicit override), +// 2. the CX_MCP_URL env var (explicit override), +// 3. the authoritative "ast-base-url" claim from the exchanged ACCESS token — +// the AST base the IAM server itself issued (works for any region/on-prem), +// 4. an offline IAM->AST host swap from the credential's `iss` claim, for +// standard cloud regions when the token exchange is unavailable. +// +// The realm (tenant) always comes from the stored credential's `iss` claim and is +// independent of how the AST base host is resolved. +func deriveMCPURL(apiKey, override string) (string, error) { + override = strings.TrimSpace(override) + if override == "" { + override = strings.TrimSpace(os.Getenv("CX_MCP_URL")) + } + if override != "" { + return strings.TrimRight(override, "/"), nil + } + if apiKey == "" { + return "", errors.New("no API key") + } + + issuer, err := wrappers.ExtractFromTokenClaims(apiKey, "iss") + if err != nil { + return "", err + } + realm, err := realmFromIssuer(issuer) + if err != nil { + return "", err + } + + // 3. Authoritative: AST base from the access token's ast-base-url claim. On any + // failure (token endpoint unreachable, claim absent) fall through to the swap. + if base := astBaseFromAccessToken(); base != "" { + return joinSecurityMCPURL(base, realm), nil + } + + // 4. Offline fallback for standard cloud regions. + astBase, err := astBaseFromIAMHost(issuer) + if err != nil { + return "", err + } + return joinSecurityMCPURL(astBase, realm), nil +} + +// astBaseFromAccessToken exchanges the stored refresh token for an access token +// (cached, non-interactive grant_type=refresh_token) and reads the authoritative +// "ast-base-url" claim from it — the claim is present only on the access token, +// never on the stored refresh token. Returns "" (logging at verbose) on any +// failure so the caller can fall back to the offline host swap. The access token +// is used ONLY for URL discovery here; it is never sent to the MCP server. +func astBaseFromAccessToken() string { + accessToken, err := getAccessToken() + if err != nil { + logger.PrintIfVerbose("cx mcp bridge: access-token exchange failed, falling back to IAM->AST host swap: " + err.Error()) + return "" + } + base, err := wrappers.ExtractFromTokenClaims(accessToken, wrappers.BaseURLKey) + if err != nil { + logger.PrintIfVerbose("cx mcp bridge: ast-base-url claim unavailable, falling back to IAM->AST host swap: " + err.Error()) + return "" + } + return strings.TrimSpace(base) +} + +// buildSecurityMCPURL maps a JWT issuer (e.g. https://eu.iam.checkmarx.net/auth/realms/) +// to the realm-scoped Security MCP URL via the offline IAM->AST host swap +// (https://eu.ast.checkmarx.net/api/security-mcp/mcp/). +func buildSecurityMCPURL(issuer string) (string, error) { + astBase, err := astBaseFromIAMHost(issuer) + if err != nil { + return "", err + } + realm, err := realmFromIssuer(issuer) + if err != nil { + return "", err + } + return joinSecurityMCPURL(astBase, realm), nil +} + +// astBaseFromIAMHost maps a JWT issuer's IAM host to the AST base URL for standard +// cloud regions (.iam -> .ast, or iam. -> ast.). Dev/on-prem/custom +// hosts are not mappable and return an error (use ast-base-url or CX_MCP_URL). +func astBaseFromIAMHost(issuer string) (string, error) { + parsed, err := url.Parse(strings.TrimRight(strings.TrimSpace(issuer), "/")) + if err != nil || parsed.Host == "" { + return "", errors.New("could not parse issuer host from API key") + } + switch host := parsed.Host; { + case strings.Contains(host, ".iam."): + return "https://" + strings.Replace(host, ".iam.", ".ast.", 1), nil + case strings.HasPrefix(host, "iam."): + return "https://ast." + strings.TrimPrefix(host, "iam."), nil + default: + return "", errors.New("could not map IAM host to AST host; set CX_MCP_URL for on-prem/custom domains") + } +} + +// realmFromIssuer extracts the realm (tenant) from a JWT issuer URL whose path ends +// in .../auth/realms/. +func realmFromIssuer(issuer string) (string, error) { + parsed, err := url.Parse(strings.TrimRight(strings.TrimSpace(issuer), "/")) + if err != nil { + return "", errors.New("could not parse issuer from API key") + } + segments := strings.Split(strings.Trim(parsed.Path, "/"), "/") + realm := segments[len(segments)-1] + if realm == "" { + return "", errors.New("could not derive realm from issuer") + } + return realm, nil +} + +// joinSecurityMCPURL composes the realm-scoped Security MCP endpoint from an AST +// base URL and a realm. +func joinSecurityMCPURL(base, realm string) string { + return strings.TrimRight(strings.TrimSpace(base), "/") + securityMCPPathPrefix + realm +} + +// dispatch handles one inbound JSON-RPC line. When unauthenticated it answers the +// MCP handshake locally; when connected it proxies to the remote (unchanged). +func (s *bridgeSession) dispatch(client *http.Client, body []byte) { + s.mu.Lock() + state := s.state + apiKey := s.apiKey + mcpURL := s.mcpURL + s.mu.Unlock() + + if state == stateUnauth { + s.dispatchLocal(body) + return + } + s.proxy(client, mcpURL, apiKey, body) +} + +// dispatchLocal answers the MCP handshake without any network access so the client +// shows the server Connected while we wait for a credential. initialize advertises +// tools.listChanged=true (the contract that lets the client accept an empty list now +// and re-fetch after notifications/tools/list_changed); tools/list returns empty; +// ping is answered; any other request returns a clear "not authenticated" error. +func (s *bridgeSession) dispatchLocal(body []byte) { + var req struct { + ID json.RawMessage `json:"id"` + Method string `json:"method"` + Params struct { + ProtocolVersion string `json:"protocolVersion"` + } `json:"params"` + } + if err := json.Unmarshal(body, &req); err != nil { + return + } + switch req.Method { + case "initialize": + s.emitLocalInitialize(req.ID, req.Params.ProtocolVersion) + case "notifications/initialized": + // client readiness signal — no response + case "tools/list": + s.emitResult(req.ID, map[string]interface{}{"tools": []interface{}{}}) + case "ping": + s.emitResult(req.ID, map[string]interface{}{}) + default: + if len(req.ID) > 0 { + s.writeError(body, "not authenticated yet — run 'cx auth login' (or /cx-cli-setup); Checkmarx tools enable automatically once authenticated") + } + } +} + +// emitLocalInitialize answers initialize locally, echoing the client's requested +// protocolVersion (or a supported default) and remembering it for the eventual +// remote handshake. +func (s *bridgeSession) emitLocalInitialize(id json.RawMessage, clientProto string) { + proto := strings.TrimSpace(clientProto) + if proto == "" { + proto = defaultProtocolVersion() + } + s.mu.Lock() + s.clientProto = proto + s.mu.Unlock() + s.emitResult(id, map[string]interface{}{ + "protocolVersion": proto, + "capabilities": map[string]interface{}{"tools": map[string]interface{}{"listChanged": true}}, + "serverInfo": map[string]interface{}{"name": "Checkmarx Security", "version": s.version}, + "instructions": "Checkmarx Security is initializing — its tools appear once you authenticate (run 'cx auth login' or /cx-cli-setup).", + }) +} + +// emitResult writes a JSON-RPC result for the given id. +func (s *bridgeSession) emitResult(id json.RawMessage, result interface{}) { + reply, err := json.Marshal(map[string]interface{}{"jsonrpc": "2.0", "id": id, "result": result}) + if err != nil { + return + } + s.writer.emitLine(reply) +} + +// proxy forwards one JSON-RPC request to the remote and relays the response. On +// 401/403 it re-reads cx config from disk and retries once with the refreshed +// credential — a rotated `cx auth login` token self-heals without a restart, while a +// dead key fails fast instead of looping. +func (s *bridgeSession) proxy(client *http.Client, mcpURL, apiKey string, body []byte) { + resp, err := s.post(client, mcpURL, apiKey, body) + if err != nil { + fmt.Fprintf(os.Stderr, "cx mcp bridge: request failed: %v\n", err) + s.writeError(body, "") + return + } + + if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden { + resp.Body.Close() + reloadConfig() // re-read disk so a token rotated by another process is visible + reloaded := s.resolveKey() + if reloaded != "" && reloaded != apiKey { + // A rotated credential may target a different tenant/realm, so + // re-derive the realm-scoped URL rather than replaying the stale + // one. Invalidate the cached access token first (as tryHeal does) + // so the ast-base-url claim is re-fetched for the new tenant instead + // of reusing the old tenant's cached base. If the URL actually + // changes, the previous Mcp-Session-Id is scoped to the old tenant + // and must not be reused — drop it so the remote issues a fresh + // session for the new tenant. + invalidateTokenCache() + retryURL := mcpURL + if derived, derr := deriveMCPURL(reloaded, s.urlOverride); derr == nil && derived != "" { + retryURL = derived + } + s.mu.Lock() + s.apiKey = reloaded + if retryURL != s.mcpURL { + s.mcpURL = retryURL + s.id = "" + } + s.mu.Unlock() + retry, retryErr := s.post(client, retryURL, reloaded, body) + if retryErr != nil { + fmt.Fprintf(os.Stderr, "cx mcp bridge: retry after credential reload failed: %v\n", retryErr) + s.writeError(body, "") + return + } + s.finish(retry, body) + return + } + fmt.Fprintf(os.Stderr, "cx mcp bridge: HTTP %d from MCP endpoint (no fresh credential to retry with)\n", resp.StatusCode) + s.writeError(body, fmt.Sprintf("HTTP %d (auth failed — run /cx-cli-setup to re-authenticate)", resp.StatusCode)) + return + } + + s.finish(resp, body) +} + +func (s *bridgeSession) post(client *http.Client, mcpURL, apiKey string, body []byte) (*http.Response, error) { + req, err := http.NewRequest(http.MethodPost, mcpURL, bytes.NewReader(body)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json, text/event-stream") + req.Header.Set("Accept-Encoding", "identity") + // Forward the RAW stored credential (API key / refresh token); the server + // exchanges it server-side. Never send the access token fetched in + // deriveMCPURL — that is for URL discovery only. + req.Header.Set("Authorization", apiKey) + if s.id != "" { + req.Header.Set("Mcp-Session-Id", s.id) + } + if s.proto != "" { + req.Header.Set("MCP-Protocol-Version", s.proto) + } + return client.Do(req) +} + +func (s *bridgeSession) finish(resp *http.Response, body []byte) { + if resp.StatusCode >= 400 { + fmt.Fprintf(os.Stderr, "cx mcp bridge: HTTP %d from MCP endpoint\n", resp.StatusCode) + s.writeError(body, fmt.Sprintf("HTTP %d", resp.StatusCode)) + resp.Body.Close() + return + } + s.handleResponse(resp) +} + +func (s *bridgeSession) handleResponse(resp *http.Response) { + defer resp.Body.Close() + if sid := resp.Header.Get("Mcp-Session-Id"); sid != "" { + s.id = sid + } + if resp.StatusCode == httpAccepted { + _, _ = io.Copy(io.Discard, resp.Body) + return + } + if strings.Contains(strings.ToLower(resp.Header.Get("Content-Type")), "text/event-stream") { + s.pumpSSE(resp.Body) + return + } + raw, _ := io.ReadAll(resp.Body) + s.emit(raw) +} + +// pumpSSE parses an SSE stream and emits each buffered `data:` payload as a single +// JSON-RPC line until the stream ends. +func (s *bridgeSession) pumpSSE(body io.Reader) { + reader := bufio.NewReader(body) + var dataLines []string + flush := func() { + if len(dataLines) > 0 { + s.emit([]byte(strings.Join(dataLines, "\n"))) + dataLines = dataLines[:0] + } + } + for { + line, err := reader.ReadString('\n') + if len(line) > 0 { + switch trimmed := strings.TrimRight(line, "\r\n"); { + case trimmed == "": // blank line dispatches the buffered event + flush() + case strings.HasPrefix(trimmed, ":"): // SSE comment / keep-alive + case strings.HasPrefix(trimmed, "data:"): + dataLines = append(dataLines, strings.TrimLeft(strings.TrimPrefix(trimmed, "data:"), " \t")) + default: // event:/id:/retry: are not needed for JSON-RPC transport + } + } + if err != nil { + break + } + } + flush() // flush a trailing event with no terminating blank line +} + +// emit writes one validated JSON-RPC message to stdout as a single line, capturing +// the negotiated protocol version from an initialize result. stdout is the MCP +// channel, so non-JSON payloads are dropped rather than corrupting the stream. +func (s *bridgeSession) emit(raw []byte) { + raw = bytes.TrimSpace(raw) + if len(raw) == 0 || !json.Valid(raw) { + return + } + var probe struct { + Result struct { + ProtocolVersion string `json:"protocolVersion"` + } `json:"result"` + } + if err := json.Unmarshal(raw, &probe); err == nil && probe.Result.ProtocolVersion != "" { + s.proto = probe.Result.ProtocolVersion + } + s.writer.emitLine(raw) +} + +// writeError surfaces a JSON-RPC error for a request id so the client never hangs on +// a failed call. Notifications and unparseable/batch lines have no single id to +// reply to and are skipped. +func (s *bridgeSession) writeError(requestLine []byte, detail string) { + var req struct { + ID json.RawMessage `json:"id"` + Method string `json:"method"` + } + if err := json.Unmarshal(requestLine, &req); err != nil || len(req.ID) == 0 || req.Method == "" { + return + } + message := "Checkmarx MCP request failed" + if detail != "" { + message = "Checkmarx MCP " + detail + } + reply, err := json.Marshal(map[string]interface{}{ + "jsonrpc": "2.0", + "id": req.ID, + "error": map[string]interface{}{"code": jsonrpcInternalError, "message": message}, + }) + if err != nil { + return + } + s.writer.emitLine(reply) +} + +// watchForCredential polls cx config on disk while the bridge is degraded. As soon +// as a usable credential appears (e.g. the developer ran `cx auth login`), it +// establishes the remote session, flips to connected, and pushes a +// notifications/tools/list_changed so the client auto-fetches the real tools — then +// the watcher exits. It also exits on EOF (stop closed). +func (s *bridgeSession) watchForCredential(client *http.Client, urlOverride string, stop <-chan struct{}) { + ticker := time.NewTicker(credentialPollInterval) + defer ticker.Stop() + for { + select { + case <-stop: + return + case <-ticker.C: + if s.tryHeal(client, urlOverride) { + return + } + } + } +} + +// tryHeal attempts one self-heal cycle. Returns true once the bridge is connected. +func (s *bridgeSession) tryHeal(client *http.Client, urlOverride string) bool { + if s.isConnected() { + return true + } + reloadConfig() // the definitive disk re-read; viper alone is a stale startup snapshot + key := s.resolveKey() + if key == "" { + return false // still no credential — stay degraded + } + invalidateTokenCache() // a fresh login may target a different tenant + mcpURL, err := deriveMCPURL(key, urlOverride) + if err != nil || mcpURL == "" { + return false + } + if !s.establishRemoteSession(client, mcpURL, key) { + return false // remote not reachable / credential not yet valid — retry next tick + } + s.promote(key, mcpURL) + s.notifyToolsChanged() + return true +} + +// establishRemoteSession performs the remote MCP initialize handshake on the +// bridge's behalf (the client's initialize was answered locally, so the remote never +// saw it). It sends a MINIMAL synthesized initialize (clientInfo=cx-bridge, no +// client capabilities), captures the Mcp-Session-Id + negotiated protocolVersion, +// drives notifications/initialized, and DISCARDS the response body (the client +// already received its handshake result). Returns false if the remote is unreachable +// or rejects the credential, so the watcher retries. +func (s *bridgeSession) establishRemoteSession(client *http.Client, mcpURL, apiKey string) bool { + s.mu.Lock() + proto := s.clientProto + s.mu.Unlock() + if proto == "" { + proto = defaultProtocolVersion() + } + + initReq, err := json.Marshal(map[string]interface{}{ + "jsonrpc": "2.0", + "id": 0, + "method": "initialize", + "params": map[string]interface{}{ + "protocolVersion": proto, + "capabilities": map[string]interface{}{}, + "clientInfo": map[string]interface{}{"name": "cx-bridge", "version": s.version}, + }, + }) + if err != nil { + return false + } + + resp, err := s.post(client, mcpURL, apiKey, initReq) + if err != nil { + return false + } + sid := resp.Header.Get("Mcp-Session-Id") + code := resp.StatusCode + _, _ = io.Copy(io.Discard, resp.Body) + resp.Body.Close() + if code >= 400 { + return false + } + + // Session id + the version we negotiated drive every subsequent proxied request. + s.id = sid + s.proto = proto + + // Complete the remote lifecycle so it accepts tools/list and tool calls. + notif, _ := json.Marshal(map[string]interface{}{"jsonrpc": "2.0", "method": "notifications/initialized"}) + if nresp, nerr := s.post(client, mcpURL, apiKey, notif); nerr == nil { + _, _ = io.Copy(io.Discard, nresp.Body) + nresp.Body.Close() + } + // remoteReady is shared with the read loop; guard the write with mu (the + // struct contract lists it among the mutex-protected fields). + s.mu.Lock() + s.remoteReady = true + s.mu.Unlock() + return true +} + +// promote publishes the resolved credential/URL and flips the bridge to connected. +func (s *bridgeSession) promote(key, mcpURL string) { + s.mu.Lock() + s.apiKey = key + s.mcpURL = mcpURL + s.state = stateConnected + s.mu.Unlock() +} + +func (s *bridgeSession) isConnected() bool { + s.mu.Lock() + defer s.mu.Unlock() + return s.state == stateConnected +} + +// notifyToolsChanged pushes an unsolicited notifications/tools/list_changed so the +// client re-fetches tools/list (now served by the connected remote) with no reload. +func (s *bridgeSession) notifyToolsChanged() { + notif, err := json.Marshal(map[string]interface{}{"jsonrpc": "2.0", "method": "notifications/tools/list_changed"}) + if err != nil { + return + } + s.writer.emitLine(notif) +} diff --git a/internal/commands/agenthooks/mcp/bridge_test.go b/internal/commands/agenthooks/mcp/bridge_test.go new file mode 100644 index 000000000..62b7a11b0 --- /dev/null +++ b/internal/commands/agenthooks/mcp/bridge_test.go @@ -0,0 +1,644 @@ +//go:build !integration + +package mcp + +import ( + "bytes" + "encoding/base64" + "encoding/json" + "errors" + "io" + "net/http" + "net/http/httptest" + "os" + "strings" + "sync" + "testing" + "time" + + commonParams "github.com/checkmarx/ast-cli/internal/params" + "github.com/spf13/viper" + "github.com/stretchr/testify/assert" +) + +func TestNewBridgeClient(t *testing.T) { + t.Run("no proxy configured keeps the default transport", func(t *testing.T) { + viper.Set(commonParams.ProxyKey, "") + c := newBridgeClient() + assert.Equal(t, bridgeRequestTimeout, c.Timeout) + // nil transport → Go's default, which honors HTTPS_PROXY / HTTP_PROXY / NO_PROXY at request time. + assert.Nil(t, c.Transport) + }) + + t.Run("configured proxy routes through the proxy-aware client", func(t *testing.T) { + viper.Set(commonParams.ProxyKey, "http://proxy.corp:8080") + defer viper.Set(commonParams.ProxyKey, "") + c := newBridgeClient() + assert.Equal(t, bridgeRequestTimeout, c.Timeout) + tr, ok := c.Transport.(*http.Transport) + assert.True(t, ok, "expected a proxy-aware *http.Transport") + assert.NotNil(t, tr.Proxy, "expected a proxy resolver") + req, err := http.NewRequest(http.MethodGet, "https://mcp.example.com", nil) + assert.NoError(t, err) + proxyURL, err := tr.Proxy(req) + assert.NoError(t, err) + assert.NotNil(t, proxyURL) + assert.Equal(t, "http://proxy.corp:8080", proxyURL.String()) + }) +} + +func TestBuildSecurityMCPURL(t *testing.T) { + tests := []struct { + name string + issuer string + want string + wantErr bool + }{ + { + name: "regional iam host maps to ast", + issuer: "https://eu.iam.checkmarx.net/auth/realms/cx_seg", + want: "https://eu.ast.checkmarx.net/api/security-mcp/mcp/cx_seg", + }, + { + name: "us no-prefix iam host maps to ast", + issuer: "https://iam.checkmarx.net/auth/realms/myorg", + want: "https://ast.checkmarx.net/api/security-mcp/mcp/myorg", + }, + { + name: "trailing slash tolerated", + issuer: "https://deu.iam.checkmarx.net/auth/realms/tenant1/", + want: "https://deu.ast.checkmarx.net/api/security-mcp/mcp/tenant1", + }, + { + name: "non-iam host cannot be mapped (use CX_MCP_URL)", + issuer: "https://example.com/auth/realms/x", + wantErr: true, + }, + { + // Dev/on-prem hosts like iam-dev.dev.cxast.net do NOT follow the + // .iam.checkmarx.net pattern, so derivation must fail and the + // caller is expected to set CX_MCP_URL (see TestDeriveMCPURL_CXMCPURLOverride). + name: "dev host is not auto-mappable", + issuer: "https://iam-dev.dev.cxast.net/auth/realms/dev_tenant", + wantErr: true, + }, + { + name: "missing realm segment", + issuer: "https://eu.iam.checkmarx.net", + wantErr: true, + }, + { + name: "empty issuer", + issuer: "", + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := buildSecurityMCPURL(tt.issuer) + if tt.wantErr { + assert.Error(t, err) + return + } + assert.NoError(t, err) + assert.Equal(t, tt.want, got) + }) + } +} + +// TestDeriveMCPURL_FlagOverride covers the --mcp-url flag (preferred for MCP +// clients that pass args but not env), which must win even over CX_MCP_URL. +func TestDeriveMCPURL_FlagOverride(t *testing.T) { + t.Setenv("CX_MCP_URL", "https://from-env.example.com/api/security-mcp/mcp/x") + got, err := deriveMCPURL("ignored-because-override-set", + "https://ast-master-components.dev.cxast.net/api/security-mcp/mcp/dev_tenant/") + assert.NoError(t, err) + assert.Equal(t, "https://ast-master-components.dev.cxast.net/api/security-mcp/mcp/dev_tenant", got) +} + +// TestDeriveMCPURL_CXMCPURLOverride covers the env escape hatch used for dev/on-prem +// environments (e.g. ast-master-components.dev.cxast.net / dev_tenant) whose host +// naming the iam->ast mapping cannot derive. +func TestDeriveMCPURL_CXMCPURLOverride(t *testing.T) { + t.Setenv("CX_MCP_URL", "https://ast-master-components.dev.cxast.net/api/security-mcp/mcp/dev_tenant/") + got, err := deriveMCPURL("ignored-because-override-set", "") + assert.NoError(t, err) + assert.Equal(t, "https://ast-master-components.dev.cxast.net/api/security-mcp/mcp/dev_tenant", got) +} + +func TestDeriveMCPURL_NoCredential(t *testing.T) { + _ = os.Unsetenv("CX_MCP_URL") + _, err := deriveMCPURL("", "") + assert.Error(t, err) +} + +// makeJWT builds an unsigned (alg=none) JWT carrying the given claims. ParseUnverified +// (used by wrappers.ExtractFromTokenClaims) does not check the signature, so this is +// sufficient to exercise claim extraction in tests. +func makeJWT(claims map[string]interface{}) string { + header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"none","typ":"JWT"}`)) + payloadBytes, _ := json.Marshal(claims) + payload := base64.RawURLEncoding.EncodeToString(payloadBytes) + return header + "." + payload + ".sig" +} + +// stubAccessToken overrides the getAccessToken seam for the duration of the test. +func stubAccessToken(t *testing.T, token string, err error) { + t.Helper() + prev := getAccessToken + getAccessToken = func() (string, error) { return token, err } + t.Cleanup(func() { getAccessToken = prev }) +} + +// TestDeriveMCPURL_AstBaseURLClaim: the authoritative path — the access token's +// ast-base-url claim supplies the AST base; the realm comes from the refresh token. +func TestDeriveMCPURL_AstBaseURLClaim(t *testing.T) { + t.Setenv("CX_MCP_URL", "") + refresh := makeJWT(map[string]interface{}{"iss": "https://eu.iam.checkmarx.net/auth/realms/cx_seg"}) + stubAccessToken(t, makeJWT(map[string]interface{}{"ast-base-url": "https://eu.ast.checkmarx.net"}), nil) + + got, err := deriveMCPURL(refresh, "") + assert.NoError(t, err) + assert.Equal(t, "https://eu.ast.checkmarx.net/api/security-mcp/mcp/cx_seg", got) +} + +// TestDeriveMCPURL_AstBaseURLClaimTrailingSlash: a base URL with a trailing slash +// must not produce a doubled separator. +func TestDeriveMCPURL_AstBaseURLClaimTrailingSlash(t *testing.T) { + t.Setenv("CX_MCP_URL", "") + refresh := makeJWT(map[string]interface{}{"iss": "https://eu.iam.checkmarx.net/auth/realms/cx_seg"}) + stubAccessToken(t, makeJWT(map[string]interface{}{"ast-base-url": "https://eu.ast.checkmarx.net/"}), nil) + + got, err := deriveMCPURL(refresh, "") + assert.NoError(t, err) + assert.Equal(t, "https://eu.ast.checkmarx.net/api/security-mcp/mcp/cx_seg", got) +} + +// TestDeriveMCPURL_DevHostResolvesViaAstBaseURL: a dev host whose iam->ast swap is +// NOT mappable (regression vs TestBuildSecurityMCPURL "dev host is not auto-mappable") +// now resolves automatically because the access token carries ast-base-url. +func TestDeriveMCPURL_DevHostResolvesViaAstBaseURL(t *testing.T) { + t.Setenv("CX_MCP_URL", "") + refresh := makeJWT(map[string]interface{}{"iss": "https://iam-dev.dev.cxast.net/auth/realms/dev_tenant"}) + stubAccessToken(t, makeJWT(map[string]interface{}{"ast-base-url": "https://ast-master-components.dev.cxast.net"}), nil) + + got, err := deriveMCPURL(refresh, "") + assert.NoError(t, err) + assert.Equal(t, "https://ast-master-components.dev.cxast.net/api/security-mcp/mcp/dev_tenant", got) +} + +// TestDeriveMCPURL_FallsBackToSwap: when the access-token exchange fails OR the claim +// is absent, a standard cloud host still resolves via the offline iam->ast swap. +func TestDeriveMCPURL_FallsBackToSwap(t *testing.T) { + refresh := makeJWT(map[string]interface{}{"iss": "https://eu.iam.checkmarx.net/auth/realms/cx_seg"}) + want := "https://eu.ast.checkmarx.net/api/security-mcp/mcp/cx_seg" + + t.Run("exchange error", func(t *testing.T) { + t.Setenv("CX_MCP_URL", "") + stubAccessToken(t, "", errors.New("token endpoint unreachable")) + got, err := deriveMCPURL(refresh, "") + assert.NoError(t, err) + assert.Equal(t, want, got) + }) + + t.Run("claim absent", func(t *testing.T) { + t.Setenv("CX_MCP_URL", "") + stubAccessToken(t, makeJWT(map[string]interface{}{"sub": "no-base-url-here"}), nil) + got, err := deriveMCPURL(refresh, "") + assert.NoError(t, err) + assert.Equal(t, want, got) + }) +} + +// TestDeriveMCPURL_LadderPrecedence asserts flag > CX_MCP_URL > ast-base-url(access) +// > iam->ast swap by removing one tier at a time. The ast-base-url and swap tiers are +// disambiguated by using a dev host (swap fails) so a correct result proves the claim +// path was taken, not the swap. +func TestDeriveMCPURL_LadderPrecedence(t *testing.T) { + devRefresh := makeJWT(map[string]interface{}{"iss": "https://iam-dev.dev.cxast.net/auth/realms/dev_tenant"}) + cloudRefresh := makeJWT(map[string]interface{}{"iss": "https://eu.iam.checkmarx.net/auth/realms/cx_seg"}) + claimToken := makeJWT(map[string]interface{}{"ast-base-url": "https://ast-master-components.dev.cxast.net"}) + + t.Run("flag wins over env and claim", func(t *testing.T) { + t.Setenv("CX_MCP_URL", "https://from-env.example.com/api/security-mcp/mcp/x") + stubAccessToken(t, claimToken, nil) + got, err := deriveMCPURL(devRefresh, "https://from-flag.example.com/api/security-mcp/mcp/x/") + assert.NoError(t, err) + assert.Equal(t, "https://from-flag.example.com/api/security-mcp/mcp/x", got) + }) + + t.Run("env wins over claim", func(t *testing.T) { + t.Setenv("CX_MCP_URL", "https://from-env.example.com/api/security-mcp/mcp/x") + stubAccessToken(t, claimToken, nil) + got, err := deriveMCPURL(devRefresh, "") + assert.NoError(t, err) + assert.Equal(t, "https://from-env.example.com/api/security-mcp/mcp/x", got) + }) + + t.Run("claim wins over swap", func(t *testing.T) { + t.Setenv("CX_MCP_URL", "") + stubAccessToken(t, claimToken, nil) + got, err := deriveMCPURL(devRefresh, "") + assert.NoError(t, err) + assert.Equal(t, "https://ast-master-components.dev.cxast.net/api/security-mcp/mcp/dev_tenant", got) + }) + + t.Run("swap is last resort", func(t *testing.T) { + t.Setenv("CX_MCP_URL", "") + stubAccessToken(t, "", errors.New("no exchange")) + got, err := deriveMCPURL(cloudRefresh, "") + assert.NoError(t, err) + assert.Equal(t, "https://eu.ast.checkmarx.net/api/security-mcp/mcp/cx_seg", got) + }) +} + +func TestRealmFromIssuer(t *testing.T) { + tests := []struct { + name string + issuer string + want string + wantErr bool + }{ + {name: "standard realm", issuer: "https://eu.iam.checkmarx.net/auth/realms/cx_seg", want: "cx_seg"}, + {name: "trailing slash", issuer: "https://eu.iam.checkmarx.net/auth/realms/tenant1/", want: "tenant1"}, + {name: "no path", issuer: "https://eu.iam.checkmarx.net", wantErr: true}, + {name: "empty issuer", issuer: "", wantErr: true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := realmFromIssuer(tt.issuer) + if tt.wantErr { + assert.Error(t, err) + return + } + assert.NoError(t, err) + assert.Equal(t, tt.want, got) + }) + } +} + +// ---- self-heal / resilience test support ---- + +// syncBuffer is a thread-safe buffer so the read loop, watcher goroutine, and the +// test can touch stdout output without a data race. +type syncBuffer struct { + mu sync.Mutex + buf bytes.Buffer +} + +func (b *syncBuffer) Write(p []byte) (int, error) { + b.mu.Lock() + defer b.mu.Unlock() + return b.buf.Write(p) +} + +func (b *syncBuffer) String() string { + b.mu.Lock() + defer b.mu.Unlock() + return b.buf.String() +} + +// setupBridgeTest isolates global state: clears the credential, stubs the disk/cache +// seams to no-ops, and speeds the watcher poll. All originals are restored via +// t.Cleanup. +func setupBridgeTest(t *testing.T) { + t.Helper() + prevKey := viper.GetString(commonParams.AstAPIKey) + prevReload, prevInval, prevPoll := reloadConfig, invalidateTokenCache, credentialPollInterval + t.Cleanup(func() { + viper.Set(commonParams.AstAPIKey, prevKey) + reloadConfig = prevReload + invalidateTokenCache = prevInval + credentialPollInterval = prevPoll + }) + viper.Set(commonParams.AstAPIKey, "") + reloadConfig = func() {} + invalidateTokenCache = func() {} + credentialPollInterval = time.Millisecond + t.Setenv("CHECKMARX_API_KEY", "") + t.Setenv("CX_MCP_URL", "") +} + +// waitFor polls the output buffer until it contains substr or times out. +func waitFor(t *testing.T, buf *syncBuffer, substr string) { + t.Helper() + deadline := time.Now().Add(3 * time.Second) + for time.Now().Before(deadline) { + if strings.Contains(buf.String(), substr) { + return + } + time.Sleep(2 * time.Millisecond) + } + t.Fatalf("timeout waiting for %q in output:\n%s", substr, buf.String()) +} + +func decodeLines(t *testing.T, s string) []map[string]interface{} { + t.Helper() + var out []map[string]interface{} + for _, line := range strings.Split(strings.TrimRight(s, "\n"), "\n") { + if strings.TrimSpace(line) == "" { + continue + } + var m map[string]interface{} + if err := json.Unmarshal([]byte(line), &m); err != nil { + t.Fatalf("invalid JSON-RPC line %q: %v", line, err) + } + out = append(out, m) + } + return out +} + +// TestRunBridge_UnauthAnswersInitializeLocally: with no credential, initialize is +// answered locally so the client sees the server Connected (listChanged=true). +func TestRunBridge_UnauthAnswersInitializeLocally(t *testing.T) { + setupBridgeTest(t) + in := strings.NewReader(`{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2025-06-18"}}` + "\n") + var out syncBuffer + err := runBridgeIO(in, &out, &http.Client{}, "1.2.3", "", func() string { return "" }) + assert.NoError(t, err) + + lines := decodeLines(t, out.String()) + assert.Len(t, lines, 1) + result := lines[0]["result"].(map[string]interface{}) + assert.Equal(t, "2025-06-18", result["protocolVersion"]) + tools := result["capabilities"].(map[string]interface{})["tools"].(map[string]interface{}) + assert.Equal(t, true, tools["listChanged"]) + si := result["serverInfo"].(map[string]interface{}) + assert.Equal(t, "Checkmarx Security", si["name"]) + assert.Equal(t, "1.2.3", si["version"]) +} + +// TestRunBridge_UnauthToolsListEmpty: while unauth, tools/list returns an empty list +// (so the client shows Connected with no tools, ready for the later list_changed). +func TestRunBridge_UnauthToolsListEmpty(t *testing.T) { + setupBridgeTest(t) + in := strings.NewReader( + `{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2025-06-18"}}` + "\n" + + `{"jsonrpc":"2.0","id":2,"method":"tools/list"}` + "\n") + var out syncBuffer + assert.NoError(t, runBridgeIO(in, &out, &http.Client{}, "1.0", "", func() string { return "" })) + + lines := decodeLines(t, out.String()) + assert.Len(t, lines, 2) + toolsResult := lines[1]["result"].(map[string]interface{}) + assert.Empty(t, toolsResult["tools"]) +} + +// TestWatcher_SelfHeal_EndToEnd is the headline test: unauth -> Connected (empty) -> +// simulate `cx auth login` -> watcher heals -> notifications/tools/list_changed -> +// client re-fetches tools/list -> REAL tools, with NO manual reconnect. +func TestWatcher_SelfHeal_EndToEnd(t *testing.T) { + setupBridgeTest(t) + + // Simulate the credential appearing via the injected resolveKey seam (mutex-guarded) + // so the watcher's poll never races the test write — viper itself is not concurrent-safe. + var credMu sync.Mutex + cred := "" + resolveKey := func() string { credMu.Lock(); defer credMu.Unlock(); return cred } + setCred := func(v string) { credMu.Lock(); cred = v; credMu.Unlock() } + + var mu sync.Mutex + var seenInit, seenInitialized, seenToolsList bool + var authHeader string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + raw, _ := io.ReadAll(r.Body) + var req struct { + Method string `json:"method"` + } + _ = json.Unmarshal(raw, &req) + mu.Lock() + authHeader = r.Header.Get("Authorization") + switch req.Method { + case "initialize": + seenInit = true + case "notifications/initialized": + seenInitialized = true + case "tools/list": + seenToolsList = true + } + mu.Unlock() + switch req.Method { + case "initialize": + w.Header().Set("Mcp-Session-Id", "sess-123") + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"jsonrpc":"2.0","id":0,"result":{"protocolVersion":"2025-06-18","serverInfo":{"name":"security-mcp"}}}`)) + case "notifications/initialized": + w.WriteHeader(http.StatusAccepted) + case "tools/list": + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"jsonrpc":"2.0","id":3,"result":{"tools":[{"name":"codeRemediation"},{"name":"triggerScan"}]}}`)) + default: + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"jsonrpc":"2.0","id":0,"result":{}}`)) + } + })) + defer srv.Close() + t.Setenv("CX_MCP_URL", srv.URL) // deriveMCPURL returns this; no token exchange needed + + pr, pw := io.Pipe() + var out syncBuffer + done := make(chan struct{}) + go func() { + _ = runBridgeIO(pr, &out, &http.Client{}, "9.9.9", "", resolveKey) + close(done) + }() + + // 1. initialize -> local synthetic init (empty tools, listChanged) + _, _ = io.WriteString(pw, `{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2025-06-18"}}`+"\n") + waitFor(t, &out, `"listChanged":true`) + // 2. tools/list while unauth -> empty + _, _ = io.WriteString(pw, `{"jsonrpc":"2.0","id":2,"method":"tools/list"}`+"\n") + waitFor(t, &out, `"tools":[]`) + // 3. simulate `cx auth login` writing a credential another process would have written + setCred(makeJWT(map[string]interface{}{"iss": "https://eu.iam.checkmarx.net/auth/realms/cx_seg"})) + // 4. watcher heals and pushes list_changed — no reconnect + waitFor(t, &out, "notifications/tools/list_changed") + // 5. client re-fetches tools/list -> now proxied to the remote -> REAL tools + _, _ = io.WriteString(pw, `{"jsonrpc":"2.0","id":3,"method":"tools/list"}`+"\n") + waitFor(t, &out, "codeRemediation") + _ = pw.Close() + <-done + + mu.Lock() + defer mu.Unlock() + assert.True(t, seenInit, "remote received the bridge-driven initialize") + assert.True(t, seenInitialized, "remote received notifications/initialized") + assert.True(t, seenToolsList, "remote received tools/list after heal") + assert.NotContains(t, authHeader, "Bearer", "credential forwarded raw, no Bearer prefix") + assert.NotEmpty(t, authHeader) +} + +// TestWatcher_StaysDegraded_NoCredential: with no credential the watcher never +// promotes, writes nothing, and exits cleanly on stop. +func TestWatcher_StaysDegraded_NoCredential(t *testing.T) { + setupBridgeTest(t) + var out syncBuffer + s := &bridgeSession{writer: newSyncWriter(&out), state: stateUnauth, resolveKey: func() string { return "" }} + stop := make(chan struct{}) + done := make(chan struct{}) + go func() { + s.watchForCredential(&http.Client{}, "", stop) + close(done) + }() + time.Sleep(25 * time.Millisecond) // several poll ticks + assert.False(t, s.isConnected()) + assert.Empty(t, out.String()) + close(stop) + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("watcher did not exit on stop") + } +} + +// TestSyncWriter_NoInterleave: concurrent emits never produce a partial/interleaved +// line — every output line is independently valid JSON. +func TestSyncWriter_NoInterleave(t *testing.T) { + var out syncBuffer + sw := newSyncWriter(&out) + var wg sync.WaitGroup + for i := 0; i < 50; i++ { + wg.Add(1) + go func(n int) { + defer wg.Done() + line, _ := json.Marshal(map[string]interface{}{"jsonrpc": "2.0", "id": n, "method": "notifications/tools/list_changed"}) + sw.emitLine(line) + }(i) + } + wg.Wait() + for _, line := range strings.Split(strings.TrimRight(out.String(), "\n"), "\n") { + assert.True(t, json.Valid([]byte(line)), "line not valid JSON: %q", line) + } +} + +// TestDispatch_AuthedPathUnchanged covers the connected-state 401/403 single-retry +// credential reload — behavior preserved from before the resilience change. The raw +// credential (no Bearer) is forwarded and the access token is never sent. +func TestDispatch_AuthedPathUnchanged(t *testing.T) { + const initialKey = "initial-key" + const reloadedKey = "reloaded-key" + body := []byte(`{"jsonrpc":"2.0","id":1,"method":"tools/list"}`) + + run := func(srvURL string) *syncBuffer { + var out syncBuffer + s := &bridgeSession{state: stateConnected, apiKey: initialKey, mcpURL: srvURL, writer: newSyncWriter(&out), resolveKey: productionResolveAPIKey} + s.dispatch(&http.Client{}, body) + return &out + } + + t.Run("401 then success after reload forwards raw new credential", func(t *testing.T) { + setupBridgeTest(t) + viper.Set(commonParams.AstAPIKey, reloadedKey) // simulate the rotated token visible on disk + var seenAuth []string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + seenAuth = append(seenAuth, r.Header.Get("Authorization")) + if len(seenAuth) == 1 { + w.WriteHeader(http.StatusUnauthorized) + return + } + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"jsonrpc":"2.0","id":1,"result":{"ok":true}}`)) + })) + defer srv.Close() + + out := run(srv.URL) + assert.Equal(t, []string{initialKey, reloadedKey}, seenAuth) // retried once, raw, no Bearer + assert.Contains(t, out.String(), `"ok":true`) + }) + + t.Run("no retry when reloaded credential is unchanged", func(t *testing.T) { + setupBridgeTest(t) + viper.Set(commonParams.AstAPIKey, initialKey) // same as the one already in use + var calls int + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + calls++ + w.WriteHeader(http.StatusForbidden) + })) + defer srv.Close() + + out := run(srv.URL) + assert.Equal(t, 1, calls) // no retry — credential did not change + assert.Contains(t, out.String(), `"error"`) + }) + + t.Run("exactly one retry then a JSON-RPC error", func(t *testing.T) { + setupBridgeTest(t) + viper.Set(commonParams.AstAPIKey, reloadedKey) + var calls int + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + calls++ + w.WriteHeader(http.StatusUnauthorized) + })) + defer srv.Close() + + out := run(srv.URL) + assert.Equal(t, 2, calls) // initial + exactly one retry, no loop + assert.Contains(t, out.String(), `"error"`) + }) +} + +// TestAuthedSelfHeal_ReReadsDisk proves the 401/403 path re-reads config from DISK +// (reloadConfig) BEFORE resolveKey — the new key only becomes visible after the +// disk re-read, so a token rotated by another process is actually picked up (this +// fails without reloadConfig because viper is a stale startup snapshot). +func TestAuthedSelfHeal_ReReadsDisk(t *testing.T) { + setupBridgeTest(t) + const oldKey = "old-key" + const newKey = "new-key" + viper.Set(commonParams.AstAPIKey, oldKey) // stale in-memory value + reloadConfig = func() { viper.Set(commonParams.AstAPIKey, newKey) } // disk re-read brings the new token + + var seenAuth []string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + seenAuth = append(seenAuth, r.Header.Get("Authorization")) + if len(seenAuth) == 1 { + w.WriteHeader(http.StatusUnauthorized) + return + } + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"jsonrpc":"2.0","id":1,"result":{"ok":true}}`)) + })) + defer srv.Close() + + var out syncBuffer + s := &bridgeSession{state: stateConnected, apiKey: oldKey, mcpURL: srv.URL, writer: newSyncWriter(&out), resolveKey: productionResolveAPIKey} + s.dispatch(&http.Client{}, []byte(`{"jsonrpc":"2.0","id":1,"method":"tools/list"}`)) + + assert.Equal(t, []string{oldKey, newKey}, seenAuth) // retry used the disk-refreshed key + assert.Contains(t, out.String(), `"ok":true`) +} + +// TestEstablishRemoteSession_DoesNotEmitInitResult: the bridge-driven remote +// initialize captures the session id + proto and drives notifications/initialized, +// but must NOT emit an init result to the client (it already got the local one). +func TestEstablishRemoteSession_DoesNotEmitInitResult(t *testing.T) { + setupBridgeTest(t) + var seenInitialized bool + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + raw, _ := io.ReadAll(r.Body) + var req struct { + Method string `json:"method"` + } + _ = json.Unmarshal(raw, &req) + switch req.Method { + case "initialize": + w.Header().Set("Mcp-Session-Id", "sid-9") + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"jsonrpc":"2.0","id":0,"result":{"protocolVersion":"2025-06-18"}}`)) + case "notifications/initialized": + seenInitialized = true + w.WriteHeader(http.StatusAccepted) + } + })) + defer srv.Close() + + var out syncBuffer + s := &bridgeSession{writer: newSyncWriter(&out), clientProto: "2025-06-18", version: "1.0"} + ok := s.establishRemoteSession(&http.Client{}, srv.URL, "raw-key") + + assert.True(t, ok) + assert.Equal(t, "sid-9", s.id) + assert.True(t, s.remoteReady) + assert.True(t, seenInitialized) + assert.Empty(t, out.String(), "no init result should be emitted to the client") +} diff --git a/internal/commands/agenthooks/mcp/server.go b/internal/commands/agenthooks/mcp/server.go new file mode 100644 index 000000000..4f3a8e333 --- /dev/null +++ b/internal/commands/agenthooks/mcp/server.go @@ -0,0 +1,103 @@ +package mcp + +import ( + "context" + "log" + + sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/spf13/cobra" + + "github.com/checkmarx/ast-cli/internal/commands/agenthooks/guardrails" + "github.com/checkmarx/ast-cli/internal/commands/agenthooks/mcp/tools" +) + +// NewMCPCommand creates the "cx mcp" cobra command. +// +// version is the binary version string (e.g. params.Version from the caller). +// licensed reports whether the current token carries a valid guardrail licence; +// when false, all guardrails run as pass-through (fail-open), matching the +// behaviour of the Cursor hook path. +func NewMCPCommand(version string, licensed func() bool) *cobra.Command { + cmd := &cobra.Command{ + Use: "mcp", + Short: "Start MCP server for AI assistant integration", + Long: `Start a Model Context Protocol (MCP) server that exposes Checkmarx +security guardrails as tools for AI coding assistants. + +Tools: + cx_shell_guard — Check shell commands against the organization's blocklist + cx_prompt_guard — Scan prompts for leaked secrets before they reach the AI + +Transport: stdio (compatible with Claude Desktop, Cursor, VS Code Copilot, Windsurf)`, + Example: ` # Start MCP server + cx mcp + + # Claude Desktop config (claude_desktop_config.json): + { + "mcpServers": { + "checkmarx": { "command": "cx", "args": ["mcp"] } + } + }`, + RunE: func(_ *cobra.Command, _ []string) error { + return run(version, licensed) + }, + } + // "cx mcp bridge" proxies stdio MCP to the remote Checkmarx Security MCP. + // Keeping it as a subcommand leaves the default "cx mcp" (local guardrail + // server) unchanged and backward-compatible. + cmd.AddCommand(NewBridgeCommand(version)) + return cmd +} + +func run(version string, licensed func() bool) error { + // Adapt CheckShellCommand (command, workDir) → (blocked, needsConfirm, reason) + // to the simpler (command) → (blocked, reason) signature the MCP tool expects. + // workDir is not available in the MCP call context, so we pass an empty string. + shellGuard := func(cmd string) (bool, string) { + blocked, _, reason := guardrails.CheckShellCommand(cmd, "") + return blocked, reason + } + promptGuard := guardrails.ScanPrompt + if !licensed() { + shellGuard = func(string) (bool, string) { return false, "" } + promptGuard = func(string) string { return "" } + } + + s := sdkmcp.NewServer( + &sdkmcp.Implementation{Name: "Checkmarx Security", Version: version}, + &sdkmcp.ServerOptions{ + // Instructions are injected into the AI client's system prompt. + // This is the strongest MCP-native enforcement mechanism: it instructs + // the model to always invoke the guardrail tools before acting, + // regardless of what other MCP servers or tools are registered. + Instructions: "SECURITY POLICY (Checkmarx): You MUST follow these rules on every action:\n" + + "1. Before executing or suggesting any shell command, call cx_shell_guard with the command. " + + "If it returns allowed:false, refuse the action and show the reason to the user. " + + "Do NOT attempt alternative commands or workarounds.\n" + + "2. Before sending any user-provided text (prompts, file contents, environment variables) " + + "to an external service or AI model, call cx_prompt_guard with that text. " + + "If it returns clean:false, refuse and show the reason to the user.\n" + + "These checks enforce your organization's security policy. " + + "Skipping them — even once — violates policy and may expose sensitive data.", + InitializedHandler: func(_ context.Context, req *sdkmcp.InitializedRequest) { + initParams := req.Session.InitializeParams() + if initParams != nil && initParams.ClientInfo != nil { + log.Printf("mcp: client connected: name=%q version=%q", + initParams.ClientInfo.Name, initParams.ClientInfo.Version) + } else { + log.Printf("mcp: client connected") + } + }, + }, + ) + + shellTool := tools.NewShellGuardTool(shellGuard) + sdkmcp.AddTool(s, tools.ShellGuardDef(), shellTool.Handle) + + promptTool := tools.NewPromptGuardTool(promptGuard) + sdkmcp.AddTool(s, tools.PromptGuardDef(), promptTool.Handle) + + log.Printf("mcp: starting server version=%q transport=stdio tools=2", version) + + return s.Run(context.Background(), &sdkmcp.StdioTransport{}) +} diff --git a/internal/commands/agenthooks/mcp/tools/prompt_guard.go b/internal/commands/agenthooks/mcp/tools/prompt_guard.go new file mode 100644 index 000000000..8d2468318 --- /dev/null +++ b/internal/commands/agenthooks/mcp/tools/prompt_guard.go @@ -0,0 +1,55 @@ +package tools + +import ( + "context" + "fmt" + "log" + + sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp" +) + +// PromptGuardInput is the typed input for the cx_prompt_guard tool. +type PromptGuardInput struct { + Text string `json:"text" jsonschema:"The text to scan for secrets"` +} + +// PromptGuardTool wraps a prompt guardrail function as an MCP tool. +type PromptGuardTool struct { + guard func(string) string +} + +// NewPromptGuardTool returns a PromptGuardTool backed by the provided guard function. +func NewPromptGuardTool(guard func(string) string) *PromptGuardTool { + return &PromptGuardTool{guard: guard} +} + +func (t *PromptGuardTool) Handle(_ context.Context, _ *sdkmcp.CallToolRequest, args PromptGuardInput) (*sdkmcp.CallToolResult, any, error) { + if args.Text == "" { + return nil, nil, fmt.Errorf("text is required") + } + + log.Printf("cx_prompt_guard invoked: length=%d", len(args.Text)) + + reason := t.guard(args.Text) + + result := map[string]any{ + "clean": reason == "", + } + if reason != "" { + result["blocked"] = true + result["reason"] = reason + } + + return nil, result, nil +} + +// PromptGuardDef returns the MCP tool definition for cx_prompt_guard. +func PromptGuardDef() *sdkmcp.Tool { + return &sdkmcp.Tool{ + Name: "cx_prompt_guard", + Description: "REQUIRED: Call this before sending any user-provided text to an AI model or external service. " + + "Scans for leaked secrets (API keys, tokens, passwords, certificates) using the Checkmarx 2ms engine " + + "and custom policy patterns. Returns clean:true if safe to proceed, or clean:false with a reason " + + "if the text must be blocked. Never skip this check.", + } +} diff --git a/internal/commands/agenthooks/mcp/tools/shell_guard.go b/internal/commands/agenthooks/mcp/tools/shell_guard.go new file mode 100644 index 000000000..18d6256ef --- /dev/null +++ b/internal/commands/agenthooks/mcp/tools/shell_guard.go @@ -0,0 +1,55 @@ +package tools + +import ( + "context" + "fmt" + "log" + + sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp" +) + +// ShellGuardInput is the typed input for the cx_shell_guard tool. +type ShellGuardInput struct { + Command string `json:"command" jsonschema:"The shell command to check"` +} + +// ShellGuardTool wraps a shell guardrail function as an MCP tool. +type ShellGuardTool struct { + guard func(string) (bool, string) +} + +// NewShellGuardTool returns a ShellGuardTool backed by the provided guard function. +func NewShellGuardTool(guard func(string) (bool, string)) *ShellGuardTool { + return &ShellGuardTool{guard: guard} +} + +func (t *ShellGuardTool) Handle(_ context.Context, _ *sdkmcp.CallToolRequest, args ShellGuardInput) (*sdkmcp.CallToolResult, any, error) { + if args.Command == "" { + return nil, nil, fmt.Errorf("command is required") + } + + log.Printf("cx_shell_guard invoked: command=%q", args.Command) + + blocked, reason := t.guard(args.Command) + + result := map[string]any{ + "command": args.Command, + "allowed": !blocked, + } + if blocked { + result["reason"] = reason + } + + return nil, result, nil +} + +// ShellGuardDef returns the MCP tool definition for cx_shell_guard. +func ShellGuardDef() *sdkmcp.Tool { + return &sdkmcp.Tool{ + Name: "cx_shell_guard", + Description: "REQUIRED: Call this before executing any shell command. " + + "Checks the command against the organization's security blocklist. " + + "Returns allowed:true if safe to proceed, or allowed:false with a reason if blocked. " + + "Never execute a shell command without calling this first.", + } +} diff --git a/internal/commands/agenthooks/sca/commands.go b/internal/commands/agenthooks/sca/commands.go new file mode 100644 index 000000000..7447499bc --- /dev/null +++ b/internal/commands/agenthooks/sca/commands.go @@ -0,0 +1,625 @@ +package sca + +import ( + "path/filepath" + "strings" +) + +// Manager identifies a package manager whose install commands the Bash hook +// recognises. +type Manager int + +const ( + ManagerUnknown Manager = iota + ManagerNpm + ManagerPypi + ManagerDotnet + ManagerGo + ManagerMaven +) + +// Format returns the manifest format that pairs with this manager — used to +// synthesise a temp manifest for oss-realtime. +func (m Manager) Format() Format { + switch m { + case ManagerNpm: + return FormatNpmPackageJson + case ManagerPypi: + return FormatPypiRequirements + case ManagerDotnet: + return FormatDotnetCsproj + case ManagerGo: + return FormatGoMod + case ManagerMaven: + return FormatMavenPom + } + return FormatUnknown +} + +// Package is a parsed install target. +type Package struct { + Name string + Version string // "" → unspecified (oss-realtime defaults to "latest") +} + +// InstallRequest is one recognised install invocation extracted from a shell +// command. A compound command may produce multiple requests. +type InstallRequest struct { + Manager Manager + Packages []Package + ManifestRef string // set for `pip install -r ` — Packages is empty +} + +// ParseInstall returns every recognised install invocation in command, or +// nil if none. Compound commands (&&, ||, ;, |) are split before matching, so +// `cd /repo && npm install lodash` returns one npm request, and +// `npm install lodash && pip install requests` returns two. +func ParseInstall(command string) []InstallRequest { + var requests []InstallRequest + for _, segment := range splitTopLevel(command) { + if req := parseSegment(segment); req != nil { + requests = append(requests, *req) + } + // Also descend into $(...) and `...` subshells inside the segment. + for _, sub := range extractSubshells(segment) { + requests = append(requests, ParseInstall(sub)...) + } + } + return requests +} + +// splitTopLevel splits command on top-level shell operators (&&, ||, ;, |), +// honouring single/double quotes and paren/$()/backtick nesting so we don't +// split inside string literals or subshells. +func splitTopLevel(command string) []string { + var ( + segments []string + current strings.Builder + sq, dq bool + paren int // ( and $( nesting depth + bt int // backtick nesting (0 or 1 — backticks don't nest in practice) + ) + flush := func() { + seg := strings.TrimSpace(current.String()) + if seg != "" { + segments = append(segments, seg) + } + current.Reset() + } + for i := 0; i < len(command); i++ { + c := command[i] + switch { + case sq: + if c == '\'' { + sq = false + } + current.WriteByte(c) + case dq: + if c == '"' { + dq = false + } else if c == '\\' && i+1 < len(command) { + current.WriteByte(c) + i++ + current.WriteByte(command[i]) + continue + } + current.WriteByte(c) + case c == '\'': + sq = true + current.WriteByte(c) + case c == '"': + dq = true + current.WriteByte(c) + case c == '`': + if bt == 0 { + bt = 1 + } else { + bt = 0 + } + current.WriteByte(c) + case bt > 0: + current.WriteByte(c) + case c == '$' && i+1 < len(command) && command[i+1] == '(': + paren++ + current.WriteByte(c) + i++ + current.WriteByte(command[i]) // the '(' + case c == '(': + paren++ + current.WriteByte(c) + case c == ')': + if paren > 0 { + paren-- + } + current.WriteByte(c) + case paren > 0: + current.WriteByte(c) + case c == '&' && i+1 < len(command) && command[i+1] == '&': + flush() + i++ + case c == '|' && i+1 < len(command) && command[i+1] == '|': + flush() + i++ + case c == ';': + flush() + case c == '|': + flush() + default: + current.WriteByte(c) + } + } + flush() + return segments +} + +// extractSubshells pulls out the bodies of $(...) and `...` subshells inside a +// segment so we can recursively parse them. Nested $() is respected. +func extractSubshells(segment string) []string { + var bodies []string + for i := 0; i < len(segment); i++ { + c := segment[i] + if c == '$' && i+1 < len(segment) && segment[i+1] == '(' { + depth := 1 + j := i + 2 + for j < len(segment) && depth > 0 { + switch segment[j] { + case '(': + depth++ + case ')': + depth-- + } + if depth == 0 { + break + } + j++ + } + if j > i+2 { + bodies = append(bodies, segment[i+2:j]) + } + i = j + } else if c == '`' { + j := i + 1 + for j < len(segment) && segment[j] != '`' { + j++ + } + if j > i+1 && j < len(segment) { + bodies = append(bodies, segment[i+1:j]) + } + i = j + } + } + return bodies +} + +// parseSegment recognises a single (non-compound) install command. Returns +// nil if the segment is not an install invocation. +func parseSegment(segment string) *InstallRequest { + tokens := tokenize(segment) + tokens = dropLeadingNoOps(tokens) + if len(tokens) == 0 { + return nil + } + + mgr, rest := matchManager(tokens) + if mgr == ManagerUnknown { + return nil + } + + switch mgr { + case ManagerNpm: + return parseNpmArgs(rest) + case ManagerPypi: + return parsePypiArgs(rest) + case ManagerDotnet: + return parseDotnetArgs(tokens, rest) + case ManagerGo: + return parseGoArgs(rest) + case ManagerMaven: + return parseMavenArgs(rest) + } + return nil +} + +// tokenize splits a command into whitespace-separated tokens, preserving +// quoted strings, $(...) subshells, and `...` backticks as single tokens so +// downstream parsing isn't confused by internal whitespace inside shell +// expansions. Surrounding single/double quotes are stripped; subshell +// markers ($(, ), `) are preserved verbatim so callers can recognise them +// and skip those tokens. +func tokenize(segment string) []string { + var ( + tokens []string + cur strings.Builder + sq, dq bool + paren int + inBT bool + ) + flush := func() { + if cur.Len() > 0 { + tokens = append(tokens, cur.String()) + cur.Reset() + } + } + for i := 0; i < len(segment); i++ { + c := segment[i] + switch { + case sq: + if c == '\'' { + sq = false + } else { + cur.WriteByte(c) + } + case dq: + if c == '"' { + dq = false + } else { + cur.WriteByte(c) + } + case inBT: + cur.WriteByte(c) + if c == '`' { + inBT = false + } + case c == '\'': + sq = true + case c == '"': + dq = true + case c == '`': + cur.WriteByte(c) + inBT = true + case c == '$' && i+1 < len(segment) && segment[i+1] == '(': + paren++ + cur.WriteByte(c) + i++ + cur.WriteByte(segment[i]) // '(' + case paren > 0: + cur.WriteByte(c) + if c == ')' { + paren-- + } else if c == '(' { + paren++ + } + case c == ' ' || c == '\t' || c == '\n': + flush() + default: + cur.WriteByte(c) + } + } + flush() + return tokens +} + +// isShellExpansion reports whether tok is a $(...) or `...` expansion that +// should be skipped during package parsing (we can't statically know what +// the expansion evaluates to). +func isShellExpansion(tok string) bool { + return strings.HasPrefix(tok, "$(") || strings.HasPrefix(tok, "`") +} + +// dropLeadingNoOps strips command prefixes that don't change which install +// command runs: `sudo`, `time`, `nice`, env-style `FOO=bar`, and a leading +// `cd ` chain (though `cd && ...` is split out at the operator level — +// this catches `(cd /foo; ...)`-style stripped patterns and standalone +// `bash -c ""` already-unwrapped by the segment split). +func dropLeadingNoOps(tokens []string) []string { + for len(tokens) > 0 { + t := tokens[0] + switch { + case t == "sudo", t == "time", t == "nice": + tokens = tokens[1:] + case strings.Contains(t, "=") && !strings.HasPrefix(t, "-") && isEnvAssignment(t): + tokens = tokens[1:] + default: + return tokens + } + } + return tokens +} + +func isEnvAssignment(t string) bool { + idx := strings.Index(t, "=") + if idx <= 0 { + return false + } + for i := 0; i < idx; i++ { + c := t[i] + if !((c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') || (c >= '0' && c <= '9') || c == '_') { + return false + } + } + return true +} + +// matchManager finds the install verb at the head of tokens (after no-op +// stripping). Returns (Manager, remaining tokens after the verb). +func matchManager(tokens []string) (Manager, []string) { + if len(tokens) < 2 { + return ManagerUnknown, nil + } + t0 := tokens[0] + + // Multi-token verbs first. + if len(tokens) >= 3 && (t0 == "python" || t0 == "python3") && tokens[1] == "-m" && tokens[2] == "pip" { + if len(tokens) >= 4 && tokens[3] == "install" { + return ManagerPypi, tokens[4:] + } + return ManagerUnknown, nil + } + if t0 == "uv" && tokens[1] == "pip" { + if len(tokens) >= 3 && tokens[2] == "install" { + return ManagerPypi, tokens[3:] + } + return ManagerUnknown, nil + } + if t0 == "dotnet" && tokens[1] == "add" { + if len(tokens) >= 3 && tokens[2] == "package" { + return ManagerDotnet, tokens[3:] + } + return ManagerUnknown, nil + } + + // Single-verb forms. + verb := tokens[1] + switch t0 { + case "npm": + if verb == "install" || verb == "i" || verb == "add" { + return ManagerNpm, tokens[2:] + } + case "yarn": + if verb == "add" { + return ManagerNpm, tokens[2:] + } + case "pnpm": + if verb == "add" || verb == "install" || verb == "i" { + return ManagerNpm, tokens[2:] + } + case "pip", "pip3": + if verb == "install" { + return ManagerPypi, tokens[2:] + } + case "pipenv": + if verb == "install" { + return ManagerPypi, tokens[2:] + } + case "poetry": + if verb == "add" { + return ManagerPypi, tokens[2:] + } + case "uv": + if verb == "add" { + return ManagerPypi, tokens[2:] + } + case "nuget": + if verb == "install" { + return ManagerDotnet, tokens[2:] + } + case "go": + if verb == "get" || verb == "install" { + return ManagerGo, tokens[2:] + } + case "mvn": + if verb == "dependency:get" { + return ManagerMaven, tokens[2:] + } + } + return ManagerUnknown, nil +} + +// --- per-manager argument parsing --- + +func parseNpmArgs(args []string) *InstallRequest { + var pkgs []Package + for _, a := range args { + if skipArg(a) { + continue + } + pkgs = append(pkgs, parseNpmSpec(a)) + } + if len(pkgs) == 0 { + return nil + } + return &InstallRequest{Manager: ManagerNpm, Packages: pkgs} +} + +// parseNpmSpec handles bare names, name@version, scoped packages (@scope/pkg +// and @scope/pkg@version). +func parseNpmSpec(spec string) Package { + if strings.HasPrefix(spec, "@") { + rest := spec[1:] + idx := strings.Index(rest, "@") + if idx < 0 { + return Package{Name: spec} + } + return Package{Name: "@" + rest[:idx], Version: normalizeSemver(rest[idx+1:])} + } + idx := strings.LastIndex(spec, "@") + if idx <= 0 { + return Package{Name: spec} + } + return Package{Name: spec[:idx], Version: normalizeSemver(spec[idx+1:])} +} + +// normalizeSemver pads a bare numeric version with missing segments so the +// SCA scanner can look it up — e.g. "4.10" → "4.10.0", "4" → "4.0.0". +// Versions that contain non-numeric characters (ranges, pre-releases) are +// returned unchanged. +func normalizeSemver(v string) string { + if v == "" { + return v + } + parts := strings.Split(v, ".") + if len(parts) >= 3 { + return v + } + for _, p := range parts { + for _, c := range p { + if c < '0' || c > '9' { + return v + } + } + } + for len(parts) < 3 { + parts = append(parts, "0") + } + return strings.Join(parts, ".") +} + +// parsePypiArgs handles bare names, name==ver, name>=ver style, and -r/-c +// requirement-file references. +func parsePypiArgs(args []string) *InstallRequest { + var ( + pkgs []Package + refs []string + skip bool + ) + for i, a := range args { + if skip { + skip = false + continue + } + switch a { + case "-r", "--requirement", "-c", "--constraint": + if i+1 < len(args) { + refs = append(refs, args[i+1]) + skip = true + } + continue + case "-e", "--editable": + // editable install of a local path — skip the next arg (the path). + skip = true + continue + } + if skipArg(a) { + continue + } + pkgs = append(pkgs, parsePypiSpec(a)) + } + // We collapse to a single request here. Multi-ref pip (-r a.txt -r b.txt) + // only emits the first ref; multi-ref is rare and supporting it cleanly + // would require parseSegment to return []*InstallRequest. + if len(refs) > 0 { + return &InstallRequest{Manager: ManagerPypi, ManifestRef: refs[0]} + } + if len(pkgs) > 0 { + return &InstallRequest{Manager: ManagerPypi, Packages: pkgs} + } + return nil +} + +func parsePypiSpec(spec string) Package { + // Strip any comparator/version specifier; keep only exact-match versions. + // requirements.txt convention: pkg==ver. Anything else → version unknown. + for _, op := range []string{"==", ">=", "<=", "~=", "!=", ">", "<"} { + if idx := strings.Index(spec, op); idx >= 0 { + name := spec[:idx] + if op == "==" { + return Package{Name: name, Version: spec[idx+2:]} + } + return Package{Name: name} + } + } + return Package{Name: spec} +} + +// parseDotnetArgs supports both `dotnet add package [-v ]` and +// `nuget install [-Version ]`. tokens is the full token slice so +// we can decide which verb was used. +func parseDotnetArgs(tokens, args []string) *InstallRequest { + isNuget := len(tokens) > 0 && tokens[0] == "nuget" + + var ( + name, version string + skip bool + ) + for i, a := range args { + if skip { + skip = false + continue + } + switch { + case a == "-v" || a == "--version" || (isNuget && (a == "-Version" || a == "-version")): + if i+1 < len(args) { + version = args[i+1] + skip = true + } + continue + case skipArg(a): + continue + } + if name == "" { + name = a + } + } + if name == "" { + return nil + } + return &InstallRequest{Manager: ManagerDotnet, Packages: []Package{{Name: name, Version: version}}} +} + +// parseGoArgs handles `go get pkg`, `go get pkg@v1`, multi-pkg variants. Bare +// `go get` / `go install` with no positional pkg → no request. +func parseGoArgs(args []string) *InstallRequest { + var pkgs []Package + for _, a := range args { + if skipArg(a) { + continue + } + if a == "." || a == "./..." { + continue + } + idx := strings.LastIndex(a, "@") + if idx <= 0 { + pkgs = append(pkgs, Package{Name: a}) + } else { + pkgs = append(pkgs, Package{Name: a[:idx], Version: a[idx+1:]}) + } + } + if len(pkgs) == 0 { + return nil + } + return &InstallRequest{Manager: ManagerGo, Packages: pkgs} +} + +// parseMavenArgs extracts the artifact spec from +// `mvn dependency:get -Dartifact=groupId:artifactId:version`. +func parseMavenArgs(args []string) *InstallRequest { + for _, a := range args { + if !strings.HasPrefix(a, "-Dartifact=") { + continue + } + spec := strings.TrimPrefix(a, "-Dartifact=") + parts := strings.Split(spec, ":") + if len(parts) < 2 { + continue + } + name := parts[0] + ":" + parts[1] + ver := "" + if len(parts) >= 3 { + ver = parts[2] + } + return &InstallRequest{Manager: ManagerMaven, Packages: []Package{{Name: name, Version: ver}}} + } + return nil +} + +// isFlag reports whether the token starts with '-' (treating '-' alone or +// e.g. `--no-progress` and `-D` as flags) but excluding "@scope/pkg" style. +func isFlag(t string) bool { + return strings.HasPrefix(t, "-") +} + +// skipArg reports whether the token should be ignored entirely when scanning +// for package names — flags and unresolved shell expansions. +func skipArg(t string) bool { + return isFlag(t) || isShellExpansion(t) +} + +// resolveRef returns an absolute (or working-directory-relative) path for a +// requirements file reference. Kept here so the Bash hook can hand it to the +// scanner directly. +func resolveRef(ref, workDir string) string { + if filepath.IsAbs(ref) || workDir == "" { + return ref + } + return filepath.Join(workDir, ref) +} diff --git a/internal/commands/agenthooks/sca/commands_test.go b/internal/commands/agenthooks/sca/commands_test.go new file mode 100644 index 000000000..0833559bf --- /dev/null +++ b/internal/commands/agenthooks/sca/commands_test.go @@ -0,0 +1,345 @@ +//go:build !integration + +package sca + +import ( + "reflect" + "sort" + "testing" +) + +func sortedPackages(pkgs []Package) []Package { + out := append([]Package(nil), pkgs...) + sort.Slice(out, func(i, j int) bool { + if out[i].Name != out[j].Name { + return out[i].Name < out[j].Name + } + return out[i].Version < out[j].Version + }) + return out +} + +func wantPackages(t *testing.T, got, want []Package) { + t.Helper() + if !reflect.DeepEqual(sortedPackages(got), sortedPackages(want)) { + t.Errorf("packages mismatch\ngot: %#v\nwant: %#v", got, want) + } +} + +func TestParseInstall_SimpleNpm(t *testing.T) { + tests := []struct { + command string + want []Package + }{ + {"npm install lodash", []Package{{Name: "lodash"}}}, + {"npm i lodash", []Package{{Name: "lodash"}}}, + {"npm add lodash@4.17.21", []Package{{Name: "lodash", Version: "4.17.21"}}}, + {"yarn add react", []Package{{Name: "react"}}}, + {"pnpm add react@18.0.0", []Package{{Name: "react", Version: "18.0.0"}}}, + {"pnpm install lodash", []Package{{Name: "lodash"}}}, + {"npm install @types/node@18.0.0", []Package{{Name: "@types/node", Version: "18.0.0"}}}, + {"npm install @types/node", []Package{{Name: "@types/node"}}}, + } + for _, tt := range tests { + got := ParseInstall(tt.command) + if len(got) != 1 { + t.Errorf("%q: got %d requests, want 1", tt.command, len(got)) + continue + } + if got[0].Manager != ManagerNpm { + t.Errorf("%q: got manager %v, want npm", tt.command, got[0].Manager) + } + wantPackages(t, got[0].Packages, tt.want) + } +} + +func TestParseInstall_SimplePypi(t *testing.T) { + tests := []struct { + command string + want []Package + }{ + {"pip install requests", []Package{{Name: "requests"}}}, + {"pip install requests==2.25.1", []Package{{Name: "requests", Version: "2.25.1"}}}, + {"pip install requests>=2.0", []Package{{Name: "requests"}}}, + {"pip3 install requests", []Package{{Name: "requests"}}}, + {"python -m pip install requests", []Package{{Name: "requests"}}}, + {"python3 -m pip install requests==2.25.1", []Package{{Name: "requests", Version: "2.25.1"}}}, + {"pipenv install requests", []Package{{Name: "requests"}}}, + {"poetry add requests", []Package{{Name: "requests"}}}, + {"uv add requests", []Package{{Name: "requests"}}}, + {"uv pip install requests", []Package{{Name: "requests"}}}, + } + for _, tt := range tests { + got := ParseInstall(tt.command) + if len(got) != 1 { + t.Errorf("%q: got %d requests, want 1", tt.command, len(got)) + continue + } + if got[0].Manager != ManagerPypi { + t.Errorf("%q: got manager %v, want pypi", tt.command, got[0].Manager) + } + wantPackages(t, got[0].Packages, tt.want) + } +} + +func TestParseInstall_SimpleDotnet(t *testing.T) { + tests := []struct { + command string + want []Package + }{ + {"dotnet add package Newtonsoft.Json", []Package{{Name: "Newtonsoft.Json"}}}, + {"dotnet add package Newtonsoft.Json -v 13.0.1", []Package{{Name: "Newtonsoft.Json", Version: "13.0.1"}}}, + {"dotnet add package Newtonsoft.Json --version 13.0.1", []Package{{Name: "Newtonsoft.Json", Version: "13.0.1"}}}, + {"nuget install Newtonsoft.Json", []Package{{Name: "Newtonsoft.Json"}}}, + {"nuget install Newtonsoft.Json -Version 13.0.1", []Package{{Name: "Newtonsoft.Json", Version: "13.0.1"}}}, + } + for _, tt := range tests { + got := ParseInstall(tt.command) + if len(got) != 1 { + t.Errorf("%q: got %d requests, want 1", tt.command, len(got)) + continue + } + if got[0].Manager != ManagerDotnet { + t.Errorf("%q: got manager %v, want dotnet", tt.command, got[0].Manager) + } + wantPackages(t, got[0].Packages, tt.want) + } +} + +func TestParseInstall_SimpleGo(t *testing.T) { + tests := []struct { + command string + want []Package + }{ + {"go get github.com/pkg/errors", []Package{{Name: "github.com/pkg/errors"}}}, + {"go get github.com/pkg/errors@v0.9.1", []Package{{Name: "github.com/pkg/errors", Version: "v0.9.1"}}}, + {"go install golang.org/x/tools/cmd/goimports@latest", []Package{{Name: "golang.org/x/tools/cmd/goimports", Version: "latest"}}}, + } + for _, tt := range tests { + got := ParseInstall(tt.command) + if len(got) != 1 { + t.Errorf("%q: got %d requests, want 1", tt.command, len(got)) + continue + } + if got[0].Manager != ManagerGo { + t.Errorf("%q: got manager %v, want go", tt.command, got[0].Manager) + } + wantPackages(t, got[0].Packages, tt.want) + } +} + +func TestParseInstall_SimpleMaven(t *testing.T) { + got := ParseInstall("mvn dependency:get -Dartifact=org.apache.commons:commons-lang3:3.12.0") + if len(got) != 1 { + t.Fatalf("got %d requests, want 1", len(got)) + } + if got[0].Manager != ManagerMaven { + t.Errorf("got manager %v, want maven", got[0].Manager) + } + want := []Package{{Name: "org.apache.commons:commons-lang3", Version: "3.12.0"}} + wantPackages(t, got[0].Packages, want) +} + +func TestParseInstall_MultiPackage(t *testing.T) { + tests := []struct { + command string + wantMgr Manager + wantPkgs []Package + }{ + { + "npm install lodash axios express", + ManagerNpm, + []Package{{Name: "lodash"}, {Name: "axios"}, {Name: "express"}}, + }, + { + "npm install lodash@4.0.0 axios@latest express", + ManagerNpm, + []Package{{Name: "lodash", Version: "4.0.0"}, {Name: "axios", Version: "latest"}, {Name: "express"}}, + }, + { + "yarn add a b c", + ManagerNpm, + []Package{{Name: "a"}, {Name: "b"}, {Name: "c"}}, + }, + { + "pnpm add a b c", + ManagerNpm, + []Package{{Name: "a"}, {Name: "b"}, {Name: "c"}}, + }, + { + "pip install pkg1==1.0 pkg2>=2.0 pkg3", + ManagerPypi, + []Package{{Name: "pkg1", Version: "1.0"}, {Name: "pkg2"}, {Name: "pkg3"}}, + }, + { + "poetry add a b", + ManagerPypi, + []Package{{Name: "a"}, {Name: "b"}}, + }, + { + "uv add a b", + ManagerPypi, + []Package{{Name: "a"}, {Name: "b"}}, + }, + { + "go get pkg1 pkg2@v1.0 pkg3", + ManagerGo, + []Package{{Name: "pkg1"}, {Name: "pkg2", Version: "v1.0"}, {Name: "pkg3"}}, + }, + } + for _, tt := range tests { + got := ParseInstall(tt.command) + if len(got) != 1 { + t.Errorf("%q: got %d requests, want 1", tt.command, len(got)) + continue + } + if got[0].Manager != tt.wantMgr { + t.Errorf("%q: got manager %v, want %v", tt.command, got[0].Manager, tt.wantMgr) + } + wantPackages(t, got[0].Packages, tt.wantPkgs) + } +} + +func TestParseInstall_Compound(t *testing.T) { + tests := []struct { + command string + wantCount int + wantMgrs []Manager + }{ + {"cd /repo && npm install lodash", 1, []Manager{ManagerNpm}}, + {"npm install lodash && npm test", 1, []Manager{ManagerNpm}}, + {"npm install lodash; npm install axios", 2, []Manager{ManagerNpm, ManagerNpm}}, + {"pip install lodash || echo failed", 1, []Manager{ManagerPypi}}, + {"echo \"starting\" && npm install lodash@4.0.0", 1, []Manager{ManagerNpm}}, + {"npm install lodash && yarn add axios", 2, []Manager{ManagerNpm, ManagerNpm}}, + {"npm install a b && pip install x y", 2, []Manager{ManagerNpm, ManagerPypi}}, + {"git pull && pip install requests", 1, []Manager{ManagerPypi}}, + } + for _, tt := range tests { + got := ParseInstall(tt.command) + if len(got) != tt.wantCount { + t.Errorf("%q: got %d requests, want %d (%v)", tt.command, len(got), tt.wantCount, got) + continue + } + for i, m := range tt.wantMgrs { + if got[i].Manager != m { + t.Errorf("%q: request[%d] manager %v, want %v", tt.command, i, got[i].Manager, m) + } + } + } +} + +func TestParseInstall_PipRequirementRef(t *testing.T) { + got := ParseInstall("pip install -r requirements.txt") + if len(got) != 1 { + t.Fatalf("got %d requests, want 1", len(got)) + } + if got[0].ManifestRef != "requirements.txt" { + t.Errorf("got ref %q, want %q", got[0].ManifestRef, "requirements.txt") + } + if len(got[0].Packages) != 0 { + t.Errorf("got %d packages, want 0", len(got[0].Packages)) + } +} + +func TestParseInstall_Negative(t *testing.T) { + negatives := []string{ + "", + "npm run build", + "npm test", + "pip uninstall pkg", + "pip list", + "git clone https://example.com/repo", + "npm install", // bare + "pip install", // bare + "ls -la", + "go build ./...", + "docker run --rm img", + } + for _, cmd := range negatives { + got := ParseInstall(cmd) + if len(got) != 0 { + t.Errorf("%q: got %d requests, want 0 (%v)", cmd, len(got), got) + } + } +} + +func TestParseInstall_Flags(t *testing.T) { + tests := []struct { + command string + want []Package + }{ + {"npm install --save-dev typescript", []Package{{Name: "typescript"}}}, + {"npm install -g pkg", []Package{{Name: "pkg"}}}, + {"npm install -D typescript prettier", []Package{{Name: "typescript"}, {Name: "prettier"}}}, + {"pip install --upgrade requests", []Package{{Name: "requests"}}}, + } + for _, tt := range tests { + got := ParseInstall(tt.command) + if len(got) != 1 { + t.Errorf("%q: got %d requests, want 1", tt.command, len(got)) + continue + } + wantPackages(t, got[0].Packages, tt.want) + } +} + +func TestParseInstall_LeadingNoOps(t *testing.T) { + tests := []string{ + "sudo npm install lodash", + "time npm install lodash", + "NODE_ENV=production npm install lodash", + "sudo NODE_ENV=production npm install lodash", + } + for _, cmd := range tests { + got := ParseInstall(cmd) + if len(got) != 1 { + t.Errorf("%q: got %d requests, want 1", cmd, len(got)) + continue + } + if got[0].Manager != ManagerNpm { + t.Errorf("%q: got manager %v, want npm", cmd, got[0].Manager) + } + wantPackages(t, got[0].Packages, []Package{{Name: "lodash"}}) + } +} + +func TestParseInstall_ShellExpansionDropped(t *testing.T) { + // $(cat req.txt) is opaque — we cannot statically know the packages, so + // the segment should resolve to zero install requests rather than emit + // garbage package names. + got := ParseInstall("pip install $(cat req.txt)") + if len(got) != 0 { + t.Errorf("expected $() to drop, got %d requests (%v)", len(got), got) + } + + got = ParseInstall("pip install `echo lodash`") + if len(got) != 0 { + t.Errorf("expected backtick to drop, got %d requests (%v)", len(got), got) + } + + // Mixed: real package + shell expansion → keep the real one. + got = ParseInstall("pip install requests $(cat extras)") + if len(got) != 1 { + t.Fatalf("expected 1 request, got %d (%v)", len(got), got) + } + if len(got[0].Packages) != 1 || got[0].Packages[0].Name != "requests" { + t.Errorf("got %v, want [requests]", got[0].Packages) + } +} + +func TestParseInstall_QuotedStrings(t *testing.T) { + // Strings that *contain* an install verb but aren't installs. + got := ParseInstall(`echo "npm install lodash"`) + if len(got) != 0 { + t.Errorf("quoted install verb should not match, got %d requests", len(got)) + } + + // Subshell containing an install. + got = ParseInstall(`bash -c "echo hello && npm install lodash"`) + // We don't recursively parse `-c "..."` arg payloads — but $() and `` we do. + // So this is a no-op (we don't dive into bash -c). Document the behaviour. + if len(got) != 0 { + t.Logf("bash -c payload: got %d requests (currently no-op by design)", len(got)) + } +} diff --git a/internal/commands/agenthooks/sca/diff.go b/internal/commands/agenthooks/sca/diff.go new file mode 100644 index 000000000..e51e83188 --- /dev/null +++ b/internal/commands/agenthooks/sca/diff.go @@ -0,0 +1,75 @@ +package sca + +import ( + "os" + "path/filepath" + + "github.com/Checkmarx/manifest-parser/pkg/parser" +) + +// AddedPackages returns the set of packages present in after but not in +// before, keyed by name+version. A version bump on an existing package is +// reported as added (its new version is new). +// +// Both before and after are parsed via manifest-parser's factory; we write +// them to temp files using a name that matches the format so the factory +// selects the right parser. An empty/missing before parses as zero packages +// (so every after-package is "added"). +func AddedPackages(format Format, before, after []byte) ([]Package, error) { + beforePkgs, err := parseManifestBytes(format, before) + if err != nil { + return nil, err + } + afterPkgs, err := parseManifestBytes(format, after) + if err != nil { + return nil, err + } + + seen := make(map[string]struct{}, len(beforePkgs)) + for _, p := range beforePkgs { + seen[pkgKey(p)] = struct{}{} + } + var added []Package + for _, p := range afterPkgs { + if _, ok := seen[pkgKey(p)]; ok { + continue + } + added = append(added, p) + } + return added, nil +} + +func parseManifestBytes(format Format, content []byte) ([]Package, error) { + if len(content) == 0 { + return nil, nil + } + dir, err := os.MkdirTemp("", "sca-diff-") + if err != nil { + return nil, err + } + defer os.RemoveAll(dir) + + path := filepath.Join(dir, format.SynthFileName()) + if writeErr := os.WriteFile(path, content, 0600); writeErr != nil { + return nil, writeErr + } + + p := parser.ParsersFactory(path) + if p == nil { + return nil, nil + } + rawPkgs, err := p.Parse(path) + if err != nil { + return nil, err + } + + out := make([]Package, 0, len(rawPkgs)) + for _, rp := range rawPkgs { + out = append(out, Package{Name: rp.PackageName, Version: rp.Version}) + } + return out, nil +} + +func pkgKey(p Package) string { + return p.Name + "\x00" + p.Version +} diff --git a/internal/commands/agenthooks/sca/diff_test.go b/internal/commands/agenthooks/sca/diff_test.go new file mode 100644 index 000000000..19844ab53 --- /dev/null +++ b/internal/commands/agenthooks/sca/diff_test.go @@ -0,0 +1,74 @@ +//go:build !integration + +package sca + +import ( + "testing" +) + +func TestAddedPackages_Npm_AddedNewPackage(t *testing.T) { + before := []byte(`{"name":"x","version":"1.0.0","dependencies":{"lodash":"4.17.21"}}`) + after := []byte(`{"name":"x","version":"1.0.0","dependencies":{"lodash":"4.17.21","axios":"1.0.0"}}`) + added, err := AddedPackages(FormatNpmPackageJson, before, after) + if err != nil { + t.Fatalf("AddedPackages: %v", err) + } + if len(added) != 1 || added[0].Name != "axios" { + t.Errorf("got added=%v, want [axios]", added) + } +} + +func TestAddedPackages_Npm_VersionBumpCountsAsAdded(t *testing.T) { + before := []byte(`{"name":"x","version":"1.0.0","dependencies":{"lodash":"4.17.0"}}`) + after := []byte(`{"name":"x","version":"1.0.0","dependencies":{"lodash":"4.17.21"}}`) + added, err := AddedPackages(FormatNpmPackageJson, before, after) + if err != nil { + t.Fatalf("AddedPackages: %v", err) + } + if len(added) != 1 || added[0].Name != "lodash" || added[0].Version != "4.17.21" { + t.Errorf("got added=%v, want [lodash@4.17.21]", added) + } +} + +func TestAddedPackages_Npm_RemovedPackageIsIgnored(t *testing.T) { + before := []byte(`{"name":"x","version":"1.0.0","dependencies":{"lodash":"4.17.21","axios":"1.0.0"}}`) + after := []byte(`{"name":"x","version":"1.0.0","dependencies":{"lodash":"4.17.21"}}`) + added, err := AddedPackages(FormatNpmPackageJson, before, after) + if err != nil { + t.Fatalf("AddedPackages: %v", err) + } + if len(added) != 0 { + t.Errorf("got %d added, want 0 (%v)", len(added), added) + } +} + +func TestAddedPackages_Npm_NewFile(t *testing.T) { + after := []byte(`{"name":"x","version":"1.0.0","dependencies":{"lodash":"4.17.21","axios":"1.0.0"}}`) + added, err := AddedPackages(FormatNpmPackageJson, nil, after) + if err != nil { + t.Fatalf("AddedPackages: %v", err) + } + if len(added) != 2 { + t.Errorf("got %d added, want 2 (%v)", len(added), added) + } +} + +func TestAddedPackages_Pypi_AddedPackage(t *testing.T) { + before := []byte("requests==2.25.1\n") + after := []byte("requests==2.25.1\nflask==2.0.0\n") + added, err := AddedPackages(FormatPypiRequirements, before, after) + if err != nil { + t.Fatalf("AddedPackages: %v", err) + } + if len(added) != 1 || added[0].Name != "flask" { + t.Errorf("got added=%v, want [flask]", added) + } +} + +func TestAddedPackages_UnparseableContent(t *testing.T) { + // Note: behaviour for unparseable content depends on the upstream parser. + // We assert that errors flow back to the caller, not that any specific + // content causes an error — the caller's contract is "treat errors as + // fail-open" so callers don't depend on a particular outcome here. + _, _ = AddedPackages(FormatNpmPackageJson, nil, []byte("{not valid json")) +} diff --git a/internal/commands/agenthooks/sca/manifests.go b/internal/commands/agenthooks/sca/manifests.go new file mode 100644 index 000000000..0d6702346 --- /dev/null +++ b/internal/commands/agenthooks/sca/manifests.go @@ -0,0 +1,98 @@ +// Package sca implements Open Source Software (OSS) realtime guardrails for +// AI coding agents: install-command interception and manifest-edit gating. +package sca + +import ( + "path/filepath" + "strings" +) + +// Format identifies one of the manifest shapes that the SCA guardrails care +// about. Each format lines up 1:1 with a format the Checkmarx manifest-parser +// library recognises. +type Format int + +const ( + FormatUnknown Format = iota + FormatNpmPackageJson + FormatPypiRequirements + FormatGoMod + FormatMavenPom + FormatDotnetCsproj + FormatDotnetDirectoryPackagesProps + FormatDotnetPackagesConfig +) + +// IsManifest reports whether path names a manifest file the OSS realtime +// scanner can analyse. The rules mirror manifest-parser's selectManifestFile: +// +// - *.csproj → Dotnet csproj +// - requirements*.txt, packages*.txt → Pypi requirements +// - pom.xml → Maven +// - package.json → Npm +// - Directory.Packages.props → Dotnet central package management +// - packages.config → Dotnet legacy +// - go.mod → Go modules +func IsManifest(path string) (Format, bool) { + base := filepath.Base(path) + ext := filepath.Ext(base) + + switch { + case ext == ".csproj": + return FormatDotnetCsproj, true + case ext == ".txt" && (strings.HasPrefix(base, "requirement") || strings.HasPrefix(base, "packages")): + return FormatPypiRequirements, true + case base == "pom.xml": + return FormatMavenPom, true + case base == "package.json": + return FormatNpmPackageJson, true + case base == "Directory.Packages.props": + return FormatDotnetDirectoryPackagesProps, true + case base == "packages.config": + return FormatDotnetPackagesConfig, true + case base == "go.mod": + return FormatGoMod, true + } + return FormatUnknown, false +} + +// ManagerName returns the package-manager string oss-realtime uses for the +// given format (matching the PackageManager field on OssPackage). +func (f Format) ManagerName() string { + switch f { + case FormatNpmPackageJson: + return "npm" + case FormatPypiRequirements: + return "pypi" + case FormatGoMod: + return "go" + case FormatMavenPom: + return "maven" + case FormatDotnetCsproj, FormatDotnetDirectoryPackagesProps, FormatDotnetPackagesConfig: + return "nuget" + } + return "" +} + +// SynthFileName returns a filename to use when writing a temp manifest in the +// given format. The name matters because manifest-parser's factory selects a +// parser by basename/extension. +func (f Format) SynthFileName() string { + switch f { + case FormatNpmPackageJson: + return "package.json" + case FormatPypiRequirements: + return "requirements.txt" + case FormatGoMod: + return "go.mod" + case FormatMavenPom: + return "pom.xml" + case FormatDotnetCsproj: + return "synth.csproj" + case FormatDotnetDirectoryPackagesProps: + return "Directory.Packages.props" + case FormatDotnetPackagesConfig: + return "packages.config" + } + return "" +} diff --git a/internal/commands/agenthooks/sca/manifests_test.go b/internal/commands/agenthooks/sca/manifests_test.go new file mode 100644 index 000000000..801f22e58 --- /dev/null +++ b/internal/commands/agenthooks/sca/manifests_test.go @@ -0,0 +1,40 @@ +//go:build !integration + +package sca + +import "testing" + +func TestIsManifest(t *testing.T) { + tests := []struct { + path string + wantOK bool + wantFmt Format + }{ + {"package.json", true, FormatNpmPackageJson}, + {"/repo/package.json", true, FormatNpmPackageJson}, + {"requirements.txt", true, FormatPypiRequirements}, + {"requirements-dev.txt", true, FormatPypiRequirements}, + {"packages.txt", true, FormatPypiRequirements}, + {"go.mod", true, FormatGoMod}, + {"pom.xml", true, FormatMavenPom}, + {"app.csproj", true, FormatDotnetCsproj}, + {"Project.csproj", true, FormatDotnetCsproj}, + {"Directory.Packages.props", true, FormatDotnetDirectoryPackagesProps}, + {"packages.config", true, FormatDotnetPackagesConfig}, + + // Negatives. + {"main.go", false, FormatUnknown}, + {"Dockerfile", false, FormatUnknown}, + {"README.md", false, FormatUnknown}, + {"random.txt", false, FormatUnknown}, + {"", false, FormatUnknown}, + } + for _, tt := range tests { + gotFmt, gotOK := IsManifest(tt.path) + if gotOK != tt.wantOK || gotFmt != tt.wantFmt { + t.Errorf("IsManifest(%q) = (%v, %v), want (%v, %v)", + tt.path, gotFmt, gotOK, tt.wantFmt, tt.wantOK) + } + } +} + diff --git a/internal/commands/agenthooks/sca/prompts.go b/internal/commands/agenthooks/sca/prompts.go new file mode 100644 index 000000000..d80268d0b --- /dev/null +++ b/internal/commands/agenthooks/sca/prompts.go @@ -0,0 +1,125 @@ +package sca + +import ( + "encoding/json" + "fmt" + "os" + "strings" + + "github.com/checkmarx/ast-cli/internal/services/realtimeengine/ignore" + "github.com/checkmarx/ast-cli/internal/services/realtimeengine/ossrealtime" +) + +// DenyMalicious returns the finding and remediation strings for one or more +// packages classified as Malicious. +func DenyMalicious(pkgs []ossrealtime.OssPackage) (finding, remediation string) { + var b strings.Builder + b.WriteString("Checkmarx SCA scan detected MALICIOUS package(s):\n") + for _, p := range pkgs { + fmt.Fprintf(&b, " - %s: Known supply chain attack.\n", pkgLabel(p)) + } + b.WriteString("\nDo NOT proceed with the installation.") + return b.String(), remediationNote("malicious package", "safest available version") +} + +// DenyVulnerable returns the finding and remediation strings for one or more +// packages with known vulnerabilities. workDir anchors the suppression command's +// --ignored-file-path to the workspace ignore file (see vulnerableRemediationNote). +func DenyVulnerable(pkgs []ossrealtime.OssPackage, workDir string) (finding, remediation string) { + var b strings.Builder + b.WriteString("Checkmarx SCA scan found vulnerabilities:\n") + for _, p := range pkgs { + fmt.Fprintf(&b, " - %s: %s\n", pkgLabel(p), vulnDetail(p)) + } + return b.String(), vulnerableRemediationNote(pkgs, workDir) +} + +// remediationNote returns the action steps to include as additionalContext for malicious packages. +// Remediation must go through the MCP tool; if it is unavailable Claude must self-recover +// by running the registration helper rather than asking the user. +func remediationNote(subject, goal string) string { + return fmt.Sprintf( + "Action required:\n"+ + " 1. Use mcp__Checkmarx__packageRemediation for each %s to find the %s.\n"+ + " This is the only supported remediation path — do not attempt manual version selection.\n"+ + " 2. If mcp__Checkmarx__packageRemediation is unavailable, recover it yourself — do NOT ask\n"+ + " the user. Run: bash \"${CLAUDE_PLUGIN_ROOT}/mcp/cx_mcp_register.sh\"\n"+ + " Then tell the user: 'I have registered the Checkmarx MCP. Please run /mcp (or restart\n"+ + " Claude Code) to load it, then retry.' Do not proceed until the MCP is available.\n"+ + " 3. If no safe version exists, the user can acknowledge the finding via\n"+ + " the Checkmarx Dev Assist interface.", + subject, goal) +} + +// vulnerableRemediationNote returns the action steps for vulnerable packages. +// When no safe version is found, the agent runs the per-package ignore command +// and informs the user. +func vulnerableRemediationNote(pkgs []ossrealtime.OssPackage, workDir string) string { + cxBinary := cxExecutable() + ignoreFlag := ignoredFilePathFlag(workDir) + var suppressCmds strings.Builder + for _, p := range pkgs { + data, _ := json.Marshal([]map[string]string{{ + "PackageManager": p.PackageManager, + "PackageName": p.PackageName, + "PackageVersion": p.PackageVersion, + }}) + fmt.Fprintf(&suppressCmds, " %s ignore-vulnerability --scan-type sca --data '%s'%s\n", cxBinary, string(data), ignoreFlag) + } + return fmt.Sprintf( + "Action required:\n"+ + " 1. Use mcp__Checkmarx__packageRemediation for each affected package to find the non-vulnerable version.\n"+ + " This is the only supported remediation path — do not attempt manual version selection.\n"+ + " 2. If mcp__Checkmarx__packageRemediation is unavailable, recover it yourself — do NOT ask\n"+ + " the user. Run: bash \"${CLAUDE_PLUGIN_ROOT}/mcp/cx_mcp_register.sh\"\n"+ + " Then tell the user: 'I have registered the Checkmarx MCP. Please run /mcp (or restart\n"+ + " Claude Code) to load it, then retry.' Do not proceed until the MCP is available.\n"+ + " 3. If no safe version exists for a package, suppress it by running the corresponding command\n"+ + " and inform the user that no safer version is available:\n%s", + suppressCmds.String()) +} + +// ignoredFilePathFlag returns the " --ignored-file-path ''" fragment that +// pins the suppression command to the workspace ignore file, anchored at the hook +// event's workDir. This keeps the write (cx ignore-vulnerability) and the later +// read (the hook) on the same absolute file regardless of either process's CWD — +// without it, a host CLI that runs the agent's shell from a different directory +// than the hook (e.g. Copilot CLI) would write and read different files. Returns +// "" when workDir is unknown so the command falls back to its CWD-relative default. +func ignoredFilePathFlag(workDir string) string { + if workDir == "" { + return "" + } + return fmt.Sprintf(" --ignored-file-path '%s'", ignore.PathFor(workDir)) +} + +func cxExecutable() string { + cxExe, err := os.Executable() + if err != nil { + return "cx" + } + return cxExe +} + +func pkgLabel(p ossrealtime.OssPackage) string { + if p.PackageVersion == "" { + return p.PackageName + } + return p.PackageName + "@" + p.PackageVersion +} + +func vulnDetail(p ossrealtime.OssPackage) string { + if len(p.Vulnerabilities) == 0 { + return "vulnerability detected" + } + v := p.Vulnerabilities[0] + cve := v.CVE + if cve == "" { + cve = "unknown" + } + desc := v.Description + if desc == "" { + desc = "vulnerability detected" + } + return fmt.Sprintf("%s — %s", cve, desc) +} diff --git a/internal/commands/agenthooks/sca/sca.go b/internal/commands/agenthooks/sca/sca.go new file mode 100644 index 000000000..af01f0d0f --- /dev/null +++ b/internal/commands/agenthooks/sca/sca.go @@ -0,0 +1,78 @@ +package sca + +import ( + "os" + + "github.com/checkmarx/ast-cli/internal/services/realtimeengine/ossrealtime" +) + +// CheckBashInstall is the entry point for the pre-Bash-tool guardrail. It +// returns ("", "") to allow the command, or (finding, remediation) to block. +// Errors fail open — we never block on infrastructure failures. +// +// Compound commands produce multiple install requests; we scan each and +// return on the first finding (malicious takes precedence over vulnerable). +func (s *Scanner) CheckBashInstall(command, workDir string) (finding, remediation string) { + s.workDir = workDir + for _, req := range ParseInstall(command) { + mal, vuln, err := s.scanRequest(req, workDir) + if err != nil { + continue + } + if f, r := denyFrom(mal, vuln, workDir); f != "" { + return f, r + } + } + return "", "" +} + +func (s *Scanner) scanRequest(req InstallRequest, workDir string) (malicious, vulnerable []ossrealtime.OssPackage, err error) { + if req.ManifestRef != "" { + path := resolveRef(req.ManifestRef, workDir) + if _, statErr := os.Stat(path); statErr != nil { + return nil, nil, nil + } + return s.ScanFile(path) + } + if len(req.Packages) == 0 { + return nil, nil, nil + } + return s.ScanPackages(req.Manager.Format(), req.Packages) +} + +// CheckManifestEdit is the entry point for the pre-file-edit guardrail. It +// returns ("", "") to accept the edit, or (finding, remediation) to reject. +// +// Non-manifest paths are a no-op. For manifest paths we diff before/after, +// scan only the newly-added packages, and reject if any are malicious or +// vulnerable. +func (s *Scanner) CheckManifestEdit(filePath string, afterContent []byte, workDir string) (finding, remediation string) { + s.workDir = workDir + format, ok := IsManifest(filePath) + if !ok { + return "", "" + } + before, _ := os.ReadFile(filePath) // missing → empty before + added, err := AddedPackages(format, before, afterContent) + if err != nil || len(added) == 0 { + return "", "" + } + mal, vuln, err := s.ScanPackages(format, added) + if err != nil { + return "", "" + } + return denyFrom(mal, vuln, workDir) +} + +// denyFrom builds the (finding, remediation) pair. workDir anchors the +// `cx ignore-vulnerability` suppression command emitted for vulnerable packages +// to the workspace ignore file so the agent writes where the hook later reads. +func denyFrom(malicious, vulnerable []ossrealtime.OssPackage, workDir string) (finding, remediation string) { + if len(malicious) > 0 { + return DenyMalicious(malicious) + } + if len(vulnerable) > 0 { + return DenyVulnerable(vulnerable, workDir) + } + return "", "" +} diff --git a/internal/commands/agenthooks/sca/sca_test.go b/internal/commands/agenthooks/sca/sca_test.go new file mode 100644 index 000000000..7fe717ee5 --- /dev/null +++ b/internal/commands/agenthooks/sca/sca_test.go @@ -0,0 +1,290 @@ +//go:build !integration + +package sca + +import ( + "os" + "path/filepath" + "strings" + "testing" + + "github.com/checkmarx/ast-cli/internal/services/realtimeengine/ignore" + "github.com/checkmarx/ast-cli/internal/services/realtimeengine/ossrealtime" +) + +func scannerWith(pkgs ...ossrealtime.OssPackage) *Scanner { + return NewScannerWithFunc(func(string) (*ossrealtime.OssPackageResults, error) { + return &ossrealtime.OssPackageResults{Packages: pkgs}, nil + }) +} + +func TestCheckBashInstall_NoInstallCommand(t *testing.T) { + s := scannerWith(ossrealtime.OssPackage{PackageName: "anything", Status: "Malicious"}) + finding, _ := s.CheckBashInstall("ls -la", "") + if finding != "" { + t.Errorf("expected empty for non-install command, got %q", finding) + } +} + +func TestCheckBashInstall_CleanInstall(t *testing.T) { + s := scannerWith(ossrealtime.OssPackage{PackageName: "lodash", Status: "OK"}) + finding, _ := s.CheckBashInstall("npm install lodash", "") + if finding != "" { + t.Errorf("expected empty for clean install, got %q", finding) + } +} + +func TestCheckBashInstall_MaliciousMentionsMCP(t *testing.T) { + s := scannerWith(ossrealtime.OssPackage{PackageName: "lodash", PackageVersion: "4.17.21", Status: "Malicious"}) + finding, remediation := s.CheckBashInstall("npm install lodash@4.17.21", "") + if !strings.Contains(finding, "MALICIOUS") { + t.Errorf("expected finding to mention MALICIOUS, got %q", finding) + } + if !strings.Contains(remediation, "mcp__Checkmarx__packageRemediation") { + t.Errorf("expected remediation to reference MCP tool, got %q", remediation) + } + if !strings.Contains(remediation, "install or enable") { + t.Errorf("expected remediation to mention installing/enabling MCP, got %q", remediation) + } + if !strings.Contains(remediation, "Dev Assist") { + t.Errorf("expected remediation to mention Dev Assist fallback, got %q", remediation) + } +} + +func TestCheckBashInstall_Vulnerable(t *testing.T) { + s := scannerWith(ossrealtime.OssPackage{PackageName: "axios", PackageVersion: "0.21.0", Status: "Vulnerable"}) + finding, remediation := s.CheckBashInstall("npm install axios@0.21.0", "") + if !strings.Contains(finding, "vulnerabilities") { + t.Errorf("expected vulnerable finding, got %q", finding) + } + if !strings.Contains(remediation, "mcp__Checkmarx__packageRemediation") { + t.Errorf("expected remediation to reference MCP tool, got %q", remediation) + } + if !strings.Contains(remediation, "ignore-vulnerability") { + t.Errorf("expected remediation to include ignore command, got %q", remediation) + } + if !strings.Contains(remediation, "axios") { + t.Errorf("expected remediation to include package name, got %q", remediation) + } +} + +func TestCheckBashInstall_MaliciousTakesPrecedence(t *testing.T) { + s := scannerWith( + ossrealtime.OssPackage{PackageName: "vuln", Status: "Vulnerable"}, + ossrealtime.OssPackage{PackageName: "bad", Status: "Malicious"}, + ) + finding, _ := s.CheckBashInstall("npm install vuln bad", "") + if !strings.Contains(finding, "MALICIOUS") { + t.Errorf("expected MALICIOUS message when both present, got %q", finding) + } +} + +func TestCheckBashInstall_CompoundWithCleanThenBad(t *testing.T) { + // First call clean, second call malicious. + call := 0 + s := NewScannerWithFunc(func(string) (*ossrealtime.OssPackageResults, error) { + call++ + if call == 1 { + return &ossrealtime.OssPackageResults{Packages: []ossrealtime.OssPackage{ + {PackageName: "lodash", Status: "OK"}, + }}, nil + } + return &ossrealtime.OssPackageResults{Packages: []ossrealtime.OssPackage{ + {PackageName: "evil", Status: "Malicious"}, + }}, nil + }) + finding, _ := s.CheckBashInstall("npm install lodash && pip install evil", "") + if !strings.Contains(finding, "MALICIOUS") { + t.Errorf("expected deny on second segment, got %q", finding) + } + if call != 2 { + t.Errorf("expected 2 scans, got %d", call) + } +} + +func TestCheckBashInstall_FailOpenOnScannerError(t *testing.T) { + s := NewScannerWithFunc(func(string) (*ossrealtime.OssPackageResults, error) { + return nil, errBoom + }) + finding, _ := s.CheckBashInstall("npm install lodash", "") + if finding != "" { + t.Errorf("expected fail-open empty on scanner error, got %q", finding) + } +} + +func TestCheckManifestEdit_NonManifestNoop(t *testing.T) { + s := scannerWith(ossrealtime.OssPackage{PackageName: "x", Status: "Malicious"}) + finding, _ := s.CheckManifestEdit("/repo/main.go", []byte("anything"), "") + if finding != "" { + t.Errorf("non-manifest: expected empty, got %q", finding) + } +} + +func TestCheckManifestEdit_NewMaliciousAddition(t *testing.T) { + dir, err := os.MkdirTemp("", "ck-edit-test-") + if err != nil { + t.Fatalf("mkdtemp: %v", err) + } + defer os.RemoveAll(dir) + pkgJSON := filepath.Join(dir, "package.json") + // Pre-existing state: lodash already installed. + if err := os.WriteFile(pkgJSON, []byte(`{"name":"x","dependencies":{"lodash":"4.17.21"}}`), 0600); err != nil { + t.Fatalf("write: %v", err) + } + + // Proposed edit: add evil-pkg. + after := []byte(`{"name":"x","dependencies":{"lodash":"4.17.21","evil-pkg":"1.0.0"}}`) + + s := scannerWith(ossrealtime.OssPackage{PackageName: "evil-pkg", PackageVersion: "1.0.0", Status: "Malicious"}) + finding, remediation := s.CheckManifestEdit(pkgJSON, after, "") + if !strings.Contains(finding, "MALICIOUS") { + t.Errorf("expected MALICIOUS finding, got %q", finding) + } + if !strings.Contains(remediation, "mcp__Checkmarx__packageRemediation") { + t.Errorf("expected remediation to reference MCP tool, got %q", remediation) + } +} + +func TestCheckManifestEdit_OnlyVersionBumpOfCleanPkg(t *testing.T) { + dir, err := os.MkdirTemp("", "ck-edit-test-") + if err != nil { + t.Fatalf("mkdtemp: %v", err) + } + defer os.RemoveAll(dir) + pkgJSON := filepath.Join(dir, "package.json") + if err := os.WriteFile(pkgJSON, []byte(`{"name":"x","dependencies":{"lodash":"4.17.0"}}`), 0600); err != nil { + t.Fatalf("write: %v", err) + } + after := []byte(`{"name":"x","dependencies":{"lodash":"4.17.21"}}`) + + // Even though it's only a bump, the new version is "new" and gets scanned. + // If the scanner returns OK, the edit is accepted. + s := scannerWith(ossrealtime.OssPackage{PackageName: "lodash", PackageVersion: "4.17.21", Status: "OK"}) + finding, _ := s.CheckManifestEdit(pkgJSON, after, "") + if finding != "" { + t.Errorf("expected accept for clean version bump, got %q", finding) + } +} + +func TestDenyVulnerable_IgnoreCommandIncludesPackageData(t *testing.T) { + pkgs := []ossrealtime.OssPackage{ + {PackageManager: "pip", PackageName: "requests", PackageVersion: "2.19.0"}, + } + _, remediation := DenyVulnerable(pkgs, "") + if !strings.Contains(remediation, "ignore-vulnerability") { + t.Errorf("expected ignore-vulnerability in remediation, got %q", remediation) + } + if !strings.Contains(remediation, "requests") { + t.Errorf("expected package name in remediation, got %q", remediation) + } + if !strings.Contains(remediation, "2.19.0") { + t.Errorf("expected package version in remediation, got %q", remediation) + } + if !strings.Contains(remediation, "pip") { + t.Errorf("expected package manager in remediation, got %q", remediation) + } +} + +func TestDenyVulnerable_MultiplePackages_EachGetsIgnoreCommand(t *testing.T) { + pkgs := []ossrealtime.OssPackage{ + {PackageManager: "npm", PackageName: "lodash", PackageVersion: "4.17.0"}, + {PackageManager: "npm", PackageName: "axios", PackageVersion: "0.21.0"}, + } + _, remediation := DenyVulnerable(pkgs, "") + if strings.Count(remediation, "ignore-vulnerability") != 2 { + t.Errorf("expected 2 ignore commands for 2 packages, got %q", remediation) + } + if !strings.Contains(remediation, "lodash") { + t.Errorf("expected lodash in remediation, got %q", remediation) + } + if !strings.Contains(remediation, "axios") { + t.Errorf("expected axios in remediation, got %q", remediation) + } +} + +func TestDenyVulnerable_PinsIgnoredFilePathToWorkDir(t *testing.T) { + pkgs := []ossrealtime.OssPackage{ + {PackageManager: "npm", PackageName: "axios", PackageVersion: "0.21.0"}, + } + workDir := filepath.Join("repo", "ws") + _, remediation := DenyVulnerable(pkgs, workDir) + want := "--ignored-file-path '" + ignore.PathFor(workDir) + "'" + if !strings.Contains(remediation, want) { + t.Errorf("expected remediation to pin %q, got %q", want, remediation) + } +} + +func TestDenyVulnerable_EmptyWorkDirOmitsIgnoredFilePath(t *testing.T) { + pkgs := []ossrealtime.OssPackage{ + {PackageManager: "npm", PackageName: "axios", PackageVersion: "0.21.0"}, + } + _, remediation := DenyVulnerable(pkgs, "") + if strings.Contains(remediation, "--ignored-file-path") { + t.Errorf("expected no ignored-file-path flag for empty workDir, got %q", remediation) + } +} + +func TestDenyMalicious_StillMentionsDevAssist(t *testing.T) { + pkgs := []ossrealtime.OssPackage{ + {PackageName: "evil-pkg", PackageVersion: "1.0.0"}, + } + _, remediation := DenyMalicious(pkgs) + if !strings.Contains(remediation, "Dev Assist") { + t.Errorf("malicious remediation should still mention Dev Assist, got %q", remediation) + } +} + +func TestCheckManifestEdit_VulnerableContainsIgnoreCommand(t *testing.T) { + dir, err := os.MkdirTemp("", "ck-vuln-ignore-test-") + if err != nil { + t.Fatalf("mkdtemp: %v", err) + } + defer os.RemoveAll(dir) + pkgJSON := filepath.Join(dir, "package.json") + if err := os.WriteFile(pkgJSON, []byte(`{"name":"x","dependencies":{}}`), 0600); err != nil { + t.Fatalf("write: %v", err) + } + after := []byte(`{"name":"x","dependencies":{"axios":"0.21.0"}}`) + + s := scannerWith(ossrealtime.OssPackage{ + PackageManager: "npm", PackageName: "axios", PackageVersion: "0.21.0", Status: "Vulnerable", + }) + finding, remediation := s.CheckManifestEdit(pkgJSON, after, "") + if !strings.Contains(finding, "vulnerabilities") { + t.Errorf("expected vulnerable finding, got %q", finding) + } + if !strings.Contains(remediation, "ignore-vulnerability") { + t.Errorf("expected ignore command in remediation, got %q", remediation) + } + if !strings.Contains(remediation, "axios") { + t.Errorf("expected package name in remediation, got %q", remediation) + } +} + +func TestExistingIgnoreFilePath_ResolvesUnderWorkDir(t *testing.T) { + workDir := t.TempDir() + dir := filepath.Join(workDir, ".checkmarx") + if err := os.MkdirAll(dir, 0o750); err != nil { + t.Fatalf("mkdir: %v", err) + } + if err := os.WriteFile(filepath.Join(dir, "checkmarxIgnoredTempList.json"), []byte("[]"), 0o600); err != nil { + t.Fatalf("write: %v", err) + } + got := existingIgnoreFilePath(workDir) + if want := ignore.PathFor(workDir); got != want { + t.Errorf("expected ignore path %q under workDir, got %q", want, got) + } +} + +func TestExistingIgnoreFilePath_MissingReturnsEmpty(t *testing.T) { + // Empty workspace: no .checkmarx file → no filtering path passed to the scanner. + if got := existingIgnoreFilePath(t.TempDir()); got != "" { + t.Errorf("expected empty for missing ignore file, got %q", got) + } +} + +var errBoom = stringError("boom") + +type stringError string + +func (e stringError) Error() string { return string(e) } diff --git a/internal/commands/agenthooks/sca/scan.go b/internal/commands/agenthooks/sca/scan.go new file mode 100644 index 000000000..e78ab6b54 --- /dev/null +++ b/internal/commands/agenthooks/sca/scan.go @@ -0,0 +1,119 @@ +package sca + +import ( + "os" + + "github.com/checkmarx/ast-cli/internal/services/realtimeengine/ignore" + "github.com/checkmarx/ast-cli/internal/services/realtimeengine/ossrealtime" + "github.com/checkmarx/ast-cli/internal/wrappers" +) + +// statusOK and statusUnknown are the only Status values that the SCA hooks +// treat as "clean". Everything else is either Malicious (escalates the deny +// message) or generic Vulnerable. The literal strings come from the upstream +// realtime-scanner API and match the values used in oss-realtime tests. +const ( + statusOK = "OK" + statusMalicious = "Malicious" + statusUnknown = "Unknown" +) + +// Scanner runs oss-realtime scans on behalf of the SCA guardrails. It holds +// the wrappers needed to construct an OssRealtimeService per call. Tests +// substitute scan via NewScannerWithFunc. +type Scanner struct { + JWT wrappers.JWTWrapper + FF wrappers.FeatureFlagsWrapper + RT wrappers.RealtimeScannerWrapper + scan func(path string) (*ossrealtime.OssPackageResults, error) + // workDir is the workspace root reported by the hook event (its "cwd"). The + // Check* entry points set it so the ignore-file lookup is anchored to the + // workspace rather than the hook process's own working directory. A hook runs + // as a one-shot process handling a single event, so this single field is safe. + workDir string +} + +// NewScanner returns a Scanner backed by the given wrappers. The scan call +// goes through ossrealtime.NewOssRealtimeService.RunOssRealtimeScan. +func NewScanner(jwt wrappers.JWTWrapper, ff wrappers.FeatureFlagsWrapper, rt wrappers.RealtimeScannerWrapper) *Scanner { + s := &Scanner{JWT: jwt, FF: ff, RT: rt} + s.scan = s.runRealScan + return s +} + +// NewScannerWithFunc returns a Scanner whose scan call is replaced with f. +// For unit tests only. +func NewScannerWithFunc(f func(path string) (*ossrealtime.OssPackageResults, error)) *Scanner { + return &Scanner{scan: f} +} + +func (s *Scanner) runRealScan(path string) (*ossrealtime.OssPackageResults, error) { + svc := ossrealtime.NewOssRealtimeService(s.JWT, s.FF, s.RT) + return svc.RunOssRealtimeScan(path, existingIgnoreFilePath(s.workDir)) +} + +// existingIgnoreFilePath returns the realtime ignore-file path (anchored at the +// hook event's workDir) only when it exists on disk. Passing a missing path to +// RunOssRealtimeScan is harmless but consistent with the ASCA pattern of only +// enabling filtering once the file exists. +func existingIgnoreFilePath(workDir string) string { + p := ignore.PathFor(workDir) + if _, err := os.Stat(p); err == nil { + return p + } + return "" +} + +// ScanPackages synthesises a temp manifest from pkgs and scans it. Returns +// (malicious, vulnerable) buckets. On error the buckets are nil and the error +// is propagated — callers fail open by treating errors as "no findings". +func (s *Scanner) ScanPackages(format Format, pkgs []Package) (malicious, vulnerable []ossrealtime.OssPackage, err error) { + if len(pkgs) == 0 { + return nil, nil, nil + } + dir, err := os.MkdirTemp("", "sca-scan-") + if err != nil { + return nil, nil, err + } + defer os.RemoveAll(dir) + + // Versions are passed through exactly as parsed. The realtime scanner matches + // on the literal version string, so padding (e.g. Maven 1.7 -> 1.7.0) makes the + // backend mismatch / time out. Callers that need bare-version handling (bash + // installs, e.g. parseNpmSpec) normalize at parse time instead. + path, err := Synthesize(format, pkgs, dir) + if err != nil { + return nil, nil, err + } + return s.scanAndBucket(path) +} + +// ScanFile scans an existing manifest at path (used for `pip install -r ...` +// and for the Cursor post-write audit). +func (s *Scanner) ScanFile(path string) (malicious, vulnerable []ossrealtime.OssPackage, err error) { + return s.scanAndBucket(path) +} + +func (s *Scanner) scanAndBucket(path string) (malicious, vulnerable []ossrealtime.OssPackage, err error) { + if s == nil || s.scan == nil { + return nil, nil, nil + } + results, err := s.scan(path) + if err != nil { + return nil, nil, err + } + if results == nil { + return nil, nil, nil + } + for _, p := range results.Packages { + switch p.Status { + case statusMalicious: + malicious = append(malicious, p) + case statusOK, statusUnknown, "": + // clean / not classified — ignore + default: + vulnerable = append(vulnerable, p) + } + } + return malicious, vulnerable, nil +} diff --git a/internal/commands/agenthooks/sca/scan_test.go b/internal/commands/agenthooks/sca/scan_test.go new file mode 100644 index 000000000..ab361ed38 --- /dev/null +++ b/internal/commands/agenthooks/sca/scan_test.go @@ -0,0 +1,85 @@ +//go:build !integration + +package sca + +import ( + "errors" + "testing" + + "github.com/checkmarx/ast-cli/internal/services/realtimeengine/ossrealtime" +) + +func fakeResults(pkgs ...ossrealtime.OssPackage) func(string) (*ossrealtime.OssPackageResults, error) { + return func(string) (*ossrealtime.OssPackageResults, error) { + return &ossrealtime.OssPackageResults{Packages: pkgs}, nil + } +} + +func TestScanner_BucketsByStatus(t *testing.T) { + s := NewScannerWithFunc(fakeResults( + ossrealtime.OssPackage{PackageName: "ok", Status: "OK"}, + ossrealtime.OssPackage{PackageName: "bad", Status: "Malicious"}, + ossrealtime.OssPackage{PackageName: "vuln", Status: "Vulnerable"}, + ossrealtime.OssPackage{PackageName: "huh", Status: "Unknown"}, + )) + mal, vuln, err := s.ScanPackages(FormatNpmPackageJson, []Package{{Name: "x"}}) + if err != nil { + t.Fatalf("ScanPackages: %v", err) + } + if len(mal) != 1 || mal[0].PackageName != "bad" { + t.Errorf("malicious=%v, want [bad]", mal) + } + if len(vuln) != 1 || vuln[0].PackageName != "vuln" { + t.Errorf("vulnerable=%v, want [vuln]", vuln) + } +} + +func TestScanner_AllClean(t *testing.T) { + s := NewScannerWithFunc(fakeResults( + ossrealtime.OssPackage{PackageName: "a", Status: "OK"}, + ossrealtime.OssPackage{PackageName: "b", Status: "OK"}, + )) + mal, vuln, err := s.ScanPackages(FormatNpmPackageJson, []Package{{Name: "a"}, {Name: "b"}}) + if err != nil { + t.Fatalf("ScanPackages: %v", err) + } + if len(mal) != 0 || len(vuln) != 0 { + t.Errorf("expected no findings, got mal=%v vuln=%v", mal, vuln) + } +} + +func TestScanner_UpstreamErrorPropagates(t *testing.T) { + wantErr := errors.New("boom") + s := NewScannerWithFunc(func(string) (*ossrealtime.OssPackageResults, error) { + return nil, wantErr + }) + _, _, err := s.ScanPackages(FormatNpmPackageJson, []Package{{Name: "x"}}) + if !errors.Is(err, wantErr) { + t.Errorf("got err %v, want %v", err, wantErr) + } +} + +func TestScanner_EmptyPackagesIsNoop(t *testing.T) { + called := false + s := NewScannerWithFunc(func(string) (*ossrealtime.OssPackageResults, error) { + called = true + return nil, nil + }) + mal, vuln, err := s.ScanPackages(FormatNpmPackageJson, nil) + if err != nil || len(mal) != 0 || len(vuln) != 0 { + t.Errorf("expected zero results no error, got mal=%v vuln=%v err=%v", mal, vuln, err) + } + if called { + t.Errorf("scan should not be invoked for empty package list") + } +} + +func TestScanner_NilResultsAreSafe(t *testing.T) { + s := NewScannerWithFunc(func(string) (*ossrealtime.OssPackageResults, error) { + return nil, nil + }) + mal, vuln, err := s.ScanPackages(FormatNpmPackageJson, []Package{{Name: "x"}}) + if err != nil || len(mal) != 0 || len(vuln) != 0 { + t.Errorf("expected zero results no error, got mal=%v vuln=%v err=%v", mal, vuln, err) + } +} diff --git a/internal/commands/agenthooks/sca/synth.go b/internal/commands/agenthooks/sca/synth.go new file mode 100644 index 000000000..f0f99f34f --- /dev/null +++ b/internal/commands/agenthooks/sca/synth.go @@ -0,0 +1,149 @@ +package sca + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" + "strings" +) + +// Synthesize writes a minimal manifest of the given format containing pkgs +// into dir, and returns the absolute path. The synthesised file's basename +// matches what manifest-parser expects so its parser factory picks the right +// parser when the scanner reads the file back. +func Synthesize(format Format, pkgs []Package, dir string) (string, error) { + name := format.SynthFileName() + if name == "" { + return "", fmt.Errorf("sca: unsupported synth format %d", format) + } + path := filepath.Join(dir, name) + + var content []byte + var err error + switch format { + case FormatNpmPackageJson: + content, err = synthNpm(pkgs) + case FormatPypiRequirements: + content = synthPypi(pkgs) + case FormatGoMod: + content = synthGoMod(pkgs) + case FormatMavenPom: + content = synthMavenPom(pkgs) + case FormatDotnetCsproj: + content = synthCsproj(pkgs) + case FormatDotnetDirectoryPackagesProps: + content = synthDirectoryPackagesProps(pkgs) + case FormatDotnetPackagesConfig: + content = synthPackagesConfig(pkgs) + default: + return "", fmt.Errorf("sca: unsupported synth format %d", format) + } + if err != nil { + return "", err + } + if writeErr := os.WriteFile(path, content, 0600); writeErr != nil { + return "", writeErr + } + return path, nil +} + +func synthNpm(pkgs []Package) ([]byte, error) { + deps := make(map[string]string, len(pkgs)) + for _, p := range pkgs { + v := p.Version + if v == "" { + v = "latest" + } + deps[p.Name] = v + } + manifest := map[string]any{ + "name": "sca-scan-temp", + "version": "1.0.0", + "dependencies": deps, + } + return json.MarshalIndent(manifest, "", " ") +} + +func synthPypi(pkgs []Package) []byte { + var b strings.Builder + for _, p := range pkgs { + if p.Version != "" { + fmt.Fprintf(&b, "%s==%s\n", p.Name, p.Version) + } else { + fmt.Fprintf(&b, "%s\n", p.Name) + } + } + return []byte(b.String()) +} + +func synthGoMod(pkgs []Package) []byte { + var b strings.Builder + b.WriteString("module sca-scan-temp\n\ngo 1.21\n\nrequire (\n") + for _, p := range pkgs { + v := p.Version + if v == "" { + v = "latest" + } + fmt.Fprintf(&b, "\t%s %s\n", p.Name, v) + } + b.WriteString(")\n") + return []byte(b.String()) +} + +func synthMavenPom(pkgs []Package) []byte { + var b strings.Builder + b.WriteString(` + + 4.0.0 + sca + sca-scan-temp + 1.0.0 + +`) + for _, p := range pkgs { + // Name is "groupId:artifactId". + group, artifact := splitMavenName(p.Name) + fmt.Fprintf(&b, " \n %s\n %s\n %s\n \n", group, artifact, p.Version) + } + b.WriteString(" \n\n") + return []byte(b.String()) +} + +func splitMavenName(name string) (string, string) { + idx := strings.Index(name, ":") + if idx < 0 { + return name, name + } + return name[:idx], name[idx+1:] +} + +func synthCsproj(pkgs []Package) []byte { + var b strings.Builder + b.WriteString("\n \n") + for _, p := range pkgs { + fmt.Fprintf(&b, " \n", p.Name, p.Version) + } + b.WriteString(" \n\n") + return []byte(b.String()) +} + +func synthDirectoryPackagesProps(pkgs []Package) []byte { + var b strings.Builder + b.WriteString("\n \n") + for _, p := range pkgs { + fmt.Fprintf(&b, " \n", p.Name, p.Version) + } + b.WriteString(" \n\n") + return []byte(b.String()) +} + +func synthPackagesConfig(pkgs []Package) []byte { + var b strings.Builder + b.WriteString("\n\n") + for _, p := range pkgs { + fmt.Fprintf(&b, " \n", p.Name, p.Version) + } + b.WriteString("\n") + return []byte(b.String()) +} diff --git a/internal/commands/agenthooks/sca/synth_test.go b/internal/commands/agenthooks/sca/synth_test.go new file mode 100644 index 000000000..3cf8fa346 --- /dev/null +++ b/internal/commands/agenthooks/sca/synth_test.go @@ -0,0 +1,96 @@ +//go:build !integration + +package sca + +import ( + "os" + "testing" + + "github.com/Checkmarx/manifest-parser/pkg/parser" +) + +// roundTrip runs Synthesize, then re-parses the file via manifest-parser, and +// asserts that the set of name+version pairs matches the input. +func roundTrip(t *testing.T, format Format, pkgs []Package) { + t.Helper() + dir, err := os.MkdirTemp("", "synth-test-") + if err != nil { + t.Fatalf("mkdtemp: %v", err) + } + defer os.RemoveAll(dir) + + path, err := Synthesize(format, pkgs, dir) + if err != nil { + t.Fatalf("Synthesize: %v", err) + } + + p := parser.ParsersFactory(path) + if p == nil { + t.Fatalf("manifest-parser has no parser for %s", path) + } + parsed, err := p.Parse(path) + if err != nil { + t.Fatalf("Parse(%s): %v", path, err) + } + + want := make(map[string]string, len(pkgs)) + for _, pkg := range pkgs { + want[pkg.Name] = pkg.Version + } + for _, parsedPkg := range parsed { + v, ok := want[parsedPkg.PackageName] + if !ok { + t.Errorf("unexpected package after parse: %s@%s", parsedPkg.PackageName, parsedPkg.Version) + continue + } + if v != "" && parsedPkg.Version != v { + t.Errorf("%s: version %q after parse, want %q", parsedPkg.PackageName, parsedPkg.Version, v) + } + delete(want, parsedPkg.PackageName) + } + for n := range want { + t.Errorf("package %s missing after parse", n) + } +} + +func TestSynthesize_Npm(t *testing.T) { + roundTrip(t, FormatNpmPackageJson, []Package{ + {Name: "lodash", Version: "4.17.21"}, + {Name: "axios", Version: "1.0.0"}, + {Name: "@types/node", Version: "18.0.0"}, + }) +} + +func TestSynthesize_Pypi(t *testing.T) { + roundTrip(t, FormatPypiRequirements, []Package{ + {Name: "requests", Version: "2.25.1"}, + {Name: "flask", Version: "2.0.0"}, + }) +} + +func TestSynthesize_GoMod(t *testing.T) { + roundTrip(t, FormatGoMod, []Package{ + {Name: "github.com/pkg/errors", Version: "v0.9.1"}, + }) +} + +func TestSynthesize_Csproj(t *testing.T) { + roundTrip(t, FormatDotnetCsproj, []Package{ + {Name: "Newtonsoft.Json", Version: "13.0.1"}, + }) +} + +func TestSynthesize_PackagesConfig(t *testing.T) { + roundTrip(t, FormatDotnetPackagesConfig, []Package{ + {Name: "Newtonsoft.Json", Version: "13.0.1"}, + }) +} + +func TestSynthesize_UnsupportedFormat(t *testing.T) { + dir, _ := os.MkdirTemp("", "synth-test-") + defer os.RemoveAll(dir) + _, err := Synthesize(FormatUnknown, nil, dir) + if err == nil { + t.Errorf("Synthesize(FormatUnknown) returned nil error, want non-nil") + } +} diff --git a/internal/commands/auth.go b/internal/commands/auth.go index 362b74763..dea970040 100644 --- a/internal/commands/auth.go +++ b/internal/commands/auth.go @@ -113,7 +113,7 @@ func NewAuthCommand(authWrapper wrappers.AuthWrapper, telemetryWrapper wrappers. }, RunE: validLogin(telemetryWrapper), } - authCmd.AddCommand(createClientCmd, validLoginCmd) + authCmd.AddCommand(createClientCmd, validLoginCmd, newAuthLoginCommand(), newAuthLogoutCommand()) return authCmd } diff --git a/internal/commands/auth_login.go b/internal/commands/auth_login.go new file mode 100644 index 000000000..2c74e6aca --- /dev/null +++ b/internal/commands/auth_login.go @@ -0,0 +1,256 @@ +package commands + +import ( + "context" + "fmt" + "os" + + "github.com/MakeNowJust/heredoc" + "github.com/checkmarx/ast-cli/internal/logger" + "github.com/checkmarx/ast-cli/internal/params" + "github.com/checkmarx/ast-cli/internal/wrappers" + "github.com/checkmarx/ast-cli/internal/wrappers/configuration" + "github.com/pkg/errors" + "github.com/spf13/cobra" + "github.com/spf13/viper" +) + +// defaultLoginClientID matches the Keycloak client used by the Checkmarx One +// VS Code extension's OAuth flow. Confirmed via the official extension source +// (Checkmarx/ast-vscode-extension, packages/core/src/services/authService.ts). +// This client has localhost callbacks whitelisted across production tenants. +const defaultLoginClientID = "ide-integration" + +// configFilePerm restricts the yaml config file to owner read/write only, since +// after login it holds a long-lived refresh token. Mirrors the 0o600 used for +// the global-session and active-mode files in the wrappers package. +const configFilePerm = 0o600 + +func newAuthLoginCommand() *cobra.Command { + cmd := &cobra.Command{ + Use: "login", + Short: "Authenticate to Checkmarx One via browser-based OAuth", + Long: "Opens the default browser, walks the user through the Checkmarx One IAM login " + + "(including MFA), and persists the resulting refresh token. The --session flag picks " + + "the storage mode: default (yaml) for backward-compatible cross-shell persistence, " + + "'local' for current-shell env-only via Invoke-Expression / eval, or 'global' for a " + + "dedicated disk file shared across shells. Every login revokes any existing token " + + "server-side and clears file storage before issuing the new credential.", + Example: heredoc.Doc(` + # Default (yaml) — saves refresh token to ~/.checkmarx/checkmarxcli.yaml + $ cx auth login --tenant my-tenant + + # Local session mode — refresh token lives in current shell's env var only + # PowerShell: + $ Invoke-Expression (cx auth login --tenant my-tenant --session local) + # bash / zsh: + $ eval "$(cx auth login --tenant my-tenant --session local)" + + # Global session mode — refresh token persists in ~/.checkmarx/session_global, + # accessible to every shell, until explicit logout + $ cx auth login --tenant my-tenant --session global + `), + Annotations: map[string]string{ + "command:doc": heredoc.Doc(` + https://checkmarx.com/resource/documents/en/34965-68627-auth.html + `), + }, + RunE: runAuthLogin, + } + cmd.Flags().Int(params.LoginPortFlag, 0, params.LoginPortFlagUsage) + cmd.Flags().Bool(params.LoginNoBrowserFlag, false, params.LoginNoBrowserFlagUsage) + cmd.Flags().String(params.SessionFlag, "", params.SessionLoginFlagUsage) + return cmd +} + +func runAuthLogin(cmd *cobra.Command, _ []string) error { + // cx auth login starts a new login session. The user's explicit --tenant / + // --base-auth-uri flags must win over the realm URL embedded in any existing + // API key's JWT claims — they may be logging into a different tenant than + // their current credential is for. + viper.Set(params.ApikeyOverrideFlag, true) + + sessionMode, _ := cmd.Flags().GetString(params.SessionFlag) + if err := validateSessionFlag(sessionMode); err != nil { + return err + } + + realmURL, err := wrappers.GetRealmURL() + if err != nil { + return errors.Wrap(err, "failed to resolve IAM realm URL") + } + + // revokeClientID is used ONLY for the best-effort revocation of any + // PRE-EXISTING stored tokens during the nuke phase. It intentionally keeps + // the CX_CLIENT_ID fallback so that a credential originally issued to that + // client can still be revoked. It is NOT used for the interactive login + // below (see the ClientID note on LoginWithPKCE). + revokeClientID := viper.GetString(params.AccessKeyIDConfigKey) + if revokeClientID == "" { + revokeClientID = defaultLoginClientID + } + + port, _ := cmd.Flags().GetInt(params.LoginPortFlag) + noBrowser, _ := cmd.Flags().GetBool(params.LoginNoBrowserFlag) + + // Authenticate FIRST and only touch existing credentials once we hold a + // fresh refresh token. If the browser flow fails or is cancelled (closed + // tab, timeout, port clash, network blip), the user's existing credential + // is left completely intact instead of being wiped before login even runs. + // + // The interactive PKCE flow MUST use the public 'ide-integration' client + // (its localhost callbacks are whitelisted and it needs no client secret). + // CX_CLIENT_ID is a confidential service-account client and cannot complete + // an Authorization Code + PKCE flow, so it is deliberately NOT used here. + tokens, err := wrappers.LoginWithPKCE(context.Background(), wrappers.PKCELoginOptions{ + RealmURL: realmURL, + ClientID: defaultLoginClientID, + Port: port, + OpenBrowser: !noBrowser, + }) + if err != nil { + return err + } + + // Nuke phase: now that a new credential exists, revoke every prior refresh + // token server-side and clear the file storages. Combined with the persist + // step below this leaves exactly one active credential in the storage + // matching --session. + nukeAllStorages(revokeClientID) + + switch sessionMode { + case params.SessionLocalValue: + return persistLocalLogin(cmd, tokens.RefreshToken) + case params.SessionGlobalValue: + return persistGlobalLogin(cmd, tokens.RefreshToken) + default: + return persistYamlLogin(cmd, tokens.RefreshToken) + } +} + +// validateSessionFlag enforces that --session is either unset, "local", or +// "global". Any other value gets a clear error instead of silently falling +// through to default-mode behavior. +func validateSessionFlag(sessionMode string) error { + switch sessionMode { + case "", params.SessionLocalValue, params.SessionGlobalValue: + return nil + default: + return errors.Errorf("invalid --session value %q: must be %q or %q", + sessionMode, params.SessionLocalValue, params.SessionGlobalValue) + } +} + +// nukeAllStorages revokes the tokens the CLI actually owns — the yaml config +// file and the global session file — at IAM (best-effort, via the OAuth 2.0 +// revocation endpoint) and clears those file storages. +// +// The CX_APIKEY environment variable is deliberately left untouched: a child +// process cannot clear a parent shell's env var, and that env value is most +// often a deliberately-provided CI / long-lived credential. Silently revoking +// it server-side would break the caller's pipeline, so we never revoke env. +// +// This is called as the first step of every login (regardless of mode) and +// of every logout, ensuring that the CLI's own file storages hold at most one +// active credential after the operation completes. +func nukeAllStorages(clientID string) { + // Revoke yaml's token first — read the yaml file directly to bypass any + // stale env shadowing in viper's normal lookup. + if yamlRT := wrappers.ReadYamlAPIKey(); yamlRT != "" { + revokeOldRefreshToken(yamlRT, clientID, "yaml") + } + if globalRT, err := wrappers.ReadSessionGlobal(); err == nil && globalRT != "" { + revokeOldRefreshToken(globalRT, clientID, "global") + } + clearFileStorages() +} + +// revokeOldRefreshToken POSTs the given refresh token to the realm extracted +// from its own JWT "aud" claim. Best-effort — failures are logged at verbose +// level so a missing realm claim or a non-2xx response doesn't block the new +// login. +func revokeOldRefreshToken(refreshToken, clientID, sourceLabel string) { + realmURL, err := wrappers.ExtractFromTokenClaims(refreshToken, audClaim) + if err != nil || realmURL == "" { + logger.PrintIfVerbose(fmt.Sprintf("could not extract realm from %s refresh token (skipping revoke): %v", sourceLabel, err)) + return + } + if err := wrappers.RevokeRefreshToken(context.Background(), realmURL, clientID, refreshToken); err != nil { + logger.PrintIfVerbose(fmt.Sprintf("revoke of %s refresh token failed (continuing): %v", sourceLabel, err)) + } +} + +// clearFileStorages empties the yaml cx_apikey field and deletes the global +// session file. Best-effort — failures are logged at verbose level. Env is +// not touched here; that's done via shell-eval emission for local-mode +// logins or by the user closing their shell. +func clearFileStorages() { + if configPath, err := configuration.GetConfigFilePath(); err == nil { + if writeErr := configuration.SafeWriteSingleConfigKeyString(configPath, params.AstAPIKey, ""); writeErr != nil { + logger.PrintIfVerbose(fmt.Sprintf("failed to clear yaml cx_apikey: %v", writeErr)) + } + } + if err := wrappers.ClearSessionGlobal(); err != nil { + logger.PrintIfVerbose(fmt.Sprintf("failed to clear global session file: %v", err)) + } +} + +// persistYamlLogin writes the new refresh token to the yaml config file and +// records yaml as the active mode. The token is NOT echoed to stdout — it is +// already persisted to the config file, and printing it would leak the +// credential into shell history / CI logs. +func persistYamlLogin(cmd *cobra.Command, refreshToken string) error { + configPath, err := configuration.GetConfigFilePath() + if err != nil { + return errors.Wrap(err, "failed to resolve config file path") + } + if err := configuration.SafeWriteSingleConfigKeyString(configPath, params.AstAPIKey, refreshToken); err != nil { + return errors.Wrap(err, "failed to save refresh token to config file") + } + // The config file now holds a long-lived refresh token; restrict it to + // owner read/write only (matching the 0o600 used for the global session + // and active-mode files). On Windows this is a best-effort no-op. + if chErr := os.Chmod(configPath, configFilePerm); chErr != nil { + logger.PrintIfVerbose(fmt.Sprintf("failed to restrict config file permissions: %v", chErr)) + } + if err := wrappers.WriteActiveMode(params.SessionYamlValue); err != nil { + logger.PrintIfVerbose(fmt.Sprintf("failed to write active-mode file: %v", err)) + } + _, _ = fmt.Fprintf(cmd.OutOrStdout(), "Authenticated. Token saved to %s\n", configPath) + return nil +} + +// persistGlobalLogin writes the new refresh token to the dedicated global +// session file and records global as the active mode. No env-var emission — +// global mode is a plain command (no Invoke-Expression wrapper). +func persistGlobalLogin(cmd *cobra.Command, refreshToken string) error { + if err := wrappers.WriteSessionGlobal(refreshToken); err != nil { + return errors.Wrap(err, "failed to write global session file") + } + if err := wrappers.WriteActiveMode(params.SessionGlobalValue); err != nil { + logger.PrintIfVerbose(fmt.Sprintf("failed to write active-mode file: %v", err)) + } + path, _ := wrappers.SessionGlobalFilePath() + _, _ = fmt.Fprintf(cmd.OutOrStdout(), "Authenticated. Token saved to %s (global session — persists across shells until explicit logout).\n", path) + return nil +} + +// persistLocalLogin records local as the active mode and emits a single +// shell-evaluable line to stdout: a defensive reset of CX_APIKEY followed by +// the new refresh-token assignment, separated by `;` so the whole emission +// stays on one line. PowerShell's Invoke-Expression accepts only a single +// string argument, so multi-line stdout would be captured as a string array +// and rejected. Bash's `eval` and fish's `;` statement separator handle the +// same single-line form correctly. Informational text goes to stderr to +// keep stdout strictly evaluable. +func persistLocalLogin(cmd *cobra.Command, refreshToken string) error { + if err := wrappers.WriteActiveMode(params.SessionLocalValue); err != nil { + logger.PrintIfVerbose(fmt.Sprintf("failed to write active-mode file: %v", err)) + } + shell := detectShell() + _, _ = fmt.Fprintf(cmd.OutOrStdout(), "%s; %s\n", + formatEnvAssignment(shell, params.AstAPIKeyEnv, ""), + formatEnvAssignment(shell, params.AstAPIKeyEnv, refreshToken)) + _, _ = fmt.Fprintln(cmd.ErrOrStderr(), "Authenticated. Wrap with Invoke-Expression (PowerShell) or eval (bash) to apply.") + return nil +} diff --git a/internal/commands/auth_login_test.go b/internal/commands/auth_login_test.go new file mode 100644 index 000000000..1c160ef07 --- /dev/null +++ b/internal/commands/auth_login_test.go @@ -0,0 +1,186 @@ +//go:build !integration + +package commands + +import ( + "bytes" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/checkmarx/ast-cli/internal/params" + "github.com/checkmarx/ast-cli/internal/wrappers" + "github.com/checkmarx/ast-cli/internal/wrappers/configuration" + "github.com/spf13/cobra" + "github.com/spf13/viper" +) + +// Full runAuthLogin coverage is out of scope for unit tests: it opens a browser +// and runs the PKCE network exchange (wrappers.LoginWithPKCE). These tests cover +// the deterministic, network-free pieces — the persist* writers, clearFileStorages, +// nukeAllStorages' env-safety, and the universal logout — which is also where the +// security fixes (#4 no token to stdout, #5 env token never revoked) live. + +// withTempConfigDir points viper at a temp config file for one test so the auth +// storage helpers operate on a sandbox instead of the real ~/.checkmarx. Also +// clears CX_APIKEY so env never shadows the sandbox. +func withTempConfigDir(t *testing.T) string { + t.Helper() + dir := t.TempDir() + prev := viper.GetString(params.ConfigFilePathKey) + viper.Set(params.ConfigFilePathKey, filepath.Join(dir, "checkmarxcli.yaml")) + t.Setenv(params.AstAPIKeyEnv, "") + t.Cleanup(func() { viper.Set(params.ConfigFilePathKey, prev) }) + return dir +} + +// newBufferedCmd returns a cobra command whose stdout/stderr are captured. +func newBufferedCmd() (*cobra.Command, *bytes.Buffer, *bytes.Buffer) { + cmd := &cobra.Command{} + var out, errOut bytes.Buffer + cmd.SetOut(&out) + cmd.SetErr(&errOut) + return cmd, &out, &errOut +} + +// TestPersistYamlLogin_DoesNotPrintToken locks in fix #4: the refresh token must +// be saved to the yaml file but never echoed to stdout (it would leak into shell +// history / CI logs). +func TestPersistYamlLogin_DoesNotPrintToken(t *testing.T) { + withTempConfigDir(t) + const token = "super-secret-refresh-token" + + cmd, out, _ := newBufferedCmd() + if err := persistYamlLogin(cmd, token); err != nil { + t.Fatalf("persistYamlLogin failed: %v", err) + } + + stdout := out.String() + if strings.Contains(stdout, token) { + t.Errorf("refresh token leaked to stdout: %q", stdout) + } + if !strings.Contains(stdout, "Authenticated. Token saved to") { + t.Errorf("expected confirmation line, got: %q", stdout) + } + // Token must actually be persisted to the yaml file. + if got := wrappers.ReadYamlAPIKey(); got != token { + t.Errorf("expected token persisted to yaml, got %q", got) + } + if mode, _ := wrappers.ReadActiveMode(); mode != params.SessionYamlValue { + t.Errorf("expected active mode %q, got %q", params.SessionYamlValue, mode) + } +} + +// TestPersistGlobalLogin_WritesFileAndNoToken: global mode persists to the global +// session file and prints only the path — never the token. +func TestPersistGlobalLogin_WritesFileAndNoToken(t *testing.T) { + withTempConfigDir(t) + const token = "global-refresh-token" + + cmd, out, _ := newBufferedCmd() + if err := persistGlobalLogin(cmd, token); err != nil { + t.Fatalf("persistGlobalLogin failed: %v", err) + } + + if strings.Contains(out.String(), token) { + t.Errorf("refresh token leaked to stdout: %q", out.String()) + } + if got, err := wrappers.ReadSessionGlobal(); err != nil || got != token { + t.Errorf("expected token in global session file, got %q (err=%v)", got, err) + } + if mode, _ := wrappers.ReadActiveMode(); mode != params.SessionGlobalValue { + t.Errorf("expected active mode %q, got %q", params.SessionGlobalValue, mode) + } +} + +// TestPersistLocalLogin_EmitsShellEval: local mode intentionally emits a single +// shell-evaluable line (reset + assignment) to stdout — the token IS present +// there by design (it lives only in the shell env). +func TestPersistLocalLogin_EmitsShellEval(t *testing.T) { + withTempConfigDir(t) + const token = "local-refresh-token" + + cmd, out, errOut := newBufferedCmd() + if err := persistLocalLogin(cmd, token); err != nil { + t.Fatalf("persistLocalLogin failed: %v", err) + } + + stdout := out.String() + if !strings.Contains(stdout, params.AstAPIKeyEnv) { + t.Errorf("expected env-var name in stdout, got: %q", stdout) + } + if !strings.Contains(stdout, token) { + t.Errorf("local mode must emit the token for eval, got: %q", stdout) + } + if !strings.Contains(errOut.String(), "Authenticated") { + t.Errorf("expected info line on stderr, got: %q", errOut.String()) + } + if mode, _ := wrappers.ReadActiveMode(); mode != params.SessionLocalValue { + t.Errorf("expected active mode %q, got %q", params.SessionLocalValue, mode) + } +} + +// TestClearFileStorages_ClearsYamlAndGlobal: clearing empties the yaml cx_apikey +// and removes the global session file. +func TestClearFileStorages_ClearsYamlAndGlobal(t *testing.T) { + dir := withTempConfigDir(t) + configPath := filepath.Join(dir, "checkmarxcli.yaml") + if err := configuration.SafeWriteSingleConfigKeyString(configPath, params.AstAPIKey, "yaml-token"); err != nil { + t.Fatalf("setup yaml write failed: %v", err) + } + if err := wrappers.WriteSessionGlobal("global-token"); err != nil { + t.Fatalf("setup global write failed: %v", err) + } + + clearFileStorages() + + if got := wrappers.ReadYamlAPIKey(); got != "" { + t.Errorf("expected yaml cx_apikey cleared, got %q", got) + } + if got, _ := wrappers.ReadSessionGlobal(); got != "" { + t.Errorf("expected global session cleared, got %q", got) + } +} + +// TestNukeAllStorages_DoesNotRevokeOrClearEnv locks in fix #5: an env-var token is +// neither cleared nor touched by the nuke (the CLI can't clear a parent shell's +// env, and it is often a deliberate CI credential). With empty file storages this +// makes no network call. +func TestNukeAllStorages_DoesNotRevokeOrClearEnv(t *testing.T) { + withTempConfigDir(t) + const envToken = "ci-env-refresh-token" + t.Setenv(params.AstAPIKeyEnv, envToken) + + // No yaml or global token present → no revocation network call is attempted. + nukeAllStorages(defaultLoginClientID) + + // The env var must remain exactly as the caller set it. + if got := os.Getenv(params.AstAPIKeyEnv); got != envToken { + t.Errorf("nukeAllStorages must not alter the env token: got %q, want %q", got, envToken) + } +} + +// TestRunAuthLogout_EmptyStorage emits a shell-clear of CX_APIKEY, clears the +// active mode, and makes no network call when no token is stored. +func TestRunAuthLogout_EmptyStorage(t *testing.T) { + withTempConfigDir(t) + if err := wrappers.WriteActiveMode(params.SessionYamlValue); err != nil { + t.Fatalf("setup active mode failed: %v", err) + } + + cmd, out, errOut := newBufferedCmd() + if err := runAuthLogout(cmd, nil); err != nil { + t.Fatalf("runAuthLogout failed: %v", err) + } + + if !strings.Contains(out.String(), params.AstAPIKeyEnv) { + t.Errorf("expected shell-clear of %s on stdout, got: %q", params.AstAPIKeyEnv, out.String()) + } + if !strings.Contains(errOut.String(), "Logged out") { + t.Errorf("expected logout info on stderr, got: %q", errOut.String()) + } + if mode, _ := wrappers.ReadActiveMode(); mode != "" { + t.Errorf("expected active mode cleared after logout, got %q", mode) + } +} diff --git a/internal/commands/auth_logout.go b/internal/commands/auth_logout.go new file mode 100644 index 000000000..441decda8 --- /dev/null +++ b/internal/commands/auth_logout.go @@ -0,0 +1,65 @@ +package commands + +import ( + "fmt" + + "github.com/MakeNowJust/heredoc" + "github.com/checkmarx/ast-cli/internal/logger" + "github.com/checkmarx/ast-cli/internal/params" + "github.com/checkmarx/ast-cli/internal/wrappers" + "github.com/spf13/cobra" + "github.com/spf13/viper" +) + +// audClaim is the OIDC "audience" JWT claim. For Keycloak refresh tokens it +// holds the realm URL — exactly the URL we POST to for revocation. +const audClaim = "aud" + +func newAuthLogoutCommand() *cobra.Command { + return &cobra.Command{ + Use: "logout", + Short: "Revoke the current refresh token and clear stored credentials", + Long: "Revokes the current refresh token at Checkmarx One IAM and clears every storage " + + "location: yaml cx_apikey, the global session file, and emits a shell-evaluable " + + "clear of CX_APIKEY for users who logged in via --session local. One universal " + + "logout — no --session flag needed; the active mode tells the CLI what to clean up.", + Example: heredoc.Doc(` + # Default usage (clears yaml and the global file, revokes server-side) + $ cx auth logout + + # If the current shell was logged in via --session local, also wrap the + # logout with Invoke-Expression so $env:CX_APIKEY gets cleared too + # PowerShell: + $ Invoke-Expression (cx auth logout) + # bash / zsh: + $ eval "$(cx auth logout)" + `), + RunE: runAuthLogout, + } +} + +// runAuthLogout is the universal logout: it nukes every storage location's +// credential (server-side revoke + local clear), deletes the active-mode +// metadata file, and emits a shell-clear line so users who wrap the call +// with Invoke-Expression / eval also have CX_APIKEY cleared in their shell. +func runAuthLogout(cmd *cobra.Command, _ []string) error { + clientID := viper.GetString(params.AccessKeyIDConfigKey) + if clientID == "" { + clientID = defaultLoginClientID + } + + nukeAllStorages(clientID) + + if err := wrappers.ClearActiveMode(); err != nil { + logger.PrintIfVerbose(fmt.Sprintf("failed to remove active-mode file: %v", err)) + } + + // Always emit a shell-clear of CX_APIKEY to stdout. Wrapping the logout + // with Invoke-Expression (PowerShell) or eval (bash) clears the env var + // in the current shell. Without the wrapper the line just prints — no + // harm done for users who didn't use --session local. + shell := detectShell() + _, _ = fmt.Fprintln(cmd.OutOrStdout(), formatEnvAssignment(shell, params.AstAPIKeyEnv, "")) + _, _ = fmt.Fprintln(cmd.ErrOrStderr(), "Logged out. If you used --session local in this shell, wrap with Invoke-Expression (PowerShell) or eval (bash) to clear CX_APIKEY.") + return nil +} diff --git a/internal/commands/auth_session_test.go b/internal/commands/auth_session_test.go new file mode 100644 index 000000000..82af65ee9 --- /dev/null +++ b/internal/commands/auth_session_test.go @@ -0,0 +1,45 @@ +//go:build !integration + +package commands + +import ( + "strings" + "testing" + + "github.com/checkmarx/ast-cli/internal/params" +) + +func TestValidateSessionFlag(t *testing.T) { + cases := []struct { + name string + value string + wantErr bool + errMatch string // substring expected in error message + }{ + {name: "empty is valid (default yaml mode)", value: "", wantErr: false}, + {name: "local is valid", value: params.SessionLocalValue, wantErr: false}, + {name: "global is valid", value: params.SessionGlobalValue, wantErr: false}, + {name: "rejects unknown value", value: "yolo", wantErr: true, errMatch: "invalid --session value"}, + {name: "rejects empty-looking but not equal", value: " ", wantErr: true, errMatch: "invalid --session value"}, + {name: "rejects case mismatch", value: "Local", wantErr: true, errMatch: "invalid --session value"}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + err := validateSessionFlag(tc.value) + if tc.wantErr { + if err == nil { + t.Errorf("expected error for value %q, got nil", tc.value) + return + } + if tc.errMatch != "" && !strings.Contains(err.Error(), tc.errMatch) { + t.Errorf("expected error containing %q, got %q", tc.errMatch, err.Error()) + } + return + } + if err != nil { + t.Errorf("expected no error for value %q, got: %v", tc.value, err) + } + }) + } +} diff --git a/internal/commands/hooks.go b/internal/commands/hooks.go index cfb304b19..567dfd3d1 100644 --- a/internal/commands/hooks.go +++ b/internal/commands/hooks.go @@ -2,6 +2,7 @@ package commands import ( "github.com/MakeNowJust/heredoc" + "github.com/checkmarx/ast-cli/internal/logger" "github.com/checkmarx/ast-cli/internal/params" "github.com/checkmarx/ast-cli/internal/wrappers" "github.com/pkg/errors" @@ -9,16 +10,17 @@ import ( ) // NewHooksCommand creates the hooks command with pre-commit subcommand -func NewHooksCommand(jwtWrapper wrappers.JWTWrapper, featureFlagsWrapper wrappers.FeatureFlagsWrapper) *cobra.Command { +func NewHooksCommand(jwtWrapper wrappers.JWTWrapper, featureFlagsWrapper wrappers.FeatureFlagsWrapper, realtimeScannerWrapper wrappers.RealtimeScannerWrapper, telemetryWrapper wrappers.TelemetryWrapper) *cobra.Command { hooksCmd := &cobra.Command{ Use: "hooks", - Short: "Manage Git hooks", - Long: "The hooks command enables the ability to manage Git hooks for Checkmarx One.", + Short: "Manage Git hooks and AI coding agent hooks", + Long: "The hooks command manages Git hooks for secret detection and AI coding agent hooks for Claude, Cursor, Windsurf, Factory Droid, Gemini, and GitHub Copilot CLI.", Example: heredoc.Doc( ` $ cx hooks pre-commit secrets-install-git-hook $ cx hooks pre-commit secrets-scan $ cx hooks pre-receive secrets-scan + $ cx hooks agenthooks install `, ), Annotations: map[string]string{ @@ -30,9 +32,17 @@ func NewHooksCommand(jwtWrapper wrappers.JWTWrapper, featureFlagsWrapper wrapper }, } - // Add pre-commit and pre-receive subcommand + // Add pre-commit, pre-receive, and agenthooks management subcommands hooksCmd.AddCommand(PreCommitCommand(jwtWrapper, featureFlagsWrapper)) hooksCmd.AddCommand(PreReceiveCommand(jwtWrapper, featureFlagsWrapper)) + hooksCmd.AddCommand(NewAgentHooksCommand()) + + // Register all hidden hook dispatch subcommands so that cx itself acts as + // the hook binary. Agents invoke: cx hooks + // e.g. cx hooks claude-pre-tool-use + for _, dispatchCmd := range HookDispatchCommands(jwtWrapper, featureFlagsWrapper, realtimeScannerWrapper, telemetryWrapper) { + hooksCmd.AddCommand(dispatchCmd) + } return hooksCmd } @@ -46,12 +56,17 @@ func validateLicense(jwtWrapper wrappers.JWTWrapper, featureFlagsWrapper wrapper licenseName = params.EnterpriseSecretsLabel } + logger.PrintIfVerbose("hooks: checking license for " + licenseName) + allowed, err := jwtWrapper.IsAllowedEngine(licenseName) if err != nil { + logger.PrintIfVerbose("hooks: authentication failed during license check - " + err.Error()) return errors.Wrapf(err, "Failed checking license") } if !allowed { + logger.PrintIfVerbose("hooks: license validation failed - " + licenseName + " not found in allowed engines") return errors.Errorf("Error: License validation failed. Please verify your CxOne license includes %s.", licenseName) } + logger.PrintIfVerbose("hooks: license validated successfully for " + licenseName) return nil } diff --git a/internal/commands/ignore_vulnerability.go b/internal/commands/ignore_vulnerability.go new file mode 100644 index 000000000..6e2d9f35c --- /dev/null +++ b/internal/commands/ignore_vulnerability.go @@ -0,0 +1,185 @@ +package commands + +import ( + "fmt" + "io" + "os" + "strings" + + "github.com/MakeNowJust/heredoc" + commonParams "github.com/checkmarx/ast-cli/internal/params" + "github.com/checkmarx/ast-cli/internal/services/realtimeengine/ignore" + "github.com/checkmarx/ast-cli/internal/wrappers" + "github.com/checkmarx/ast-cli/internal/wrappers/utils" + "github.com/pkg/errors" + "github.com/spf13/cobra" +) + +// NewIgnoreVulnerabilityCommand creates the `cx ignore-vulnerability` command, which creates or +// updates the realtime ignore file from a scan finding (ignore), or removes a matching entry +// (--remove, i.e. revive/review). The file it writes is consumed by the realtime scans via +// --ignored-file-path. This is the local realtime ignore — distinct from platform `cx triage`. +func NewIgnoreVulnerabilityCommand(telemetryWrapper wrappers.TelemetryWrapper) *cobra.Command { + cmd := &cobra.Command{ + Use: "ignore-vulnerability", + Hidden: true, + Short: "Ignore (or revive) a realtime-scan vulnerability via the ignore file", + Long: heredoc.Doc(` + Create or update the realtime ignore file from a scan finding so the realtime scans + (oss-realtime, secrets-realtime, containers-realtime, iac-realtime, asca) suppress it. + + Pass the finding JSON emitted by the scan via --data (inline, @file, or - for stdin). + Use --remove to revive (un-ignore) a previously ignored finding. + `), + Example: heredoc.Doc(` + $ cx ignore-vulnerability --scan-type oss --data '{"PackageManager":"npm","PackageName":"lodash","PackageVersion":"4.17.20"}' + $ cx ignore-vulnerability --scan-type asca --data @finding.json + $ cx ignore-vulnerability --remove --scan-type oss --data '{"PackageManager":"npm","PackageName":"lodash","PackageVersion":"4.17.20"}' + `), + RunE: runIgnoreVulnerability(telemetryWrapper), + } + + cmd.Flags().String(commonParams.ScanTypeFlag, "", "Scan type of the finding: oss (alias sca), secrets, containers, iac, asca") + cmd.Flags().String(commonParams.IgnoreDataFlag, "", "Finding JSON from the realtime scan output. Inline JSON, @ to read a file, or - for stdin") + cmd.Flags().Bool(commonParams.IgnoreRemoveFlag, false, "Revive (un-ignore): remove the matching entry from the ignore file") + cmd.Flags().String(commonParams.IgnoredFilePathFlag, ignore.DefaultPath(), "Path to the ignore file to create/update") + + _ = cmd.MarkFlagRequired(commonParams.ScanTypeFlag) + _ = cmd.MarkFlagRequired(commonParams.IgnoreDataFlag) + + return cmd +} + +func runIgnoreVulnerability(telemetryWrapper wrappers.TelemetryWrapper) func(cmd *cobra.Command, args []string) error { + return func(cmd *cobra.Command, _ []string) error { + scanType, _ := cmd.Flags().GetString(commonParams.ScanTypeFlag) + dataArg, _ := cmd.Flags().GetString(commonParams.IgnoreDataFlag) + remove, _ := cmd.Flags().GetBool(commonParams.IgnoreRemoveFlag) + ignoredFilePath, _ := cmd.Flags().GetString(commonParams.IgnoredFilePathFlag) + + if !ignore.IsValidScanType(scanType) { + return errors.Errorf("invalid --scan-type %q; expected one of: oss, sca, secrets, containers, iac, asca", scanType) + } + + data, err := readIgnoreData(cmd, dataArg) + if err != nil { + return err + } + + entries, err := ignore.BuildEntries(scanType, data) + if err != nil { + return err + } + + list, err := ignore.Load(ignoredFilePath) + if err != nil { + return errors.Wrapf(err, "failed to read ignore file %s", ignoredFilePath) + } + + changed := 0 + for _, entry := range entries { + var ok bool + if remove { + list, ok, err = ignore.Remove(list, entry) + } else { + list, ok, err = ignore.Append(list, entry) + } + if err != nil { + return err + } + if ok { + changed++ + } + } + + if err = ignore.Save(ignoredFilePath, list); err != nil { + return errors.Wrapf(err, "failed to write ignore file %s", ignoredFilePath) + } + + action := "ignored" + if remove { + action = "revived" + } + if !remove && changed > 0 { + logIgnoreTelemetry(telemetryWrapper, scanType) + } + _, _ = fmt.Fprintf(cmd.OutOrStdout(), + "%s %d vulnerability(ies); %d entr%s now in %s\n", + action, changed, len(list), plural(len(list)), ignoredFilePath) + return nil + } +} + +// readIgnoreData resolves --data from inline JSON, @, or - (stdin). +func readIgnoreData(cmd *cobra.Command, dataArg string) ([]byte, error) { + switch { + case dataArg == "": + return nil, errors.New("--data is required") + case dataArg == "-": + data, err := io.ReadAll(cmd.InOrStdin()) + if err != nil { + return nil, errors.Wrap(err, "failed to read --data from stdin") + } + return data, nil + case strings.HasPrefix(dataArg, "@"): + path := strings.TrimPrefix(dataArg, "@") + data, err := os.ReadFile(path) + if err != nil { + return nil, errors.Wrapf(err, "failed to read --data file %s", path) + } + return data, nil + default: + return []byte(dataArg), nil + } +} + +func plural(n int) string { + if n == 1 { + return "y" + } + return "ies" +} + +// logIgnoreTelemetry sends telemetry when an agent ignores a finding via the realtime ignore file. +func logIgnoreTelemetry(telemetryWrapper wrappers.TelemetryWrapper, scanType string) { + if telemetryWrapper == nil { + return + } + + aiProvider := utils.GetOptionalParam("aiProvider") + engine := engineName(scanType) + telemetryData := &wrappers.DataForAITelemetry{ + AIProvider: aiProvider, + Agent: aiProvider + "-cli", + Engine: engine, + ScanType: strings.ToLower(engine), + UniqueID: wrappers.GetUniqueID(), + Type: "hooks-ignore", + SubType: "ignorePackage", + } + + if err := telemetryWrapper.SendAIDataToLog(telemetryData); err != nil { + // fail-open + } +} + +// engineName maps a --scan-type value to the capitalized engine label used in telemetry +// (matching the convention in agenthooks/cx/hooks.go, e.g. "Oss", "Asca", "Sca"). +func engineName(scanType string) string { + switch strings.ToLower(strings.TrimSpace(scanType)) { + case ignore.ScanTypeOSS: + return "Oss" + case ignore.ScanTypeSCA: + return "Sca" + case ignore.ScanTypeSecrets: + return "Secrets" + case ignore.ScanTypeContainers: + return "Containers" + case ignore.ScanTypeIaC: + return "Iac" + case ignore.ScanTypeASCA: + return "Asca" + default: + return scanType + } +} diff --git a/internal/commands/ignore_vulnerability_test.go b/internal/commands/ignore_vulnerability_test.go new file mode 100644 index 000000000..fb03eee2c --- /dev/null +++ b/internal/commands/ignore_vulnerability_test.go @@ -0,0 +1,112 @@ +//go:build !integration + +package commands + +import ( + "bytes" + "encoding/json" + "os" + "path/filepath" + "testing" + + "github.com/checkmarx/ast-cli/internal/services/realtimeengine/ignore" + "github.com/checkmarx/ast-cli/internal/wrappers/mock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func runIgnoreVulnCmd(stdin string, args ...string) (string, error) { + cmd := NewIgnoreVulnerabilityCommand(mock.TelemetryMockWrapper{}) + var out bytes.Buffer + cmd.SetOut(&out) + cmd.SetErr(&out) + if stdin != "" { + cmd.SetIn(bytes.NewBufferString(stdin)) + } + cmd.SetArgs(args) + err := cmd.Execute() + return out.String(), err +} + +func TestIgnoreVulnerability_AddIsIdempotentThenRemove(t *testing.T) { + file := filepath.Join(t.TempDir(), "ignore.json") + finding := `{"PackageManager":"npm","PackageName":"lodash","PackageVersion":"4.17.20"}` + + _, err := runIgnoreVulnCmd("", "--scan-type", "oss", "--data", finding, "--ignored-file-path", file) + require.NoError(t, err) + list, err := ignore.Load(file) + require.NoError(t, err) + assert.Len(t, list, 1) + + // adding the same finding again does not duplicate + _, err = runIgnoreVulnCmd("", "--scan-type", "oss", "--data", finding, "--ignored-file-path", file) + require.NoError(t, err) + list, _ = ignore.Load(file) + assert.Len(t, list, 1) + + // remove (revive) clears it + _, err = runIgnoreVulnCmd("", "--scan-type", "oss", "--data", finding, "--ignored-file-path", file, "--remove") + require.NoError(t, err) + list, _ = ignore.Load(file) + assert.Empty(t, list) +} + +func TestIgnoreVulnerability_DataFromFile(t *testing.T) { + dir := t.TempDir() + findingFile := filepath.Join(dir, "finding.json") + require.NoError(t, os.WriteFile(findingFile, []byte(`{"Title":"github-pat","SecretValue":"ghp_x"}`), 0o600)) + file := filepath.Join(dir, "ignore.json") + + _, err := runIgnoreVulnCmd("", "--scan-type", "secrets", "--data", "@"+findingFile, "--ignored-file-path", file) + require.NoError(t, err) + list, _ := ignore.Load(file) + assert.Len(t, list, 1) +} + +func TestIgnoreVulnerability_DataFromStdin(t *testing.T) { + file := filepath.Join(t.TempDir(), "ignore.json") + _, err := runIgnoreVulnCmd(`{"ImageName":"ubuntu","ImageTag":"14.04"}`, + "--scan-type", "containers", "--data", "-", "--ignored-file-path", file) + require.NoError(t, err) + list, _ := ignore.Load(file) + assert.Len(t, list, 1) +} + +func TestIgnoreVulnerability_InvalidScanType(t *testing.T) { + file := filepath.Join(t.TempDir(), "ignore.json") + _, err := runIgnoreVulnCmd("", "--scan-type", "sast", "--data", `{}`, "--ignored-file-path", file) + require.Error(t, err) +} + +// Mixed-file regression: all 5 engine shapes coexist in one file (the contract the IDE plugins +// and the realtime engines rely on), and the ASCA entry is PascalCase with numeric fields. +func TestIgnoreVulnerability_MixedFile_AllEngines(t *testing.T) { + file := filepath.Join(t.TempDir(), "ignore.json") + cases := []struct{ scanType, data string }{ + {"oss", `{"PackageManager":"npm","PackageName":"lodash","PackageVersion":"4.17.20"}`}, + {"secrets", `{"Title":"github-pat","SecretValue":"ghp_x"}`}, + {"containers", `{"ImageName":"ubuntu","ImageTag":"14.04"}`}, + {"iac", `{"Title":"Missing User Instruction","SimilarityID":"abc123"}`}, + {"asca", `{"file_name":"server.py","line":77,"rule_id":5004}`}, + } + for _, c := range cases { + _, err := runIgnoreVulnCmd("", "--scan-type", c.scanType, "--data", c.data, "--ignored-file-path", file) + require.NoErrorf(t, err, "scan-type %s", c.scanType) + } + + list, err := ignore.Load(file) + require.NoError(t, err) + require.Len(t, list, 5) + + ascaFound := false + for _, e := range list { + var m map[string]any + require.NoError(t, json.Unmarshal(e, &m)) + if fn, ok := m["FileName"]; ok { + assert.Equal(t, "server.py", fn) + assert.EqualValues(t, 5004, m["RuleID"]) + ascaFound = true + } + } + assert.True(t, ascaFound, "ASCA entry must be present as PascalCase FileName/RuleID") +} diff --git a/internal/commands/pre_commit_test.go b/internal/commands/pre_commit_test.go index c77c54c50..00001f798 100644 --- a/internal/commands/pre_commit_test.go +++ b/internal/commands/pre_commit_test.go @@ -1,3 +1,5 @@ +//go:build !integration + package commands import ( @@ -11,7 +13,9 @@ import ( func TestNewHooksCommand(t *testing.T) { mockJWT := &mock.JWTMockWrapper{} mockFF := &mock.FeatureFlagsMockWrapper{} - cmd := NewHooksCommand(mockJWT, mockFF) + mockRealtime := &mock.RealtimeScannerMockWrapper{} + mockTelemetry := &mock.TelemetryMockWrapper{} + cmd := NewHooksCommand(mockJWT, mockFF, mockRealtime, mockTelemetry) assert.NotNil(t, cmd) assert.Equal(t, "hooks", cmd.Use) diff --git a/internal/commands/root.go b/internal/commands/root.go index 6f3503036..bad2e9ed3 100644 --- a/internal/commands/root.go +++ b/internal/commands/root.go @@ -10,6 +10,7 @@ import ( "strings" "github.com/MakeNowJust/heredoc" + cxmcp "github.com/checkmarx/ast-cli/internal/commands/agenthooks/mcp" "github.com/checkmarx/ast-cli/internal/commands/dast" "github.com/checkmarx/ast-cli/internal/commands/util" "github.com/checkmarx/ast-cli/internal/commands/util/printer" @@ -246,8 +247,14 @@ func NewAstCLI( triageCmd := NewResultsPredicatesCommand(resultsPredicatesWrapper, featureFlagsWrapper, customStatesWrapper) chatCmd := NewChatCommand(chatWrapper, tenantWrapper) - hooksCmd := NewHooksCommand(jwtWrapper, featureFlagsWrapper) + hooksCmd := NewHooksCommand(jwtWrapper, featureFlagsWrapper, realTimeWrapper, telemetryWrapper) telemetryCmd := NewTelemetryCommand(telemetryWrapper) + ignoreVulnerabilityCmd := NewIgnoreVulnerabilityCommand() + + // MCP server — directly uses the exported guardrail functions from agenthooks.go. + mcpServerCmd := cxmcp.NewMCPCommand(params.Version, func() bool { return isLicensed(jwtWrapper) }) + + ignoreVulnerabilityCmd := NewIgnoreVulnerabilityCommand(telemetryWrapper) rootCmd.AddCommand( scanCmd, projectCmd, @@ -261,6 +268,8 @@ func NewAstCLI( chatCmd, hooksCmd, telemetryCmd, + ignoreVulnerabilityCmd, + mcpServerCmd, ) rootCmd.SilenceUsage = true @@ -429,7 +438,7 @@ func setLogOutputFromFlag(flag, dirPath string) error { } else { multiWriter = io.MultiWriter(file) } - log.SetOutput(multiWriter) + logger.SetOutput(multiWriter) return nil } func CheckPreferredCredentials(cmd *cobra.Command) { diff --git a/internal/commands/scan.go b/internal/commands/scan.go index 1ea696e8f..74a6c5afe 100644 --- a/internal/commands/scan.go +++ b/internal/commands/scan.go @@ -486,7 +486,7 @@ func scanASCASubCommand(jwtWrapper wrappers.JWTWrapper, featureFlagsWrapper wrap "The file source should be the path to a single file", ) - scanASCACmd.PersistentFlags().String(commonParams.IgnoredFilePathFlag, "", "Path to ignored secrets file") + scanASCACmd.PersistentFlags().String(commonParams.IgnoredFilePathFlag, "", "Path to a JSON file listing ignored ASCA findings") scanASCACmd.PersistentFlags().String(commonParams.ASCALocationFlag, "", "Path to custom location where ASCA engine is installed") _ = viper.BindPFlag(commonParams.ASCALocationKey, scanASCACmd.PersistentFlags().Lookup(commonParams.ASCALocationFlag)) diff --git a/internal/commands/scan_test.go b/internal/commands/scan_test.go index 6d029bb34..332958806 100644 --- a/internal/commands/scan_test.go +++ b/internal/commands/scan_test.go @@ -13,7 +13,6 @@ import ( "reflect" "strings" "testing" - "time" "github.com/checkmarx/ast-cli/internal/commands/util" errorConstants "github.com/checkmarx/ast-cli/internal/constants/errors" @@ -5123,14 +5122,12 @@ func TestGetGitCommitHistoryValue_WithWarnings(t *testing.T) { } func setupMockAccessToken() { - wrappers.CachedAccessToken = "mock-token-for-testing" - wrappers.CachedAccessTime = time.Now() + wrappers.SetCachedAccessTokenForTest("mock-token-for-testing") viper.Set(commonParams.TokenExpirySecondsKey, 300) } func cleanupMockAccessToken() { - wrappers.CachedAccessToken = "" - wrappers.CachedAccessTime = time.Time{} + wrappers.SetCachedAccessTokenForTest("") wrappers.ClearCache() // Reset to default value (300 seconds as per params/binds.go) diff --git a/internal/commands/shell_output.go b/internal/commands/shell_output.go new file mode 100644 index 000000000..5091afd73 --- /dev/null +++ b/internal/commands/shell_output.go @@ -0,0 +1,47 @@ +package commands + +import ( + "fmt" + "os" + "runtime" + "strings" +) + +// detectShell returns the user's likely shell so session-mode login/logout +// can emit env-var assignment lines in the right syntax. PowerShell is +// detected via PSModulePath (present in PowerShell sessions, absent in +// cmd.exe and *nix shells). Bash/zsh/fish are detected via SHELL. +// Defaults: PowerShell on Windows, bash elsewhere. +func detectShell() string { + if os.Getenv("PSModulePath") != "" { + return "powershell" + } + shell := strings.ToLower(os.Getenv("SHELL")) + switch { + case strings.Contains(shell, "fish"): + return "fish" + case strings.Contains(shell, "bash"), strings.Contains(shell, "zsh"): + return "bash" + } + if runtime.GOOS == "windows" { + return "powershell" + } + return "bash" +} + +// formatEnvAssignment returns a shell-evaluable env var assignment line. +// Examples: +// +// powershell → $env:CX_APIKEY = "value" +// bash/zsh → export CX_APIKEY="value" +// fish → set -gx CX_APIKEY "value" +func formatEnvAssignment(shell, name, value string) string { + switch shell { + case "powershell": + return fmt.Sprintf(`$env:%s = "%s"`, name, value) + case "fish": + return fmt.Sprintf(`set -gx %s "%s"`, name, value) + default: + return fmt.Sprintf(`export %s="%s"`, name, value) + } +} diff --git a/internal/commands/shell_output_test.go b/internal/commands/shell_output_test.go new file mode 100644 index 000000000..1fc24d93d --- /dev/null +++ b/internal/commands/shell_output_test.go @@ -0,0 +1,60 @@ +//go:build !integration + +package commands + +import ( + "testing" +) + +func TestDetectShell(t *testing.T) { + cases := []struct { + name string + psModule string // value for PSModulePath env + shell string // value for SHELL env + want string + }{ + {name: "PSModulePath set → powershell", psModule: "C:\\Program Files\\PowerShell\\Modules", shell: "", want: "powershell"}, + {name: "PSModulePath wins over SHELL", psModule: "C:\\Program Files\\PowerShell\\Modules", shell: "/usr/bin/bash", want: "powershell"}, + {name: "bash via SHELL", psModule: "", shell: "/usr/bin/bash", want: "bash"}, + {name: "zsh via SHELL", psModule: "", shell: "/usr/bin/zsh", want: "bash"}, + {name: "fish via SHELL", psModule: "", shell: "/usr/local/bin/fish", want: "fish"}, + {name: "fish wins over bash substring matching", psModule: "", shell: "/usr/local/bin/fishtank", want: "fish"}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Setenv("PSModulePath", tc.psModule) + t.Setenv("SHELL", tc.shell) + got := detectShell() + if got != tc.want { + t.Errorf("detectShell() = %q, want %q", got, tc.want) + } + }) + } +} + +func TestFormatEnvAssignment(t *testing.T) { + cases := []struct { + name string + shell string + key string + value string + want string + }{ + {name: "powershell with token", shell: "powershell", key: "CX_APIKEY", value: "abc.def", want: `$env:CX_APIKEY = "abc.def"`}, + {name: "powershell with empty value clears", shell: "powershell", key: "CX_APIKEY", value: "", want: `$env:CX_APIKEY = ""`}, + {name: "bash with token", shell: "bash", key: "CX_APIKEY", value: "abc.def", want: `export CX_APIKEY="abc.def"`}, + {name: "bash with empty value clears", shell: "bash", key: "CX_APIKEY", value: "", want: `export CX_APIKEY=""`}, + {name: "fish with token", shell: "fish", key: "CX_APIKEY", value: "abc.def", want: `set -gx CX_APIKEY "abc.def"`}, + {name: "unknown shell falls back to bash syntax", shell: "made-up-shell", key: "CX_APIKEY", value: "abc.def", want: `export CX_APIKEY="abc.def"`}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + got := formatEnvAssignment(tc.shell, tc.key, tc.value) + if got != tc.want { + t.Errorf("formatEnvAssignment(%q, %q, %q) = %q, want %q", tc.shell, tc.key, tc.value, got, tc.want) + } + }) + } +} diff --git a/internal/commands/util/configuration_test.go b/internal/commands/util/configuration_test.go index a465ccd68..560cd84f6 100644 --- a/internal/commands/util/configuration_test.go +++ b/internal/commands/util/configuration_test.go @@ -139,6 +139,73 @@ func TestWriteSingleConfigKeyStringNonExistingFile_CreatingTheFileAndWritesTheKe asserts.NotNil(t, file) } +func TestLoadConfigEmptyFile_ZeroByteConfig_ReturnsEmptyMapWithoutError(t *testing.T) { + configFilePath := "empty-config-file.yaml" + + // A zero-byte config file is the state that previously caused + // "error decoding YAML: EOF" when cx auth login persisted a fresh token. + file, err := os.Create(configFilePath) + assert.NilError(t, err) + _ = file.Close() + defer func() { + _ = os.Remove(configFilePath) + _ = os.Remove(configFilePath + ".lock") + }() + + loaded, err := configuration.LoadConfig(configFilePath) + assert.NilError(t, err) + asserts.Equal(t, 0, len(loaded), "an empty file must load as an empty config, not an error") +} + +func TestWriteSingleConfigKeyStringEmptyFile_ZeroByteConfig_WritesKeyWithoutEOFError(t *testing.T) { + configFilePath := "empty-config-write.yaml" + + file, err := os.Create(configFilePath) + assert.NilError(t, err) + _ = file.Close() + defer func() { + _ = os.Remove(configFilePath) + _ = os.Remove(configFilePath + ".lock") + }() + + // Previously failed with "error loading config: error decoding YAML: EOF". + err = configuration.SafeWriteSingleConfigKeyString(configFilePath, cxScsScanOverviewPath, defaultScsScanOverviewPath) + assert.NilError(t, err) + + config, err := configuration.LoadConfig(configFilePath) + assert.NilError(t, err) + asserts.Equal(t, defaultScsScanOverviewPath, config[cxScsScanOverviewPath]) +} + +func TestLoadConfigMalformedYaml_CorruptConfig_StillReturnsError(t *testing.T) { + configFilePath := "malformed-config.yaml" + // Only the empty-file (io.EOF) case is special-cased. A genuinely corrupt + // config must still error — this pins that narrow behavior so a future + // change can't broaden it into silently swallowing (and then truncating) + // a real config. + err := os.WriteFile(configFilePath, []byte("foo: [bar"), 0600) + assert.NilError(t, err) + defer func() { _ = os.Remove(configFilePath) }() + + cfg, err := configuration.LoadConfig(configFilePath) + asserts.NotNil(t, err) + asserts.Nil(t, cfg) + asserts.True(t, strings.Contains(err.Error(), "error decoding YAML"), "corrupt YAML must still report a decode error") +} + +func TestLoadConfigCommentOnlyFile_EffectivelyEmpty_ReturnsEmptyMap(t *testing.T) { + configFilePath := "comment-only-config.yaml" + // A comment-only / whitespace-only file decodes to io.EOF in yaml.v3, the + // same as a zero-byte file — it must load as an empty config, not an error. + err := os.WriteFile(configFilePath, []byte("# only a comment\n"), 0600) + assert.NilError(t, err) + defer func() { _ = os.Remove(configFilePath) }() + + cfg, err := configuration.LoadConfig(configFilePath) + assert.NilError(t, err) + asserts.Equal(t, 0, len(cfg)) +} + func TestChangedOnlyScsScanOverviewPathInConfigFile_ConfigFileExistsWithDefaultValues_OnlyScsScanOverviewPathChangedSuccess(t *testing.T) { err := configuration.LoadConfiguration() assert.NilError(t, err) diff --git a/internal/params/flags.go b/internal/params/flags.go index 622c8a010..2b362bdc9 100644 --- a/internal/params/flags.go +++ b/internal/params/flags.go @@ -2,6 +2,19 @@ package params // Flags const ( + // OAuth browser login (cx auth login) + session storage modes + LoginPortFlag = "port" + LoginPortFlagUsage = "Local port for the OAuth callback listener (0 = pick a free port)" + LoginNoBrowserFlag = "no-browser" + LoginNoBrowserFlagUsage = "Print the authorization URL instead of opening a browser" + SessionFlag = "session" + SessionLocalValue = "local" + SessionGlobalValue = "global" + SessionYamlValue = "yaml" + SessionGlobalFileName = "session_global" + ActiveModeFileName = "active_mode" + SessionLoginFlagUsage = "Session mode: 'local' keeps the refresh token only in the current shell's environment (requires Invoke-Expression / eval wrapper); 'global' persists it to a dedicated file readable by every shell on the machine until explicit logout." + AllStatesFlag = "all" AgentFlag = "agent" AiProviderFlag = "ai-provider" @@ -111,6 +124,8 @@ const ( ProjectName = "project-name" ScanTypes = "scan-types" ScanTypeFlag = "scan-type" + IgnoreDataFlag = "data" + IgnoreRemoveFlag = "remove" ScanResubmit = "resubmit" KicsRealtimeFile = "file" KicsRealtimeEngine = "engine" diff --git a/internal/services/asca.go b/internal/services/asca.go index 9f051acb0..fb55c664e 100644 --- a/internal/services/asca.go +++ b/internal/services/asca.go @@ -149,7 +149,9 @@ func executeScan(ascaWrapper grpcs.AscaWrapper, filePath, ignoredFilePath string if ignoredFilePath != "" { ignoredFindings, err := loadIgnoredAscaFindings(ignoredFilePath) - if err == nil { + if err != nil { + logger.PrintfIfVerbose("asca: failed to load ignore file %s: %v; continuing without ignore filtering", ignoredFilePath, err) + } else { ignoreMap := buildAscaIgnoreMap(ignoredFindings) scanResult.ScanDetails = filterIgnoredAscaFindings(scanResult.ScanDetails, ignoreMap) } diff --git a/internal/services/realtimeengine/containersrealtime/containers-realtime.go b/internal/services/realtimeengine/containersrealtime/containers-realtime.go index a035e80a1..afdc026ed 100644 --- a/internal/services/realtimeengine/containersrealtime/containers-realtime.go +++ b/internal/services/realtimeengine/containersrealtime/containers-realtime.go @@ -106,10 +106,11 @@ func (c *ContainersRealtimeService) RunContainersRealtimeScan(filePath, ignoredF if ignoredFilePath != "" { ignored, err := loadIgnoredContainerFindings(ignoredFilePath) if err != nil { - return nil, errorconstants.NewRealtimeEngineError("failed to load ignored containers").Error() + logger.PrintfIfVerbose("containers-realtime: failed to load ignore file %s: %v; continuing without ignore filtering", ignoredFilePath, err) + } else { + ignoreMap := buildContainerIgnoreMap(ignored) + results.Images = filterIgnoredContainers(results.Images, ignoreMap) } - ignoreMap := buildContainerIgnoreMap(ignored) - results.Images = filterIgnoredContainers(results.Images, ignoreMap) } return results, nil diff --git a/internal/services/realtimeengine/iacrealtime/iac-realtime.go b/internal/services/realtimeengine/iacrealtime/iac-realtime.go index 078276033..87c63677a 100644 --- a/internal/services/realtimeengine/iacrealtime/iac-realtime.go +++ b/internal/services/realtimeengine/iacrealtime/iac-realtime.go @@ -10,6 +10,7 @@ import ( "runtime" "time" + "github.com/checkmarx/ast-cli/internal/logger" "github.com/checkmarx/ast-cli/internal/services/realtimeengine" "github.com/checkmarx/ast-cli/internal/wrappers" "github.com/pkg/errors" @@ -102,10 +103,11 @@ func (svc *IacRealtimeService) RunIacRealtimeScan(filePath, engine, ignoredFileP if ignoredFilePath != "" { ignored, err := loadIgnoredIacFindings(ignoredFilePath) if err != nil { - return nil, errorconstants.NewRealtimeEngineError("failed to load ignored IaC findings").Error() + logger.PrintfIfVerbose("iac-realtime: failed to load ignore file %s: %v; continuing without ignore filtering", ignoredFilePath, err) + } else { + ignoreMap := buildIgnoreMap(ignored) + results = filterIgnoredFindings(results, ignoreMap) } - ignoreMap := buildIgnoreMap(ignored) - results = filterIgnoredFindings(results, ignoreMap) } return results, nil diff --git a/internal/services/realtimeengine/ignore/builder.go b/internal/services/realtimeengine/ignore/builder.go new file mode 100644 index 000000000..09eb914a1 --- /dev/null +++ b/internal/services/realtimeengine/ignore/builder.go @@ -0,0 +1,169 @@ +package ignore + +import ( + "bytes" + "encoding/json" + "strings" + + "github.com/checkmarx/ast-cli/internal/services/realtimeengine/containersrealtime" + "github.com/checkmarx/ast-cli/internal/services/realtimeengine/iacrealtime" + "github.com/checkmarx/ast-cli/internal/services/realtimeengine/ossrealtime" + "github.com/checkmarx/ast-cli/internal/services/realtimeengine/secretsrealtime" + "github.com/checkmarx/ast-cli/internal/wrappers/grpcs" + "github.com/pkg/errors" +) + +// Scan-type identifiers accepted by --scan-type. "sca" is an alias for "oss". +const ( + ScanTypeOSS = "oss" + ScanTypeSCA = "sca" + ScanTypeSecrets = "secrets" + ScanTypeContainers = "containers" + ScanTypeIaC = "iac" + ScanTypeASCA = "asca" +) + +// wrapperKeys are the array fields under which the realtime scans nest their findings. A finding +// payload may be the full scan output, a bare array, or a single finding object. +var wrapperKeys = []string{"Packages", "Images", "scan_details"} + +// IsValidScanType reports whether s is an accepted --scan-type value. +func IsValidScanType(s string) bool { + switch normalizeScanType(s) { + case ScanTypeOSS, ScanTypeSecrets, ScanTypeContainers, ScanTypeIaC, ScanTypeASCA: + return true + default: + return false + } +} + +func normalizeScanType(s string) string { + s = strings.ToLower(strings.TrimSpace(s)) + if s == ScanTypeSCA { + return ScanTypeOSS + } + return s +} + +// BuildEntries converts a finding payload (single finding, array, or full scan output) into the +// lean ignore entries for the given scan type. +func BuildEntries(scanType string, data []byte) ([]any, error) { + st := normalizeScanType(scanType) + findings, err := extractFindings(data) + if err != nil { + return nil, err + } + if len(findings) == 0 { + return nil, errors.New("no findings found in --data") + } + entries := make([]any, 0, len(findings)) + for _, f := range findings { + entry, bErr := buildOne(st, f) + if bErr != nil { + return nil, bErr + } + entries = append(entries, entry) + } + return entries, nil +} + +// extractFindings normalizes the payload into a list of raw finding objects. +func extractFindings(data []byte) ([]json.RawMessage, error) { + data = bytes.TrimSpace(data) + if len(data) == 0 { + return nil, errors.New("--data is empty") + } + switch data[0] { + case '[': + var arr []json.RawMessage + if err := json.Unmarshal(data, &arr); err != nil { + return nil, errors.Wrap(err, "parsing findings array") + } + return arr, nil + case '{': + var obj map[string]json.RawMessage + if err := json.Unmarshal(data, &obj); err != nil { + return nil, errors.Wrap(err, "parsing finding object") + } + for _, key := range wrapperKeys { + if raw, ok := obj[key]; ok { + var arr []json.RawMessage + if err := json.Unmarshal(raw, &arr); err != nil { + return nil, errors.Wrapf(err, "parsing %q array", key) + } + return arr, nil + } + } + return []json.RawMessage{data}, nil // a single finding object + default: + return nil, errors.New("--data is not valid JSON (expected an object or array)") + } +} + +func buildOne(scanType string, raw json.RawMessage) (any, error) { + switch scanType { + case ScanTypeOSS: + var e ossrealtime.IgnoredPackage + if err := json.Unmarshal(raw, &e); err != nil { + return nil, errors.Wrap(err, "parsing oss finding") + } + if e.PackageManager == "" || e.PackageName == "" || e.PackageVersion == "" { + return nil, missingFieldsErr(ScanTypeOSS, "PackageManager, PackageName, PackageVersion") + } + return e, nil + case ScanTypeSecrets: + var e secretsrealtime.IgnoredSecret + if err := json.Unmarshal(raw, &e); err != nil { + return nil, errors.Wrap(err, "parsing secrets finding") + } + if e.Title == "" || e.SecretValue == "" { + return nil, missingFieldsErr(ScanTypeSecrets, "Title, SecretValue") + } + return e, nil + case ScanTypeContainers: + var e containersrealtime.IgnoredContainersFinding + if err := json.Unmarshal(raw, &e); err != nil { + return nil, errors.Wrap(err, "parsing containers finding") + } + if e.ImageName == "" || e.ImageTag == "" { + return nil, missingFieldsErr(ScanTypeContainers, "ImageName, ImageTag") + } + return e, nil + case ScanTypeIaC: + var e iacrealtime.IgnoredIacFinding + if err := json.Unmarshal(raw, &e); err != nil { + return nil, errors.Wrap(err, "parsing iac finding") + } + if e.Title == "" || e.SimilarityID == "" { + return nil, missingFieldsErr(ScanTypeIaC, "Title, SimilarityID") + } + return e, nil + case ScanTypeASCA: + return buildAscaEntry(raw) + default: + return nil, errors.Errorf("unsupported scan type %q", scanType) + } +} + +// buildAscaEntry maps an ASCA finding to its ignore entry. The realtime scan emits snake_case +// (file_name/line/rule_id via grpcs.ScanDetail) while the ignore entry is PascalCase +// (FileName/Line/RuleID) — so a direct unmarshal would silently drop FileName/RuleID. We unmarshal +// the scan-output shape first, then fall back to the ignore-entry shape (e.g. for --remove input). +func buildAscaEntry(raw json.RawMessage) (any, error) { + var sd grpcs.ScanDetail + _ = json.Unmarshal(raw, &sd) + fileName, line, ruleID := sd.FileName, sd.Line, sd.RuleID + if fileName == "" && ruleID == 0 { + var ig grpcs.AscaIgnoreFinding + _ = json.Unmarshal(raw, &ig) + fileName, line, ruleID = ig.FileName, ig.Line, ig.RuleID + } + if fileName == "" || ruleID == 0 { + return nil, missingFieldsErr(ScanTypeASCA, "file_name/FileName, rule_id/RuleID") + } + return grpcs.AscaIgnoreFinding{FileName: fileName, Line: line, RuleID: ruleID}, nil +} + +func missingFieldsErr(scanType, fields string) error { + return errors.Errorf("invalid %s finding: missing required field(s): %s", scanType, fields) +} diff --git a/internal/services/realtimeengine/ignore/builder_test.go b/internal/services/realtimeengine/ignore/builder_test.go new file mode 100644 index 000000000..c62768d56 --- /dev/null +++ b/internal/services/realtimeengine/ignore/builder_test.go @@ -0,0 +1,130 @@ +package ignore + +import ( + "encoding/json" + "testing" + + "github.com/checkmarx/ast-cli/internal/wrappers/grpcs" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func entryMap(t *testing.T, v any) map[string]any { + t.Helper() + b, err := json.Marshal(v) + require.NoError(t, err) + var m map[string]any + require.NoError(t, json.Unmarshal(b, &m)) + return m +} + +func TestBuildEntries_OSS_DropsExtraScanFields(t *testing.T) { + finding := `{"PackageManager":"npm","PackageName":"lodash","PackageVersion":"4.17.20", + "FilePath":"package.json","Vulnerabilities":[{"CVE":"CVE-2021-23337","Severity":"High"}]}` + entries, err := BuildEntries("oss", []byte(finding)) + require.NoError(t, err) + require.Len(t, entries, 1) + + m := entryMap(t, entries[0]) + assert.Equal(t, "npm", m["PackageManager"]) + assert.Equal(t, "lodash", m["PackageName"]) + assert.Equal(t, "4.17.20", m["PackageVersion"]) + _, hasVulns := m["Vulnerabilities"] + assert.False(t, hasVulns, "only key fields should be persisted") +} + +func TestBuildEntries_SCAAlias(t *testing.T) { + entries, err := BuildEntries("sca", []byte(`{"PackageManager":"npm","PackageName":"lodash","PackageVersion":"4.17.20"}`)) + require.NoError(t, err) + assert.Len(t, entries, 1) +} + +func TestBuildEntries_Secrets(t *testing.T) { + entries, err := BuildEntries("secrets", []byte(`{"Title":"github-pat","SecretValue":"ghp_x","FilePath":"a.js","Severity":"High"}`)) + require.NoError(t, err) + m := entryMap(t, entries[0]) + assert.Equal(t, "github-pat", m["Title"]) + assert.Equal(t, "ghp_x", m["SecretValue"]) +} + +func TestBuildEntries_Containers(t *testing.T) { + entries, err := BuildEntries("containers", []byte(`{"ImageName":"ubuntu","ImageTag":"14.04","Status":"OK"}`)) + require.NoError(t, err) + m := entryMap(t, entries[0]) + assert.Equal(t, "ubuntu", m["ImageName"]) + assert.Equal(t, "14.04", m["ImageTag"]) +} + +func TestBuildEntries_IaC(t *testing.T) { + entries, err := BuildEntries("iac", []byte(`{"Title":"Missing User Instruction","SimilarityID":"abc123","Severity":"High"}`)) + require.NoError(t, err) + m := entryMap(t, entries[0]) + assert.Equal(t, "Missing User Instruction", m["Title"]) + assert.Equal(t, "abc123", m["SimilarityID"]) +} + +// The critical ASCA case: scan output is snake_case (file_name/line/rule_id), the ignore entry is +// PascalCase (FileName/Line/RuleID). A naive direct unmarshal would silently lose FileName/RuleID. +func TestBuildEntries_ASCA_MapsSnakeCaseScanOutput(t *testing.T) { + finding := `{"rule_id":5004,"rule_name":"Insecure Logging","file_name":"server.py","line":77, + "problematicLine":"log(pw)","severity":"High"}` + entries, err := BuildEntries("asca", []byte(finding)) + require.NoError(t, err) + require.Len(t, entries, 1) + + ig, ok := entries[0].(grpcs.AscaIgnoreFinding) + require.True(t, ok) + assert.Equal(t, "server.py", ig.FileName) + assert.Equal(t, uint32(77), ig.Line) + assert.Equal(t, uint32(5004), ig.RuleID) + + m := entryMap(t, entries[0]) + assert.Equal(t, "server.py", m["FileName"]) + assert.EqualValues(t, 5004, m["RuleID"]) +} + +func TestBuildEntries_ASCA_AcceptsIgnoreEntryShape(t *testing.T) { + entries, err := BuildEntries("asca", []byte(`{"FileName":"server.py","Line":77,"RuleID":5004}`)) + require.NoError(t, err) + ig := entries[0].(grpcs.AscaIgnoreFinding) + assert.Equal(t, "server.py", ig.FileName) + assert.Equal(t, uint32(5004), ig.RuleID) +} + +func TestBuildEntries_FullPayloadWrappers(t *testing.T) { + ossEntries, err := BuildEntries("oss", []byte(`{"Packages":[ + {"PackageManager":"npm","PackageName":"a","PackageVersion":"1"}, + {"PackageManager":"npm","PackageName":"b","PackageVersion":"2"}]}`)) + require.NoError(t, err) + assert.Len(t, ossEntries, 2) + + ascaEntries, err := BuildEntries("asca", []byte(`{"scan_details":[{"file_name":"x.py","line":1,"rule_id":10}]}`)) + require.NoError(t, err) + assert.Len(t, ascaEntries, 1) + + secretEntries, err := BuildEntries("secrets", []byte(`[{"Title":"t","SecretValue":"s"}]`)) + require.NoError(t, err) + assert.Len(t, secretEntries, 1) +} + +func TestBuildEntries_MissingRequiredFields_Errors(t *testing.T) { + _, err := BuildEntries("oss", []byte(`{"PackageManager":"npm","PackageName":"lodash"}`)) + require.Error(t, err) + + _, err = BuildEntries("asca", []byte(`{"file_name":"server.py","line":77}`)) // no rule_id + require.Error(t, err) +} + +func TestBuildEntries_BadJSON_Errors(t *testing.T) { + _, err := BuildEntries("oss", []byte(`not json`)) + require.Error(t, err) +} + +func TestIsValidScanType(t *testing.T) { + for _, s := range []string{"oss", "sca", "secrets", "containers", "iac", "asca", "OSS", "ASCA", " sca "} { + assert.Truef(t, IsValidScanType(s), "expected %q valid", s) + } + for _, s := range []string{"", "foo", "sast", "kics"} { + assert.Falsef(t, IsValidScanType(s), "expected %q invalid", s) + } +} diff --git a/internal/services/realtimeengine/ignore/ignorefile.go b/internal/services/realtimeengine/ignore/ignorefile.go new file mode 100644 index 000000000..4cb2b610d --- /dev/null +++ b/internal/services/realtimeengine/ignore/ignorefile.go @@ -0,0 +1,134 @@ +// Package ignore creates and updates the realtime-scan ignore file (the lean +// ".checkmarxIgnoredTempList.json" temp-list) that the realtime engines consume via +// --ignored-file-path. It is the write side of the ignore flow; the engines own the read/filter side. +package ignore + +import ( + "bytes" + "encoding/json" + "os" + "path/filepath" +) + +const ( + // defaultDir / defaultFileName form the CLI's default ignore file: + // /.checkmarx/checkmarxIgnoredTempList.json — written under the current working + // directory (the folder where the agent/Claude runs). The content format matches the IDE temp-list. + defaultDir = ".checkmarx" + defaultFileName = "checkmarxIgnoredTempList.json" + + dirPerm = 0o750 + filePerm = 0o600 +) + +// DefaultPath returns the default ignore-file path: ".checkmarx/checkmarxIgnoredTempList.json" +// under the current working directory (the project root where Claude / the agent runs). +func DefaultPath() string { + return filepath.Join(defaultDir, defaultFileName) +} + +// PathFor returns the ignore-file path anchored at workDir — the workspace root the hook +// event reports via its "cwd" field — i.e. /.checkmarx/checkmarxIgnoredTempList.json. +// When workDir is empty it falls back to the CWD-relative DefaultPath. +// +// Anchoring to workDir (rather than relying on the hook process's own working directory) +// keeps the read side aligned with the write side regardless of where the host CLI launches +// the hook subprocess: Claude Code happens to run hooks from the workspace root, but Copilot +// CLI launches them from a different directory, so a CWD-relative lookup misses the file the +// ignore command wrote under the workspace and the suppression is silently dropped. +func PathFor(workDir string) string { + if workDir == "" { + return DefaultPath() + } + return filepath.Join(workDir, defaultDir, defaultFileName) +} + +// Load reads the ignore file as a list of raw JSON entries. A missing or empty file yields an +// empty list (not an error) so the first ignore creates the file cleanly. +func Load(path string) ([]json.RawMessage, error) { + data, err := os.ReadFile(path) + if err != nil { + if os.IsNotExist(err) { + return []json.RawMessage{}, nil + } + return nil, err + } + if len(bytes.TrimSpace(data)) == 0 { + return []json.RawMessage{}, nil + } + var list []json.RawMessage + if err := json.Unmarshal(data, &list); err != nil { + return nil, err + } + return list, nil +} + +// Append adds entry to the list unless an entry with identical content already exists. +// Returns the (possibly extended) list and whether a new entry was added. +func Append(list []json.RawMessage, entry any) ([]json.RawMessage, bool, error) { + raw, err := json.Marshal(entry) + if err != nil { + return list, false, err + } + target, err := canonical(raw) + if err != nil { + return list, false, err + } + for _, existing := range list { + if c, cErr := canonical(existing); cErr == nil && c == target { + return list, false, nil // already ignored + } + } + return append(list, json.RawMessage(raw)), true, nil +} + +// Remove deletes any entry whose content matches the given one (the revive / review operation). +// Returns the (possibly shortened) list and whether anything was removed. +func Remove(list []json.RawMessage, entry any) ([]json.RawMessage, bool, error) { + raw, err := json.Marshal(entry) + if err != nil { + return list, false, err + } + target, err := canonical(raw) + if err != nil { + return list, false, err + } + out := make([]json.RawMessage, 0, len(list)) + removed := false + for _, existing := range list { + if c, cErr := canonical(existing); cErr == nil && c == target { + removed = true + continue + } + out = append(out, existing) + } + return out, removed, nil +} + +// Save writes the list as pretty-printed JSON, creating the parent directory if needed. +func Save(path string, list []json.RawMessage) error { + if dir := filepath.Dir(path); dir != "" && dir != "." { + if err := os.MkdirAll(dir, dirPerm); err != nil { + return err + } + } + data, err := json.MarshalIndent(list, "", " ") + if err != nil { + return err + } + return os.WriteFile(path, data, filePerm) +} + +// canonical normalizes a JSON object so two entries that differ only in key order or whitespace +// compare equal (json.Marshal of a map sorts keys alphabetically). +func canonical(raw []byte) (string, error) { + var m map[string]any + if err := json.Unmarshal(raw, &m); err != nil { + return "", err + } + out, err := json.Marshal(m) + if err != nil { + return "", err + } + return string(out), nil +} diff --git a/internal/services/realtimeengine/ignore/ignorefile_test.go b/internal/services/realtimeengine/ignore/ignorefile_test.go new file mode 100644 index 000000000..6024ed2d7 --- /dev/null +++ b/internal/services/realtimeengine/ignore/ignorefile_test.go @@ -0,0 +1,102 @@ +package ignore + +import ( + "encoding/json" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type ossEntry struct { + PackageManager string `json:"PackageManager"` + PackageName string `json:"PackageName"` + PackageVersion string `json:"PackageVersion"` +} + +func TestLoad_MissingFile_ReturnsEmpty(t *testing.T) { + list, err := Load(filepath.Join(t.TempDir(), "does-not-exist.json")) + require.NoError(t, err) + assert.Empty(t, list) +} + +func TestLoad_EmptyFile_ReturnsEmpty(t *testing.T) { + p := filepath.Join(t.TempDir(), "empty.json") + require.NoError(t, os.WriteFile(p, []byte(" \n"), 0o600)) + list, err := Load(p) + require.NoError(t, err) + assert.Empty(t, list) +} + +func TestAppend_DeDupes(t *testing.T) { + e := ossEntry{"npm", "lodash", "4.17.20"} + + list, added, err := Append(nil, e) + require.NoError(t, err) + assert.True(t, added) + assert.Len(t, list, 1) + + list, added, err = Append(list, e) + require.NoError(t, err) + assert.False(t, added, "identical entry must not be added twice") + assert.Len(t, list, 1) + + list, added, err = Append(list, ossEntry{"npm", "lodash", "4.17.21"}) + require.NoError(t, err) + assert.True(t, added) + assert.Len(t, list, 2) +} + +func TestAppend_DeDupe_IgnoresKeyOrder(t *testing.T) { + existing := json.RawMessage(`{"PackageVersion":"4.17.20","PackageName":"lodash","PackageManager":"npm"}`) + list, added, err := Append([]json.RawMessage{existing}, ossEntry{"npm", "lodash", "4.17.20"}) + require.NoError(t, err) + assert.False(t, added, "key order must not affect de-dupe") + assert.Len(t, list, 1) +} + +func TestRemove(t *testing.T) { + e := ossEntry{"npm", "lodash", "4.17.20"} + list, _, _ := Append(nil, e) + + list, removed, err := Remove(list, e) + require.NoError(t, err) + assert.True(t, removed) + assert.Empty(t, list) + + _, removed, err = Remove(list, e) + require.NoError(t, err) + assert.False(t, removed, "removing a missing entry is a no-op") +} + +func TestSaveLoad_RoundTrip_CreatesParentDir(t *testing.T) { + p := filepath.Join(t.TempDir(), ".checkmarx", ".checkmarxIgnoredTempList.json") + list, _, _ := Append(nil, ossEntry{"npm", "lodash", "4.17.20"}) + require.NoError(t, Save(p, list)) + + loaded, err := Load(p) + require.NoError(t, err) + require.Len(t, loaded, 1) + + var got ossEntry + require.NoError(t, json.Unmarshal(loaded[0], &got)) + assert.Equal(t, "lodash", got.PackageName) +} + +func TestDefaultPath(t *testing.T) { + assert.Equal(t, filepath.Join(".checkmarx", "checkmarxIgnoredTempList.json"), DefaultPath()) +} + +func TestPathFor_AnchorsAtWorkDir(t *testing.T) { + workDir := filepath.Join("some", "workspace") + assert.Equal(t, + filepath.Join(workDir, ".checkmarx", "checkmarxIgnoredTempList.json"), + PathFor(workDir), + ) +} + +func TestPathFor_EmptyWorkDirFallsBackToDefault(t *testing.T) { + assert.Equal(t, DefaultPath(), PathFor("")) +} diff --git a/internal/services/realtimeengine/ossrealtime/oss-realtime.go b/internal/services/realtimeengine/ossrealtime/oss-realtime.go index b2c5e0cfb..6411d81ff 100644 --- a/internal/services/realtimeengine/ossrealtime/oss-realtime.go +++ b/internal/services/realtimeengine/ossrealtime/oss-realtime.go @@ -89,11 +89,11 @@ func (o *OssRealtimeService) RunOssRealtimeScan(filePath, ignoredFilePath string if ignoredFilePath != "" { ignoredPkgs, err := loadIgnoredPackages(ignoredFilePath) if err != nil { - return nil, errorconstants.NewRealtimeEngineError("failed to load ignored packages").Error() + logger.PrintfIfVerbose("oss-realtime: failed to load ignore file %s: %v; continuing without ignore filtering", ignoredFilePath, err) + } else { + ignoreMap := buildIgnoreMap(ignoredPkgs) + response.Packages = filterIgnoredPackages(response.Packages, ignoreMap) } - - ignoreMap := buildIgnoreMap(ignoredPkgs) - response.Packages = filterIgnoredPackages(response.Packages, ignoreMap) } return response, nil diff --git a/internal/services/realtimeengine/secretsrealtime/secrets-realtime.go b/internal/services/realtimeengine/secretsrealtime/secrets-realtime.go index 05527804a..a4ab60914 100644 --- a/internal/services/realtimeengine/secretsrealtime/secrets-realtime.go +++ b/internal/services/realtimeengine/secretsrealtime/secrets-realtime.go @@ -115,7 +115,8 @@ func (s *SecretsRealtimeService) RunSecretsRealtimeScan(filePath, ignoredFilePat } ignoredSecrets, err := loadIgnoredSecrets(ignoredFilePath) if err != nil { - return nil, errorconstants.NewRealtimeEngineError("failed to load ignored secrets").Error() + logger.PrintfIfVerbose("secrets-realtime: failed to load ignore file %s: %v; continuing without ignore filtering", ignoredFilePath, err) + return results, nil } ignoreMap := buildIgnoreMap(ignoredSecrets) results = filterIgnoredSecrets(results, ignoreMap) diff --git a/internal/wrappers/active_mode.go b/internal/wrappers/active_mode.go new file mode 100644 index 000000000..26fd4cfd4 --- /dev/null +++ b/internal/wrappers/active_mode.go @@ -0,0 +1,88 @@ +package wrappers + +import ( + "os" + "path/filepath" + "strings" + + "github.com/checkmarx/ast-cli/internal/params" + "github.com/checkmarx/ast-cli/internal/wrappers/configuration" + "github.com/pkg/errors" +) + +// File permission: owner read/write only. Active-mode metadata doesn't hold a +// credential itself, but it does reveal where the user's credential currently +// lives — keep it owner-only to avoid leaking that signal. +const activeModeFilePerm = 0o600 + +// ActiveModeFilePath returns the absolute path to the active-mode metadata +// file. Derived from the same config directory as the existing yaml so a +// custom --config-file-path is respected. +func ActiveModeFilePath() (string, error) { + configPath, err := configuration.GetConfigFilePath() + if err != nil { + return "", errors.Wrap(err, "failed to resolve config file path for active-mode file") + } + return filepath.Join(filepath.Dir(configPath), params.ActiveModeFileName), nil +} + +// ReadActiveMode returns the currently active session mode — one of +// params.SessionYamlValue, params.SessionLocalValue, or +// params.SessionGlobalValue. Returns ("", nil) if the file does not exist, +// which means "no active session" — every read path falls back to whatever +// the user has set directly (env var or yaml). +func ReadActiveMode() (string, error) { + path, err := ActiveModeFilePath() + if err != nil { + return "", err + } + data, err := os.ReadFile(path) + if err != nil { + if os.IsNotExist(err) { + return "", nil + } + return "", errors.Wrap(err, "failed to read active-mode file") + } + mode := strings.TrimSpace(string(data)) + switch mode { + case params.SessionYamlValue, params.SessionLocalValue, params.SessionGlobalValue, "": + return mode, nil + default: + // Unknown value — treat as no active mode so the CLI doesn't get + // confused by a corrupt file. Caller can still fall back to defaults. + return "", nil + } +} + +// WriteActiveMode persists the active session mode. Creates the config +// directory if needed so the first-ever login on a fresh machine works. +func WriteActiveMode(mode string) error { + if mode != params.SessionYamlValue && mode != params.SessionLocalValue && mode != params.SessionGlobalValue { + return errors.Errorf("invalid active mode %q: must be %q, %q, or %q", + mode, params.SessionYamlValue, params.SessionLocalValue, params.SessionGlobalValue) + } + path, err := ActiveModeFilePath() + if err != nil { + return err + } + if mkErr := os.MkdirAll(filepath.Dir(path), 0o700); mkErr != nil { + return errors.Wrap(mkErr, "failed to create config directory for active-mode file") + } + if writeErr := os.WriteFile(path, []byte(mode), activeModeFilePerm); writeErr != nil { + return errors.Wrap(writeErr, "failed to write active-mode file") + } + return nil +} + +// ClearActiveMode removes the active-mode file. Returns nil if the file +// already does not exist (logout is idempotent). +func ClearActiveMode() error { + path, err := ActiveModeFilePath() + if err != nil { + return err + } + if rmErr := os.Remove(path); rmErr != nil && !os.IsNotExist(rmErr) { + return errors.Wrap(rmErr, "failed to remove active-mode file") + } + return nil +} diff --git a/internal/wrappers/active_mode_test.go b/internal/wrappers/active_mode_test.go new file mode 100644 index 000000000..baf20af7b --- /dev/null +++ b/internal/wrappers/active_mode_test.go @@ -0,0 +1,95 @@ +package wrappers + +import ( + "os" + "path/filepath" + "testing" + + "github.com/checkmarx/ast-cli/internal/params" +) + +func TestActiveModeFilePath_ReturnsPathInConfigDir(t *testing.T) { + dir := withTempConfigDir(t) + got, err := ActiveModeFilePath() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + want := filepath.Join(dir, params.ActiveModeFileName) + if got != want { + t.Errorf("ActiveModeFilePath() = %q, want %q", got, want) + } +} + +func TestReadActiveMode_EmptyWhenFileMissing(t *testing.T) { + withTempConfigDir(t) + got, err := ReadActiveMode() + if err != nil { + t.Fatalf("expected nil error when file absent, got: %v", err) + } + if got != "" { + t.Errorf("expected empty string, got %q", got) + } +} + +func TestWriteAndReadActiveMode_RoundTrip(t *testing.T) { + withTempConfigDir(t) + for _, mode := range []string{params.SessionYamlValue, params.SessionLocalValue, params.SessionGlobalValue} { + if err := WriteActiveMode(mode); err != nil { + t.Fatalf("WriteActiveMode(%q) failed: %v", mode, err) + } + got, err := ReadActiveMode() + if err != nil { + t.Fatalf("ReadActiveMode after writing %q failed: %v", mode, err) + } + if got != mode { + t.Errorf("round-trip mismatch: wrote %q, read %q", mode, got) + } + } +} + +func TestWriteActiveMode_RejectsInvalidValue(t *testing.T) { + withTempConfigDir(t) + err := WriteActiveMode("invalid-mode-value") + if err == nil { + t.Fatal("expected error for invalid mode value, got nil") + } +} + +func TestReadActiveMode_TreatsCorruptValueAsAbsent(t *testing.T) { + dir := withTempConfigDir(t) + // Manually write garbage to the active-mode file. + if err := os.WriteFile(filepath.Join(dir, params.ActiveModeFileName), []byte("garbage-mode"), 0o600); err != nil { + t.Fatalf("setup write failed: %v", err) + } + got, err := ReadActiveMode() + if err != nil { + t.Fatalf("ReadActiveMode returned unexpected error: %v", err) + } + if got != "" { + t.Errorf("expected corrupt value to be treated as absent (empty), got %q", got) + } +} + +func TestClearActiveMode_RemovesFile(t *testing.T) { + withTempConfigDir(t) + if err := WriteActiveMode(params.SessionGlobalValue); err != nil { + t.Fatalf("setup write failed: %v", err) + } + if err := ClearActiveMode(); err != nil { + t.Fatalf("ClearActiveMode failed: %v", err) + } + got, err := ReadActiveMode() + if err != nil { + t.Fatalf("ReadActiveMode after clear failed: %v", err) + } + if got != "" { + t.Errorf("expected empty after clear, got %q", got) + } +} + +func TestClearActiveMode_IdempotentWhenFileMissing(t *testing.T) { + withTempConfigDir(t) + if err := ClearActiveMode(); err != nil { + t.Errorf("expected nil error when file absent, got: %v", err) + } +} diff --git a/internal/wrappers/client.go b/internal/wrappers/client.go index be09d4da9..de73ba993 100644 --- a/internal/wrappers/client.go +++ b/internal/wrappers/client.go @@ -72,12 +72,14 @@ type ClientCredentialsError struct { const FailedToAuth = "Failed to authenticate - please provide an %s" const BaseAuthURLSuffix = "protocol/openid-connect/token" const BaseAuthURLPrefix = "auth/realms/organization" -const baseURLKey = "ast-base-url" +// BaseURLKey is the JWT claim that carries the AST base URL. It is present only +// on the exchanged access token (see GetURL), not on the stored refresh token. +const BaseURLKey = "ast-base-url" const audienceClaimKey = "aud" -var CachedAccessToken string -var CachedAccessTime time.Time +var cachedAccessToken string +var cachedAccessTime time.Time var Domains = make(map[string]struct{}) func retryHTTPRequest(requestFunc func() (*http.Response, error), retries int, baseDelayInMilliSec time.Duration) (*http.Response, error) { @@ -640,23 +642,50 @@ func configureClientCredentialsAndGetNewToken() (string, error) { func getClientCredentialsFromCache(tokenExpirySeconds int) string { logger.PrintIfVerbose("Checking cache for API access token.") - expired := time.Since(CachedAccessTime) > time.Duration(tokenExpirySeconds-expiryGraceSeconds)*time.Second + expired := time.Since(cachedAccessTime) > time.Duration(tokenExpirySeconds-expiryGraceSeconds)*time.Second if !expired { logger.PrintIfVerbose("Using cached API access token!") - return CachedAccessToken + return cachedAccessToken } logger.PrintIfVerbose("API access token not found in cache!") return "" } +// InvalidateAccessTokenCache forces the next GetAccessToken to re-exchange the +// stored credential instead of returning a cached access token. Used after a fresh +// `cx auth login` (which may target a different tenant) so the new ast-base-url +// claim is re-derived. Guarded by the same mutex that protects the cache writes. +func InvalidateAccessTokenCache() { + credentialsMutex.Lock() + defer credentialsMutex.Unlock() + cachedAccessToken = "" + cachedAccessTime = time.Time{} +} + func writeCredentialsToCache(accessToken string) { credentialsMutex.Lock() defer credentialsMutex.Unlock() logger.PrintIfVerbose("Storing API access token to cache.") viper.Set(commonParams.AstToken, accessToken) - CachedAccessToken = accessToken - CachedAccessTime = time.Now() + cachedAccessToken = accessToken + cachedAccessTime = time.Now() +} + +// SetCachedAccessTokenForTest seeds (token != "") or clears (token == "") the +// in-memory access-token cache. Exported solely so tests in other packages can +// control the cache without reaching into unexported state. Guarded by the same +// mutex that protects the cache writes. +func SetCachedAccessTokenForTest(token string) { + credentialsMutex.Lock() + defer credentialsMutex.Unlock() + if token == "" { + cachedAccessToken = "" + cachedAccessTime = time.Time{} + return + } + cachedAccessToken = token + cachedAccessTime = time.Now() } func getNewToken(credentialsPayload, authServerURI string) (string, error) { @@ -878,13 +907,19 @@ func hasRedirectStatusCode(resp *http.Response) bool { return resp.StatusCode == http.StatusTemporaryRedirect || resp.StatusCode == http.StatusMovedPermanently } -func GetAuthURI() (string, error) { +func GetRealmURL() (string, error) { var authURI string var err error override := viper.GetBool(commonParams.ApikeyOverrideFlag) apiKey := viper.GetString(commonParams.AstAPIKey) - if len(apiKey) > 0 { + // When override is set (e.g. `cx auth login`, which forces ApikeyOverrideFlag + // so the explicit --base-auth-uri/--tenant win), do NOT decode the stored API + // key. Decoding it here would surface a stale/malformed cx_apikey as a hard + // "failed to resolve IAM realm URL" error before the override branch below can + // build the realm from the flags — defeating the very purpose of the override + // and making login impossible until the bad key is manually cleared. + if len(apiKey) > 0 && !override { logger.PrintIfVerbose("Base Auth URI - Extract from API KEY") authURI, err = ExtractFromTokenClaims(apiKey, audienceClaimKey) if err != nil { @@ -925,7 +960,15 @@ func GetAuthURI() (string, error) { authURI = strings.Trim(authURI, "/") logger.PrintIfVerbose(fmt.Sprintf("Base Auth URI - %s ", authURI)) - return fmt.Sprintf("%s/%s", authURI, BaseAuthURLSuffix), nil + return authURI, nil +} + +func GetAuthURI() (string, error) { + realmURL, err := GetRealmURL() + if err != nil { + return "", err + } + return fmt.Sprintf("%s/%s", realmURL, BaseAuthURLSuffix), nil } func GetURL(path, accessToken string) (string, error) { @@ -934,7 +977,7 @@ func GetURL(path, accessToken string) (string, error) { override := viper.GetBool(commonParams.ApikeyOverrideFlag) if accessToken != "" { logger.PrintIfVerbose("Base URI - Extract from JWT token") - cleanURL, err = ExtractFromTokenClaims(accessToken, baseURLKey) + cleanURL, err = ExtractFromTokenClaims(accessToken, BaseURLKey) if err != nil { return "", err } @@ -965,9 +1008,28 @@ func ExtractFromTokenClaims(accessToken, claim string) (string, error) { return "", errors.Errorf(APIKeyDecodeErrorFormat, err) } - if claims, ok := token.Claims.(jwt.MapClaims); ok && claims[claim] != nil { - value = strings.TrimSpace(claims[claim].(string)) - } else { + claims, ok := token.Claims.(jwt.MapClaims) + if !ok || claims[claim] == nil { + return "", errors.Errorf(jwtError, claim) + } + + // A claim value can be a string or, per the OIDC spec (notably "aud"), an + // array of strings. Handle both without a type assertion that would panic. + switch v := claims[claim].(type) { + case string: + value = strings.TrimSpace(v) + case []interface{}: + for _, item := range v { + if s, isStr := item.(string); isStr && strings.TrimSpace(s) != "" { + value = strings.TrimSpace(s) + break + } + } + default: + return "", errors.Errorf(jwtError, claim) + } + + if value == "" { return "", errors.Errorf(jwtError, claim) } diff --git a/internal/wrappers/client_test.go b/internal/wrappers/client_test.go index e75e9e9b0..e2f4fb670 100644 --- a/internal/wrappers/client_test.go +++ b/internal/wrappers/client_test.go @@ -1,6 +1,8 @@ package wrappers import ( + "encoding/base64" + "encoding/json" "errors" "fmt" "net/http" @@ -185,6 +187,55 @@ func TestGetAPIKeyPayload(t *testing.T) { } } +// TestGetRealmURL_LoginOverrideSkipsStoredAPIKey guards the `cx auth login` fix: +// when ApikeyOverrideFlag is set, GetRealmURL must build the realm from the +// explicit --base-auth-uri/--tenant flags and must NOT decode the stored +// cx_apikey. A stale/malformed stored key previously surfaced here as a hard +// "failed to resolve IAM realm URL" error, making login impossible until the bad +// key was manually cleared. +func TestGetRealmURL_LoginOverrideSkipsStoredAPIKey(t *testing.T) { + keys := []string{ + commonParams.ApikeyOverrideFlag, + commonParams.AstAPIKey, + commonParams.BaseAuthURIKey, + commonParams.TenantKey, + } + saved := make(map[string]interface{}, len(keys)) + for _, k := range keys { + saved[k] = viper.Get(k) + } + t.Cleanup(func() { + for _, k := range keys { + viper.Set(k, saved[k]) + } + }) + + const malformedKey = "not-a-jwt" // single segment -> ExtractFromTokenClaims fails + + t.Run("override builds realm from flags despite a malformed stored key", func(t *testing.T) { + viper.Set(commonParams.ApikeyOverrideFlag, true) + viper.Set(commonParams.AstAPIKey, malformedKey) + viper.Set(commonParams.BaseAuthURIKey, "https://eu.iam.checkmarx.net") + viper.Set(commonParams.TenantKey, "cx_seg") + + realmURL, err := GetRealmURL() + + assert.NoError(t, err) + assert.Equal(t, "https://eu.iam.checkmarx.net/auth/realms/cx_seg", realmURL) + }) + + t.Run("without override a malformed stored key still errors (unchanged)", func(t *testing.T) { + viper.Set(commonParams.ApikeyOverrideFlag, false) + viper.Set(commonParams.AstAPIKey, malformedKey) + viper.Set(commonParams.BaseAuthURIKey, "https://eu.iam.checkmarx.net") + viper.Set(commonParams.TenantKey, "cx_seg") + + _, err := GetRealmURL() + + assert.Error(t, err) + }) +} + func TestSetAgentNameAndOrigin(t *testing.T) { viper.Set(commonParams.AgentNameKey, "TestAgent") viper.Set(commonParams.OriginKey, "TestOrigin") @@ -208,6 +259,47 @@ func TestSetAgentNameAndOrigin(t *testing.T) { } } +// unsignedJWT builds a 3-segment JWT with the given claims. ExtractFromTokenClaims +// uses ParseUnverified, so the signature segment is irrelevant. +func unsignedJWT(claims map[string]interface{}) string { + header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"none","typ":"JWT"}`)) + payloadBytes, _ := json.Marshal(claims) + payload := base64.RawURLEncoding.EncodeToString(payloadBytes) + return header + "." + payload + ".sig" +} + +// TestExtractFromTokenClaims covers fix #10: the "aud" (and any) claim may be a +// string or an array of strings. The array form must not panic. +func TestExtractFromTokenClaims(t *testing.T) { + const realm = "https://eu.iam.checkmarx.net/auth/realms/cx_seg" + + t.Run("string claim returned as-is", func(t *testing.T) { + token := unsignedJWT(map[string]interface{}{"aud": realm}) + got, err := ExtractFromTokenClaims(token, "aud") + assert.NoError(t, err) + assert.Equal(t, realm, got) + }) + + t.Run("array claim returns first non-empty string (no panic)", func(t *testing.T) { + token := unsignedJWT(map[string]interface{}{"aud": []interface{}{"", realm, "account"}}) + got, err := ExtractFromTokenClaims(token, "aud") + assert.NoError(t, err) + assert.Equal(t, realm, got) + }) + + t.Run("non-string, non-array claim errors instead of panicking", func(t *testing.T) { + token := unsignedJWT(map[string]interface{}{"aud": 12345}) + _, err := ExtractFromTokenClaims(token, "aud") + assert.Error(t, err) + }) + + t.Run("missing claim errors", func(t *testing.T) { + token := unsignedJWT(map[string]interface{}{"iss": realm}) + _, err := ExtractFromTokenClaims(token, "aud") + assert.Error(t, err) + }) +} + func TestRetryIAMHTTPRequest_Success(t *testing.T) { fn := func() (*http.Response, error) { return &http.Response{ diff --git a/internal/wrappers/configuration/configuration.go b/internal/wrappers/configuration/configuration.go index 796d57c0b..64a19b6d6 100644 --- a/internal/wrappers/configuration/configuration.go +++ b/internal/wrappers/configuration/configuration.go @@ -3,6 +3,7 @@ package configuration import ( "bufio" "fmt" + "io" "log" "os" "os/user" @@ -228,7 +229,8 @@ func SafeWriteSingleConfigKeyString(configFilePath, key string, value string) er return nil } -// LoadConfig loads the configuration from a file. If the file does not exist, it returns an empty map. +// LoadConfig loads the configuration from a file. If the file does not exist +// or is empty, it returns an empty map. func LoadConfig(path string) (map[string]interface{}, error) { config := make(map[string]interface{}) file, err := os.Open(path) @@ -244,6 +246,13 @@ func LoadConfig(path string) (map[string]interface{}, error) { decoder := yaml.NewDecoder(file) if err = decoder.Decode(&config); err != nil { + if err == io.EOF { + // An empty (zero-byte) config file is a valid "no config yet" + // state, not corruption. Treat it like a missing file and return + // an empty config so callers (e.g. cx auth login persisting a + // fresh token) can populate it instead of failing. + return config, nil + } return nil, fmt.Errorf("error decoding YAML: %w", err) } return config, nil diff --git a/internal/wrappers/oauth_pkce.go b/internal/wrappers/oauth_pkce.go new file mode 100644 index 000000000..55aebc5e7 --- /dev/null +++ b/internal/wrappers/oauth_pkce.go @@ -0,0 +1,339 @@ +package wrappers + +import ( + "context" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "fmt" + "html" + "io" + "net" + "net/http" + "net/url" + "os" + "os/exec" + "runtime" + "strings" + "time" + + "github.com/checkmarx/ast-cli/internal/logger" + "github.com/pkg/errors" +) + +// pkceScopes matches the scopes requested by the Checkmarx One VS Code +// extension's OAuth flow. The ast-api / iam-api scopes are configured as +// default client scopes on the ide-integration Keycloak client, so they +// are granted automatically without being requested explicitly. +const pkceScopes = "openid offline_access" + +const pkceLoginTimeout = 5 * time.Minute + +type PKCETokenResponse struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + IDToken string `json:"id_token"` + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in"` + Scope string `json:"scope"` +} + +type PKCELoginOptions struct { + RealmURL string + ClientID string + Port int + OpenBrowser bool +} + +type oidcDiscovery struct { + AuthorizationEndpoint string `json:"authorization_endpoint"` + TokenEndpoint string `json:"token_endpoint"` +} + +// LoginWithPKCE runs an OAuth 2.0 Authorization Code + PKCE flow against the +// Keycloak realm at opts.RealmURL and returns the token response. The flow +// starts a one-shot HTTP listener on 127.0.0.1, opens the user's browser to +// the authorize URL, and waits for the redirect callback. The caller is +// responsible for persisting or printing the returned tokens. +func LoginWithPKCE(ctx context.Context, opts PKCELoginOptions) (*PKCETokenResponse, error) { + if opts.RealmURL == "" { + return nil, errors.New("realm URL is required") + } + if opts.ClientID == "" { + return nil, errors.New("client-id is required") + } + + disco, err := discoverOIDC(ctx, opts.RealmURL) + if err != nil { + return nil, errors.Wrap(err, "failed to fetch OIDC discovery document") + } + + verifier, challenge, err := newPKCE() + if err != nil { + return nil, errors.Wrap(err, "failed to generate PKCE verifier") + } + state, err := randomURLSafe(16) + if err != nil { + return nil, errors.Wrap(err, "failed to generate state") + } + + listener, err := net.Listen("tcp4", fmt.Sprintf("127.0.0.1:%d", opts.Port)) + if err != nil { + return nil, errors.Wrap(err, "failed to start local callback listener") + } + defer listener.Close() + tcpAddr, ok := listener.Addr().(*net.TCPAddr) + if !ok { + return nil, errors.New("local listener did not bind to a TCP address") + } + // The redirect URI uses the 'localhost' hostname and the '/checkmarx1/callback' + // path to match the pattern whitelisted on the 'ide-integration' Keycloak client + // — the same pattern used by the Checkmarx One VS Code extension. Because + // 'localhost' resolves to ::1 on IPv6-preferring systems, we ALSO bind an IPv6 + // loopback listener on the same port below (best-effort), so the browser callback + // reaches us whichever family 'localhost' resolves to. Both listeners are + // loopback-only (safe). + redirectURI := fmt.Sprintf("http://localhost:%d/checkmarx1/callback", tcpAddr.Port) + authURL := buildAuthorizeURL(disco.AuthorizationEndpoint, opts.ClientID, redirectURI, state, challenge) + + type callbackResult struct { + code string + err error + } + resultCh := make(chan callbackResult, 1) + + mux := http.NewServeMux() + mux.HandleFunc("/checkmarx1/callback", func(w http.ResponseWriter, r *http.Request) { + q := r.URL.Query() + // Validate the anti-CSRF state FIRST. A request that does not carry our + // exact state value is unsolicited (a stray prefetch, a probe, or a CSRF + // attempt) and is IGNORED — we do NOT resolve resultCh, so it cannot + // abort the pending login, and we do not act on its other parameters. + // The genuine callback echoes our state and is the only thing that + // completes the flow; if none arrives, the outer timeout fires. + if got := q.Get("state"); got != state { + writeBrowserMessage(w, "Authentication failed.", "State mismatch — this request was ignored. You can close this tab.") + logger.PrintIfVerbose("OAuth callback: ignoring request with missing/mismatched state") + return + } + if errParam := q.Get("error"); errParam != "" { + desc := q.Get("error_description") + writeBrowserMessage(w, "Authentication failed.", fmt.Sprintf("%s: %s", errParam, desc)) + resultCh <- callbackResult{err: errors.Errorf("authorization server returned error: %s — %s", errParam, desc)} + return + } + code := q.Get("code") + if code == "" { + writeBrowserMessage(w, "Authentication failed.", "Missing authorization code in callback.") + resultCh <- callbackResult{err: errors.New("missing authorization code in callback")} + return + } + writeBrowserMessage(w, "Authentication successful.", "You can close this tab and return to the terminal.") + resultCh <- callbackResult{code: code} + }) + + server := &http.Server{Handler: mux, ReadHeaderTimeout: 10 * time.Second} + go func() { _ = server.Serve(listener) }() + // Best-effort IPv6 loopback listener on the same port, so a browser that resolves + // 'localhost' to ::1 still reaches the callback. If it fails (no IPv6 stack, or the + // port is taken on ::1), proceed IPv4-only — unchanged from the previous behavior. + if v6Listener, v6Err := net.Listen("tcp6", fmt.Sprintf("[::1]:%d", tcpAddr.Port)); v6Err == nil { + defer v6Listener.Close() + go func() { _ = server.Serve(v6Listener) }() + } else { + logger.PrintIfVerbose("OAuth callback: IPv6 loopback listener unavailable, using IPv4 only: " + v6Err.Error()) + } + defer func() { + shutdownCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + _ = server.Shutdown(shutdownCtx) + }() + + // Diagnostic messages go to stderr so session mode's eval-able stdout + // (the env-var assignment line emitted by the caller after this returns) + // is not polluted. In default mode the diagnostics are still visible in + // the terminal since stderr renders to the console. + fmt.Fprintf(os.Stderr, "Opening browser to: %s\n", authURL) + fmt.Fprintln(os.Stderr, "If the browser does not open, copy and paste the URL above.") + if opts.OpenBrowser { + if err := openBrowser(authURL); err != nil { + logger.PrintIfVerbose("Failed to open browser automatically: " + err.Error()) + } + } + fmt.Fprintln(os.Stderr, "Waiting for authentication...") + + var code string + select { + case res := <-resultCh: + if res.err != nil { + return nil, res.err + } + code = res.code + case <-time.After(pkceLoginTimeout): + return nil, errors.Errorf("timed out after %s waiting for authentication", pkceLoginTimeout) + case <-ctx.Done(): + return nil, ctx.Err() + } + + return exchangeCodeForToken(ctx, disco.TokenEndpoint, opts.ClientID, code, verifier, redirectURI) +} + +func discoverOIDC(ctx context.Context, realmURL string) (*oidcDiscovery, error) { + discoURL := strings.TrimRight(realmURL, "/") + "/.well-known/openid-configuration" + req, err := http.NewRequestWithContext(ctx, http.MethodGet, discoURL, nil) + if err != nil { + return nil, err + } + client := GetClient(15) + resp, err := client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + if resp.StatusCode == http.StatusNotFound { + return nil, errors.Errorf("realm not found at %s — check --tenant and --base-auth-uri", discoURL) + } + if resp.StatusCode != http.StatusOK { + return nil, errors.Errorf("discovery endpoint returned status %d", resp.StatusCode) + } + var d oidcDiscovery + if err := json.NewDecoder(resp.Body).Decode(&d); err != nil { + return nil, errors.Wrap(err, "failed to decode discovery document") + } + if d.AuthorizationEndpoint == "" || d.TokenEndpoint == "" { + return nil, errors.New("discovery document is missing authorization_endpoint or token_endpoint") + } + return &d, nil +} + +func buildAuthorizeURL(authEndpoint, clientID, redirectURI, state, challenge string) string { + q := url.Values{} + q.Set("response_type", "code") + q.Set("client_id", clientID) + q.Set("redirect_uri", redirectURI) + q.Set("scope", pkceScopes) + q.Set("state", state) + q.Set("code_challenge", challenge) + q.Set("code_challenge_method", "S256") + return authEndpoint + "?" + q.Encode() +} + +func exchangeCodeForToken(ctx context.Context, tokenEndpoint, clientID, code, verifier, redirectURI string) (*PKCETokenResponse, error) { + form := url.Values{} + form.Set("grant_type", "authorization_code") + form.Set("client_id", clientID) + form.Set("code", code) + form.Set("redirect_uri", redirectURI) + form.Set("code_verifier", verifier) + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, tokenEndpoint, strings.NewReader(form.Encode())) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + client := GetClient(30) + resp, err := client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return nil, errors.Errorf("token exchange failed with status %d: %s", resp.StatusCode, string(body)) + } + + var tr PKCETokenResponse + if err := json.NewDecoder(resp.Body).Decode(&tr); err != nil { + return nil, errors.Wrap(err, "failed to decode token response") + } + if tr.RefreshToken == "" { + return nil, errors.New("token response did not include a refresh_token — verify that the Keycloak client grants the 'offline_access' scope") + } + return &tr, nil +} + +// RevokeRefreshToken invalidates the given refresh token at the Keycloak realm +// via the OAuth 2.0 Token Revocation endpoint (RFC 7009). This is deliberately +// the /revoke endpoint and NOT /logout: /logout is RP-initiated logout that +// ends the entire SSO session and would invalidate every token in that +// session — including tokens we want to keep alive in other CLI session modes. +// /revoke targets a single token, leaving sibling tokens in the same session +// untouched, which is what strict storage independence between --session +// modes requires. +// +// Idempotent: a 400 response (token already invalid) is treated as success +// so callers can use this as best-effort cleanup during auto-revoke and +// explicit logout. +func RevokeRefreshToken(ctx context.Context, realmURL, clientID, refreshToken string) error { + endpoint := strings.TrimRight(realmURL, "/") + "/protocol/openid-connect/revoke" + form := url.Values{} + form.Set("client_id", clientID) + form.Set("token", refreshToken) + form.Set("token_type_hint", "refresh_token") + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, strings.NewReader(form.Encode())) + if err != nil { + return err + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + resp, err := GetClient(15).Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + + if resp.StatusCode >= 200 && resp.StatusCode < 300 { + return nil + } + if resp.StatusCode == http.StatusBadRequest { + return nil + } + body, _ := io.ReadAll(resp.Body) + return errors.Errorf("revoke request failed with status %d: %s", resp.StatusCode, string(body)) +} + +func newPKCE() (verifier, challenge string, err error) { + verifier, err = randomURLSafe(32) + if err != nil { + return "", "", err + } + sum := sha256.Sum256([]byte(verifier)) + challenge = base64.RawURLEncoding.EncodeToString(sum[:]) + return verifier, challenge, nil +} + +func randomURLSafe(byteLen int) (string, error) { + b := make([]byte, byteLen) + if _, err := rand.Read(b); err != nil { + return "", err + } + return base64.RawURLEncoding.EncodeToString(b), nil +} + +// openBrowser is a package-level var so tests can intercept the launch and +// simulate the user completing the OAuth flow without a real browser. +var openBrowser = func(targetURL string) error { + switch runtime.GOOS { + case "windows": + return exec.Command("rundll32", "url.dll,FileProtocolHandler", targetURL).Start() + case "darwin": + return exec.Command("open", targetURL).Start() + default: + return exec.Command("xdg-open", targetURL).Start() + } +} + +func writeBrowserMessage(w http.ResponseWriter, title, body string) { + w.Header().Set("Content-Type", "text/html; charset=utf-8") + // Escape both fields: body may contain server-supplied error/description + // text, so it must never be reflected into the page as raw HTML. + safeTitle := html.EscapeString(title) + safeBody := html.EscapeString(body) + _, _ = fmt.Fprintf(w, `%s + +

%s

%s

`, safeTitle, safeTitle, safeBody) +} diff --git a/internal/wrappers/oauth_pkce_test.go b/internal/wrappers/oauth_pkce_test.go new file mode 100644 index 000000000..a5ead20c3 --- /dev/null +++ b/internal/wrappers/oauth_pkce_test.go @@ -0,0 +1,282 @@ +package wrappers + +import ( + "context" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" +) + +func TestNewPKCE_ChallengeIsSHA256OfVerifier(t *testing.T) { + verifier, challenge, err := newPKCE() + if err != nil { + t.Fatalf("newPKCE returned error: %v", err) + } + if verifier == "" || challenge == "" { + t.Fatal("verifier or challenge is empty") + } + sum := sha256.Sum256([]byte(verifier)) + expected := base64.RawURLEncoding.EncodeToString(sum[:]) + if challenge != expected { + t.Errorf("challenge = %q, want %q", challenge, expected) + } +} + +func TestBuildAuthorizeURL_IncludesAllRequiredParams(t *testing.T) { + authURL := buildAuthorizeURL("https://iam.example.com/auth", "ast-app", "http://127.0.0.1:54321/callback", "state-123", "challenge-abc") + parsed, err := url.Parse(authURL) + if err != nil { + t.Fatalf("authURL is not parseable: %v", err) + } + q := parsed.Query() + checks := map[string]string{ + "response_type": "code", + "client_id": "ast-app", + "redirect_uri": "http://127.0.0.1:54321/callback", + "state": "state-123", + "code_challenge": "challenge-abc", + "code_challenge_method": "S256", + } + for key, want := range checks { + if got := q.Get(key); got != want { + t.Errorf("query[%q] = %q, want %q", key, got, want) + } + } + if got := q.Get("scope"); !strings.Contains(got, "openid") || !strings.Contains(got, "offline_access") { + t.Errorf("scope %q must include openid and offline_access", got) + } +} + +func TestDiscoverOIDC_HappyPath(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/.well-known/openid-configuration" { + http.NotFound(w, r) + return + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]string{ + "authorization_endpoint": "https://iam.example.com/auth", + "token_endpoint": "https://iam.example.com/token", + }) + })) + defer srv.Close() + + d, err := discoverOIDC(context.Background(), srv.URL) + if err != nil { + t.Fatalf("discoverOIDC returned error: %v", err) + } + if d.AuthorizationEndpoint != "https://iam.example.com/auth" || d.TokenEndpoint != "https://iam.example.com/token" { + t.Errorf("unexpected endpoints: %+v", d) + } +} + +func TestDiscoverOIDC_404(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.NotFound(w, r) + })) + defer srv.Close() + + _, err := discoverOIDC(context.Background(), srv.URL) + if err == nil { + t.Fatal("expected error on 404, got nil") + } + if !strings.Contains(err.Error(), "realm not found") { + t.Errorf("error %q should mention 'realm not found'", err.Error()) + } +} + +func TestExchangeCodeForToken_HappyPath(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _ = r.ParseForm() + if r.Form.Get("grant_type") != "authorization_code" { + http.Error(w, "bad grant_type", http.StatusBadRequest) + return + } + if r.Form.Get("code_verifier") != "the-verifier" { + http.Error(w, "bad verifier", http.StatusBadRequest) + return + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "access_token": "fake-access", + "refresh_token": "fake-refresh", + "expires_in": 300, + "token_type": "Bearer", + }) + })) + defer srv.Close() + + tokens, err := exchangeCodeForToken(context.Background(), srv.URL, "ast-app", "the-code", "the-verifier", "http://127.0.0.1:1/callback") + if err != nil { + t.Fatalf("exchangeCodeForToken returned error: %v", err) + } + if tokens.RefreshToken != "fake-refresh" || tokens.AccessToken != "fake-access" { + t.Errorf("unexpected tokens: %+v", tokens) + } +} + +func TestExchangeCodeForToken_MissingRefreshToken(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "access_token": "fake-access", + "expires_in": 300, + }) + })) + defer srv.Close() + + _, err := exchangeCodeForToken(context.Background(), srv.URL, "ast-app", "c", "v", "r") + if err == nil { + t.Fatal("expected error on missing refresh_token") + } + if !strings.Contains(err.Error(), "offline_access") { + t.Errorf("error %q should mention offline_access scope", err.Error()) + } +} + +func TestExchangeCodeForToken_KeycloakError(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, `{"error":"invalid_grant","error_description":"code expired"}`, http.StatusBadRequest) + })) + defer srv.Close() + + _, err := exchangeCodeForToken(context.Background(), srv.URL, "ast-app", "c", "v", "r") + if err == nil { + t.Fatal("expected error on 400") + } + if !strings.Contains(err.Error(), "invalid_grant") { + t.Errorf("error should surface Keycloak's response: %q", err.Error()) + } +} + +// TestLoginWithPKCE_HappyPath drives the full flow end-to-end with a fake +// Keycloak. It hijacks the openBrowser var so the test itself plays the role +// of the browser — visiting the /authorize URL, which makes the fake Keycloak +// redirect back to the CLI's local listener with code+state. +func TestLoginWithPKCE_HappyPath(t *testing.T) { + mux := http.NewServeMux() + var srv *httptest.Server + srv = httptest.NewServer(mux) + defer srv.Close() + + mux.HandleFunc("/.well-known/openid-configuration", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]string{ + "authorization_endpoint": srv.URL + "/auth", + "token_endpoint": srv.URL + "/token", + }) + }) + + mux.HandleFunc("/auth", func(w http.ResponseWriter, r *http.Request) { + q := r.URL.Query() + redirectURI := q.Get("redirect_uri") + state := q.Get("state") + go func() { + resp, err := http.Get(redirectURI + "?code=fake-auth-code&state=" + state) + if err == nil && resp != nil { + _ = resp.Body.Close() + } + }() + w.WriteHeader(http.StatusOK) + }) + + mux.HandleFunc("/token", func(w http.ResponseWriter, r *http.Request) { + _ = r.ParseForm() + if r.Form.Get("code") != "fake-auth-code" { + http.Error(w, "bad code", http.StatusBadRequest) + return + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "access_token": "fake-access", + "refresh_token": "fake-refresh", + "expires_in": 300, + "token_type": "Bearer", + }) + }) + + original := openBrowser + openBrowser = func(target string) error { + go func() { + resp, err := http.Get(target) + if err == nil && resp != nil { + _ = resp.Body.Close() + } + }() + return nil + } + defer func() { openBrowser = original }() + + tokens, err := LoginWithPKCE(context.Background(), PKCELoginOptions{ + RealmURL: srv.URL, + ClientID: "ast-app", + Port: 0, + OpenBrowser: true, + }) + if err != nil { + t.Fatalf("LoginWithPKCE returned error: %v", err) + } + if tokens.RefreshToken != "fake-refresh" { + t.Errorf("got refresh_token %q, want fake-refresh", tokens.RefreshToken) + } +} + +func TestRevokeRefreshToken_HappyPath(t *testing.T) { + var gotForm url.Values + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/protocol/openid-connect/revoke" { + http.NotFound(w, r) + return + } + _ = r.ParseForm() + gotForm = r.Form + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + err := RevokeRefreshToken(context.Background(), srv.URL, "ast-app", "the-rt") + if err != nil { + t.Fatalf("RevokeRefreshToken returned error: %v", err) + } + if gotForm.Get("client_id") != "ast-app" { + t.Errorf("client_id form field: got %q, want ast-app", gotForm.Get("client_id")) + } + if gotForm.Get("token") != "the-rt" { + t.Errorf("token form field: got %q, want the-rt", gotForm.Get("token")) + } + if gotForm.Get("token_type_hint") != "refresh_token" { + t.Errorf("token_type_hint form field: got %q, want refresh_token", gotForm.Get("token_type_hint")) + } +} + +func TestRevokeRefreshToken_AlreadyInvalidIsIdempotent(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, `{"error":"invalid_grant","error_description":"Refresh token expired"}`, http.StatusBadRequest) + })) + defer srv.Close() + + err := RevokeRefreshToken(context.Background(), srv.URL, "ast-app", "expired-rt") + if err != nil { + t.Errorf("400 response should be treated as already-logged-out (nil error), got: %v", err) + } +} + +func TestRevokeRefreshToken_ServerErrorIsSurfaced(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "boom", http.StatusInternalServerError) + })) + defer srv.Close() + + err := RevokeRefreshToken(context.Background(), srv.URL, "ast-app", "the-rt") + if err == nil { + t.Fatal("expected error on 500 response, got nil") + } + if !strings.Contains(err.Error(), "500") { + t.Errorf("error should mention status code 500: %q", err.Error()) + } +} diff --git a/internal/wrappers/session_global.go b/internal/wrappers/session_global.go new file mode 100644 index 000000000..8b84c6ef8 --- /dev/null +++ b/internal/wrappers/session_global.go @@ -0,0 +1,145 @@ +package wrappers + +import ( + "os" + "path/filepath" + "strings" + + "github.com/checkmarx/ast-cli/internal/params" + "github.com/checkmarx/ast-cli/internal/wrappers/configuration" + "github.com/pkg/errors" + "github.com/spf13/viper" +) + +// File permission: owner read/write only. The global session file holds a +// refresh token; readable group/world would be a credential leak. +const sessionGlobalFilePerm = 0o600 + +// SessionGlobalFilePath returns the absolute path to the global session file. +// Derived from the same config directory as the existing yaml (so a custom +// --config-file-path is respected), with the filename swapped to +// SessionGlobalFileName. +func SessionGlobalFilePath() (string, error) { + configPath, err := configuration.GetConfigFilePath() + if err != nil { + return "", errors.Wrap(err, "failed to resolve config file path for global session file") + } + dir := filepath.Dir(configPath) + return filepath.Join(dir, params.SessionGlobalFileName), nil +} + +// ReadSessionGlobal returns the refresh token persisted by --session global +// mode. Returns ("", nil) if the file does not exist — that just means the +// user has not logged in via global mode. Any other I/O error is surfaced. +func ReadSessionGlobal() (string, error) { + path, err := SessionGlobalFilePath() + if err != nil { + return "", err + } + data, err := os.ReadFile(path) + if err != nil { + if os.IsNotExist(err) { + return "", nil + } + return "", errors.Wrap(err, "failed to read global session file") + } + return strings.TrimSpace(string(data)), nil +} + +// WriteSessionGlobal writes the refresh token to the global session file with +// owner-only permissions. Creates the parent directory if necessary so the +// first-ever --session global login on a fresh machine works. +func WriteSessionGlobal(refreshToken string) error { + path, err := SessionGlobalFilePath() + if err != nil { + return err + } + if mkErr := os.MkdirAll(filepath.Dir(path), 0o700); mkErr != nil { + return errors.Wrap(mkErr, "failed to create config directory for global session file") + } + if writeErr := os.WriteFile(path, []byte(refreshToken), sessionGlobalFilePerm); writeErr != nil { + return errors.Wrap(writeErr, "failed to write global session file") + } + return nil +} + +// ClearSessionGlobal removes the global session file. Returns nil if the file +// already does not exist (logout is idempotent — running it twice is fine). +func ClearSessionGlobal() error { + path, err := SessionGlobalFilePath() + if err != nil { + return err + } + if rmErr := os.Remove(path); rmErr != nil && !os.IsNotExist(rmErr) { + return errors.Wrap(rmErr, "failed to remove global session file") + } + return nil +} + +// LoadActiveCredential makes the refresh token from the active session mode +// available to viper, so every CLI command's existing cx_apikey lookup +// resolves to the right credential — without any precedence chain. +// +// The active-mode metadata file (~/.checkmarx/active_mode) tells us where +// the user's current credential lives: +// +// - "yaml": read cx_apikey from the yaml config file; viper.Set it so it +// wins over any stale CX_APIKEY env var left over from a +// previous --session local invocation +// - "local": no action — env-binding (viper.BindEnv) already gives env the +// right precedence; we want the user's current shell env to +// win +// - "global": read the dedicated global file; viper.Set it so it wins over +// any stale yaml or env value +// - "": no active session — viper's natural precedence applies +// (env > yaml). Backward-compatible with users who set +// CX_APIKEY directly or who logged in with the previous CLI. +// +// Called once at startup from main, after configuration.LoadConfiguration. +func LoadActiveCredential() { + mode, err := ReadActiveMode() + if err != nil || mode == "" { + return + } + switch mode { + case params.SessionGlobalValue: + rt, err := ReadSessionGlobal() + if err == nil && rt != "" { + viper.Set(params.AstAPIKey, rt) + } + case params.SessionYamlValue: + // Yaml's cx_apikey is already loaded by configuration.LoadConfiguration, + // but env-binding would override it if a stale CX_APIKEY is set in + // this shell. viper.Set with the yaml value forces yaml to win. + yamlRT := ReadYamlAPIKey() + if yamlRT != "" { + viper.Set(params.AstAPIKey, yamlRT) + } + case params.SessionLocalValue: + // Env binding already gives the current shell's CX_APIKEY the right + // precedence (env > config file in viper). Nothing to do here. + // If the user is in a shell that didn't run the --session local + // login, env will be empty and the command will surface a clear + // "not authenticated" error. + } +} + +// ReadYamlAPIKey reads cx_apikey directly from the yaml config file, bypassing +// viper's env-first precedence. Used by LoadActiveCredential to force yaml to +// win when the active mode is "yaml" but a stale CX_APIKEY env var exists, and +// by the auth login/logout nuke phase to revoke whatever yaml actually holds +// (not what viper currently resolves to, which could be a stale env var). +func ReadYamlAPIKey() string { + configPath, err := configuration.GetConfigFilePath() + if err != nil { + return "" + } + yamlConfig, err := configuration.LoadConfig(configPath) + if err != nil { + return "" + } + if v, ok := yamlConfig[params.AstAPIKey].(string); ok { + return v + } + return "" +} diff --git a/internal/wrappers/session_global_test.go b/internal/wrappers/session_global_test.go new file mode 100644 index 000000000..688415c81 --- /dev/null +++ b/internal/wrappers/session_global_test.go @@ -0,0 +1,163 @@ +package wrappers + +import ( + "os" + "path/filepath" + "testing" + + "github.com/checkmarx/ast-cli/internal/params" + "github.com/spf13/viper" +) + +// withTempConfigDir points viper at a temp directory for the duration of one +// test, so the session_global helpers operate on a sandbox rather than the +// real user's ~/.checkmarx. Restores prior state via t.Cleanup. +func withTempConfigDir(t *testing.T) string { + t.Helper() + dir := t.TempDir() + prev := viper.GetString(params.ConfigFilePathKey) + viper.Set(params.ConfigFilePathKey, filepath.Join(dir, "checkmarxcli.yaml")) + t.Cleanup(func() { + viper.Set(params.ConfigFilePathKey, prev) + }) + return dir +} + +func TestSessionGlobalFilePath_ReturnsPathInConfigDir(t *testing.T) { + dir := withTempConfigDir(t) + got, err := SessionGlobalFilePath() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + want := filepath.Join(dir, params.SessionGlobalFileName) + if got != want { + t.Errorf("SessionGlobalFilePath() = %q, want %q", got, want) + } +} + +func TestReadSessionGlobal_ReturnsEmptyWhenFileMissing(t *testing.T) { + withTempConfigDir(t) + got, err := ReadSessionGlobal() + if err != nil { + t.Fatalf("expected nil error when file does not exist, got: %v", err) + } + if got != "" { + t.Errorf("expected empty string, got %q", got) + } +} + +func TestWriteAndReadSessionGlobal_RoundTrip(t *testing.T) { + withTempConfigDir(t) + token := "eyJhbGc.test-refresh-token.xyz" + if err := WriteSessionGlobal(token); err != nil { + t.Fatalf("WriteSessionGlobal failed: %v", err) + } + got, err := ReadSessionGlobal() + if err != nil { + t.Fatalf("ReadSessionGlobal failed: %v", err) + } + if got != token { + t.Errorf("round-trip mismatch: wrote %q, read %q", token, got) + } +} + +func TestReadSessionGlobal_TrimsWhitespace(t *testing.T) { + dir := withTempConfigDir(t) + // Write a token with trailing newline directly to disk to simulate a file + // edited by hand. + path := filepath.Join(dir, params.SessionGlobalFileName) + if err := os.WriteFile(path, []byte("the-token\n"), 0o600); err != nil { + t.Fatalf("setup write failed: %v", err) + } + got, err := ReadSessionGlobal() + if err != nil { + t.Fatalf("ReadSessionGlobal failed: %v", err) + } + if got != "the-token" { + t.Errorf("expected trailing whitespace trimmed, got %q", got) + } +} + +func TestClearSessionGlobal_RemovesFile(t *testing.T) { + withTempConfigDir(t) + if err := WriteSessionGlobal("some-token"); err != nil { + t.Fatalf("setup write failed: %v", err) + } + if err := ClearSessionGlobal(); err != nil { + t.Fatalf("ClearSessionGlobal failed: %v", err) + } + got, err := ReadSessionGlobal() + if err != nil { + t.Fatalf("ReadSessionGlobal after clear failed: %v", err) + } + if got != "" { + t.Errorf("expected empty after clear, got %q", got) + } +} + +func TestClearSessionGlobal_IdempotentWhenFileMissing(t *testing.T) { + withTempConfigDir(t) + // File never created; clearing it should not error. + if err := ClearSessionGlobal(); err != nil { + t.Errorf("expected nil error when file does not exist, got: %v", err) + } +} + +func TestLoadActiveCredential_GlobalModeLoadsFile(t *testing.T) { + withTempConfigDir(t) + t.Setenv(params.AstAPIKeyEnv, "") + if err := WriteSessionGlobal("global-token"); err != nil { + t.Fatalf("setup write failed: %v", err) + } + if err := WriteActiveMode(params.SessionGlobalValue); err != nil { + t.Fatalf("WriteActiveMode failed: %v", err) + } + viper.Set(params.AstAPIKey, "") + LoadActiveCredential() + if got := viper.GetString(params.AstAPIKey); got != "global-token" { + t.Errorf("expected global mode to load token into viper, got %q", got) + } +} + +func TestLoadActiveCredential_GlobalOverridesStaleEnv(t *testing.T) { + withTempConfigDir(t) + t.Setenv(params.AstAPIKeyEnv, "stale-env-token") + if err := WriteSessionGlobal("global-token"); err != nil { + t.Fatalf("setup write failed: %v", err) + } + if err := WriteActiveMode(params.SessionGlobalValue); err != nil { + t.Fatalf("WriteActiveMode failed: %v", err) + } + viper.Set(params.AstAPIKey, "") + LoadActiveCredential() + if got := viper.GetString(params.AstAPIKey); got != "global-token" { + t.Errorf("global mode must win over stale env, got %q", got) + } +} + +func TestLoadActiveCredential_LocalModeNoOpLetsEnvWin(t *testing.T) { + withTempConfigDir(t) + t.Setenv(params.AstAPIKeyEnv, "local-token") + if err := WriteActiveMode(params.SessionLocalValue); err != nil { + t.Fatalf("WriteActiveMode failed: %v", err) + } + viper.Set(params.AstAPIKey, "") + LoadActiveCredential() + // We don't viper.Set for local mode — env binding does the work. + // Verify that we didn't overwrite anything. + // (We can't easily verify env-binding inside this test without + // going through viper, so just confirm no error and no surprise Set.) + if got := viper.GetString(params.AstAPIKey); got != "" && got != "local-token" { + t.Errorf("local mode should not viper.Set anything; got unexpected %q", got) + } +} + +func TestLoadActiveCredential_NoActiveModeIsNoOp(t *testing.T) { + withTempConfigDir(t) + // No WriteActiveMode call — file is absent. + viper.Set(params.AstAPIKey, "") + LoadActiveCredential() + if got := viper.GetString(params.AstAPIKey); got != "" { + t.Errorf("expected viper.AstAPIKey to remain empty when no active mode, got %q", got) + } +} diff --git a/internal/wrappers/utils/utils.go b/internal/wrappers/utils/utils.go index 793d9ef96..1453a10b4 100644 --- a/internal/wrappers/utils/utils.go +++ b/internal/wrappers/utils/utils.go @@ -17,6 +17,7 @@ var ( var allowedOptionalKeys = map[string]bool{ "asca-location": true, + "aiProvider": true, } // CleanURL returns a cleaned url removing double slashes diff --git a/test/integration/scan_test.go b/test/integration/scan_test.go index 786ec33ee..dd07eba90 100644 --- a/test/integration/scan_test.go +++ b/test/integration/scan_test.go @@ -2164,8 +2164,7 @@ func TestCreateAsyncScan_ChangedCachedTokenAndPollingScanStatus_Success(t *testi } scanID, _ := executeCreateScan(t, args) scanWrapper := wrappers.NewHTTPScansWrapper(viper.GetString(params.ScansPathKey)) - wrappers.CachedAccessToken = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWUsImlhdCI6MTUxNjIzOTAyMiwiaXNzIjoiaHR0cHM6Ly9kZXUuaWFtLmNoZWNrbWFyeC5uZXQvYXV0aC9yZWFsbXMvZ2FsYWN0aWNhIiwiYXN0LWJhc2UtdXJsIjoiaHR0cHM6Ly9kZXUuYXN0LmNoZWNrbWFyeC5uZXQifQ.j0MMhLKBkmvJ_vz5xjvvut5UfN7OJVPqV-RwJ3NdKD4" - wrappers.CachedAccessTime = time.Now() + wrappers.SetCachedAccessTokenForTest("eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWUsImlhdCI6MTUxNjIzOTAyMiwiaXNzIjoiaHR0cHM6Ly9kZXUuaWFtLmNoZWNrbWFyeC5uZXQvYXV0aC9yZWFsbXMvZ2FsYWN0aWNhIiwiYXN0LWJhc2UtdXJsIjoiaHR0cHM6Ly9kZXUuYXN0LmNoZWNrbWFyeC5uZXQifQ.j0MMhLKBkmvJ_vz5xjvvut5UfN7OJVPqV-RwJ3NdKD4") viper.Set(params.TokenExpirySecondsKey, 300) scan, _, err := scanWrapper.GetByID(scanID) asserts.Nil(t, err)