From 1271598fda157f1510a8f013f7a0e7ba3e8fe8e6 Mon Sep 17 00:00:00 2001 From: Nathan Rijksen Date: Thu, 10 Apr 2025 08:59:49 -0700 Subject: [PATCH 01/29] Captain: fix SetArgs, sort of --- internal/captain/command.go | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/internal/captain/command.go b/internal/captain/command.go index 16a25bc252..aa69f8601a 100644 --- a/internal/captain/command.go +++ b/internal/captain/command.go @@ -219,9 +219,12 @@ func (c *Command) ShortDescription() string { func (c *Command) Execute(args []string) error { defer profile.Measure("cobra:Execute", time.Now()) c.logArgs(args) - c.cobra.SetArgs(args) - err := c.cobra.Execute() - c.cobra.SetArgs(nil) + // Cobra always executes the root command, so we need to set the args for the root command + // This makes running command.Execute() super error-prone if the args don't match the command + // We should probably get rid of Cobra over issues like this + c.cobra.Root().SetArgs(args) + err := c.cobra.Root().Execute() + c.cobra.Root().SetArgs(nil) rationalizeError(&err) return setupSensibleErrors(err, args) } From e8d1d28116f5349c36ddba4d721c6d53df03e48b Mon Sep 17 00:00:00 2001 From: Nathan Rijksen Date: Thu, 10 Apr 2025 09:00:21 -0700 Subject: [PATCH 02/29] Don't include root command (state) in NameRecursive --- internal/captain/command.go | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/internal/captain/command.go b/internal/captain/command.go index aa69f8601a..3e8550a798 100644 --- a/internal/captain/command.go +++ b/internal/captain/command.go @@ -275,11 +275,14 @@ func (c *Command) Name() string { } func (c *Command) NameRecursive() string { - child := c + parent := c name := []string{} - for child != nil { - name = append([]string{child.Name()}, name...) - child = child.parent + for parent != nil { + name = append([]string{parent.Name()}, name...) + parent = parent.parent + if parent.parent == nil { + break // Don't include the root command in the name + } } return strings.Join(name, " ") } From 94c8042661a36cd2889b976197735f35300f8300 Mon Sep 17 00:00:00 2001 From: Nathan Rijksen Date: Thu, 10 Apr 2025 09:00:34 -0700 Subject: [PATCH 03/29] Added BaseCommand and AllChildren to Captain --- internal/captain/command.go | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/internal/captain/command.go b/internal/captain/command.go index 3e8550a798..ca3aa5fc19 100644 --- a/internal/captain/command.go +++ b/internal/captain/command.go @@ -24,6 +24,7 @@ import ( configMediator "github.com/ActiveState/cli/internal/mediators/config" "github.com/ActiveState/cli/internal/multilog" "github.com/ActiveState/cli/internal/osutils" + "github.com/ActiveState/cli/internal/osutils/stacktrace" "github.com/ActiveState/cli/internal/output" "github.com/ActiveState/cli/internal/profile" "github.com/ActiveState/cli/internal/rollbar" @@ -287,6 +288,14 @@ func (c *Command) NameRecursive() string { return strings.Join(name, " ") } +func (c *Command) BaseCommand() *Command { + base := c + for base.parent != nil && base.parent.parent != nil { + base = base.parent + } + return base +} + func (c *Command) NamePadding() int { return c.cobra.NamePadding() } @@ -463,6 +472,15 @@ func (c *Command) Children() []*Command { return commands } +func (c *Command) AllChildren() []*Command { + commands := []*Command{} + for _, child := range c.Children() { + commands = append(commands, child) + commands = append(commands, child.AllChildren()...) + } + return commands +} + func (c *Command) AvailableChildren() []*Command { commands := []*Command{} for _, child := range c.Children() { From 48850be5cfe11365d6adabd04b034ecebea1d2b4 Mon Sep 17 00:00:00 2001 From: Nathan Rijksen Date: Thu, 10 Apr 2025 09:02:32 -0700 Subject: [PATCH 04/29] Remove dependency on HOME and TEMPDIR env vars --- .../analytics/client/sync/reporters/test.go | 6 +- internal/config/instance.go | 9 +- internal/installation/storage/storage.go | 91 +++++----- .../installation/storage/storage_darwin.go | 13 ++ internal/installation/storage/storage_test.go | 7 +- .../installation/storage/storage_windows.go | 22 +++ internal/installation/storage/storage_xdg.go | 25 +++ internal/logging/defaults.go | 8 +- internal/osutils/user/user.go | 15 +- internal/svcctl/svcctl.go | 23 +-- internal/updater/updater.go | 5 +- scripts/ci/parallelize/parallelize.go | 5 +- test/integration/performance_svc_int_test.go | 4 +- vendor/github.com/shibukawa/configdir/LICENSE | 21 --- .../github.com/shibukawa/configdir/README.rst | 111 ------------ .../github.com/shibukawa/configdir/config.go | 160 ------------------ .../shibukawa/configdir/config_darwin.go | 8 - .../shibukawa/configdir/config_windows.go | 8 - .../shibukawa/configdir/config_xdg.go | 34 ---- vendor/modules.txt | 3 - 20 files changed, 134 insertions(+), 444 deletions(-) create mode 100644 internal/installation/storage/storage_darwin.go create mode 100644 internal/installation/storage/storage_windows.go create mode 100644 internal/installation/storage/storage_xdg.go delete mode 100644 vendor/github.com/shibukawa/configdir/LICENSE delete mode 100644 vendor/github.com/shibukawa/configdir/README.rst delete mode 100644 vendor/github.com/shibukawa/configdir/config.go delete mode 100644 vendor/github.com/shibukawa/configdir/config_darwin.go delete mode 100644 vendor/github.com/shibukawa/configdir/config_windows.go delete mode 100644 vendor/github.com/shibukawa/configdir/config_xdg.go diff --git a/internal/analytics/client/sync/reporters/test.go b/internal/analytics/client/sync/reporters/test.go index 95ff0c426a..7fe70846b6 100644 --- a/internal/analytics/client/sync/reporters/test.go +++ b/internal/analytics/client/sync/reporters/test.go @@ -18,10 +18,8 @@ type TestReporter struct { const TestReportFilename = "analytics.log" func TestReportFilepath() string { - appdata, err := storage.AppDataPath() - if err != nil { - logging.Warning("Could not acquire appdata path, using cwd instead. Error received: %s", errs.JoinMessage(err)) - } + appdata := storage.AppDataPath() + logging.Warning("Appdata path: %s", appdata) return filepath.Join(appdata, TestReportFilename) } diff --git a/internal/config/instance.go b/internal/config/instance.go index fd6d68fe48..d1c1a47e24 100644 --- a/internal/config/instance.go +++ b/internal/config/instance.go @@ -41,14 +41,10 @@ func NewCustom(localPath string, thread *singlethread.Thread, closeThread bool) i.thread = thread i.closeThread = closeThread - var err error if localPath != "" { - i.appDataDir, err = storage.AppDataPathWithParent(localPath) + i.appDataDir = storage.AppDataPathWithParent(localPath) } else { - i.appDataDir, err = storage.AppDataPath() - } - if err != nil { - return nil, errs.Wrap(err, "Could not detect appdata dir") + i.appDataDir = storage.AppDataPath() } // Ensure appdata dir exists, because the sqlite driver sure doesn't @@ -61,6 +57,7 @@ func NewCustom(localPath string, thread *singlethread.Thread, closeThread bool) path := filepath.Join(i.appDataDir, C.InternalConfigFileName) + var err error t := time.Now() i.db, err = sql.Open("sqlite", path) if err != nil { diff --git a/internal/installation/storage/storage.go b/internal/installation/storage/storage.go index a8659cbed1..f511f8d7c9 100644 --- a/internal/installation/storage/storage.go +++ b/internal/installation/storage/storage.go @@ -11,12 +11,29 @@ import ( "github.com/ActiveState/cli/internal/constants" "github.com/ActiveState/cli/internal/osutils/user" "github.com/google/uuid" - "github.com/shibukawa/configdir" ) -func AppDataPath() (string, error) { - configDirs := configdir.New(constants.InternalConfigNamespace, fmt.Sprintf("%s-%s", constants.LibraryName, constants.ChannelName)) +var homeDir string + +func init() { + var err error + homeDir, err = user.HomeDir() + if err != nil { + panic(fmt.Sprintf("Could not get home dir, you can fix this by ensuring the $HOME environment variable is set. Error: %v", err)) + } +} + + +func relativeAppDataPath() string { + return filepath.Join(constants.InternalConfigNamespace, fmt.Sprintf("%s-%s", constants.LibraryName, constants.ChannelName)) +} + +func relativeCachePath() string { + return constants.InternalConfigNamespace +} + +func AppDataPath() string { localPath, envSet := os.LookupEnv(constants.ConfigEnvVarName) if envSet { return AppDataPathWithParent(localPath) @@ -27,35 +44,10 @@ func AppDataPath() (string, error) { // panic as this only happening in tests panic(err) } - return localPath, nil + return localPath } - // Account for HOME dir not being set, meaning querying global folders will fail - // This is a workaround for docker envs that don't usually have $HOME set - _, envSet = os.LookupEnv("HOME") - if !envSet && runtime.GOOS != "windows" { - homeDir, err := user.HomeDir() - if err != nil { - if !condition.InUnitTest() { - return "", fmt.Errorf("Could not get user home directory: %w", err) - } - // Use temp dir if we're in a test (we don't want to write to our src directory) - var err error - localPath, err = os.MkdirTemp("", "cli-config-test") - if err != nil { - return "", fmt.Errorf("could not create temp dir: %w", err) - } - return AppDataPathWithParent(localPath) - } - os.Setenv("HOME", homeDir) - } - - dir := configDirs.QueryFolders(configdir.Global)[0].Path - if err := os.MkdirAll(dir, os.ModePerm); err != nil { - return "", fmt.Errorf("could not create appdata dir: %s", dir) - } - - return dir, nil + return AppDataPathWithParent(BaseAppDataPath()) } var _appDataPathInTest string @@ -67,11 +59,11 @@ func appDataPathInTest() (string, error) { localPath, err := os.MkdirTemp("", "cli-config") if err != nil { - return "", fmt.Errorf("Could not create temp dir: %w", err) + return "", fmt.Errorf("could not create temp dir: %w", err) } err = os.RemoveAll(localPath) if err != nil { - return "", fmt.Errorf("Could not remove generated config dir for tests: %w", err) + return "", fmt.Errorf("could not remove generated config dir for tests: %w", err) } _appDataPathInTest = localPath @@ -79,16 +71,15 @@ func appDataPathInTest() (string, error) { return localPath, nil } -func AppDataPathWithParent(parentDir string) (string, error) { - configDirs := configdir.New(constants.InternalConfigNamespace, fmt.Sprintf("%s-%s", constants.LibraryName, constants.ChannelName)) - configDirs.LocalPath = parentDir - dir := configDirs.QueryFolders(configdir.Local)[0].Path - +func AppDataPathWithParent(parentDir string) string { + dir := filepath.Join(parentDir, relativeAppDataPath()) if err := os.MkdirAll(dir, os.ModePerm); err != nil { - return "", fmt.Errorf("could not create appdata dir: %s", dir) + // Can't use logging here because it would cause a circular dependency + // This would only happen if the user has corrupt permissions on their home dir + os.Stderr.WriteString(fmt.Sprintf("Could not create appdata dir: %s", dir)) } - return dir, nil + return dir } // CachePath returns the path at which our cache is stored @@ -108,17 +99,15 @@ func CachePath() string { cachePath = filepath.Join(drive, "temp", prefix+uuid.New().String()[0:8]) } } - } else if path := os.Getenv(constants.CacheEnvVarName); path != "" { - cachePath = path - } else { - cachePath = configdir.New(constants.InternalConfigNamespace, "").QueryCacheFolder().Path - if runtime.GOOS == "windows" { - // Explicitly append "cache" dir as the cachedir on Windows is the same as the local appdata dir (conflicts with config) - cachePath = filepath.Join(cachePath, "cache") - } + return cachePath + + } + + if path := os.Getenv(constants.CacheEnvVarName); path != "" { + return path } - return cachePath + return filepath.Join(BaseCachePath(), relativeCachePath()) } func GlobalBinDir() string { @@ -127,11 +116,7 @@ func GlobalBinDir() string { // InstallSource returns the installation source of the State Tool func InstallSource() (string, error) { - path, err := AppDataPath() - if err != nil { - return "", fmt.Errorf("Could not detect AppDataPath: %w", err) - } - + path := AppDataPath() installFilePath := filepath.Join(path, constants.InstallSourceFile) installFileData, err := os.ReadFile(installFilePath) if err != nil { diff --git a/internal/installation/storage/storage_darwin.go b/internal/installation/storage/storage_darwin.go new file mode 100644 index 0000000000..2dde6e1630 --- /dev/null +++ b/internal/installation/storage/storage_darwin.go @@ -0,0 +1,13 @@ +package storage + +import ( + "path/filepath" +) + +func BaseAppDataPath() string { + return filepath.Join(homeDir, "Library", "Application Support") +} + +func BaseCachePath() string { + return filepath.Join(homeDir, "Library", "Caches") +} diff --git a/internal/installation/storage/storage_test.go b/internal/installation/storage/storage_test.go index d0887b59cf..6ecef193ad 100644 --- a/internal/installation/storage/storage_test.go +++ b/internal/installation/storage/storage_test.go @@ -4,13 +4,10 @@ import ( "testing" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func Test_AppDataPath(t *testing.T) { - path1, err := AppDataPath() - require.NoError(t, err) - path2, err := AppDataPath() - require.NoError(t, err) + path1 := AppDataPath() + path2 := AppDataPath() assert.Equal(t, path1, path2) } diff --git a/internal/installation/storage/storage_windows.go b/internal/installation/storage/storage_windows.go new file mode 100644 index 0000000000..4a4986440a --- /dev/null +++ b/internal/installation/storage/storage_windows.go @@ -0,0 +1,22 @@ +package storage + +import ( + "os" + "path/filepath" +) + +func BaseAppDataPath() string { + if appData := os.Getenv("APPDATA"); appData != "" { + return appData + } + + return filepath.Join(homeDir, "AppData", "Roaming") +} + +func BaseCachePath() string { + if cache := os.Getenv("LOCALAPPDATA"); cache != "" { + return cache + } + + return filepath.Join(homeDir, "AppData", "Local", "cache") +} diff --git a/internal/installation/storage/storage_xdg.go b/internal/installation/storage/storage_xdg.go new file mode 100644 index 0000000000..1efefd0649 --- /dev/null +++ b/internal/installation/storage/storage_xdg.go @@ -0,0 +1,25 @@ +//go:build !windows && !darwin +// +build !windows,!darwin + +package storage + +import ( + "os" + "path/filepath" +) + +func BaseAppDataPath() string { + if os.Getenv("XDG_CONFIG_HOME") != "" { + return os.Getenv("XDG_CONFIG_HOME") + } + + return filepath.Join(homeDir, ".config") +} + +func BaseCachePath() string { + if os.Getenv("XDG_CACHE_HOME") != "" { + return os.Getenv("XDG_CACHE_HOME") + } + + return filepath.Join(homeDir, ".cache") +} diff --git a/internal/logging/defaults.go b/internal/logging/defaults.go index 68a5c897bf..395e41c446 100644 --- a/internal/logging/defaults.go +++ b/internal/logging/defaults.go @@ -83,13 +83,7 @@ func init() { defer func() { handlePanics(recover()) }() // Set up datadir - var err error - datadir, err = storage.AppDataPath() - if err != nil { - log.SetOutput(os.Stderr) - Error("Could not detect AppData dir: %v", err) - return - } + datadir = storage.AppDataPath() // Set up handler timestamp = time.Now().UnixNano() diff --git a/internal/osutils/user/user.go b/internal/osutils/user/user.go index 1b485d0fab..e8238c3063 100644 --- a/internal/osutils/user/user.go +++ b/internal/osutils/user/user.go @@ -2,8 +2,10 @@ package user import ( "os" + "os/user" "github.com/ActiveState/cli/internal/constants" + "github.com/ActiveState/cli/internal/errs" ) // HomeDirNotFoundError is an error that implements the ErrorLocalier and ErrorInput interfaces @@ -40,9 +42,16 @@ func HomeDir() (string, error) { if dir := os.Getenv(constants.HomeEnvVarName); dir != "" { return dir, nil } - dir, err := os.UserHomeDir() - if err != nil { - return "", &HomeDirNotFoundError{err} + + u, err := user.Current() + if err == nil { + return u.HomeDir, nil + } + + // If we can't get the current user, try to get the home dir from the os + dir, err2 := os.UserHomeDir() + if err2 != nil { + return "", &HomeDirNotFoundError{errs.Pack(err, err2)} } return dir, nil } diff --git a/internal/svcctl/svcctl.go b/internal/svcctl/svcctl.go index 58933826f9..0887749d85 100644 --- a/internal/svcctl/svcctl.go +++ b/internal/svcctl/svcctl.go @@ -7,21 +7,19 @@ package svcctl import ( "context" "errors" - "fmt" "io" "os" - "path/filepath" "time" "github.com/ActiveState/cli/internal/constants" "github.com/ActiveState/cli/internal/errs" "github.com/ActiveState/cli/internal/fileutils" "github.com/ActiveState/cli/internal/installation" + "github.com/ActiveState/cli/internal/installation/storage" "github.com/ActiveState/cli/internal/ipc" "github.com/ActiveState/cli/internal/locale" "github.com/ActiveState/cli/internal/logging" "github.com/ActiveState/cli/internal/osutils" - "github.com/ActiveState/cli/internal/output" "github.com/ActiveState/cli/internal/profile" ) @@ -45,25 +43,30 @@ type IPCommunicator interface { SockPath() *ipc.SockPath } +type Outputer interface { + Notice(interface{}) +} + func NewIPCSockPathFromGlobals() *ipc.SockPath { - subdir := fmt.Sprintf("%s-%s", constants.CommandName, "ipc") - rootDir := filepath.Join(os.TempDir(), subdir) + rootDir := storage.AppDataPath() if os.Getenv(constants.ServiceSockDir) != "" { rootDir = os.Getenv(constants.ServiceSockDir) } - return &ipc.SockPath{ + sp := &ipc.SockPath{ RootDir: rootDir, AppName: constants.CommandName, AppChannel: constants.ChannelName, } + + return sp } func NewDefaultIPCClient() *ipc.Client { return ipc.NewClient(NewIPCSockPathFromGlobals()) } -func EnsureExecStartedAndLocateHTTP(ipComm IPCommunicator, exec, argText string, out output.Outputer) (addr string, err error) { +func EnsureExecStartedAndLocateHTTP(ipComm IPCommunicator, exec, argText string, out Outputer) (addr string, err error) { defer profile.Measure("svcctl:EnsureExecStartedAndLocateHTTP", time.Now()) addr, err = LocateHTTP(ipComm) @@ -91,7 +94,7 @@ func EnsureExecStartedAndLocateHTTP(ipComm IPCommunicator, exec, argText string, return addr, nil } -func EnsureStartedAndLocateHTTP(argText string, out output.Outputer) (addr string, err error) { +func EnsureStartedAndLocateHTTP(argText string, out Outputer) (addr string, err error) { svcExec, err := installation.ServiceExec() if err != nil { return "", locale.WrapError(err, "err_service_exec") @@ -146,7 +149,7 @@ func StopServer(ipComm IPCommunicator) error { return nil } -func startAndWait(ctx context.Context, ipComm IPCommunicator, exec, argText string, out output.Outputer) error { +func startAndWait(ctx context.Context, ipComm IPCommunicator, exec, argText string, out Outputer) error { defer profile.Measure("svcmanager:Start", time.Now()) if !fileutils.FileExists(exec) { @@ -174,7 +177,7 @@ var ( waitTimeoutL10nKey = "svcctl_wait_timeout" ) -func waitUp(ctx context.Context, ipComm IPCommunicator, out output.Outputer, debugInfo *debugData) error { +func waitUp(ctx context.Context, ipComm IPCommunicator, out Outputer, debugInfo *debugData) error { debugInfo.startWait() defer debugInfo.stopWait() diff --git a/internal/updater/updater.go b/internal/updater/updater.go index bd2a661cb0..3560ae1557 100644 --- a/internal/updater/updater.go +++ b/internal/updater/updater.go @@ -207,10 +207,7 @@ func (u *UpdateInstaller) InstallBlocking(installTargetPath string, args ...stri return errs.Wrap(err, "Could not check if State Tool was installed as admin") } - appdata, err := storage.AppDataPath() - if err != nil { - return errs.Wrap(err, "Could not detect appdata path") - } + appdata := storage.AppDataPath() // Protect against multiple updates happening simultaneously lockFile := filepath.Join(appdata, "install.lock") diff --git a/scripts/ci/parallelize/parallelize.go b/scripts/ci/parallelize/parallelize.go index f963523645..fe135e2b63 100644 --- a/scripts/ci/parallelize/parallelize.go +++ b/scripts/ci/parallelize/parallelize.go @@ -68,10 +68,7 @@ func run() error { } func jobDir() string { - path, err := storage.AppDataPath() - if err != nil { - panic(err) - } + path := storage.AppDataPath() path = filepath.Join(path, "jobs") if err := fileutils.MkdirUnlessExists(path); err != nil { diff --git a/test/integration/performance_svc_int_test.go b/test/integration/performance_svc_int_test.go index 5d05fa7615..46f0ecf20d 100644 --- a/test/integration/performance_svc_int_test.go +++ b/test/integration/performance_svc_int_test.go @@ -41,9 +41,7 @@ func (suite *PerformanceIntegrationTestSuite) TestSvcPerformance() { // This integration test is a bit special because it bypasses the spawning logic // so in order to get the right log files when debugging we manually provide the config dir - var err error - ts.Dirs.Config, err = storage.AppDataPath() - suite.Require().NoError(err) + ts.Dirs.Config = storage.AppDataPath() ipcClient := svcctl.NewDefaultIPCClient() var svcPort string diff --git a/vendor/github.com/shibukawa/configdir/LICENSE b/vendor/github.com/shibukawa/configdir/LICENSE deleted file mode 100644 index b20af456a1..0000000000 --- a/vendor/github.com/shibukawa/configdir/LICENSE +++ /dev/null @@ -1,21 +0,0 @@ -The MIT License (MIT) - -Copyright (c) 2016 shibukawa - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. diff --git a/vendor/github.com/shibukawa/configdir/README.rst b/vendor/github.com/shibukawa/configdir/README.rst deleted file mode 100644 index 99906697da..0000000000 --- a/vendor/github.com/shibukawa/configdir/README.rst +++ /dev/null @@ -1,111 +0,0 @@ -configdir for Golang -===================== - -Multi platform library of configuration directory for Golang. - -This library helps to get regular directories for configuration files or cache files that matches target operationg system's convention. - -It assumes the following folders are standard paths of each environment: - -.. list-table:: - :header-rows: 1 - - - * - * Windows: - * Linux/BSDs: - * MacOSX: - - * System level configuration folder - * ``%PROGRAMDATA%`` (``C:\\ProgramData``) - * ``${XDG_CONFIG_DIRS}`` (``/etc/xdg``) - * ``/Library/Application Support`` - - * User level configuration folder - * ``%APPDATA%`` (``C:\\Users\\\\AppData\\Roaming``) - * ``${XDG_CONFIG_HOME}`` (``${HOME}/.config``) - * ``${HOME}/Library/Application Support`` - - * User wide cache folder - * ``%LOCALAPPDATA%`` ``(C:\\Users\\\\AppData\\Local)`` - * ``${XDG_CACHE_HOME}`` (``${HOME}/.cache``) - * ``${HOME}/Library/Caches`` - -Examples ------------- - -Getting Configuration -~~~~~~~~~~~~~~~~~~~~~~~~ - -``configdir.ConfigDir.QueryFolderContainsFile()`` searches files in the following order: - -* Local path (if you add the path via LocalPath parameter) -* User level configuration folder(e.g. ``$HOME/.config///setting.json`` in Linux) -* System level configuration folder(e.g. ``/etc/xdg///setting.json`` in Linux) - -``configdir.Config`` provides some convenient methods(``ReadFile``, ``WriteFile`` and so on). - -.. code-block:: go - - var config Config - - configDirs := configdir.New("vendor-name", "application-name") - // optional: local path has the highest priority - configDirs.LocalPath, _ = filepath.Abs(".") - folder := configDirs.QueryFolderContainsFile("setting.json") - if folder != nil { - data, _ := folder.ReadFile("setting.json") - json.Unmarshal(data, &config) - } else { - config = DefaultConfig - } - -Write Configuration -~~~~~~~~~~~~~~~~~~~~~~ - -When storing configuration, get configuration folder by using ``configdir.ConfigDir.QueryFolders()`` method. - -.. code-block:: go - - configDirs := configdir.New("vendor-name", "application-name") - - var config Config - data, _ := json.Marshal(&config) - - // Stores to local folder - folders := configDirs.QueryFolders(configdir.Local) - folders[0].WriteFile("setting.json", data) - - // Stores to user folder - folders = configDirs.QueryFolders(configdir.Global) - folders[0].WriteFile("setting.json", data) - - // Stores to system folder - folders = configDirs.QueryFolders(configdir.System) - folders[0].WriteFile("setting.json", data) - -Getting Cache Folder -~~~~~~~~~~~~~~~~~~~~~~ - -It is similar to the above example, but returns cache folder. - -.. code-block:: go - - configDirs := configdir.New("vendor-name", "application-name") - cache := configDirs.QueryCacheFolder() - - resp, err := http.Get("http://examples.com/sdk.zip") - if err != nil { - log.Fatal(err) - } - defer resp.Body.Close() - body, err := ioutil.ReadAll(resp.Body) - - cache.WriteFile("sdk.zip", body) - -Document ------------- - -https://godoc.org/github.com/shibukawa/configdir - -License ------------- - -MIT - diff --git a/vendor/github.com/shibukawa/configdir/config.go b/vendor/github.com/shibukawa/configdir/config.go deleted file mode 100644 index 8a20e54b59..0000000000 --- a/vendor/github.com/shibukawa/configdir/config.go +++ /dev/null @@ -1,160 +0,0 @@ -// configdir provides access to configuration folder in each platforms. -// -// System wide configuration folders: -// -// - Windows: %PROGRAMDATA% (C:\ProgramData) -// - Linux/BSDs: ${XDG_CONFIG_DIRS} (/etc/xdg) -// - MacOSX: "/Library/Application Support" -// -// User wide configuration folders: -// -// - Windows: %APPDATA% (C:\Users\\AppData\Roaming) -// - Linux/BSDs: ${XDG_CONFIG_HOME} (${HOME}/.config) -// - MacOSX: "${HOME}/Library/Application Support" -// -// User wide cache folders: -// -// - Windows: %LOCALAPPDATA% (C:\Users\\AppData\Local) -// - Linux/BSDs: ${XDG_CACHE_HOME} (${HOME}/.cache) -// - MacOSX: "${HOME}/Library/Caches" -// -// configdir returns paths inside the above folders. - -package configdir - -import ( - "io/ioutil" - "os" - "path/filepath" -) - -type ConfigType int - -const ( - System ConfigType = iota - Global - All - Existing - Local - Cache -) - -// Config represents each folder -type Config struct { - Path string - Type ConfigType -} - -func (c Config) Open(fileName string) (*os.File, error) { - return os.Open(filepath.Join(c.Path, fileName)) -} - -func (c Config) Create(fileName string) (*os.File, error) { - err := c.CreateParentDir(fileName) - if err != nil { - return nil, err - } - return os.Create(filepath.Join(c.Path, fileName)) -} - -func (c Config) ReadFile(fileName string) ([]byte, error) { - return ioutil.ReadFile(filepath.Join(c.Path, fileName)) -} - -// CreateParentDir creates the parent directory of fileName inside c. fileName -// is a relative path inside c, containing zero or more path separators. -func (c Config) CreateParentDir(fileName string) error { - return os.MkdirAll(filepath.Dir(filepath.Join(c.Path, fileName)), 0755) -} - -func (c Config) WriteFile(fileName string, data []byte) error { - err := c.CreateParentDir(fileName) - if err != nil { - return err - } - return ioutil.WriteFile(filepath.Join(c.Path, fileName), data, 0644) -} - -func (c Config) MkdirAll() error { - return os.MkdirAll(c.Path, 0755) -} - -func (c Config) Exists(fileName string) bool { - _, err := os.Stat(filepath.Join(c.Path, fileName)) - return !os.IsNotExist(err) -} - -// ConfigDir keeps setting for querying folders. -type ConfigDir struct { - VendorName string - ApplicationName string - LocalPath string -} - -func New(vendorName, applicationName string) ConfigDir { - return ConfigDir{ - VendorName: vendorName, - ApplicationName: applicationName, - } -} - -func (c ConfigDir) joinPath(root string) string { - if c.VendorName != "" && hasVendorName { - return filepath.Join(root, c.VendorName, c.ApplicationName) - } - return filepath.Join(root, c.ApplicationName) -} - -func (c ConfigDir) QueryFolders(configType ConfigType) []*Config { - if configType == Cache { - return []*Config{c.QueryCacheFolder()} - } - var result []*Config - if c.LocalPath != "" && configType != System && configType != Global { - result = append(result, &Config{ - Path: c.LocalPath, - Type: Local, - }) - } - if configType != System && configType != Local { - result = append(result, &Config{ - Path: c.joinPath(globalSettingFolder), - Type: Global, - }) - } - if configType != Global && configType != Local { - for _, root := range systemSettingFolders { - result = append(result, &Config{ - Path: c.joinPath(root), - Type: System, - }) - } - } - if configType != Existing { - return result - } - var existing []*Config - for _, entry := range result { - if _, err := os.Stat(entry.Path); !os.IsNotExist(err) { - existing = append(existing, entry) - } - } - return existing -} - -func (c ConfigDir) QueryFolderContainsFile(fileName string) *Config { - configs := c.QueryFolders(Existing) - for _, config := range configs { - if _, err := os.Stat(filepath.Join(config.Path, fileName)); !os.IsNotExist(err) { - return config - } - } - return nil -} - -func (c ConfigDir) QueryCacheFolder() *Config { - return &Config{ - Path: c.joinPath(cacheFolder), - Type: Cache, - } -} diff --git a/vendor/github.com/shibukawa/configdir/config_darwin.go b/vendor/github.com/shibukawa/configdir/config_darwin.go deleted file mode 100644 index d668507a7e..0000000000 --- a/vendor/github.com/shibukawa/configdir/config_darwin.go +++ /dev/null @@ -1,8 +0,0 @@ -package configdir - -import "os" - -var hasVendorName = true -var systemSettingFolders = []string{"/Library/Application Support"} -var globalSettingFolder = os.Getenv("HOME") + "/Library/Application Support" -var cacheFolder = os.Getenv("HOME") + "/Library/Caches" diff --git a/vendor/github.com/shibukawa/configdir/config_windows.go b/vendor/github.com/shibukawa/configdir/config_windows.go deleted file mode 100644 index 0984821778..0000000000 --- a/vendor/github.com/shibukawa/configdir/config_windows.go +++ /dev/null @@ -1,8 +0,0 @@ -package configdir - -import "os" - -var hasVendorName = true -var systemSettingFolders = []string{os.Getenv("PROGRAMDATA")} -var globalSettingFolder = os.Getenv("APPDATA") -var cacheFolder = os.Getenv("LOCALAPPDATA") diff --git a/vendor/github.com/shibukawa/configdir/config_xdg.go b/vendor/github.com/shibukawa/configdir/config_xdg.go deleted file mode 100644 index 026ca68a0b..0000000000 --- a/vendor/github.com/shibukawa/configdir/config_xdg.go +++ /dev/null @@ -1,34 +0,0 @@ -// +build !windows,!darwin - -package configdir - -import ( - "os" - "path/filepath" - "strings" -) - -// https://specifications.freedesktop.org/basedir-spec/basedir-spec-latest.html - -var hasVendorName = true -var systemSettingFolders []string -var globalSettingFolder string -var cacheFolder string - -func init() { - if os.Getenv("XDG_CONFIG_HOME") != "" { - globalSettingFolder = os.Getenv("XDG_CONFIG_HOME") - } else { - globalSettingFolder = filepath.Join(os.Getenv("HOME"), ".config") - } - if os.Getenv("XDG_CONFIG_DIRS") != "" { - systemSettingFolders = strings.Split(os.Getenv("XDG_CONFIG_DIRS"), ":") - } else { - systemSettingFolders = []string{"/etc/xdg"} - } - if os.Getenv("XDG_CACHE_HOME") != "" { - cacheFolder = os.Getenv("XDG_CACHE_HOME") - } else { - cacheFolder = filepath.Join(os.Getenv("HOME"), ".cache") - } -} diff --git a/vendor/modules.txt b/vendor/modules.txt index 0d38daabe6..d2c03b10e9 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -578,9 +578,6 @@ github.com/rollbar/rollbar-go # github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3 ## explicit; go 1.13 github.com/sergi/go-diff/diffmatchpatch -# github.com/shibukawa/configdir v0.0.0-20170330084843-e180dbdc8da0 -## explicit -github.com/shibukawa/configdir # github.com/shirou/gopsutil/v3 v3.24.5 ## explicit; go 1.18 github.com/shirou/gopsutil/v3/common From 7114a625f5d33fa4a74fcaf0f28929e10dc4c3dc Mon Sep 17 00:00:00 2001 From: Nathan Rijksen Date: Thu, 10 Apr 2025 09:03:01 -0700 Subject: [PATCH 05/29] Support debugging via go-build --- internal/installation/appinfo.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/internal/installation/appinfo.go b/internal/installation/appinfo.go index 2a18a42e93..0099891ad4 100644 --- a/internal/installation/appinfo.go +++ b/internal/installation/appinfo.go @@ -46,7 +46,9 @@ func newExecFromDir(baseDir string, exec executableType) (string, error) { // Work around dlv and goland debugger giving an unexpected executable path if !condition.BuiltViaCI() && len(os.Args) > 1 && - (strings.Contains(os.Args[0], "__debug_bin") || strings.Contains(filepath.ToSlash(os.Args[0]), "GoLand/___")) { + (strings.Contains(os.Args[0], "__debug_bin") || + strings.Contains(filepath.ToSlash(os.Args[0]), "GoLand/___") || + strings.Contains(os.Args[0], "go-build")) { rootPath := filepath.Clean(environment.GetRootPathUnsafe()) path = filepath.Join(rootPath, "build") } From 0d405d88176f27f68e91a8df20fbf5ac7562f384 Mon Sep 17 00:00:00 2001 From: Nathan Rijksen Date: Thu, 10 Apr 2025 09:03:18 -0700 Subject: [PATCH 06/29] Fix nil panic --- internal/runners/cve/cve.go | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/internal/runners/cve/cve.go b/internal/runners/cve/cve.go index 96c24a24f4..dac49d9602 100644 --- a/internal/runners/cve/cve.go +++ b/internal/runners/cve/cve.go @@ -9,6 +9,7 @@ import ( "github.com/ActiveState/cli/internal/errs" "github.com/ActiveState/cli/internal/locale" + "github.com/ActiveState/cli/internal/logging" "github.com/ActiveState/cli/internal/output" "github.com/ActiveState/cli/internal/output/renderers" "github.com/ActiveState/cli/internal/primer" @@ -60,6 +61,12 @@ type cveOutput struct { } func (r *Cve) Run(params *Params) error { + defer func() { + if rc := recover(); rc != nil { + logging.Error("Recovered from panic: %v", rc) + fmt.Printf("Recovered from panic: %v\n", rc) + } + }() if !params.Namespace.IsValid() && r.proj == nil { return rationalize.ErrNoProject } @@ -71,7 +78,7 @@ func (r *Cve) Run(params *Params) error { ) } - vulnerabilities, err := r.fetchVulnerabilities(*params.Namespace) + vulnerabilities, err := r.fetchVulnerabilities(params.Namespace) if err != nil { var errProjectNotFound *model.ErrProjectNotFound if errors.As(err, &errProjectNotFound) { @@ -101,7 +108,7 @@ func (r *Cve) Run(params *Params) error { return nil } -func (r *Cve) fetchVulnerabilities(namespaceOverride project.Namespaced) (*medmodel.CommitVulnerabilities, error) { +func (r *Cve) fetchVulnerabilities(namespaceOverride *project.Namespaced) (*medmodel.CommitVulnerabilities, error) { if namespaceOverride.IsValid() && namespaceOverride.CommitID == nil { resp, err := model.FetchProjectVulnerabilities(r.auth, namespaceOverride.Owner, namespaceOverride.Project) if err != nil { From d614458ab5be80a07602d3c46ed086e17d1ff6c3 Mon Sep 17 00:00:00 2001 From: Nathan Rijksen Date: Thu, 10 Apr 2025 09:03:33 -0700 Subject: [PATCH 07/29] Fix simple output oddly formatted --- internal/runners/cve/cve.go | 3 --- 1 file changed, 3 deletions(-) diff --git a/internal/runners/cve/cve.go b/internal/runners/cve/cve.go index dac49d9602..a46da01c26 100644 --- a/internal/runners/cve/cve.go +++ b/internal/runners/cve/cve.go @@ -142,9 +142,6 @@ type SeverityCountOutput struct { } func (rd *cveOutput) MarshalOutput(format output.Format) interface{} { - if format != output.PlainFormatName { - return rd.data - } ri := &CveInfo{ fmt.Sprintf("[ACTIONABLE]%s[/RESET]", rd.data.Project), rd.data.CommitID, From e35097bf2dbb351cdd4ff30f2a88643fefb5fd8c Mon Sep 17 00:00:00 2001 From: Nathan Rijksen Date: Thu, 10 Apr 2025 09:03:46 -0700 Subject: [PATCH 08/29] Drop useless debug entry --- pkg/platform/model/svc.go | 1 - 1 file changed, 1 deletion(-) diff --git a/pkg/platform/model/svc.go b/pkg/platform/model/svc.go index d6e4af9ae3..43e7bcebc3 100644 --- a/pkg/platform/model/svc.go +++ b/pkg/platform/model/svc.go @@ -181,7 +181,6 @@ func (m *SvcModel) GetProcessesInUse(ctx context.Context, execDir string) ([]*gr // Note we respond with mono_models.JWT here for compatibility and to minimize the changeset at time of implementation. // We can revisit this in the future. func (m *SvcModel) GetJWT(ctx context.Context) (*mono_models.JWT, error) { - logging.Debug("Checking for GetJWT") defer profile.Measure("svc:GetJWT", time.Now()) r := request.NewJWTRequest() From 3a78c3c3bbdb666e578b04f1a42d038c1b6be248 Mon Sep 17 00:00:00 2001 From: Nathan Rijksen Date: Thu, 10 Apr 2025 09:03:58 -0700 Subject: [PATCH 09/29] Work around vscode syntax parsing issue --- pkg/projectfile/projectfile.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/projectfile/projectfile.go b/pkg/projectfile/projectfile.go index e3ebcc5c62..b4b105b29b 100644 --- a/pkg/projectfile/projectfile.go +++ b/pkg/projectfile/projectfile.go @@ -1129,7 +1129,7 @@ func AddLockInfo(projectFilePath, branch, version string) error { projectRegex := regexp.MustCompile(fmt.Sprintf("(?m:(^project:\\s*%s))", ProjectURLRe)) lockString := fmt.Sprintf("%s@%s", branch, version) - lockUpdate := []byte(fmt.Sprintf("${1}\nlock: %s", lockString)) + lockUpdate := []byte(fmt.Sprintf(`${1}\nlock: %s`, lockString)) data, err = os.ReadFile(projectFilePath) if err != nil { From ed26dc9b0510deb72884834c3b61dcc35958d217 Mon Sep 17 00:00:00 2001 From: Nathan Rijksen Date: Thu, 10 Apr 2025 09:04:14 -0700 Subject: [PATCH 10/29] Add state-mcp --- .github/workflows/build.yml | 16 +- activestate.yaml | 14 + cmd/state-installer/cmd.go | 6 +- cmd/state-mcp/lookupcve.go | 38 + cmd/state-mcp/lookupcve_test.go | 43 + cmd/state-mcp/main.go | 449 +++++++++ cmd/state-mcp/server_test.go | 23 + cmd/state/donotshipme/donotshipme.go | 17 + go.mod | 3 +- go.sum | 6 +- vendor/github.com/mark3labs/mcp-go/LICENSE | 21 + .../mark3labs/mcp-go/client/client.go | 84 ++ .../github.com/mark3labs/mcp-go/client/sse.go | 588 ++++++++++++ .../mark3labs/mcp-go/client/stdio.go | 457 ++++++++++ .../mark3labs/mcp-go/client/types.go | 8 + .../mark3labs/mcp-go/mcp/prompts.go | 163 ++++ .../mark3labs/mcp-go/mcp/resources.go | 105 +++ .../github.com/mark3labs/mcp-go/mcp/tools.go | 466 ++++++++++ .../github.com/mark3labs/mcp-go/mcp/types.go | 860 ++++++++++++++++++ .../github.com/mark3labs/mcp-go/mcp/utils.go | 596 ++++++++++++ .../mark3labs/mcp-go/server/hooks.go | 461 ++++++++++ .../mcp-go/server/request_handler.go | 279 ++++++ .../mark3labs/mcp-go/server/server.go | 768 ++++++++++++++++ .../github.com/mark3labs/mcp-go/server/sse.go | 433 +++++++++ .../mark3labs/mcp-go/server/stdio.go | 283 ++++++ .../yosida95/uritemplate/v3/LICENSE | 25 + .../yosida95/uritemplate/v3/README.rst | 46 + .../yosida95/uritemplate/v3/compile.go | 224 +++++ .../yosida95/uritemplate/v3/equals.go | 53 ++ .../yosida95/uritemplate/v3/error.go | 16 + .../yosida95/uritemplate/v3/escape.go | 190 ++++ .../yosida95/uritemplate/v3/expression.go | 173 ++++ .../yosida95/uritemplate/v3/machine.go | 23 + .../yosida95/uritemplate/v3/match.go | 213 +++++ .../yosida95/uritemplate/v3/parse.go | 277 ++++++ .../yosida95/uritemplate/v3/prog.go | 130 +++ .../yosida95/uritemplate/v3/uritemplate.go | 116 +++ .../yosida95/uritemplate/v3/value.go | 216 +++++ vendor/modules.txt | 8 + 39 files changed, 7886 insertions(+), 11 deletions(-) create mode 100644 cmd/state-mcp/lookupcve.go create mode 100644 cmd/state-mcp/lookupcve_test.go create mode 100644 cmd/state-mcp/main.go create mode 100644 cmd/state-mcp/server_test.go create mode 100644 cmd/state/donotshipme/donotshipme.go create mode 100644 vendor/github.com/mark3labs/mcp-go/LICENSE create mode 100644 vendor/github.com/mark3labs/mcp-go/client/client.go create mode 100644 vendor/github.com/mark3labs/mcp-go/client/sse.go create mode 100644 vendor/github.com/mark3labs/mcp-go/client/stdio.go create mode 100644 vendor/github.com/mark3labs/mcp-go/client/types.go create mode 100644 vendor/github.com/mark3labs/mcp-go/mcp/prompts.go create mode 100644 vendor/github.com/mark3labs/mcp-go/mcp/resources.go create mode 100644 vendor/github.com/mark3labs/mcp-go/mcp/tools.go create mode 100644 vendor/github.com/mark3labs/mcp-go/mcp/types.go create mode 100644 vendor/github.com/mark3labs/mcp-go/mcp/utils.go create mode 100644 vendor/github.com/mark3labs/mcp-go/server/hooks.go create mode 100644 vendor/github.com/mark3labs/mcp-go/server/request_handler.go create mode 100644 vendor/github.com/mark3labs/mcp-go/server/server.go create mode 100644 vendor/github.com/mark3labs/mcp-go/server/sse.go create mode 100644 vendor/github.com/mark3labs/mcp-go/server/stdio.go create mode 100644 vendor/github.com/yosida95/uritemplate/v3/LICENSE create mode 100644 vendor/github.com/yosida95/uritemplate/v3/README.rst create mode 100644 vendor/github.com/yosida95/uritemplate/v3/compile.go create mode 100644 vendor/github.com/yosida95/uritemplate/v3/equals.go create mode 100644 vendor/github.com/yosida95/uritemplate/v3/error.go create mode 100644 vendor/github.com/yosida95/uritemplate/v3/escape.go create mode 100644 vendor/github.com/yosida95/uritemplate/v3/expression.go create mode 100644 vendor/github.com/yosida95/uritemplate/v3/machine.go create mode 100644 vendor/github.com/yosida95/uritemplate/v3/match.go create mode 100644 vendor/github.com/yosida95/uritemplate/v3/parse.go create mode 100644 vendor/github.com/yosida95/uritemplate/v3/prog.go create mode 100644 vendor/github.com/yosida95/uritemplate/v3/uritemplate.go create mode 100644 vendor/github.com/yosida95/uritemplate/v3/value.go diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 85e57314f8..142505f11f 100755 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -171,6 +171,10 @@ jobs: "ID": "Build-Executor", "Args": ["state", "run", "build-exec"] } + { + "ID": "Build-MCP", + "Args": ["state", "run", "build-mcp"] + } ] EOF )" @@ -220,10 +224,15 @@ jobs: shell: bash run: parallelize results Build-Install-Scripts - - # === "Build: Executor" === - name: "Build: Executor" + - # === "Build: Executor" === + name: "Build: Executor" + shell: bash + run: parallelize results Build-Executor + + - # === "Build: MCP" === + name: "Build: MCP" shell: bash - run: parallelize results Build-Executor + run: parallelize results Build-MCP - # === Prepare Windows Cert === name: Prepare Windows Cert @@ -245,6 +254,7 @@ jobs: signtool.exe sign -d "ActiveState State Service" -f "Cert.p12" -p ${CODE_SIGNING_PASSWD} ./build/state-svc.exe signtool.exe sign -d "ActiveState State Installer" -f "Cert.p12" -p ${CODE_SIGNING_PASSWD} ./build/state-installer.exe signtool.exe sign -d "ActiveState State Tool Remote Installer" -f "Cert.p12" -p ${CODE_SIGNING_PASSWD} ./build/state-remote-installer.exe + signtool.exe sign -d "ActiveState State MCP" -f "Cert.p12" -p ${CODE_SIGNING_PASSWD} ./build/state-mcp.exe env: CODE_SIGNING_PASSWD: ${{ secrets.CODE_SIGNING_PASSWD }} diff --git a/activestate.yaml b/activestate.yaml index fac0bab6f8..d41910a6be 100644 --- a/activestate.yaml +++ b/activestate.yaml @@ -10,6 +10,8 @@ constants: value: ./cmd/state-installer - name: EXECUTOR_PKGS value: ./cmd/state-exec + - name: MCP_PKGS + value: ./cmd/state-mcp - name: BUILD_TARGET_PREFIX_DIR value: ./build - name: BUILD_TARGET @@ -29,6 +31,8 @@ constants: value: state-installer - name: BUILD_REMOTE_INSTALLER_TARGET value: state-remote-installer + - name: BUILD_MCP_TARGET + value: state-mcp - name: INTEGRATION_TEST_REGEX value: 'integration\|automation' - name: SET_ENV @@ -131,6 +135,14 @@ scripts: $constants.SET_ENV go build -tags "$GO_BUILD_TAGS" -o $BUILD_TARGET_DIR/$constants.BUILD_EXEC_TARGET $constants.CLI_BUILDFLAGS $constants.EXECUTOR_PKGS + - name: build-mcp + description: Builds the State MCP application + language: bash + standalone: true + value: | + set -e + $constants.SET_ENV + go build -tags "$GO_BUILD_TAGS" -o $BUILD_TARGET_DIR/$constants.BUILD_MCP_TARGET $constants.CLI_BUILDFLAGS $constants.MCP_PKGS - name: build-all description: Builds all our tools language: bash @@ -147,6 +159,8 @@ scripts: $scripts.build-svc.path() echo "Building State Executor" $scripts.build-exec.path() + echo "Building State MCP" + $scripts.build-mcp.path() - name: build-installer language: bash standalone: true diff --git a/cmd/state-installer/cmd.go b/cmd/state-installer/cmd.go index 89a5450319..8bf68fbd73 100644 --- a/cmd/state-installer/cmd.go +++ b/cmd/state-installer/cmd.go @@ -475,11 +475,7 @@ func storeInstallSource(installSource string) { installSource = "state-installer" } - appData, err := storage.AppDataPath() - if err != nil { - multilog.Error("Could not store install source due to AppDataPath error: %s", errs.JoinMessage(err)) - return - } + appData := storage.AppDataPath() if err := fileutils.WriteFile(filepath.Join(appData, constants.InstallSourceFile), []byte(installSource)); err != nil { multilog.Error("Could not store install source due to WriteFile error: %s", errs.JoinMessage(err)) } diff --git a/cmd/state-mcp/lookupcve.go b/cmd/state-mcp/lookupcve.go new file mode 100644 index 0000000000..2c113b3821 --- /dev/null +++ b/cmd/state-mcp/lookupcve.go @@ -0,0 +1,38 @@ +package main + +import ( + "encoding/json" + "fmt" + "net/http" + + "github.com/ActiveState/cli/internal/chanutils/workerpool" + "github.com/ActiveState/cli/internal/errs" +) + +func LookupCve(cveIds ...string) (map[string]interface{}, error) { + results := map[string]interface{}{} + // https://api.osv.dev/v1/vulns/OSV-2020-111 + wp := workerpool.New(5) + for _, cveId := range cveIds { + wp.Submit(func() error { + resp, err := http.Get(fmt.Sprintf("https://api.osv.dev/v1/vulns/%s", cveId)) + if err != nil { + return err + } + defer resp.Body.Close() + var result map[string]interface{} + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return err + } + results[cveId] = result + return nil + }) + } + + err := wp.Wait() + if err != nil { + return nil, errs.Wrap(err, "Failed to wait for workerpool") + } + + return results, nil +} \ No newline at end of file diff --git a/cmd/state-mcp/lookupcve_test.go b/cmd/state-mcp/lookupcve_test.go new file mode 100644 index 0000000000..e643f9672b --- /dev/null +++ b/cmd/state-mcp/lookupcve_test.go @@ -0,0 +1,43 @@ +package main + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestLookupCve(t *testing.T) { + // Table-driven test cases + tests := []struct { + name string + cveIds []string + }{ + { + name: "Single CVE", + cveIds: []string{"CVE-2021-44228"}, + }, + { + name: "Multiple CVEs", + cveIds: []string{"CVE-2021-44228", "CVE-2022-22965"}, + }, + { + name: "Non-existent CVE", + cveIds: []string{"CVE-DOES-NOT-EXIST"}, + }, + { + name: "Empty Input", + cveIds: []string{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + results, err := LookupCve(tt.cveIds...) + require.NoError(t, err) + require.NotNil(t, results) + for _, cveId := range tt.cveIds { + require.Contains(t, results, cveId) + } + }) + } +} \ No newline at end of file diff --git a/cmd/state-mcp/main.go b/cmd/state-mcp/main.go new file mode 100644 index 0000000000..fd092a7757 --- /dev/null +++ b/cmd/state-mcp/main.go @@ -0,0 +1,449 @@ +package main + +import ( + "bytes" + "context" + "encoding/json" + "flag" + "fmt" + "os" + "strings" + "time" + + "github.com/ActiveState/cli/cmd/state/donotshipme" + "github.com/ActiveState/cli/internal/config" + "github.com/ActiveState/cli/internal/constants" + "github.com/ActiveState/cli/internal/constraints" + "github.com/ActiveState/cli/internal/errs" + "github.com/ActiveState/cli/internal/events" + "github.com/ActiveState/cli/internal/installation" + "github.com/ActiveState/cli/internal/ipc" + "github.com/ActiveState/cli/internal/logging" + "github.com/ActiveState/cli/internal/multilog" + "github.com/ActiveState/cli/internal/output" + "github.com/ActiveState/cli/internal/primer" + "github.com/ActiveState/cli/internal/runners/cve" + "github.com/ActiveState/cli/internal/runners/manifest" + "github.com/ActiveState/cli/internal/runners/projects" + "github.com/ActiveState/cli/internal/sliceutils" + "github.com/ActiveState/cli/internal/subshell" + "github.com/ActiveState/cli/internal/svcctl" + "github.com/ActiveState/cli/pkg/platform/authentication" + "github.com/ActiveState/cli/pkg/platform/model" + "github.com/ActiveState/cli/pkg/project" + "github.com/ActiveState/cli/pkg/projectfile" + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" +) + +func main() { + defer func() { + logging.Debug("Exiting") + if r := recover(); r != nil { + logging.Error("Recovered from panic: %v", r) + fmt.Printf("Recovered from panic: %v\n", r) + os.Exit(1) + } + }() + defer func() { + if err := events.WaitForEvents(5*time.Second, logging.Close); err != nil { + logging.Warning("Failed waiting for events: %v", err) + } + }() + + mcpHandler := registerServer() + + // Parse command line flags + rawFlag := flag.Bool("raw", false, "Expose all State Tool commands as tools; this will lead to issues and is not optimized for AI use") + flag.Parse() + if *rawFlag { + close := registerRawTools(mcpHandler) + defer close() + } else { + registerCuratedTools(mcpHandler) + } + + // Start the stdio server + logging.Info("Starting MCP server") + if err := server.ServeStdio(mcpHandler.mcpServer); err != nil { + logging.Error("Server error: %v\n", err) + } +} + +func registerServer() *mcpServerHandler { + ipcClient, svcPort, err := connectToSvc() + if err != nil { + panic(errs.JoinMessage(err)) + } + + // Create MCP server + s := server.NewMCPServer( + constants.CommandName, + constants.VersionNumber, + ) + + mcpHandler := &mcpServerHandler{ + mcpServer: s, + ipcClient: ipcClient, + svcPort: svcPort, + } + + return mcpHandler +} + +func registerRawTools(mcpHandler *mcpServerHandler) func() error { + byt := &bytes.Buffer{} + prime, close, err := mcpHandler.newPrimer("", byt) + if err != nil { + panic(err) + } + + require := func(b bool) mcp.PropertyOption { + if b { + return mcp.Required() + } + return func(map[string]interface{}) {} + } + + tree := donotshipme.CmdTree(prime) + for _, command := range tree.Command().AllChildren() { + // Best effort to filter out interactive commands + if sliceutils.Contains([]string{"activate", "shell"}, command.NameRecursive()) { + continue + } + + opts := []mcp.ToolOption{ + mcp.WithDescription(command.Description()), + } + + // Require project directory for most commands. This is currently not encoded into the command tree + if !sliceutils.Contains([]string{"projects", "auth"}, command.BaseCommand().Name()) { + opts = append(opts, mcp.WithString( + "project_directory", + require(true), + mcp.Description("Absolute path to the directory where your activestate project is checked out. It should contain the activestate.yaml file."), + )) + } + + for _, arg := range command.Arguments() { + opts = append(opts, mcp.WithString(arg.Name, + require(arg.Required), + mcp.Description(arg.Description), + )) + } + for _, flag := range command.Flags() { + opts = append(opts, mcp.WithString(flag.Name, + mcp.Description(flag.Description), + )) + } + mcpHandler.addTool( + mcp.NewTool(strings.Join(strings.Split(command.NameRecursive(), " "), "_"), opts...), + func(ctx context.Context, request mcp.CallToolRequest) (r *mcp.CallToolResult, rerr error) { + byt.Truncate(0) + if projectDir, ok := request.Params.Arguments["project_directory"]; ok { + pj, err := project.FromPath(projectDir.(string)) + if err != nil { + return nil, errs.Wrap(err, "Failed to create project") + } + prime.SetProject(pj) + } + args := strings.Split(command.NameRecursive(), " ") + for _, arg := range command.Arguments() { + v, ok := request.Params.Arguments[arg.Name] + if !ok { + break + } + args = append(args, v.(string)) + } + for _, flag := range command.Flags() { + v, ok := request.Params.Arguments[flag.Name] + if !ok { + break + } + args = append(args, fmt.Sprintf("--%s=%s", flag.Name, v.(string))) + } + logging.Debug("Executing command: %s, args: %v (%v)", command.NameRecursive(), args, args==nil) + err := command.Execute(args) + if err != nil { + return nil, errs.Wrap(err, "Failed to execute command") + } + return mcp.NewToolResultText(byt.String()), nil + }, + ) + } + + return close +} + +func registerCuratedTools(mcpHandler *mcpServerHandler) { + projectDirParam := mcp.WithString("project_directory", + mcp.Required(), + mcp.Description("Absolute path to the directory where your activestate project is checked out. It should contain the activestate.yaml file."), + ) + + mcpHandler.addTool(mcp.NewTool("list_projects", + mcp.WithDescription("List all ActiveState projects checked out on the local machine"), + ), mcpHandler.listProjectsHandler) + + mcpHandler.addTool(mcp.NewTool("view_manifest", + mcp.WithDescription("Show the manifest (packages and dependencies) for a locally checked out ActiveState platform project"), + projectDirParam, + ), mcpHandler.manifestHandler) + + mcpHandler.addTool(mcp.NewTool("view_cves", + mcp.WithDescription("Show the CVEs for a locally checked out ActiveState platform project"), + projectDirParam, + ), mcpHandler.cveHandler) + + mcpHandler.addTool(mcp.NewTool("lookup_cve", + mcp.WithDescription("Lookup one or more CVEs by their ID"), + mcp.WithString("cve_ids", + mcp.Required(), + mcp.Description("The IDs of the CVEs to lookup, comma separated"), + ), + ), mcpHandler.lookupCveHandler) +} + +type mcpServerHandler struct { + mcpServer *server.MCPServer + ipcClient *ipc.Client + svcPort string +} + +func (t *mcpServerHandler) addResource(resource mcp.Resource, handler server.ResourceHandlerFunc) { + t.mcpServer.AddResource(resource, func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { + defer func() { + if r := recover(); r != nil { + logging.Error("Recovered from resource handler panic: %v", r) + fmt.Printf("Recovered from resource handler panic: %v\n", r) + } + }() + logging.Debug("Received resource request: %s", resource.Name) + r, err := handler(ctx, request) + if err != nil { + logging.Error("%s: Error handling resource request: %v", resource.Name, err) + return nil, errs.Wrap(err, "Failed to handle resource request") + } + return r, nil + }) +} + +func (t *mcpServerHandler) addTool(tool mcp.Tool, handler server.ToolHandlerFunc) { + t.mcpServer.AddTool(tool, func(ctx context.Context, request mcp.CallToolRequest) (r *mcp.CallToolResult, rerr error) { + defer func() { + if r := recover(); r != nil { + logging.Error("Recovered from tool handler panic: %v", r) + fmt.Printf("Recovered from tool handler panic: %v\n", r) + } + }() + logging.Debug("Received tool request: %s", tool.Name) + r, err := handler(ctx, request) + logging.Debug("Received tool response from %s", tool.Name) + if err != nil { + logging.Error("%s: Error handling tool request: %v", tool.Name, errs.JoinMessage(err)) + // Format all errors as a single string, so the client gets the full context + return nil, fmt.Errorf("%s: %s", tool.Name, errs.JoinMessage(err)) + } + return r, nil + }) +} + +func (t *mcpServerHandler) listProjectsHandler(ctx context.Context, request mcp.CallToolRequest) (r *mcp.CallToolResult, rerr error) { + var byt bytes.Buffer + prime, close, err := t.newPrimer("", &byt) + if err != nil { + return nil, errs.Wrap(err, "Failed to create primer") + } + defer func() { + if err := close(); err != nil { + rerr = errs.Pack(rerr, err) + } + }() + + runner := projects.NewProjects(prime) + params := projects.NewParams() + err = runner.Run(params) + if err != nil { + return nil, errs.Wrap(err, "Failed to run projects") + } + + return mcp.NewToolResultText(byt.String()), nil +} + +func (t *mcpServerHandler) listProjectsResourceHandler(ctx context.Context, request mcp.ReadResourceRequest) (r []mcp.ResourceContents, rerr error) { + var byt bytes.Buffer + prime, close, err := t.newPrimer("", &byt) + if err != nil { + return nil, errs.Wrap(err, "Failed to create primer") + } + defer func() { + if err := close(); err != nil { + rerr = errs.Pack(rerr, err) + } + }() + + runner := projects.NewProjects(prime) + params := projects.NewParams() + err = runner.Run(params) + if err != nil { + return nil, errs.Wrap(err, "Failed to run projects") + } + + r = append(r, mcp.TextResourceContents{Text: byt.String()}) + return r, nil +} + +func (t *mcpServerHandler) manifestHandler(ctx context.Context, request mcp.CallToolRequest) (r *mcp.CallToolResult, rerr error) { + pjPath := request.Params.Arguments["project_directory"].(string) + + var byt bytes.Buffer + prime, close, err := t.newPrimer(pjPath, &byt) + if err != nil { + return nil, errs.Wrap(err, "Failed to create primer") + } + defer func() { + if err := close(); err != nil { + rerr = errs.Pack(rerr, err) + } + }() + + m := manifest.NewManifest(prime) + err = m.Run(manifest.Params{}) + if err != nil { + return nil, errs.Wrap(err, "Failed to run manifest") + } + + return mcp.NewToolResultText(byt.String()), nil +} + +func (t *mcpServerHandler) cveHandler(ctx context.Context, request mcp.CallToolRequest) (r *mcp.CallToolResult, rerr error) { + pjPath := request.Params.Arguments["project_directory"].(string) + + var byt bytes.Buffer + prime, close, err := t.newPrimer(pjPath, &byt) + if err != nil { + return nil, errs.Wrap(err, "Failed to create primer") + } + defer func() { + if err := close(); err != nil { + rerr = errs.Pack(rerr, err) + } + }() + + c := cve.NewCve(prime) + err = c.Run(&cve.Params{}) + if err != nil { + return nil, errs.Wrap(err, "Failed to run manifest") + } + + return mcp.NewToolResultText(byt.String()), nil +} + +func (t *mcpServerHandler) lookupCveHandler(ctx context.Context, request mcp.CallToolRequest) (r *mcp.CallToolResult, rerr error) { + cveId := request.Params.Arguments["cve_ids"].(string) + cveIds := strings.Split(cveId, ",") + + results, err := LookupCve(cveIds...) + if err != nil { + return nil, errs.Wrap(err, "Failed to lookup CVEs") + } + + byt, err := json.Marshal(results) + if err != nil { + return nil, errs.Wrap(err, "Failed to marshal results") + } + + return mcp.NewToolResultText(string(byt)), nil +} + +type stdOutput struct{} + +func (s *stdOutput) Notice(msg interface{}) { + logging.Info(fmt.Sprintf("%v", msg)) +} + +func connectToSvc() (*ipc.Client, string, error) { + svcExec, err := installation.ServiceExec() + if err != nil { + return nil, "", errs.Wrap(err, "Could not get service info") + } + + ipcClient := svcctl.NewDefaultIPCClient() + argText := strings.Join(os.Args, " ") + svcPort, err := svcctl.EnsureExecStartedAndLocateHTTP(ipcClient, svcExec, argText, &stdOutput{}) + if err != nil { + return nil, "", errs.Wrap(err, "Failed to start state-svc at state tool invocation") + } + + return ipcClient, svcPort, nil +} + +func (t *mcpServerHandler) newPrimer(projectDir string, o *bytes.Buffer) (*primer.Values, func() error, error) { + closers := []func() error{} + closer := func() error { + for _, c := range closers { + if err := c(); err != nil { + return err + } + } + return nil + } + + cfg, err := config.New() + if err != nil { + return nil, closer, errs.Wrap(err, "Failed to create config") + } + closers = append(closers, cfg.Close) + + auth := authentication.New(cfg) + closers = append(closers, auth.Close) + + out, err := output.New(string(output.SimpleFormatName), &output.Config{ + OutWriter: o, + ErrWriter: o, + Colored: false, + Interactive: false, + ShellName: "", + }) + if err != nil { + return nil, closer, errs.Wrap(err, "Failed to create output") + } + + var pj *project.Project + if projectDir != "" { + pjf, err := projectfile.FromPath(projectDir) + if err != nil { + return nil, closer, errs.Wrap(err, "Failed to create projectfile") + } + pj, err = project.New(pjf, out) + if err != nil { + return nil, closer, errs.Wrap(err, "Failed to create project") + } + } + + // Set up conditional, which accesses a lot of primer data + sshell := subshell.New(cfg) + + conditional := constraints.NewPrimeConditional(auth, pj, sshell.Shell()) + project.RegisterConditional(conditional) + if err := project.RegisterExpander("mixin", project.NewMixin(auth).Expander); err != nil { + logging.Debug("Could not register mixin expander: %v", err) + } + + svcmodel := model.NewSvcModel(t.svcPort) + + if auth.AvailableAPIToken() != "" { + jwt, err := svcmodel.GetJWT(context.Background()) + if err != nil { + multilog.Critical("Could not get JWT: %v", errs.JoinMessage(err)) + } + if err != nil || jwt == nil { + // Could not authenticate; user got logged out + auth.Logout() + } else { + auth.UpdateSession(jwt) + } + } + + return primer.New(pj, out, auth, sshell, conditional, cfg, t.ipcClient, svcmodel), closer, nil +} diff --git a/cmd/state-mcp/server_test.go b/cmd/state-mcp/server_test.go new file mode 100644 index 0000000000..9fbd210212 --- /dev/null +++ b/cmd/state-mcp/server_test.go @@ -0,0 +1,23 @@ +package main + +import ( + "context" + "encoding/json" + "testing" +) + +func TestServer(t *testing.T) { + mcpHandler := registerServer() + registerRawTools(mcpHandler) + + msg := mcpHandler.mcpServer.HandleMessage(context.Background(), json.RawMessage(`{ + "jsonrpc": "2.0", + "id": 1, + "method": "tools/call", + "params": { + "name": "projects", + "arguments": {} + } + }`)) + t.Fatalf("%+v", msg) +} diff --git a/cmd/state/donotshipme/donotshipme.go b/cmd/state/donotshipme/donotshipme.go new file mode 100644 index 0000000000..b021fc9ca5 --- /dev/null +++ b/cmd/state/donotshipme/donotshipme.go @@ -0,0 +1,17 @@ +package donotshipme + +import ( + "github.com/ActiveState/cli/cmd/state/internal/cmdtree" + "github.com/ActiveState/cli/internal/constants" + "github.com/ActiveState/cli/internal/primer" +) + +func init() { + if constants.ChannelName == "release" { + panic("This file is for experimentation only, it should not be shipped as is. CmdTree is internal to the State command and should remain that way or be refactored.") + } +} + +func CmdTree(prime *primer.Values, args ...string) *cmdtree.CmdTree { + return cmdtree.New(prime, args...) +} \ No newline at end of file diff --git a/go.mod b/go.mod index d83833829c..d5f3fd3da6 100644 --- a/go.mod +++ b/go.mod @@ -45,7 +45,6 @@ require ( github.com/phayes/permbits v0.0.0-20190108233746-1efae4548023 github.com/posener/wstest v0.0.0-20180216222922-04b166ca0bf1 github.com/rollbar/rollbar-go v1.1.0 - github.com/shibukawa/configdir v0.0.0-20170330084843-e180dbdc8da0 github.com/shirou/gopsutil/v3 v3.24.5 github.com/skratchdot/open-golang v0.0.0-20190104022628-a2dfa6d0dab6 github.com/spf13/cast v1.3.0 @@ -76,6 +75,7 @@ require ( github.com/go-git/go-git/v5 v5.13.1 github.com/gowebpki/jcs v1.0.1 github.com/klauspost/compress v1.11.4 + github.com/mark3labs/mcp-go v0.18.0 github.com/mholt/archiver/v3 v3.5.1 github.com/zijiren233/yaml-comment v0.2.1 ) @@ -109,6 +109,7 @@ require ( github.com/shoenig/go-m1cpu v0.1.6 // indirect github.com/skeema/knownhosts v1.3.0 // indirect github.com/sosodev/duration v1.3.1 // indirect + github.com/yosida95/uritemplate/v3 v3.0.2 // indirect golang.org/x/sync v0.11.0 // indirect ) diff --git a/go.sum b/go.sum index 1661bd5015..30abed107f 100644 --- a/go.sum +++ b/go.sum @@ -472,6 +472,8 @@ github.com/mailru/easyjson v0.7.1/go.mod h1:KAzv3t3aY1NaHWoQz1+4F1ccyAH66Jk7yos7 github.com/mailru/easyjson v0.7.6/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= +github.com/mark3labs/mcp-go v0.18.0 h1:YuhgIVjNlTG2ZOwmrkORWyPTp0dz1opPEqvsPtySXao= +github.com/mark3labs/mcp-go v0.18.0/go.mod h1:KmJndYv7GIgcPVwEKJjNcbhVQ+hJGJhrCCB/9xITzpE= github.com/markbates/oncer v0.0.0-20181203154359-bf2de49a0be2/go.mod h1:Ld9puTsIW75CHf65OeIOkyKbteujpZVXDpWK6YGZbxE= github.com/markbates/safe v1.0.1/go.mod h1:nAqgmRi7cY2nqMc92/bSEeQA+R4OheNU2T1kNSCBdG0= github.com/maruel/natural v1.1.0 h1:2z1NgP/Vae+gYrtC0VuvrTJ6U35OuyUqDdfluLqMWuQ= @@ -602,8 +604,6 @@ github.com/sean-/seed v0.0.0-20170313163322-e2103e2c3529/go.mod h1:DxrIzT+xaE7yg github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo= github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3 h1:n661drycOFuPLCN3Uc8sB6B/s6Z4t2xvBgU1htSHuq8= github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3/go.mod h1:A0bzQcvG0E7Rwjx0REVgAGH58e96+X0MeOfepqsbeW4= -github.com/shibukawa/configdir v0.0.0-20170330084843-e180dbdc8da0 h1:Xuk8ma/ibJ1fOy4Ee11vHhUFHQNpHhrBneOCNHVXS5w= -github.com/shibukawa/configdir v0.0.0-20170330084843-e180dbdc8da0/go.mod h1:7AwjWCpdPhkSmNAgUv5C7EJ4AbmjEB3r047r3DXWu3Y= github.com/shirou/gopsutil/v3 v3.24.5 h1:i0t8kL+kQTvpAYToeuiVk3TgDeKOFioZO3Ztz/iZ9pI= github.com/shirou/gopsutil/v3 v3.24.5/go.mod h1:bsoOS1aStSs9ErQ1WWfxllSeS1K5D+U30r2NfcubMVk= github.com/shoenig/go-m1cpu v0.1.6 h1:nxdKQNcEB6vzgA2E2bvzKIYRuNj7XNJ4S/aRSwKzFtM= @@ -684,6 +684,8 @@ github.com/xdg/stringprep v0.0.0-20180714160509-73f8eece6fdc/go.mod h1:Jhud4/sHM github.com/xi2/xz v0.0.0-20171230120015-48954b6210f8 h1:nIPpBwaJSVYIxUFsDv3M8ofmx9yWTog9BfvIu0q41lo= github.com/xi2/xz v0.0.0-20171230120015-48954b6210f8/go.mod h1:HUYIGzjTL3rfEspMxjDjgmT5uz5wzYJKVo23qUhYTos= github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2/go.mod h1:UETIi67q53MR2AWcXfiuqkDkRtnGDLqkBTpCHuJHxtU= +github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= +github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d/go.mod h1:rHwXgn7JulP+udvsHwJoVG1YGAP6VLg4y9I5dyZdqmA= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= diff --git a/vendor/github.com/mark3labs/mcp-go/LICENSE b/vendor/github.com/mark3labs/mcp-go/LICENSE new file mode 100644 index 0000000000..3d48435454 --- /dev/null +++ b/vendor/github.com/mark3labs/mcp-go/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2024 Anthropic, PBC + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/vendor/github.com/mark3labs/mcp-go/client/client.go b/vendor/github.com/mark3labs/mcp-go/client/client.go new file mode 100644 index 0000000000..1d3cb1051e --- /dev/null +++ b/vendor/github.com/mark3labs/mcp-go/client/client.go @@ -0,0 +1,84 @@ +// Package client provides MCP (Model Control Protocol) client implementations. +package client + +import ( + "context" + + "github.com/mark3labs/mcp-go/mcp" +) + +// MCPClient represents an MCP client interface +type MCPClient interface { + // Initialize sends the initial connection request to the server + Initialize( + ctx context.Context, + request mcp.InitializeRequest, + ) (*mcp.InitializeResult, error) + + // Ping checks if the server is alive + Ping(ctx context.Context) error + + // ListResources requests a list of available resources from the server + ListResources( + ctx context.Context, + request mcp.ListResourcesRequest, + ) (*mcp.ListResourcesResult, error) + + // ListResourceTemplates requests a list of available resource templates from the server + ListResourceTemplates( + ctx context.Context, + request mcp.ListResourceTemplatesRequest, + ) (*mcp.ListResourceTemplatesResult, + error) + + // ReadResource reads a specific resource from the server + ReadResource( + ctx context.Context, + request mcp.ReadResourceRequest, + ) (*mcp.ReadResourceResult, error) + + // Subscribe requests notifications for changes to a specific resource + Subscribe(ctx context.Context, request mcp.SubscribeRequest) error + + // Unsubscribe cancels notifications for a specific resource + Unsubscribe(ctx context.Context, request mcp.UnsubscribeRequest) error + + // ListPrompts requests a list of available prompts from the server + ListPrompts( + ctx context.Context, + request mcp.ListPromptsRequest, + ) (*mcp.ListPromptsResult, error) + + // GetPrompt retrieves a specific prompt from the server + GetPrompt( + ctx context.Context, + request mcp.GetPromptRequest, + ) (*mcp.GetPromptResult, error) + + // ListTools requests a list of available tools from the server + ListTools( + ctx context.Context, + request mcp.ListToolsRequest, + ) (*mcp.ListToolsResult, error) + + // CallTool invokes a specific tool on the server + CallTool( + ctx context.Context, + request mcp.CallToolRequest, + ) (*mcp.CallToolResult, error) + + // SetLevel sets the logging level for the server + SetLevel(ctx context.Context, request mcp.SetLevelRequest) error + + // Complete requests completion options for a given argument + Complete( + ctx context.Context, + request mcp.CompleteRequest, + ) (*mcp.CompleteResult, error) + + // Close client connection and cleanup resources + Close() error + + // OnNotification registers a handler for notifications + OnNotification(handler func(notification mcp.JSONRPCNotification)) +} diff --git a/vendor/github.com/mark3labs/mcp-go/client/sse.go b/vendor/github.com/mark3labs/mcp-go/client/sse.go new file mode 100644 index 0000000000..cf4a1028e0 --- /dev/null +++ b/vendor/github.com/mark3labs/mcp-go/client/sse.go @@ -0,0 +1,588 @@ +package client + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/mark3labs/mcp-go/mcp" +) + +// SSEMCPClient implements the MCPClient interface using Server-Sent Events (SSE). +// It maintains a persistent HTTP connection to receive server-pushed events +// while sending requests over regular HTTP POST calls. The client handles +// automatic reconnection and message routing between requests and responses. +type SSEMCPClient struct { + baseURL *url.URL + endpoint *url.URL + httpClient *http.Client + requestID atomic.Int64 + responses map[int64]chan RPCResponse + mu sync.RWMutex + done chan struct{} + initialized bool + notifications []func(mcp.JSONRPCNotification) + notifyMu sync.RWMutex + endpointChan chan struct{} + capabilities mcp.ServerCapabilities + headers map[string]string + sseReadTimeout time.Duration +} + +type ClientOption func(*SSEMCPClient) + +func WithHeaders(headers map[string]string) ClientOption { + return func(sc *SSEMCPClient) { + sc.headers = headers + } +} + +func WithSSEReadTimeout(timeout time.Duration) ClientOption { + return func(sc *SSEMCPClient) { + sc.sseReadTimeout = timeout + } +} + +// NewSSEMCPClient creates a new SSE-based MCP client with the given base URL. +// Returns an error if the URL is invalid. +func NewSSEMCPClient(baseURL string, options ...ClientOption) (*SSEMCPClient, error) { + parsedURL, err := url.Parse(baseURL) + if err != nil { + return nil, fmt.Errorf("invalid URL: %w", err) + } + + smc := &SSEMCPClient{ + baseURL: parsedURL, + httpClient: &http.Client{}, + responses: make(map[int64]chan RPCResponse), + done: make(chan struct{}), + endpointChan: make(chan struct{}), + sseReadTimeout: 30 * time.Second, + headers: make(map[string]string), + } + + for _, opt := range options { + opt(smc) + } + + return smc, nil +} + +// Start initiates the SSE connection to the server and waits for the endpoint information. +// Returns an error if the connection fails or times out waiting for the endpoint. +func (c *SSEMCPClient) Start(ctx context.Context) error { + + req, err := http.NewRequestWithContext(ctx, "GET", c.baseURL.String(), nil) + + if err != nil { + + return fmt.Errorf("failed to create request: %w", err) + + } + + req.Header.Set("Accept", "text/event-stream") + req.Header.Set("Cache-Control", "no-cache") + req.Header.Set("Connection", "keep-alive") + + resp, err := c.httpClient.Do(req) + if err != nil { + return fmt.Errorf("failed to connect to SSE stream: %w", err) + } + + if resp.StatusCode != http.StatusOK { + resp.Body.Close() + return fmt.Errorf("unexpected status code: %d", resp.StatusCode) + } + + go c.readSSE(resp.Body) + + // Wait for the endpoint to be received + + select { + case <-c.endpointChan: + // Endpoint received, proceed + case <-ctx.Done(): + return fmt.Errorf("context cancelled while waiting for endpoint") + case <-time.After(30 * time.Second): // Add a timeout + return fmt.Errorf("timeout waiting for endpoint") + } + + return nil +} + +// readSSE continuously reads the SSE stream and processes events. +// It runs until the connection is closed or an error occurs. +func (c *SSEMCPClient) readSSE(reader io.ReadCloser) { + defer reader.Close() + + br := bufio.NewReader(reader) + var event, data string + + ctx, cancel := context.WithTimeout(context.Background(), c.sseReadTimeout) + defer cancel() + + for { + select { + case <-ctx.Done(): + return + default: + line, err := br.ReadString('\n') + if err != nil { + if err == io.EOF { + // Process any pending event before exit + if event != "" && data != "" { + c.handleSSEEvent(event, data) + } + break + } + select { + case <-c.done: + return + default: + fmt.Printf("SSE stream error: %v\n", err) + return + } + } + + // Remove only newline markers + line = strings.TrimRight(line, "\r\n") + if line == "" { + // Empty line means end of event + if event != "" && data != "" { + c.handleSSEEvent(event, data) + event = "" + data = "" + } + continue + } + + if strings.HasPrefix(line, "event:") { + event = strings.TrimSpace(strings.TrimPrefix(line, "event:")) + } else if strings.HasPrefix(line, "data:") { + data = strings.TrimSpace(strings.TrimPrefix(line, "data:")) + } + } + } +} + +// handleSSEEvent processes SSE events based on their type. +// Handles 'endpoint' events for connection setup and 'message' events for JSON-RPC communication. +func (c *SSEMCPClient) handleSSEEvent(event, data string) { + switch event { + case "endpoint": + endpoint, err := c.baseURL.Parse(data) + if err != nil { + fmt.Printf("Error parsing endpoint URL: %v\n", err) + return + } + if endpoint.Host != c.baseURL.Host { + fmt.Printf("Endpoint origin does not match connection origin\n") + return + } + c.endpoint = endpoint + close(c.endpointChan) + + case "message": + var baseMessage struct { + JSONRPC string `json:"jsonrpc"` + ID *int64 `json:"id,omitempty"` + Method string `json:"method,omitempty"` + Result json.RawMessage `json:"result,omitempty"` + Error *struct { + Code int `json:"code"` + Message string `json:"message"` + } `json:"error,omitempty"` + } + + if err := json.Unmarshal([]byte(data), &baseMessage); err != nil { + fmt.Printf("Error unmarshaling message: %v\n", err) + return + } + + // Handle notification + if baseMessage.ID == nil { + var notification mcp.JSONRPCNotification + if err := json.Unmarshal([]byte(data), ¬ification); err != nil { + return + } + c.notifyMu.RLock() + for _, handler := range c.notifications { + handler(notification) + } + c.notifyMu.RUnlock() + return + } + + c.mu.RLock() + ch, ok := c.responses[*baseMessage.ID] + c.mu.RUnlock() + + if ok { + if baseMessage.Error != nil { + ch <- RPCResponse{ + Error: &baseMessage.Error.Message, + } + } else { + ch <- RPCResponse{ + Response: &baseMessage.Result, + } + } + c.mu.Lock() + delete(c.responses, *baseMessage.ID) + c.mu.Unlock() + } + } +} + +// OnNotification registers a handler function to be called when notifications are received. +// Multiple handlers can be registered and will be called in the order they were added. +func (c *SSEMCPClient) OnNotification( + handler func(notification mcp.JSONRPCNotification), +) { + c.notifyMu.Lock() + defer c.notifyMu.Unlock() + c.notifications = append(c.notifications, handler) +} + +// sendRequest sends a JSON-RPC request to the server and waits for a response. +// Returns the raw JSON response message or an error if the request fails. +func (c *SSEMCPClient) sendRequest( + ctx context.Context, + method string, + params interface{}, +) (*json.RawMessage, error) { + if !c.initialized && method != "initialize" { + return nil, fmt.Errorf("client not initialized") + } + + if c.endpoint == nil { + return nil, fmt.Errorf("endpoint not received") + } + + id := c.requestID.Add(1) + + request := mcp.JSONRPCRequest{ + JSONRPC: mcp.JSONRPC_VERSION, + ID: id, + Request: mcp.Request{ + Method: method, + }, + Params: params, + } + + requestBytes, err := json.Marshal(request) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + responseChan := make(chan RPCResponse, 1) + c.mu.Lock() + c.responses[id] = responseChan + c.mu.Unlock() + + req, err := http.NewRequestWithContext( + ctx, + "POST", + c.endpoint.String(), + bytes.NewReader(requestBytes), + ) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + // set custom http headers + for k, v := range c.headers { + req.Header.Set(k, v) + } + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK && + resp.StatusCode != http.StatusAccepted { + body, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf( + "request failed with status %d: %s", + resp.StatusCode, + body, + ) + } + + select { + case <-ctx.Done(): + c.mu.Lock() + delete(c.responses, id) + c.mu.Unlock() + return nil, ctx.Err() + case response := <-responseChan: + if response.Error != nil { + return nil, errors.New(*response.Error) + } + return response.Response, nil + } +} + +func (c *SSEMCPClient) Initialize( + ctx context.Context, + request mcp.InitializeRequest, +) (*mcp.InitializeResult, error) { + // Ensure we send a params object with all required fields + params := struct { + ProtocolVersion string `json:"protocolVersion"` + ClientInfo mcp.Implementation `json:"clientInfo"` + Capabilities mcp.ClientCapabilities `json:"capabilities"` + }{ + ProtocolVersion: request.Params.ProtocolVersion, + ClientInfo: request.Params.ClientInfo, + Capabilities: request.Params.Capabilities, // Will be empty struct if not set + } + + response, err := c.sendRequest(ctx, "initialize", params) + if err != nil { + return nil, err + } + + var result mcp.InitializeResult + if err := json.Unmarshal(*response, &result); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) + } + + // Store capabilities + c.capabilities = result.Capabilities + + // Send initialized notification + notification := mcp.JSONRPCNotification{ + JSONRPC: mcp.JSONRPC_VERSION, + Notification: mcp.Notification{ + Method: "notifications/initialized", + }, + } + + notificationBytes, err := json.Marshal(notification) + if err != nil { + return nil, fmt.Errorf( + "failed to marshal initialized notification: %w", + err, + ) + } + + req, err := http.NewRequestWithContext( + ctx, + "POST", + c.endpoint.String(), + bytes.NewReader(notificationBytes), + ) + if err != nil { + return nil, fmt.Errorf("failed to create notification request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf( + "failed to send initialized notification: %w", + err, + ) + } + resp.Body.Close() + + c.initialized = true + return &result, nil +} + +func (c *SSEMCPClient) Ping(ctx context.Context) error { + _, err := c.sendRequest(ctx, "ping", nil) + return err +} + +func (c *SSEMCPClient) ListResources( + ctx context.Context, + request mcp.ListResourcesRequest, +) (*mcp.ListResourcesResult, error) { + response, err := c.sendRequest(ctx, "resources/list", request.Params) + if err != nil { + return nil, err + } + + var result mcp.ListResourcesResult + if err := json.Unmarshal(*response, &result); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) + } + + return &result, nil +} + +func (c *SSEMCPClient) ListResourceTemplates( + ctx context.Context, + request mcp.ListResourceTemplatesRequest, +) (*mcp.ListResourceTemplatesResult, error) { + response, err := c.sendRequest( + ctx, + "resources/templates/list", + request.Params, + ) + if err != nil { + return nil, err + } + + var result mcp.ListResourceTemplatesResult + if err := json.Unmarshal(*response, &result); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) + } + + return &result, nil +} + +func (c *SSEMCPClient) ReadResource( + ctx context.Context, + request mcp.ReadResourceRequest, +) (*mcp.ReadResourceResult, error) { + response, err := c.sendRequest(ctx, "resources/read", request.Params) + if err != nil { + return nil, err + } + + return mcp.ParseReadResourceResult(response) +} + +func (c *SSEMCPClient) Subscribe( + ctx context.Context, + request mcp.SubscribeRequest, +) error { + _, err := c.sendRequest(ctx, "resources/subscribe", request.Params) + return err +} + +func (c *SSEMCPClient) Unsubscribe( + ctx context.Context, + request mcp.UnsubscribeRequest, +) error { + _, err := c.sendRequest(ctx, "resources/unsubscribe", request.Params) + return err +} + +func (c *SSEMCPClient) ListPrompts( + ctx context.Context, + request mcp.ListPromptsRequest, +) (*mcp.ListPromptsResult, error) { + response, err := c.sendRequest(ctx, "prompts/list", request.Params) + if err != nil { + return nil, err + } + + var result mcp.ListPromptsResult + if err := json.Unmarshal(*response, &result); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) + } + + return &result, nil +} + +func (c *SSEMCPClient) GetPrompt( + ctx context.Context, + request mcp.GetPromptRequest, +) (*mcp.GetPromptResult, error) { + response, err := c.sendRequest(ctx, "prompts/get", request.Params) + if err != nil { + return nil, err + } + + return mcp.ParseGetPromptResult(response) +} + +func (c *SSEMCPClient) ListTools( + ctx context.Context, + request mcp.ListToolsRequest, +) (*mcp.ListToolsResult, error) { + response, err := c.sendRequest(ctx, "tools/list", request.Params) + if err != nil { + return nil, err + } + + var result mcp.ListToolsResult + if err := json.Unmarshal(*response, &result); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) + } + + return &result, nil +} + +func (c *SSEMCPClient) CallTool( + ctx context.Context, + request mcp.CallToolRequest, +) (*mcp.CallToolResult, error) { + response, err := c.sendRequest(ctx, "tools/call", request.Params) + if err != nil { + return nil, err + } + + return mcp.ParseCallToolResult(response) +} + +func (c *SSEMCPClient) SetLevel( + ctx context.Context, + request mcp.SetLevelRequest, +) error { + _, err := c.sendRequest(ctx, "logging/setLevel", request.Params) + return err +} + +func (c *SSEMCPClient) Complete( + ctx context.Context, + request mcp.CompleteRequest, +) (*mcp.CompleteResult, error) { + response, err := c.sendRequest(ctx, "completion/complete", request.Params) + if err != nil { + return nil, err + } + + var result mcp.CompleteResult + if err := json.Unmarshal(*response, &result); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) + } + + return &result, nil +} + +// Helper methods + +// GetEndpoint returns the current endpoint URL for the SSE connection. +func (c *SSEMCPClient) GetEndpoint() *url.URL { + return c.endpoint +} + +// Close shuts down the SSE client connection and cleans up any pending responses. +// Returns an error if the shutdown process fails. +func (c *SSEMCPClient) Close() error { + select { + case <-c.done: + return nil // Already closed + default: + close(c.done) + } + + // Clean up any pending responses + c.mu.Lock() + for _, ch := range c.responses { + close(ch) + } + c.responses = make(map[int64]chan RPCResponse) + c.mu.Unlock() + + return nil +} diff --git a/vendor/github.com/mark3labs/mcp-go/client/stdio.go b/vendor/github.com/mark3labs/mcp-go/client/stdio.go new file mode 100644 index 0000000000..8e0845dca6 --- /dev/null +++ b/vendor/github.com/mark3labs/mcp-go/client/stdio.go @@ -0,0 +1,457 @@ +package client + +import ( + "bufio" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "os" + "os/exec" + "sync" + "sync/atomic" + + "github.com/mark3labs/mcp-go/mcp" +) + +// StdioMCPClient implements the MCPClient interface using stdio communication. +// It launches a subprocess and communicates with it via standard input/output streams +// using JSON-RPC messages. The client handles message routing between requests and +// responses, and supports asynchronous notifications. +type StdioMCPClient struct { + cmd *exec.Cmd + stdin io.WriteCloser + stdout *bufio.Reader + stderr io.ReadCloser + requestID atomic.Int64 + responses map[int64]chan RPCResponse + mu sync.RWMutex + done chan struct{} + initialized bool + notifications []func(mcp.JSONRPCNotification) + notifyMu sync.RWMutex + capabilities mcp.ServerCapabilities +} + +// NewStdioMCPClient creates a new stdio-based MCP client that communicates with a subprocess. +// It launches the specified command with given arguments and sets up stdin/stdout pipes for communication. +// Returns an error if the subprocess cannot be started or the pipes cannot be created. +func NewStdioMCPClient( + command string, + env []string, + args ...string, +) (*StdioMCPClient, error) { + cmd := exec.Command(command, args...) + + mergedEnv := os.Environ() + mergedEnv = append(mergedEnv, env...) + + cmd.Env = mergedEnv + + stdin, err := cmd.StdinPipe() + if err != nil { + return nil, fmt.Errorf("failed to create stdin pipe: %w", err) + } + + stdout, err := cmd.StdoutPipe() + if err != nil { + return nil, fmt.Errorf("failed to create stdout pipe: %w", err) + } + + stderr, err := cmd.StderrPipe() + if err != nil { + return nil, fmt.Errorf("failed to create stderr pipe: %w", err) + } + + client := &StdioMCPClient{ + cmd: cmd, + stdin: stdin, + stderr: stderr, + stdout: bufio.NewReader(stdout), + responses: make(map[int64]chan RPCResponse), + done: make(chan struct{}), + } + + if err := cmd.Start(); err != nil { + return nil, fmt.Errorf("failed to start command: %w", err) + } + + // Start reading responses in a goroutine and wait for it to be ready + ready := make(chan struct{}) + go func() { + close(ready) + client.readResponses() + }() + <-ready + + return client, nil +} + +// Close shuts down the stdio client, closing the stdin pipe and waiting for the subprocess to exit. +// Returns an error if there are issues closing stdin or waiting for the subprocess to terminate. +func (c *StdioMCPClient) Close() error { + close(c.done) + if err := c.stdin.Close(); err != nil { + return fmt.Errorf("failed to close stdin: %w", err) + } + if err := c.stderr.Close(); err != nil { + return fmt.Errorf("failed to close stderr: %w", err) + } + return c.cmd.Wait() +} + +// Stderr returns a reader for the stderr output of the subprocess. +// This can be used to capture error messages or logs from the subprocess. +func (c *StdioMCPClient) Stderr() io.Reader { + return c.stderr +} + +// OnNotification registers a handler function to be called when notifications are received. +// Multiple handlers can be registered and will be called in the order they were added. +func (c *StdioMCPClient) OnNotification( + handler func(notification mcp.JSONRPCNotification), +) { + c.notifyMu.Lock() + defer c.notifyMu.Unlock() + c.notifications = append(c.notifications, handler) +} + +// readResponses continuously reads and processes responses from the server's stdout. +// It handles both responses to requests and notifications, routing them appropriately. +// Runs until the done channel is closed or an error occurs reading from stdout. +func (c *StdioMCPClient) readResponses() { + for { + select { + case <-c.done: + return + default: + line, err := c.stdout.ReadString('\n') + if err != nil { + if err != io.EOF { + fmt.Printf("Error reading response: %v\n", err) + } + return + } + + var baseMessage struct { + JSONRPC string `json:"jsonrpc"` + ID *int64 `json:"id,omitempty"` + Method string `json:"method,omitempty"` + Result json.RawMessage `json:"result,omitempty"` + Error *struct { + Code int `json:"code"` + Message string `json:"message"` + } `json:"error,omitempty"` + } + + if err := json.Unmarshal([]byte(line), &baseMessage); err != nil { + continue + } + + // Handle notification + if baseMessage.ID == nil { + var notification mcp.JSONRPCNotification + if err := json.Unmarshal([]byte(line), ¬ification); err != nil { + continue + } + c.notifyMu.RLock() + for _, handler := range c.notifications { + handler(notification) + } + c.notifyMu.RUnlock() + continue + } + + c.mu.RLock() + ch, ok := c.responses[*baseMessage.ID] + c.mu.RUnlock() + + if ok { + if baseMessage.Error != nil { + ch <- RPCResponse{ + Error: &baseMessage.Error.Message, + } + } else { + ch <- RPCResponse{ + Response: &baseMessage.Result, + } + } + c.mu.Lock() + delete(c.responses, *baseMessage.ID) + c.mu.Unlock() + } + } + } +} + +// sendRequest sends a JSON-RPC request to the server and waits for a response. +// It creates a unique request ID, sends the request over stdin, and waits for +// the corresponding response or context cancellation. +// Returns the raw JSON response message or an error if the request fails. +func (c *StdioMCPClient) sendRequest( + ctx context.Context, + method string, + params interface{}, +) (*json.RawMessage, error) { + if !c.initialized && method != "initialize" { + return nil, fmt.Errorf("client not initialized") + } + + id := c.requestID.Add(1) + + // Create the complete request structure + request := mcp.JSONRPCRequest{ + JSONRPC: mcp.JSONRPC_VERSION, + ID: id, + Request: mcp.Request{ + Method: method, + }, + Params: params, + } + + responseChan := make(chan RPCResponse, 1) + c.mu.Lock() + c.responses[id] = responseChan + c.mu.Unlock() + + requestBytes, err := json.Marshal(request) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + requestBytes = append(requestBytes, '\n') + + if _, err := c.stdin.Write(requestBytes); err != nil { + return nil, fmt.Errorf("failed to write request: %w", err) + } + + select { + case <-ctx.Done(): + c.mu.Lock() + delete(c.responses, id) + c.mu.Unlock() + return nil, ctx.Err() + case response := <-responseChan: + if response.Error != nil { + return nil, errors.New(*response.Error) + } + return response.Response, nil + } +} + +func (c *StdioMCPClient) Ping(ctx context.Context) error { + _, err := c.sendRequest(ctx, "ping", nil) + return err +} + +func (c *StdioMCPClient) Initialize( + ctx context.Context, + request mcp.InitializeRequest, +) (*mcp.InitializeResult, error) { + // This structure ensures Capabilities is always included in JSON + params := struct { + ProtocolVersion string `json:"protocolVersion"` + ClientInfo mcp.Implementation `json:"clientInfo"` + Capabilities mcp.ClientCapabilities `json:"capabilities"` + }{ + ProtocolVersion: request.Params.ProtocolVersion, + ClientInfo: request.Params.ClientInfo, + Capabilities: request.Params.Capabilities, // Will be empty struct if not set + } + + response, err := c.sendRequest(ctx, "initialize", params) + if err != nil { + return nil, err + } + + var result mcp.InitializeResult + if err := json.Unmarshal(*response, &result); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) + } + + // Store capabilities + c.capabilities = result.Capabilities + + // Send initialized notification + notification := mcp.JSONRPCNotification{ + JSONRPC: mcp.JSONRPC_VERSION, + Notification: mcp.Notification{ + Method: "notifications/initialized", + }, + } + + notificationBytes, err := json.Marshal(notification) + if err != nil { + return nil, fmt.Errorf( + "failed to marshal initialized notification: %w", + err, + ) + } + notificationBytes = append(notificationBytes, '\n') + + if _, err := c.stdin.Write(notificationBytes); err != nil { + return nil, fmt.Errorf( + "failed to send initialized notification: %w", + err, + ) + } + + c.initialized = true + return &result, nil +} + +func (c *StdioMCPClient) ListResources( + ctx context.Context, + request mcp.ListResourcesRequest, +) (*mcp. + ListResourcesResult, error) { + response, err := c.sendRequest( + ctx, + "resources/list", + request.Params, + ) + if err != nil { + return nil, err + } + + var result mcp.ListResourcesResult + if err := json.Unmarshal(*response, &result); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) + } + + return &result, nil +} + +func (c *StdioMCPClient) ListResourceTemplates( + ctx context.Context, + request mcp.ListResourceTemplatesRequest, +) (*mcp. + ListResourceTemplatesResult, error) { + response, err := c.sendRequest( + ctx, + "resources/templates/list", + request.Params, + ) + if err != nil { + return nil, err + } + + var result mcp.ListResourceTemplatesResult + if err := json.Unmarshal(*response, &result); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) + } + + return &result, nil +} + +func (c *StdioMCPClient) ReadResource( + ctx context.Context, + request mcp.ReadResourceRequest, +) (*mcp.ReadResourceResult, + error) { + response, err := c.sendRequest(ctx, "resources/read", request.Params) + if err != nil { + return nil, err + } + + return mcp.ParseReadResourceResult(response) +} + +func (c *StdioMCPClient) Subscribe( + ctx context.Context, + request mcp.SubscribeRequest, +) error { + _, err := c.sendRequest(ctx, "resources/subscribe", request.Params) + return err +} + +func (c *StdioMCPClient) Unsubscribe( + ctx context.Context, + request mcp.UnsubscribeRequest, +) error { + _, err := c.sendRequest(ctx, "resources/unsubscribe", request.Params) + return err +} + +func (c *StdioMCPClient) ListPrompts( + ctx context.Context, + request mcp.ListPromptsRequest, +) (*mcp.ListPromptsResult, error) { + response, err := c.sendRequest(ctx, "prompts/list", request.Params) + if err != nil { + return nil, err + } + + var result mcp.ListPromptsResult + if err := json.Unmarshal(*response, &result); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) + } + + return &result, nil +} + +func (c *StdioMCPClient) GetPrompt( + ctx context.Context, + request mcp.GetPromptRequest, +) (*mcp.GetPromptResult, error) { + response, err := c.sendRequest(ctx, "prompts/get", request.Params) + if err != nil { + return nil, err + } + + return mcp.ParseGetPromptResult(response) +} + +func (c *StdioMCPClient) ListTools( + ctx context.Context, + request mcp.ListToolsRequest, +) (*mcp.ListToolsResult, error) { + response, err := c.sendRequest(ctx, "tools/list", request.Params) + if err != nil { + return nil, err + } + + var result mcp.ListToolsResult + if err := json.Unmarshal(*response, &result); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) + } + + return &result, nil +} + +func (c *StdioMCPClient) CallTool( + ctx context.Context, + request mcp.CallToolRequest, +) (*mcp.CallToolResult, error) { + response, err := c.sendRequest(ctx, "tools/call", request.Params) + if err != nil { + return nil, err + } + + return mcp.ParseCallToolResult(response) +} + +func (c *StdioMCPClient) SetLevel( + ctx context.Context, + request mcp.SetLevelRequest, +) error { + _, err := c.sendRequest(ctx, "logging/setLevel", request.Params) + return err +} + +func (c *StdioMCPClient) Complete( + ctx context.Context, + request mcp.CompleteRequest, +) (*mcp.CompleteResult, error) { + response, err := c.sendRequest(ctx, "completion/complete", request.Params) + if err != nil { + return nil, err + } + + var result mcp.CompleteResult + if err := json.Unmarshal(*response, &result); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) + } + + return &result, nil +} diff --git a/vendor/github.com/mark3labs/mcp-go/client/types.go b/vendor/github.com/mark3labs/mcp-go/client/types.go new file mode 100644 index 0000000000..4402bd0240 --- /dev/null +++ b/vendor/github.com/mark3labs/mcp-go/client/types.go @@ -0,0 +1,8 @@ +package client + +import "encoding/json" + +type RPCResponse struct { + Error *string + Response *json.RawMessage +} diff --git a/vendor/github.com/mark3labs/mcp-go/mcp/prompts.go b/vendor/github.com/mark3labs/mcp-go/mcp/prompts.go new file mode 100644 index 0000000000..bc12a72976 --- /dev/null +++ b/vendor/github.com/mark3labs/mcp-go/mcp/prompts.go @@ -0,0 +1,163 @@ +package mcp + +/* Prompts */ + +// ListPromptsRequest is sent from the client to request a list of prompts and +// prompt templates the server has. +type ListPromptsRequest struct { + PaginatedRequest +} + +// ListPromptsResult is the server's response to a prompts/list request from +// the client. +type ListPromptsResult struct { + PaginatedResult + Prompts []Prompt `json:"prompts"` +} + +// GetPromptRequest is used by the client to get a prompt provided by the +// server. +type GetPromptRequest struct { + Request + Params struct { + // The name of the prompt or prompt template. + Name string `json:"name"` + // Arguments to use for templating the prompt. + Arguments map[string]string `json:"arguments,omitempty"` + } `json:"params"` +} + +// GetPromptResult is the server's response to a prompts/get request from the +// client. +type GetPromptResult struct { + Result + // An optional description for the prompt. + Description string `json:"description,omitempty"` + Messages []PromptMessage `json:"messages"` +} + +// Prompt represents a prompt or prompt template that the server offers. +// If Arguments is non-nil and non-empty, this indicates the prompt is a template +// that requires argument values to be provided when calling prompts/get. +// If Arguments is nil or empty, this is a static prompt that takes no arguments. +type Prompt struct { + // The name of the prompt or prompt template. + Name string `json:"name"` + // An optional description of what this prompt provides + Description string `json:"description,omitempty"` + // A list of arguments to use for templating the prompt. + // The presence of arguments indicates this is a template prompt. + Arguments []PromptArgument `json:"arguments,omitempty"` +} + +// PromptArgument describes an argument that a prompt template can accept. +// When a prompt includes arguments, clients must provide values for all +// required arguments when making a prompts/get request. +type PromptArgument struct { + // The name of the argument. + Name string `json:"name"` + // A human-readable description of the argument. + Description string `json:"description,omitempty"` + // Whether this argument must be provided. + // If true, clients must include this argument when calling prompts/get. + Required bool `json:"required,omitempty"` +} + +// Role represents the sender or recipient of messages and data in a +// conversation. +type Role string + +const ( + RoleUser Role = "user" + RoleAssistant Role = "assistant" +) + +// PromptMessage describes a message returned as part of a prompt. +// +// This is similar to `SamplingMessage`, but also supports the embedding of +// resources from the MCP server. +type PromptMessage struct { + Role Role `json:"role"` + Content Content `json:"content"` // Can be TextContent, ImageContent, or EmbeddedResource +} + +// PromptListChangedNotification is an optional notification from the server +// to the client, informing it that the list of prompts it offers has changed. This +// may be issued by servers without any previous subscription from the client. +type PromptListChangedNotification struct { + Notification +} + +// PromptOption is a function that configures a Prompt. +// It provides a flexible way to set various properties of a Prompt using the functional options pattern. +type PromptOption func(*Prompt) + +// ArgumentOption is a function that configures a PromptArgument. +// It allows for flexible configuration of prompt arguments using the functional options pattern. +type ArgumentOption func(*PromptArgument) + +// +// Core Prompt Functions +// + +// NewPrompt creates a new Prompt with the given name and options. +// The prompt will be configured based on the provided options. +// Options are applied in order, allowing for flexible prompt configuration. +func NewPrompt(name string, opts ...PromptOption) Prompt { + prompt := Prompt{ + Name: name, + } + + for _, opt := range opts { + opt(&prompt) + } + + return prompt +} + +// WithPromptDescription adds a description to the Prompt. +// The description should provide a clear, human-readable explanation of what the prompt does. +func WithPromptDescription(description string) PromptOption { + return func(p *Prompt) { + p.Description = description + } +} + +// WithArgument adds an argument to the prompt's argument list. +// The argument will be configured based on the provided options. +func WithArgument(name string, opts ...ArgumentOption) PromptOption { + return func(p *Prompt) { + arg := PromptArgument{ + Name: name, + } + + for _, opt := range opts { + opt(&arg) + } + + if p.Arguments == nil { + p.Arguments = make([]PromptArgument, 0) + } + p.Arguments = append(p.Arguments, arg) + } +} + +// +// Argument Options +// + +// ArgumentDescription adds a description to a prompt argument. +// The description should explain the purpose and expected values of the argument. +func ArgumentDescription(desc string) ArgumentOption { + return func(arg *PromptArgument) { + arg.Description = desc + } +} + +// RequiredArgument marks an argument as required in the prompt. +// Required arguments must be provided when getting the prompt. +func RequiredArgument() ArgumentOption { + return func(arg *PromptArgument) { + arg.Required = true + } +} diff --git a/vendor/github.com/mark3labs/mcp-go/mcp/resources.go b/vendor/github.com/mark3labs/mcp-go/mcp/resources.go new file mode 100644 index 0000000000..51cdd25dd3 --- /dev/null +++ b/vendor/github.com/mark3labs/mcp-go/mcp/resources.go @@ -0,0 +1,105 @@ +package mcp + +import "github.com/yosida95/uritemplate/v3" + +// ResourceOption is a function that configures a Resource. +// It provides a flexible way to set various properties of a Resource using the functional options pattern. +type ResourceOption func(*Resource) + +// NewResource creates a new Resource with the given URI, name and options. +// The resource will be configured based on the provided options. +// Options are applied in order, allowing for flexible resource configuration. +func NewResource(uri string, name string, opts ...ResourceOption) Resource { + resource := Resource{ + URI: uri, + Name: name, + } + + for _, opt := range opts { + opt(&resource) + } + + return resource +} + +// WithResourceDescription adds a description to the Resource. +// The description should provide a clear, human-readable explanation of what the resource represents. +func WithResourceDescription(description string) ResourceOption { + return func(r *Resource) { + r.Description = description + } +} + +// WithMIMEType sets the MIME type for the Resource. +// This should indicate the format of the resource's contents. +func WithMIMEType(mimeType string) ResourceOption { + return func(r *Resource) { + r.MIMEType = mimeType + } +} + +// WithAnnotations adds annotations to the Resource. +// Annotations can provide additional metadata about the resource's intended use. +func WithAnnotations(audience []Role, priority float64) ResourceOption { + return func(r *Resource) { + if r.Annotations == nil { + r.Annotations = &struct { + Audience []Role `json:"audience,omitempty"` + Priority float64 `json:"priority,omitempty"` + }{} + } + r.Annotations.Audience = audience + r.Annotations.Priority = priority + } +} + +// ResourceTemplateOption is a function that configures a ResourceTemplate. +// It provides a flexible way to set various properties of a ResourceTemplate using the functional options pattern. +type ResourceTemplateOption func(*ResourceTemplate) + +// NewResourceTemplate creates a new ResourceTemplate with the given URI template, name and options. +// The template will be configured based on the provided options. +// Options are applied in order, allowing for flexible template configuration. +func NewResourceTemplate(uriTemplate string, name string, opts ...ResourceTemplateOption) ResourceTemplate { + template := ResourceTemplate{ + URITemplate: &URITemplate{Template: uritemplate.MustNew(uriTemplate)}, + Name: name, + } + + for _, opt := range opts { + opt(&template) + } + + return template +} + +// WithTemplateDescription adds a description to the ResourceTemplate. +// The description should provide a clear, human-readable explanation of what resources this template represents. +func WithTemplateDescription(description string) ResourceTemplateOption { + return func(t *ResourceTemplate) { + t.Description = description + } +} + +// WithTemplateMIMEType sets the MIME type for the ResourceTemplate. +// This should only be set if all resources matching this template will have the same type. +func WithTemplateMIMEType(mimeType string) ResourceTemplateOption { + return func(t *ResourceTemplate) { + t.MIMEType = mimeType + } +} + +// WithTemplateAnnotations adds annotations to the ResourceTemplate. +// Annotations can provide additional metadata about the template's intended use. +func WithTemplateAnnotations(audience []Role, priority float64) ResourceTemplateOption { + return func(t *ResourceTemplate) { + if t.Annotations == nil { + t.Annotations = &struct { + Audience []Role `json:"audience,omitempty"` + Priority float64 `json:"priority,omitempty"` + }{} + } + t.Annotations.Audience = audience + t.Annotations.Priority = priority + } +} diff --git a/vendor/github.com/mark3labs/mcp-go/mcp/tools.go b/vendor/github.com/mark3labs/mcp-go/mcp/tools.go new file mode 100644 index 0000000000..c4c1b1dec0 --- /dev/null +++ b/vendor/github.com/mark3labs/mcp-go/mcp/tools.go @@ -0,0 +1,466 @@ +package mcp + +import ( + "encoding/json" + "errors" + "fmt" +) + +var errToolSchemaConflict = errors.New("provide either InputSchema or RawInputSchema, not both") + +// ListToolsRequest is sent from the client to request a list of tools the +// server has. +type ListToolsRequest struct { + PaginatedRequest +} + +// ListToolsResult is the server's response to a tools/list request from the +// client. +type ListToolsResult struct { + PaginatedResult + Tools []Tool `json:"tools"` +} + +// CallToolResult is the server's response to a tool call. +// +// Any errors that originate from the tool SHOULD be reported inside the result +// object, with `isError` set to true, _not_ as an MCP protocol-level error +// response. Otherwise, the LLM would not be able to see that an error occurred +// and self-correct. +// +// However, any errors in _finding_ the tool, an error indicating that the +// server does not support tool calls, or any other exceptional conditions, +// should be reported as an MCP error response. +type CallToolResult struct { + Result + Content []Content `json:"content"` // Can be TextContent, ImageContent, or EmbeddedResource + // Whether the tool call ended in an error. + // + // If not set, this is assumed to be false (the call was successful). + IsError bool `json:"isError,omitempty"` +} + +// CallToolRequest is used by the client to invoke a tool provided by the server. +type CallToolRequest struct { + Request + Params struct { + Name string `json:"name"` + Arguments map[string]interface{} `json:"arguments,omitempty"` + Meta *struct { + // If specified, the caller is requesting out-of-band progress + // notifications for this request (as represented by + // notifications/progress). The value of this parameter is an + // opaque token that will be attached to any subsequent + // notifications. The receiver is not obligated to provide these + // notifications. + ProgressToken ProgressToken `json:"progressToken,omitempty"` + } `json:"_meta,omitempty"` + } `json:"params"` +} + +// ToolListChangedNotification is an optional notification from the server to +// the client, informing it that the list of tools it offers has changed. This may +// be issued by servers without any previous subscription from the client. +type ToolListChangedNotification struct { + Notification +} + +// Tool represents the definition for a tool the client can call. +type Tool struct { + // The name of the tool. + Name string `json:"name"` + // A human-readable description of the tool. + Description string `json:"description,omitempty"` + // A JSON Schema object defining the expected parameters for the tool. + InputSchema ToolInputSchema `json:"inputSchema"` + // Alternative to InputSchema - allows arbitrary JSON Schema to be provided + RawInputSchema json.RawMessage `json:"-"` // Hide this from JSON marshaling +} + +// MarshalJSON implements the json.Marshaler interface for Tool. +// It handles marshaling either InputSchema or RawInputSchema based on which is set. +func (t Tool) MarshalJSON() ([]byte, error) { + // Create a map to build the JSON structure + m := make(map[string]interface{}, 3) + + // Add the name and description + m["name"] = t.Name + if t.Description != "" { + m["description"] = t.Description + } + + // Determine which schema to use + if t.RawInputSchema != nil { + if t.InputSchema.Type != "" { + return nil, fmt.Errorf("tool %s has both InputSchema and RawInputSchema set: %w", t.Name, errToolSchemaConflict) + } + m["inputSchema"] = t.RawInputSchema + } else { + // Use the structured InputSchema + m["inputSchema"] = t.InputSchema + } + + return json.Marshal(m) +} + +type ToolInputSchema struct { + Type string `json:"type"` + Properties map[string]interface{} `json:"properties"` + Required []string `json:"required,omitempty"` +} + +// ToolOption is a function that configures a Tool. +// It provides a flexible way to set various properties of a Tool using the functional options pattern. +type ToolOption func(*Tool) + +// PropertyOption is a function that configures a property in a Tool's input schema. +// It allows for flexible configuration of JSON Schema properties using the functional options pattern. +type PropertyOption func(map[string]interface{}) + +// +// Core Tool Functions +// + +// NewTool creates a new Tool with the given name and options. +// The tool will have an object-type input schema with configurable properties. +// Options are applied in order, allowing for flexible tool configuration. +func NewTool(name string, opts ...ToolOption) Tool { + tool := Tool{ + Name: name, + InputSchema: ToolInputSchema{ + Type: "object", + Properties: make(map[string]interface{}), + Required: nil, // Will be omitted from JSON if empty + }, + } + + for _, opt := range opts { + opt(&tool) + } + + return tool +} + +// NewToolWithRawSchema creates a new Tool with the given name and a raw JSON +// Schema. This allows for arbitrary JSON Schema to be used for the tool's input +// schema. +// +// NOTE a [Tool] built in such a way is incompatible with the [ToolOption] and +// runtime errors will result from supplying a [ToolOption] to a [Tool] built +// with this function. +func NewToolWithRawSchema(name, description string, schema json.RawMessage) Tool { + tool := Tool{ + Name: name, + Description: description, + RawInputSchema: schema, + } + + return tool +} + +// WithDescription adds a description to the Tool. +// The description should provide a clear, human-readable explanation of what the tool does. +func WithDescription(description string) ToolOption { + return func(t *Tool) { + t.Description = description + } +} + +// +// Common Property Options +// + +// Description adds a description to a property in the JSON Schema. +// The description should explain the purpose and expected values of the property. +func Description(desc string) PropertyOption { + return func(schema map[string]interface{}) { + schema["description"] = desc + } +} + +// Required marks a property as required in the tool's input schema. +// Required properties must be provided when using the tool. +func Required() PropertyOption { + return func(schema map[string]interface{}) { + schema["required"] = true + } +} + +// Title adds a display-friendly title to a property in the JSON Schema. +// This title can be used by UI components to show a more readable property name. +func Title(title string) PropertyOption { + return func(schema map[string]interface{}) { + schema["title"] = title + } +} + +// +// String Property Options +// + +// DefaultString sets the default value for a string property. +// This value will be used if the property is not explicitly provided. +func DefaultString(value string) PropertyOption { + return func(schema map[string]interface{}) { + schema["default"] = value + } +} + +// Enum specifies a list of allowed values for a string property. +// The property value must be one of the specified enum values. +func Enum(values ...string) PropertyOption { + return func(schema map[string]interface{}) { + schema["enum"] = values + } +} + +// MaxLength sets the maximum length for a string property. +// The string value must not exceed this length. +func MaxLength(max int) PropertyOption { + return func(schema map[string]interface{}) { + schema["maxLength"] = max + } +} + +// MinLength sets the minimum length for a string property. +// The string value must be at least this length. +func MinLength(min int) PropertyOption { + return func(schema map[string]interface{}) { + schema["minLength"] = min + } +} + +// Pattern sets a regex pattern that a string property must match. +// The string value must conform to the specified regular expression. +func Pattern(pattern string) PropertyOption { + return func(schema map[string]interface{}) { + schema["pattern"] = pattern + } +} + +// +// Number Property Options +// + +// DefaultNumber sets the default value for a number property. +// This value will be used if the property is not explicitly provided. +func DefaultNumber(value float64) PropertyOption { + return func(schema map[string]interface{}) { + schema["default"] = value + } +} + +// Max sets the maximum value for a number property. +// The number value must not exceed this maximum. +func Max(max float64) PropertyOption { + return func(schema map[string]interface{}) { + schema["maximum"] = max + } +} + +// Min sets the minimum value for a number property. +// The number value must not be less than this minimum. +func Min(min float64) PropertyOption { + return func(schema map[string]interface{}) { + schema["minimum"] = min + } +} + +// MultipleOf specifies that a number must be a multiple of the given value. +// The number value must be divisible by this value. +func MultipleOf(value float64) PropertyOption { + return func(schema map[string]interface{}) { + schema["multipleOf"] = value + } +} + +// +// Boolean Property Options +// + +// DefaultBool sets the default value for a boolean property. +// This value will be used if the property is not explicitly provided. +func DefaultBool(value bool) PropertyOption { + return func(schema map[string]interface{}) { + schema["default"] = value + } +} + +// +// Property Type Helpers +// + +// WithBoolean adds a boolean property to the tool schema. +// It accepts property options to configure the boolean property's behavior and constraints. +func WithBoolean(name string, opts ...PropertyOption) ToolOption { + return func(t *Tool) { + schema := map[string]interface{}{ + "type": "boolean", + } + + for _, opt := range opts { + opt(schema) + } + + // Remove required from property schema and add to InputSchema.required + if required, ok := schema["required"].(bool); ok && required { + delete(schema, "required") + t.InputSchema.Required = append(t.InputSchema.Required, name) + } + + t.InputSchema.Properties[name] = schema + } +} + +// WithNumber adds a number property to the tool schema. +// It accepts property options to configure the number property's behavior and constraints. +func WithNumber(name string, opts ...PropertyOption) ToolOption { + return func(t *Tool) { + schema := map[string]interface{}{ + "type": "number", + } + + for _, opt := range opts { + opt(schema) + } + + // Remove required from property schema and add to InputSchema.required + if required, ok := schema["required"].(bool); ok && required { + delete(schema, "required") + t.InputSchema.Required = append(t.InputSchema.Required, name) + } + + t.InputSchema.Properties[name] = schema + } +} + +// WithString adds a string property to the tool schema. +// It accepts property options to configure the string property's behavior and constraints. +func WithString(name string, opts ...PropertyOption) ToolOption { + return func(t *Tool) { + schema := map[string]interface{}{ + "type": "string", + } + + for _, opt := range opts { + opt(schema) + } + + // Remove required from property schema and add to InputSchema.required + if required, ok := schema["required"].(bool); ok && required { + delete(schema, "required") + t.InputSchema.Required = append(t.InputSchema.Required, name) + } + + t.InputSchema.Properties[name] = schema + } +} + +// WithObject adds an object property to the tool schema. +// It accepts property options to configure the object property's behavior and constraints. +func WithObject(name string, opts ...PropertyOption) ToolOption { + return func(t *Tool) { + schema := map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{}, + } + + for _, opt := range opts { + opt(schema) + } + + // Remove required from property schema and add to InputSchema.required + if required, ok := schema["required"].(bool); ok && required { + delete(schema, "required") + t.InputSchema.Required = append(t.InputSchema.Required, name) + } + + t.InputSchema.Properties[name] = schema + } +} + +// WithArray adds an array property to the tool schema. +// It accepts property options to configure the array property's behavior and constraints. +func WithArray(name string, opts ...PropertyOption) ToolOption { + return func(t *Tool) { + schema := map[string]interface{}{ + "type": "array", + } + + for _, opt := range opts { + opt(schema) + } + + // Remove required from property schema and add to InputSchema.required + if required, ok := schema["required"].(bool); ok && required { + delete(schema, "required") + t.InputSchema.Required = append(t.InputSchema.Required, name) + } + + t.InputSchema.Properties[name] = schema + } +} + +// Properties defines the properties for an object schema +func Properties(props map[string]interface{}) PropertyOption { + return func(schema map[string]interface{}) { + schema["properties"] = props + } +} + +// AdditionalProperties specifies whether additional properties are allowed in the object +// or defines a schema for additional properties +func AdditionalProperties(schema interface{}) PropertyOption { + return func(schemaMap map[string]interface{}) { + schemaMap["additionalProperties"] = schema + } +} + +// MinProperties sets the minimum number of properties for an object +func MinProperties(min int) PropertyOption { + return func(schema map[string]interface{}) { + schema["minProperties"] = min + } +} + +// MaxProperties sets the maximum number of properties for an object +func MaxProperties(max int) PropertyOption { + return func(schema map[string]interface{}) { + schema["maxProperties"] = max + } +} + +// PropertyNames defines a schema for property names in an object +func PropertyNames(schema map[string]interface{}) PropertyOption { + return func(schemaMap map[string]interface{}) { + schemaMap["propertyNames"] = schema + } +} + +// Items defines the schema for array items +func Items(schema interface{}) PropertyOption { + return func(schemaMap map[string]interface{}) { + schemaMap["items"] = schema + } +} + +// MinItems sets the minimum number of items for an array +func MinItems(min int) PropertyOption { + return func(schema map[string]interface{}) { + schema["minItems"] = min + } +} + +// MaxItems sets the maximum number of items for an array +func MaxItems(max int) PropertyOption { + return func(schema map[string]interface{}) { + schema["maxItems"] = max + } +} + +// UniqueItems specifies whether array items must be unique +func UniqueItems(unique bool) PropertyOption { + return func(schema map[string]interface{}) { + schema["uniqueItems"] = unique + } +} diff --git a/vendor/github.com/mark3labs/mcp-go/mcp/types.go b/vendor/github.com/mark3labs/mcp-go/mcp/types.go new file mode 100644 index 0000000000..a3ad8174e6 --- /dev/null +++ b/vendor/github.com/mark3labs/mcp-go/mcp/types.go @@ -0,0 +1,860 @@ +// Package mcp defines the core types and interfaces for the Model Control Protocol (MCP). +// MCP is a protocol for communication between LLM-powered applications and their supporting services. +package mcp + +import ( + "encoding/json" + + "github.com/yosida95/uritemplate/v3" +) + +type MCPMethod string + +const ( + // Initiates connection and negotiates protocol capabilities. + // https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/lifecycle/#initialization + MethodInitialize MCPMethod = "initialize" + + // Verifies connection liveness between client and server. + // https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/utilities/ping/ + MethodPing MCPMethod = "ping" + + // Lists all available server resources. + // https://spec.modelcontextprotocol.io/specification/2024-11-05/server/resources/ + MethodResourcesList MCPMethod = "resources/list" + + // Provides URI templates for constructing resource URIs. + // https://spec.modelcontextprotocol.io/specification/2024-11-05/server/resources/ + MethodResourcesTemplatesList MCPMethod = "resources/templates/list" + + // Retrieves content of a specific resource by URI. + // https://spec.modelcontextprotocol.io/specification/2024-11-05/server/resources/ + MethodResourcesRead MCPMethod = "resources/read" + + // Lists all available prompt templates. + // https://spec.modelcontextprotocol.io/specification/2024-11-05/server/prompts/ + MethodPromptsList MCPMethod = "prompts/list" + + // Retrieves a specific prompt template with filled parameters. + // https://spec.modelcontextprotocol.io/specification/2024-11-05/server/prompts/ + MethodPromptsGet MCPMethod = "prompts/get" + + // Lists all available executable tools. + // https://spec.modelcontextprotocol.io/specification/2024-11-05/server/tools/ + MethodToolsList MCPMethod = "tools/list" + + // Invokes a specific tool with provided parameters. + // https://spec.modelcontextprotocol.io/specification/2024-11-05/server/tools/ + MethodToolsCall MCPMethod = "tools/call" +) + +type URITemplate struct { + *uritemplate.Template +} + +func (t *URITemplate) MarshalJSON() ([]byte, error) { + return json.Marshal(t.Template.Raw()) +} + +func (t *URITemplate) UnmarshalJSON(data []byte) error { + var raw string + if err := json.Unmarshal(data, &raw); err != nil { + return err + } + template, err := uritemplate.New(raw) + if err != nil { + return err + } + t.Template = template + return nil +} + +/* JSON-RPC types */ + +// JSONRPCMessage represents either a JSONRPCRequest, JSONRPCNotification, JSONRPCResponse, or JSONRPCError +type JSONRPCMessage interface{} + +// LATEST_PROTOCOL_VERSION is the most recent version of the MCP protocol. +const LATEST_PROTOCOL_VERSION = "2024-11-05" + +// JSONRPC_VERSION is the version of JSON-RPC used by MCP. +const JSONRPC_VERSION = "2.0" + +// ProgressToken is used to associate progress notifications with the original request. +type ProgressToken interface{} + +// Cursor is an opaque token used to represent a cursor for pagination. +type Cursor string + +type Request struct { + Method string `json:"method"` + Params struct { + Meta *struct { + // If specified, the caller is requesting out-of-band progress + // notifications for this request (as represented by + // notifications/progress). The value of this parameter is an + // opaque token that will be attached to any subsequent + // notifications. The receiver is not obligated to provide these + // notifications. + ProgressToken ProgressToken `json:"progressToken,omitempty"` + } `json:"_meta,omitempty"` + } `json:"params,omitempty"` +} + +type Params map[string]interface{} + +type Notification struct { + Method string `json:"method"` + Params NotificationParams `json:"params,omitempty"` +} + +type NotificationParams struct { + // This parameter name is reserved by MCP to allow clients and + // servers to attach additional metadata to their notifications. + Meta map[string]interface{} `json:"_meta,omitempty"` + + // Additional fields can be added to this map + AdditionalFields map[string]interface{} `json:"-"` +} + +// MarshalJSON implements custom JSON marshaling +func (p NotificationParams) MarshalJSON() ([]byte, error) { + // Create a map to hold all fields + m := make(map[string]interface{}) + + // Add Meta if it exists + if p.Meta != nil { + m["_meta"] = p.Meta + } + + // Add all additional fields + for k, v := range p.AdditionalFields { + // Ensure we don't override the _meta field + if k != "_meta" { + m[k] = v + } + } + + return json.Marshal(m) +} + +// UnmarshalJSON implements custom JSON unmarshaling +func (p *NotificationParams) UnmarshalJSON(data []byte) error { + // Create a map to hold all fields + var m map[string]interface{} + if err := json.Unmarshal(data, &m); err != nil { + return err + } + + // Initialize maps if they're nil + if p.Meta == nil { + p.Meta = make(map[string]interface{}) + } + if p.AdditionalFields == nil { + p.AdditionalFields = make(map[string]interface{}) + } + + // Process all fields + for k, v := range m { + if k == "_meta" { + // Handle Meta field + if meta, ok := v.(map[string]interface{}); ok { + p.Meta = meta + } + } else { + // Handle additional fields + p.AdditionalFields[k] = v + } + } + + return nil +} + +type Result struct { + // This result property is reserved by the protocol to allow clients and + // servers to attach additional metadata to their responses. + Meta map[string]interface{} `json:"_meta,omitempty"` +} + +// RequestId is a uniquely identifying ID for a request in JSON-RPC. +// It can be any JSON-serializable value, typically a number or string. +type RequestId interface{} + +// JSONRPCRequest represents a request that expects a response. +type JSONRPCRequest struct { + JSONRPC string `json:"jsonrpc"` + ID RequestId `json:"id"` + Params interface{} `json:"params,omitempty"` + Request +} + +// JSONRPCNotification represents a notification which does not expect a response. +type JSONRPCNotification struct { + JSONRPC string `json:"jsonrpc"` + Notification +} + +// JSONRPCResponse represents a successful (non-error) response to a request. +type JSONRPCResponse struct { + JSONRPC string `json:"jsonrpc"` + ID RequestId `json:"id"` + Result interface{} `json:"result"` +} + +// JSONRPCError represents a non-successful (error) response to a request. +type JSONRPCError struct { + JSONRPC string `json:"jsonrpc"` + ID RequestId `json:"id"` + Error struct { + // The error type that occurred. + Code int `json:"code"` + // A short description of the error. The message SHOULD be limited + // to a concise single sentence. + Message string `json:"message"` + // Additional information about the error. The value of this member + // is defined by the sender (e.g. detailed error information, nested errors etc.). + Data interface{} `json:"data,omitempty"` + } `json:"error"` +} + +// Standard JSON-RPC error codes +const ( + PARSE_ERROR = -32700 + INVALID_REQUEST = -32600 + METHOD_NOT_FOUND = -32601 + INVALID_PARAMS = -32602 + INTERNAL_ERROR = -32603 +) + +/* Empty result */ + +// EmptyResult represents a response that indicates success but carries no data. +type EmptyResult Result + +/* Cancellation */ + +// CancelledNotification can be sent by either side to indicate that it is +// cancelling a previously-issued request. +// +// The request SHOULD still be in-flight, but due to communication latency, it +// is always possible that this notification MAY arrive after the request has +// already finished. +// +// This notification indicates that the result will be unused, so any +// associated processing SHOULD cease. +// +// A client MUST NOT attempt to cancel its `initialize` request. +type CancelledNotification struct { + Notification + Params struct { + // The ID of the request to cancel. + // + // This MUST correspond to the ID of a request previously issued + // in the same direction. + RequestId RequestId `json:"requestId"` + + // An optional string describing the reason for the cancellation. This MAY + // be logged or presented to the user. + Reason string `json:"reason,omitempty"` + } `json:"params"` +} + +/* Initialization */ + +// InitializeRequest is sent from the client to the server when it first +// connects, asking it to begin initialization. +type InitializeRequest struct { + Request + Params struct { + // The latest version of the Model Context Protocol that the client supports. + // The client MAY decide to support older versions as well. + ProtocolVersion string `json:"protocolVersion"` + Capabilities ClientCapabilities `json:"capabilities"` + ClientInfo Implementation `json:"clientInfo"` + } `json:"params"` +} + +// InitializeResult is sent after receiving an initialize request from the +// client. +type InitializeResult struct { + Result + // The version of the Model Context Protocol that the server wants to use. + // This may not match the version that the client requested. If the client cannot + // support this version, it MUST disconnect. + ProtocolVersion string `json:"protocolVersion"` + Capabilities ServerCapabilities `json:"capabilities"` + ServerInfo Implementation `json:"serverInfo"` + // Instructions describing how to use the server and its features. + // + // This can be used by clients to improve the LLM's understanding of + // available tools, resources, etc. It can be thought of like a "hint" to the model. + // For example, this information MAY be added to the system prompt. + Instructions string `json:"instructions,omitempty"` +} + +// InitializedNotification is sent from the client to the server after +// initialization has finished. +type InitializedNotification struct { + Notification +} + +// ClientCapabilities represents capabilities a client may support. Known +// capabilities are defined here, in this schema, but this is not a closed set: any +// client can define its own, additional capabilities. +type ClientCapabilities struct { + // Experimental, non-standard capabilities that the client supports. + Experimental map[string]interface{} `json:"experimental,omitempty"` + // Present if the client supports listing roots. + Roots *struct { + // Whether the client supports notifications for changes to the roots list. + ListChanged bool `json:"listChanged,omitempty"` + } `json:"roots,omitempty"` + // Present if the client supports sampling from an LLM. + Sampling *struct{} `json:"sampling,omitempty"` +} + +// ServerCapabilities represents capabilities that a server may support. Known +// capabilities are defined here, in this schema, but this is not a closed set: any +// server can define its own, additional capabilities. +type ServerCapabilities struct { + // Experimental, non-standard capabilities that the server supports. + Experimental map[string]interface{} `json:"experimental,omitempty"` + // Present if the server supports sending log messages to the client. + Logging *struct{} `json:"logging,omitempty"` + // Present if the server offers any prompt templates. + Prompts *struct { + // Whether this server supports notifications for changes to the prompt list. + ListChanged bool `json:"listChanged,omitempty"` + } `json:"prompts,omitempty"` + // Present if the server offers any resources to read. + Resources *struct { + // Whether this server supports subscribing to resource updates. + Subscribe bool `json:"subscribe,omitempty"` + // Whether this server supports notifications for changes to the resource + // list. + ListChanged bool `json:"listChanged,omitempty"` + } `json:"resources,omitempty"` + // Present if the server offers any tools to call. + Tools *struct { + // Whether this server supports notifications for changes to the tool list. + ListChanged bool `json:"listChanged,omitempty"` + } `json:"tools,omitempty"` +} + +// Implementation describes the name and version of an MCP implementation. +type Implementation struct { + Name string `json:"name"` + Version string `json:"version"` +} + +/* Ping */ + +// PingRequest represents a ping, issued by either the server or the client, +// to check that the other party is still alive. The receiver must promptly respond, +// or else may be disconnected. +type PingRequest struct { + Request +} + +/* Progress notifications */ + +// ProgressNotification is an out-of-band notification used to inform the +// receiver of a progress update for a long-running request. +type ProgressNotification struct { + Notification + Params struct { + // The progress token which was given in the initial request, used to + // associate this notification with the request that is proceeding. + ProgressToken ProgressToken `json:"progressToken"` + // The progress thus far. This should increase every time progress is made, + // even if the total is unknown. + Progress float64 `json:"progress"` + // Total number of items to process (or total progress required), if known. + Total float64 `json:"total,omitempty"` + } `json:"params"` +} + +/* Pagination */ + +type PaginatedRequest struct { + Request + Params struct { + // An opaque token representing the current pagination position. + // If provided, the server should return results starting after this cursor. + Cursor Cursor `json:"cursor,omitempty"` + } `json:"params,omitempty"` +} + +type PaginatedResult struct { + Result + // An opaque token representing the pagination position after the last + // returned result. + // If present, there may be more results available. + NextCursor Cursor `json:"nextCursor,omitempty"` +} + +/* Resources */ + +// ListResourcesRequest is sent from the client to request a list of resources +// the server has. +type ListResourcesRequest struct { + PaginatedRequest +} + +// ListResourcesResult is the server's response to a resources/list request +// from the client. +type ListResourcesResult struct { + PaginatedResult + Resources []Resource `json:"resources"` +} + +// ListResourceTemplatesRequest is sent from the client to request a list of +// resource templates the server has. +type ListResourceTemplatesRequest struct { + PaginatedRequest +} + +// ListResourceTemplatesResult is the server's response to a +// resources/templates/list request from the client. +type ListResourceTemplatesResult struct { + PaginatedResult + ResourceTemplates []ResourceTemplate `json:"resourceTemplates"` +} + +// ReadResourceRequest is sent from the client to the server, to read a +// specific resource URI. +type ReadResourceRequest struct { + Request + Params struct { + // The URI of the resource to read. The URI can use any protocol; it is up + // to the server how to interpret it. + URI string `json:"uri"` + // Arguments to pass to the resource handler + Arguments map[string]interface{} `json:"arguments,omitempty"` + } `json:"params"` +} + +// ReadResourceResult is the server's response to a resources/read request +// from the client. +type ReadResourceResult struct { + Result + Contents []ResourceContents `json:"contents"` // Can be TextResourceContents or BlobResourceContents +} + +// ResourceListChangedNotification is an optional notification from the server +// to the client, informing it that the list of resources it can read from has +// changed. This may be issued by servers without any previous subscription from +// the client. +type ResourceListChangedNotification struct { + Notification +} + +// SubscribeRequest is sent from the client to request resources/updated +// notifications from the server whenever a particular resource changes. +type SubscribeRequest struct { + Request + Params struct { + // The URI of the resource to subscribe to. The URI can use any protocol; it + // is up to the server how to interpret it. + URI string `json:"uri"` + } `json:"params"` +} + +// UnsubscribeRequest is sent from the client to request cancellation of +// resources/updated notifications from the server. This should follow a previous +// resources/subscribe request. +type UnsubscribeRequest struct { + Request + Params struct { + // The URI of the resource to unsubscribe from. + URI string `json:"uri"` + } `json:"params"` +} + +// ResourceUpdatedNotification is a notification from the server to the client, +// informing it that a resource has changed and may need to be read again. This +// should only be sent if the client previously sent a resources/subscribe request. +type ResourceUpdatedNotification struct { + Notification + Params struct { + // The URI of the resource that has been updated. This might be a sub- + // resource of the one that the client actually subscribed to. + URI string `json:"uri"` + } `json:"params"` +} + +// Resource represents a known resource that the server is capable of reading. +type Resource struct { + Annotated + // The URI of this resource. + URI string `json:"uri"` + // A human-readable name for this resource. + // + // This can be used by clients to populate UI elements. + Name string `json:"name"` + // A description of what this resource represents. + // + // This can be used by clients to improve the LLM's understanding of + // available resources. It can be thought of like a "hint" to the model. + Description string `json:"description,omitempty"` + // The MIME type of this resource, if known. + MIMEType string `json:"mimeType,omitempty"` +} + +// ResourceTemplate represents a template description for resources available +// on the server. +type ResourceTemplate struct { + Annotated + // A URI template (according to RFC 6570) that can be used to construct + // resource URIs. + URITemplate *URITemplate `json:"uriTemplate"` + // A human-readable name for the type of resource this template refers to. + // + // This can be used by clients to populate UI elements. + Name string `json:"name"` + // A description of what this template is for. + // + // This can be used by clients to improve the LLM's understanding of + // available resources. It can be thought of like a "hint" to the model. + Description string `json:"description,omitempty"` + // The MIME type for all resources that match this template. This should only + // be included if all resources matching this template have the same type. + MIMEType string `json:"mimeType,omitempty"` +} + +// ResourceContents represents the contents of a specific resource or sub- +// resource. +type ResourceContents interface { + isResourceContents() +} + +type TextResourceContents struct { + // The URI of this resource. + URI string `json:"uri"` + // The MIME type of this resource, if known. + MIMEType string `json:"mimeType,omitempty"` + // The text of the item. This must only be set if the item can actually be + // represented as text (not binary data). + Text string `json:"text"` +} + +func (TextResourceContents) isResourceContents() {} + +type BlobResourceContents struct { + // The URI of this resource. + URI string `json:"uri"` + // The MIME type of this resource, if known. + MIMEType string `json:"mimeType,omitempty"` + // A base64-encoded string representing the binary data of the item. + Blob string `json:"blob"` +} + +func (BlobResourceContents) isResourceContents() {} + +/* Logging */ + +// SetLevelRequest is a request from the client to the server, to enable or +// adjust logging. +type SetLevelRequest struct { + Request + Params struct { + // The level of logging that the client wants to receive from the server. + // The server should send all logs at this level and higher (i.e., more severe) to + // the client as notifications/logging/message. + Level LoggingLevel `json:"level"` + } `json:"params"` +} + +// LoggingMessageNotification is a notification of a log message passed from +// server to client. If no logging/setLevel request has been sent from the client, +// the server MAY decide which messages to send automatically. +type LoggingMessageNotification struct { + Notification + Params struct { + // The severity of this log message. + Level LoggingLevel `json:"level"` + // An optional name of the logger issuing this message. + Logger string `json:"logger,omitempty"` + // The data to be logged, such as a string message or an object. Any JSON + // serializable type is allowed here. + Data interface{} `json:"data"` + } `json:"params"` +} + +// LoggingLevel represents the severity of a log message. +// +// These map to syslog message severities, as specified in RFC-5424: +// https://datatracker.ietf.org/doc/html/rfc5424#section-6.2.1 +type LoggingLevel string + +const ( + LoggingLevelDebug LoggingLevel = "debug" + LoggingLevelInfo LoggingLevel = "info" + LoggingLevelNotice LoggingLevel = "notice" + LoggingLevelWarning LoggingLevel = "warning" + LoggingLevelError LoggingLevel = "error" + LoggingLevelCritical LoggingLevel = "critical" + LoggingLevelAlert LoggingLevel = "alert" + LoggingLevelEmergency LoggingLevel = "emergency" +) + +/* Sampling */ + +// CreateMessageRequest is a request from the server to sample an LLM via the +// client. The client has full discretion over which model to select. The client +// should also inform the user before beginning sampling, to allow them to inspect +// the request (human in the loop) and decide whether to approve it. +type CreateMessageRequest struct { + Request + Params struct { + Messages []SamplingMessage `json:"messages"` + ModelPreferences *ModelPreferences `json:"modelPreferences,omitempty"` + SystemPrompt string `json:"systemPrompt,omitempty"` + IncludeContext string `json:"includeContext,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + MaxTokens int `json:"maxTokens"` + StopSequences []string `json:"stopSequences,omitempty"` + Metadata interface{} `json:"metadata,omitempty"` + } `json:"params"` +} + +// CreateMessageResult is the client's response to a sampling/create_message +// request from the server. The client should inform the user before returning the +// sampled message, to allow them to inspect the response (human in the loop) and +// decide whether to allow the server to see it. +type CreateMessageResult struct { + Result + SamplingMessage + // The name of the model that generated the message. + Model string `json:"model"` + // The reason why sampling stopped, if known. + StopReason string `json:"stopReason,omitempty"` +} + +// SamplingMessage describes a message issued to or received from an LLM API. +type SamplingMessage struct { + Role Role `json:"role"` + Content interface{} `json:"content"` // Can be TextContent or ImageContent +} + +// Annotated is the base for objects that include optional annotations for the +// client. The client can use annotations to inform how objects are used or +// displayed +type Annotated struct { + Annotations *struct { + // Describes who the intended customer of this object or data is. + // + // It can include multiple entries to indicate content useful for multiple + // audiences (e.g., `["user", "assistant"]`). + Audience []Role `json:"audience,omitempty"` + + // Describes how important this data is for operating the server. + // + // A value of 1 means "most important," and indicates that the data is + // effectively required, while 0 means "least important," and indicates that + // the data is entirely optional. + Priority float64 `json:"priority,omitempty"` + } `json:"annotations,omitempty"` +} + +type Content interface { + isContent() +} + +// TextContent represents text provided to or from an LLM. +// It must have Type set to "text". +type TextContent struct { + Annotated + Type string `json:"type"` // Must be "text" + // The text content of the message. + Text string `json:"text"` +} + +func (TextContent) isContent() {} + +// ImageContent represents an image provided to or from an LLM. +// It must have Type set to "image". +type ImageContent struct { + Annotated + Type string `json:"type"` // Must be "image" + // The base64-encoded image data. + Data string `json:"data"` + // The MIME type of the image. Different providers may support different image types. + MIMEType string `json:"mimeType"` +} + +func (ImageContent) isContent() {} + +// EmbeddedResource represents the contents of a resource, embedded into a prompt or tool call result. +// +// It is up to the client how best to render embedded resources for the +// benefit of the LLM and/or the user. +type EmbeddedResource struct { + Annotated + Type string `json:"type"` + Resource ResourceContents `json:"resource"` +} + +func (EmbeddedResource) isContent() {} + +// ModelPreferences represents the server's preferences for model selection, +// requested of the client during sampling. +// +// Because LLMs can vary along multiple dimensions, choosing the "best" modelis +// rarely straightforward. Different models excel in different areas—some are +// faster but less capable, others are more capable but more expensive, and so +// on. This interface allows servers to express their priorities across multiple +// dimensions to help clients make an appropriate selection for their use case. +// +// These preferences are always advisory. The client MAY ignore them. It is also +// up to the client to decide how to interpret these preferences and how to +// balance them against other considerations. +type ModelPreferences struct { + // Optional hints to use for model selection. + // + // If multiple hints are specified, the client MUST evaluate them in order + // (such that the first match is taken). + // + // The client SHOULD prioritize these hints over the numeric priorities, but + // MAY still use the priorities to select from ambiguous matches. + Hints []ModelHint `json:"hints,omitempty"` + + // How much to prioritize cost when selecting a model. A value of 0 means cost + // is not important, while a value of 1 means cost is the most important + // factor. + CostPriority float64 `json:"costPriority,omitempty"` + + // How much to prioritize sampling speed (latency) when selecting a model. A + // value of 0 means speed is not important, while a value of 1 means speed is + // the most important factor. + SpeedPriority float64 `json:"speedPriority,omitempty"` + + // How much to prioritize intelligence and capabilities when selecting a + // model. A value of 0 means intelligence is not important, while a value of 1 + // means intelligence is the most important factor. + IntelligencePriority float64 `json:"intelligencePriority,omitempty"` +} + +// ModelHint represents hints to use for model selection. +// +// Keys not declared here are currently left unspecified by the spec and are up +// to the client to interpret. +type ModelHint struct { + // A hint for a model name. + // + // The client SHOULD treat this as a substring of a model name; for example: + // - `claude-3-5-sonnet` should match `claude-3-5-sonnet-20241022` + // - `sonnet` should match `claude-3-5-sonnet-20241022`, `claude-3-sonnet-20240229`, etc. + // - `claude` should match any Claude model + // + // The client MAY also map the string to a different provider's model name or + // a different model family, as long as it fills a similar niche; for example: + // - `gemini-1.5-flash` could match `claude-3-haiku-20240307` + Name string `json:"name,omitempty"` +} + +/* Autocomplete */ + +// CompleteRequest is a request from the client to the server, to ask for completion options. +type CompleteRequest struct { + Request + Params struct { + Ref interface{} `json:"ref"` // Can be PromptReference or ResourceReference + Argument struct { + // The name of the argument + Name string `json:"name"` + // The value of the argument to use for completion matching. + Value string `json:"value"` + } `json:"argument"` + } `json:"params"` +} + +// CompleteResult is the server's response to a completion/complete request +type CompleteResult struct { + Result + Completion struct { + // An array of completion values. Must not exceed 100 items. + Values []string `json:"values"` + // The total number of completion options available. This can exceed the + // number of values actually sent in the response. + Total int `json:"total,omitempty"` + // Indicates whether there are additional completion options beyond those + // provided in the current response, even if the exact total is unknown. + HasMore bool `json:"hasMore,omitempty"` + } `json:"completion"` +} + +// ResourceReference is a reference to a resource or resource template definition. +type ResourceReference struct { + Type string `json:"type"` + // The URI or URI template of the resource. + URI string `json:"uri"` +} + +// PromptReference identifies a prompt. +type PromptReference struct { + Type string `json:"type"` + // The name of the prompt or prompt template + Name string `json:"name"` +} + +/* Roots */ + +// ListRootsRequest is sent from the server to request a list of root URIs from the client. Roots allow +// servers to ask for specific directories or files to operate on. A common example +// for roots is providing a set of repositories or directories a server should operate +// on. +// +// This request is typically used when the server needs to understand the file system +// structure or access specific locations that the client has permission to read from. +type ListRootsRequest struct { + Request +} + +// ListRootsResult is the client's response to a roots/list request from the server. +// This result contains an array of Root objects, each representing a root directory +// or file that the server can operate on. +type ListRootsResult struct { + Result + Roots []Root `json:"roots"` +} + +// Root represents a root directory or file that the server can operate on. +type Root struct { + // The URI identifying the root. This *must* start with file:// for now. + // This restriction may be relaxed in future versions of the protocol to allow + // other URI schemes. + URI string `json:"uri"` + // An optional name for the root. This can be used to provide a human-readable + // identifier for the root, which may be useful for display purposes or for + // referencing the root in other parts of the application. + Name string `json:"name,omitempty"` +} + +// RootsListChangedNotification is a notification from the client to the +// server, informing it that the list of roots has changed. +// This notification should be sent whenever the client adds, removes, or modifies any root. +// The server should then request an updated list of roots using the ListRootsRequest. +type RootsListChangedNotification struct { + Notification +} + +/* Client messages */ +// ClientRequest represents any request that can be sent from client to server. +type ClientRequest interface{} + +// ClientNotification represents any notification that can be sent from client to server. +type ClientNotification interface{} + +// ClientResult represents any result that can be sent from client to server. +type ClientResult interface{} + +/* Server messages */ +// ServerRequest represents any request that can be sent from server to client. +type ServerRequest interface{} + +// ServerNotification represents any notification that can be sent from server to client. +type ServerNotification interface{} + +// ServerResult represents any result that can be sent from server to client. +type ServerResult interface{} diff --git a/vendor/github.com/mark3labs/mcp-go/mcp/utils.go b/vendor/github.com/mark3labs/mcp-go/mcp/utils.go new file mode 100644 index 0000000000..236164cbd8 --- /dev/null +++ b/vendor/github.com/mark3labs/mcp-go/mcp/utils.go @@ -0,0 +1,596 @@ +package mcp + +import ( + "encoding/json" + "fmt" +) + +// ClientRequest types +var _ ClientRequest = &PingRequest{} +var _ ClientRequest = &InitializeRequest{} +var _ ClientRequest = &CompleteRequest{} +var _ ClientRequest = &SetLevelRequest{} +var _ ClientRequest = &GetPromptRequest{} +var _ ClientRequest = &ListPromptsRequest{} +var _ ClientRequest = &ListResourcesRequest{} +var _ ClientRequest = &ReadResourceRequest{} +var _ ClientRequest = &SubscribeRequest{} +var _ ClientRequest = &UnsubscribeRequest{} +var _ ClientRequest = &CallToolRequest{} +var _ ClientRequest = &ListToolsRequest{} + +// ClientNotification types +var _ ClientNotification = &CancelledNotification{} +var _ ClientNotification = &ProgressNotification{} +var _ ClientNotification = &InitializedNotification{} +var _ ClientNotification = &RootsListChangedNotification{} + +// ClientResult types +var _ ClientResult = &EmptyResult{} +var _ ClientResult = &CreateMessageResult{} +var _ ClientResult = &ListRootsResult{} + +// ServerRequest types +var _ ServerRequest = &PingRequest{} +var _ ServerRequest = &CreateMessageRequest{} +var _ ServerRequest = &ListRootsRequest{} + +// ServerNotification types +var _ ServerNotification = &CancelledNotification{} +var _ ServerNotification = &ProgressNotification{} +var _ ServerNotification = &LoggingMessageNotification{} +var _ ServerNotification = &ResourceUpdatedNotification{} +var _ ServerNotification = &ResourceListChangedNotification{} +var _ ServerNotification = &ToolListChangedNotification{} +var _ ServerNotification = &PromptListChangedNotification{} + +// ServerResult types +var _ ServerResult = &EmptyResult{} +var _ ServerResult = &InitializeResult{} +var _ ServerResult = &CompleteResult{} +var _ ServerResult = &GetPromptResult{} +var _ ServerResult = &ListPromptsResult{} +var _ ServerResult = &ListResourcesResult{} +var _ ServerResult = &ReadResourceResult{} +var _ ServerResult = &CallToolResult{} +var _ ServerResult = &ListToolsResult{} + +// Helper functions for type assertions + +// asType attempts to cast the given interface to the given type +func asType[T any](content interface{}) (*T, bool) { + tc, ok := content.(T) + if !ok { + return nil, false + } + return &tc, true +} + +// AsTextContent attempts to cast the given interface to TextContent +func AsTextContent(content interface{}) (*TextContent, bool) { + return asType[TextContent](content) +} + +// AsImageContent attempts to cast the given interface to ImageContent +func AsImageContent(content interface{}) (*ImageContent, bool) { + return asType[ImageContent](content) +} + +// AsEmbeddedResource attempts to cast the given interface to EmbeddedResource +func AsEmbeddedResource(content interface{}) (*EmbeddedResource, bool) { + return asType[EmbeddedResource](content) +} + +// AsTextResourceContents attempts to cast the given interface to TextResourceContents +func AsTextResourceContents(content interface{}) (*TextResourceContents, bool) { + return asType[TextResourceContents](content) +} + +// AsBlobResourceContents attempts to cast the given interface to BlobResourceContents +func AsBlobResourceContents(content interface{}) (*BlobResourceContents, bool) { + return asType[BlobResourceContents](content) +} + +// Helper function for JSON-RPC + +// NewJSONRPCResponse creates a new JSONRPCResponse with the given id and result +func NewJSONRPCResponse(id RequestId, result Result) JSONRPCResponse { + return JSONRPCResponse{ + JSONRPC: JSONRPC_VERSION, + ID: id, + Result: result, + } +} + +// NewJSONRPCError creates a new JSONRPCResponse with the given id, code, and message +func NewJSONRPCError( + id RequestId, + code int, + message string, + data interface{}, +) JSONRPCError { + return JSONRPCError{ + JSONRPC: JSONRPC_VERSION, + ID: id, + Error: struct { + Code int `json:"code"` + Message string `json:"message"` + Data interface{} `json:"data,omitempty"` + }{ + Code: code, + Message: message, + Data: data, + }, + } +} + +// Helper function for creating a progress notification +func NewProgressNotification( + token ProgressToken, + progress float64, + total *float64, +) ProgressNotification { + notification := ProgressNotification{ + Notification: Notification{ + Method: "notifications/progress", + }, + Params: struct { + ProgressToken ProgressToken `json:"progressToken"` + Progress float64 `json:"progress"` + Total float64 `json:"total,omitempty"` + }{ + ProgressToken: token, + Progress: progress, + }, + } + if total != nil { + notification.Params.Total = *total + } + return notification +} + +// Helper function for creating a logging message notification +func NewLoggingMessageNotification( + level LoggingLevel, + logger string, + data interface{}, +) LoggingMessageNotification { + return LoggingMessageNotification{ + Notification: Notification{ + Method: "notifications/message", + }, + Params: struct { + Level LoggingLevel `json:"level"` + Logger string `json:"logger,omitempty"` + Data interface{} `json:"data"` + }{ + Level: level, + Logger: logger, + Data: data, + }, + } +} + +// Helper function to create a new PromptMessage +func NewPromptMessage(role Role, content Content) PromptMessage { + return PromptMessage{ + Role: role, + Content: content, + } +} + +// Helper function to create a new TextContent +func NewTextContent(text string) TextContent { + return TextContent{ + Type: "text", + Text: text, + } +} + +// Helper function to create a new ImageContent +func NewImageContent(data, mimeType string) ImageContent { + return ImageContent{ + Type: "image", + Data: data, + MIMEType: mimeType, + } +} + +// Helper function to create a new EmbeddedResource +func NewEmbeddedResource(resource ResourceContents) EmbeddedResource { + return EmbeddedResource{ + Type: "resource", + Resource: resource, + } +} + +// NewToolResultText creates a new CallToolResult with a text content +func NewToolResultText(text string) *CallToolResult { + return &CallToolResult{ + Content: []Content{ + TextContent{ + Type: "text", + Text: text, + }, + }, + } +} + +// NewToolResultImage creates a new CallToolResult with both text and image content +func NewToolResultImage(text, imageData, mimeType string) *CallToolResult { + return &CallToolResult{ + Content: []Content{ + TextContent{ + Type: "text", + Text: text, + }, + ImageContent{ + Type: "image", + Data: imageData, + MIMEType: mimeType, + }, + }, + } +} + +// NewToolResultResource creates a new CallToolResult with an embedded resource +func NewToolResultResource( + text string, + resource ResourceContents, +) *CallToolResult { + return &CallToolResult{ + Content: []Content{ + TextContent{ + Type: "text", + Text: text, + }, + EmbeddedResource{ + Type: "resource", + Resource: resource, + }, + }, + } +} + +// NewToolResultError creates a new CallToolResult with an error message. +// Any errors that originate from the tool SHOULD be reported inside the result object. +func NewToolResultError(text string) *CallToolResult { + return &CallToolResult{ + Content: []Content{ + TextContent{ + Type: "text", + Text: text, + }, + }, + IsError: true, + } +} + +// NewListResourcesResult creates a new ListResourcesResult +func NewListResourcesResult( + resources []Resource, + nextCursor Cursor, +) *ListResourcesResult { + return &ListResourcesResult{ + PaginatedResult: PaginatedResult{ + NextCursor: nextCursor, + }, + Resources: resources, + } +} + +// NewListResourceTemplatesResult creates a new ListResourceTemplatesResult +func NewListResourceTemplatesResult( + templates []ResourceTemplate, + nextCursor Cursor, +) *ListResourceTemplatesResult { + return &ListResourceTemplatesResult{ + PaginatedResult: PaginatedResult{ + NextCursor: nextCursor, + }, + ResourceTemplates: templates, + } +} + +// NewReadResourceResult creates a new ReadResourceResult with text content +func NewReadResourceResult(text string) *ReadResourceResult { + return &ReadResourceResult{ + Contents: []ResourceContents{ + TextResourceContents{ + Text: text, + }, + }, + } +} + +// NewListPromptsResult creates a new ListPromptsResult +func NewListPromptsResult( + prompts []Prompt, + nextCursor Cursor, +) *ListPromptsResult { + return &ListPromptsResult{ + PaginatedResult: PaginatedResult{ + NextCursor: nextCursor, + }, + Prompts: prompts, + } +} + +// NewGetPromptResult creates a new GetPromptResult +func NewGetPromptResult( + description string, + messages []PromptMessage, +) *GetPromptResult { + return &GetPromptResult{ + Description: description, + Messages: messages, + } +} + +// NewListToolsResult creates a new ListToolsResult +func NewListToolsResult(tools []Tool, nextCursor Cursor) *ListToolsResult { + return &ListToolsResult{ + PaginatedResult: PaginatedResult{ + NextCursor: nextCursor, + }, + Tools: tools, + } +} + +// NewInitializeResult creates a new InitializeResult +func NewInitializeResult( + protocolVersion string, + capabilities ServerCapabilities, + serverInfo Implementation, + instructions string, +) *InitializeResult { + return &InitializeResult{ + ProtocolVersion: protocolVersion, + Capabilities: capabilities, + ServerInfo: serverInfo, + Instructions: instructions, + } +} + +// Helper for formatting numbers in tool results +func FormatNumberResult(value float64) *CallToolResult { + return NewToolResultText(fmt.Sprintf("%.2f", value)) +} + +func ExtractString(data map[string]any, key string) string { + if value, ok := data[key]; ok { + if str, ok := value.(string); ok { + return str + } + } + return "" +} + +func ExtractMap(data map[string]any, key string) map[string]any { + if value, ok := data[key]; ok { + if m, ok := value.(map[string]any); ok { + return m + } + } + return nil +} + +func ParseContent(contentMap map[string]any) (Content, error) { + contentType := ExtractString(contentMap, "type") + + switch contentType { + case "text": + text := ExtractString(contentMap, "text") + if text == "" { + return nil, fmt.Errorf("text is missing") + } + return NewTextContent(text), nil + + case "image": + data := ExtractString(contentMap, "data") + mimeType := ExtractString(contentMap, "mimeType") + if data == "" || mimeType == "" { + return nil, fmt.Errorf("image data or mimeType is missing") + } + return NewImageContent(data, mimeType), nil + + case "resource": + resourceMap := ExtractMap(contentMap, "resource") + if resourceMap == nil { + return nil, fmt.Errorf("resource is missing") + } + + resourceContents, err := ParseResourceContents(resourceMap) + if err != nil { + return nil, err + } + + return NewEmbeddedResource(resourceContents), nil + } + + return nil, fmt.Errorf("unsupported content type: %s", contentType) +} + +func ParseGetPromptResult(rawMessage *json.RawMessage) (*GetPromptResult, error) { + var jsonContent map[string]any + if err := json.Unmarshal(*rawMessage, &jsonContent); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) + } + + result := GetPromptResult{} + + meta, ok := jsonContent["_meta"] + if ok { + if metaMap, ok := meta.(map[string]any); ok { + result.Meta = metaMap + } + } + + description, ok := jsonContent["description"] + if ok { + if descriptionStr, ok := description.(string); ok { + result.Description = descriptionStr + } + } + + messages, ok := jsonContent["messages"] + if ok { + messagesArr, ok := messages.([]any) + if !ok { + return nil, fmt.Errorf("messages is not an array") + } + + for _, message := range messagesArr { + messageMap, ok := message.(map[string]any) + if !ok { + return nil, fmt.Errorf("message is not an object") + } + + // Extract role + roleStr := ExtractString(messageMap, "role") + if roleStr == "" || (roleStr != string(RoleAssistant) && roleStr != string(RoleUser)) { + return nil, fmt.Errorf("unsupported role: %s", roleStr) + } + + // Extract content + contentMap, ok := messageMap["content"].(map[string]any) + if !ok { + return nil, fmt.Errorf("content is not an object") + } + + // Process content + content, err := ParseContent(contentMap) + if err != nil { + return nil, err + } + + // Append processed message + result.Messages = append(result.Messages, NewPromptMessage(Role(roleStr), content)) + + } + } + + return &result, nil +} + +func ParseCallToolResult(rawMessage *json.RawMessage) (*CallToolResult, error) { + var jsonContent map[string]any + if err := json.Unmarshal(*rawMessage, &jsonContent); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) + } + + var result CallToolResult + + meta, ok := jsonContent["_meta"] + if ok { + if metaMap, ok := meta.(map[string]any); ok { + result.Meta = metaMap + } + } + + isError, ok := jsonContent["isError"] + if ok { + if isErrorBool, ok := isError.(bool); ok { + result.IsError = isErrorBool + } + } + + contents, ok := jsonContent["content"] + if !ok { + return nil, fmt.Errorf("content is missing") + } + + contentArr, ok := contents.([]any) + if !ok { + return nil, fmt.Errorf("content is not an array") + } + + for _, content := range contentArr { + // Extract content + contentMap, ok := content.(map[string]any) + if !ok { + return nil, fmt.Errorf("content is not an object") + } + + // Process content + content, err := ParseContent(contentMap) + if err != nil { + return nil, err + } + + result.Content = append(result.Content, content) + } + + return &result, nil +} + +func ParseResourceContents(contentMap map[string]any) (ResourceContents, error) { + uri := ExtractString(contentMap, "uri") + if uri == "" { + return nil, fmt.Errorf("resource uri is missing") + } + + mimeType := ExtractString(contentMap, "mimeType") + + if text := ExtractString(contentMap, "text"); text != "" { + return TextResourceContents{ + URI: uri, + MIMEType: mimeType, + Text: text, + }, nil + } + + if blob := ExtractString(contentMap, "blob"); blob != "" { + return BlobResourceContents{ + URI: uri, + MIMEType: mimeType, + Blob: blob, + }, nil + } + + return nil, fmt.Errorf("unsupported resource type") +} + +func ParseReadResourceResult(rawMessage *json.RawMessage) (*ReadResourceResult, error) { + var jsonContent map[string]any + if err := json.Unmarshal(*rawMessage, &jsonContent); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) + } + + var result ReadResourceResult + + meta, ok := jsonContent["_meta"] + if ok { + if metaMap, ok := meta.(map[string]any); ok { + result.Meta = metaMap + } + } + + contents, ok := jsonContent["contents"] + if !ok { + return nil, fmt.Errorf("contents is missing") + } + + contentArr, ok := contents.([]any) + if !ok { + return nil, fmt.Errorf("contents is not an array") + } + + for _, content := range contentArr { + // Extract content + contentMap, ok := content.(map[string]any) + if !ok { + return nil, fmt.Errorf("content is not an object") + } + + // Process content + content, err := ParseResourceContents(contentMap) + if err != nil { + return nil, err + } + + result.Contents = append(result.Contents, content) + } + + return &result, nil +} diff --git a/vendor/github.com/mark3labs/mcp-go/server/hooks.go b/vendor/github.com/mark3labs/mcp-go/server/hooks.go new file mode 100644 index 0000000000..ce976a6cdb --- /dev/null +++ b/vendor/github.com/mark3labs/mcp-go/server/hooks.go @@ -0,0 +1,461 @@ +// Code generated by `go generate`. DO NOT EDIT. +// source: server/internal/gen/hooks.go.tmpl +package server + +import ( + "context" + + "github.com/mark3labs/mcp-go/mcp" +) + +// OnRegisterSessionHookFunc is a hook that will be called when a new session is registered. +type OnRegisterSessionHookFunc func(ctx context.Context, session ClientSession) + +// BeforeAnyHookFunc is a function that is called after the request is +// parsed but before the method is called. +type BeforeAnyHookFunc func(ctx context.Context, id any, method mcp.MCPMethod, message any) + +// OnSuccessHookFunc is a hook that will be called after the request +// successfully generates a result, but before the result is sent to the client. +type OnSuccessHookFunc func(ctx context.Context, id any, method mcp.MCPMethod, message any, result any) + +// OnErrorHookFunc is a hook that will be called when an error occurs, +// either during the request parsing or the method execution. +// +// Example usage: +// ``` +// +// hooks.AddOnError(func(ctx context.Context, id any, method mcp.MCPMethod, message any, err error) { +// // Check for specific error types using errors.Is +// if errors.Is(err, ErrUnsupported) { +// // Handle capability not supported errors +// log.Printf("Capability not supported: %v", err) +// } +// +// // Use errors.As to get specific error types +// var parseErr = &UnparseableMessageError{} +// if errors.As(err, &parseErr) { +// // Access specific methods/fields of the error type +// log.Printf("Failed to parse message for method %s: %v", +// parseErr.GetMethod(), parseErr.Unwrap()) +// // Access the raw message that failed to parse +// rawMsg := parseErr.GetMessage() +// } +// +// // Check for specific resource/prompt/tool errors +// switch { +// case errors.Is(err, ErrResourceNotFound): +// log.Printf("Resource not found: %v", err) +// case errors.Is(err, ErrPromptNotFound): +// log.Printf("Prompt not found: %v", err) +// case errors.Is(err, ErrToolNotFound): +// log.Printf("Tool not found: %v", err) +// } +// }) +type OnErrorHookFunc func(ctx context.Context, id any, method mcp.MCPMethod, message any, err error) + +type OnBeforeInitializeFunc func(ctx context.Context, id any, message *mcp.InitializeRequest) +type OnAfterInitializeFunc func(ctx context.Context, id any, message *mcp.InitializeRequest, result *mcp.InitializeResult) + +type OnBeforePingFunc func(ctx context.Context, id any, message *mcp.PingRequest) +type OnAfterPingFunc func(ctx context.Context, id any, message *mcp.PingRequest, result *mcp.EmptyResult) + +type OnBeforeListResourcesFunc func(ctx context.Context, id any, message *mcp.ListResourcesRequest) +type OnAfterListResourcesFunc func(ctx context.Context, id any, message *mcp.ListResourcesRequest, result *mcp.ListResourcesResult) + +type OnBeforeListResourceTemplatesFunc func(ctx context.Context, id any, message *mcp.ListResourceTemplatesRequest) +type OnAfterListResourceTemplatesFunc func(ctx context.Context, id any, message *mcp.ListResourceTemplatesRequest, result *mcp.ListResourceTemplatesResult) + +type OnBeforeReadResourceFunc func(ctx context.Context, id any, message *mcp.ReadResourceRequest) +type OnAfterReadResourceFunc func(ctx context.Context, id any, message *mcp.ReadResourceRequest, result *mcp.ReadResourceResult) + +type OnBeforeListPromptsFunc func(ctx context.Context, id any, message *mcp.ListPromptsRequest) +type OnAfterListPromptsFunc func(ctx context.Context, id any, message *mcp.ListPromptsRequest, result *mcp.ListPromptsResult) + +type OnBeforeGetPromptFunc func(ctx context.Context, id any, message *mcp.GetPromptRequest) +type OnAfterGetPromptFunc func(ctx context.Context, id any, message *mcp.GetPromptRequest, result *mcp.GetPromptResult) + +type OnBeforeListToolsFunc func(ctx context.Context, id any, message *mcp.ListToolsRequest) +type OnAfterListToolsFunc func(ctx context.Context, id any, message *mcp.ListToolsRequest, result *mcp.ListToolsResult) + +type OnBeforeCallToolFunc func(ctx context.Context, id any, message *mcp.CallToolRequest) +type OnAfterCallToolFunc func(ctx context.Context, id any, message *mcp.CallToolRequest, result *mcp.CallToolResult) + +type Hooks struct { + OnRegisterSession []OnRegisterSessionHookFunc + OnBeforeAny []BeforeAnyHookFunc + OnSuccess []OnSuccessHookFunc + OnError []OnErrorHookFunc + OnBeforeInitialize []OnBeforeInitializeFunc + OnAfterInitialize []OnAfterInitializeFunc + OnBeforePing []OnBeforePingFunc + OnAfterPing []OnAfterPingFunc + OnBeforeListResources []OnBeforeListResourcesFunc + OnAfterListResources []OnAfterListResourcesFunc + OnBeforeListResourceTemplates []OnBeforeListResourceTemplatesFunc + OnAfterListResourceTemplates []OnAfterListResourceTemplatesFunc + OnBeforeReadResource []OnBeforeReadResourceFunc + OnAfterReadResource []OnAfterReadResourceFunc + OnBeforeListPrompts []OnBeforeListPromptsFunc + OnAfterListPrompts []OnAfterListPromptsFunc + OnBeforeGetPrompt []OnBeforeGetPromptFunc + OnAfterGetPrompt []OnAfterGetPromptFunc + OnBeforeListTools []OnBeforeListToolsFunc + OnAfterListTools []OnAfterListToolsFunc + OnBeforeCallTool []OnBeforeCallToolFunc + OnAfterCallTool []OnAfterCallToolFunc +} + +func (c *Hooks) AddBeforeAny(hook BeforeAnyHookFunc) { + c.OnBeforeAny = append(c.OnBeforeAny, hook) +} + +func (c *Hooks) AddOnSuccess(hook OnSuccessHookFunc) { + c.OnSuccess = append(c.OnSuccess, hook) +} + +// AddOnError registers a hook function that will be called when an error occurs. +// The error parameter contains the actual error object, which can be interrogated +// using Go's error handling patterns like errors.Is and errors.As. +// +// Example: +// ``` +// // Create a channel to receive errors for testing +// errChan := make(chan error, 1) +// +// // Register hook to capture and inspect errors +// hooks := &Hooks{} +// +// hooks.AddOnError(func(ctx context.Context, id any, method mcp.MCPMethod, message any, err error) { +// // For capability-related errors +// if errors.Is(err, ErrUnsupported) { +// // Handle capability not supported +// errChan <- err +// return +// } +// +// // For parsing errors +// var parseErr = &UnparseableMessageError{} +// if errors.As(err, &parseErr) { +// // Handle unparseable message errors +// fmt.Printf("Failed to parse %s request: %v\n", +// parseErr.GetMethod(), parseErr.Unwrap()) +// errChan <- parseErr +// return +// } +// +// // For resource/prompt/tool not found errors +// if errors.Is(err, ErrResourceNotFound) || +// errors.Is(err, ErrPromptNotFound) || +// errors.Is(err, ErrToolNotFound) { +// // Handle not found errors +// errChan <- err +// return +// } +// +// // For other errors +// errChan <- err +// }) +// +// server := NewMCPServer("test-server", "1.0.0", WithHooks(hooks)) +// ``` +func (c *Hooks) AddOnError(hook OnErrorHookFunc) { + c.OnError = append(c.OnError, hook) +} + +func (c *Hooks) beforeAny(ctx context.Context, id any, method mcp.MCPMethod, message any) { + if c == nil { + return + } + for _, hook := range c.OnBeforeAny { + hook(ctx, id, method, message) + } +} + +func (c *Hooks) onSuccess(ctx context.Context, id any, method mcp.MCPMethod, message any, result any) { + if c == nil { + return + } + for _, hook := range c.OnSuccess { + hook(ctx, id, method, message, result) + } +} + +// onError calls all registered error hooks with the error object. +// The err parameter contains the actual error that occurred, which implements +// the standard error interface and may be a wrapped error or custom error type. +// +// This allows consumer code to use Go's error handling patterns: +// - errors.Is(err, ErrUnsupported) to check for specific sentinel errors +// - errors.As(err, &customErr) to extract custom error types +// +// Common error types include: +// - ErrUnsupported: When a capability is not enabled +// - UnparseableMessageError: When request parsing fails +// - ErrResourceNotFound: When a resource is not found +// - ErrPromptNotFound: When a prompt is not found +// - ErrToolNotFound: When a tool is not found +func (c *Hooks) onError(ctx context.Context, id any, method mcp.MCPMethod, message any, err error) { + if c == nil { + return + } + for _, hook := range c.OnError { + hook(ctx, id, method, message, err) + } +} + +func (c *Hooks) AddOnRegisterSession(hook OnRegisterSessionHookFunc) { + c.OnRegisterSession = append(c.OnRegisterSession, hook) +} + +func (c *Hooks) RegisterSession(ctx context.Context, session ClientSession) { + if c == nil { + return + } + for _, hook := range c.OnRegisterSession { + hook(ctx, session) + } +} +func (c *Hooks) AddBeforeInitialize(hook OnBeforeInitializeFunc) { + c.OnBeforeInitialize = append(c.OnBeforeInitialize, hook) +} + +func (c *Hooks) AddAfterInitialize(hook OnAfterInitializeFunc) { + c.OnAfterInitialize = append(c.OnAfterInitialize, hook) +} + +func (c *Hooks) beforeInitialize(ctx context.Context, id any, message *mcp.InitializeRequest) { + c.beforeAny(ctx, id, mcp.MethodInitialize, message) + if c == nil { + return + } + for _, hook := range c.OnBeforeInitialize { + hook(ctx, id, message) + } +} + +func (c *Hooks) afterInitialize(ctx context.Context, id any, message *mcp.InitializeRequest, result *mcp.InitializeResult) { + c.onSuccess(ctx, id, mcp.MethodInitialize, message, result) + if c == nil { + return + } + for _, hook := range c.OnAfterInitialize { + hook(ctx, id, message, result) + } +} +func (c *Hooks) AddBeforePing(hook OnBeforePingFunc) { + c.OnBeforePing = append(c.OnBeforePing, hook) +} + +func (c *Hooks) AddAfterPing(hook OnAfterPingFunc) { + c.OnAfterPing = append(c.OnAfterPing, hook) +} + +func (c *Hooks) beforePing(ctx context.Context, id any, message *mcp.PingRequest) { + c.beforeAny(ctx, id, mcp.MethodPing, message) + if c == nil { + return + } + for _, hook := range c.OnBeforePing { + hook(ctx, id, message) + } +} + +func (c *Hooks) afterPing(ctx context.Context, id any, message *mcp.PingRequest, result *mcp.EmptyResult) { + c.onSuccess(ctx, id, mcp.MethodPing, message, result) + if c == nil { + return + } + for _, hook := range c.OnAfterPing { + hook(ctx, id, message, result) + } +} +func (c *Hooks) AddBeforeListResources(hook OnBeforeListResourcesFunc) { + c.OnBeforeListResources = append(c.OnBeforeListResources, hook) +} + +func (c *Hooks) AddAfterListResources(hook OnAfterListResourcesFunc) { + c.OnAfterListResources = append(c.OnAfterListResources, hook) +} + +func (c *Hooks) beforeListResources(ctx context.Context, id any, message *mcp.ListResourcesRequest) { + c.beforeAny(ctx, id, mcp.MethodResourcesList, message) + if c == nil { + return + } + for _, hook := range c.OnBeforeListResources { + hook(ctx, id, message) + } +} + +func (c *Hooks) afterListResources(ctx context.Context, id any, message *mcp.ListResourcesRequest, result *mcp.ListResourcesResult) { + c.onSuccess(ctx, id, mcp.MethodResourcesList, message, result) + if c == nil { + return + } + for _, hook := range c.OnAfterListResources { + hook(ctx, id, message, result) + } +} +func (c *Hooks) AddBeforeListResourceTemplates(hook OnBeforeListResourceTemplatesFunc) { + c.OnBeforeListResourceTemplates = append(c.OnBeforeListResourceTemplates, hook) +} + +func (c *Hooks) AddAfterListResourceTemplates(hook OnAfterListResourceTemplatesFunc) { + c.OnAfterListResourceTemplates = append(c.OnAfterListResourceTemplates, hook) +} + +func (c *Hooks) beforeListResourceTemplates(ctx context.Context, id any, message *mcp.ListResourceTemplatesRequest) { + c.beforeAny(ctx, id, mcp.MethodResourcesTemplatesList, message) + if c == nil { + return + } + for _, hook := range c.OnBeforeListResourceTemplates { + hook(ctx, id, message) + } +} + +func (c *Hooks) afterListResourceTemplates(ctx context.Context, id any, message *mcp.ListResourceTemplatesRequest, result *mcp.ListResourceTemplatesResult) { + c.onSuccess(ctx, id, mcp.MethodResourcesTemplatesList, message, result) + if c == nil { + return + } + for _, hook := range c.OnAfterListResourceTemplates { + hook(ctx, id, message, result) + } +} +func (c *Hooks) AddBeforeReadResource(hook OnBeforeReadResourceFunc) { + c.OnBeforeReadResource = append(c.OnBeforeReadResource, hook) +} + +func (c *Hooks) AddAfterReadResource(hook OnAfterReadResourceFunc) { + c.OnAfterReadResource = append(c.OnAfterReadResource, hook) +} + +func (c *Hooks) beforeReadResource(ctx context.Context, id any, message *mcp.ReadResourceRequest) { + c.beforeAny(ctx, id, mcp.MethodResourcesRead, message) + if c == nil { + return + } + for _, hook := range c.OnBeforeReadResource { + hook(ctx, id, message) + } +} + +func (c *Hooks) afterReadResource(ctx context.Context, id any, message *mcp.ReadResourceRequest, result *mcp.ReadResourceResult) { + c.onSuccess(ctx, id, mcp.MethodResourcesRead, message, result) + if c == nil { + return + } + for _, hook := range c.OnAfterReadResource { + hook(ctx, id, message, result) + } +} +func (c *Hooks) AddBeforeListPrompts(hook OnBeforeListPromptsFunc) { + c.OnBeforeListPrompts = append(c.OnBeforeListPrompts, hook) +} + +func (c *Hooks) AddAfterListPrompts(hook OnAfterListPromptsFunc) { + c.OnAfterListPrompts = append(c.OnAfterListPrompts, hook) +} + +func (c *Hooks) beforeListPrompts(ctx context.Context, id any, message *mcp.ListPromptsRequest) { + c.beforeAny(ctx, id, mcp.MethodPromptsList, message) + if c == nil { + return + } + for _, hook := range c.OnBeforeListPrompts { + hook(ctx, id, message) + } +} + +func (c *Hooks) afterListPrompts(ctx context.Context, id any, message *mcp.ListPromptsRequest, result *mcp.ListPromptsResult) { + c.onSuccess(ctx, id, mcp.MethodPromptsList, message, result) + if c == nil { + return + } + for _, hook := range c.OnAfterListPrompts { + hook(ctx, id, message, result) + } +} +func (c *Hooks) AddBeforeGetPrompt(hook OnBeforeGetPromptFunc) { + c.OnBeforeGetPrompt = append(c.OnBeforeGetPrompt, hook) +} + +func (c *Hooks) AddAfterGetPrompt(hook OnAfterGetPromptFunc) { + c.OnAfterGetPrompt = append(c.OnAfterGetPrompt, hook) +} + +func (c *Hooks) beforeGetPrompt(ctx context.Context, id any, message *mcp.GetPromptRequest) { + c.beforeAny(ctx, id, mcp.MethodPromptsGet, message) + if c == nil { + return + } + for _, hook := range c.OnBeforeGetPrompt { + hook(ctx, id, message) + } +} + +func (c *Hooks) afterGetPrompt(ctx context.Context, id any, message *mcp.GetPromptRequest, result *mcp.GetPromptResult) { + c.onSuccess(ctx, id, mcp.MethodPromptsGet, message, result) + if c == nil { + return + } + for _, hook := range c.OnAfterGetPrompt { + hook(ctx, id, message, result) + } +} +func (c *Hooks) AddBeforeListTools(hook OnBeforeListToolsFunc) { + c.OnBeforeListTools = append(c.OnBeforeListTools, hook) +} + +func (c *Hooks) AddAfterListTools(hook OnAfterListToolsFunc) { + c.OnAfterListTools = append(c.OnAfterListTools, hook) +} + +func (c *Hooks) beforeListTools(ctx context.Context, id any, message *mcp.ListToolsRequest) { + c.beforeAny(ctx, id, mcp.MethodToolsList, message) + if c == nil { + return + } + for _, hook := range c.OnBeforeListTools { + hook(ctx, id, message) + } +} + +func (c *Hooks) afterListTools(ctx context.Context, id any, message *mcp.ListToolsRequest, result *mcp.ListToolsResult) { + c.onSuccess(ctx, id, mcp.MethodToolsList, message, result) + if c == nil { + return + } + for _, hook := range c.OnAfterListTools { + hook(ctx, id, message, result) + } +} +func (c *Hooks) AddBeforeCallTool(hook OnBeforeCallToolFunc) { + c.OnBeforeCallTool = append(c.OnBeforeCallTool, hook) +} + +func (c *Hooks) AddAfterCallTool(hook OnAfterCallToolFunc) { + c.OnAfterCallTool = append(c.OnAfterCallTool, hook) +} + +func (c *Hooks) beforeCallTool(ctx context.Context, id any, message *mcp.CallToolRequest) { + c.beforeAny(ctx, id, mcp.MethodToolsCall, message) + if c == nil { + return + } + for _, hook := range c.OnBeforeCallTool { + hook(ctx, id, message) + } +} + +func (c *Hooks) afterCallTool(ctx context.Context, id any, message *mcp.CallToolRequest, result *mcp.CallToolResult) { + c.onSuccess(ctx, id, mcp.MethodToolsCall, message, result) + if c == nil { + return + } + for _, hook := range c.OnAfterCallTool { + hook(ctx, id, message, result) + } +} diff --git a/vendor/github.com/mark3labs/mcp-go/server/request_handler.go b/vendor/github.com/mark3labs/mcp-go/server/request_handler.go new file mode 100644 index 0000000000..946ca7abd3 --- /dev/null +++ b/vendor/github.com/mark3labs/mcp-go/server/request_handler.go @@ -0,0 +1,279 @@ +// Code generated by `go generate`. DO NOT EDIT. +// source: server/internal/gen/request_handler.go.tmpl +package server + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/mark3labs/mcp-go/mcp" +) + +// HandleMessage processes an incoming JSON-RPC message and returns an appropriate response +func (s *MCPServer) HandleMessage( + ctx context.Context, + message json.RawMessage, +) mcp.JSONRPCMessage { + // Add server to context + ctx = context.WithValue(ctx, serverKey{}, s) + var err *requestError + + var baseMessage struct { + JSONRPC string `json:"jsonrpc"` + Method mcp.MCPMethod `json:"method"` + ID any `json:"id,omitempty"` + } + + if err := json.Unmarshal(message, &baseMessage); err != nil { + return createErrorResponse( + nil, + mcp.PARSE_ERROR, + "Failed to parse message", + ) + } + + // Check for valid JSONRPC version + if baseMessage.JSONRPC != mcp.JSONRPC_VERSION { + return createErrorResponse( + baseMessage.ID, + mcp.INVALID_REQUEST, + "Invalid JSON-RPC version", + ) + } + + if baseMessage.ID == nil { + var notification mcp.JSONRPCNotification + if err := json.Unmarshal(message, ¬ification); err != nil { + return createErrorResponse( + nil, + mcp.PARSE_ERROR, + "Failed to parse notification", + ) + } + s.handleNotification(ctx, notification) + return nil // Return nil for notifications + } + + switch baseMessage.Method { + case mcp.MethodInitialize: + var request mcp.InitializeRequest + var result *mcp.InitializeResult + if unmarshalErr := json.Unmarshal(message, &request); unmarshalErr != nil { + err = &requestError{ + id: baseMessage.ID, + code: mcp.INVALID_REQUEST, + err: &UnparseableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, + } + } else { + s.hooks.beforeInitialize(ctx, baseMessage.ID, &request) + result, err = s.handleInitialize(ctx, baseMessage.ID, request) + } + if err != nil { + s.hooks.onError(ctx, baseMessage.ID, baseMessage.Method, &request, err) + return err.ToJSONRPCError() + } + s.hooks.afterInitialize(ctx, baseMessage.ID, &request, result) + return createResponse(baseMessage.ID, *result) + case mcp.MethodPing: + var request mcp.PingRequest + var result *mcp.EmptyResult + if unmarshalErr := json.Unmarshal(message, &request); unmarshalErr != nil { + err = &requestError{ + id: baseMessage.ID, + code: mcp.INVALID_REQUEST, + err: &UnparseableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, + } + } else { + s.hooks.beforePing(ctx, baseMessage.ID, &request) + result, err = s.handlePing(ctx, baseMessage.ID, request) + } + if err != nil { + s.hooks.onError(ctx, baseMessage.ID, baseMessage.Method, &request, err) + return err.ToJSONRPCError() + } + s.hooks.afterPing(ctx, baseMessage.ID, &request, result) + return createResponse(baseMessage.ID, *result) + case mcp.MethodResourcesList: + var request mcp.ListResourcesRequest + var result *mcp.ListResourcesResult + if s.capabilities.resources == nil { + err = &requestError{ + id: baseMessage.ID, + code: mcp.METHOD_NOT_FOUND, + err: fmt.Errorf("resources %w", ErrUnsupported), + } + } else if unmarshalErr := json.Unmarshal(message, &request); unmarshalErr != nil { + err = &requestError{ + id: baseMessage.ID, + code: mcp.INVALID_REQUEST, + err: &UnparseableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, + } + } else { + s.hooks.beforeListResources(ctx, baseMessage.ID, &request) + result, err = s.handleListResources(ctx, baseMessage.ID, request) + } + if err != nil { + s.hooks.onError(ctx, baseMessage.ID, baseMessage.Method, &request, err) + return err.ToJSONRPCError() + } + s.hooks.afterListResources(ctx, baseMessage.ID, &request, result) + return createResponse(baseMessage.ID, *result) + case mcp.MethodResourcesTemplatesList: + var request mcp.ListResourceTemplatesRequest + var result *mcp.ListResourceTemplatesResult + if s.capabilities.resources == nil { + err = &requestError{ + id: baseMessage.ID, + code: mcp.METHOD_NOT_FOUND, + err: fmt.Errorf("resources %w", ErrUnsupported), + } + } else if unmarshalErr := json.Unmarshal(message, &request); unmarshalErr != nil { + err = &requestError{ + id: baseMessage.ID, + code: mcp.INVALID_REQUEST, + err: &UnparseableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, + } + } else { + s.hooks.beforeListResourceTemplates(ctx, baseMessage.ID, &request) + result, err = s.handleListResourceTemplates(ctx, baseMessage.ID, request) + } + if err != nil { + s.hooks.onError(ctx, baseMessage.ID, baseMessage.Method, &request, err) + return err.ToJSONRPCError() + } + s.hooks.afterListResourceTemplates(ctx, baseMessage.ID, &request, result) + return createResponse(baseMessage.ID, *result) + case mcp.MethodResourcesRead: + var request mcp.ReadResourceRequest + var result *mcp.ReadResourceResult + if s.capabilities.resources == nil { + err = &requestError{ + id: baseMessage.ID, + code: mcp.METHOD_NOT_FOUND, + err: fmt.Errorf("resources %w", ErrUnsupported), + } + } else if unmarshalErr := json.Unmarshal(message, &request); unmarshalErr != nil { + err = &requestError{ + id: baseMessage.ID, + code: mcp.INVALID_REQUEST, + err: &UnparseableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, + } + } else { + s.hooks.beforeReadResource(ctx, baseMessage.ID, &request) + result, err = s.handleReadResource(ctx, baseMessage.ID, request) + } + if err != nil { + s.hooks.onError(ctx, baseMessage.ID, baseMessage.Method, &request, err) + return err.ToJSONRPCError() + } + s.hooks.afterReadResource(ctx, baseMessage.ID, &request, result) + return createResponse(baseMessage.ID, *result) + case mcp.MethodPromptsList: + var request mcp.ListPromptsRequest + var result *mcp.ListPromptsResult + if s.capabilities.prompts == nil { + err = &requestError{ + id: baseMessage.ID, + code: mcp.METHOD_NOT_FOUND, + err: fmt.Errorf("prompts %w", ErrUnsupported), + } + } else if unmarshalErr := json.Unmarshal(message, &request); unmarshalErr != nil { + err = &requestError{ + id: baseMessage.ID, + code: mcp.INVALID_REQUEST, + err: &UnparseableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, + } + } else { + s.hooks.beforeListPrompts(ctx, baseMessage.ID, &request) + result, err = s.handleListPrompts(ctx, baseMessage.ID, request) + } + if err != nil { + s.hooks.onError(ctx, baseMessage.ID, baseMessage.Method, &request, err) + return err.ToJSONRPCError() + } + s.hooks.afterListPrompts(ctx, baseMessage.ID, &request, result) + return createResponse(baseMessage.ID, *result) + case mcp.MethodPromptsGet: + var request mcp.GetPromptRequest + var result *mcp.GetPromptResult + if s.capabilities.prompts == nil { + err = &requestError{ + id: baseMessage.ID, + code: mcp.METHOD_NOT_FOUND, + err: fmt.Errorf("prompts %w", ErrUnsupported), + } + } else if unmarshalErr := json.Unmarshal(message, &request); unmarshalErr != nil { + err = &requestError{ + id: baseMessage.ID, + code: mcp.INVALID_REQUEST, + err: &UnparseableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, + } + } else { + s.hooks.beforeGetPrompt(ctx, baseMessage.ID, &request) + result, err = s.handleGetPrompt(ctx, baseMessage.ID, request) + } + if err != nil { + s.hooks.onError(ctx, baseMessage.ID, baseMessage.Method, &request, err) + return err.ToJSONRPCError() + } + s.hooks.afterGetPrompt(ctx, baseMessage.ID, &request, result) + return createResponse(baseMessage.ID, *result) + case mcp.MethodToolsList: + var request mcp.ListToolsRequest + var result *mcp.ListToolsResult + if s.capabilities.tools == nil { + err = &requestError{ + id: baseMessage.ID, + code: mcp.METHOD_NOT_FOUND, + err: fmt.Errorf("tools %w", ErrUnsupported), + } + } else if unmarshalErr := json.Unmarshal(message, &request); unmarshalErr != nil { + err = &requestError{ + id: baseMessage.ID, + code: mcp.INVALID_REQUEST, + err: &UnparseableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, + } + } else { + s.hooks.beforeListTools(ctx, baseMessage.ID, &request) + result, err = s.handleListTools(ctx, baseMessage.ID, request) + } + if err != nil { + s.hooks.onError(ctx, baseMessage.ID, baseMessage.Method, &request, err) + return err.ToJSONRPCError() + } + s.hooks.afterListTools(ctx, baseMessage.ID, &request, result) + return createResponse(baseMessage.ID, *result) + case mcp.MethodToolsCall: + var request mcp.CallToolRequest + var result *mcp.CallToolResult + if s.capabilities.tools == nil { + err = &requestError{ + id: baseMessage.ID, + code: mcp.METHOD_NOT_FOUND, + err: fmt.Errorf("tools %w", ErrUnsupported), + } + } else if unmarshalErr := json.Unmarshal(message, &request); unmarshalErr != nil { + err = &requestError{ + id: baseMessage.ID, + code: mcp.INVALID_REQUEST, + err: &UnparseableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, + } + } else { + s.hooks.beforeCallTool(ctx, baseMessage.ID, &request) + result, err = s.handleToolCall(ctx, baseMessage.ID, request) + } + if err != nil { + s.hooks.onError(ctx, baseMessage.ID, baseMessage.Method, &request, err) + return err.ToJSONRPCError() + } + s.hooks.afterCallTool(ctx, baseMessage.ID, &request, result) + return createResponse(baseMessage.ID, *result) + default: + return createErrorResponse( + baseMessage.ID, + mcp.METHOD_NOT_FOUND, + fmt.Sprintf("Method %s not found", baseMessage.Method), + ) + } +} diff --git a/vendor/github.com/mark3labs/mcp-go/server/server.go b/vendor/github.com/mark3labs/mcp-go/server/server.go new file mode 100644 index 0000000000..ec4fcef006 --- /dev/null +++ b/vendor/github.com/mark3labs/mcp-go/server/server.go @@ -0,0 +1,768 @@ +// Package server provides MCP (Model Control Protocol) server implementations. +package server + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "sort" + "sync" + + "github.com/mark3labs/mcp-go/mcp" +) + +// resourceEntry holds both a resource and its handler +type resourceEntry struct { + resource mcp.Resource + handler ResourceHandlerFunc +} + +// resourceTemplateEntry holds both a template and its handler +type resourceTemplateEntry struct { + template mcp.ResourceTemplate + handler ResourceTemplateHandlerFunc +} + +// ServerOption is a function that configures an MCPServer. +type ServerOption func(*MCPServer) + +// ResourceHandlerFunc is a function that returns resource contents. +type ResourceHandlerFunc func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) + +// ResourceTemplateHandlerFunc is a function that returns a resource template. +type ResourceTemplateHandlerFunc func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) + +// PromptHandlerFunc handles prompt requests with given arguments. +type PromptHandlerFunc func(ctx context.Context, request mcp.GetPromptRequest) (*mcp.GetPromptResult, error) + +// ToolHandlerFunc handles tool calls with given arguments. +type ToolHandlerFunc func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) + +// ServerTool combines a Tool with its ToolHandlerFunc. +type ServerTool struct { + Tool mcp.Tool + Handler ToolHandlerFunc +} + +// ClientSession represents an active session that can be used by MCPServer to interact with client. +type ClientSession interface { + // Initialize marks session as fully initialized and ready for notifications + Initialize() + // Initialized returns if session is ready to accept notifications + Initialized() bool + // NotificationChannel provides a channel suitable for sending notifications to client. + NotificationChannel() chan<- mcp.JSONRPCNotification + // SessionID is a unique identifier used to track user session. + SessionID() string +} + +// clientSessionKey is the context key for storing current client notification channel. +type clientSessionKey struct{} + +// ClientSessionFromContext retrieves current client notification context from context. +func ClientSessionFromContext(ctx context.Context) ClientSession { + if session, ok := ctx.Value(clientSessionKey{}).(ClientSession); ok { + return session + } + return nil +} + +// UnparseableMessageError is attached to the RequestError when json.Unmarshal +// fails on the request. +type UnparseableMessageError struct { + message json.RawMessage + method mcp.MCPMethod + err error +} + +func (e *UnparseableMessageError) Error() string { + return fmt.Sprintf("unparseable %s request: %s", e.method, e.err) +} + +func (e *UnparseableMessageError) Unwrap() error { + return e.err +} + +func (e *UnparseableMessageError) GetMessage() json.RawMessage { + return e.message +} + +func (e *UnparseableMessageError) GetMethod() mcp.MCPMethod { + return e.method +} + +// RequestError is an error that can be converted to a JSON-RPC error. +// Implements Unwrap() to allow inspecting the error chain. +type requestError struct { + id any + code int + err error +} + +func (e *requestError) Error() string { + return fmt.Sprintf("request error: %s", e.err) +} + +func (e *requestError) ToJSONRPCError() mcp.JSONRPCError { + return mcp.JSONRPCError{ + JSONRPC: mcp.JSONRPC_VERSION, + ID: e.id, + Error: struct { + Code int `json:"code"` + Message string `json:"message"` + Data any `json:"data,omitempty"` + }{ + Code: e.code, + Message: e.err.Error(), + }, + } +} + +func (e *requestError) Unwrap() error { + return e.err +} + +var ( + ErrUnsupported = errors.New("not supported") + ErrResourceNotFound = errors.New("resource not found") + ErrPromptNotFound = errors.New("prompt not found") + ErrToolNotFound = errors.New("tool not found") +) + +// NotificationHandlerFunc handles incoming notifications. +type NotificationHandlerFunc func(ctx context.Context, notification mcp.JSONRPCNotification) + +// MCPServer implements a Model Control Protocol server that can handle various types of requests +// including resources, prompts, and tools. +type MCPServer struct { + mu sync.RWMutex // Add mutex for protecting shared resources + name string + version string + instructions string + resources map[string]resourceEntry + resourceTemplates map[string]resourceTemplateEntry + prompts map[string]mcp.Prompt + promptHandlers map[string]PromptHandlerFunc + tools map[string]ServerTool + notificationHandlers map[string]NotificationHandlerFunc + capabilities serverCapabilities + sessions sync.Map + hooks *Hooks +} + +// serverKey is the context key for storing the server instance +type serverKey struct{} + +// ServerFromContext retrieves the MCPServer instance from a context +func ServerFromContext(ctx context.Context) *MCPServer { + if srv, ok := ctx.Value(serverKey{}).(*MCPServer); ok { + return srv + } + return nil +} + +// WithContext sets the current client session and returns the provided context +func (s *MCPServer) WithContext( + ctx context.Context, + session ClientSession, +) context.Context { + return context.WithValue(ctx, clientSessionKey{}, session) +} + +// RegisterSession saves session that should be notified in case if some server attributes changed. +func (s *MCPServer) RegisterSession( + ctx context.Context, + session ClientSession, +) error { + sessionID := session.SessionID() + if _, exists := s.sessions.LoadOrStore(sessionID, session); exists { + return fmt.Errorf("session %s is already registered", sessionID) + } + s.hooks.RegisterSession(ctx, session) + return nil +} + +// UnregisterSession removes from storage session that is shut down. +func (s *MCPServer) UnregisterSession( + sessionID string, +) { + s.sessions.Delete(sessionID) +} + +// sendNotificationToAllClients sends a notification to all the currently active clients. +func (s *MCPServer) sendNotificationToAllClients( + method string, + params map[string]any, +) { + notification := mcp.JSONRPCNotification{ + JSONRPC: mcp.JSONRPC_VERSION, + Notification: mcp.Notification{ + Method: method, + Params: mcp.NotificationParams{ + AdditionalFields: params, + }, + }, + } + + s.sessions.Range(func(k, v any) bool { + if session, ok := v.(ClientSession); ok && session.Initialized() { + select { + case session.NotificationChannel() <- notification: + default: + // TODO: log blocked channel in the future versions + } + } + return true + }) +} + +// SendNotificationToClient sends a notification to the current client +func (s *MCPServer) SendNotificationToClient( + ctx context.Context, + method string, + params map[string]any, +) error { + session := ClientSessionFromContext(ctx) + if session == nil || !session.Initialized() { + return fmt.Errorf("notification channel not initialized") + } + + notification := mcp.JSONRPCNotification{ + JSONRPC: mcp.JSONRPC_VERSION, + Notification: mcp.Notification{ + Method: method, + Params: mcp.NotificationParams{ + AdditionalFields: params, + }, + }, + } + + select { + case session.NotificationChannel() <- notification: + return nil + default: + return fmt.Errorf("notification channel full or blocked") + } +} + +// serverCapabilities defines the supported features of the MCP server +type serverCapabilities struct { + tools *toolCapabilities + resources *resourceCapabilities + prompts *promptCapabilities + logging bool +} + +// resourceCapabilities defines the supported resource-related features +type resourceCapabilities struct { + subscribe bool + listChanged bool +} + +// promptCapabilities defines the supported prompt-related features +type promptCapabilities struct { + listChanged bool +} + +// toolCapabilities defines the supported tool-related features +type toolCapabilities struct { + listChanged bool +} + +// WithResourceCapabilities configures resource-related server capabilities +func WithResourceCapabilities(subscribe, listChanged bool) ServerOption { + return func(s *MCPServer) { + // Always create a non-nil capability object + s.capabilities.resources = &resourceCapabilities{ + subscribe: subscribe, + listChanged: listChanged, + } + } +} + +// WithHooks allows adding hooks that will be called before or after +// either [all] requests or before / after specific request methods, or else +// prior to returning an error to the client. +func WithHooks(hooks *Hooks) ServerOption { + return func(s *MCPServer) { + s.hooks = hooks + } +} + +// WithPromptCapabilities configures prompt-related server capabilities +func WithPromptCapabilities(listChanged bool) ServerOption { + return func(s *MCPServer) { + // Always create a non-nil capability object + s.capabilities.prompts = &promptCapabilities{ + listChanged: listChanged, + } + } +} + +// WithToolCapabilities configures tool-related server capabilities +func WithToolCapabilities(listChanged bool) ServerOption { + return func(s *MCPServer) { + // Always create a non-nil capability object + s.capabilities.tools = &toolCapabilities{ + listChanged: listChanged, + } + } +} + +// WithLogging enables logging capabilities for the server +func WithLogging() ServerOption { + return func(s *MCPServer) { + s.capabilities.logging = true + } +} + +// WithInstructions sets the server instructions for the client returned in the initialize response +func WithInstructions(instructions string) ServerOption { + return func(s *MCPServer) { + s.instructions = instructions + } +} + +// NewMCPServer creates a new MCP server instance with the given name, version and options +func NewMCPServer( + name, version string, + opts ...ServerOption, +) *MCPServer { + s := &MCPServer{ + resources: make(map[string]resourceEntry), + resourceTemplates: make(map[string]resourceTemplateEntry), + prompts: make(map[string]mcp.Prompt), + promptHandlers: make(map[string]PromptHandlerFunc), + tools: make(map[string]ServerTool), + name: name, + version: version, + notificationHandlers: make(map[string]NotificationHandlerFunc), + capabilities: serverCapabilities{ + tools: nil, + resources: nil, + prompts: nil, + logging: false, + }, + } + + for _, opt := range opts { + opt(s) + } + + return s +} + +// AddResource registers a new resource and its handler +func (s *MCPServer) AddResource( + resource mcp.Resource, + handler ResourceHandlerFunc, +) { + if s.capabilities.resources == nil { + s.capabilities.resources = &resourceCapabilities{} + } + s.mu.Lock() + defer s.mu.Unlock() + s.resources[resource.URI] = resourceEntry{ + resource: resource, + handler: handler, + } +} + +// AddResourceTemplate registers a new resource template and its handler +func (s *MCPServer) AddResourceTemplate( + template mcp.ResourceTemplate, + handler ResourceTemplateHandlerFunc, +) { + if s.capabilities.resources == nil { + s.capabilities.resources = &resourceCapabilities{} + } + s.mu.Lock() + defer s.mu.Unlock() + s.resourceTemplates[template.URITemplate.Raw()] = resourceTemplateEntry{ + template: template, + handler: handler, + } +} + +// AddPrompt registers a new prompt handler with the given name +func (s *MCPServer) AddPrompt(prompt mcp.Prompt, handler PromptHandlerFunc) { + if s.capabilities.prompts == nil { + s.capabilities.prompts = &promptCapabilities{} + } + s.mu.Lock() + defer s.mu.Unlock() + s.prompts[prompt.Name] = prompt + s.promptHandlers[prompt.Name] = handler +} + +// AddTool registers a new tool and its handler +func (s *MCPServer) AddTool(tool mcp.Tool, handler ToolHandlerFunc) { + s.AddTools(ServerTool{Tool: tool, Handler: handler}) +} + +// AddTools registers multiple tools at once +func (s *MCPServer) AddTools(tools ...ServerTool) { + if s.capabilities.tools == nil { + s.capabilities.tools = &toolCapabilities{} + } + s.mu.Lock() + for _, entry := range tools { + s.tools[entry.Tool.Name] = entry + } + s.mu.Unlock() + + // Send notification to all initialized sessions + s.sendNotificationToAllClients("notifications/tools/list_changed", nil) +} + +// SetTools replaces all existing tools with the provided list +func (s *MCPServer) SetTools(tools ...ServerTool) { + s.mu.Lock() + s.tools = make(map[string]ServerTool) + s.mu.Unlock() + s.AddTools(tools...) +} + +// DeleteTools removes a tool from the server +func (s *MCPServer) DeleteTools(names ...string) { + s.mu.Lock() + for _, name := range names { + delete(s.tools, name) + } + s.mu.Unlock() + + // Send notification to all initialized sessions + s.sendNotificationToAllClients("notifications/tools/list_changed", nil) +} + +// AddNotificationHandler registers a new handler for incoming notifications +func (s *MCPServer) AddNotificationHandler( + method string, + handler NotificationHandlerFunc, +) { + s.mu.Lock() + defer s.mu.Unlock() + s.notificationHandlers[method] = handler +} + +func (s *MCPServer) handleInitialize( + ctx context.Context, + id interface{}, + request mcp.InitializeRequest, +) (*mcp.InitializeResult, *requestError) { + capabilities := mcp.ServerCapabilities{} + + // Only add resource capabilities if they're configured + if s.capabilities.resources != nil { + capabilities.Resources = &struct { + Subscribe bool `json:"subscribe,omitempty"` + ListChanged bool `json:"listChanged,omitempty"` + }{ + Subscribe: s.capabilities.resources.subscribe, + ListChanged: s.capabilities.resources.listChanged, + } + } + + // Only add prompt capabilities if they're configured + if s.capabilities.prompts != nil { + capabilities.Prompts = &struct { + ListChanged bool `json:"listChanged,omitempty"` + }{ + ListChanged: s.capabilities.prompts.listChanged, + } + } + + // Only add tool capabilities if they're configured + if s.capabilities.tools != nil { + capabilities.Tools = &struct { + ListChanged bool `json:"listChanged,omitempty"` + }{ + ListChanged: s.capabilities.tools.listChanged, + } + } + + if s.capabilities.logging { + capabilities.Logging = &struct{}{} + } + + result := mcp.InitializeResult{ + ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION, + ServerInfo: mcp.Implementation{ + Name: s.name, + Version: s.version, + }, + Capabilities: capabilities, + Instructions: s.instructions, + } + + if session := ClientSessionFromContext(ctx); session != nil { + session.Initialize() + } + return &result, nil +} + +func (s *MCPServer) handlePing( + ctx context.Context, + id interface{}, + request mcp.PingRequest, +) (*mcp.EmptyResult, *requestError) { + return &mcp.EmptyResult{}, nil +} + +func (s *MCPServer) handleListResources( + ctx context.Context, + id interface{}, + request mcp.ListResourcesRequest, +) (*mcp.ListResourcesResult, *requestError) { + s.mu.RLock() + resources := make([]mcp.Resource, 0, len(s.resources)) + for _, entry := range s.resources { + resources = append(resources, entry.resource) + } + s.mu.RUnlock() + + result := mcp.ListResourcesResult{ + Resources: resources, + } + if request.Params.Cursor != "" { + result.NextCursor = "" // Handle pagination if needed + } + return &result, nil +} + +func (s *MCPServer) handleListResourceTemplates( + ctx context.Context, + id interface{}, + request mcp.ListResourceTemplatesRequest, +) (*mcp.ListResourceTemplatesResult, *requestError) { + s.mu.RLock() + templates := make([]mcp.ResourceTemplate, 0, len(s.resourceTemplates)) + for _, entry := range s.resourceTemplates { + templates = append(templates, entry.template) + } + s.mu.RUnlock() + + result := mcp.ListResourceTemplatesResult{ + ResourceTemplates: templates, + } + if request.Params.Cursor != "" { + result.NextCursor = "" // Handle pagination if needed + } + return &result, nil +} + +func (s *MCPServer) handleReadResource( + ctx context.Context, + id interface{}, + request mcp.ReadResourceRequest, +) (*mcp.ReadResourceResult, *requestError) { + s.mu.RLock() + // First try direct resource handlers + if entry, ok := s.resources[request.Params.URI]; ok { + handler := entry.handler + s.mu.RUnlock() + contents, err := handler(ctx, request) + if err != nil { + return nil, &requestError{ + id: id, + code: mcp.INTERNAL_ERROR, + err: err, + } + } + return &mcp.ReadResourceResult{Contents: contents}, nil + } + + // If no direct handler found, try matching against templates + var matchedHandler ResourceTemplateHandlerFunc + var matched bool + for _, entry := range s.resourceTemplates { + template := entry.template + if matchesTemplate(request.Params.URI, template.URITemplate) { + matchedHandler = entry.handler + matched = true + matchedVars := template.URITemplate.Match(request.Params.URI) + // Convert matched variables to a map + request.Params.Arguments = make(map[string]interface{}) + for name, value := range matchedVars { + request.Params.Arguments[name] = value.V + } + break + } + } + s.mu.RUnlock() + + if matched { + contents, err := matchedHandler(ctx, request) + if err != nil { + return nil, &requestError{ + id: id, + code: mcp.INTERNAL_ERROR, + err: err, + } + } + return &mcp.ReadResourceResult{Contents: contents}, nil + } + + return nil, &requestError{ + id: id, + code: mcp.INVALID_PARAMS, + err: fmt.Errorf("handler not found for resource URI '%s': %w", request.Params.URI, ErrResourceNotFound), + } +} + +// matchesTemplate checks if a URI matches a URI template pattern +func matchesTemplate(uri string, template *mcp.URITemplate) bool { + return template.Regexp().MatchString(uri) +} + +func (s *MCPServer) handleListPrompts( + ctx context.Context, + id interface{}, + request mcp.ListPromptsRequest, +) (*mcp.ListPromptsResult, *requestError) { + s.mu.RLock() + prompts := make([]mcp.Prompt, 0, len(s.prompts)) + for _, prompt := range s.prompts { + prompts = append(prompts, prompt) + } + s.mu.RUnlock() + + result := mcp.ListPromptsResult{ + Prompts: prompts, + } + if request.Params.Cursor != "" { + result.NextCursor = "" // Handle pagination if needed + } + return &result, nil +} + +func (s *MCPServer) handleGetPrompt( + ctx context.Context, + id interface{}, + request mcp.GetPromptRequest, +) (*mcp.GetPromptResult, *requestError) { + s.mu.RLock() + handler, ok := s.promptHandlers[request.Params.Name] + s.mu.RUnlock() + + if !ok { + return nil, &requestError{ + id: id, + code: mcp.INVALID_PARAMS, + err: fmt.Errorf("prompt '%s' not found: %w", request.Params.Name, ErrPromptNotFound), + } + } + + result, err := handler(ctx, request) + if err != nil { + return nil, &requestError{ + id: id, + code: mcp.INTERNAL_ERROR, + err: err, + } + } + + return result, nil +} + +func (s *MCPServer) handleListTools( + ctx context.Context, + id interface{}, + request mcp.ListToolsRequest, +) (*mcp.ListToolsResult, *requestError) { + s.mu.RLock() + tools := make([]mcp.Tool, 0, len(s.tools)) + + // Get all tool names for consistent ordering + toolNames := make([]string, 0, len(s.tools)) + for name := range s.tools { + toolNames = append(toolNames, name) + } + + // Sort the tool names for consistent ordering + sort.Strings(toolNames) + + // Add tools in sorted order + for _, name := range toolNames { + tools = append(tools, s.tools[name].Tool) + } + s.mu.RUnlock() + + result := mcp.ListToolsResult{ + Tools: tools, + } + if request.Params.Cursor != "" { + result.NextCursor = "" // Handle pagination if needed + } + return &result, nil +} +func (s *MCPServer) handleToolCall( + ctx context.Context, + id interface{}, + request mcp.CallToolRequest, +) (*mcp.CallToolResult, *requestError) { + s.mu.RLock() + tool, ok := s.tools[request.Params.Name] + s.mu.RUnlock() + + if !ok { + return nil, &requestError{ + id: id, + code: mcp.INVALID_PARAMS, + err: fmt.Errorf("tool '%s' not found: %w", request.Params.Name, ErrToolNotFound), + } + } + + result, err := tool.Handler(ctx, request) + if err != nil { + return nil, &requestError{ + id: id, + code: mcp.INTERNAL_ERROR, + err: err, + } + } + + return result, nil +} + +func (s *MCPServer) handleNotification( + ctx context.Context, + notification mcp.JSONRPCNotification, +) mcp.JSONRPCMessage { + s.mu.RLock() + handler, ok := s.notificationHandlers[notification.Method] + s.mu.RUnlock() + + if ok { + handler(ctx, notification) + } + return nil +} + +func createResponse(id interface{}, result interface{}) mcp.JSONRPCMessage { + return mcp.JSONRPCResponse{ + JSONRPC: mcp.JSONRPC_VERSION, + ID: id, + Result: result, + } +} + +func createErrorResponse( + id interface{}, + code int, + message string, +) mcp.JSONRPCMessage { + return mcp.JSONRPCError{ + JSONRPC: mcp.JSONRPC_VERSION, + ID: id, + Error: struct { + Code int `json:"code"` + Message string `json:"message"` + Data interface{} `json:"data,omitempty"` + }{ + Code: code, + Message: message, + }, + } +} diff --git a/vendor/github.com/mark3labs/mcp-go/server/sse.go b/vendor/github.com/mark3labs/mcp-go/server/sse.go new file mode 100644 index 0000000000..6e6a13fe78 --- /dev/null +++ b/vendor/github.com/mark3labs/mcp-go/server/sse.go @@ -0,0 +1,433 @@ +package server + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "sync" + "sync/atomic" + + "github.com/google/uuid" + "github.com/mark3labs/mcp-go/mcp" +) + +// sseSession represents an active SSE connection. +type sseSession struct { + writer http.ResponseWriter + flusher http.Flusher + done chan struct{} + eventQueue chan string // Channel for queuing events + sessionID string + notificationChannel chan mcp.JSONRPCNotification + initialized atomic.Bool +} + +// SSEContextFunc is a function that takes an existing context and the current +// request and returns a potentially modified context based on the request +// content. This can be used to inject context values from headers, for example. +type SSEContextFunc func(ctx context.Context, r *http.Request) context.Context + +func (s *sseSession) SessionID() string { + return s.sessionID +} + +func (s *sseSession) NotificationChannel() chan<- mcp.JSONRPCNotification { + return s.notificationChannel +} + +func (s *sseSession) Initialize() { + s.initialized.Store(true) +} + +func (s *sseSession) Initialized() bool { + return s.initialized.Load() +} + +var _ ClientSession = (*sseSession)(nil) + +// SSEServer implements a Server-Sent Events (SSE) based MCP server. +// It provides real-time communication capabilities over HTTP using the SSE protocol. +type SSEServer struct { + server *MCPServer + baseURL string + basePath string + messageEndpoint string + useFullURLForMessageEndpoint bool + sseEndpoint string + sessions sync.Map + srv *http.Server + contextFunc SSEContextFunc +} + +// SSEOption defines a function type for configuring SSEServer +type SSEOption func(*SSEServer) + +// WithBaseURL sets the base URL for the SSE server +func WithBaseURL(baseURL string) SSEOption { + return func(s *SSEServer) { + if baseURL != "" { + u, err := url.Parse(baseURL) + if err != nil { + return + } + if u.Scheme != "http" && u.Scheme != "https" { + return + } + // Check if the host is empty or only contains a port + if u.Host == "" || strings.HasPrefix(u.Host, ":") { + return + } + if len(u.Query()) > 0 { + return + } + } + s.baseURL = strings.TrimSuffix(baseURL, "/") + } +} + +// Add a new option for setting base path +func WithBasePath(basePath string) SSEOption { + return func(s *SSEServer) { + // Ensure the path starts with / and doesn't end with / + if !strings.HasPrefix(basePath, "/") { + basePath = "/" + basePath + } + s.basePath = strings.TrimSuffix(basePath, "/") + } +} + +// WithMessageEndpoint sets the message endpoint path +func WithMessageEndpoint(endpoint string) SSEOption { + return func(s *SSEServer) { + s.messageEndpoint = endpoint + } +} + +// WithUseFullURLForMessageEndpoint controls whether the SSE server returns a complete URL (including baseURL) +// or just the path portion for the message endpoint. Set to false when clients will concatenate +// the baseURL themselves to avoid malformed URLs like "http://localhost/mcphttp://localhost/mcp/message". +func WithUseFullURLForMessageEndpoint(useFullURLForMessageEndpoint bool) SSEOption { + return func(s *SSEServer) { + s.useFullURLForMessageEndpoint = useFullURLForMessageEndpoint + } +} + +// WithSSEEndpoint sets the SSE endpoint path +func WithSSEEndpoint(endpoint string) SSEOption { + return func(s *SSEServer) { + s.sseEndpoint = endpoint + } +} + +// WithHTTPServer sets the HTTP server instance +func WithHTTPServer(srv *http.Server) SSEOption { + return func(s *SSEServer) { + s.srv = srv + } +} + +// WithContextFunc sets a function that will be called to customise the context +// to the server using the incoming request. +func WithSSEContextFunc(fn SSEContextFunc) SSEOption { + return func(s *SSEServer) { + s.contextFunc = fn + } +} + +// NewSSEServer creates a new SSE server instance with the given MCP server and options. +func NewSSEServer(server *MCPServer, opts ...SSEOption) *SSEServer { + s := &SSEServer{ + server: server, + sseEndpoint: "/sse", + messageEndpoint: "/message", + useFullURLForMessageEndpoint: true, + } + + // Apply all options + for _, opt := range opts { + opt(s) + } + + return s +} + +// NewTestServer creates a test server for testing purposes +func NewTestServer(server *MCPServer, opts ...SSEOption) *httptest.Server { + sseServer := NewSSEServer(server) + for _, opt := range opts { + opt(sseServer) + } + + testServer := httptest.NewServer(sseServer) + sseServer.baseURL = testServer.URL + return testServer +} + +// Start begins serving SSE connections on the specified address. +// It sets up HTTP handlers for SSE and message endpoints. +func (s *SSEServer) Start(addr string) error { + s.srv = &http.Server{ + Addr: addr, + Handler: s, + } + + return s.srv.ListenAndServe() +} + +// Shutdown gracefully stops the SSE server, closing all active sessions +// and shutting down the HTTP server. +func (s *SSEServer) Shutdown(ctx context.Context) error { + if s.srv != nil { + s.sessions.Range(func(key, value interface{}) bool { + if session, ok := value.(*sseSession); ok { + close(session.done) + } + s.sessions.Delete(key) + return true + }) + + return s.srv.Shutdown(ctx) + } + return nil +} + +// handleSSE handles incoming SSE connection requests. +// It sets up appropriate headers and creates a new session for the client. +func (s *SSEServer) handleSSE(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + w.Header().Set("Access-Control-Allow-Origin", "*") + + flusher, ok := w.(http.Flusher) + if !ok { + http.Error(w, "Streaming unsupported", http.StatusInternalServerError) + return + } + + sessionID := uuid.New().String() + session := &sseSession{ + writer: w, + flusher: flusher, + done: make(chan struct{}), + eventQueue: make(chan string, 100), // Buffer for events + sessionID: sessionID, + notificationChannel: make(chan mcp.JSONRPCNotification, 100), + } + + s.sessions.Store(sessionID, session) + defer s.sessions.Delete(sessionID) + + if err := s.server.RegisterSession(r.Context(), session); err != nil { + http.Error(w, fmt.Sprintf("Session registration failed: %v", err), http.StatusInternalServerError) + return + } + defer s.server.UnregisterSession(sessionID) + + // Start notification handler for this session + go func() { + for { + select { + case notification := <-session.notificationChannel: + eventData, err := json.Marshal(notification) + if err == nil { + select { + case session.eventQueue <- fmt.Sprintf("event: message\ndata: %s\n\n", eventData): + // Event queued successfully + case <-session.done: + return + } + } + case <-session.done: + return + case <-r.Context().Done(): + return + } + } + }() + + // Send the initial endpoint event + fmt.Fprintf(w, "event: endpoint\ndata: %s\r\n\r\n", s.GetMessageEndpointForClient(sessionID)) + flusher.Flush() + + // Main event loop - this runs in the HTTP handler goroutine + for { + select { + case event := <-session.eventQueue: + // Write the event to the response + fmt.Fprint(w, event) + flusher.Flush() + case <-r.Context().Done(): + close(session.done) + return + } + } +} + +// GetMessageEndpointForClient returns the appropriate message endpoint URL with session ID +// based on the useFullURLForMessageEndpoint configuration. +func (s *SSEServer) GetMessageEndpointForClient(sessionID string) string { + messageEndpoint := s.messageEndpoint + if s.useFullURLForMessageEndpoint { + messageEndpoint = s.CompleteMessageEndpoint() + } + return fmt.Sprintf("%s?sessionId=%s", messageEndpoint, sessionID) +} + +// handleMessage processes incoming JSON-RPC messages from clients and sends responses +// back through both the SSE connection and HTTP response. +func (s *SSEServer) handleMessage(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + s.writeJSONRPCError(w, nil, mcp.INVALID_REQUEST, "Method not allowed") + return + } + + sessionID := r.URL.Query().Get("sessionId") + if sessionID == "" { + s.writeJSONRPCError(w, nil, mcp.INVALID_PARAMS, "Missing sessionId") + return + } + + sessionI, ok := s.sessions.Load(sessionID) + if !ok { + s.writeJSONRPCError(w, nil, mcp.INVALID_PARAMS, "Invalid session ID") + return + } + session := sessionI.(*sseSession) + + // Set the client context before handling the message + ctx := s.server.WithContext(r.Context(), session) + if s.contextFunc != nil { + ctx = s.contextFunc(ctx, r) + } + + // Parse message as raw JSON + var rawMessage json.RawMessage + if err := json.NewDecoder(r.Body).Decode(&rawMessage); err != nil { + s.writeJSONRPCError(w, nil, mcp.PARSE_ERROR, "Parse error") + return + } + + // Process message through MCPServer + response := s.server.HandleMessage(ctx, rawMessage) + + // Only send response if there is one (not for notifications) + if response != nil { + eventData, _ := json.Marshal(response) + + // Queue the event for sending via SSE + select { + case session.eventQueue <- fmt.Sprintf("event: message\ndata: %s\n\n", eventData): + // Event queued successfully + case <-session.done: + // Session is closed, don't try to queue + default: + // Queue is full, could log this + } + + // Send HTTP response + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusAccepted) + json.NewEncoder(w).Encode(response) + } else { + // For notifications, just send 202 Accepted with no body + w.WriteHeader(http.StatusAccepted) + } +} + +// writeJSONRPCError writes a JSON-RPC error response with the given error details. +func (s *SSEServer) writeJSONRPCError( + w http.ResponseWriter, + id interface{}, + code int, + message string, +) { + response := createErrorResponse(id, code, message) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + json.NewEncoder(w).Encode(response) +} + +// SendEventToSession sends an event to a specific SSE session identified by sessionID. +// Returns an error if the session is not found or closed. +func (s *SSEServer) SendEventToSession( + sessionID string, + event interface{}, +) error { + sessionI, ok := s.sessions.Load(sessionID) + if !ok { + return fmt.Errorf("session not found: %s", sessionID) + } + session := sessionI.(*sseSession) + + eventData, err := json.Marshal(event) + if err != nil { + return err + } + + // Queue the event for sending via SSE + select { + case session.eventQueue <- fmt.Sprintf("event: message\ndata: %s\n\n", eventData): + return nil + case <-session.done: + return fmt.Errorf("session closed") + default: + return fmt.Errorf("event queue full") + } +} +func (s *SSEServer) GetUrlPath(input string) (string, error) { + parse, err := url.Parse(input) + if err != nil { + return "", fmt.Errorf("failed to parse URL %s: %w", input, err) + } + return parse.Path, nil +} + +func (s *SSEServer) CompleteSseEndpoint() string { + return s.baseURL + s.basePath + s.sseEndpoint +} +func (s *SSEServer) CompleteSsePath() string { + path, err := s.GetUrlPath(s.CompleteSseEndpoint()) + if err != nil { + return s.basePath + s.sseEndpoint + } + return path +} + +func (s *SSEServer) CompleteMessageEndpoint() string { + return s.baseURL + s.basePath + s.messageEndpoint +} +func (s *SSEServer) CompleteMessagePath() string { + path, err := s.GetUrlPath(s.CompleteMessageEndpoint()) + if err != nil { + return s.basePath + s.messageEndpoint + } + return path +} + +// ServeHTTP implements the http.Handler interface. +func (s *SSEServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { + path := r.URL.Path + // Use exact path matching rather than Contains + ssePath := s.CompleteSsePath() + if ssePath != "" && path == ssePath { + s.handleSSE(w, r) + return + } + messagePath := s.CompleteMessagePath() + if messagePath != "" && path == messagePath { + s.handleMessage(w, r) + return + } + + http.NotFound(w, r) +} diff --git a/vendor/github.com/mark3labs/mcp-go/server/stdio.go b/vendor/github.com/mark3labs/mcp-go/server/stdio.go new file mode 100644 index 0000000000..14c1e76e9a --- /dev/null +++ b/vendor/github.com/mark3labs/mcp-go/server/stdio.go @@ -0,0 +1,283 @@ +package server + +import ( + "bufio" + "context" + "encoding/json" + "fmt" + "io" + "log" + "os" + "os/signal" + "sync/atomic" + "syscall" + + "github.com/mark3labs/mcp-go/mcp" +) + +// StdioContextFunc is a function that takes an existing context and returns +// a potentially modified context. +// This can be used to inject context values from environment variables, +// for example. +type StdioContextFunc func(ctx context.Context) context.Context + +// StdioServer wraps a MCPServer and handles stdio communication. +// It provides a simple way to create command-line MCP servers that +// communicate via standard input/output streams using JSON-RPC messages. +type StdioServer struct { + server *MCPServer + errLogger *log.Logger + contextFunc StdioContextFunc +} + +// StdioOption defines a function type for configuring StdioServer +type StdioOption func(*StdioServer) + +// WithErrorLogger sets the error logger for the server +func WithErrorLogger(logger *log.Logger) StdioOption { + return func(s *StdioServer) { + s.errLogger = logger + } +} + +// WithContextFunc sets a function that will be called to customise the context +// to the server. Note that the stdio server uses the same context for all requests, +// so this function will only be called once per server instance. +func WithStdioContextFunc(fn StdioContextFunc) StdioOption { + return func(s *StdioServer) { + s.contextFunc = fn + } +} + +// stdioSession is a static client session, since stdio has only one client. +type stdioSession struct { + notifications chan mcp.JSONRPCNotification + initialized atomic.Bool +} + +func (s *stdioSession) SessionID() string { + return "stdio" +} + +func (s *stdioSession) NotificationChannel() chan<- mcp.JSONRPCNotification { + return s.notifications +} + +func (s *stdioSession) Initialize() { + s.initialized.Store(true) +} + +func (s *stdioSession) Initialized() bool { + return s.initialized.Load() +} + +var _ ClientSession = (*stdioSession)(nil) + +var stdioSessionInstance = stdioSession{ + notifications: make(chan mcp.JSONRPCNotification, 100), +} + +// NewStdioServer creates a new stdio server wrapper around an MCPServer. +// It initializes the server with a default error logger that discards all output. +func NewStdioServer(server *MCPServer) *StdioServer { + return &StdioServer{ + server: server, + errLogger: log.New( + os.Stderr, + "", + log.LstdFlags, + ), // Default to discarding logs + } +} + +// SetErrorLogger configures where error messages from the StdioServer are logged. +// The provided logger will receive all error messages generated during server operation. +func (s *StdioServer) SetErrorLogger(logger *log.Logger) { + s.errLogger = logger +} + +// SetContextFunc sets a function that will be called to customise the context +// to the server. Note that the stdio server uses the same context for all requests, +// so this function will only be called once per server instance. +func (s *StdioServer) SetContextFunc(fn StdioContextFunc) { + s.contextFunc = fn +} + +// handleNotifications continuously processes notifications from the session's notification channel +// and writes them to the provided output. It runs until the context is cancelled. +// Any errors encountered while writing notifications are logged but do not stop the handler. +func (s *StdioServer) handleNotifications(ctx context.Context, stdout io.Writer) { + for { + select { + case notification := <-stdioSessionInstance.notifications: + if err := s.writeResponse(notification, stdout); err != nil { + s.errLogger.Printf("Error writing notification: %v", err) + } + case <-ctx.Done(): + return + } + } +} + +// processInputStream continuously reads and processes messages from the input stream. +// It handles EOF gracefully as a normal termination condition. +// The function returns when either: +// - The context is cancelled (returns context.Err()) +// - EOF is encountered (returns nil) +// - An error occurs while reading or processing messages (returns the error) +func (s *StdioServer) processInputStream(ctx context.Context, reader *bufio.Reader, stdout io.Writer) error { + for { + if err := ctx.Err(); err != nil { + return err + } + + line, err := s.readNextLine(ctx, reader) + if err != nil { + if err == io.EOF { + return nil + } + s.errLogger.Printf("Error reading input: %v", err) + return err + } + + if err := s.processMessage(ctx, line, stdout); err != nil { + if err == io.EOF { + return nil + } + s.errLogger.Printf("Error handling message: %v", err) + return err + } + } +} + +// readNextLine reads a single line from the input reader in a context-aware manner. +// It uses channels to make the read operation cancellable via context. +// Returns the read line and any error encountered. If the context is cancelled, +// returns an empty string and the context's error. EOF is returned when the input +// stream is closed. +func (s *StdioServer) readNextLine(ctx context.Context, reader *bufio.Reader) (string, error) { + readChan := make(chan string, 1) + errChan := make(chan error, 1) + defer func() { + close(readChan) + close(errChan) + }() + + go func() { + line, err := reader.ReadString('\n') + if err != nil { + errChan <- err + return + } + readChan <- line + }() + + select { + case <-ctx.Done(): + return "", ctx.Err() + case err := <-errChan: + return "", err + case line := <-readChan: + return line, nil + } +} + +// Listen starts listening for JSON-RPC messages on the provided input and writes responses to the provided output. +// It runs until the context is cancelled or an error occurs. +// Returns an error if there are issues with reading input or writing output. +func (s *StdioServer) Listen( + ctx context.Context, + stdin io.Reader, + stdout io.Writer, +) error { + // Set a static client context since stdio only has one client + if err := s.server.RegisterSession(ctx, &stdioSessionInstance); err != nil { + return fmt.Errorf("register session: %w", err) + } + defer s.server.UnregisterSession(stdioSessionInstance.SessionID()) + ctx = s.server.WithContext(ctx, &stdioSessionInstance) + + // Add in any custom context. + if s.contextFunc != nil { + ctx = s.contextFunc(ctx) + } + + reader := bufio.NewReader(stdin) + + // Start notification handler + go s.handleNotifications(ctx, stdout) + return s.processInputStream(ctx, reader, stdout) +} + +// processMessage handles a single JSON-RPC message and writes the response. +// It parses the message, processes it through the wrapped MCPServer, and writes any response. +// Returns an error if there are issues with message processing or response writing. +func (s *StdioServer) processMessage( + ctx context.Context, + line string, + writer io.Writer, +) error { + // Parse the message as raw JSON + var rawMessage json.RawMessage + if err := json.Unmarshal([]byte(line), &rawMessage); err != nil { + response := createErrorResponse(nil, mcp.PARSE_ERROR, "Parse error") + return s.writeResponse(response, writer) + } + + // Handle the message using the wrapped server + response := s.server.HandleMessage(ctx, rawMessage) + + // Only write response if there is one (not for notifications) + if response != nil { + if err := s.writeResponse(response, writer); err != nil { + return fmt.Errorf("failed to write response: %w", err) + } + } + + return nil +} + +// writeResponse marshals and writes a JSON-RPC response message followed by a newline. +// Returns an error if marshaling or writing fails. +func (s *StdioServer) writeResponse( + response mcp.JSONRPCMessage, + writer io.Writer, +) error { + responseBytes, err := json.Marshal(response) + if err != nil { + return err + } + + // Write response followed by newline + if _, err := fmt.Fprintf(writer, "%s\n", responseBytes); err != nil { + return err + } + + return nil +} + +// ServeStdio is a convenience function that creates and starts a StdioServer with os.Stdin and os.Stdout. +// It sets up signal handling for graceful shutdown on SIGTERM and SIGINT. +// Returns an error if the server encounters any issues during operation. +func ServeStdio(server *MCPServer, opts ...StdioOption) error { + s := NewStdioServer(server) + s.SetErrorLogger(log.New(os.Stderr, "", log.LstdFlags)) + + for _, opt := range opts { + opt(s) + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Set up signal handling + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGTERM, syscall.SIGINT) + + go func() { + <-sigChan + cancel() + }() + + return s.Listen(ctx, os.Stdin, os.Stdout) +} diff --git a/vendor/github.com/yosida95/uritemplate/v3/LICENSE b/vendor/github.com/yosida95/uritemplate/v3/LICENSE new file mode 100644 index 0000000000..79e8f87572 --- /dev/null +++ b/vendor/github.com/yosida95/uritemplate/v3/LICENSE @@ -0,0 +1,25 @@ +Copyright (C) 2016, Kohei YOSHIDA . All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + * Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/vendor/github.com/yosida95/uritemplate/v3/README.rst b/vendor/github.com/yosida95/uritemplate/v3/README.rst new file mode 100644 index 0000000000..6815d0a465 --- /dev/null +++ b/vendor/github.com/yosida95/uritemplate/v3/README.rst @@ -0,0 +1,46 @@ +uritemplate +=========== + +`uritemplate`_ is a Go implementation of `URI Template`_ [RFC6570] with +full functionality of URI Template Level 4. + +uritemplate can also generate a regexp that matches expansion of the +URI Template from a URI Template. + +Getting Started +--------------- + +Installation +~~~~~~~~~~~~ + +.. code-block:: sh + + $ go get -u github.com/yosida95/uritemplate/v3 + +Documentation +~~~~~~~~~~~~~ + +The documentation is available on GoDoc_. + +Examples +-------- + +See `examples on GoDoc`_. + +License +------- + +`uritemplate`_ is distributed under the BSD 3-Clause license. +PLEASE READ ./LICENSE carefully and follow its clauses to use this software. + +Author +------ + +yosida95_ + + +.. _`URI Template`: https://tools.ietf.org/html/rfc6570 +.. _Godoc: https://godoc.org/github.com/yosida95/uritemplate +.. _`examples on GoDoc`: https://godoc.org/github.com/yosida95/uritemplate#pkg-examples +.. _yosida95: https://yosida95.com/ +.. _uritemplate: https://github.com/yosida95/uritemplate diff --git a/vendor/github.com/yosida95/uritemplate/v3/compile.go b/vendor/github.com/yosida95/uritemplate/v3/compile.go new file mode 100644 index 0000000000..bd774d15d0 --- /dev/null +++ b/vendor/github.com/yosida95/uritemplate/v3/compile.go @@ -0,0 +1,224 @@ +// Copyright (C) 2016 Kohei YOSHIDA. All rights reserved. +// +// This program is free software; you can redistribute it and/or +// modify it under the terms of The BSD 3-Clause License +// that can be found in the LICENSE file. + +package uritemplate + +import ( + "fmt" + "unicode/utf8" +) + +type compiler struct { + prog *prog +} + +func (c *compiler) init() { + c.prog = &prog{} +} + +func (c *compiler) op(opcode progOpcode) uint32 { + i := len(c.prog.op) + c.prog.op = append(c.prog.op, progOp{code: opcode}) + return uint32(i) +} + +func (c *compiler) opWithRune(opcode progOpcode, r rune) uint32 { + addr := c.op(opcode) + (&c.prog.op[addr]).r = r + return addr +} + +func (c *compiler) opWithRuneClass(opcode progOpcode, rc runeClass) uint32 { + addr := c.op(opcode) + (&c.prog.op[addr]).rc = rc + return addr +} + +func (c *compiler) opWithAddr(opcode progOpcode, absaddr uint32) uint32 { + addr := c.op(opcode) + (&c.prog.op[addr]).i = absaddr + return addr +} + +func (c *compiler) opWithAddrDelta(opcode progOpcode, delta uint32) uint32 { + return c.opWithAddr(opcode, uint32(len(c.prog.op))+delta) +} + +func (c *compiler) opWithName(opcode progOpcode, name string) uint32 { + addr := c.op(opcode) + (&c.prog.op[addr]).name = name + return addr +} + +func (c *compiler) compileString(str string) { + for i := 0; i < len(str); { + // NOTE(yosida95): It is confirmed at parse time that literals + // consist of only valid-UTF8 runes. + r, size := utf8.DecodeRuneInString(str[i:]) + c.opWithRune(opRune, r) + i += size + } +} + +func (c *compiler) compileRuneClass(rc runeClass, maxlen int) { + for i := 0; i < maxlen; i++ { + if i > 0 { + c.opWithAddrDelta(opSplit, 7) + } + c.opWithAddrDelta(opSplit, 3) // raw rune or pct-encoded + c.opWithRuneClass(opRuneClass, rc) // raw rune + c.opWithAddrDelta(opJmp, 4) // + c.opWithRune(opRune, '%') // pct-encoded + c.opWithRuneClass(opRuneClass, runeClassPctE) // + c.opWithRuneClass(opRuneClass, runeClassPctE) // + } +} + +func (c *compiler) compileRuneClassInfinite(rc runeClass) { + start := c.opWithAddrDelta(opSplit, 3) // raw rune or pct-encoded + c.opWithRuneClass(opRuneClass, rc) // raw rune + c.opWithAddrDelta(opJmp, 4) // + c.opWithRune(opRune, '%') // pct-encoded + c.opWithRuneClass(opRuneClass, runeClassPctE) // + c.opWithRuneClass(opRuneClass, runeClassPctE) // + c.opWithAddrDelta(opSplit, 2) // loop + c.opWithAddr(opJmp, start) // +} + +func (c *compiler) compileVarspecValue(spec varspec, expr *expression) { + var specname string + if spec.maxlen > 0 { + specname = fmt.Sprintf("%s:%d", spec.name, spec.maxlen) + } else { + specname = spec.name + } + + c.prog.numCap++ + + c.opWithName(opCapStart, specname) + + split := c.op(opSplit) + if spec.maxlen > 0 { + c.compileRuneClass(expr.allow, spec.maxlen) + } else { + c.compileRuneClassInfinite(expr.allow) + } + + capEnd := c.opWithName(opCapEnd, specname) + c.prog.op[split].i = capEnd +} + +func (c *compiler) compileVarspec(spec varspec, expr *expression) { + switch { + case expr.named && spec.explode: + split1 := c.op(opSplit) + noop := c.op(opNoop) + c.compileString(spec.name) + + split2 := c.op(opSplit) + c.opWithRune(opRune, '=') + c.compileVarspecValue(spec, expr) + + split3 := c.op(opSplit) + c.compileString(expr.sep) + c.opWithAddr(opJmp, noop) + + c.prog.op[split2].i = uint32(len(c.prog.op)) + c.compileString(expr.ifemp) + c.opWithAddr(opJmp, split3) + + c.prog.op[split1].i = uint32(len(c.prog.op)) + c.prog.op[split3].i = uint32(len(c.prog.op)) + + case expr.named && !spec.explode: + c.compileString(spec.name) + + split2 := c.op(opSplit) + c.opWithRune(opRune, '=') + + split3 := c.op(opSplit) + + split4 := c.op(opSplit) + c.compileVarspecValue(spec, expr) + + split5 := c.op(opSplit) + c.prog.op[split4].i = split5 + c.compileString(",") + c.opWithAddr(opJmp, split4) + + c.prog.op[split3].i = uint32(len(c.prog.op)) + c.compileString(",") + jmp1 := c.op(opJmp) + + c.prog.op[split2].i = uint32(len(c.prog.op)) + c.compileString(expr.ifemp) + + c.prog.op[split5].i = uint32(len(c.prog.op)) + c.prog.op[jmp1].i = uint32(len(c.prog.op)) + + case !expr.named: + start := uint32(len(c.prog.op)) + c.compileVarspecValue(spec, expr) + + split1 := c.op(opSplit) + jmp := c.op(opJmp) + + c.prog.op[split1].i = uint32(len(c.prog.op)) + if spec.explode { + c.compileString(expr.sep) + } else { + c.opWithRune(opRune, ',') + } + c.opWithAddr(opJmp, start) + + c.prog.op[jmp].i = uint32(len(c.prog.op)) + } +} + +func (c *compiler) compileExpression(expr *expression) { + if len(expr.vars) < 1 { + return + } + + split1 := c.op(opSplit) + c.compileString(expr.first) + + for i, size := 0, len(expr.vars); i < size; i++ { + spec := expr.vars[i] + + split2 := c.op(opSplit) + if i > 0 { + split3 := c.op(opSplit) + c.compileString(expr.sep) + c.prog.op[split3].i = uint32(len(c.prog.op)) + } + c.compileVarspec(spec, expr) + c.prog.op[split2].i = uint32(len(c.prog.op)) + } + + c.prog.op[split1].i = uint32(len(c.prog.op)) +} + +func (c *compiler) compileLiterals(lt literals) { + c.compileString(string(lt)) +} + +func (c *compiler) compile(tmpl *Template) { + c.op(opLineBegin) + for i := range tmpl.exprs { + expr := tmpl.exprs[i] + switch expr := expr.(type) { + default: + panic("unhandled expression") + case *expression: + c.compileExpression(expr) + case literals: + c.compileLiterals(expr) + } + } + c.op(opLineEnd) + c.op(opEnd) +} diff --git a/vendor/github.com/yosida95/uritemplate/v3/equals.go b/vendor/github.com/yosida95/uritemplate/v3/equals.go new file mode 100644 index 0000000000..aa59a5c030 --- /dev/null +++ b/vendor/github.com/yosida95/uritemplate/v3/equals.go @@ -0,0 +1,53 @@ +// Copyright (C) 2016 Kohei YOSHIDA. All rights reserved. +// +// This program is free software; you can redistribute it and/or +// modify it under the terms of The BSD 3-Clause License +// that can be found in the LICENSE file. + +package uritemplate + +type CompareFlags uint8 + +const ( + CompareVarname CompareFlags = 1 << iota +) + +// Equals reports whether or not two URI Templates t1 and t2 are equivalent. +func Equals(t1 *Template, t2 *Template, flags CompareFlags) bool { + if len(t1.exprs) != len(t2.exprs) { + return false + } + for i := 0; i < len(t1.exprs); i++ { + switch t1 := t1.exprs[i].(type) { + case literals: + t2, ok := t2.exprs[i].(literals) + if !ok { + return false + } + if t1 != t2 { + return false + } + case *expression: + t2, ok := t2.exprs[i].(*expression) + if !ok { + return false + } + if t1.op != t2.op || len(t1.vars) != len(t2.vars) { + return false + } + for n := 0; n < len(t1.vars); n++ { + v1 := t1.vars[n] + v2 := t2.vars[n] + if flags&CompareVarname == CompareVarname && v1.name != v2.name { + return false + } + if v1.maxlen != v2.maxlen || v1.explode != v2.explode { + return false + } + } + default: + panic("unhandled case") + } + } + return true +} diff --git a/vendor/github.com/yosida95/uritemplate/v3/error.go b/vendor/github.com/yosida95/uritemplate/v3/error.go new file mode 100644 index 0000000000..2fd34a8080 --- /dev/null +++ b/vendor/github.com/yosida95/uritemplate/v3/error.go @@ -0,0 +1,16 @@ +// Copyright (C) 2016 Kohei YOSHIDA. All rights reserved. +// +// This program is free software; you can redistribute it and/or +// modify it under the terms of The BSD 3-Clause License +// that can be found in the LICENSE file. + +package uritemplate + +import ( + "fmt" +) + +func errorf(pos int, format string, a ...interface{}) error { + msg := fmt.Sprintf(format, a...) + return fmt.Errorf("uritemplate:%d:%s", pos, msg) +} diff --git a/vendor/github.com/yosida95/uritemplate/v3/escape.go b/vendor/github.com/yosida95/uritemplate/v3/escape.go new file mode 100644 index 0000000000..6d27e693af --- /dev/null +++ b/vendor/github.com/yosida95/uritemplate/v3/escape.go @@ -0,0 +1,190 @@ +// Copyright (C) 2016 Kohei YOSHIDA. All rights reserved. +// +// This program is free software; you can redistribute it and/or +// modify it under the terms of The BSD 3-Clause License +// that can be found in the LICENSE file. + +package uritemplate + +import ( + "strings" + "unicode" + "unicode/utf8" +) + +var ( + hex = []byte("0123456789ABCDEF") + // reserved = gen-delims / sub-delims + // gen-delims = ":" / "/" / "?" / "#" / "[" / "]" / "@" + // sub-delims = "!" / "$" / "&" / "’" / "(" / ")" + // / "*" / "+" / "," / ";" / "=" + rangeReserved = &unicode.RangeTable{ + R16: []unicode.Range16{ + {Lo: 0x21, Hi: 0x21, Stride: 1}, // '!' + {Lo: 0x23, Hi: 0x24, Stride: 1}, // '#' - '$' + {Lo: 0x26, Hi: 0x2C, Stride: 1}, // '&' - ',' + {Lo: 0x2F, Hi: 0x2F, Stride: 1}, // '/' + {Lo: 0x3A, Hi: 0x3B, Stride: 1}, // ':' - ';' + {Lo: 0x3D, Hi: 0x3D, Stride: 1}, // '=' + {Lo: 0x3F, Hi: 0x40, Stride: 1}, // '?' - '@' + {Lo: 0x5B, Hi: 0x5B, Stride: 1}, // '[' + {Lo: 0x5D, Hi: 0x5D, Stride: 1}, // ']' + }, + LatinOffset: 9, + } + reReserved = `\x21\x23\x24\x26-\x2c\x2f\x3a\x3b\x3d\x3f\x40\x5b\x5d` + // ALPHA = %x41-5A / %x61-7A + // DIGIT = %x30-39 + // unreserved = ALPHA / DIGIT / "-" / "." / "_" / "~" + rangeUnreserved = &unicode.RangeTable{ + R16: []unicode.Range16{ + {Lo: 0x2D, Hi: 0x2E, Stride: 1}, // '-' - '.' + {Lo: 0x30, Hi: 0x39, Stride: 1}, // '0' - '9' + {Lo: 0x41, Hi: 0x5A, Stride: 1}, // 'A' - 'Z' + {Lo: 0x5F, Hi: 0x5F, Stride: 1}, // '_' + {Lo: 0x61, Hi: 0x7A, Stride: 1}, // 'a' - 'z' + {Lo: 0x7E, Hi: 0x7E, Stride: 1}, // '~' + }, + } + reUnreserved = `\x2d\x2e\x30-\x39\x41-\x5a\x5f\x61-\x7a\x7e` +) + +type runeClass uint8 + +const ( + runeClassU runeClass = 1 << iota + runeClassR + runeClassPctE + runeClassLast + + runeClassUR = runeClassU | runeClassR +) + +var runeClassNames = []string{ + "U", + "R", + "pct-encoded", +} + +func (rc runeClass) String() string { + ret := make([]string, 0, len(runeClassNames)) + for i, j := 0, runeClass(1); j < runeClassLast; j <<= 1 { + if rc&j == j { + ret = append(ret, runeClassNames[i]) + } + i++ + } + return strings.Join(ret, "+") +} + +func pctEncode(w *strings.Builder, r rune) { + if s := r >> 24 & 0xff; s > 0 { + w.Write([]byte{'%', hex[s/16], hex[s%16]}) + } + if s := r >> 16 & 0xff; s > 0 { + w.Write([]byte{'%', hex[s/16], hex[s%16]}) + } + if s := r >> 8 & 0xff; s > 0 { + w.Write([]byte{'%', hex[s/16], hex[s%16]}) + } + if s := r & 0xff; s > 0 { + w.Write([]byte{'%', hex[s/16], hex[s%16]}) + } +} + +func unhex(c byte) byte { + switch { + case '0' <= c && c <= '9': + return c - '0' + case 'a' <= c && c <= 'f': + return c - 'a' + 10 + case 'A' <= c && c <= 'F': + return c - 'A' + 10 + } + return 0 +} + +func ishex(c byte) bool { + switch { + case '0' <= c && c <= '9': + return true + case 'a' <= c && c <= 'f': + return true + case 'A' <= c && c <= 'F': + return true + default: + return false + } +} + +func pctDecode(s string) string { + size := len(s) + for i := 0; i < len(s); { + switch s[i] { + case '%': + size -= 2 + i += 3 + default: + i++ + } + } + if size == len(s) { + return s + } + + buf := make([]byte, size) + j := 0 + for i := 0; i < len(s); { + switch c := s[i]; c { + case '%': + buf[j] = unhex(s[i+1])<<4 | unhex(s[i+2]) + i += 3 + j++ + default: + buf[j] = c + i++ + j++ + } + } + return string(buf) +} + +type escapeFunc func(*strings.Builder, string) error + +func escapeLiteral(w *strings.Builder, v string) error { + w.WriteString(v) + return nil +} + +func escapeExceptU(w *strings.Builder, v string) error { + for i := 0; i < len(v); { + r, size := utf8.DecodeRuneInString(v[i:]) + if r == utf8.RuneError { + return errorf(i, "invalid encoding") + } + if unicode.Is(rangeUnreserved, r) { + w.WriteRune(r) + } else { + pctEncode(w, r) + } + i += size + } + return nil +} + +func escapeExceptUR(w *strings.Builder, v string) error { + for i := 0; i < len(v); { + r, size := utf8.DecodeRuneInString(v[i:]) + if r == utf8.RuneError { + return errorf(i, "invalid encoding") + } + // TODO(yosida95): is pct-encoded triplets allowed here? + if unicode.In(r, rangeUnreserved, rangeReserved) { + w.WriteRune(r) + } else { + pctEncode(w, r) + } + i += size + } + return nil +} diff --git a/vendor/github.com/yosida95/uritemplate/v3/expression.go b/vendor/github.com/yosida95/uritemplate/v3/expression.go new file mode 100644 index 0000000000..4858c2ddef --- /dev/null +++ b/vendor/github.com/yosida95/uritemplate/v3/expression.go @@ -0,0 +1,173 @@ +// Copyright (C) 2016 Kohei YOSHIDA. All rights reserved. +// +// This program is free software; you can redistribute it and/or +// modify it under the terms of The BSD 3-Clause License +// that can be found in the LICENSE file. + +package uritemplate + +import ( + "regexp" + "strconv" + "strings" +) + +type template interface { + expand(*strings.Builder, Values) error + regexp(*strings.Builder) +} + +type literals string + +func (l literals) expand(b *strings.Builder, _ Values) error { + b.WriteString(string(l)) + return nil +} + +func (l literals) regexp(b *strings.Builder) { + b.WriteString("(?:") + b.WriteString(regexp.QuoteMeta(string(l))) + b.WriteByte(')') +} + +type varspec struct { + name string + maxlen int + explode bool +} + +type expression struct { + vars []varspec + op parseOp + first string + sep string + named bool + ifemp string + escape escapeFunc + allow runeClass +} + +func (e *expression) init() { + switch e.op { + case parseOpSimple: + e.sep = "," + e.escape = escapeExceptU + e.allow = runeClassU + case parseOpPlus: + e.sep = "," + e.escape = escapeExceptUR + e.allow = runeClassUR + case parseOpCrosshatch: + e.first = "#" + e.sep = "," + e.escape = escapeExceptUR + e.allow = runeClassUR + case parseOpDot: + e.first = "." + e.sep = "." + e.escape = escapeExceptU + e.allow = runeClassU + case parseOpSlash: + e.first = "/" + e.sep = "/" + e.escape = escapeExceptU + e.allow = runeClassU + case parseOpSemicolon: + e.first = ";" + e.sep = ";" + e.named = true + e.escape = escapeExceptU + e.allow = runeClassU + case parseOpQuestion: + e.first = "?" + e.sep = "&" + e.named = true + e.ifemp = "=" + e.escape = escapeExceptU + e.allow = runeClassU + case parseOpAmpersand: + e.first = "&" + e.sep = "&" + e.named = true + e.ifemp = "=" + e.escape = escapeExceptU + e.allow = runeClassU + } +} + +func (e *expression) expand(w *strings.Builder, values Values) error { + first := true + for _, varspec := range e.vars { + value := values.Get(varspec.name) + if !value.Valid() { + continue + } + + if first { + w.WriteString(e.first) + first = false + } else { + w.WriteString(e.sep) + } + + if err := value.expand(w, varspec, e); err != nil { + return err + } + + } + return nil +} + +func (e *expression) regexp(b *strings.Builder) { + if e.first != "" { + b.WriteString("(?:") // $1 + b.WriteString(regexp.QuoteMeta(e.first)) + } + b.WriteByte('(') // $2 + runeClassToRegexp(b, e.allow, e.named || e.vars[0].explode) + if len(e.vars) > 1 || e.vars[0].explode { + max := len(e.vars) - 1 + for i := 0; i < len(e.vars); i++ { + if e.vars[i].explode { + max = -1 + break + } + } + + b.WriteString("(?:") // $3 + b.WriteString(regexp.QuoteMeta(e.sep)) + runeClassToRegexp(b, e.allow, e.named || max < 0) + b.WriteByte(')') // $3 + if max > 0 { + b.WriteString("{0,") + b.WriteString(strconv.Itoa(max)) + b.WriteByte('}') + } else { + b.WriteByte('*') + } + } + b.WriteByte(')') // $2 + if e.first != "" { + b.WriteByte(')') // $1 + } + b.WriteByte('?') +} + +func runeClassToRegexp(b *strings.Builder, class runeClass, named bool) { + b.WriteString("(?:(?:[") + if class&runeClassR == 0 { + b.WriteString(`\x2c`) + if named { + b.WriteString(`\x3d`) + } + } + if class&runeClassU == runeClassU { + b.WriteString(reUnreserved) + } + if class&runeClassR == runeClassR { + b.WriteString(reReserved) + } + b.WriteString("]") + b.WriteString("|%[[:xdigit:]][[:xdigit:]]") + b.WriteString(")*)") +} diff --git a/vendor/github.com/yosida95/uritemplate/v3/machine.go b/vendor/github.com/yosida95/uritemplate/v3/machine.go new file mode 100644 index 0000000000..7b1d0b518d --- /dev/null +++ b/vendor/github.com/yosida95/uritemplate/v3/machine.go @@ -0,0 +1,23 @@ +// Copyright (C) 2016 Kohei YOSHIDA. All rights reserved. +// +// This program is free software; you can redistribute it and/or +// modify it under the terms of The BSD 3-Clause License +// that can be found in the LICENSE file. + +package uritemplate + +// threadList implements https://research.swtch.com/sparse. +type threadList struct { + dense []threadEntry + sparse []uint32 +} + +type threadEntry struct { + pc uint32 + t *thread +} + +type thread struct { + op *progOp + cap map[string][]int +} diff --git a/vendor/github.com/yosida95/uritemplate/v3/match.go b/vendor/github.com/yosida95/uritemplate/v3/match.go new file mode 100644 index 0000000000..02fe6385a3 --- /dev/null +++ b/vendor/github.com/yosida95/uritemplate/v3/match.go @@ -0,0 +1,213 @@ +// Copyright (C) 2016 Kohei YOSHIDA. All rights reserved. +// +// This program is free software; you can redistribute it and/or +// modify it under the terms of The BSD 3-Clause License +// that can be found in the LICENSE file. + +package uritemplate + +import ( + "bytes" + "unicode" + "unicode/utf8" +) + +type matcher struct { + prog *prog + + list1 threadList + list2 threadList + matched bool + cap map[string][]int + + input string +} + +func (m *matcher) at(pos int) (rune, int, bool) { + if l := len(m.input); pos < l { + c := m.input[pos] + if c < utf8.RuneSelf { + return rune(c), 1, pos+1 < l + } + r, size := utf8.DecodeRuneInString(m.input[pos:]) + return r, size, pos+size < l + } + return -1, 0, false +} + +func (m *matcher) add(list *threadList, pc uint32, pos int, next bool, cap map[string][]int) { + if i := list.sparse[pc]; i < uint32(len(list.dense)) && list.dense[i].pc == pc { + return + } + + n := len(list.dense) + list.dense = list.dense[:n+1] + list.sparse[pc] = uint32(n) + + e := &list.dense[n] + e.pc = pc + e.t = nil + + op := &m.prog.op[pc] + switch op.code { + default: + panic("unhandled opcode") + case opRune, opRuneClass, opEnd: + e.t = &thread{ + op: &m.prog.op[pc], + cap: make(map[string][]int, len(m.cap)), + } + for k, v := range cap { + e.t.cap[k] = make([]int, len(v)) + copy(e.t.cap[k], v) + } + case opLineBegin: + if pos == 0 { + m.add(list, pc+1, pos, next, cap) + } + case opLineEnd: + if !next { + m.add(list, pc+1, pos, next, cap) + } + case opCapStart, opCapEnd: + ocap := make(map[string][]int, len(m.cap)) + for k, v := range cap { + ocap[k] = make([]int, len(v)) + copy(ocap[k], v) + } + ocap[op.name] = append(ocap[op.name], pos) + m.add(list, pc+1, pos, next, ocap) + case opSplit: + m.add(list, pc+1, pos, next, cap) + m.add(list, op.i, pos, next, cap) + case opJmp: + m.add(list, op.i, pos, next, cap) + case opJmpIfNotDefined: + m.add(list, pc+1, pos, next, cap) + m.add(list, op.i, pos, next, cap) + case opJmpIfNotFirst: + m.add(list, pc+1, pos, next, cap) + m.add(list, op.i, pos, next, cap) + case opJmpIfNotEmpty: + m.add(list, op.i, pos, next, cap) + m.add(list, pc+1, pos, next, cap) + case opNoop: + m.add(list, pc+1, pos, next, cap) + } +} + +func (m *matcher) step(clist *threadList, nlist *threadList, r rune, pos int, nextPos int, next bool) { + debug.Printf("===== %q =====", string(r)) + for i := 0; i < len(clist.dense); i++ { + e := clist.dense[i] + if debug { + var buf bytes.Buffer + dumpProg(&buf, m.prog, e.pc) + debug.Printf("\n%s", buf.String()) + } + if e.t == nil { + continue + } + + t := e.t + op := t.op + switch op.code { + default: + panic("unhandled opcode") + case opRune: + if op.r == r { + m.add(nlist, e.pc+1, nextPos, next, t.cap) + } + case opRuneClass: + ret := false + if !ret && op.rc&runeClassU == runeClassU { + ret = ret || unicode.Is(rangeUnreserved, r) + } + if !ret && op.rc&runeClassR == runeClassR { + ret = ret || unicode.Is(rangeReserved, r) + } + if !ret && op.rc&runeClassPctE == runeClassPctE { + ret = ret || unicode.Is(unicode.ASCII_Hex_Digit, r) + } + if ret { + m.add(nlist, e.pc+1, nextPos, next, t.cap) + } + case opEnd: + m.matched = true + for k, v := range t.cap { + m.cap[k] = make([]int, len(v)) + copy(m.cap[k], v) + } + clist.dense = clist.dense[:0] + } + } + clist.dense = clist.dense[:0] +} + +func (m *matcher) match() bool { + pos := 0 + clist, nlist := &m.list1, &m.list2 + for { + if len(clist.dense) == 0 && m.matched { + break + } + r, width, next := m.at(pos) + if !m.matched { + m.add(clist, 0, pos, next, m.cap) + } + m.step(clist, nlist, r, pos, pos+width, next) + + if width < 1 { + break + } + pos += width + + clist, nlist = nlist, clist + } + return m.matched +} + +func (tmpl *Template) Match(expansion string) Values { + tmpl.mu.Lock() + if tmpl.prog == nil { + c := compiler{} + c.init() + c.compile(tmpl) + tmpl.prog = c.prog + } + prog := tmpl.prog + tmpl.mu.Unlock() + + n := len(prog.op) + m := matcher{ + prog: prog, + list1: threadList{ + dense: make([]threadEntry, 0, n), + sparse: make([]uint32, n), + }, + list2: threadList{ + dense: make([]threadEntry, 0, n), + sparse: make([]uint32, n), + }, + cap: make(map[string][]int, prog.numCap), + input: expansion, + } + if !m.match() { + return nil + } + + match := make(Values, len(m.cap)) + for name, indices := range m.cap { + v := Value{V: make([]string, len(indices)/2)} + for i := range v.V { + v.V[i] = pctDecode(expansion[indices[2*i]:indices[2*i+1]]) + } + if len(v.V) == 1 { + v.T = ValueTypeString + } else { + v.T = ValueTypeList + } + match[name] = v + } + return match +} diff --git a/vendor/github.com/yosida95/uritemplate/v3/parse.go b/vendor/github.com/yosida95/uritemplate/v3/parse.go new file mode 100644 index 0000000000..fd38a682f1 --- /dev/null +++ b/vendor/github.com/yosida95/uritemplate/v3/parse.go @@ -0,0 +1,277 @@ +// Copyright (C) 2016 Kohei YOSHIDA. All rights reserved. +// +// This program is free software; you can redistribute it and/or +// modify it under the terms of The BSD 3-Clause License +// that can be found in the LICENSE file. + +package uritemplate + +import ( + "fmt" + "unicode" + "unicode/utf8" +) + +type parseOp int + +const ( + parseOpSimple parseOp = iota + parseOpPlus + parseOpCrosshatch + parseOpDot + parseOpSlash + parseOpSemicolon + parseOpQuestion + parseOpAmpersand +) + +var ( + rangeVarchar = &unicode.RangeTable{ + R16: []unicode.Range16{ + {Lo: 0x0030, Hi: 0x0039, Stride: 1}, // '0' - '9' + {Lo: 0x0041, Hi: 0x005A, Stride: 1}, // 'A' - 'Z' + {Lo: 0x005F, Hi: 0x005F, Stride: 1}, // '_' + {Lo: 0x0061, Hi: 0x007A, Stride: 1}, // 'a' - 'z' + }, + LatinOffset: 4, + } + rangeLiterals = &unicode.RangeTable{ + R16: []unicode.Range16{ + {Lo: 0x0021, Hi: 0x0021, Stride: 1}, // '!' + {Lo: 0x0023, Hi: 0x0024, Stride: 1}, // '#' - '$' + {Lo: 0x0026, Hi: 0x003B, Stride: 1}, // '&' ''' '(' - ';'. '''/27 used to be excluded but an errata is in the review process https://www.rfc-editor.org/errata/eid6937 + {Lo: 0x003D, Hi: 0x003D, Stride: 1}, // '=' + {Lo: 0x003F, Hi: 0x005B, Stride: 1}, // '?' - '[' + {Lo: 0x005D, Hi: 0x005D, Stride: 1}, // ']' + {Lo: 0x005F, Hi: 0x005F, Stride: 1}, // '_' + {Lo: 0x0061, Hi: 0x007A, Stride: 1}, // 'a' - 'z' + {Lo: 0x007E, Hi: 0x007E, Stride: 1}, // '~' + {Lo: 0x00A0, Hi: 0xD7FF, Stride: 1}, // ucschar + {Lo: 0xE000, Hi: 0xF8FF, Stride: 1}, // iprivate + {Lo: 0xF900, Hi: 0xFDCF, Stride: 1}, // ucschar + {Lo: 0xFDF0, Hi: 0xFFEF, Stride: 1}, // ucschar + }, + R32: []unicode.Range32{ + {Lo: 0x00010000, Hi: 0x0001FFFD, Stride: 1}, // ucschar + {Lo: 0x00020000, Hi: 0x0002FFFD, Stride: 1}, // ucschar + {Lo: 0x00030000, Hi: 0x0003FFFD, Stride: 1}, // ucschar + {Lo: 0x00040000, Hi: 0x0004FFFD, Stride: 1}, // ucschar + {Lo: 0x00050000, Hi: 0x0005FFFD, Stride: 1}, // ucschar + {Lo: 0x00060000, Hi: 0x0006FFFD, Stride: 1}, // ucschar + {Lo: 0x00070000, Hi: 0x0007FFFD, Stride: 1}, // ucschar + {Lo: 0x00080000, Hi: 0x0008FFFD, Stride: 1}, // ucschar + {Lo: 0x00090000, Hi: 0x0009FFFD, Stride: 1}, // ucschar + {Lo: 0x000A0000, Hi: 0x000AFFFD, Stride: 1}, // ucschar + {Lo: 0x000B0000, Hi: 0x000BFFFD, Stride: 1}, // ucschar + {Lo: 0x000C0000, Hi: 0x000CFFFD, Stride: 1}, // ucschar + {Lo: 0x000D0000, Hi: 0x000DFFFD, Stride: 1}, // ucschar + {Lo: 0x000E1000, Hi: 0x000EFFFD, Stride: 1}, // ucschar + {Lo: 0x000F0000, Hi: 0x000FFFFD, Stride: 1}, // iprivate + {Lo: 0x00100000, Hi: 0x0010FFFD, Stride: 1}, // iprivate + }, + LatinOffset: 10, + } +) + +type parser struct { + r string + start int + stop int + state parseState +} + +func (p *parser) errorf(i rune, format string, a ...interface{}) error { + return fmt.Errorf("%s: %s%s", fmt.Sprintf(format, a...), p.r[0:p.stop], string(i)) +} + +func (p *parser) rune() (rune, int) { + r, size := utf8.DecodeRuneInString(p.r[p.stop:]) + if r != utf8.RuneError { + p.stop += size + } + return r, size +} + +func (p *parser) unread(r rune) { + p.stop -= utf8.RuneLen(r) +} + +type parseState int + +const ( + parseStateDefault = parseState(iota) + parseStateOperator + parseStateVarList + parseStateVarName + parseStatePrefix +) + +func (p *parser) setState(state parseState) { + p.state = state + p.start = p.stop +} + +func (p *parser) parseURITemplate() (*Template, error) { + tmpl := Template{ + raw: p.r, + exprs: []template{}, + } + + var exp *expression + for { + r, size := p.rune() + if r == utf8.RuneError { + if size == 0 { + if p.state != parseStateDefault { + return nil, p.errorf('_', "incomplete expression") + } + if p.start < p.stop { + tmpl.exprs = append(tmpl.exprs, literals(p.r[p.start:p.stop])) + } + return &tmpl, nil + } + return nil, p.errorf('_', "invalid UTF-8 sequence") + } + + switch p.state { + case parseStateDefault: + switch r { + case '{': + if stop := p.stop - size; stop > p.start { + tmpl.exprs = append(tmpl.exprs, literals(p.r[p.start:stop])) + } + exp = &expression{} + tmpl.exprs = append(tmpl.exprs, exp) + p.setState(parseStateOperator) + case '%': + p.unread(r) + if err := p.consumeTriplet(); err != nil { + return nil, err + } + default: + if !unicode.Is(rangeLiterals, r) { + p.unread(r) + return nil, p.errorf('_', "unacceptable character (hint: use %%XX encoding)") + } + } + case parseStateOperator: + switch r { + default: + p.unread(r) + exp.op = parseOpSimple + case '+': + exp.op = parseOpPlus + case '#': + exp.op = parseOpCrosshatch + case '.': + exp.op = parseOpDot + case '/': + exp.op = parseOpSlash + case ';': + exp.op = parseOpSemicolon + case '?': + exp.op = parseOpQuestion + case '&': + exp.op = parseOpAmpersand + case '=', ',', '!', '@', '|': // op-reserved + return nil, p.errorf('|', "unimplemented operator (op-reserved)") + } + p.setState(parseStateVarName) + case parseStateVarList: + switch r { + case ',': + p.setState(parseStateVarName) + case '}': + exp.init() + p.setState(parseStateDefault) + default: + p.unread(r) + return nil, p.errorf('_', "unrecognized value modifier") + } + case parseStateVarName: + switch r { + case ':', '*': + name := p.r[p.start : p.stop-size] + if !isValidVarname(name) { + return nil, p.errorf('|', "unacceptable variable name") + } + explode := r == '*' + exp.vars = append(exp.vars, varspec{ + name: name, + explode: explode, + }) + if explode { + p.setState(parseStateVarList) + } else { + p.setState(parseStatePrefix) + } + case ',', '}': + p.unread(r) + name := p.r[p.start:p.stop] + if !isValidVarname(name) { + return nil, p.errorf('|', "unacceptable variable name") + } + exp.vars = append(exp.vars, varspec{ + name: name, + }) + p.setState(parseStateVarList) + case '%': + p.unread(r) + if err := p.consumeTriplet(); err != nil { + return nil, err + } + case '.': + if dot := p.stop - size; dot == p.start || p.r[dot-1] == '.' { + return nil, p.errorf('|', "unacceptable variable name") + } + default: + if !unicode.Is(rangeVarchar, r) { + p.unread(r) + return nil, p.errorf('_', "unacceptable variable name") + } + } + case parseStatePrefix: + spec := &(exp.vars[len(exp.vars)-1]) + switch { + case '0' <= r && r <= '9': + spec.maxlen *= 10 + spec.maxlen += int(r - '0') + if spec.maxlen == 0 || spec.maxlen > 9999 { + return nil, p.errorf('|', "max-length must be (0, 9999]") + } + default: + p.unread(r) + if spec.maxlen == 0 { + return nil, p.errorf('_', "max-length must be (0, 9999]") + } + p.setState(parseStateVarList) + } + default: + p.unread(r) + panic(p.errorf('_', "unhandled parseState(%d)", p.state)) + } + } +} + +func isValidVarname(name string) bool { + if l := len(name); l == 0 || name[0] == '.' || name[l-1] == '.' { + return false + } + for i := 1; i < len(name)-1; i++ { + switch c := name[i]; c { + case '.': + if name[i-1] == '.' { + return false + } + } + } + return true +} + +func (p *parser) consumeTriplet() error { + if len(p.r)-p.stop < 3 || p.r[p.stop] != '%' || !ishex(p.r[p.stop+1]) || !ishex(p.r[p.stop+2]) { + return p.errorf('_', "incomplete pct-encodeed") + } + p.stop += 3 + return nil +} diff --git a/vendor/github.com/yosida95/uritemplate/v3/prog.go b/vendor/github.com/yosida95/uritemplate/v3/prog.go new file mode 100644 index 0000000000..97af4f0eab --- /dev/null +++ b/vendor/github.com/yosida95/uritemplate/v3/prog.go @@ -0,0 +1,130 @@ +// Copyright (C) 2016 Kohei YOSHIDA. All rights reserved. +// +// This program is free software; you can redistribute it and/or +// modify it under the terms of The BSD 3-Clause License +// that can be found in the LICENSE file. + +package uritemplate + +import ( + "bytes" + "strconv" +) + +type progOpcode uint16 + +const ( + // match + opRune progOpcode = iota + opRuneClass + opLineBegin + opLineEnd + // capture + opCapStart + opCapEnd + // stack + opSplit + opJmp + opJmpIfNotDefined + opJmpIfNotEmpty + opJmpIfNotFirst + // result + opEnd + // fake + opNoop + opcodeMax +) + +var opcodeNames = []string{ + // match + "opRune", + "opRuneClass", + "opLineBegin", + "opLineEnd", + // capture + "opCapStart", + "opCapEnd", + // stack + "opSplit", + "opJmp", + "opJmpIfNotDefined", + "opJmpIfNotEmpty", + "opJmpIfNotFirst", + // result + "opEnd", +} + +func (code progOpcode) String() string { + if code >= opcodeMax { + return "" + } + return opcodeNames[code] +} + +type progOp struct { + code progOpcode + r rune + rc runeClass + i uint32 + + name string +} + +func dumpProgOp(b *bytes.Buffer, op *progOp) { + b.WriteString(op.code.String()) + switch op.code { + case opRune: + b.WriteString("(") + b.WriteString(strconv.QuoteToASCII(string(op.r))) + b.WriteString(")") + case opRuneClass: + b.WriteString("(") + b.WriteString(op.rc.String()) + b.WriteString(")") + case opCapStart, opCapEnd: + b.WriteString("(") + b.WriteString(strconv.QuoteToASCII(op.name)) + b.WriteString(")") + case opSplit: + b.WriteString(" -> ") + b.WriteString(strconv.FormatInt(int64(op.i), 10)) + case opJmp, opJmpIfNotFirst: + b.WriteString(" -> ") + b.WriteString(strconv.FormatInt(int64(op.i), 10)) + case opJmpIfNotDefined, opJmpIfNotEmpty: + b.WriteString("(") + b.WriteString(strconv.QuoteToASCII(op.name)) + b.WriteString(")") + b.WriteString(" -> ") + b.WriteString(strconv.FormatInt(int64(op.i), 10)) + } +} + +type prog struct { + op []progOp + numCap int +} + +func dumpProg(b *bytes.Buffer, prog *prog, pc uint32) { + for i := range prog.op { + op := prog.op[i] + + pos := strconv.Itoa(i) + if uint32(i) == pc { + pos = "*" + pos + } + b.WriteString(" "[len(pos):]) + b.WriteString(pos) + + b.WriteByte('\t') + dumpProgOp(b, &op) + + b.WriteByte('\n') + } +} + +func (p *prog) String() string { + b := bytes.Buffer{} + dumpProg(&b, p, 0) + return b.String() +} diff --git a/vendor/github.com/yosida95/uritemplate/v3/uritemplate.go b/vendor/github.com/yosida95/uritemplate/v3/uritemplate.go new file mode 100644 index 0000000000..dbd2673753 --- /dev/null +++ b/vendor/github.com/yosida95/uritemplate/v3/uritemplate.go @@ -0,0 +1,116 @@ +// Copyright (C) 2016 Kohei YOSHIDA. All rights reserved. +// +// This program is free software; you can redistribute it and/or +// modify it under the terms of The BSD 3-Clause License +// that can be found in the LICENSE file. + +package uritemplate + +import ( + "log" + "regexp" + "strings" + "sync" +) + +var ( + debug = debugT(false) +) + +type debugT bool + +func (t debugT) Printf(format string, v ...interface{}) { + if t { + log.Printf(format, v...) + } +} + +// Template represents a URI Template. +type Template struct { + raw string + exprs []template + + // protects the rest of fields + mu sync.Mutex + varnames []string + re *regexp.Regexp + prog *prog +} + +// New parses and constructs a new Template instance based on the template. +// New returns an error if the template cannot be recognized. +func New(template string) (*Template, error) { + return (&parser{r: template}).parseURITemplate() +} + +// MustNew panics if the template cannot be recognized. +func MustNew(template string) *Template { + ret, err := New(template) + if err != nil { + panic(err) + } + return ret +} + +// Raw returns a raw URI template passed to New in string. +func (t *Template) Raw() string { + return t.raw +} + +// Varnames returns variable names used in the template. +func (t *Template) Varnames() []string { + t.mu.Lock() + defer t.mu.Unlock() + if t.varnames != nil { + return t.varnames + } + + reg := map[string]struct{}{} + t.varnames = []string{} + for i := range t.exprs { + expr, ok := t.exprs[i].(*expression) + if !ok { + continue + } + for _, spec := range expr.vars { + if _, ok := reg[spec.name]; ok { + continue + } + reg[spec.name] = struct{}{} + t.varnames = append(t.varnames, spec.name) + } + } + + return t.varnames +} + +// Expand returns a URI reference corresponding to the template expanded using the passed variables. +func (t *Template) Expand(vars Values) (string, error) { + var w strings.Builder + for i := range t.exprs { + expr := t.exprs[i] + if err := expr.expand(&w, vars); err != nil { + return w.String(), err + } + } + return w.String(), nil +} + +// Regexp converts the template to regexp and returns compiled *regexp.Regexp. +func (t *Template) Regexp() *regexp.Regexp { + t.mu.Lock() + defer t.mu.Unlock() + if t.re != nil { + return t.re + } + + var b strings.Builder + b.WriteByte('^') + for _, expr := range t.exprs { + expr.regexp(&b) + } + b.WriteByte('$') + t.re = regexp.MustCompile(b.String()) + + return t.re +} diff --git a/vendor/github.com/yosida95/uritemplate/v3/value.go b/vendor/github.com/yosida95/uritemplate/v3/value.go new file mode 100644 index 0000000000..0550eabdbf --- /dev/null +++ b/vendor/github.com/yosida95/uritemplate/v3/value.go @@ -0,0 +1,216 @@ +// Copyright (C) 2016 Kohei YOSHIDA. All rights reserved. +// +// This program is free software; you can redistribute it and/or +// modify it under the terms of The BSD 3-Clause License +// that can be found in the LICENSE file. + +package uritemplate + +import "strings" + +// A varname containing pct-encoded characters is not the same variable as +// a varname with those same characters decoded. +// +// -- https://tools.ietf.org/html/rfc6570#section-2.3 +type Values map[string]Value + +func (v Values) Set(name string, value Value) { + v[name] = value +} + +func (v Values) Get(name string) Value { + if v == nil { + return Value{} + } + return v[name] +} + +type ValueType uint8 + +const ( + ValueTypeString = iota + ValueTypeList + ValueTypeKV + valueTypeLast +) + +var valueTypeNames = []string{ + "String", + "List", + "KV", +} + +func (vt ValueType) String() string { + if vt < valueTypeLast { + return valueTypeNames[vt] + } + return "" +} + +type Value struct { + T ValueType + V []string +} + +func (v Value) String() string { + if v.Valid() && v.T == ValueTypeString { + return v.V[0] + } + return "" +} + +func (v Value) List() []string { + if v.Valid() && v.T == ValueTypeList { + return v.V + } + return nil +} + +func (v Value) KV() []string { + if v.Valid() && v.T == ValueTypeKV { + return v.V + } + return nil +} + +func (v Value) Valid() bool { + switch v.T { + default: + return false + case ValueTypeString: + return len(v.V) > 0 + case ValueTypeList: + return len(v.V) > 0 + case ValueTypeKV: + return len(v.V) > 0 && len(v.V)%2 == 0 + } +} + +func (v Value) expand(w *strings.Builder, spec varspec, exp *expression) error { + switch v.T { + case ValueTypeString: + val := v.V[0] + var maxlen int + if max := len(val); spec.maxlen < 1 || spec.maxlen > max { + maxlen = max + } else { + maxlen = spec.maxlen + } + + if exp.named { + w.WriteString(spec.name) + if val == "" { + w.WriteString(exp.ifemp) + return nil + } + w.WriteByte('=') + } + return exp.escape(w, val[:maxlen]) + case ValueTypeList: + var sep string + if spec.explode { + sep = exp.sep + } else { + sep = "," + } + + var pre string + var preifemp string + if spec.explode && exp.named { + pre = spec.name + "=" + preifemp = spec.name + exp.ifemp + } + + if !spec.explode && exp.named { + w.WriteString(spec.name) + w.WriteByte('=') + } + for i := range v.V { + val := v.V[i] + if i > 0 { + w.WriteString(sep) + } + if val == "" { + w.WriteString(preifemp) + continue + } + w.WriteString(pre) + + if err := exp.escape(w, val); err != nil { + return err + } + } + case ValueTypeKV: + var sep string + var kvsep string + if spec.explode { + sep = exp.sep + kvsep = "=" + } else { + sep = "," + kvsep = "," + } + + var ifemp string + var kescape escapeFunc + if spec.explode && exp.named { + ifemp = exp.ifemp + kescape = escapeLiteral + } else { + ifemp = "," + kescape = exp.escape + } + + if !spec.explode && exp.named { + w.WriteString(spec.name) + w.WriteByte('=') + } + + for i := 0; i < len(v.V); i += 2 { + if i > 0 { + w.WriteString(sep) + } + if err := kescape(w, v.V[i]); err != nil { + return err + } + if v.V[i+1] == "" { + w.WriteString(ifemp) + continue + } + w.WriteString(kvsep) + + if err := exp.escape(w, v.V[i+1]); err != nil { + return err + } + } + } + return nil +} + +// String returns Value that represents string. +func String(v string) Value { + return Value{ + T: ValueTypeString, + V: []string{v}, + } +} + +// List returns Value that represents list. +func List(v ...string) Value { + return Value{ + T: ValueTypeList, + V: v, + } +} + +// KV returns Value that represents associative list. +// KV panics if len(kv) is not even. +func KV(kv ...string) Value { + if len(kv)%2 != 0 { + panic("uritemplate.go: count of the kv must be even number") + } + return Value{ + T: ValueTypeKV, + V: kv, + } +} diff --git a/vendor/modules.txt b/vendor/modules.txt index d2c03b10e9..56fdf01d91 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -465,6 +465,11 @@ github.com/lufia/plan9stats github.com/mailru/easyjson/buffer github.com/mailru/easyjson/jlexer github.com/mailru/easyjson/jwriter +# github.com/mark3labs/mcp-go v0.18.0 +## explicit; go 1.23 +github.com/mark3labs/mcp-go/client +github.com/mark3labs/mcp-go/mcp +github.com/mark3labs/mcp-go/server # github.com/maruel/natural v1.1.0 ## explicit; go 1.11 github.com/maruel/natural @@ -659,6 +664,9 @@ github.com/xanzy/ssh-agent # github.com/xi2/xz v0.0.0-20171230120015-48954b6210f8 ## explicit github.com/xi2/xz +# github.com/yosida95/uritemplate/v3 v3.0.2 +## explicit; go 1.14 +github.com/yosida95/uritemplate/v3 # github.com/yusufpapurcu/wmi v1.2.4 ## explicit; go 1.16 github.com/yusufpapurcu/wmi From 7f33790dc7ce61a1689e0dd25c2a999a2a6bd5ce Mon Sep 17 00:00:00 2001 From: Nathan Rijksen Date: Thu, 10 Apr 2025 09:11:49 -0700 Subject: [PATCH 11/29] Fix AI deciding to indent something that didn't need it --- .github/workflows/build.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 142505f11f..5a25eea307 100755 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -224,10 +224,10 @@ jobs: shell: bash run: parallelize results Build-Install-Scripts - - # === "Build: Executor" === - name: "Build: Executor" - shell: bash - run: parallelize results Build-Executor + - # === "Build: Executor" === + name: "Build: Executor" + shell: bash + run: parallelize results Build-Executor - # === "Build: MCP" === name: "Build: MCP" From 993a88b9c1ad807d6c82d186089edba7d71283e0 Mon Sep 17 00:00:00 2001 From: Nathan Rijksen Date: Thu, 10 Apr 2025 11:22:44 -0700 Subject: [PATCH 12/29] Fix prime being stale --- cmd/state-mcp/main.go | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/cmd/state-mcp/main.go b/cmd/state-mcp/main.go index fd092a7757..c549e8096e 100644 --- a/cmd/state-mcp/main.go +++ b/cmd/state-mcp/main.go @@ -119,8 +119,8 @@ func registerRawTools(mcpHandler *mcpServerHandler) func() error { // Require project directory for most commands. This is currently not encoded into the command tree if !sliceutils.Contains([]string{"projects", "auth"}, command.BaseCommand().Name()) { opts = append(opts, mcp.WithString( - "project_directory", - require(true), + "project_directory", + mcp.Required(), mcp.Description("Absolute path to the directory where your activestate project is checked out. It should contain the activestate.yaml file."), )) } @@ -138,7 +138,7 @@ func registerRawTools(mcpHandler *mcpServerHandler) func() error { } mcpHandler.addTool( mcp.NewTool(strings.Join(strings.Split(command.NameRecursive(), " "), "_"), opts...), - func(ctx context.Context, request mcp.CallToolRequest) (r *mcp.CallToolResult, rerr error) { + func(ctx context.Context, request mcp.CallToolRequest) (r *mcp.CallToolResult, rerr error) { byt.Truncate(0) if projectDir, ok := request.Params.Arguments["project_directory"]; ok { pj, err := project.FromPath(projectDir.(string)) @@ -147,6 +147,15 @@ func registerRawTools(mcpHandler *mcpServerHandler) func() error { } prime.SetProject(pj) } + // Reinitialize tree with updated primer, because currently our command can take things + // from the primer at the time of registration, and not the time of invocation. + invocationTree := donotshipme.CmdTree(prime) + for _, child := range invocationTree.Command().AllChildren() { + if child.NameRecursive() == command.NameRecursive() { + command = child + break + } + } args := strings.Split(command.NameRecursive(), " ") for _, arg := range command.Arguments() { v, ok := request.Params.Arguments[arg.Name] @@ -162,7 +171,7 @@ func registerRawTools(mcpHandler *mcpServerHandler) func() error { } args = append(args, fmt.Sprintf("--%s=%s", flag.Name, v.(string))) } - logging.Debug("Executing command: %s, args: %v (%v)", command.NameRecursive(), args, args==nil) + logging.Debug("Executing command: %s, args: %v (%v)", command.NameRecursive(), args, args == nil) err := command.Execute(args) if err != nil { return nil, errs.Wrap(err, "Failed to execute command") From 0a062a31f6433521c6c75d85255ad87b9ff2c258 Mon Sep 17 00:00:00 2001 From: Nathan Rijksen Date: Thu, 10 Apr 2025 11:22:54 -0700 Subject: [PATCH 13/29] Drop unused import --- internal/captain/command.go | 1 - 1 file changed, 1 deletion(-) diff --git a/internal/captain/command.go b/internal/captain/command.go index ca3aa5fc19..13288e04e9 100644 --- a/internal/captain/command.go +++ b/internal/captain/command.go @@ -24,7 +24,6 @@ import ( configMediator "github.com/ActiveState/cli/internal/mediators/config" "github.com/ActiveState/cli/internal/multilog" "github.com/ActiveState/cli/internal/osutils" - "github.com/ActiveState/cli/internal/osutils/stacktrace" "github.com/ActiveState/cli/internal/output" "github.com/ActiveState/cli/internal/profile" "github.com/ActiveState/cli/internal/rollbar" From c11c5c48752a78cef12deedde9a350a61dabe53a Mon Sep 17 00:00:00 2001 From: Nathan Rijksen Date: Thu, 10 Apr 2025 11:23:11 -0700 Subject: [PATCH 14/29] Disable tests --- cmd/state-mcp/server_test.go | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/cmd/state-mcp/server_test.go b/cmd/state-mcp/server_test.go index 9fbd210212..c43892e8e8 100644 --- a/cmd/state-mcp/server_test.go +++ b/cmd/state-mcp/server_test.go @@ -4,9 +4,12 @@ import ( "context" "encoding/json" "testing" + + "github.com/ActiveState/cli/internal/environment" ) -func TestServer(t *testing.T) { +func TestServerProjects(t *testing.T) { + t.Skip("Intended for manual testing") mcpHandler := registerServer() registerRawTools(mcpHandler) @@ -21,3 +24,22 @@ func TestServer(t *testing.T) { }`)) t.Fatalf("%+v", msg) } + +func TestServerPackages(t *testing.T) { + t.Skip("Intended for manual testing") + mcpHandler := registerServer() + registerRawTools(mcpHandler) + + msg := mcpHandler.mcpServer.HandleMessage(context.Background(), json.RawMessage(`{ + "jsonrpc": "2.0", + "id": 1, + "method": "tools/call", + "params": { + "name": "packages", + "arguments": { + "project_directory": "`+environment.GetRootPathUnsafe()+`" + } + } + }`)) + t.Fatalf("%+v", msg) +} From c8dc1b33195aba917ea0e59d3c2405222e7e80f4 Mon Sep 17 00:00:00 2001 From: Nathan Rijksen Date: Thu, 10 Apr 2025 11:27:50 -0700 Subject: [PATCH 15/29] Add missing comma --- .github/workflows/build.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 5a25eea307..8fb46dd284 100755 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -170,7 +170,7 @@ jobs: { "ID": "Build-Executor", "Args": ["state", "run", "build-exec"] - } + }, { "ID": "Build-MCP", "Args": ["state", "run", "build-mcp"] From af5769a3fd2bab06bdc655fde3a2b80a3b7e6e44 Mon Sep 17 00:00:00 2001 From: Nathan Rijksen Date: Thu, 10 Apr 2025 12:32:41 -0700 Subject: [PATCH 16/29] Update test now that we work without HOME set --- internal/osutils/user/user_test.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/internal/osutils/user/user_test.go b/internal/osutils/user/user_test.go index 27c5010622..9a487750a4 100644 --- a/internal/osutils/user/user_test.go +++ b/internal/osutils/user/user_test.go @@ -39,6 +39,5 @@ func TestNoHome(t *testing.T) { defer func() { os.Setenv("USERPROFILE", osHomeDir) }() } _, err = HomeDir() - assert.Error(t, err) - assert.Contains(t, err.Error(), "HOME environment variable is unset") + assert.NoError(t, err) } From 93a552908e663276ebb16c2ffc849d2848ffb578 Mon Sep 17 00:00:00 2001 From: Nathan Rijksen Date: Thu, 10 Apr 2025 14:11:41 -0700 Subject: [PATCH 17/29] Include state-mcp in update --- internal/constants/constants.go | 3 +++ scripts/ci/payload-generator/main.go | 1 + 2 files changed, 4 insertions(+) diff --git a/internal/constants/constants.go b/internal/constants/constants.go index da5104fe4b..29319a310a 100644 --- a/internal/constants/constants.go +++ b/internal/constants/constants.go @@ -436,6 +436,9 @@ const InstallerName = "State Installer" // StateExecutorCmd is the name of the state executor binary const StateExecutorCmd = "state-exec" +// StateMCPCmd is the name of the state mcp binary +const StateMCPCmd = "state-mcp" + // LegacyToplevelInstallArchiveDir is the top-level directory for files in an installation archive // This constant will be removed in DX-2081. const LegacyToplevelInstallArchiveDir = "state-install" diff --git a/scripts/ci/payload-generator/main.go b/scripts/ci/payload-generator/main.go index d264c37853..840bda1328 100644 --- a/scripts/ci/payload-generator/main.go +++ b/scripts/ci/payload-generator/main.go @@ -71,6 +71,7 @@ func generatePayload(inDir, outDir, binDir, channel, version string) error { filepath.Join(inDir, constants.StateCmd+osutils.ExeExtension): binDir, filepath.Join(inDir, constants.StateSvcCmd+osutils.ExeExtension): binDir, filepath.Join(inDir, constants.StateExecutorCmd+osutils.ExeExtension): binDir, + filepath.Join(inDir, constants.StateMCPCmd+osutils.ExeExtension): binDir, } if err := copyFiles(files); err != nil { return fmt.Errorf(emsg, err) From 5e29acd98389f143aec28ce335e51fbfa79b2351 Mon Sep 17 00:00:00 2001 From: Nathan Rijksen Date: Thu, 10 Apr 2025 14:16:25 -0700 Subject: [PATCH 18/29] Windows specific build target --- activestate.windows.yaml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/activestate.windows.yaml b/activestate.windows.yaml index 76f5251757..1ffd7363e7 100644 --- a/activestate.windows.yaml +++ b/activestate.windows.yaml @@ -7,6 +7,8 @@ constants: value: state-exec.exe - name: BUILD_INSTALLER_TARGET value: state-installer.exe + - name: BUILD_MCP_TARGET + value: state-mcp.exe - name: SVC_BUILDFLAGS value: -ldflags="-s -w -H=windowsgui" - name: SCRIPT_EXT From c7148779e9f77ddd0b05c1cc3a0a796dde7d12ce Mon Sep 17 00:00:00 2001 From: Nathan Rijksen Date: Fri, 11 Apr 2025 08:54:29 -0700 Subject: [PATCH 19/29] Always print what we're building --- activestate.yaml | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/activestate.yaml b/activestate.yaml index d41910a6be..ab3a71ae8b 100644 --- a/activestate.yaml +++ b/activestate.yaml @@ -110,7 +110,9 @@ scripts: go generate popd > /dev/null fi - go build -tags "$GO_BUILD_TAGS" -o $BUILD_TARGET_DIR/$constants.BUILD_TARGET $constants.CLI_BUILDFLAGS $constants.CLI_PKGS + TARGET=$BUILD_TARGET_DIR/$constants.BUILD_TARGET + echo "Building $TARGET" + go build -tags "$GO_BUILD_TAGS" -o $TARGET $constants.CLI_BUILDFLAGS $constants.CLI_PKGS - name: build-svc language: bash standalone: true @@ -125,7 +127,9 @@ scripts: go generate popd > /dev/null fi - go build -tags "$GO_BUILD_TAGS" -o $BUILD_TARGET_DIR/$constants.BUILD_DAEMON_TARGET $constants.SVC_BUILDFLAGS $constants.DAEMON_PKGS + TARGET=$BUILD_TARGET_DIR/$constants.BUILD_DAEMON_TARGET + echo "Building $TARGET" + go build -tags "$GO_BUILD_TAGS" -o $TARGET $constants.SVC_BUILDFLAGS $constants.DAEMON_PKGS - name: build-exec description: Builds the State Executor application language: bash @@ -133,8 +137,9 @@ scripts: value: | set -e $constants.SET_ENV - - go build -tags "$GO_BUILD_TAGS" -o $BUILD_TARGET_DIR/$constants.BUILD_EXEC_TARGET $constants.CLI_BUILDFLAGS $constants.EXECUTOR_PKGS + TARGET=$BUILD_TARGET_DIR/$constants.BUILD_EXEC_TARGET + echo "Building $TARGET" + go build -tags "$GO_BUILD_TAGS" -o $TARGET $constants.CLI_BUILDFLAGS $constants.EXECUTOR_PKGS - name: build-mcp description: Builds the State MCP application language: bash @@ -142,7 +147,9 @@ scripts: value: | set -e $constants.SET_ENV - go build -tags "$GO_BUILD_TAGS" -o $BUILD_TARGET_DIR/$constants.BUILD_MCP_TARGET $constants.CLI_BUILDFLAGS $constants.MCP_PKGS + TARGET=$BUILD_TARGET_DIR/$constants.BUILD_MCP_TARGET + echo "Building $TARGET" + go build -tags "$GO_BUILD_TAGS" -o $TARGET $constants.CLI_BUILDFLAGS $constants.MCP_PKGS - name: build-all description: Builds all our tools language: bash From 5cc036dea65f701e94bc07f1569e7a63f2c50347 Mon Sep 17 00:00:00 2001 From: Nathan Rijksen Date: Fri, 11 Apr 2025 08:54:37 -0700 Subject: [PATCH 20/29] Debug copy error --- scripts/ci/payload-generator/main.go | 1 + 1 file changed, 1 insertion(+) diff --git a/scripts/ci/payload-generator/main.go b/scripts/ci/payload-generator/main.go index 840bda1328..d50e1bd235 100644 --- a/scripts/ci/payload-generator/main.go +++ b/scripts/ci/payload-generator/main.go @@ -108,6 +108,7 @@ func copyFiles(files map[string]string) error { dest := filepath.Join(target, filepath.Base(src)) if err := fileutils.CopyFile(src, dest); err != nil { + fmt.Printf("Files in %s: %+v\n", filepath.Dir(src), fileutils.ListFilesUnsafe(filepath.Dir(src))) return fmt.Errorf("copy files (%s to %s): %w", src, target, err) } } From 6286ad5a1a35440638dc24f6cce9c2582a948db6 Mon Sep 17 00:00:00 2001 From: Nathan Rijksen Date: Fri, 11 Apr 2025 09:09:18 -0700 Subject: [PATCH 21/29] Re-org by AI --- cmd/state-mcp/handlers.go | 103 ++++++++++ cmd/state-mcp/main.go | 414 +------------------------------------- cmd/state-mcp/primer.go | 90 +++++++++ cmd/state-mcp/server.go | 109 ++++++++++ cmd/state-mcp/tools.go | 139 +++++++++++++ 5 files changed, 442 insertions(+), 413 deletions(-) create mode 100644 cmd/state-mcp/handlers.go create mode 100644 cmd/state-mcp/primer.go create mode 100644 cmd/state-mcp/server.go create mode 100644 cmd/state-mcp/tools.go diff --git a/cmd/state-mcp/handlers.go b/cmd/state-mcp/handlers.go new file mode 100644 index 0000000000..cc72209b67 --- /dev/null +++ b/cmd/state-mcp/handlers.go @@ -0,0 +1,103 @@ +package main + +import ( + "bytes" + "context" + "encoding/json" + "strings" + + "github.com/ActiveState/cli/internal/errs" + "github.com/ActiveState/cli/internal/runners/cve" + "github.com/ActiveState/cli/internal/runners/manifest" + "github.com/ActiveState/cli/internal/runners/projects" + "github.com/mark3labs/mcp-go/mcp" +) + +// listProjectsHandler handles the list_projects tool +func (t *mcpServerHandler) listProjectsHandler(ctx context.Context, request mcp.CallToolRequest) (r *mcp.CallToolResult, rerr error) { + var byt bytes.Buffer + prime, close, err := t.newPrimer("", &byt) + if err != nil { + return nil, errs.Wrap(err, "Failed to create primer") + } + defer func() { + if err := close(); err != nil { + rerr = errs.Pack(rerr, err) + } + }() + + runner := projects.NewProjects(prime) + params := projects.NewParams() + err = runner.Run(params) + if err != nil { + return nil, errs.Wrap(err, "Failed to run projects") + } + + return mcp.NewToolResultText(byt.String()), nil +} + +// manifestHandler handles the view_manifest tool +func (t *mcpServerHandler) manifestHandler(ctx context.Context, request mcp.CallToolRequest) (r *mcp.CallToolResult, rerr error) { + pjPath := request.Params.Arguments["project_directory"].(string) + + var byt bytes.Buffer + prime, close, err := t.newPrimer(pjPath, &byt) + if err != nil { + return nil, errs.Wrap(err, "Failed to create primer") + } + defer func() { + if err := close(); err != nil { + rerr = errs.Pack(rerr, err) + } + }() + + m := manifest.NewManifest(prime) + err = m.Run(manifest.Params{}) + if err != nil { + return nil, errs.Wrap(err, "Failed to run manifest") + } + + return mcp.NewToolResultText(byt.String()), nil +} + +// cveHandler handles the view_cves tool +func (t *mcpServerHandler) cveHandler(ctx context.Context, request mcp.CallToolRequest) (r *mcp.CallToolResult, rerr error) { + pjPath := request.Params.Arguments["project_directory"].(string) + + var byt bytes.Buffer + prime, close, err := t.newPrimer(pjPath, &byt) + if err != nil { + return nil, errs.Wrap(err, "Failed to create primer") + } + defer func() { + if err := close(); err != nil { + rerr = errs.Pack(rerr, err) + } + }() + + c := cve.NewCve(prime) + err = c.Run(&cve.Params{}) + if err != nil { + return nil, errs.Wrap(err, "Failed to run manifest") + } + + return mcp.NewToolResultText(byt.String()), nil +} + +// lookupCveHandler handles the lookup_cve tool +func (t *mcpServerHandler) lookupCveHandler(ctx context.Context, request mcp.CallToolRequest) (r *mcp.CallToolResult, rerr error) { + cveId := request.Params.Arguments["cve_ids"].(string) + cveIds := strings.Split(cveId, ",") + + results, err := LookupCve(cveIds...) + if err != nil { + return nil, errs.Wrap(err, "Failed to lookup CVEs") + } + + byt, err := json.Marshal(results) + if err != nil { + return nil, errs.Wrap(err, "Failed to marshal results") + } + + return mcp.NewToolResultText(string(byt)), nil +} \ No newline at end of file diff --git a/cmd/state-mcp/main.go b/cmd/state-mcp/main.go index c549e8096e..6a35a1bc24 100644 --- a/cmd/state-mcp/main.go +++ b/cmd/state-mcp/main.go @@ -1,38 +1,13 @@ package main import ( - "bytes" - "context" - "encoding/json" "flag" "fmt" "os" - "strings" "time" - "github.com/ActiveState/cli/cmd/state/donotshipme" - "github.com/ActiveState/cli/internal/config" - "github.com/ActiveState/cli/internal/constants" - "github.com/ActiveState/cli/internal/constraints" - "github.com/ActiveState/cli/internal/errs" "github.com/ActiveState/cli/internal/events" - "github.com/ActiveState/cli/internal/installation" - "github.com/ActiveState/cli/internal/ipc" "github.com/ActiveState/cli/internal/logging" - "github.com/ActiveState/cli/internal/multilog" - "github.com/ActiveState/cli/internal/output" - "github.com/ActiveState/cli/internal/primer" - "github.com/ActiveState/cli/internal/runners/cve" - "github.com/ActiveState/cli/internal/runners/manifest" - "github.com/ActiveState/cli/internal/runners/projects" - "github.com/ActiveState/cli/internal/sliceutils" - "github.com/ActiveState/cli/internal/subshell" - "github.com/ActiveState/cli/internal/svcctl" - "github.com/ActiveState/cli/pkg/platform/authentication" - "github.com/ActiveState/cli/pkg/platform/model" - "github.com/ActiveState/cli/pkg/project" - "github.com/ActiveState/cli/pkg/projectfile" - "github.com/mark3labs/mcp-go/mcp" "github.com/mark3labs/mcp-go/server" ) @@ -68,391 +43,4 @@ func main() { if err := server.ServeStdio(mcpHandler.mcpServer); err != nil { logging.Error("Server error: %v\n", err) } -} - -func registerServer() *mcpServerHandler { - ipcClient, svcPort, err := connectToSvc() - if err != nil { - panic(errs.JoinMessage(err)) - } - - // Create MCP server - s := server.NewMCPServer( - constants.CommandName, - constants.VersionNumber, - ) - - mcpHandler := &mcpServerHandler{ - mcpServer: s, - ipcClient: ipcClient, - svcPort: svcPort, - } - - return mcpHandler -} - -func registerRawTools(mcpHandler *mcpServerHandler) func() error { - byt := &bytes.Buffer{} - prime, close, err := mcpHandler.newPrimer("", byt) - if err != nil { - panic(err) - } - - require := func(b bool) mcp.PropertyOption { - if b { - return mcp.Required() - } - return func(map[string]interface{}) {} - } - - tree := donotshipme.CmdTree(prime) - for _, command := range tree.Command().AllChildren() { - // Best effort to filter out interactive commands - if sliceutils.Contains([]string{"activate", "shell"}, command.NameRecursive()) { - continue - } - - opts := []mcp.ToolOption{ - mcp.WithDescription(command.Description()), - } - - // Require project directory for most commands. This is currently not encoded into the command tree - if !sliceutils.Contains([]string{"projects", "auth"}, command.BaseCommand().Name()) { - opts = append(opts, mcp.WithString( - "project_directory", - mcp.Required(), - mcp.Description("Absolute path to the directory where your activestate project is checked out. It should contain the activestate.yaml file."), - )) - } - - for _, arg := range command.Arguments() { - opts = append(opts, mcp.WithString(arg.Name, - require(arg.Required), - mcp.Description(arg.Description), - )) - } - for _, flag := range command.Flags() { - opts = append(opts, mcp.WithString(flag.Name, - mcp.Description(flag.Description), - )) - } - mcpHandler.addTool( - mcp.NewTool(strings.Join(strings.Split(command.NameRecursive(), " "), "_"), opts...), - func(ctx context.Context, request mcp.CallToolRequest) (r *mcp.CallToolResult, rerr error) { - byt.Truncate(0) - if projectDir, ok := request.Params.Arguments["project_directory"]; ok { - pj, err := project.FromPath(projectDir.(string)) - if err != nil { - return nil, errs.Wrap(err, "Failed to create project") - } - prime.SetProject(pj) - } - // Reinitialize tree with updated primer, because currently our command can take things - // from the primer at the time of registration, and not the time of invocation. - invocationTree := donotshipme.CmdTree(prime) - for _, child := range invocationTree.Command().AllChildren() { - if child.NameRecursive() == command.NameRecursive() { - command = child - break - } - } - args := strings.Split(command.NameRecursive(), " ") - for _, arg := range command.Arguments() { - v, ok := request.Params.Arguments[arg.Name] - if !ok { - break - } - args = append(args, v.(string)) - } - for _, flag := range command.Flags() { - v, ok := request.Params.Arguments[flag.Name] - if !ok { - break - } - args = append(args, fmt.Sprintf("--%s=%s", flag.Name, v.(string))) - } - logging.Debug("Executing command: %s, args: %v (%v)", command.NameRecursive(), args, args == nil) - err := command.Execute(args) - if err != nil { - return nil, errs.Wrap(err, "Failed to execute command") - } - return mcp.NewToolResultText(byt.String()), nil - }, - ) - } - - return close -} - -func registerCuratedTools(mcpHandler *mcpServerHandler) { - projectDirParam := mcp.WithString("project_directory", - mcp.Required(), - mcp.Description("Absolute path to the directory where your activestate project is checked out. It should contain the activestate.yaml file."), - ) - - mcpHandler.addTool(mcp.NewTool("list_projects", - mcp.WithDescription("List all ActiveState projects checked out on the local machine"), - ), mcpHandler.listProjectsHandler) - - mcpHandler.addTool(mcp.NewTool("view_manifest", - mcp.WithDescription("Show the manifest (packages and dependencies) for a locally checked out ActiveState platform project"), - projectDirParam, - ), mcpHandler.manifestHandler) - - mcpHandler.addTool(mcp.NewTool("view_cves", - mcp.WithDescription("Show the CVEs for a locally checked out ActiveState platform project"), - projectDirParam, - ), mcpHandler.cveHandler) - - mcpHandler.addTool(mcp.NewTool("lookup_cve", - mcp.WithDescription("Lookup one or more CVEs by their ID"), - mcp.WithString("cve_ids", - mcp.Required(), - mcp.Description("The IDs of the CVEs to lookup, comma separated"), - ), - ), mcpHandler.lookupCveHandler) -} - -type mcpServerHandler struct { - mcpServer *server.MCPServer - ipcClient *ipc.Client - svcPort string -} - -func (t *mcpServerHandler) addResource(resource mcp.Resource, handler server.ResourceHandlerFunc) { - t.mcpServer.AddResource(resource, func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { - defer func() { - if r := recover(); r != nil { - logging.Error("Recovered from resource handler panic: %v", r) - fmt.Printf("Recovered from resource handler panic: %v\n", r) - } - }() - logging.Debug("Received resource request: %s", resource.Name) - r, err := handler(ctx, request) - if err != nil { - logging.Error("%s: Error handling resource request: %v", resource.Name, err) - return nil, errs.Wrap(err, "Failed to handle resource request") - } - return r, nil - }) -} - -func (t *mcpServerHandler) addTool(tool mcp.Tool, handler server.ToolHandlerFunc) { - t.mcpServer.AddTool(tool, func(ctx context.Context, request mcp.CallToolRequest) (r *mcp.CallToolResult, rerr error) { - defer func() { - if r := recover(); r != nil { - logging.Error("Recovered from tool handler panic: %v", r) - fmt.Printf("Recovered from tool handler panic: %v\n", r) - } - }() - logging.Debug("Received tool request: %s", tool.Name) - r, err := handler(ctx, request) - logging.Debug("Received tool response from %s", tool.Name) - if err != nil { - logging.Error("%s: Error handling tool request: %v", tool.Name, errs.JoinMessage(err)) - // Format all errors as a single string, so the client gets the full context - return nil, fmt.Errorf("%s: %s", tool.Name, errs.JoinMessage(err)) - } - return r, nil - }) -} - -func (t *mcpServerHandler) listProjectsHandler(ctx context.Context, request mcp.CallToolRequest) (r *mcp.CallToolResult, rerr error) { - var byt bytes.Buffer - prime, close, err := t.newPrimer("", &byt) - if err != nil { - return nil, errs.Wrap(err, "Failed to create primer") - } - defer func() { - if err := close(); err != nil { - rerr = errs.Pack(rerr, err) - } - }() - - runner := projects.NewProjects(prime) - params := projects.NewParams() - err = runner.Run(params) - if err != nil { - return nil, errs.Wrap(err, "Failed to run projects") - } - - return mcp.NewToolResultText(byt.String()), nil -} - -func (t *mcpServerHandler) listProjectsResourceHandler(ctx context.Context, request mcp.ReadResourceRequest) (r []mcp.ResourceContents, rerr error) { - var byt bytes.Buffer - prime, close, err := t.newPrimer("", &byt) - if err != nil { - return nil, errs.Wrap(err, "Failed to create primer") - } - defer func() { - if err := close(); err != nil { - rerr = errs.Pack(rerr, err) - } - }() - - runner := projects.NewProjects(prime) - params := projects.NewParams() - err = runner.Run(params) - if err != nil { - return nil, errs.Wrap(err, "Failed to run projects") - } - - r = append(r, mcp.TextResourceContents{Text: byt.String()}) - return r, nil -} - -func (t *mcpServerHandler) manifestHandler(ctx context.Context, request mcp.CallToolRequest) (r *mcp.CallToolResult, rerr error) { - pjPath := request.Params.Arguments["project_directory"].(string) - - var byt bytes.Buffer - prime, close, err := t.newPrimer(pjPath, &byt) - if err != nil { - return nil, errs.Wrap(err, "Failed to create primer") - } - defer func() { - if err := close(); err != nil { - rerr = errs.Pack(rerr, err) - } - }() - - m := manifest.NewManifest(prime) - err = m.Run(manifest.Params{}) - if err != nil { - return nil, errs.Wrap(err, "Failed to run manifest") - } - - return mcp.NewToolResultText(byt.String()), nil -} - -func (t *mcpServerHandler) cveHandler(ctx context.Context, request mcp.CallToolRequest) (r *mcp.CallToolResult, rerr error) { - pjPath := request.Params.Arguments["project_directory"].(string) - - var byt bytes.Buffer - prime, close, err := t.newPrimer(pjPath, &byt) - if err != nil { - return nil, errs.Wrap(err, "Failed to create primer") - } - defer func() { - if err := close(); err != nil { - rerr = errs.Pack(rerr, err) - } - }() - - c := cve.NewCve(prime) - err = c.Run(&cve.Params{}) - if err != nil { - return nil, errs.Wrap(err, "Failed to run manifest") - } - - return mcp.NewToolResultText(byt.String()), nil -} - -func (t *mcpServerHandler) lookupCveHandler(ctx context.Context, request mcp.CallToolRequest) (r *mcp.CallToolResult, rerr error) { - cveId := request.Params.Arguments["cve_ids"].(string) - cveIds := strings.Split(cveId, ",") - - results, err := LookupCve(cveIds...) - if err != nil { - return nil, errs.Wrap(err, "Failed to lookup CVEs") - } - - byt, err := json.Marshal(results) - if err != nil { - return nil, errs.Wrap(err, "Failed to marshal results") - } - - return mcp.NewToolResultText(string(byt)), nil -} - -type stdOutput struct{} - -func (s *stdOutput) Notice(msg interface{}) { - logging.Info(fmt.Sprintf("%v", msg)) -} - -func connectToSvc() (*ipc.Client, string, error) { - svcExec, err := installation.ServiceExec() - if err != nil { - return nil, "", errs.Wrap(err, "Could not get service info") - } - - ipcClient := svcctl.NewDefaultIPCClient() - argText := strings.Join(os.Args, " ") - svcPort, err := svcctl.EnsureExecStartedAndLocateHTTP(ipcClient, svcExec, argText, &stdOutput{}) - if err != nil { - return nil, "", errs.Wrap(err, "Failed to start state-svc at state tool invocation") - } - - return ipcClient, svcPort, nil -} - -func (t *mcpServerHandler) newPrimer(projectDir string, o *bytes.Buffer) (*primer.Values, func() error, error) { - closers := []func() error{} - closer := func() error { - for _, c := range closers { - if err := c(); err != nil { - return err - } - } - return nil - } - - cfg, err := config.New() - if err != nil { - return nil, closer, errs.Wrap(err, "Failed to create config") - } - closers = append(closers, cfg.Close) - - auth := authentication.New(cfg) - closers = append(closers, auth.Close) - - out, err := output.New(string(output.SimpleFormatName), &output.Config{ - OutWriter: o, - ErrWriter: o, - Colored: false, - Interactive: false, - ShellName: "", - }) - if err != nil { - return nil, closer, errs.Wrap(err, "Failed to create output") - } - - var pj *project.Project - if projectDir != "" { - pjf, err := projectfile.FromPath(projectDir) - if err != nil { - return nil, closer, errs.Wrap(err, "Failed to create projectfile") - } - pj, err = project.New(pjf, out) - if err != nil { - return nil, closer, errs.Wrap(err, "Failed to create project") - } - } - - // Set up conditional, which accesses a lot of primer data - sshell := subshell.New(cfg) - - conditional := constraints.NewPrimeConditional(auth, pj, sshell.Shell()) - project.RegisterConditional(conditional) - if err := project.RegisterExpander("mixin", project.NewMixin(auth).Expander); err != nil { - logging.Debug("Could not register mixin expander: %v", err) - } - - svcmodel := model.NewSvcModel(t.svcPort) - - if auth.AvailableAPIToken() != "" { - jwt, err := svcmodel.GetJWT(context.Background()) - if err != nil { - multilog.Critical("Could not get JWT: %v", errs.JoinMessage(err)) - } - if err != nil || jwt == nil { - // Could not authenticate; user got logged out - auth.Logout() - } else { - auth.UpdateSession(jwt) - } - } - - return primer.New(pj, out, auth, sshell, conditional, cfg, t.ipcClient, svcmodel), closer, nil -} +} \ No newline at end of file diff --git a/cmd/state-mcp/primer.go b/cmd/state-mcp/primer.go new file mode 100644 index 0000000000..7215cd42dd --- /dev/null +++ b/cmd/state-mcp/primer.go @@ -0,0 +1,90 @@ +package main + +import ( + "context" + "io" + + "github.com/ActiveState/cli/internal/config" + "github.com/ActiveState/cli/internal/constraints" + "github.com/ActiveState/cli/internal/errs" + "github.com/ActiveState/cli/internal/logging" + "github.com/ActiveState/cli/internal/multilog" + "github.com/ActiveState/cli/internal/output" + "github.com/ActiveState/cli/internal/primer" + "github.com/ActiveState/cli/internal/subshell" + "github.com/ActiveState/cli/pkg/platform/authentication" + "github.com/ActiveState/cli/pkg/platform/model" + "github.com/ActiveState/cli/pkg/project" + "github.com/ActiveState/cli/pkg/projectfile" +) + +// newPrimer creates a new primer.Values instance for use with command execution +func (t *mcpServerHandler) newPrimer(projectDir string, o io.Writer) (*primer.Values, func() error, error) { + closers := []func() error{} + closer := func() error { + for _, c := range closers { + if err := c(); err != nil { + return err + } + } + return nil + } + + cfg, err := config.New() + if err != nil { + return nil, closer, errs.Wrap(err, "Failed to create config") + } + closers = append(closers, cfg.Close) + + auth := authentication.New(cfg) + closers = append(closers, auth.Close) + + out, err := output.New(string(output.SimpleFormatName), &output.Config{ + OutWriter: o, + ErrWriter: o, + Colored: false, + Interactive: false, + ShellName: "", + }) + if err != nil { + return nil, closer, errs.Wrap(err, "Failed to create output") + } + + var pj *project.Project + if projectDir != "" { + pjf, err := projectfile.FromPath(projectDir) + if err != nil { + return nil, closer, errs.Wrap(err, "Failed to create projectfile") + } + pj, err = project.New(pjf, out) + if err != nil { + return nil, closer, errs.Wrap(err, "Failed to create project") + } + } + + // Set up conditional, which accesses a lot of primer data + sshell := subshell.New(cfg) + + conditional := constraints.NewPrimeConditional(auth, pj, sshell.Shell()) + project.RegisterConditional(conditional) + if err := project.RegisterExpander("mixin", project.NewMixin(auth).Expander); err != nil { + logging.Debug("Could not register mixin expander: %v", err) + } + + svcmodel := model.NewSvcModel(t.svcPort) + + if auth.AvailableAPIToken() != "" { + jwt, err := svcmodel.GetJWT(context.Background()) + if err != nil { + multilog.Critical("Could not get JWT: %v", errs.JoinMessage(err)) + } + if err != nil || jwt == nil { + // Could not authenticate; user got logged out + auth.Logout() + } else { + auth.UpdateSession(jwt) + } + } + + return primer.New(pj, out, auth, sshell, conditional, cfg, t.ipcClient, svcmodel), closer, nil +} \ No newline at end of file diff --git a/cmd/state-mcp/server.go b/cmd/state-mcp/server.go new file mode 100644 index 0000000000..7a4461c98d --- /dev/null +++ b/cmd/state-mcp/server.go @@ -0,0 +1,109 @@ +package main + +import ( + "context" + "fmt" + "os" + "strings" + + "github.com/ActiveState/cli/internal/constants" + "github.com/ActiveState/cli/internal/errs" + "github.com/ActiveState/cli/internal/installation" + "github.com/ActiveState/cli/internal/ipc" + "github.com/ActiveState/cli/internal/logging" + "github.com/ActiveState/cli/internal/svcctl" + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" +) + +// mcpServerHandler wraps the MCP server and provides methods for adding tools and resources +type mcpServerHandler struct { + mcpServer *server.MCPServer + ipcClient *ipc.Client + svcPort string +} + +// registerServer creates and configures a new MCP server +func registerServer() *mcpServerHandler { + ipcClient, svcPort, err := connectToSvc() + if err != nil { + panic(errs.JoinMessage(err)) + } + + // Create MCP server + s := server.NewMCPServer( + constants.CommandName, + constants.VersionNumber, + ) + + mcpHandler := &mcpServerHandler{ + mcpServer: s, + ipcClient: ipcClient, + svcPort: svcPort, + } + + return mcpHandler +} + +// addResource adds a resource to the MCP server with error handling and logging +func (t *mcpServerHandler) addResource(resource mcp.Resource, handler server.ResourceHandlerFunc) { + t.mcpServer.AddResource(resource, func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { + defer func() { + if r := recover(); r != nil { + logging.Error("Recovered from resource handler panic: %v", r) + fmt.Printf("Recovered from resource handler panic: %v\n", r) + } + }() + logging.Debug("Received resource request: %s", resource.Name) + r, err := handler(ctx, request) + if err != nil { + logging.Error("%s: Error handling resource request: %v", resource.Name, err) + return nil, errs.Wrap(err, "Failed to handle resource request") + } + return r, nil + }) +} + +// addTool adds a tool to the MCP server with error handling and logging +func (t *mcpServerHandler) addTool(tool mcp.Tool, handler server.ToolHandlerFunc) { + t.mcpServer.AddTool(tool, func(ctx context.Context, request mcp.CallToolRequest) (r *mcp.CallToolResult, rerr error) { + defer func() { + if r := recover(); r != nil { + logging.Error("Recovered from tool handler panic: %v", r) + fmt.Printf("Recovered from tool handler panic: %v\n", r) + } + }() + logging.Debug("Received tool request: %s", tool.Name) + r, err := handler(ctx, request) + logging.Debug("Received tool response from %s", tool.Name) + if err != nil { + logging.Error("%s: Error handling tool request: %v", tool.Name, errs.JoinMessage(err)) + // Format all errors as a single string, so the client gets the full context + return nil, fmt.Errorf("%s: %s", tool.Name, errs.JoinMessage(err)) + } + return r, nil + }) +} + +type stdOutput struct{} + +func (s *stdOutput) Notice(msg interface{}) { + logging.Info(fmt.Sprintf("%v", msg)) +} + +// connectToSvc connects to the state service and returns an IPC client +func connectToSvc() (*ipc.Client, string, error) { + svcExec, err := installation.ServiceExec() + if err != nil { + return nil, "", errs.Wrap(err, "Could not get service info") + } + + ipcClient := svcctl.NewDefaultIPCClient() + argText := strings.Join(os.Args, " ") + svcPort, err := svcctl.EnsureExecStartedAndLocateHTTP(ipcClient, svcExec, argText, &stdOutput{}) + if err != nil { + return nil, "", errs.Wrap(err, "Failed to start state-svc at state tool invocation") + } + + return ipcClient, svcPort, nil +} \ No newline at end of file diff --git a/cmd/state-mcp/tools.go b/cmd/state-mcp/tools.go new file mode 100644 index 0000000000..6fd56133f0 --- /dev/null +++ b/cmd/state-mcp/tools.go @@ -0,0 +1,139 @@ +package main + +import ( + "bytes" + "context" + "fmt" + "strings" + + "github.com/ActiveState/cli/cmd/state/donotshipme" + "github.com/ActiveState/cli/internal/errs" + "github.com/ActiveState/cli/internal/logging" + "github.com/ActiveState/cli/internal/sliceutils" + "github.com/ActiveState/cli/pkg/project" + "github.com/mark3labs/mcp-go/mcp" +) + +// registerCuratedTools registers a curated set of tools for the AI assistant +func registerCuratedTools(mcpHandler *mcpServerHandler) { + projectDirParam := mcp.WithString("project_directory", + mcp.Required(), + mcp.Description("Absolute path to the directory where your activestate project is checked out. It should contain the activestate.yaml file."), + ) + + mcpHandler.addTool(mcp.NewTool("list_projects", + mcp.WithDescription("List all ActiveState projects checked out on the local machine"), + ), mcpHandler.listProjectsHandler) + + mcpHandler.addTool(mcp.NewTool("view_manifest", + mcp.WithDescription("Show the manifest (packages and dependencies) for a locally checked out ActiveState platform project"), + projectDirParam, + ), mcpHandler.manifestHandler) + + mcpHandler.addTool(mcp.NewTool("view_cves", + mcp.WithDescription("Show the CVEs for a locally checked out ActiveState platform project"), + projectDirParam, + ), mcpHandler.cveHandler) + + mcpHandler.addTool(mcp.NewTool("lookup_cve", + mcp.WithDescription("Lookup one or more CVEs by their ID"), + mcp.WithString("cve_ids", + mcp.Required(), + mcp.Description("The IDs of the CVEs to lookup, comma separated"), + ), + ), mcpHandler.lookupCveHandler) +} + +// registerRawTools registers all State Tool commands as raw tools +func registerRawTools(mcpHandler *mcpServerHandler) func() error { + byt := &bytes.Buffer{} + prime, close, err := mcpHandler.newPrimer("", byt) + if err != nil { + panic(err) + } + + require := func(b bool) mcp.PropertyOption { + if b { + return mcp.Required() + } + return func(map[string]interface{}) {} + } + + tree := donotshipme.CmdTree(prime) + for _, command := range tree.Command().AllChildren() { + // Best effort to filter out interactive commands + if sliceutils.Contains([]string{"activate", "shell"}, command.NameRecursive()) { + continue + } + + opts := []mcp.ToolOption{ + mcp.WithDescription(command.Description()), + } + + // Require project directory for most commands. This is currently not encoded into the command tree + if !sliceutils.Contains([]string{"projects", "auth"}, command.BaseCommand().Name()) { + opts = append(opts, mcp.WithString( + "project_directory", + mcp.Required(), + mcp.Description("Absolute path to the directory where your activestate project is checked out. It should contain the activestate.yaml file."), + )) + } + + for _, arg := range command.Arguments() { + opts = append(opts, mcp.WithString(arg.Name, + require(arg.Required), + mcp.Description(arg.Description), + )) + } + for _, flag := range command.Flags() { + opts = append(opts, mcp.WithString(flag.Name, + mcp.Description(flag.Description), + )) + } + mcpHandler.addTool( + mcp.NewTool(strings.Join(strings.Split(command.NameRecursive(), " "), "_"), opts...), + func(ctx context.Context, request mcp.CallToolRequest) (r *mcp.CallToolResult, rerr error) { + byt.Truncate(0) + if projectDir, ok := request.Params.Arguments["project_directory"]; ok { + pj, err := project.FromPath(projectDir.(string)) + if err != nil { + return nil, errs.Wrap(err, "Failed to create project") + } + prime.SetProject(pj) + } + // Reinitialize tree with updated primer, because currently our command can take things + // from the primer at the time of registration, and not the time of invocation. + invocationTree := donotshipme.CmdTree(prime) + for _, child := range invocationTree.Command().AllChildren() { + if child.NameRecursive() == command.NameRecursive() { + command = child + break + } + } + args := strings.Split(command.NameRecursive(), " ") + for _, arg := range command.Arguments() { + v, ok := request.Params.Arguments[arg.Name] + if !ok { + break + } + args = append(args, v.(string)) + } + for _, flag := range command.Flags() { + v, ok := request.Params.Arguments[flag.Name] + if !ok { + break + } + args = append(args, fmt.Sprintf("--%s=%s", flag.Name, v.(string))) + } + logging.Debug("Executing command: %s, args: %v (%v)", command.NameRecursive(), args, args == nil) + err := command.Execute(args) + if err != nil { + return nil, errs.Wrap(err, "Failed to execute command") + } + return mcp.NewToolResultText(byt.String()), nil + }, + ) + } + + return close +} \ No newline at end of file From 1d205866a64e6607d47640995667d11efcdf8c58 Mon Sep 17 00:00:00 2001 From: Nathan Rijksen Date: Fri, 11 Apr 2025 09:17:25 -0700 Subject: [PATCH 22/29] Fix window target not being used --- activestate.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/activestate.yaml b/activestate.yaml index ab3a71ae8b..6333cb5368 100644 --- a/activestate.yaml +++ b/activestate.yaml @@ -32,6 +32,7 @@ constants: - name: BUILD_REMOTE_INSTALLER_TARGET value: state-remote-installer - name: BUILD_MCP_TARGET + if: ne .OS.Name "Windows" value: state-mcp - name: INTEGRATION_TEST_REGEX value: 'integration\|automation' From 0063d6f6820667e2330f4c33318cf4a012ccdeb3 Mon Sep 17 00:00:00 2001 From: Nathan Rijksen Date: Fri, 11 Apr 2025 09:29:29 -0700 Subject: [PATCH 23/29] Implemented flawed script runner that doesn't pass back stdout as subshell doesnt support it --- cmd/state-mcp/main.go | 13 +++++++++---- cmd/state-mcp/tools.go | 40 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+), 4 deletions(-) diff --git a/cmd/state-mcp/main.go b/cmd/state-mcp/main.go index 6a35a1bc24..88f62e5785 100644 --- a/cmd/state-mcp/main.go +++ b/cmd/state-mcp/main.go @@ -4,6 +4,7 @@ import ( "flag" "fmt" "os" + "runtime/debug" "time" "github.com/ActiveState/cli/internal/events" @@ -16,7 +17,7 @@ func main() { logging.Debug("Exiting") if r := recover(); r != nil { logging.Error("Recovered from panic: %v", r) - fmt.Printf("Recovered from panic: %v\n", r) + fmt.Printf("Recovered from panic: %v, stack: %s\n", r, string(debug.Stack())) os.Exit(1) } }() @@ -29,12 +30,16 @@ func main() { mcpHandler := registerServer() // Parse command line flags - rawFlag := flag.Bool("raw", false, "Expose all State Tool commands as tools; this will lead to issues and is not optimized for AI use") + rawFlag := flag.String("type", "", "Type of MCP server to run; raw, curated or scripts") flag.Parse() - if *rawFlag { + switch *rawFlag { + case "raw": close := registerRawTools(mcpHandler) defer close() - } else { + case "scripts": + close := registerScriptTools(mcpHandler) + defer close() + default: registerCuratedTools(mcpHandler) } diff --git a/cmd/state-mcp/tools.go b/cmd/state-mcp/tools.go index 6fd56133f0..fb70b11f21 100644 --- a/cmd/state-mcp/tools.go +++ b/cmd/state-mcp/tools.go @@ -4,16 +4,56 @@ import ( "bytes" "context" "fmt" + "os" "strings" "github.com/ActiveState/cli/cmd/state/donotshipme" + "github.com/ActiveState/cli/internal/constants" "github.com/ActiveState/cli/internal/errs" "github.com/ActiveState/cli/internal/logging" + "github.com/ActiveState/cli/internal/scriptrun" "github.com/ActiveState/cli/internal/sliceutils" "github.com/ActiveState/cli/pkg/project" "github.com/mark3labs/mcp-go/mcp" ) +func registerScriptTools(mcpHandler *mcpServerHandler) func() error { + byt := &bytes.Buffer{} + prime, close, err := mcpHandler.newPrimer(os.Getenv(constants.ActivatedStateEnvVarName), byt) + if err != nil { + panic(err) + } + + scripts, err := prime.Project().Scripts() + if err != nil { + panic(err) + } + + for _, script := range scripts { + mcpHandler.addTool(mcp.NewTool(script.Name(), + mcp.WithDescription(script.Description()), + ), func(ctx context.Context, request mcp.CallToolRequest) (r *mcp.CallToolResult, rerr error) { + byt.Truncate(0) + + scriptrunner := scriptrun.New(prime) + if !script.Standalone() && scriptrunner.NeedsActivation() { + if err := scriptrunner.PrepareVirtualEnv(); err != nil { + return nil, errs.Wrap(err, "Failed to prepare virtual environment") + } + } + + err := scriptrunner.Run(script, []string{}) + if err != nil { + return nil, errs.Wrap(err, "Failed to run script") + } + + return mcp.NewToolResultText(byt.String()), nil + }) + } + + return close +} + // registerCuratedTools registers a curated set of tools for the AI assistant func registerCuratedTools(mcpHandler *mcpServerHandler) { projectDirParam := mcp.WithString("project_directory", From 121eb845e964648d0192b365ada50a89176e4f2e Mon Sep 17 00:00:00 2001 From: Nathan Rijksen Date: Fri, 11 Jul 2025 15:53:02 -0700 Subject: [PATCH 24/29] Cleanup state-mcp --- cmd/state-mcp/handlers.go | 103 -- cmd/state-mcp/internal/mcpserver/server.go | 74 + .../internal/toolregistry/categories.go | 23 + .../internal/toolregistry/registry.go | 56 + cmd/state-mcp/internal/toolregistry/tools.go | 40 + cmd/state-mcp/lookupcve.go | 38 - cmd/state-mcp/lookupcve_test.go | 43 - cmd/state-mcp/main.go | 36 +- cmd/state-mcp/primer.go | 62 +- cmd/state-mcp/server.go | 109 -- cmd/state-mcp/server_test.go | 38 +- cmd/state-mcp/tools.go | 179 --- go.mod | 4 +- go.sum | 9 +- internal/output/json.go | 19 +- internal/output/mediator.go | 4 + internal/output/output.go | 7 + internal/output/plain.go | 24 +- internal/output/writer.go | 17 + internal/runners/hello/hello_example.go | 65 +- internal/testhelpers/outputhelper/outputer.go | 1 + .../mark3labs/mcp-go/client/client.go | 84 -- .../github.com/mark3labs/mcp-go/client/sse.go | 588 -------- .../mark3labs/mcp-go/client/stdio.go | 457 ------ .../mark3labs/mcp-go/client/types.go | 8 - .../mark3labs/mcp-go/mcp/prompts.go | 21 +- .../mark3labs/mcp-go/mcp/resources.go | 10 +- .../github.com/mark3labs/mcp-go/mcp/tools.go | 718 +++++++++- .../mark3labs/mcp-go/mcp/typed_tools.go | 20 + .../github.com/mark3labs/mcp-go/mcp/types.go | 539 ++++--- .../github.com/mark3labs/mcp-go/mcp/utils.go | 253 +++- .../mark3labs/mcp-go/server/errors.go | 34 + .../mark3labs/mcp-go/server/hooks.go | 79 +- .../mcp-go/server/http_transport_options.go | 11 + .../mcp-go/server/request_handler.go | 59 +- .../mark3labs/mcp-go/server/sampling.go | 37 + .../mark3labs/mcp-go/server/server.go | 783 +++++++---- .../mark3labs/mcp-go/server/session.go | 380 +++++ .../github.com/mark3labs/mcp-go/server/sse.go | 445 +++++- .../mark3labs/mcp-go/server/stdio.go | 247 +++- .../mcp-go/server/streamable_http.go | 655 +++++++++ .../mark3labs/mcp-go/util/logger.go | 33 + vendor/github.com/spf13/cast/.editorconfig | 15 + vendor/github.com/spf13/cast/.golangci.yaml | 39 + vendor/github.com/spf13/cast/.travis.yml | 15 - vendor/github.com/spf13/cast/Makefile | 4 +- vendor/github.com/spf13/cast/README.md | 16 +- vendor/github.com/spf13/cast/alias.go | 69 + vendor/github.com/spf13/cast/basic.go | 131 ++ vendor/github.com/spf13/cast/cast.go | 227 +-- vendor/github.com/spf13/cast/caste.go | 1249 ----------------- vendor/github.com/spf13/cast/indirect.go | 37 + vendor/github.com/spf13/cast/internal/time.go | 79 ++ .../cast/internal/timeformattype_string.go | 27 + vendor/github.com/spf13/cast/map.go | 224 +++ vendor/github.com/spf13/cast/number.go | 549 ++++++++ vendor/github.com/spf13/cast/slice.go | 106 ++ vendor/github.com/spf13/cast/time.go | 116 ++ vendor/github.com/spf13/cast/zz_generated.go | 261 ++++ vendor/modules.txt | 9 +- 60 files changed, 5814 insertions(+), 3771 deletions(-) delete mode 100644 cmd/state-mcp/handlers.go create mode 100644 cmd/state-mcp/internal/mcpserver/server.go create mode 100644 cmd/state-mcp/internal/toolregistry/categories.go create mode 100644 cmd/state-mcp/internal/toolregistry/registry.go create mode 100644 cmd/state-mcp/internal/toolregistry/tools.go delete mode 100644 cmd/state-mcp/lookupcve.go delete mode 100644 cmd/state-mcp/lookupcve_test.go delete mode 100644 cmd/state-mcp/server.go delete mode 100644 cmd/state-mcp/tools.go create mode 100644 internal/output/writer.go delete mode 100644 vendor/github.com/mark3labs/mcp-go/client/client.go delete mode 100644 vendor/github.com/mark3labs/mcp-go/client/sse.go delete mode 100644 vendor/github.com/mark3labs/mcp-go/client/stdio.go delete mode 100644 vendor/github.com/mark3labs/mcp-go/client/types.go create mode 100644 vendor/github.com/mark3labs/mcp-go/mcp/typed_tools.go create mode 100644 vendor/github.com/mark3labs/mcp-go/server/errors.go create mode 100644 vendor/github.com/mark3labs/mcp-go/server/http_transport_options.go create mode 100644 vendor/github.com/mark3labs/mcp-go/server/sampling.go create mode 100644 vendor/github.com/mark3labs/mcp-go/server/session.go create mode 100644 vendor/github.com/mark3labs/mcp-go/server/streamable_http.go create mode 100644 vendor/github.com/mark3labs/mcp-go/util/logger.go create mode 100644 vendor/github.com/spf13/cast/.editorconfig create mode 100644 vendor/github.com/spf13/cast/.golangci.yaml delete mode 100644 vendor/github.com/spf13/cast/.travis.yml create mode 100644 vendor/github.com/spf13/cast/alias.go create mode 100644 vendor/github.com/spf13/cast/basic.go delete mode 100644 vendor/github.com/spf13/cast/caste.go create mode 100644 vendor/github.com/spf13/cast/indirect.go create mode 100644 vendor/github.com/spf13/cast/internal/time.go create mode 100644 vendor/github.com/spf13/cast/internal/timeformattype_string.go create mode 100644 vendor/github.com/spf13/cast/map.go create mode 100644 vendor/github.com/spf13/cast/number.go create mode 100644 vendor/github.com/spf13/cast/slice.go create mode 100644 vendor/github.com/spf13/cast/time.go create mode 100644 vendor/github.com/spf13/cast/zz_generated.go diff --git a/cmd/state-mcp/handlers.go b/cmd/state-mcp/handlers.go deleted file mode 100644 index cc72209b67..0000000000 --- a/cmd/state-mcp/handlers.go +++ /dev/null @@ -1,103 +0,0 @@ -package main - -import ( - "bytes" - "context" - "encoding/json" - "strings" - - "github.com/ActiveState/cli/internal/errs" - "github.com/ActiveState/cli/internal/runners/cve" - "github.com/ActiveState/cli/internal/runners/manifest" - "github.com/ActiveState/cli/internal/runners/projects" - "github.com/mark3labs/mcp-go/mcp" -) - -// listProjectsHandler handles the list_projects tool -func (t *mcpServerHandler) listProjectsHandler(ctx context.Context, request mcp.CallToolRequest) (r *mcp.CallToolResult, rerr error) { - var byt bytes.Buffer - prime, close, err := t.newPrimer("", &byt) - if err != nil { - return nil, errs.Wrap(err, "Failed to create primer") - } - defer func() { - if err := close(); err != nil { - rerr = errs.Pack(rerr, err) - } - }() - - runner := projects.NewProjects(prime) - params := projects.NewParams() - err = runner.Run(params) - if err != nil { - return nil, errs.Wrap(err, "Failed to run projects") - } - - return mcp.NewToolResultText(byt.String()), nil -} - -// manifestHandler handles the view_manifest tool -func (t *mcpServerHandler) manifestHandler(ctx context.Context, request mcp.CallToolRequest) (r *mcp.CallToolResult, rerr error) { - pjPath := request.Params.Arguments["project_directory"].(string) - - var byt bytes.Buffer - prime, close, err := t.newPrimer(pjPath, &byt) - if err != nil { - return nil, errs.Wrap(err, "Failed to create primer") - } - defer func() { - if err := close(); err != nil { - rerr = errs.Pack(rerr, err) - } - }() - - m := manifest.NewManifest(prime) - err = m.Run(manifest.Params{}) - if err != nil { - return nil, errs.Wrap(err, "Failed to run manifest") - } - - return mcp.NewToolResultText(byt.String()), nil -} - -// cveHandler handles the view_cves tool -func (t *mcpServerHandler) cveHandler(ctx context.Context, request mcp.CallToolRequest) (r *mcp.CallToolResult, rerr error) { - pjPath := request.Params.Arguments["project_directory"].(string) - - var byt bytes.Buffer - prime, close, err := t.newPrimer(pjPath, &byt) - if err != nil { - return nil, errs.Wrap(err, "Failed to create primer") - } - defer func() { - if err := close(); err != nil { - rerr = errs.Pack(rerr, err) - } - }() - - c := cve.NewCve(prime) - err = c.Run(&cve.Params{}) - if err != nil { - return nil, errs.Wrap(err, "Failed to run manifest") - } - - return mcp.NewToolResultText(byt.String()), nil -} - -// lookupCveHandler handles the lookup_cve tool -func (t *mcpServerHandler) lookupCveHandler(ctx context.Context, request mcp.CallToolRequest) (r *mcp.CallToolResult, rerr error) { - cveId := request.Params.Arguments["cve_ids"].(string) - cveIds := strings.Split(cveId, ",") - - results, err := LookupCve(cveIds...) - if err != nil { - return nil, errs.Wrap(err, "Failed to lookup CVEs") - } - - byt, err := json.Marshal(results) - if err != nil { - return nil, errs.Wrap(err, "Failed to marshal results") - } - - return mcp.NewToolResultText(string(byt)), nil -} \ No newline at end of file diff --git a/cmd/state-mcp/internal/mcpserver/server.go b/cmd/state-mcp/internal/mcpserver/server.go new file mode 100644 index 0000000000..bb0a565c8a --- /dev/null +++ b/cmd/state-mcp/internal/mcpserver/server.go @@ -0,0 +1,74 @@ +package mcpserver + +import ( + "context" + "fmt" + + "github.com/ActiveState/cli/internal/constants" + "github.com/ActiveState/cli/internal/errs" + "github.com/ActiveState/cli/internal/logging" + "github.com/ActiveState/cli/internal/primer" + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" +) + +type ToolHandlerFunc func(context.Context, *primer.Values, mcp.CallToolRequest) (*mcp.CallToolResult, error) + +// Handler wraps the MCP server and provides methods for adding tools and resources +type Handler struct { + Server *server.MCPServer + primeGetter func() (*primer.Values, func() error, error) +} + +func New(primeGetter func() (*primer.Values, func() error, error)) *Handler { + s := server.NewMCPServer( + constants.StateMCPCmd, + constants.VersionNumber, + ) + + mcpHandler := &Handler{ + Server: s, + primeGetter: primeGetter, + } + + return mcpHandler +} + +func (m Handler) ServeStdio() error { + if err := server.ServeStdio(m.Server); err != nil { + logging.Error("Server error: %v\n", err) + } + return nil +} + +// addResource adds a resource to the MCP server with error handling and logging +func (m *Handler) AddResource(resource mcp.Resource, handler server.ResourceHandlerFunc) { + m.Server.AddResource(resource, func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { + r, err := handler(ctx, request) + if err != nil { + logging.Error("%s: Error handling resource request: %v", resource.Name, err) + return nil, errs.Wrap(err, "Failed to handle resource request") + } + return r, nil + }) +} + +// addTool adds a tool to the MCP server with error handling and logging +func (m *Handler) AddTool(tool mcp.Tool, handler ToolHandlerFunc) { + m.Server.AddTool(tool, func(ctx context.Context, request mcp.CallToolRequest) (r *mcp.CallToolResult, rerr error) { + p, closer, err := m.primeGetter() + if err != nil { + return nil, errs.Wrap(err, "Failed to get primer") + } + defer closer() + r, err = handler(ctx, p, request) + if err != nil { + logging.Error("%s: Error handling tool request: %v", tool.Name, errs.JoinMessage(err)) + // Format all errors as a single string, so the client gets the full context + return nil, fmt.Errorf("%s: %s", tool.Name, errs.JoinMessage(err)) + } + return r, nil + }) +} + + diff --git a/cmd/state-mcp/internal/toolregistry/categories.go b/cmd/state-mcp/internal/toolregistry/categories.go new file mode 100644 index 0000000000..0960cd1d7f --- /dev/null +++ b/cmd/state-mcp/internal/toolregistry/categories.go @@ -0,0 +1,23 @@ +package toolregistry + +type ToolCategory string + +const ( + ToolCategoryDebug ToolCategory = "debug" +) + +type ToolCategories []ToolCategory + +func (c ToolCategories) String() []string { + result := []string{} + for _, category := range c { + result = append(result, string(category)) + } + return result +} + +func Categories() ToolCategories { + return ToolCategories{ + ToolCategoryDebug, + } +} \ No newline at end of file diff --git a/cmd/state-mcp/internal/toolregistry/registry.go b/cmd/state-mcp/internal/toolregistry/registry.go new file mode 100644 index 0000000000..4592ccc2d5 --- /dev/null +++ b/cmd/state-mcp/internal/toolregistry/registry.go @@ -0,0 +1,56 @@ +package toolregistry + +import ( + "context" + "slices" + + "github.com/ActiveState/cli/internal/primer" + "github.com/mark3labs/mcp-go/mcp" +) + +type Tool struct { + mcp.Tool + Category ToolCategory + Handler func(context.Context, *primer.Values, mcp.CallToolRequest) (*mcp.CallToolResult, error) +} + +type Registry struct { + tools map[ToolCategory][]Tool +} + +func New() *Registry { + r := &Registry{ + tools: make(map[ToolCategory][]Tool), + } + + r.RegisterTool(HelloWorldTool()) + + return r +} + +func (r *Registry) RegisterTool(tool Tool) { + if _, ok := r.tools[tool.Category]; !ok { + r.tools[tool.Category] = []Tool{} + } + r.tools[tool.Category] = append(r.tools[tool.Category], tool) +} + +func (r *Registry) GetTools(requestCategories ...string) []Tool { + if len(requestCategories) == 0 { + for _, category := range Categories() { + if category == ToolCategoryDebug { + // Debug must be explicitly requested + continue + } + requestCategories = append(requestCategories, string(category)) + } + } + categories := Categories() + result := []Tool{} + for _, category := range categories { + if slices.Contains(requestCategories, string(category)) { + result = append(result, r.tools[category]...) + } + } + return result +} \ No newline at end of file diff --git a/cmd/state-mcp/internal/toolregistry/tools.go b/cmd/state-mcp/internal/toolregistry/tools.go new file mode 100644 index 0000000000..7b216bc7ae --- /dev/null +++ b/cmd/state-mcp/internal/toolregistry/tools.go @@ -0,0 +1,40 @@ +package toolregistry + +import ( + "context" + "strings" + + "github.com/ActiveState/cli/internal/primer" + "github.com/ActiveState/cli/internal/runners/hello" + "github.com/mark3labs/mcp-go/mcp" +) + +func HelloWorldTool() Tool { + return Tool{ + Category: ToolCategoryDebug, + Tool: mcp.NewTool( + "hello", + mcp.WithDescription("Hello world tool"), + mcp.WithString("name", mcp.Required(), mcp.Description("The name to say hello to")), + ), + Handler: func(ctx context.Context, p *primer.Values, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + name, err := request.RequireString("name") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + runner := hello.New(p) + params := hello.NewParams() + params.Name = name + + err = runner.Run(params) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + return mcp.NewToolResultText( + strings.Join(p.Output().History().Print, "\n"), + ), nil + }, + } +} \ No newline at end of file diff --git a/cmd/state-mcp/lookupcve.go b/cmd/state-mcp/lookupcve.go deleted file mode 100644 index 2c113b3821..0000000000 --- a/cmd/state-mcp/lookupcve.go +++ /dev/null @@ -1,38 +0,0 @@ -package main - -import ( - "encoding/json" - "fmt" - "net/http" - - "github.com/ActiveState/cli/internal/chanutils/workerpool" - "github.com/ActiveState/cli/internal/errs" -) - -func LookupCve(cveIds ...string) (map[string]interface{}, error) { - results := map[string]interface{}{} - // https://api.osv.dev/v1/vulns/OSV-2020-111 - wp := workerpool.New(5) - for _, cveId := range cveIds { - wp.Submit(func() error { - resp, err := http.Get(fmt.Sprintf("https://api.osv.dev/v1/vulns/%s", cveId)) - if err != nil { - return err - } - defer resp.Body.Close() - var result map[string]interface{} - if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { - return err - } - results[cveId] = result - return nil - }) - } - - err := wp.Wait() - if err != nil { - return nil, errs.Wrap(err, "Failed to wait for workerpool") - } - - return results, nil -} \ No newline at end of file diff --git a/cmd/state-mcp/lookupcve_test.go b/cmd/state-mcp/lookupcve_test.go deleted file mode 100644 index e643f9672b..0000000000 --- a/cmd/state-mcp/lookupcve_test.go +++ /dev/null @@ -1,43 +0,0 @@ -package main - -import ( - "testing" - - "github.com/stretchr/testify/require" -) - -func TestLookupCve(t *testing.T) { - // Table-driven test cases - tests := []struct { - name string - cveIds []string - }{ - { - name: "Single CVE", - cveIds: []string{"CVE-2021-44228"}, - }, - { - name: "Multiple CVEs", - cveIds: []string{"CVE-2021-44228", "CVE-2022-22965"}, - }, - { - name: "Non-existent CVE", - cveIds: []string{"CVE-DOES-NOT-EXIST"}, - }, - { - name: "Empty Input", - cveIds: []string{}, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - results, err := LookupCve(tt.cveIds...) - require.NoError(t, err) - require.NotNil(t, results) - for _, cveId := range tt.cveIds { - require.Contains(t, results, cveId) - } - }) - } -} \ No newline at end of file diff --git a/cmd/state-mcp/main.go b/cmd/state-mcp/main.go index 88f62e5785..0931ed11e4 100644 --- a/cmd/state-mcp/main.go +++ b/cmd/state-mcp/main.go @@ -5,11 +5,13 @@ import ( "fmt" "os" "runtime/debug" + "strings" "time" + "github.com/ActiveState/cli/cmd/state-mcp/internal/mcpserver" + "github.com/ActiveState/cli/cmd/state-mcp/internal/toolregistry" "github.com/ActiveState/cli/internal/events" "github.com/ActiveState/cli/internal/logging" - "github.com/mark3labs/mcp-go/server" ) func main() { @@ -27,25 +29,27 @@ func main() { } }() - mcpHandler := registerServer() - // Parse command line flags - rawFlag := flag.String("type", "", "Type of MCP server to run; raw, curated or scripts") + rawFlag := flag.String("categories", "", "Comma separated list of categories to register tools for") flag.Parse() - switch *rawFlag { - case "raw": - close := registerRawTools(mcpHandler) - defer close() - case "scripts": - close := registerScriptTools(mcpHandler) - defer close() - default: - registerCuratedTools(mcpHandler) - } + + mcps := setupServer(strings.Split(*rawFlag, ",")...) // Start the stdio server logging.Info("Starting MCP server") - if err := server.ServeStdio(mcpHandler.mcpServer); err != nil { + if err := mcps.ServeStdio(); err != nil { logging.Error("Server error: %v\n", err) } -} \ No newline at end of file +} + +func setupServer(categories ...string) *mcpserver.Handler { + mcps := mcpserver.New(newPrimer) + + registry := toolregistry.New() + tools := registry.GetTools(categories...) + for _, tool := range tools { + mcps.AddTool(tool.Tool, tool.Handler) + } + + return mcps +} diff --git a/cmd/state-mcp/primer.go b/cmd/state-mcp/primer.go index 7215cd42dd..5791b52ab6 100644 --- a/cmd/state-mcp/primer.go +++ b/cmd/state-mcp/primer.go @@ -2,24 +2,29 @@ package main import ( "context" + "fmt" "io" + "os" + "strings" "github.com/ActiveState/cli/internal/config" "github.com/ActiveState/cli/internal/constraints" "github.com/ActiveState/cli/internal/errs" + "github.com/ActiveState/cli/internal/installation" + "github.com/ActiveState/cli/internal/ipc" "github.com/ActiveState/cli/internal/logging" "github.com/ActiveState/cli/internal/multilog" "github.com/ActiveState/cli/internal/output" "github.com/ActiveState/cli/internal/primer" "github.com/ActiveState/cli/internal/subshell" + "github.com/ActiveState/cli/internal/svcctl" "github.com/ActiveState/cli/pkg/platform/authentication" "github.com/ActiveState/cli/pkg/platform/model" "github.com/ActiveState/cli/pkg/project" - "github.com/ActiveState/cli/pkg/projectfile" ) // newPrimer creates a new primer.Values instance for use with command execution -func (t *mcpServerHandler) newPrimer(projectDir string, o io.Writer) (*primer.Values, func() error, error) { +func newPrimer() (*primer.Values, func() error, error) { closers := []func() error{} closer := func() error { for _, c := range closers { @@ -40,8 +45,8 @@ func (t *mcpServerHandler) newPrimer(projectDir string, o io.Writer) (*primer.Va closers = append(closers, auth.Close) out, err := output.New(string(output.SimpleFormatName), &output.Config{ - OutWriter: o, - ErrWriter: o, + OutWriter: io.Discard, // We use Outputer.History() instead + ErrWriter: io.Discard, // We use Outputer.History() instead Colored: false, Interactive: false, ShellName: "", @@ -50,28 +55,21 @@ func (t *mcpServerHandler) newPrimer(projectDir string, o io.Writer) (*primer.Va return nil, closer, errs.Wrap(err, "Failed to create output") } - var pj *project.Project - if projectDir != "" { - pjf, err := projectfile.FromPath(projectDir) - if err != nil { - return nil, closer, errs.Wrap(err, "Failed to create projectfile") - } - pj, err = project.New(pjf, out) - if err != nil { - return nil, closer, errs.Wrap(err, "Failed to create project") - } - } - // Set up conditional, which accesses a lot of primer data sshell := subshell.New(cfg) - conditional := constraints.NewPrimeConditional(auth, pj, sshell.Shell()) + conditional := constraints.NewPrimeConditional(auth, nil, sshell.Shell()) project.RegisterConditional(conditional) if err := project.RegisterExpander("mixin", project.NewMixin(auth).Expander); err != nil { logging.Debug("Could not register mixin expander: %v", err) } - svcmodel := model.NewSvcModel(t.svcPort) + ipcClient, svcPort, err := connectToSvc() + if err != nil { + return nil, closer, errs.Wrap(err, "Failed to connect to service") + } + + svcmodel := model.NewSvcModel(svcPort) if auth.AvailableAPIToken() != "" { jwt, err := svcmodel.GetJWT(context.Background()) @@ -86,5 +84,29 @@ func (t *mcpServerHandler) newPrimer(projectDir string, o io.Writer) (*primer.Va } } - return primer.New(pj, out, auth, sshell, conditional, cfg, t.ipcClient, svcmodel), closer, nil -} \ No newline at end of file + return primer.New(out, auth, sshell, conditional, cfg, ipcClient, svcmodel), closer, nil +} + + +type stdOutput struct{} + +func (s *stdOutput) Notice(msg interface{}) { + logging.Info(fmt.Sprintf("%v", msg)) +} + +// connectToSvc connects to the state service and returns an IPC client +func connectToSvc() (*ipc.Client, string, error) { + svcExec, err := installation.ServiceExec() + if err != nil { + return nil, "", errs.Wrap(err, "Could not get service info") + } + + ipcClient := svcctl.NewDefaultIPCClient() + argText := strings.Join(os.Args, " ") + svcPort, err := svcctl.EnsureExecStartedAndLocateHTTP(ipcClient, svcExec, argText, &stdOutput{}) + if err != nil { + return nil, "", errs.Wrap(err, "Failed to start state-svc at state tool invocation") + } + + return ipcClient, svcPort, nil +} \ No newline at end of file diff --git a/cmd/state-mcp/server.go b/cmd/state-mcp/server.go deleted file mode 100644 index 7a4461c98d..0000000000 --- a/cmd/state-mcp/server.go +++ /dev/null @@ -1,109 +0,0 @@ -package main - -import ( - "context" - "fmt" - "os" - "strings" - - "github.com/ActiveState/cli/internal/constants" - "github.com/ActiveState/cli/internal/errs" - "github.com/ActiveState/cli/internal/installation" - "github.com/ActiveState/cli/internal/ipc" - "github.com/ActiveState/cli/internal/logging" - "github.com/ActiveState/cli/internal/svcctl" - "github.com/mark3labs/mcp-go/mcp" - "github.com/mark3labs/mcp-go/server" -) - -// mcpServerHandler wraps the MCP server and provides methods for adding tools and resources -type mcpServerHandler struct { - mcpServer *server.MCPServer - ipcClient *ipc.Client - svcPort string -} - -// registerServer creates and configures a new MCP server -func registerServer() *mcpServerHandler { - ipcClient, svcPort, err := connectToSvc() - if err != nil { - panic(errs.JoinMessage(err)) - } - - // Create MCP server - s := server.NewMCPServer( - constants.CommandName, - constants.VersionNumber, - ) - - mcpHandler := &mcpServerHandler{ - mcpServer: s, - ipcClient: ipcClient, - svcPort: svcPort, - } - - return mcpHandler -} - -// addResource adds a resource to the MCP server with error handling and logging -func (t *mcpServerHandler) addResource(resource mcp.Resource, handler server.ResourceHandlerFunc) { - t.mcpServer.AddResource(resource, func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { - defer func() { - if r := recover(); r != nil { - logging.Error("Recovered from resource handler panic: %v", r) - fmt.Printf("Recovered from resource handler panic: %v\n", r) - } - }() - logging.Debug("Received resource request: %s", resource.Name) - r, err := handler(ctx, request) - if err != nil { - logging.Error("%s: Error handling resource request: %v", resource.Name, err) - return nil, errs.Wrap(err, "Failed to handle resource request") - } - return r, nil - }) -} - -// addTool adds a tool to the MCP server with error handling and logging -func (t *mcpServerHandler) addTool(tool mcp.Tool, handler server.ToolHandlerFunc) { - t.mcpServer.AddTool(tool, func(ctx context.Context, request mcp.CallToolRequest) (r *mcp.CallToolResult, rerr error) { - defer func() { - if r := recover(); r != nil { - logging.Error("Recovered from tool handler panic: %v", r) - fmt.Printf("Recovered from tool handler panic: %v\n", r) - } - }() - logging.Debug("Received tool request: %s", tool.Name) - r, err := handler(ctx, request) - logging.Debug("Received tool response from %s", tool.Name) - if err != nil { - logging.Error("%s: Error handling tool request: %v", tool.Name, errs.JoinMessage(err)) - // Format all errors as a single string, so the client gets the full context - return nil, fmt.Errorf("%s: %s", tool.Name, errs.JoinMessage(err)) - } - return r, nil - }) -} - -type stdOutput struct{} - -func (s *stdOutput) Notice(msg interface{}) { - logging.Info(fmt.Sprintf("%v", msg)) -} - -// connectToSvc connects to the state service and returns an IPC client -func connectToSvc() (*ipc.Client, string, error) { - svcExec, err := installation.ServiceExec() - if err != nil { - return nil, "", errs.Wrap(err, "Could not get service info") - } - - ipcClient := svcctl.NewDefaultIPCClient() - argText := strings.Join(os.Args, " ") - svcPort, err := svcctl.EnsureExecStartedAndLocateHTTP(ipcClient, svcExec, argText, &stdOutput{}) - if err != nil { - return nil, "", errs.Wrap(err, "Failed to start state-svc at state tool invocation") - } - - return ipcClient, svcPort, nil -} \ No newline at end of file diff --git a/cmd/state-mcp/server_test.go b/cmd/state-mcp/server_test.go index c43892e8e8..22807773d5 100644 --- a/cmd/state-mcp/server_test.go +++ b/cmd/state-mcp/server_test.go @@ -5,39 +5,25 @@ import ( "encoding/json" "testing" - "github.com/ActiveState/cli/internal/environment" + "github.com/ActiveState/cli/cmd/state-mcp/internal/toolregistry" + "github.com/ActiveState/cli/internal/logging" ) -func TestServerProjects(t *testing.T) { - t.Skip("Intended for manual testing") - mcpHandler := registerServer() - registerRawTools(mcpHandler) - - msg := mcpHandler.mcpServer.HandleMessage(context.Background(), json.RawMessage(`{ +func TestServerHello(t *testing.T) { + t.Skip(` +Fails due to state-svc not being detected when run as regular test, +works when running with debugger. Problem for another day. +`) + logging.CurrentHandler().SetVerbose(true) + mcpHandler := setupServer(string(toolregistry.ToolCategoryDebug)) + msg := mcpHandler.Server.HandleMessage(context.Background(), json.RawMessage(`{ "jsonrpc": "2.0", "id": 1, "method": "tools/call", "params": { - "name": "projects", - "arguments": {} - } - }`)) - t.Fatalf("%+v", msg) -} - -func TestServerPackages(t *testing.T) { - t.Skip("Intended for manual testing") - mcpHandler := registerServer() - registerRawTools(mcpHandler) - - msg := mcpHandler.mcpServer.HandleMessage(context.Background(), json.RawMessage(`{ - "jsonrpc": "2.0", - "id": 1, - "method": "tools/call", - "params": { - "name": "packages", + "name": "hello", "arguments": { - "project_directory": "`+environment.GetRootPathUnsafe()+`" + "name": "World" } } }`)) diff --git a/cmd/state-mcp/tools.go b/cmd/state-mcp/tools.go deleted file mode 100644 index fb70b11f21..0000000000 --- a/cmd/state-mcp/tools.go +++ /dev/null @@ -1,179 +0,0 @@ -package main - -import ( - "bytes" - "context" - "fmt" - "os" - "strings" - - "github.com/ActiveState/cli/cmd/state/donotshipme" - "github.com/ActiveState/cli/internal/constants" - "github.com/ActiveState/cli/internal/errs" - "github.com/ActiveState/cli/internal/logging" - "github.com/ActiveState/cli/internal/scriptrun" - "github.com/ActiveState/cli/internal/sliceutils" - "github.com/ActiveState/cli/pkg/project" - "github.com/mark3labs/mcp-go/mcp" -) - -func registerScriptTools(mcpHandler *mcpServerHandler) func() error { - byt := &bytes.Buffer{} - prime, close, err := mcpHandler.newPrimer(os.Getenv(constants.ActivatedStateEnvVarName), byt) - if err != nil { - panic(err) - } - - scripts, err := prime.Project().Scripts() - if err != nil { - panic(err) - } - - for _, script := range scripts { - mcpHandler.addTool(mcp.NewTool(script.Name(), - mcp.WithDescription(script.Description()), - ), func(ctx context.Context, request mcp.CallToolRequest) (r *mcp.CallToolResult, rerr error) { - byt.Truncate(0) - - scriptrunner := scriptrun.New(prime) - if !script.Standalone() && scriptrunner.NeedsActivation() { - if err := scriptrunner.PrepareVirtualEnv(); err != nil { - return nil, errs.Wrap(err, "Failed to prepare virtual environment") - } - } - - err := scriptrunner.Run(script, []string{}) - if err != nil { - return nil, errs.Wrap(err, "Failed to run script") - } - - return mcp.NewToolResultText(byt.String()), nil - }) - } - - return close -} - -// registerCuratedTools registers a curated set of tools for the AI assistant -func registerCuratedTools(mcpHandler *mcpServerHandler) { - projectDirParam := mcp.WithString("project_directory", - mcp.Required(), - mcp.Description("Absolute path to the directory where your activestate project is checked out. It should contain the activestate.yaml file."), - ) - - mcpHandler.addTool(mcp.NewTool("list_projects", - mcp.WithDescription("List all ActiveState projects checked out on the local machine"), - ), mcpHandler.listProjectsHandler) - - mcpHandler.addTool(mcp.NewTool("view_manifest", - mcp.WithDescription("Show the manifest (packages and dependencies) for a locally checked out ActiveState platform project"), - projectDirParam, - ), mcpHandler.manifestHandler) - - mcpHandler.addTool(mcp.NewTool("view_cves", - mcp.WithDescription("Show the CVEs for a locally checked out ActiveState platform project"), - projectDirParam, - ), mcpHandler.cveHandler) - - mcpHandler.addTool(mcp.NewTool("lookup_cve", - mcp.WithDescription("Lookup one or more CVEs by their ID"), - mcp.WithString("cve_ids", - mcp.Required(), - mcp.Description("The IDs of the CVEs to lookup, comma separated"), - ), - ), mcpHandler.lookupCveHandler) -} - -// registerRawTools registers all State Tool commands as raw tools -func registerRawTools(mcpHandler *mcpServerHandler) func() error { - byt := &bytes.Buffer{} - prime, close, err := mcpHandler.newPrimer("", byt) - if err != nil { - panic(err) - } - - require := func(b bool) mcp.PropertyOption { - if b { - return mcp.Required() - } - return func(map[string]interface{}) {} - } - - tree := donotshipme.CmdTree(prime) - for _, command := range tree.Command().AllChildren() { - // Best effort to filter out interactive commands - if sliceutils.Contains([]string{"activate", "shell"}, command.NameRecursive()) { - continue - } - - opts := []mcp.ToolOption{ - mcp.WithDescription(command.Description()), - } - - // Require project directory for most commands. This is currently not encoded into the command tree - if !sliceutils.Contains([]string{"projects", "auth"}, command.BaseCommand().Name()) { - opts = append(opts, mcp.WithString( - "project_directory", - mcp.Required(), - mcp.Description("Absolute path to the directory where your activestate project is checked out. It should contain the activestate.yaml file."), - )) - } - - for _, arg := range command.Arguments() { - opts = append(opts, mcp.WithString(arg.Name, - require(arg.Required), - mcp.Description(arg.Description), - )) - } - for _, flag := range command.Flags() { - opts = append(opts, mcp.WithString(flag.Name, - mcp.Description(flag.Description), - )) - } - mcpHandler.addTool( - mcp.NewTool(strings.Join(strings.Split(command.NameRecursive(), " "), "_"), opts...), - func(ctx context.Context, request mcp.CallToolRequest) (r *mcp.CallToolResult, rerr error) { - byt.Truncate(0) - if projectDir, ok := request.Params.Arguments["project_directory"]; ok { - pj, err := project.FromPath(projectDir.(string)) - if err != nil { - return nil, errs.Wrap(err, "Failed to create project") - } - prime.SetProject(pj) - } - // Reinitialize tree with updated primer, because currently our command can take things - // from the primer at the time of registration, and not the time of invocation. - invocationTree := donotshipme.CmdTree(prime) - for _, child := range invocationTree.Command().AllChildren() { - if child.NameRecursive() == command.NameRecursive() { - command = child - break - } - } - args := strings.Split(command.NameRecursive(), " ") - for _, arg := range command.Arguments() { - v, ok := request.Params.Arguments[arg.Name] - if !ok { - break - } - args = append(args, v.(string)) - } - for _, flag := range command.Flags() { - v, ok := request.Params.Arguments[flag.Name] - if !ok { - break - } - args = append(args, fmt.Sprintf("--%s=%s", flag.Name, v.(string))) - } - logging.Debug("Executing command: %s, args: %v (%v)", command.NameRecursive(), args, args == nil) - err := command.Execute(args) - if err != nil { - return nil, errs.Wrap(err, "Failed to execute command") - } - return mcp.NewToolResultText(byt.String()), nil - }, - ) - } - - return close -} \ No newline at end of file diff --git a/go.mod b/go.mod index d5f3fd3da6..91350c7d99 100644 --- a/go.mod +++ b/go.mod @@ -47,7 +47,7 @@ require ( github.com/rollbar/rollbar-go v1.1.0 github.com/shirou/gopsutil/v3 v3.24.5 github.com/skratchdot/open-golang v0.0.0-20190104022628-a2dfa6d0dab6 - github.com/spf13/cast v1.3.0 + github.com/spf13/cast v1.9.2 github.com/spf13/cobra v1.1.1 github.com/spf13/pflag v1.0.5 github.com/stretchr/testify v1.10.0 @@ -75,7 +75,7 @@ require ( github.com/go-git/go-git/v5 v5.13.1 github.com/gowebpki/jcs v1.0.1 github.com/klauspost/compress v1.11.4 - github.com/mark3labs/mcp-go v0.18.0 + github.com/mark3labs/mcp-go v0.33.0 github.com/mholt/archiver/v3 v3.5.1 github.com/zijiren233/yaml-comment v0.2.1 ) diff --git a/go.sum b/go.sum index 30abed107f..b80d8708cb 100644 --- a/go.sum +++ b/go.sum @@ -157,6 +157,8 @@ github.com/fatih/structs v1.1.0 h1:Q7juDM0QtcnhCpeyLGQKyg4TOIghuNXrkL32pHAUMxo= github.com/fatih/structs v1.1.0/go.mod h1:9NiDSp5zOcgEDl+j00MP/WkGVPOlPRLejGD8Ga6PJ7M= github.com/felixge/fgprof v0.9.0 h1:1Unx04fyC3gn3RMH/GuwUF1UdlulLMpJV13jr9SOHvs= github.com/felixge/fgprof v0.9.0/go.mod h1:7/HK6JFtFaARhIljgP2IV8rJLIoHDoOYoUphsnGvqxE= +github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= +github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= github.com/fsnotify/fsnotify v1.4.7 h1:IXs+QLmnXW2CcXuY+8Mzv/fWEsPGWxqefPtCP5CnV9I= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/gammazero/deque v0.0.0-20200721202602-07291166fe33 h1:UG4wNrJX9xSKnm/Gck5yTbxnOhpNleuE4MQRdmcGySo= @@ -472,8 +474,8 @@ github.com/mailru/easyjson v0.7.1/go.mod h1:KAzv3t3aY1NaHWoQz1+4F1ccyAH66Jk7yos7 github.com/mailru/easyjson v0.7.6/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= -github.com/mark3labs/mcp-go v0.18.0 h1:YuhgIVjNlTG2ZOwmrkORWyPTp0dz1opPEqvsPtySXao= -github.com/mark3labs/mcp-go v0.18.0/go.mod h1:KmJndYv7GIgcPVwEKJjNcbhVQ+hJGJhrCCB/9xITzpE= +github.com/mark3labs/mcp-go v0.33.0 h1:naxhjnTIs/tyPZmWUZFuG0lDmdA6sUyYGGf3gsHvTCc= +github.com/mark3labs/mcp-go v0.33.0/go.mod h1:rXqOudj/djTORU/ThxYx8fqEVj/5pvTuuebQ2RC7uk4= github.com/markbates/oncer v0.0.0-20181203154359-bf2de49a0be2/go.mod h1:Ld9puTsIW75CHf65OeIOkyKbteujpZVXDpWK6YGZbxE= github.com/markbates/safe v1.0.1/go.mod h1:nAqgmRi7cY2nqMc92/bSEeQA+R4OheNU2T1kNSCBdG0= github.com/maruel/natural v1.1.0 h1:2z1NgP/Vae+gYrtC0VuvrTJ6U35OuyUqDdfluLqMWuQ= @@ -628,8 +630,9 @@ github.com/sosodev/duration v1.3.1/go.mod h1:RQIBBX0+fMLc/D9+Jb/fwvVmo0eZvDDEERA github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72 h1:qLC7fQah7D6K1B0ujays3HV9gkFtllcxhzImRR7ArPQ= github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA= github.com/spf13/afero v1.1.2/go.mod h1:j4pytiNVoe2o6bmDsKpLACNPDBIoEAkihy7loJ1B0CQ= -github.com/spf13/cast v1.3.0 h1:oget//CVOEoFewqQxwr0Ej5yjygnqGkvggSE/gB35Q8= github.com/spf13/cast v1.3.0/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE= +github.com/spf13/cast v1.9.2 h1:SsGfm7M8QOFtEzumm7UZrZdLLquNdzFYfIbEXntcFbE= +github.com/spf13/cast v1.9.2/go.mod h1:jNfB8QC9IA6ZuY2ZjDp0KtFO2LZZlg4S/7bzP6qqeHo= github.com/spf13/cobra v0.0.3/go.mod h1:1l0Ry5zgKvJasoi3XT1TypsSe7PqH0Sj9dhYf7v3XqQ= github.com/spf13/cobra v1.1.1 h1:KfztREH0tPxJJ+geloSLaAkaPkr4ki2Er5quFV1TDo4= github.com/spf13/cobra v1.1.1/go.mod h1:WnodtKOvamDL/PwE2M4iKs8aMDBZ5Q5klgD3qfVJQMI= diff --git a/internal/output/json.go b/internal/output/json.go index ad2b8e2939..a7dbc41eda 100644 --- a/internal/output/json.go +++ b/internal/output/json.go @@ -19,11 +19,12 @@ import ( type JSON struct { cfg *Config wroteOutput bool + history *OutputHistory } // NewJSON constructs a new JSON struct func NewJSON(config *Config) (JSON, error) { - return JSON{cfg: config}, nil + return JSON{cfg: config, history: &OutputHistory{}}, nil } // Type tells callers what type of outputer we are @@ -38,7 +39,11 @@ func (f *JSON) Print(v interface{}) { return } - f.Fprint(f.cfg.OutWriter, v) + w := NewWriteProxy(f.cfg.OutWriter, func(p []byte) { + f.history.Print = append(f.history.Print, string(p)) + }) + + f.Fprint(w, v) } // Fprint allows printing to a specific writer, using all the conveniences of the output package @@ -97,7 +102,11 @@ func (f *JSON) Error(value interface{}) { } b = []byte(colorize.StripColorCodes(string(b))) - _, err = f.cfg.OutWriter.Write(b) + w := NewWriteProxy(f.cfg.OutWriter, func(p []byte) { + f.history.Error = append(f.history.Error, string(p)) + }) + + _, err = w.Write(b) if err != nil { if isPipeClosedError(err) { logging.Error("Could not write json output, error: %v", err) // do not log to rollbar @@ -107,6 +116,10 @@ func (f *JSON) Error(value interface{}) { } } +func (f *JSON) History() *OutputHistory { + return f.history +} + func isPipeClosedError(err error) bool { pipeErr := errors.Is(err, syscall.EPIPE) if runtime.GOOS == "windows" && errors.Is(err, syscall.Errno(232)) { diff --git a/internal/output/mediator.go b/internal/output/mediator.go index 4dd71a69a8..fbc0cac065 100644 --- a/internal/output/mediator.go +++ b/internal/output/mediator.go @@ -52,6 +52,10 @@ func (m *Mediator) Notice(v interface{}) { m.Outputer.Notice(v) } +func (m *Mediator) History() *OutputHistory { + return m.Outputer.History() +} + func mediatorValue(v interface{}, format Format) interface{} { if format.IsStructured() { if vt, ok := v.(StructuredMarshaller); ok { diff --git a/internal/output/output.go b/internal/output/output.go index 529f2cea3c..282cd50125 100644 --- a/internal/output/output.go +++ b/internal/output/output.go @@ -29,6 +29,12 @@ const ( var ErrNotRecognized = errs.New("Not Recognized") +type OutputHistory struct { + Print []string + Error []string + Notice []string +} + // Outputer is the initialized formatter type Outputer interface { Fprint(writer io.Writer, value interface{}) @@ -37,6 +43,7 @@ type Outputer interface { Notice(value interface{}) Type() Format Config() *Config + History() *OutputHistory } // lastCreated is here for specific legacy use cases diff --git a/internal/output/plain.go b/internal/output/plain.go index 14e8088bbb..c810dc330f 100644 --- a/internal/output/plain.go +++ b/internal/output/plain.go @@ -49,11 +49,12 @@ const ( // Struct keys are localized by sending them to the locale library as field_key (lowercase) type Plain struct { cfg *Config + history *OutputHistory`` } // NewPlain constructs a new Plain struct func NewPlain(config *Config) (Plain, error) { - return Plain{config}, nil + return Plain{cfg: config, history: &OutputHistory{}}, nil } // Type tells callers what type of outputer we are @@ -68,20 +69,29 @@ func (f *Plain) Fprint(writer io.Writer, v interface{}) { // Print will marshal and print the given value to the output writer func (f *Plain) Print(value interface{}) { - f.write(f.cfg.OutWriter, value) - f.write(f.cfg.OutWriter, "\n") + w := NewWriteProxy(f.cfg.OutWriter, func(p []byte) { + f.history.Print = append(f.history.Print, string(p)) + }) + f.write(w, value) + f.write(w, "\n") } // Error will marshal and print the given value to the error writer, it wraps it in the error format but otherwise the // only thing that identifies it as an error is the channel it writes it to func (f *Plain) Error(value interface{}) { - f.write(f.cfg.ErrWriter, fmt.Sprintf("[ERROR]%s[/RESET]\n", value)) + w := NewWriteProxy(f.cfg.ErrWriter, func(p []byte) { + f.history.Error = append(f.history.Error, string(p)) + }) + f.write(w, fmt.Sprintf("[ERROR]%s[/RESET]\n", value)) } // Notice will marshal and print the given value to the error writer, it wraps it in the notice format but otherwise the // only thing that identifies it as an error is the channel it writes it to func (f *Plain) Notice(value interface{}) { - f.write(f.cfg.ErrWriter, fmt.Sprintf("%s\n", value)) + w := NewWriteProxy(f.cfg.ErrWriter, func(p []byte) { + f.history.Notice = append(f.history.Notice, string(p)) + }) + f.write(w, fmt.Sprintf("%s\n", value)) } // Config returns the Config struct for the active instance @@ -89,6 +99,10 @@ func (f *Plain) Config() *Config { return f.cfg } +func (f *Plain) History() *OutputHistory { + return f.history +} + // write is a little helper that just takes care of marshalling the value and sending it to the requested writer func (f *Plain) write(writer io.Writer, value interface{}) { v, err := sprint(value, nil) diff --git a/internal/output/writer.go b/internal/output/writer.go new file mode 100644 index 0000000000..6613dd6605 --- /dev/null +++ b/internal/output/writer.go @@ -0,0 +1,17 @@ +package output + +import "io" + +type WriteProxy struct { + w io.Writer + onWrite func(p []byte) +} + +func NewWriteProxy(w io.Writer, onWrite func(p []byte)) *WriteProxy { + return &WriteProxy{w: w, onWrite: onWrite} +} + +func (w *WriteProxy) Write(p []byte) (n int, err error) { + w.onWrite(p) + return w.w.Write(p) +} \ No newline at end of file diff --git a/internal/runners/hello/hello_example.go b/internal/runners/hello/hello_example.go index a13ee016c3..b8cbe22bc0 100644 --- a/internal/runners/hello/hello_example.go +++ b/internal/runners/hello/hello_example.go @@ -10,16 +10,13 @@ package hello import ( "errors" + "github.com/ActiveState/cli/internal/constants" "github.com/ActiveState/cli/internal/errs" "github.com/ActiveState/cli/internal/locale" "github.com/ActiveState/cli/internal/output" "github.com/ActiveState/cli/internal/primer" "github.com/ActiveState/cli/internal/runbits/example" - "github.com/ActiveState/cli/internal/runbits/rationalize" - "github.com/ActiveState/cli/pkg/localcommit" "github.com/ActiveState/cli/pkg/platform/authentication" - "github.com/ActiveState/cli/pkg/platform/model" - "github.com/ActiveState/cli/pkg/project" ) // primeable describes the app-level dependencies that a runner will need. @@ -49,7 +46,6 @@ func NewParams() *Params { // function. type Hello struct { out output.Outputer - project *project.Project auth *authentication.Auth } @@ -58,7 +54,6 @@ type Hello struct { func New(p primeable) *Hello { return &Hello{ out: p.Output(), - project: p.Project(), auth: p.Auth(), } } @@ -77,15 +72,6 @@ func rationalizeError(err *error) { // Ensure we wrap the top-level error returned from the runner and not // the unpacked error that we are inspecting. *err = errs.WrapUserFacing(*err, locale.Tl("hello_err_no_name", "Cannot say hello because no name was provided.")) - case errors.Is(*err, rationalize.ErrNoProject): - // It's useful to offer users reasonable tips on recourses. - *err = errs.WrapUserFacing( - *err, - locale.Tl("hello_err_no_project", "Cannot say hello because you are not in a project directory."), - errs.SetTips( - locale.Tl("hello_suggest_checkout", "Try using '[ACTIONABLE]state checkout[/RESET]' first."), - ), - ) } } @@ -95,10 +81,6 @@ func (h *Hello) Run(params *Params) (rerr error) { h.out.Print(locale.Tl("hello_notice", "This command is for example use only")) - if h.project == nil { - return rationalize.ErrNoProject - } - // Reusable runner logic is contained within the runbits package. // You should only use this if you intend to share logic between // runners. Runners should NEVER invoke other runners. @@ -120,50 +102,11 @@ func (h *Hello) Run(params *Params) (rerr error) { return nil } - // Grab data from the platform. - commitMsg, err := currentCommitMessage(h.project, h.auth) - if err != nil { - err = errs.Wrap( - err, "Cannot get commit message", - ) - return errs.AddTips( - err, - locale.Tl("hello_info_suggest_ensure_commit", "Ensure project has commits"), - ) - } - h.out.Print(locale.Tl( "hello_extra_info", - "Project: {{.V0}}\nCurrent commit message: {{.V1}}", - h.project.Namespace().String(), commitMsg, + "You are on commit {{.V0}}", + constants.RevisionHashShort, )) return nil -} - -// currentCommitMessage contains the scope in which the current commit message -// is obtained. Since it is a sort of construction function that has some -// complexity, it is helpful to provide localized error context. Secluding this -// sort of logic is helpful to keep the subhandlers clean. -func currentCommitMessage(proj *project.Project, auth *authentication.Auth) (string, error) { - if proj == nil { - return "", errs.New("Cannot determine which project to use") - } - - commitId, err := localcommit.Get(proj.Dir()) - if err != nil { - return "", errs.Wrap(err, "Cannot determine which commit to use") - } - - commit, err := model.GetCommit(commitId, auth) - if err != nil { - return "", errs.Wrap(err, "Cannot get commit from server") - } - - commitMsg := locale.Tl("hello_info_warn_no_commit", "Commit description not provided.") - if commit.Message != "" { - commitMsg = commit.Message - } - - return commitMsg, nil -} +} \ No newline at end of file diff --git a/internal/testhelpers/outputhelper/outputer.go b/internal/testhelpers/outputhelper/outputer.go index 18a9b30da2..7e0bdb30c9 100644 --- a/internal/testhelpers/outputhelper/outputer.go +++ b/internal/testhelpers/outputhelper/outputer.go @@ -14,3 +14,4 @@ func (o *TestOutputer) Fprint(w io.Writer, value interface{}) {} func (o *TestOutputer) Error(value interface{}) {} func (o *TestOutputer) Notice(value interface{}) {} func (o *TestOutputer) Config() *output.Config { return nil } +func (o *TestOutputer) History() *output.OutputHistory { return nil } \ No newline at end of file diff --git a/vendor/github.com/mark3labs/mcp-go/client/client.go b/vendor/github.com/mark3labs/mcp-go/client/client.go deleted file mode 100644 index 1d3cb1051e..0000000000 --- a/vendor/github.com/mark3labs/mcp-go/client/client.go +++ /dev/null @@ -1,84 +0,0 @@ -// Package client provides MCP (Model Control Protocol) client implementations. -package client - -import ( - "context" - - "github.com/mark3labs/mcp-go/mcp" -) - -// MCPClient represents an MCP client interface -type MCPClient interface { - // Initialize sends the initial connection request to the server - Initialize( - ctx context.Context, - request mcp.InitializeRequest, - ) (*mcp.InitializeResult, error) - - // Ping checks if the server is alive - Ping(ctx context.Context) error - - // ListResources requests a list of available resources from the server - ListResources( - ctx context.Context, - request mcp.ListResourcesRequest, - ) (*mcp.ListResourcesResult, error) - - // ListResourceTemplates requests a list of available resource templates from the server - ListResourceTemplates( - ctx context.Context, - request mcp.ListResourceTemplatesRequest, - ) (*mcp.ListResourceTemplatesResult, - error) - - // ReadResource reads a specific resource from the server - ReadResource( - ctx context.Context, - request mcp.ReadResourceRequest, - ) (*mcp.ReadResourceResult, error) - - // Subscribe requests notifications for changes to a specific resource - Subscribe(ctx context.Context, request mcp.SubscribeRequest) error - - // Unsubscribe cancels notifications for a specific resource - Unsubscribe(ctx context.Context, request mcp.UnsubscribeRequest) error - - // ListPrompts requests a list of available prompts from the server - ListPrompts( - ctx context.Context, - request mcp.ListPromptsRequest, - ) (*mcp.ListPromptsResult, error) - - // GetPrompt retrieves a specific prompt from the server - GetPrompt( - ctx context.Context, - request mcp.GetPromptRequest, - ) (*mcp.GetPromptResult, error) - - // ListTools requests a list of available tools from the server - ListTools( - ctx context.Context, - request mcp.ListToolsRequest, - ) (*mcp.ListToolsResult, error) - - // CallTool invokes a specific tool on the server - CallTool( - ctx context.Context, - request mcp.CallToolRequest, - ) (*mcp.CallToolResult, error) - - // SetLevel sets the logging level for the server - SetLevel(ctx context.Context, request mcp.SetLevelRequest) error - - // Complete requests completion options for a given argument - Complete( - ctx context.Context, - request mcp.CompleteRequest, - ) (*mcp.CompleteResult, error) - - // Close client connection and cleanup resources - Close() error - - // OnNotification registers a handler for notifications - OnNotification(handler func(notification mcp.JSONRPCNotification)) -} diff --git a/vendor/github.com/mark3labs/mcp-go/client/sse.go b/vendor/github.com/mark3labs/mcp-go/client/sse.go deleted file mode 100644 index cf4a1028e0..0000000000 --- a/vendor/github.com/mark3labs/mcp-go/client/sse.go +++ /dev/null @@ -1,588 +0,0 @@ -package client - -import ( - "bufio" - "bytes" - "context" - "encoding/json" - "errors" - "fmt" - "io" - "net/http" - "net/url" - "strings" - "sync" - "sync/atomic" - "time" - - "github.com/mark3labs/mcp-go/mcp" -) - -// SSEMCPClient implements the MCPClient interface using Server-Sent Events (SSE). -// It maintains a persistent HTTP connection to receive server-pushed events -// while sending requests over regular HTTP POST calls. The client handles -// automatic reconnection and message routing between requests and responses. -type SSEMCPClient struct { - baseURL *url.URL - endpoint *url.URL - httpClient *http.Client - requestID atomic.Int64 - responses map[int64]chan RPCResponse - mu sync.RWMutex - done chan struct{} - initialized bool - notifications []func(mcp.JSONRPCNotification) - notifyMu sync.RWMutex - endpointChan chan struct{} - capabilities mcp.ServerCapabilities - headers map[string]string - sseReadTimeout time.Duration -} - -type ClientOption func(*SSEMCPClient) - -func WithHeaders(headers map[string]string) ClientOption { - return func(sc *SSEMCPClient) { - sc.headers = headers - } -} - -func WithSSEReadTimeout(timeout time.Duration) ClientOption { - return func(sc *SSEMCPClient) { - sc.sseReadTimeout = timeout - } -} - -// NewSSEMCPClient creates a new SSE-based MCP client with the given base URL. -// Returns an error if the URL is invalid. -func NewSSEMCPClient(baseURL string, options ...ClientOption) (*SSEMCPClient, error) { - parsedURL, err := url.Parse(baseURL) - if err != nil { - return nil, fmt.Errorf("invalid URL: %w", err) - } - - smc := &SSEMCPClient{ - baseURL: parsedURL, - httpClient: &http.Client{}, - responses: make(map[int64]chan RPCResponse), - done: make(chan struct{}), - endpointChan: make(chan struct{}), - sseReadTimeout: 30 * time.Second, - headers: make(map[string]string), - } - - for _, opt := range options { - opt(smc) - } - - return smc, nil -} - -// Start initiates the SSE connection to the server and waits for the endpoint information. -// Returns an error if the connection fails or times out waiting for the endpoint. -func (c *SSEMCPClient) Start(ctx context.Context) error { - - req, err := http.NewRequestWithContext(ctx, "GET", c.baseURL.String(), nil) - - if err != nil { - - return fmt.Errorf("failed to create request: %w", err) - - } - - req.Header.Set("Accept", "text/event-stream") - req.Header.Set("Cache-Control", "no-cache") - req.Header.Set("Connection", "keep-alive") - - resp, err := c.httpClient.Do(req) - if err != nil { - return fmt.Errorf("failed to connect to SSE stream: %w", err) - } - - if resp.StatusCode != http.StatusOK { - resp.Body.Close() - return fmt.Errorf("unexpected status code: %d", resp.StatusCode) - } - - go c.readSSE(resp.Body) - - // Wait for the endpoint to be received - - select { - case <-c.endpointChan: - // Endpoint received, proceed - case <-ctx.Done(): - return fmt.Errorf("context cancelled while waiting for endpoint") - case <-time.After(30 * time.Second): // Add a timeout - return fmt.Errorf("timeout waiting for endpoint") - } - - return nil -} - -// readSSE continuously reads the SSE stream and processes events. -// It runs until the connection is closed or an error occurs. -func (c *SSEMCPClient) readSSE(reader io.ReadCloser) { - defer reader.Close() - - br := bufio.NewReader(reader) - var event, data string - - ctx, cancel := context.WithTimeout(context.Background(), c.sseReadTimeout) - defer cancel() - - for { - select { - case <-ctx.Done(): - return - default: - line, err := br.ReadString('\n') - if err != nil { - if err == io.EOF { - // Process any pending event before exit - if event != "" && data != "" { - c.handleSSEEvent(event, data) - } - break - } - select { - case <-c.done: - return - default: - fmt.Printf("SSE stream error: %v\n", err) - return - } - } - - // Remove only newline markers - line = strings.TrimRight(line, "\r\n") - if line == "" { - // Empty line means end of event - if event != "" && data != "" { - c.handleSSEEvent(event, data) - event = "" - data = "" - } - continue - } - - if strings.HasPrefix(line, "event:") { - event = strings.TrimSpace(strings.TrimPrefix(line, "event:")) - } else if strings.HasPrefix(line, "data:") { - data = strings.TrimSpace(strings.TrimPrefix(line, "data:")) - } - } - } -} - -// handleSSEEvent processes SSE events based on their type. -// Handles 'endpoint' events for connection setup and 'message' events for JSON-RPC communication. -func (c *SSEMCPClient) handleSSEEvent(event, data string) { - switch event { - case "endpoint": - endpoint, err := c.baseURL.Parse(data) - if err != nil { - fmt.Printf("Error parsing endpoint URL: %v\n", err) - return - } - if endpoint.Host != c.baseURL.Host { - fmt.Printf("Endpoint origin does not match connection origin\n") - return - } - c.endpoint = endpoint - close(c.endpointChan) - - case "message": - var baseMessage struct { - JSONRPC string `json:"jsonrpc"` - ID *int64 `json:"id,omitempty"` - Method string `json:"method,omitempty"` - Result json.RawMessage `json:"result,omitempty"` - Error *struct { - Code int `json:"code"` - Message string `json:"message"` - } `json:"error,omitempty"` - } - - if err := json.Unmarshal([]byte(data), &baseMessage); err != nil { - fmt.Printf("Error unmarshaling message: %v\n", err) - return - } - - // Handle notification - if baseMessage.ID == nil { - var notification mcp.JSONRPCNotification - if err := json.Unmarshal([]byte(data), ¬ification); err != nil { - return - } - c.notifyMu.RLock() - for _, handler := range c.notifications { - handler(notification) - } - c.notifyMu.RUnlock() - return - } - - c.mu.RLock() - ch, ok := c.responses[*baseMessage.ID] - c.mu.RUnlock() - - if ok { - if baseMessage.Error != nil { - ch <- RPCResponse{ - Error: &baseMessage.Error.Message, - } - } else { - ch <- RPCResponse{ - Response: &baseMessage.Result, - } - } - c.mu.Lock() - delete(c.responses, *baseMessage.ID) - c.mu.Unlock() - } - } -} - -// OnNotification registers a handler function to be called when notifications are received. -// Multiple handlers can be registered and will be called in the order they were added. -func (c *SSEMCPClient) OnNotification( - handler func(notification mcp.JSONRPCNotification), -) { - c.notifyMu.Lock() - defer c.notifyMu.Unlock() - c.notifications = append(c.notifications, handler) -} - -// sendRequest sends a JSON-RPC request to the server and waits for a response. -// Returns the raw JSON response message or an error if the request fails. -func (c *SSEMCPClient) sendRequest( - ctx context.Context, - method string, - params interface{}, -) (*json.RawMessage, error) { - if !c.initialized && method != "initialize" { - return nil, fmt.Errorf("client not initialized") - } - - if c.endpoint == nil { - return nil, fmt.Errorf("endpoint not received") - } - - id := c.requestID.Add(1) - - request := mcp.JSONRPCRequest{ - JSONRPC: mcp.JSONRPC_VERSION, - ID: id, - Request: mcp.Request{ - Method: method, - }, - Params: params, - } - - requestBytes, err := json.Marshal(request) - if err != nil { - return nil, fmt.Errorf("failed to marshal request: %w", err) - } - - responseChan := make(chan RPCResponse, 1) - c.mu.Lock() - c.responses[id] = responseChan - c.mu.Unlock() - - req, err := http.NewRequestWithContext( - ctx, - "POST", - c.endpoint.String(), - bytes.NewReader(requestBytes), - ) - if err != nil { - return nil, fmt.Errorf("failed to create request: %w", err) - } - - req.Header.Set("Content-Type", "application/json") - // set custom http headers - for k, v := range c.headers { - req.Header.Set(k, v) - } - - resp, err := c.httpClient.Do(req) - if err != nil { - return nil, fmt.Errorf("failed to send request: %w", err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK && - resp.StatusCode != http.StatusAccepted { - body, _ := io.ReadAll(resp.Body) - return nil, fmt.Errorf( - "request failed with status %d: %s", - resp.StatusCode, - body, - ) - } - - select { - case <-ctx.Done(): - c.mu.Lock() - delete(c.responses, id) - c.mu.Unlock() - return nil, ctx.Err() - case response := <-responseChan: - if response.Error != nil { - return nil, errors.New(*response.Error) - } - return response.Response, nil - } -} - -func (c *SSEMCPClient) Initialize( - ctx context.Context, - request mcp.InitializeRequest, -) (*mcp.InitializeResult, error) { - // Ensure we send a params object with all required fields - params := struct { - ProtocolVersion string `json:"protocolVersion"` - ClientInfo mcp.Implementation `json:"clientInfo"` - Capabilities mcp.ClientCapabilities `json:"capabilities"` - }{ - ProtocolVersion: request.Params.ProtocolVersion, - ClientInfo: request.Params.ClientInfo, - Capabilities: request.Params.Capabilities, // Will be empty struct if not set - } - - response, err := c.sendRequest(ctx, "initialize", params) - if err != nil { - return nil, err - } - - var result mcp.InitializeResult - if err := json.Unmarshal(*response, &result); err != nil { - return nil, fmt.Errorf("failed to unmarshal response: %w", err) - } - - // Store capabilities - c.capabilities = result.Capabilities - - // Send initialized notification - notification := mcp.JSONRPCNotification{ - JSONRPC: mcp.JSONRPC_VERSION, - Notification: mcp.Notification{ - Method: "notifications/initialized", - }, - } - - notificationBytes, err := json.Marshal(notification) - if err != nil { - return nil, fmt.Errorf( - "failed to marshal initialized notification: %w", - err, - ) - } - - req, err := http.NewRequestWithContext( - ctx, - "POST", - c.endpoint.String(), - bytes.NewReader(notificationBytes), - ) - if err != nil { - return nil, fmt.Errorf("failed to create notification request: %w", err) - } - - req.Header.Set("Content-Type", "application/json") - - resp, err := c.httpClient.Do(req) - if err != nil { - return nil, fmt.Errorf( - "failed to send initialized notification: %w", - err, - ) - } - resp.Body.Close() - - c.initialized = true - return &result, nil -} - -func (c *SSEMCPClient) Ping(ctx context.Context) error { - _, err := c.sendRequest(ctx, "ping", nil) - return err -} - -func (c *SSEMCPClient) ListResources( - ctx context.Context, - request mcp.ListResourcesRequest, -) (*mcp.ListResourcesResult, error) { - response, err := c.sendRequest(ctx, "resources/list", request.Params) - if err != nil { - return nil, err - } - - var result mcp.ListResourcesResult - if err := json.Unmarshal(*response, &result); err != nil { - return nil, fmt.Errorf("failed to unmarshal response: %w", err) - } - - return &result, nil -} - -func (c *SSEMCPClient) ListResourceTemplates( - ctx context.Context, - request mcp.ListResourceTemplatesRequest, -) (*mcp.ListResourceTemplatesResult, error) { - response, err := c.sendRequest( - ctx, - "resources/templates/list", - request.Params, - ) - if err != nil { - return nil, err - } - - var result mcp.ListResourceTemplatesResult - if err := json.Unmarshal(*response, &result); err != nil { - return nil, fmt.Errorf("failed to unmarshal response: %w", err) - } - - return &result, nil -} - -func (c *SSEMCPClient) ReadResource( - ctx context.Context, - request mcp.ReadResourceRequest, -) (*mcp.ReadResourceResult, error) { - response, err := c.sendRequest(ctx, "resources/read", request.Params) - if err != nil { - return nil, err - } - - return mcp.ParseReadResourceResult(response) -} - -func (c *SSEMCPClient) Subscribe( - ctx context.Context, - request mcp.SubscribeRequest, -) error { - _, err := c.sendRequest(ctx, "resources/subscribe", request.Params) - return err -} - -func (c *SSEMCPClient) Unsubscribe( - ctx context.Context, - request mcp.UnsubscribeRequest, -) error { - _, err := c.sendRequest(ctx, "resources/unsubscribe", request.Params) - return err -} - -func (c *SSEMCPClient) ListPrompts( - ctx context.Context, - request mcp.ListPromptsRequest, -) (*mcp.ListPromptsResult, error) { - response, err := c.sendRequest(ctx, "prompts/list", request.Params) - if err != nil { - return nil, err - } - - var result mcp.ListPromptsResult - if err := json.Unmarshal(*response, &result); err != nil { - return nil, fmt.Errorf("failed to unmarshal response: %w", err) - } - - return &result, nil -} - -func (c *SSEMCPClient) GetPrompt( - ctx context.Context, - request mcp.GetPromptRequest, -) (*mcp.GetPromptResult, error) { - response, err := c.sendRequest(ctx, "prompts/get", request.Params) - if err != nil { - return nil, err - } - - return mcp.ParseGetPromptResult(response) -} - -func (c *SSEMCPClient) ListTools( - ctx context.Context, - request mcp.ListToolsRequest, -) (*mcp.ListToolsResult, error) { - response, err := c.sendRequest(ctx, "tools/list", request.Params) - if err != nil { - return nil, err - } - - var result mcp.ListToolsResult - if err := json.Unmarshal(*response, &result); err != nil { - return nil, fmt.Errorf("failed to unmarshal response: %w", err) - } - - return &result, nil -} - -func (c *SSEMCPClient) CallTool( - ctx context.Context, - request mcp.CallToolRequest, -) (*mcp.CallToolResult, error) { - response, err := c.sendRequest(ctx, "tools/call", request.Params) - if err != nil { - return nil, err - } - - return mcp.ParseCallToolResult(response) -} - -func (c *SSEMCPClient) SetLevel( - ctx context.Context, - request mcp.SetLevelRequest, -) error { - _, err := c.sendRequest(ctx, "logging/setLevel", request.Params) - return err -} - -func (c *SSEMCPClient) Complete( - ctx context.Context, - request mcp.CompleteRequest, -) (*mcp.CompleteResult, error) { - response, err := c.sendRequest(ctx, "completion/complete", request.Params) - if err != nil { - return nil, err - } - - var result mcp.CompleteResult - if err := json.Unmarshal(*response, &result); err != nil { - return nil, fmt.Errorf("failed to unmarshal response: %w", err) - } - - return &result, nil -} - -// Helper methods - -// GetEndpoint returns the current endpoint URL for the SSE connection. -func (c *SSEMCPClient) GetEndpoint() *url.URL { - return c.endpoint -} - -// Close shuts down the SSE client connection and cleans up any pending responses. -// Returns an error if the shutdown process fails. -func (c *SSEMCPClient) Close() error { - select { - case <-c.done: - return nil // Already closed - default: - close(c.done) - } - - // Clean up any pending responses - c.mu.Lock() - for _, ch := range c.responses { - close(ch) - } - c.responses = make(map[int64]chan RPCResponse) - c.mu.Unlock() - - return nil -} diff --git a/vendor/github.com/mark3labs/mcp-go/client/stdio.go b/vendor/github.com/mark3labs/mcp-go/client/stdio.go deleted file mode 100644 index 8e0845dca6..0000000000 --- a/vendor/github.com/mark3labs/mcp-go/client/stdio.go +++ /dev/null @@ -1,457 +0,0 @@ -package client - -import ( - "bufio" - "context" - "encoding/json" - "errors" - "fmt" - "io" - "os" - "os/exec" - "sync" - "sync/atomic" - - "github.com/mark3labs/mcp-go/mcp" -) - -// StdioMCPClient implements the MCPClient interface using stdio communication. -// It launches a subprocess and communicates with it via standard input/output streams -// using JSON-RPC messages. The client handles message routing between requests and -// responses, and supports asynchronous notifications. -type StdioMCPClient struct { - cmd *exec.Cmd - stdin io.WriteCloser - stdout *bufio.Reader - stderr io.ReadCloser - requestID atomic.Int64 - responses map[int64]chan RPCResponse - mu sync.RWMutex - done chan struct{} - initialized bool - notifications []func(mcp.JSONRPCNotification) - notifyMu sync.RWMutex - capabilities mcp.ServerCapabilities -} - -// NewStdioMCPClient creates a new stdio-based MCP client that communicates with a subprocess. -// It launches the specified command with given arguments and sets up stdin/stdout pipes for communication. -// Returns an error if the subprocess cannot be started or the pipes cannot be created. -func NewStdioMCPClient( - command string, - env []string, - args ...string, -) (*StdioMCPClient, error) { - cmd := exec.Command(command, args...) - - mergedEnv := os.Environ() - mergedEnv = append(mergedEnv, env...) - - cmd.Env = mergedEnv - - stdin, err := cmd.StdinPipe() - if err != nil { - return nil, fmt.Errorf("failed to create stdin pipe: %w", err) - } - - stdout, err := cmd.StdoutPipe() - if err != nil { - return nil, fmt.Errorf("failed to create stdout pipe: %w", err) - } - - stderr, err := cmd.StderrPipe() - if err != nil { - return nil, fmt.Errorf("failed to create stderr pipe: %w", err) - } - - client := &StdioMCPClient{ - cmd: cmd, - stdin: stdin, - stderr: stderr, - stdout: bufio.NewReader(stdout), - responses: make(map[int64]chan RPCResponse), - done: make(chan struct{}), - } - - if err := cmd.Start(); err != nil { - return nil, fmt.Errorf("failed to start command: %w", err) - } - - // Start reading responses in a goroutine and wait for it to be ready - ready := make(chan struct{}) - go func() { - close(ready) - client.readResponses() - }() - <-ready - - return client, nil -} - -// Close shuts down the stdio client, closing the stdin pipe and waiting for the subprocess to exit. -// Returns an error if there are issues closing stdin or waiting for the subprocess to terminate. -func (c *StdioMCPClient) Close() error { - close(c.done) - if err := c.stdin.Close(); err != nil { - return fmt.Errorf("failed to close stdin: %w", err) - } - if err := c.stderr.Close(); err != nil { - return fmt.Errorf("failed to close stderr: %w", err) - } - return c.cmd.Wait() -} - -// Stderr returns a reader for the stderr output of the subprocess. -// This can be used to capture error messages or logs from the subprocess. -func (c *StdioMCPClient) Stderr() io.Reader { - return c.stderr -} - -// OnNotification registers a handler function to be called when notifications are received. -// Multiple handlers can be registered and will be called in the order they were added. -func (c *StdioMCPClient) OnNotification( - handler func(notification mcp.JSONRPCNotification), -) { - c.notifyMu.Lock() - defer c.notifyMu.Unlock() - c.notifications = append(c.notifications, handler) -} - -// readResponses continuously reads and processes responses from the server's stdout. -// It handles both responses to requests and notifications, routing them appropriately. -// Runs until the done channel is closed or an error occurs reading from stdout. -func (c *StdioMCPClient) readResponses() { - for { - select { - case <-c.done: - return - default: - line, err := c.stdout.ReadString('\n') - if err != nil { - if err != io.EOF { - fmt.Printf("Error reading response: %v\n", err) - } - return - } - - var baseMessage struct { - JSONRPC string `json:"jsonrpc"` - ID *int64 `json:"id,omitempty"` - Method string `json:"method,omitempty"` - Result json.RawMessage `json:"result,omitempty"` - Error *struct { - Code int `json:"code"` - Message string `json:"message"` - } `json:"error,omitempty"` - } - - if err := json.Unmarshal([]byte(line), &baseMessage); err != nil { - continue - } - - // Handle notification - if baseMessage.ID == nil { - var notification mcp.JSONRPCNotification - if err := json.Unmarshal([]byte(line), ¬ification); err != nil { - continue - } - c.notifyMu.RLock() - for _, handler := range c.notifications { - handler(notification) - } - c.notifyMu.RUnlock() - continue - } - - c.mu.RLock() - ch, ok := c.responses[*baseMessage.ID] - c.mu.RUnlock() - - if ok { - if baseMessage.Error != nil { - ch <- RPCResponse{ - Error: &baseMessage.Error.Message, - } - } else { - ch <- RPCResponse{ - Response: &baseMessage.Result, - } - } - c.mu.Lock() - delete(c.responses, *baseMessage.ID) - c.mu.Unlock() - } - } - } -} - -// sendRequest sends a JSON-RPC request to the server and waits for a response. -// It creates a unique request ID, sends the request over stdin, and waits for -// the corresponding response or context cancellation. -// Returns the raw JSON response message or an error if the request fails. -func (c *StdioMCPClient) sendRequest( - ctx context.Context, - method string, - params interface{}, -) (*json.RawMessage, error) { - if !c.initialized && method != "initialize" { - return nil, fmt.Errorf("client not initialized") - } - - id := c.requestID.Add(1) - - // Create the complete request structure - request := mcp.JSONRPCRequest{ - JSONRPC: mcp.JSONRPC_VERSION, - ID: id, - Request: mcp.Request{ - Method: method, - }, - Params: params, - } - - responseChan := make(chan RPCResponse, 1) - c.mu.Lock() - c.responses[id] = responseChan - c.mu.Unlock() - - requestBytes, err := json.Marshal(request) - if err != nil { - return nil, fmt.Errorf("failed to marshal request: %w", err) - } - requestBytes = append(requestBytes, '\n') - - if _, err := c.stdin.Write(requestBytes); err != nil { - return nil, fmt.Errorf("failed to write request: %w", err) - } - - select { - case <-ctx.Done(): - c.mu.Lock() - delete(c.responses, id) - c.mu.Unlock() - return nil, ctx.Err() - case response := <-responseChan: - if response.Error != nil { - return nil, errors.New(*response.Error) - } - return response.Response, nil - } -} - -func (c *StdioMCPClient) Ping(ctx context.Context) error { - _, err := c.sendRequest(ctx, "ping", nil) - return err -} - -func (c *StdioMCPClient) Initialize( - ctx context.Context, - request mcp.InitializeRequest, -) (*mcp.InitializeResult, error) { - // This structure ensures Capabilities is always included in JSON - params := struct { - ProtocolVersion string `json:"protocolVersion"` - ClientInfo mcp.Implementation `json:"clientInfo"` - Capabilities mcp.ClientCapabilities `json:"capabilities"` - }{ - ProtocolVersion: request.Params.ProtocolVersion, - ClientInfo: request.Params.ClientInfo, - Capabilities: request.Params.Capabilities, // Will be empty struct if not set - } - - response, err := c.sendRequest(ctx, "initialize", params) - if err != nil { - return nil, err - } - - var result mcp.InitializeResult - if err := json.Unmarshal(*response, &result); err != nil { - return nil, fmt.Errorf("failed to unmarshal response: %w", err) - } - - // Store capabilities - c.capabilities = result.Capabilities - - // Send initialized notification - notification := mcp.JSONRPCNotification{ - JSONRPC: mcp.JSONRPC_VERSION, - Notification: mcp.Notification{ - Method: "notifications/initialized", - }, - } - - notificationBytes, err := json.Marshal(notification) - if err != nil { - return nil, fmt.Errorf( - "failed to marshal initialized notification: %w", - err, - ) - } - notificationBytes = append(notificationBytes, '\n') - - if _, err := c.stdin.Write(notificationBytes); err != nil { - return nil, fmt.Errorf( - "failed to send initialized notification: %w", - err, - ) - } - - c.initialized = true - return &result, nil -} - -func (c *StdioMCPClient) ListResources( - ctx context.Context, - request mcp.ListResourcesRequest, -) (*mcp. - ListResourcesResult, error) { - response, err := c.sendRequest( - ctx, - "resources/list", - request.Params, - ) - if err != nil { - return nil, err - } - - var result mcp.ListResourcesResult - if err := json.Unmarshal(*response, &result); err != nil { - return nil, fmt.Errorf("failed to unmarshal response: %w", err) - } - - return &result, nil -} - -func (c *StdioMCPClient) ListResourceTemplates( - ctx context.Context, - request mcp.ListResourceTemplatesRequest, -) (*mcp. - ListResourceTemplatesResult, error) { - response, err := c.sendRequest( - ctx, - "resources/templates/list", - request.Params, - ) - if err != nil { - return nil, err - } - - var result mcp.ListResourceTemplatesResult - if err := json.Unmarshal(*response, &result); err != nil { - return nil, fmt.Errorf("failed to unmarshal response: %w", err) - } - - return &result, nil -} - -func (c *StdioMCPClient) ReadResource( - ctx context.Context, - request mcp.ReadResourceRequest, -) (*mcp.ReadResourceResult, - error) { - response, err := c.sendRequest(ctx, "resources/read", request.Params) - if err != nil { - return nil, err - } - - return mcp.ParseReadResourceResult(response) -} - -func (c *StdioMCPClient) Subscribe( - ctx context.Context, - request mcp.SubscribeRequest, -) error { - _, err := c.sendRequest(ctx, "resources/subscribe", request.Params) - return err -} - -func (c *StdioMCPClient) Unsubscribe( - ctx context.Context, - request mcp.UnsubscribeRequest, -) error { - _, err := c.sendRequest(ctx, "resources/unsubscribe", request.Params) - return err -} - -func (c *StdioMCPClient) ListPrompts( - ctx context.Context, - request mcp.ListPromptsRequest, -) (*mcp.ListPromptsResult, error) { - response, err := c.sendRequest(ctx, "prompts/list", request.Params) - if err != nil { - return nil, err - } - - var result mcp.ListPromptsResult - if err := json.Unmarshal(*response, &result); err != nil { - return nil, fmt.Errorf("failed to unmarshal response: %w", err) - } - - return &result, nil -} - -func (c *StdioMCPClient) GetPrompt( - ctx context.Context, - request mcp.GetPromptRequest, -) (*mcp.GetPromptResult, error) { - response, err := c.sendRequest(ctx, "prompts/get", request.Params) - if err != nil { - return nil, err - } - - return mcp.ParseGetPromptResult(response) -} - -func (c *StdioMCPClient) ListTools( - ctx context.Context, - request mcp.ListToolsRequest, -) (*mcp.ListToolsResult, error) { - response, err := c.sendRequest(ctx, "tools/list", request.Params) - if err != nil { - return nil, err - } - - var result mcp.ListToolsResult - if err := json.Unmarshal(*response, &result); err != nil { - return nil, fmt.Errorf("failed to unmarshal response: %w", err) - } - - return &result, nil -} - -func (c *StdioMCPClient) CallTool( - ctx context.Context, - request mcp.CallToolRequest, -) (*mcp.CallToolResult, error) { - response, err := c.sendRequest(ctx, "tools/call", request.Params) - if err != nil { - return nil, err - } - - return mcp.ParseCallToolResult(response) -} - -func (c *StdioMCPClient) SetLevel( - ctx context.Context, - request mcp.SetLevelRequest, -) error { - _, err := c.sendRequest(ctx, "logging/setLevel", request.Params) - return err -} - -func (c *StdioMCPClient) Complete( - ctx context.Context, - request mcp.CompleteRequest, -) (*mcp.CompleteResult, error) { - response, err := c.sendRequest(ctx, "completion/complete", request.Params) - if err != nil { - return nil, err - } - - var result mcp.CompleteResult - if err := json.Unmarshal(*response, &result); err != nil { - return nil, fmt.Errorf("failed to unmarshal response: %w", err) - } - - return &result, nil -} diff --git a/vendor/github.com/mark3labs/mcp-go/client/types.go b/vendor/github.com/mark3labs/mcp-go/client/types.go deleted file mode 100644 index 4402bd0240..0000000000 --- a/vendor/github.com/mark3labs/mcp-go/client/types.go +++ /dev/null @@ -1,8 +0,0 @@ -package client - -import "encoding/json" - -type RPCResponse struct { - Error *string - Response *json.RawMessage -} diff --git a/vendor/github.com/mark3labs/mcp-go/mcp/prompts.go b/vendor/github.com/mark3labs/mcp-go/mcp/prompts.go index bc12a72976..a63a214503 100644 --- a/vendor/github.com/mark3labs/mcp-go/mcp/prompts.go +++ b/vendor/github.com/mark3labs/mcp-go/mcp/prompts.go @@ -19,12 +19,14 @@ type ListPromptsResult struct { // server. type GetPromptRequest struct { Request - Params struct { - // The name of the prompt or prompt template. - Name string `json:"name"` - // Arguments to use for templating the prompt. - Arguments map[string]string `json:"arguments,omitempty"` - } `json:"params"` + Params GetPromptParams `json:"params"` +} + +type GetPromptParams struct { + // The name of the prompt or prompt template. + Name string `json:"name"` + // Arguments to use for templating the prompt. + Arguments map[string]string `json:"arguments,omitempty"` } // GetPromptResult is the server's response to a prompts/get request from the @@ -50,6 +52,11 @@ type Prompt struct { Arguments []PromptArgument `json:"arguments,omitempty"` } +// GetName returns the name of the prompt. +func (p Prompt) GetName() string { + return p.Name +} + // PromptArgument describes an argument that a prompt template can accept. // When a prompt includes arguments, clients must provide values for all // required arguments when making a prompts/get request. @@ -78,7 +85,7 @@ const ( // resources from the MCP server. type PromptMessage struct { Role Role `json:"role"` - Content Content `json:"content"` // Can be TextContent, ImageContent, or EmbeddedResource + Content Content `json:"content"` // Can be TextContent, ImageContent, AudioContent or EmbeddedResource } // PromptListChangedNotification is an optional notification from the server diff --git a/vendor/github.com/mark3labs/mcp-go/mcp/resources.go b/vendor/github.com/mark3labs/mcp-go/mcp/resources.go index 51cdd25dd3..07a59a3223 100644 --- a/vendor/github.com/mark3labs/mcp-go/mcp/resources.go +++ b/vendor/github.com/mark3labs/mcp-go/mcp/resources.go @@ -43,10 +43,7 @@ func WithMIMEType(mimeType string) ResourceOption { func WithAnnotations(audience []Role, priority float64) ResourceOption { return func(r *Resource) { if r.Annotations == nil { - r.Annotations = &struct { - Audience []Role `json:"audience,omitempty"` - Priority float64 `json:"priority,omitempty"` - }{} + r.Annotations = &Annotations{} } r.Annotations.Audience = audience r.Annotations.Priority = priority @@ -94,10 +91,7 @@ func WithTemplateMIMEType(mimeType string) ResourceTemplateOption { func WithTemplateAnnotations(audience []Role, priority float64) ResourceTemplateOption { return func(t *ResourceTemplate) { if t.Annotations == nil { - t.Annotations = &struct { - Audience []Role `json:"audience,omitempty"` - Priority float64 `json:"priority,omitempty"` - }{} + t.Annotations = &Annotations{} } t.Annotations.Audience = audience t.Annotations.Priority = priority diff --git a/vendor/github.com/mark3labs/mcp-go/mcp/tools.go b/vendor/github.com/mark3labs/mcp-go/mcp/tools.go index c4c1b1dec0..3e3931b09c 100644 --- a/vendor/github.com/mark3labs/mcp-go/mcp/tools.go +++ b/vendor/github.com/mark3labs/mcp-go/mcp/tools.go @@ -4,6 +4,8 @@ import ( "encoding/json" "errors" "fmt" + "reflect" + "strconv" ) var errToolSchemaConflict = errors.New("provide either InputSchema or RawInputSchema, not both") @@ -33,7 +35,7 @@ type ListToolsResult struct { // should be reported as an MCP error response. type CallToolResult struct { Result - Content []Content `json:"content"` // Can be TextContent, ImageContent, or EmbeddedResource + Content []Content `json:"content"` // Can be TextContent, ImageContent, AudioContent, or EmbeddedResource // Whether the tool call ended in an error. // // If not set, this is assumed to be false (the call was successful). @@ -43,19 +45,420 @@ type CallToolResult struct { // CallToolRequest is used by the client to invoke a tool provided by the server. type CallToolRequest struct { Request - Params struct { - Name string `json:"name"` - Arguments map[string]interface{} `json:"arguments,omitempty"` - Meta *struct { - // If specified, the caller is requesting out-of-band progress - // notifications for this request (as represented by - // notifications/progress). The value of this parameter is an - // opaque token that will be attached to any subsequent - // notifications. The receiver is not obligated to provide these - // notifications. - ProgressToken ProgressToken `json:"progressToken,omitempty"` - } `json:"_meta,omitempty"` - } `json:"params"` + Params CallToolParams `json:"params"` +} + +type CallToolParams struct { + Name string `json:"name"` + Arguments any `json:"arguments,omitempty"` + Meta *Meta `json:"_meta,omitempty"` +} + +// GetArguments returns the Arguments as map[string]any for backward compatibility +// If Arguments is not a map, it returns an empty map +func (r CallToolRequest) GetArguments() map[string]any { + if args, ok := r.Params.Arguments.(map[string]any); ok { + return args + } + return nil +} + +// GetRawArguments returns the Arguments as-is without type conversion +// This allows users to access the raw arguments in any format +func (r CallToolRequest) GetRawArguments() any { + return r.Params.Arguments +} + +// BindArguments unmarshals the Arguments into the provided struct +// This is useful for working with strongly-typed arguments +func (r CallToolRequest) BindArguments(target any) error { + if target == nil || reflect.ValueOf(target).Kind() != reflect.Ptr { + return fmt.Errorf("target must be a non-nil pointer") + } + + // Fast-path: already raw JSON + if raw, ok := r.Params.Arguments.(json.RawMessage); ok { + return json.Unmarshal(raw, target) + } + + data, err := json.Marshal(r.Params.Arguments) + if err != nil { + return fmt.Errorf("failed to marshal arguments: %w", err) + } + + return json.Unmarshal(data, target) +} + +// GetString returns a string argument by key, or the default value if not found +func (r CallToolRequest) GetString(key string, defaultValue string) string { + args := r.GetArguments() + if val, ok := args[key]; ok { + if str, ok := val.(string); ok { + return str + } + } + return defaultValue +} + +// RequireString returns a string argument by key, or an error if not found or not a string +func (r CallToolRequest) RequireString(key string) (string, error) { + args := r.GetArguments() + if val, ok := args[key]; ok { + if str, ok := val.(string); ok { + return str, nil + } + return "", fmt.Errorf("argument %q is not a string", key) + } + return "", fmt.Errorf("required argument %q not found", key) +} + +// GetInt returns an int argument by key, or the default value if not found +func (r CallToolRequest) GetInt(key string, defaultValue int) int { + args := r.GetArguments() + if val, ok := args[key]; ok { + switch v := val.(type) { + case int: + return v + case float64: + return int(v) + case string: + if i, err := strconv.Atoi(v); err == nil { + return i + } + } + } + return defaultValue +} + +// RequireInt returns an int argument by key, or an error if not found or not convertible to int +func (r CallToolRequest) RequireInt(key string) (int, error) { + args := r.GetArguments() + if val, ok := args[key]; ok { + switch v := val.(type) { + case int: + return v, nil + case float64: + return int(v), nil + case string: + if i, err := strconv.Atoi(v); err == nil { + return i, nil + } + return 0, fmt.Errorf("argument %q cannot be converted to int", key) + default: + return 0, fmt.Errorf("argument %q is not an int", key) + } + } + return 0, fmt.Errorf("required argument %q not found", key) +} + +// GetFloat returns a float64 argument by key, or the default value if not found +func (r CallToolRequest) GetFloat(key string, defaultValue float64) float64 { + args := r.GetArguments() + if val, ok := args[key]; ok { + switch v := val.(type) { + case float64: + return v + case int: + return float64(v) + case string: + if f, err := strconv.ParseFloat(v, 64); err == nil { + return f + } + } + } + return defaultValue +} + +// RequireFloat returns a float64 argument by key, or an error if not found or not convertible to float64 +func (r CallToolRequest) RequireFloat(key string) (float64, error) { + args := r.GetArguments() + if val, ok := args[key]; ok { + switch v := val.(type) { + case float64: + return v, nil + case int: + return float64(v), nil + case string: + if f, err := strconv.ParseFloat(v, 64); err == nil { + return f, nil + } + return 0, fmt.Errorf("argument %q cannot be converted to float64", key) + default: + return 0, fmt.Errorf("argument %q is not a float64", key) + } + } + return 0, fmt.Errorf("required argument %q not found", key) +} + +// GetBool returns a bool argument by key, or the default value if not found +func (r CallToolRequest) GetBool(key string, defaultValue bool) bool { + args := r.GetArguments() + if val, ok := args[key]; ok { + switch v := val.(type) { + case bool: + return v + case string: + if b, err := strconv.ParseBool(v); err == nil { + return b + } + case int: + return v != 0 + case float64: + return v != 0 + } + } + return defaultValue +} + +// RequireBool returns a bool argument by key, or an error if not found or not convertible to bool +func (r CallToolRequest) RequireBool(key string) (bool, error) { + args := r.GetArguments() + if val, ok := args[key]; ok { + switch v := val.(type) { + case bool: + return v, nil + case string: + if b, err := strconv.ParseBool(v); err == nil { + return b, nil + } + return false, fmt.Errorf("argument %q cannot be converted to bool", key) + case int: + return v != 0, nil + case float64: + return v != 0, nil + default: + return false, fmt.Errorf("argument %q is not a bool", key) + } + } + return false, fmt.Errorf("required argument %q not found", key) +} + +// GetStringSlice returns a string slice argument by key, or the default value if not found +func (r CallToolRequest) GetStringSlice(key string, defaultValue []string) []string { + args := r.GetArguments() + if val, ok := args[key]; ok { + switch v := val.(type) { + case []string: + return v + case []any: + result := make([]string, 0, len(v)) + for _, item := range v { + if str, ok := item.(string); ok { + result = append(result, str) + } + } + return result + } + } + return defaultValue +} + +// RequireStringSlice returns a string slice argument by key, or an error if not found or not convertible to string slice +func (r CallToolRequest) RequireStringSlice(key string) ([]string, error) { + args := r.GetArguments() + if val, ok := args[key]; ok { + switch v := val.(type) { + case []string: + return v, nil + case []any: + result := make([]string, 0, len(v)) + for i, item := range v { + if str, ok := item.(string); ok { + result = append(result, str) + } else { + return nil, fmt.Errorf("item %d in argument %q is not a string", i, key) + } + } + return result, nil + default: + return nil, fmt.Errorf("argument %q is not a string slice", key) + } + } + return nil, fmt.Errorf("required argument %q not found", key) +} + +// GetIntSlice returns an int slice argument by key, or the default value if not found +func (r CallToolRequest) GetIntSlice(key string, defaultValue []int) []int { + args := r.GetArguments() + if val, ok := args[key]; ok { + switch v := val.(type) { + case []int: + return v + case []any: + result := make([]int, 0, len(v)) + for _, item := range v { + switch num := item.(type) { + case int: + result = append(result, num) + case float64: + result = append(result, int(num)) + case string: + if i, err := strconv.Atoi(num); err == nil { + result = append(result, i) + } + } + } + return result + } + } + return defaultValue +} + +// RequireIntSlice returns an int slice argument by key, or an error if not found or not convertible to int slice +func (r CallToolRequest) RequireIntSlice(key string) ([]int, error) { + args := r.GetArguments() + if val, ok := args[key]; ok { + switch v := val.(type) { + case []int: + return v, nil + case []any: + result := make([]int, 0, len(v)) + for i, item := range v { + switch num := item.(type) { + case int: + result = append(result, num) + case float64: + result = append(result, int(num)) + case string: + if i, err := strconv.Atoi(num); err == nil { + result = append(result, i) + } else { + return nil, fmt.Errorf("item %d in argument %q cannot be converted to int", i, key) + } + default: + return nil, fmt.Errorf("item %d in argument %q is not an int", i, key) + } + } + return result, nil + default: + return nil, fmt.Errorf("argument %q is not an int slice", key) + } + } + return nil, fmt.Errorf("required argument %q not found", key) +} + +// GetFloatSlice returns a float64 slice argument by key, or the default value if not found +func (r CallToolRequest) GetFloatSlice(key string, defaultValue []float64) []float64 { + args := r.GetArguments() + if val, ok := args[key]; ok { + switch v := val.(type) { + case []float64: + return v + case []any: + result := make([]float64, 0, len(v)) + for _, item := range v { + switch num := item.(type) { + case float64: + result = append(result, num) + case int: + result = append(result, float64(num)) + case string: + if f, err := strconv.ParseFloat(num, 64); err == nil { + result = append(result, f) + } + } + } + return result + } + } + return defaultValue +} + +// RequireFloatSlice returns a float64 slice argument by key, or an error if not found or not convertible to float64 slice +func (r CallToolRequest) RequireFloatSlice(key string) ([]float64, error) { + args := r.GetArguments() + if val, ok := args[key]; ok { + switch v := val.(type) { + case []float64: + return v, nil + case []any: + result := make([]float64, 0, len(v)) + for i, item := range v { + switch num := item.(type) { + case float64: + result = append(result, num) + case int: + result = append(result, float64(num)) + case string: + if f, err := strconv.ParseFloat(num, 64); err == nil { + result = append(result, f) + } else { + return nil, fmt.Errorf("item %d in argument %q cannot be converted to float64", i, key) + } + default: + return nil, fmt.Errorf("item %d in argument %q is not a float64", i, key) + } + } + return result, nil + default: + return nil, fmt.Errorf("argument %q is not a float64 slice", key) + } + } + return nil, fmt.Errorf("required argument %q not found", key) +} + +// GetBoolSlice returns a bool slice argument by key, or the default value if not found +func (r CallToolRequest) GetBoolSlice(key string, defaultValue []bool) []bool { + args := r.GetArguments() + if val, ok := args[key]; ok { + switch v := val.(type) { + case []bool: + return v + case []any: + result := make([]bool, 0, len(v)) + for _, item := range v { + switch b := item.(type) { + case bool: + result = append(result, b) + case string: + if parsed, err := strconv.ParseBool(b); err == nil { + result = append(result, parsed) + } + case int: + result = append(result, b != 0) + case float64: + result = append(result, b != 0) + } + } + return result + } + } + return defaultValue +} + +// RequireBoolSlice returns a bool slice argument by key, or an error if not found or not convertible to bool slice +func (r CallToolRequest) RequireBoolSlice(key string) ([]bool, error) { + args := r.GetArguments() + if val, ok := args[key]; ok { + switch v := val.(type) { + case []bool: + return v, nil + case []any: + result := make([]bool, 0, len(v)) + for i, item := range v { + switch b := item.(type) { + case bool: + result = append(result, b) + case string: + if parsed, err := strconv.ParseBool(b); err == nil { + result = append(result, parsed) + } else { + return nil, fmt.Errorf("item %d in argument %q cannot be converted to bool", i, key) + } + case int: + result = append(result, b != 0) + case float64: + result = append(result, b != 0) + default: + return nil, fmt.Errorf("item %d in argument %q is not a bool", i, key) + } + } + return result, nil + default: + return nil, fmt.Errorf("argument %q is not a bool slice", key) + } + } + return nil, fmt.Errorf("required argument %q not found", key) } // ToolListChangedNotification is an optional notification from the server to @@ -75,13 +478,20 @@ type Tool struct { InputSchema ToolInputSchema `json:"inputSchema"` // Alternative to InputSchema - allows arbitrary JSON Schema to be provided RawInputSchema json.RawMessage `json:"-"` // Hide this from JSON marshaling + // Optional properties describing tool behavior + Annotations ToolAnnotation `json:"annotations"` +} + +// GetName returns the name of the tool. +func (t Tool) GetName() string { + return t.Name } // MarshalJSON implements the json.Marshaler interface for Tool. // It handles marshaling either InputSchema or RawInputSchema based on which is set. func (t Tool) MarshalJSON() ([]byte, error) { // Create a map to build the JSON structure - m := make(map[string]interface{}, 3) + m := make(map[string]any, 3) // Add the name and description m["name"] = t.Name @@ -100,13 +510,45 @@ func (t Tool) MarshalJSON() ([]byte, error) { m["inputSchema"] = t.InputSchema } + m["annotations"] = t.Annotations + return json.Marshal(m) } type ToolInputSchema struct { - Type string `json:"type"` - Properties map[string]interface{} `json:"properties"` - Required []string `json:"required,omitempty"` + Type string `json:"type"` + Properties map[string]any `json:"properties,omitempty"` + Required []string `json:"required,omitempty"` +} + +// MarshalJSON implements the json.Marshaler interface for ToolInputSchema. +func (tis ToolInputSchema) MarshalJSON() ([]byte, error) { + m := make(map[string]any) + m["type"] = tis.Type + + // Marshal Properties to '{}' rather than `nil` when its length equals zero + if tis.Properties != nil { + m["properties"] = tis.Properties + } + + if len(tis.Required) > 0 { + m["required"] = tis.Required + } + + return json.Marshal(m) +} + +type ToolAnnotation struct { + // Human-readable title for the tool + Title string `json:"title,omitempty"` + // If true, the tool does not modify its environment + ReadOnlyHint *bool `json:"readOnlyHint,omitempty"` + // If true, the tool may perform destructive updates + DestructiveHint *bool `json:"destructiveHint,omitempty"` + // If true, repeated calls with same args have no additional effect + IdempotentHint *bool `json:"idempotentHint,omitempty"` + // If true, tool interacts with external entities + OpenWorldHint *bool `json:"openWorldHint,omitempty"` } // ToolOption is a function that configures a Tool. @@ -115,7 +557,7 @@ type ToolOption func(*Tool) // PropertyOption is a function that configures a property in a Tool's input schema. // It allows for flexible configuration of JSON Schema properties using the functional options pattern. -type PropertyOption func(map[string]interface{}) +type PropertyOption func(map[string]any) // // Core Tool Functions @@ -129,9 +571,16 @@ func NewTool(name string, opts ...ToolOption) Tool { Name: name, InputSchema: ToolInputSchema{ Type: "object", - Properties: make(map[string]interface{}), + Properties: make(map[string]any), Required: nil, // Will be omitted from JSON if empty }, + Annotations: ToolAnnotation{ + Title: "", + ReadOnlyHint: ToBoolPtr(false), + DestructiveHint: ToBoolPtr(true), + IdempotentHint: ToBoolPtr(false), + OpenWorldHint: ToBoolPtr(true), + }, } for _, opt := range opts { @@ -166,6 +615,53 @@ func WithDescription(description string) ToolOption { } } +// WithToolAnnotation adds optional hints about the Tool. +func WithToolAnnotation(annotation ToolAnnotation) ToolOption { + return func(t *Tool) { + t.Annotations = annotation + } +} + +// WithTitleAnnotation sets the Title field of the Tool's Annotations. +// It provides a human-readable title for the tool. +func WithTitleAnnotation(title string) ToolOption { + return func(t *Tool) { + t.Annotations.Title = title + } +} + +// WithReadOnlyHintAnnotation sets the ReadOnlyHint field of the Tool's Annotations. +// If true, it indicates the tool does not modify its environment. +func WithReadOnlyHintAnnotation(value bool) ToolOption { + return func(t *Tool) { + t.Annotations.ReadOnlyHint = &value + } +} + +// WithDestructiveHintAnnotation sets the DestructiveHint field of the Tool's Annotations. +// If true, it indicates the tool may perform destructive updates. +func WithDestructiveHintAnnotation(value bool) ToolOption { + return func(t *Tool) { + t.Annotations.DestructiveHint = &value + } +} + +// WithIdempotentHintAnnotation sets the IdempotentHint field of the Tool's Annotations. +// If true, it indicates repeated calls with the same arguments have no additional effect. +func WithIdempotentHintAnnotation(value bool) ToolOption { + return func(t *Tool) { + t.Annotations.IdempotentHint = &value + } +} + +// WithOpenWorldHintAnnotation sets the OpenWorldHint field of the Tool's Annotations. +// If true, it indicates the tool interacts with external entities. +func WithOpenWorldHintAnnotation(value bool) ToolOption { + return func(t *Tool) { + t.Annotations.OpenWorldHint = &value + } +} + // // Common Property Options // @@ -173,7 +669,7 @@ func WithDescription(description string) ToolOption { // Description adds a description to a property in the JSON Schema. // The description should explain the purpose and expected values of the property. func Description(desc string) PropertyOption { - return func(schema map[string]interface{}) { + return func(schema map[string]any) { schema["description"] = desc } } @@ -181,7 +677,7 @@ func Description(desc string) PropertyOption { // Required marks a property as required in the tool's input schema. // Required properties must be provided when using the tool. func Required() PropertyOption { - return func(schema map[string]interface{}) { + return func(schema map[string]any) { schema["required"] = true } } @@ -189,7 +685,7 @@ func Required() PropertyOption { // Title adds a display-friendly title to a property in the JSON Schema. // This title can be used by UI components to show a more readable property name. func Title(title string) PropertyOption { - return func(schema map[string]interface{}) { + return func(schema map[string]any) { schema["title"] = title } } @@ -201,7 +697,7 @@ func Title(title string) PropertyOption { // DefaultString sets the default value for a string property. // This value will be used if the property is not explicitly provided. func DefaultString(value string) PropertyOption { - return func(schema map[string]interface{}) { + return func(schema map[string]any) { schema["default"] = value } } @@ -209,7 +705,7 @@ func DefaultString(value string) PropertyOption { // Enum specifies a list of allowed values for a string property. // The property value must be one of the specified enum values. func Enum(values ...string) PropertyOption { - return func(schema map[string]interface{}) { + return func(schema map[string]any) { schema["enum"] = values } } @@ -217,7 +713,7 @@ func Enum(values ...string) PropertyOption { // MaxLength sets the maximum length for a string property. // The string value must not exceed this length. func MaxLength(max int) PropertyOption { - return func(schema map[string]interface{}) { + return func(schema map[string]any) { schema["maxLength"] = max } } @@ -225,7 +721,7 @@ func MaxLength(max int) PropertyOption { // MinLength sets the minimum length for a string property. // The string value must be at least this length. func MinLength(min int) PropertyOption { - return func(schema map[string]interface{}) { + return func(schema map[string]any) { schema["minLength"] = min } } @@ -233,7 +729,7 @@ func MinLength(min int) PropertyOption { // Pattern sets a regex pattern that a string property must match. // The string value must conform to the specified regular expression. func Pattern(pattern string) PropertyOption { - return func(schema map[string]interface{}) { + return func(schema map[string]any) { schema["pattern"] = pattern } } @@ -245,7 +741,7 @@ func Pattern(pattern string) PropertyOption { // DefaultNumber sets the default value for a number property. // This value will be used if the property is not explicitly provided. func DefaultNumber(value float64) PropertyOption { - return func(schema map[string]interface{}) { + return func(schema map[string]any) { schema["default"] = value } } @@ -253,7 +749,7 @@ func DefaultNumber(value float64) PropertyOption { // Max sets the maximum value for a number property. // The number value must not exceed this maximum. func Max(max float64) PropertyOption { - return func(schema map[string]interface{}) { + return func(schema map[string]any) { schema["maximum"] = max } } @@ -261,7 +757,7 @@ func Max(max float64) PropertyOption { // Min sets the minimum value for a number property. // The number value must not be less than this minimum. func Min(min float64) PropertyOption { - return func(schema map[string]interface{}) { + return func(schema map[string]any) { schema["minimum"] = min } } @@ -269,7 +765,7 @@ func Min(min float64) PropertyOption { // MultipleOf specifies that a number must be a multiple of the given value. // The number value must be divisible by this value. func MultipleOf(value float64) PropertyOption { - return func(schema map[string]interface{}) { + return func(schema map[string]any) { schema["multipleOf"] = value } } @@ -281,7 +777,19 @@ func MultipleOf(value float64) PropertyOption { // DefaultBool sets the default value for a boolean property. // This value will be used if the property is not explicitly provided. func DefaultBool(value bool) PropertyOption { - return func(schema map[string]interface{}) { + return func(schema map[string]any) { + schema["default"] = value + } +} + +// +// Array Property Options +// + +// DefaultArray sets the default value for an array property. +// This value will be used if the property is not explicitly provided. +func DefaultArray[T any](value []T) PropertyOption { + return func(schema map[string]any) { schema["default"] = value } } @@ -294,7 +802,7 @@ func DefaultBool(value bool) PropertyOption { // It accepts property options to configure the boolean property's behavior and constraints. func WithBoolean(name string, opts ...PropertyOption) ToolOption { return func(t *Tool) { - schema := map[string]interface{}{ + schema := map[string]any{ "type": "boolean", } @@ -316,7 +824,7 @@ func WithBoolean(name string, opts ...PropertyOption) ToolOption { // It accepts property options to configure the number property's behavior and constraints. func WithNumber(name string, opts ...PropertyOption) ToolOption { return func(t *Tool) { - schema := map[string]interface{}{ + schema := map[string]any{ "type": "number", } @@ -338,7 +846,7 @@ func WithNumber(name string, opts ...PropertyOption) ToolOption { // It accepts property options to configure the string property's behavior and constraints. func WithString(name string, opts ...PropertyOption) ToolOption { return func(t *Tool) { - schema := map[string]interface{}{ + schema := map[string]any{ "type": "string", } @@ -360,9 +868,9 @@ func WithString(name string, opts ...PropertyOption) ToolOption { // It accepts property options to configure the object property's behavior and constraints. func WithObject(name string, opts ...PropertyOption) ToolOption { return func(t *Tool) { - schema := map[string]interface{}{ + schema := map[string]any{ "type": "object", - "properties": map[string]interface{}{}, + "properties": map[string]any{}, } for _, opt := range opts { @@ -383,7 +891,7 @@ func WithObject(name string, opts ...PropertyOption) ToolOption { // It accepts property options to configure the array property's behavior and constraints. func WithArray(name string, opts ...PropertyOption) ToolOption { return func(t *Tool) { - schema := map[string]interface{}{ + schema := map[string]any{ "type": "array", } @@ -402,65 +910,169 @@ func WithArray(name string, opts ...PropertyOption) ToolOption { } // Properties defines the properties for an object schema -func Properties(props map[string]interface{}) PropertyOption { - return func(schema map[string]interface{}) { +func Properties(props map[string]any) PropertyOption { + return func(schema map[string]any) { schema["properties"] = props } } // AdditionalProperties specifies whether additional properties are allowed in the object // or defines a schema for additional properties -func AdditionalProperties(schema interface{}) PropertyOption { - return func(schemaMap map[string]interface{}) { +func AdditionalProperties(schema any) PropertyOption { + return func(schemaMap map[string]any) { schemaMap["additionalProperties"] = schema } } // MinProperties sets the minimum number of properties for an object func MinProperties(min int) PropertyOption { - return func(schema map[string]interface{}) { + return func(schema map[string]any) { schema["minProperties"] = min } } // MaxProperties sets the maximum number of properties for an object func MaxProperties(max int) PropertyOption { - return func(schema map[string]interface{}) { + return func(schema map[string]any) { schema["maxProperties"] = max } } // PropertyNames defines a schema for property names in an object -func PropertyNames(schema map[string]interface{}) PropertyOption { - return func(schemaMap map[string]interface{}) { +func PropertyNames(schema map[string]any) PropertyOption { + return func(schemaMap map[string]any) { schemaMap["propertyNames"] = schema } } -// Items defines the schema for array items -func Items(schema interface{}) PropertyOption { - return func(schemaMap map[string]interface{}) { +// Items defines the schema for array items. +// Accepts any schema definition for maximum flexibility. +// +// Example: +// +// Items(map[string]any{ +// "type": "object", +// "properties": map[string]any{ +// "name": map[string]any{"type": "string"}, +// "age": map[string]any{"type": "number"}, +// }, +// }) +// +// For simple types, use ItemsString(), ItemsNumber(), ItemsBoolean() instead. +func Items(schema any) PropertyOption { + return func(schemaMap map[string]any) { schemaMap["items"] = schema } } // MinItems sets the minimum number of items for an array func MinItems(min int) PropertyOption { - return func(schema map[string]interface{}) { + return func(schema map[string]any) { schema["minItems"] = min } } // MaxItems sets the maximum number of items for an array func MaxItems(max int) PropertyOption { - return func(schema map[string]interface{}) { + return func(schema map[string]any) { schema["maxItems"] = max } } // UniqueItems specifies whether array items must be unique func UniqueItems(unique bool) PropertyOption { - return func(schema map[string]interface{}) { + return func(schema map[string]any) { schema["uniqueItems"] = unique } } + +// WithStringItems configures an array's items to be of type string. +// +// Supported options: Description(), DefaultString(), Enum(), MaxLength(), MinLength(), Pattern() +// Note: Options like Required() are not valid for item schemas and will be ignored. +// +// Examples: +// +// mcp.WithArray("tags", mcp.WithStringItems()) +// mcp.WithArray("colors", mcp.WithStringItems(mcp.Enum("red", "green", "blue"))) +// mcp.WithArray("names", mcp.WithStringItems(mcp.MinLength(1), mcp.MaxLength(50))) +// +// Limitations: Only supports simple string arrays. Use Items() for complex objects. +func WithStringItems(opts ...PropertyOption) PropertyOption { + return func(schema map[string]any) { + itemSchema := map[string]any{ + "type": "string", + } + + for _, opt := range opts { + opt(itemSchema) + } + + schema["items"] = itemSchema + } +} + +// WithStringEnumItems configures an array's items to be of type string with a specified enum. +// Example: +// +// mcp.WithArray("priority", mcp.WithStringEnumItems([]string{"low", "medium", "high"})) +// +// Limitations: Only supports string enums. Use WithStringItems(Enum(...)) for more flexibility. +func WithStringEnumItems(values []string) PropertyOption { + return func(schema map[string]any) { + schema["items"] = map[string]any{ + "type": "string", + "enum": values, + } + } +} + +// WithNumberItems configures an array's items to be of type number. +// +// Supported options: Description(), DefaultNumber(), Min(), Max(), MultipleOf() +// Note: Options like Required() are not valid for item schemas and will be ignored. +// +// Examples: +// +// mcp.WithArray("scores", mcp.WithNumberItems(mcp.Min(0), mcp.Max(100))) +// mcp.WithArray("prices", mcp.WithNumberItems(mcp.Min(0))) +// +// Limitations: Only supports simple number arrays. Use Items() for complex objects. +func WithNumberItems(opts ...PropertyOption) PropertyOption { + return func(schema map[string]any) { + itemSchema := map[string]any{ + "type": "number", + } + + for _, opt := range opts { + opt(itemSchema) + } + + schema["items"] = itemSchema + } +} + +// WithBooleanItems configures an array's items to be of type boolean. +// +// Supported options: Description(), DefaultBool() +// Note: Options like Required() are not valid for item schemas and will be ignored. +// +// Examples: +// +// mcp.WithArray("flags", mcp.WithBooleanItems()) +// mcp.WithArray("permissions", mcp.WithBooleanItems(mcp.Description("User permissions"))) +// +// Limitations: Only supports simple boolean arrays. Use Items() for complex objects. +func WithBooleanItems(opts ...PropertyOption) PropertyOption { + return func(schema map[string]any) { + itemSchema := map[string]any{ + "type": "boolean", + } + + for _, opt := range opts { + opt(itemSchema) + } + + schema["items"] = itemSchema + } +} diff --git a/vendor/github.com/mark3labs/mcp-go/mcp/typed_tools.go b/vendor/github.com/mark3labs/mcp-go/mcp/typed_tools.go new file mode 100644 index 0000000000..68d8cdd1f3 --- /dev/null +++ b/vendor/github.com/mark3labs/mcp-go/mcp/typed_tools.go @@ -0,0 +1,20 @@ +package mcp + +import ( + "context" + "fmt" +) + +// TypedToolHandlerFunc is a function that handles a tool call with typed arguments +type TypedToolHandlerFunc[T any] func(ctx context.Context, request CallToolRequest, args T) (*CallToolResult, error) + +// NewTypedToolHandler creates a ToolHandlerFunc that automatically binds arguments to a typed struct +func NewTypedToolHandler[T any](handler TypedToolHandlerFunc[T]) func(ctx context.Context, request CallToolRequest) (*CallToolResult, error) { + return func(ctx context.Context, request CallToolRequest) (*CallToolResult, error) { + var args T + if err := request.BindArguments(&args); err != nil { + return NewToolResultError(fmt.Sprintf("failed to bind arguments: %v", err)), nil + } + return handler(ctx, request, args) + } +} diff --git a/vendor/github.com/mark3labs/mcp-go/mcp/types.go b/vendor/github.com/mark3labs/mcp-go/mcp/types.go index a3ad8174e6..241b55ce9b 100644 --- a/vendor/github.com/mark3labs/mcp-go/mcp/types.go +++ b/vendor/github.com/mark3labs/mcp-go/mcp/types.go @@ -1,9 +1,12 @@ -// Package mcp defines the core types and interfaces for the Model Control Protocol (MCP). +// Package mcp defines the core types and interfaces for the Model Context Protocol (MCP). // MCP is a protocol for communication between LLM-powered applications and their supporting services. package mcp import ( "encoding/json" + "fmt" + "maps" + "strconv" "github.com/yosida95/uritemplate/v3" ) @@ -11,41 +14,59 @@ import ( type MCPMethod string const ( - // Initiates connection and negotiates protocol capabilities. - // https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/lifecycle/#initialization + // MethodInitialize initiates connection and negotiates protocol capabilities. + // https://modelcontextprotocol.io/specification/2024-11-05/basic/lifecycle/#initialization MethodInitialize MCPMethod = "initialize" - // Verifies connection liveness between client and server. - // https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/utilities/ping/ + // MethodPing verifies connection liveness between client and server. + // https://modelcontextprotocol.io/specification/2024-11-05/basic/utilities/ping/ MethodPing MCPMethod = "ping" - // Lists all available server resources. - // https://spec.modelcontextprotocol.io/specification/2024-11-05/server/resources/ + // MethodResourcesList lists all available server resources. + // https://modelcontextprotocol.io/specification/2024-11-05/server/resources/ MethodResourcesList MCPMethod = "resources/list" - // Provides URI templates for constructing resource URIs. - // https://spec.modelcontextprotocol.io/specification/2024-11-05/server/resources/ + // MethodResourcesTemplatesList provides URI templates for constructing resource URIs. + // https://modelcontextprotocol.io/specification/2024-11-05/server/resources/ MethodResourcesTemplatesList MCPMethod = "resources/templates/list" - // Retrieves content of a specific resource by URI. - // https://spec.modelcontextprotocol.io/specification/2024-11-05/server/resources/ + // MethodResourcesRead retrieves content of a specific resource by URI. + // https://modelcontextprotocol.io/specification/2024-11-05/server/resources/ MethodResourcesRead MCPMethod = "resources/read" - // Lists all available prompt templates. - // https://spec.modelcontextprotocol.io/specification/2024-11-05/server/prompts/ + // MethodPromptsList lists all available prompt templates. + // https://modelcontextprotocol.io/specification/2024-11-05/server/prompts/ MethodPromptsList MCPMethod = "prompts/list" - // Retrieves a specific prompt template with filled parameters. - // https://spec.modelcontextprotocol.io/specification/2024-11-05/server/prompts/ + // MethodPromptsGet retrieves a specific prompt template with filled parameters. + // https://modelcontextprotocol.io/specification/2024-11-05/server/prompts/ MethodPromptsGet MCPMethod = "prompts/get" - // Lists all available executable tools. - // https://spec.modelcontextprotocol.io/specification/2024-11-05/server/tools/ + // MethodToolsList lists all available executable tools. + // https://modelcontextprotocol.io/specification/2024-11-05/server/tools/ MethodToolsList MCPMethod = "tools/list" - // Invokes a specific tool with provided parameters. - // https://spec.modelcontextprotocol.io/specification/2024-11-05/server/tools/ + // MethodToolsCall invokes a specific tool with provided parameters. + // https://modelcontextprotocol.io/specification/2024-11-05/server/tools/ MethodToolsCall MCPMethod = "tools/call" + + // MethodSetLogLevel configures the minimum log level for client + // https://modelcontextprotocol.io/specification/2025-03-26/server/utilities/logging + MethodSetLogLevel MCPMethod = "logging/setLevel" + + // MethodNotificationResourcesListChanged notifies when the list of available resources changes. + // https://modelcontextprotocol.io/specification/2025-03-26/server/resources#list-changed-notification + MethodNotificationResourcesListChanged = "notifications/resources/list_changed" + + MethodNotificationResourceUpdated = "notifications/resources/updated" + + // MethodNotificationPromptsListChanged notifies when the list of available prompt templates changes. + // https://modelcontextprotocol.io/specification/2025-03-26/server/prompts#list-changed-notification + MethodNotificationPromptsListChanged = "notifications/prompts/list_changed" + + // MethodNotificationToolsListChanged notifies when the list of available tools changes. + // https://spec.modelcontextprotocol.io/specification/2024-11-05/server/tools/list_changed/ + MethodNotificationToolsListChanged = "notifications/tools/list_changed" ) type URITemplate struct { @@ -53,7 +74,7 @@ type URITemplate struct { } func (t *URITemplate) MarshalJSON() ([]byte, error) { - return json.Marshal(t.Template.Raw()) + return json.Marshal(t.Raw()) } func (t *URITemplate) UnmarshalJSON(data []byte) error { @@ -72,36 +93,73 @@ func (t *URITemplate) UnmarshalJSON(data []byte) error { /* JSON-RPC types */ // JSONRPCMessage represents either a JSONRPCRequest, JSONRPCNotification, JSONRPCResponse, or JSONRPCError -type JSONRPCMessage interface{} +type JSONRPCMessage any // LATEST_PROTOCOL_VERSION is the most recent version of the MCP protocol. -const LATEST_PROTOCOL_VERSION = "2024-11-05" +const LATEST_PROTOCOL_VERSION = "2025-03-26" + +// ValidProtocolVersions lists all known valid MCP protocol versions. +var ValidProtocolVersions = []string{ + "2024-11-05", + LATEST_PROTOCOL_VERSION, +} // JSONRPC_VERSION is the version of JSON-RPC used by MCP. const JSONRPC_VERSION = "2.0" // ProgressToken is used to associate progress notifications with the original request. -type ProgressToken interface{} +type ProgressToken any // Cursor is an opaque token used to represent a cursor for pagination. type Cursor string +// Meta is metadata attached to a request's parameters. This can include fields +// formally defined by the protocol or other arbitrary data. +type Meta struct { + // If specified, the caller is requesting out-of-band progress + // notifications for this request (as represented by + // notifications/progress). The value of this parameter is an + // opaque token that will be attached to any subsequent + // notifications. The receiver is not obligated to provide these + // notifications. + ProgressToken ProgressToken + + // AdditionalFields are any fields present in the Meta that are not + // otherwise defined in the protocol. + AdditionalFields map[string]any +} + +func (m *Meta) MarshalJSON() ([]byte, error) { + raw := make(map[string]any) + if m.ProgressToken != nil { + raw["progressToken"] = m.ProgressToken + } + maps.Copy(raw, m.AdditionalFields) + + return json.Marshal(raw) +} + +func (m *Meta) UnmarshalJSON(data []byte) error { + raw := make(map[string]any) + if err := json.Unmarshal(data, &raw); err != nil { + return err + } + m.ProgressToken = raw["progressToken"] + delete(raw, "progressToken") + m.AdditionalFields = raw + return nil +} + type Request struct { - Method string `json:"method"` - Params struct { - Meta *struct { - // If specified, the caller is requesting out-of-band progress - // notifications for this request (as represented by - // notifications/progress). The value of this parameter is an - // opaque token that will be attached to any subsequent - // notifications. The receiver is not obligated to provide these - // notifications. - ProgressToken ProgressToken `json:"progressToken,omitempty"` - } `json:"_meta,omitempty"` - } `json:"params,omitempty"` -} - -type Params map[string]interface{} + Method string `json:"method"` + Params RequestParams `json:"params,omitempty"` +} + +type RequestParams struct { + Meta *Meta `json:"_meta,omitempty"` +} + +type Params map[string]any type Notification struct { Method string `json:"method"` @@ -111,16 +169,16 @@ type Notification struct { type NotificationParams struct { // This parameter name is reserved by MCP to allow clients and // servers to attach additional metadata to their notifications. - Meta map[string]interface{} `json:"_meta,omitempty"` + Meta map[string]any `json:"_meta,omitempty"` // Additional fields can be added to this map - AdditionalFields map[string]interface{} `json:"-"` + AdditionalFields map[string]any `json:"-"` } // MarshalJSON implements custom JSON marshaling func (p NotificationParams) MarshalJSON() ([]byte, error) { // Create a map to hold all fields - m := make(map[string]interface{}) + m := make(map[string]any) // Add Meta if it exists if p.Meta != nil { @@ -141,24 +199,24 @@ func (p NotificationParams) MarshalJSON() ([]byte, error) { // UnmarshalJSON implements custom JSON unmarshaling func (p *NotificationParams) UnmarshalJSON(data []byte) error { // Create a map to hold all fields - var m map[string]interface{} + var m map[string]any if err := json.Unmarshal(data, &m); err != nil { return err } // Initialize maps if they're nil if p.Meta == nil { - p.Meta = make(map[string]interface{}) + p.Meta = make(map[string]any) } if p.AdditionalFields == nil { - p.AdditionalFields = make(map[string]interface{}) + p.AdditionalFields = make(map[string]any) } // Process all fields for k, v := range m { if k == "_meta" { // Handle Meta field - if meta, ok := v.(map[string]interface{}); ok { + if meta, ok := v.(map[string]any); ok { p.Meta = meta } } else { @@ -173,18 +231,86 @@ func (p *NotificationParams) UnmarshalJSON(data []byte) error { type Result struct { // This result property is reserved by the protocol to allow clients and // servers to attach additional metadata to their responses. - Meta map[string]interface{} `json:"_meta,omitempty"` + Meta map[string]any `json:"_meta,omitempty"` } // RequestId is a uniquely identifying ID for a request in JSON-RPC. // It can be any JSON-serializable value, typically a number or string. -type RequestId interface{} +type RequestId struct { + value any +} + +// NewRequestId creates a new RequestId with the given value +func NewRequestId(value any) RequestId { + return RequestId{value: value} +} + +// Value returns the underlying value of the RequestId +func (r RequestId) Value() any { + return r.value +} + +// String returns a string representation of the RequestId +func (r RequestId) String() string { + switch v := r.value.(type) { + case string: + return "string:" + v + case int64: + return "int64:" + strconv.FormatInt(v, 10) + case float64: + if v == float64(int64(v)) { + return "int64:" + strconv.FormatInt(int64(v), 10) + } + return "float64:" + strconv.FormatFloat(v, 'f', -1, 64) + case nil: + return "" + default: + return "unknown:" + fmt.Sprintf("%v", v) + } +} + +// IsNil returns true if the RequestId is nil +func (r RequestId) IsNil() bool { + return r.value == nil +} + +func (r RequestId) MarshalJSON() ([]byte, error) { + return json.Marshal(r.value) +} + +func (r *RequestId) UnmarshalJSON(data []byte) error { + + if string(data) == "null" { + r.value = nil + return nil + } + + // Try unmarshaling as string first + var s string + if err := json.Unmarshal(data, &s); err == nil { + r.value = s + return nil + } + + // JSON numbers are unmarshaled as float64 in Go + var f float64 + if err := json.Unmarshal(data, &f); err == nil { + if f == float64(int64(f)) { + r.value = int64(f) + } else { + r.value = f + } + return nil + } + + return fmt.Errorf("invalid request id: %s", string(data)) +} // JSONRPCRequest represents a request that expects a response. type JSONRPCRequest struct { - JSONRPC string `json:"jsonrpc"` - ID RequestId `json:"id"` - Params interface{} `json:"params,omitempty"` + JSONRPC string `json:"jsonrpc"` + ID RequestId `json:"id"` + Params any `json:"params,omitempty"` Request } @@ -196,9 +322,9 @@ type JSONRPCNotification struct { // JSONRPCResponse represents a successful (non-error) response to a request. type JSONRPCResponse struct { - JSONRPC string `json:"jsonrpc"` - ID RequestId `json:"id"` - Result interface{} `json:"result"` + JSONRPC string `json:"jsonrpc"` + ID RequestId `json:"id"` + Result any `json:"result"` } // JSONRPCError represents a non-successful (error) response to a request. @@ -213,7 +339,7 @@ type JSONRPCError struct { Message string `json:"message"` // Additional information about the error. The value of this member // is defined by the sender (e.g. detailed error information, nested errors etc.). - Data interface{} `json:"data,omitempty"` + Data any `json:"data,omitempty"` } `json:"error"` } @@ -226,6 +352,11 @@ const ( INTERNAL_ERROR = -32603 ) +// MCP error codes +const ( + RESOURCE_NOT_FOUND = -32002 +) + /* Empty result */ // EmptyResult represents a response that indicates success but carries no data. @@ -246,17 +377,19 @@ type EmptyResult Result // A client MUST NOT attempt to cancel its `initialize` request. type CancelledNotification struct { Notification - Params struct { - // The ID of the request to cancel. - // - // This MUST correspond to the ID of a request previously issued - // in the same direction. - RequestId RequestId `json:"requestId"` + Params CancelledNotificationParams `json:"params"` +} - // An optional string describing the reason for the cancellation. This MAY - // be logged or presented to the user. - Reason string `json:"reason,omitempty"` - } `json:"params"` +type CancelledNotificationParams struct { + // The ID of the request to cancel. + // + // This MUST correspond to the ID of a request previously issued + // in the same direction. + RequestId RequestId `json:"requestId"` + + // An optional string describing the reason for the cancellation. This MAY + // be logged or presented to the user. + Reason string `json:"reason,omitempty"` } /* Initialization */ @@ -265,13 +398,15 @@ type CancelledNotification struct { // connects, asking it to begin initialization. type InitializeRequest struct { Request - Params struct { - // The latest version of the Model Context Protocol that the client supports. - // The client MAY decide to support older versions as well. - ProtocolVersion string `json:"protocolVersion"` - Capabilities ClientCapabilities `json:"capabilities"` - ClientInfo Implementation `json:"clientInfo"` - } `json:"params"` + Params InitializeParams `json:"params"` +} + +type InitializeParams struct { + // The latest version of the Model Context Protocol that the client supports. + // The client MAY decide to support older versions as well. + ProtocolVersion string `json:"protocolVersion"` + Capabilities ClientCapabilities `json:"capabilities"` + ClientInfo Implementation `json:"clientInfo"` } // InitializeResult is sent after receiving an initialize request from the @@ -303,7 +438,7 @@ type InitializedNotification struct { // client can define its own, additional capabilities. type ClientCapabilities struct { // Experimental, non-standard capabilities that the client supports. - Experimental map[string]interface{} `json:"experimental,omitempty"` + Experimental map[string]any `json:"experimental,omitempty"` // Present if the client supports listing roots. Roots *struct { // Whether the client supports notifications for changes to the roots list. @@ -318,7 +453,7 @@ type ClientCapabilities struct { // server can define its own, additional capabilities. type ServerCapabilities struct { // Experimental, non-standard capabilities that the server supports. - Experimental map[string]interface{} `json:"experimental,omitempty"` + Experimental map[string]any `json:"experimental,omitempty"` // Present if the server supports sending log messages to the client. Logging *struct{} `json:"logging,omitempty"` // Present if the server offers any prompt templates. @@ -362,27 +497,34 @@ type PingRequest struct { // receiver of a progress update for a long-running request. type ProgressNotification struct { Notification - Params struct { - // The progress token which was given in the initial request, used to - // associate this notification with the request that is proceeding. - ProgressToken ProgressToken `json:"progressToken"` - // The progress thus far. This should increase every time progress is made, - // even if the total is unknown. - Progress float64 `json:"progress"` - // Total number of items to process (or total progress required), if known. - Total float64 `json:"total,omitempty"` - } `json:"params"` + Params ProgressNotificationParams `json:"params"` +} + +type ProgressNotificationParams struct { + // The progress token which was given in the initial request, used to + // associate this notification with the request that is proceeding. + ProgressToken ProgressToken `json:"progressToken"` + // The progress thus far. This should increase every time progress is made, + // even if the total is unknown. + Progress float64 `json:"progress"` + // Total number of items to process (or total progress required), if known. + Total float64 `json:"total,omitempty"` + // Message related to progress. This should provide relevant human-readable + // progress information. + Message string `json:"message,omitempty"` } /* Pagination */ type PaginatedRequest struct { Request - Params struct { - // An opaque token representing the current pagination position. - // If provided, the server should return results starting after this cursor. - Cursor Cursor `json:"cursor,omitempty"` - } `json:"params,omitempty"` + Params PaginatedParams `json:"params,omitempty"` +} + +type PaginatedParams struct { + // An opaque token representing the current pagination position. + // If provided, the server should return results starting after this cursor. + Cursor Cursor `json:"cursor,omitempty"` } type PaginatedResult struct { @@ -425,13 +567,15 @@ type ListResourceTemplatesResult struct { // specific resource URI. type ReadResourceRequest struct { Request - Params struct { - // The URI of the resource to read. The URI can use any protocol; it is up - // to the server how to interpret it. - URI string `json:"uri"` - // Arguments to pass to the resource handler - Arguments map[string]interface{} `json:"arguments,omitempty"` - } `json:"params"` + Params ReadResourceParams `json:"params"` +} + +type ReadResourceParams struct { + // The URI of the resource to read. The URI can use any protocol; it is up + // to the server how to interpret it. + URI string `json:"uri"` + // Arguments to pass to the resource handler + Arguments map[string]any `json:"arguments,omitempty"` } // ReadResourceResult is the server's response to a resources/read request @@ -453,11 +597,13 @@ type ResourceListChangedNotification struct { // notifications from the server whenever a particular resource changes. type SubscribeRequest struct { Request - Params struct { - // The URI of the resource to subscribe to. The URI can use any protocol; it - // is up to the server how to interpret it. - URI string `json:"uri"` - } `json:"params"` + Params SubscribeParams `json:"params"` +} + +type SubscribeParams struct { + // The URI of the resource to subscribe to. The URI can use any protocol; it + // is up to the server how to interpret it. + URI string `json:"uri"` } // UnsubscribeRequest is sent from the client to request cancellation of @@ -465,10 +611,12 @@ type SubscribeRequest struct { // resources/subscribe request. type UnsubscribeRequest struct { Request - Params struct { - // The URI of the resource to unsubscribe from. - URI string `json:"uri"` - } `json:"params"` + Params UnsubscribeParams `json:"params"` +} + +type UnsubscribeParams struct { + // The URI of the resource to unsubscribe from. + URI string `json:"uri"` } // ResourceUpdatedNotification is a notification from the server to the client, @@ -476,11 +624,12 @@ type UnsubscribeRequest struct { // should only be sent if the client previously sent a resources/subscribe request. type ResourceUpdatedNotification struct { Notification - Params struct { - // The URI of the resource that has been updated. This might be a sub- - // resource of the one that the client actually subscribed to. - URI string `json:"uri"` - } `json:"params"` + Params ResourceUpdatedNotificationParams `json:"params"` +} +type ResourceUpdatedNotificationParams struct { + // The URI of the resource that has been updated. This might be a sub- + // resource of the one that the client actually subscribed to. + URI string `json:"uri"` } // Resource represents a known resource that the server is capable of reading. @@ -501,6 +650,11 @@ type Resource struct { MIMEType string `json:"mimeType,omitempty"` } +// GetName returns the name of the resource. +func (r Resource) GetName() string { + return r.Name +} + // ResourceTemplate represents a template description for resources available // on the server. type ResourceTemplate struct { @@ -522,6 +676,11 @@ type ResourceTemplate struct { MIMEType string `json:"mimeType,omitempty"` } +// GetName returns the name of the resourceTemplate. +func (rt ResourceTemplate) GetName() string { + return rt.Name +} + // ResourceContents represents the contents of a specific resource or sub- // resource. type ResourceContents interface { @@ -557,12 +716,14 @@ func (BlobResourceContents) isResourceContents() {} // adjust logging. type SetLevelRequest struct { Request - Params struct { - // The level of logging that the client wants to receive from the server. - // The server should send all logs at this level and higher (i.e., more severe) to - // the client as notifications/logging/message. - Level LoggingLevel `json:"level"` - } `json:"params"` + Params SetLevelParams `json:"params"` +} + +type SetLevelParams struct { + // The level of logging that the client wants to receive from the server. + // The server should send all logs at this level and higher (i.e., more severe) to + // the client as notifications/logging/message. + Level LoggingLevel `json:"level"` } // LoggingMessageNotification is a notification of a log message passed from @@ -570,15 +731,17 @@ type SetLevelRequest struct { // the server MAY decide which messages to send automatically. type LoggingMessageNotification struct { Notification - Params struct { - // The severity of this log message. - Level LoggingLevel `json:"level"` - // An optional name of the logger issuing this message. - Logger string `json:"logger,omitempty"` - // The data to be logged, such as a string message or an object. Any JSON - // serializable type is allowed here. - Data interface{} `json:"data"` - } `json:"params"` + Params LoggingMessageNotificationParams `json:"params"` +} + +type LoggingMessageNotificationParams struct { + // The severity of this log message. + Level LoggingLevel `json:"level"` + // An optional name of the logger issuing this message. + Logger string `json:"logger,omitempty"` + // The data to be logged, such as a string message or an object. Any JSON + // serializable type is allowed here. + Data any `json:"data"` } // LoggingLevel represents the severity of a log message. @@ -600,22 +763,29 @@ const ( /* Sampling */ +const ( + // MethodSamplingCreateMessage allows servers to request LLM completions from clients + MethodSamplingCreateMessage MCPMethod = "sampling/createMessage" +) + // CreateMessageRequest is a request from the server to sample an LLM via the // client. The client has full discretion over which model to select. The client // should also inform the user before beginning sampling, to allow them to inspect // the request (human in the loop) and decide whether to approve it. type CreateMessageRequest struct { Request - Params struct { - Messages []SamplingMessage `json:"messages"` - ModelPreferences *ModelPreferences `json:"modelPreferences,omitempty"` - SystemPrompt string `json:"systemPrompt,omitempty"` - IncludeContext string `json:"includeContext,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - MaxTokens int `json:"maxTokens"` - StopSequences []string `json:"stopSequences,omitempty"` - Metadata interface{} `json:"metadata,omitempty"` - } `json:"params"` + CreateMessageParams `json:"params"` +} + +type CreateMessageParams struct { + Messages []SamplingMessage `json:"messages"` + ModelPreferences *ModelPreferences `json:"modelPreferences,omitempty"` + SystemPrompt string `json:"systemPrompt,omitempty"` + IncludeContext string `json:"includeContext,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + MaxTokens int `json:"maxTokens"` + StopSequences []string `json:"stopSequences,omitempty"` + Metadata any `json:"metadata,omitempty"` } // CreateMessageResult is the client's response to a sampling/create_message @@ -633,28 +803,30 @@ type CreateMessageResult struct { // SamplingMessage describes a message issued to or received from an LLM API. type SamplingMessage struct { - Role Role `json:"role"` - Content interface{} `json:"content"` // Can be TextContent or ImageContent + Role Role `json:"role"` + Content any `json:"content"` // Can be TextContent, ImageContent or AudioContent +} + +type Annotations struct { + // Describes who the intended customer of this object or data is. + // + // It can include multiple entries to indicate content useful for multiple + // audiences (e.g., `["user", "assistant"]`). + Audience []Role `json:"audience,omitempty"` + + // Describes how important this data is for operating the server. + // + // A value of 1 means "most important," and indicates that the data is + // effectively required, while 0 means "least important," and indicates that + // the data is entirely optional. + Priority float64 `json:"priority,omitempty"` } // Annotated is the base for objects that include optional annotations for the // client. The client can use annotations to inform how objects are used or // displayed type Annotated struct { - Annotations *struct { - // Describes who the intended customer of this object or data is. - // - // It can include multiple entries to indicate content useful for multiple - // audiences (e.g., `["user", "assistant"]`). - Audience []Role `json:"audience,omitempty"` - - // Describes how important this data is for operating the server. - // - // A value of 1 means "most important," and indicates that the data is - // effectively required, while 0 means "least important," and indicates that - // the data is entirely optional. - Priority float64 `json:"priority,omitempty"` - } `json:"annotations,omitempty"` + Annotations *Annotations `json:"annotations,omitempty"` } type Content interface { @@ -685,6 +857,35 @@ type ImageContent struct { func (ImageContent) isContent() {} +// AudioContent represents the contents of audio, embedded into a prompt or tool call result. +// It must have Type set to "audio". +type AudioContent struct { + Annotated + Type string `json:"type"` // Must be "audio" + // The base64-encoded audio data. + Data string `json:"data"` + // The MIME type of the audio. Different providers may support different audio types. + MIMEType string `json:"mimeType"` +} + +func (AudioContent) isContent() {} + +// ResourceLink represents a link to a resource that the client can access. +type ResourceLink struct { + Annotated + Type string `json:"type"` // Must be "resource_link" + // The URI of the resource. + URI string `json:"uri"` + // The name of the resource. + Name string `json:"name"` + // The description of the resource. + Description string `json:"description"` + // The MIME type of the resource. + MIMEType string `json:"mimeType"` +} + +func (ResourceLink) isContent() {} + // EmbeddedResource represents the contents of a resource, embedded into a prompt or tool call result. // // It is up to the client how best to render embedded resources for the @@ -758,15 +959,17 @@ type ModelHint struct { // CompleteRequest is a request from the client to the server, to ask for completion options. type CompleteRequest struct { Request - Params struct { - Ref interface{} `json:"ref"` // Can be PromptReference or ResourceReference - Argument struct { - // The name of the argument - Name string `json:"name"` - // The value of the argument to use for completion matching. - Value string `json:"value"` - } `json:"argument"` - } `json:"params"` + Params CompleteParams `json:"params"` +} + +type CompleteParams struct { + Ref any `json:"ref"` // Can be PromptReference or ResourceReference + Argument struct { + // The name of the argument + Name string `json:"name"` + // The value of the argument to use for completion matching. + Value string `json:"value"` + } `json:"argument"` } // CompleteResult is the server's response to a completion/complete request @@ -839,22 +1042,24 @@ type RootsListChangedNotification struct { Notification } -/* Client messages */ // ClientRequest represents any request that can be sent from client to server. -type ClientRequest interface{} +type ClientRequest any // ClientNotification represents any notification that can be sent from client to server. -type ClientNotification interface{} +type ClientNotification any // ClientResult represents any result that can be sent from client to server. -type ClientResult interface{} +type ClientResult any -/* Server messages */ // ServerRequest represents any request that can be sent from server to client. -type ServerRequest interface{} +type ServerRequest any // ServerNotification represents any notification that can be sent from server to client. -type ServerNotification interface{} +type ServerNotification any // ServerResult represents any result that can be sent from server to client. -type ServerResult interface{} +type ServerResult any + +type Named interface { + GetName() string +} diff --git a/vendor/github.com/mark3labs/mcp-go/mcp/utils.go b/vendor/github.com/mark3labs/mcp-go/mcp/utils.go index 236164cbd8..3e652efd7e 100644 --- a/vendor/github.com/mark3labs/mcp-go/mcp/utils.go +++ b/vendor/github.com/mark3labs/mcp-go/mcp/utils.go @@ -3,6 +3,8 @@ package mcp import ( "encoding/json" "fmt" + + "github.com/spf13/cast" ) // ClientRequest types @@ -58,7 +60,7 @@ var _ ServerResult = &ListToolsResult{} // Helper functions for type assertions // asType attempts to cast the given interface to the given type -func asType[T any](content interface{}) (*T, bool) { +func asType[T any](content any) (*T, bool) { tc, ok := content.(T) if !ok { return nil, false @@ -67,27 +69,32 @@ func asType[T any](content interface{}) (*T, bool) { } // AsTextContent attempts to cast the given interface to TextContent -func AsTextContent(content interface{}) (*TextContent, bool) { +func AsTextContent(content any) (*TextContent, bool) { return asType[TextContent](content) } // AsImageContent attempts to cast the given interface to ImageContent -func AsImageContent(content interface{}) (*ImageContent, bool) { +func AsImageContent(content any) (*ImageContent, bool) { return asType[ImageContent](content) } +// AsAudioContent attempts to cast the given interface to AudioContent +func AsAudioContent(content any) (*AudioContent, bool) { + return asType[AudioContent](content) +} + // AsEmbeddedResource attempts to cast the given interface to EmbeddedResource -func AsEmbeddedResource(content interface{}) (*EmbeddedResource, bool) { +func AsEmbeddedResource(content any) (*EmbeddedResource, bool) { return asType[EmbeddedResource](content) } // AsTextResourceContents attempts to cast the given interface to TextResourceContents -func AsTextResourceContents(content interface{}) (*TextResourceContents, bool) { +func AsTextResourceContents(content any) (*TextResourceContents, bool) { return asType[TextResourceContents](content) } // AsBlobResourceContents attempts to cast the given interface to BlobResourceContents -func AsBlobResourceContents(content interface{}) (*BlobResourceContents, bool) { +func AsBlobResourceContents(content any) (*BlobResourceContents, bool) { return asType[BlobResourceContents](content) } @@ -107,15 +114,15 @@ func NewJSONRPCError( id RequestId, code int, message string, - data interface{}, + data any, ) JSONRPCError { return JSONRPCError{ JSONRPC: JSONRPC_VERSION, ID: id, Error: struct { - Code int `json:"code"` - Message string `json:"message"` - Data interface{} `json:"data,omitempty"` + Code int `json:"code"` + Message string `json:"message"` + Data any `json:"data,omitempty"` }{ Code: code, Message: message, @@ -124,11 +131,13 @@ func NewJSONRPCError( } } +// NewProgressNotification // Helper function for creating a progress notification func NewProgressNotification( token ProgressToken, progress float64, total *float64, + message *string, ) ProgressNotification { notification := ProgressNotification{ Notification: Notification{ @@ -138,6 +147,7 @@ func NewProgressNotification( ProgressToken ProgressToken `json:"progressToken"` Progress float64 `json:"progress"` Total float64 `json:"total,omitempty"` + Message string `json:"message,omitempty"` }{ ProgressToken: token, Progress: progress, @@ -146,14 +156,18 @@ func NewProgressNotification( if total != nil { notification.Params.Total = *total } + if message != nil { + notification.Params.Message = *message + } return notification } +// NewLoggingMessageNotification // Helper function for creating a logging message notification func NewLoggingMessageNotification( level LoggingLevel, logger string, - data interface{}, + data any, ) LoggingMessageNotification { return LoggingMessageNotification{ Notification: Notification{ @@ -162,7 +176,7 @@ func NewLoggingMessageNotification( Params: struct { Level LoggingLevel `json:"level"` Logger string `json:"logger,omitempty"` - Data interface{} `json:"data"` + Data any `json:"data"` }{ Level: level, Logger: logger, @@ -171,6 +185,7 @@ func NewLoggingMessageNotification( } } +// NewPromptMessage // Helper function to create a new PromptMessage func NewPromptMessage(role Role, content Content) PromptMessage { return PromptMessage{ @@ -179,6 +194,7 @@ func NewPromptMessage(role Role, content Content) PromptMessage { } } +// NewTextContent // Helper function to create a new TextContent func NewTextContent(text string) TextContent { return TextContent{ @@ -187,6 +203,7 @@ func NewTextContent(text string) TextContent { } } +// NewImageContent // Helper function to create a new ImageContent func NewImageContent(data, mimeType string) ImageContent { return ImageContent{ @@ -196,6 +213,26 @@ func NewImageContent(data, mimeType string) ImageContent { } } +// Helper function to create a new AudioContent +func NewAudioContent(data, mimeType string) AudioContent { + return AudioContent{ + Type: "audio", + Data: data, + MIMEType: mimeType, + } +} + +// Helper function to create a new ResourceLink +func NewResourceLink(uri, name, description, mimeType string) ResourceLink { + return ResourceLink{ + Type: "resource_link", + URI: uri, + Name: name, + Description: description, + MIMEType: mimeType, + } +} + // Helper function to create a new EmbeddedResource func NewEmbeddedResource(resource ResourceContents) EmbeddedResource { return EmbeddedResource{ @@ -233,6 +270,23 @@ func NewToolResultImage(text, imageData, mimeType string) *CallToolResult { } } +// NewToolResultAudio creates a new CallToolResult with both text and audio content +func NewToolResultAudio(text, imageData, mimeType string) *CallToolResult { + return &CallToolResult{ + Content: []Content{ + TextContent{ + Type: "text", + Text: text, + }, + AudioContent{ + Type: "audio", + Data: imageData, + MIMEType: mimeType, + }, + }, + } +} + // NewToolResultResource creates a new CallToolResult with an embedded resource func NewToolResultResource( text string, @@ -266,6 +320,39 @@ func NewToolResultError(text string) *CallToolResult { } } +// NewToolResultErrorFromErr creates a new CallToolResult with an error message. +// If an error is provided, its details will be appended to the text message. +// Any errors that originate from the tool SHOULD be reported inside the result object. +func NewToolResultErrorFromErr(text string, err error) *CallToolResult { + if err != nil { + text = fmt.Sprintf("%s: %v", text, err) + } + return &CallToolResult{ + Content: []Content{ + TextContent{ + Type: "text", + Text: text, + }, + }, + IsError: true, + } +} + +// NewToolResultErrorf creates a new CallToolResult with an error message. +// The error message is formatted using the fmt package. +// Any errors that originate from the tool SHOULD be reported inside the result object. +func NewToolResultErrorf(format string, a ...any) *CallToolResult { + return &CallToolResult{ + Content: []Content{ + TextContent{ + Type: "text", + Text: fmt.Sprintf(format, a...), + }, + }, + IsError: true, + } +} + // NewListResourcesResult creates a new ListResourcesResult func NewListResourcesResult( resources []Resource, @@ -352,6 +439,7 @@ func NewInitializeResult( } } +// FormatNumberResult // Helper for formatting numbers in tool results func FormatNumberResult(value float64) *CallToolResult { return NewToolResultText(fmt.Sprintf("%.2f", value)) @@ -381,9 +469,6 @@ func ParseContent(contentMap map[string]any) (Content, error) { switch contentType { case "text": text := ExtractString(contentMap, "text") - if text == "" { - return nil, fmt.Errorf("text is missing") - } return NewTextContent(text), nil case "image": @@ -394,6 +479,24 @@ func ParseContent(contentMap map[string]any) (Content, error) { } return NewImageContent(data, mimeType), nil + case "audio": + data := ExtractString(contentMap, "data") + mimeType := ExtractString(contentMap, "mimeType") + if data == "" || mimeType == "" { + return nil, fmt.Errorf("audio data or mimeType is missing") + } + return NewAudioContent(data, mimeType), nil + + case "resource_link": + uri := ExtractString(contentMap, "uri") + name := ExtractString(contentMap, "name") + description := ExtractString(contentMap, "description") + mimeType := ExtractString(contentMap, "mimeType") + if uri == "" || name == "" { + return nil, fmt.Errorf("resource_link uri or name is missing") + } + return NewResourceLink(uri, name, description, mimeType), nil + case "resource": resourceMap := ExtractMap(contentMap, "resource") if resourceMap == nil { @@ -412,6 +515,10 @@ func ParseContent(contentMap map[string]any) (Content, error) { } func ParseGetPromptResult(rawMessage *json.RawMessage) (*GetPromptResult, error) { + if rawMessage == nil { + return nil, fmt.Errorf("response is nil") + } + var jsonContent map[string]any if err := json.Unmarshal(*rawMessage, &jsonContent); err != nil { return nil, fmt.Errorf("failed to unmarshal response: %w", err) @@ -474,6 +581,10 @@ func ParseGetPromptResult(rawMessage *json.RawMessage) (*GetPromptResult, error) } func ParseCallToolResult(rawMessage *json.RawMessage) (*CallToolResult, error) { + if rawMessage == nil { + return nil, fmt.Errorf("response is nil") + } + var jsonContent map[string]any if err := json.Unmarshal(*rawMessage, &jsonContent); err != nil { return nil, fmt.Errorf("failed to unmarshal response: %w", err) @@ -552,6 +663,10 @@ func ParseResourceContents(contentMap map[string]any) (ResourceContents, error) } func ParseReadResourceResult(rawMessage *json.RawMessage) (*ReadResourceResult, error) { + if rawMessage == nil { + return nil, fmt.Errorf("response is nil") + } + var jsonContent map[string]any if err := json.Unmarshal(*rawMessage, &jsonContent); err != nil { return nil, fmt.Errorf("failed to unmarshal response: %w", err) @@ -594,3 +709,111 @@ func ParseReadResourceResult(rawMessage *json.RawMessage) (*ReadResourceResult, return &result, nil } + +func ParseArgument(request CallToolRequest, key string, defaultVal any) any { + args := request.GetArguments() + if _, ok := args[key]; !ok { + return defaultVal + } else { + return args[key] + } +} + +// ParseBoolean extracts and converts a boolean parameter from a CallToolRequest. +// If the key is not found in the Arguments map, the defaultValue is returned. +// The function uses cast.ToBool for conversion which handles various string representations +// such as "true", "yes", "1", etc. +func ParseBoolean(request CallToolRequest, key string, defaultValue bool) bool { + v := ParseArgument(request, key, defaultValue) + return cast.ToBool(v) +} + +// ParseInt64 extracts and converts an int64 parameter from a CallToolRequest. +// If the key is not found in the Arguments map, the defaultValue is returned. +func ParseInt64(request CallToolRequest, key string, defaultValue int64) int64 { + v := ParseArgument(request, key, defaultValue) + return cast.ToInt64(v) +} + +// ParseInt32 extracts and converts an int32 parameter from a CallToolRequest. +func ParseInt32(request CallToolRequest, key string, defaultValue int32) int32 { + v := ParseArgument(request, key, defaultValue) + return cast.ToInt32(v) +} + +// ParseInt16 extracts and converts an int16 parameter from a CallToolRequest. +func ParseInt16(request CallToolRequest, key string, defaultValue int16) int16 { + v := ParseArgument(request, key, defaultValue) + return cast.ToInt16(v) +} + +// ParseInt8 extracts and converts an int8 parameter from a CallToolRequest. +func ParseInt8(request CallToolRequest, key string, defaultValue int8) int8 { + v := ParseArgument(request, key, defaultValue) + return cast.ToInt8(v) +} + +// ParseInt extracts and converts an int parameter from a CallToolRequest. +func ParseInt(request CallToolRequest, key string, defaultValue int) int { + v := ParseArgument(request, key, defaultValue) + return cast.ToInt(v) +} + +// ParseUInt extracts and converts an uint parameter from a CallToolRequest. +func ParseUInt(request CallToolRequest, key string, defaultValue uint) uint { + v := ParseArgument(request, key, defaultValue) + return cast.ToUint(v) +} + +// ParseUInt64 extracts and converts an uint64 parameter from a CallToolRequest. +func ParseUInt64(request CallToolRequest, key string, defaultValue uint64) uint64 { + v := ParseArgument(request, key, defaultValue) + return cast.ToUint64(v) +} + +// ParseUInt32 extracts and converts an uint32 parameter from a CallToolRequest. +func ParseUInt32(request CallToolRequest, key string, defaultValue uint32) uint32 { + v := ParseArgument(request, key, defaultValue) + return cast.ToUint32(v) +} + +// ParseUInt16 extracts and converts an uint16 parameter from a CallToolRequest. +func ParseUInt16(request CallToolRequest, key string, defaultValue uint16) uint16 { + v := ParseArgument(request, key, defaultValue) + return cast.ToUint16(v) +} + +// ParseUInt8 extracts and converts an uint8 parameter from a CallToolRequest. +func ParseUInt8(request CallToolRequest, key string, defaultValue uint8) uint8 { + v := ParseArgument(request, key, defaultValue) + return cast.ToUint8(v) +} + +// ParseFloat32 extracts and converts a float32 parameter from a CallToolRequest. +func ParseFloat32(request CallToolRequest, key string, defaultValue float32) float32 { + v := ParseArgument(request, key, defaultValue) + return cast.ToFloat32(v) +} + +// ParseFloat64 extracts and converts a float64 parameter from a CallToolRequest. +func ParseFloat64(request CallToolRequest, key string, defaultValue float64) float64 { + v := ParseArgument(request, key, defaultValue) + return cast.ToFloat64(v) +} + +// ParseString extracts and converts a string parameter from a CallToolRequest. +func ParseString(request CallToolRequest, key string, defaultValue string) string { + v := ParseArgument(request, key, defaultValue) + return cast.ToString(v) +} + +// ParseStringMap extracts and converts a string map parameter from a CallToolRequest. +func ParseStringMap(request CallToolRequest, key string, defaultValue map[string]any) map[string]any { + v := ParseArgument(request, key, defaultValue) + return cast.ToStringMap(v) +} + +// ToBoolPtr returns a pointer to the given boolean value +func ToBoolPtr(b bool) *bool { + return &b +} diff --git a/vendor/github.com/mark3labs/mcp-go/server/errors.go b/vendor/github.com/mark3labs/mcp-go/server/errors.go new file mode 100644 index 0000000000..ecbe91e5fb --- /dev/null +++ b/vendor/github.com/mark3labs/mcp-go/server/errors.go @@ -0,0 +1,34 @@ +package server + +import ( + "errors" + "fmt" +) + +var ( + // Common server errors + ErrUnsupported = errors.New("not supported") + ErrResourceNotFound = errors.New("resource not found") + ErrPromptNotFound = errors.New("prompt not found") + ErrToolNotFound = errors.New("tool not found") + + // Session-related errors + ErrSessionNotFound = errors.New("session not found") + ErrSessionExists = errors.New("session already exists") + ErrSessionNotInitialized = errors.New("session not properly initialized") + ErrSessionDoesNotSupportTools = errors.New("session does not support per-session tools") + ErrSessionDoesNotSupportLogging = errors.New("session does not support setting logging level") + + // Notification-related errors + ErrNotificationNotInitialized = errors.New("notification channel not initialized") + ErrNotificationChannelBlocked = errors.New("notification channel full or blocked") +) + +// ErrDynamicPathConfig is returned when attempting to use static path methods with dynamic path configuration +type ErrDynamicPathConfig struct { + Method string +} + +func (e *ErrDynamicPathConfig) Error() string { + return fmt.Sprintf("%s cannot be used with WithDynamicBasePath. Use dynamic path logic in your router.", e.Method) +} diff --git a/vendor/github.com/mark3labs/mcp-go/server/hooks.go b/vendor/github.com/mark3labs/mcp-go/server/hooks.go index ce976a6cdb..4baa1c4e05 100644 --- a/vendor/github.com/mark3labs/mcp-go/server/hooks.go +++ b/vendor/github.com/mark3labs/mcp-go/server/hooks.go @@ -11,6 +11,9 @@ import ( // OnRegisterSessionHookFunc is a hook that will be called when a new session is registered. type OnRegisterSessionHookFunc func(ctx context.Context, session ClientSession) +// OnUnregisterSessionHookFunc is a hook that will be called when a session is being unregistered. +type OnUnregisterSessionHookFunc func(ctx context.Context, session ClientSession) + // BeforeAnyHookFunc is a function that is called after the request is // parsed but before the method is called. type BeforeAnyHookFunc func(ctx context.Context, id any, method mcp.MCPMethod, message any) @@ -33,7 +36,7 @@ type OnSuccessHookFunc func(ctx context.Context, id any, method mcp.MCPMethod, m // } // // // Use errors.As to get specific error types -// var parseErr = &UnparseableMessageError{} +// var parseErr = &UnparsableMessageError{} // if errors.As(err, &parseErr) { // // Access specific methods/fields of the error type // log.Printf("Failed to parse message for method %s: %v", @@ -54,12 +57,19 @@ type OnSuccessHookFunc func(ctx context.Context, id any, method mcp.MCPMethod, m // }) type OnErrorHookFunc func(ctx context.Context, id any, method mcp.MCPMethod, message any, err error) +// OnRequestInitializationFunc is a function that called before handle diff request method +// Should any errors arise during func execution, the service will promptly return the corresponding error message. +type OnRequestInitializationFunc func(ctx context.Context, id any, message any) error + type OnBeforeInitializeFunc func(ctx context.Context, id any, message *mcp.InitializeRequest) type OnAfterInitializeFunc func(ctx context.Context, id any, message *mcp.InitializeRequest, result *mcp.InitializeResult) type OnBeforePingFunc func(ctx context.Context, id any, message *mcp.PingRequest) type OnAfterPingFunc func(ctx context.Context, id any, message *mcp.PingRequest, result *mcp.EmptyResult) +type OnBeforeSetLevelFunc func(ctx context.Context, id any, message *mcp.SetLevelRequest) +type OnAfterSetLevelFunc func(ctx context.Context, id any, message *mcp.SetLevelRequest, result *mcp.EmptyResult) + type OnBeforeListResourcesFunc func(ctx context.Context, id any, message *mcp.ListResourcesRequest) type OnAfterListResourcesFunc func(ctx context.Context, id any, message *mcp.ListResourcesRequest, result *mcp.ListResourcesResult) @@ -83,13 +93,17 @@ type OnAfterCallToolFunc func(ctx context.Context, id any, message *mcp.CallTool type Hooks struct { OnRegisterSession []OnRegisterSessionHookFunc + OnUnregisterSession []OnUnregisterSessionHookFunc OnBeforeAny []BeforeAnyHookFunc OnSuccess []OnSuccessHookFunc OnError []OnErrorHookFunc + OnRequestInitialization []OnRequestInitializationFunc OnBeforeInitialize []OnBeforeInitializeFunc OnAfterInitialize []OnAfterInitializeFunc OnBeforePing []OnBeforePingFunc OnAfterPing []OnAfterPingFunc + OnBeforeSetLevel []OnBeforeSetLevelFunc + OnAfterSetLevel []OnAfterSetLevelFunc OnBeforeListResources []OnBeforeListResourcesFunc OnAfterListResources []OnAfterListResourcesFunc OnBeforeListResourceTemplates []OnBeforeListResourceTemplatesFunc @@ -135,9 +149,9 @@ func (c *Hooks) AddOnSuccess(hook OnSuccessHookFunc) { // } // // // For parsing errors -// var parseErr = &UnparseableMessageError{} +// var parseErr = &UnparsableMessageError{} // if errors.As(err, &parseErr) { -// // Handle unparseable message errors +// // Handle unparsable message errors // fmt.Printf("Failed to parse %s request: %v\n", // parseErr.GetMethod(), parseErr.Unwrap()) // errChan <- parseErr @@ -191,7 +205,7 @@ func (c *Hooks) onSuccess(ctx context.Context, id any, method mcp.MCPMethod, mes // // Common error types include: // - ErrUnsupported: When a capability is not enabled -// - UnparseableMessageError: When request parsing fails +// - UnparsableMessageError: When request parsing fails // - ErrResourceNotFound: When a resource is not found // - ErrPromptNotFound: When a prompt is not found // - ErrToolNotFound: When a tool is not found @@ -216,6 +230,36 @@ func (c *Hooks) RegisterSession(ctx context.Context, session ClientSession) { hook(ctx, session) } } + +func (c *Hooks) AddOnUnregisterSession(hook OnUnregisterSessionHookFunc) { + c.OnUnregisterSession = append(c.OnUnregisterSession, hook) +} + +func (c *Hooks) UnregisterSession(ctx context.Context, session ClientSession) { + if c == nil { + return + } + for _, hook := range c.OnUnregisterSession { + hook(ctx, session) + } +} + +func (c *Hooks) AddOnRequestInitialization(hook OnRequestInitializationFunc) { + c.OnRequestInitialization = append(c.OnRequestInitialization, hook) +} + +func (c *Hooks) onRequestInitialization(ctx context.Context, id any, message any) error { + if c == nil { + return nil + } + for _, hook := range c.OnRequestInitialization { + err := hook(ctx, id, message) + if err != nil { + return err + } + } + return nil +} func (c *Hooks) AddBeforeInitialize(hook OnBeforeInitializeFunc) { c.OnBeforeInitialize = append(c.OnBeforeInitialize, hook) } @@ -270,6 +314,33 @@ func (c *Hooks) afterPing(ctx context.Context, id any, message *mcp.PingRequest, hook(ctx, id, message, result) } } +func (c *Hooks) AddBeforeSetLevel(hook OnBeforeSetLevelFunc) { + c.OnBeforeSetLevel = append(c.OnBeforeSetLevel, hook) +} + +func (c *Hooks) AddAfterSetLevel(hook OnAfterSetLevelFunc) { + c.OnAfterSetLevel = append(c.OnAfterSetLevel, hook) +} + +func (c *Hooks) beforeSetLevel(ctx context.Context, id any, message *mcp.SetLevelRequest) { + c.beforeAny(ctx, id, mcp.MethodSetLogLevel, message) + if c == nil { + return + } + for _, hook := range c.OnBeforeSetLevel { + hook(ctx, id, message) + } +} + +func (c *Hooks) afterSetLevel(ctx context.Context, id any, message *mcp.SetLevelRequest, result *mcp.EmptyResult) { + c.onSuccess(ctx, id, mcp.MethodSetLogLevel, message, result) + if c == nil { + return + } + for _, hook := range c.OnAfterSetLevel { + hook(ctx, id, message, result) + } +} func (c *Hooks) AddBeforeListResources(hook OnBeforeListResourcesFunc) { c.OnBeforeListResources = append(c.OnBeforeListResources, hook) } diff --git a/vendor/github.com/mark3labs/mcp-go/server/http_transport_options.go b/vendor/github.com/mark3labs/mcp-go/server/http_transport_options.go new file mode 100644 index 0000000000..4f5ad53d0d --- /dev/null +++ b/vendor/github.com/mark3labs/mcp-go/server/http_transport_options.go @@ -0,0 +1,11 @@ +package server + +import ( + "context" + "net/http" +) + +// HTTPContextFunc is a function that takes an existing context and the current +// request and returns a potentially modified context based on the request +// content. This can be used to inject context values from headers, for example. +type HTTPContextFunc func(ctx context.Context, r *http.Request) context.Context diff --git a/vendor/github.com/mark3labs/mcp-go/server/request_handler.go b/vendor/github.com/mark3labs/mcp-go/server/request_handler.go index 946ca7abd3..25f6ef14f3 100644 --- a/vendor/github.com/mark3labs/mcp-go/server/request_handler.go +++ b/vendor/github.com/mark3labs/mcp-go/server/request_handler.go @@ -23,6 +23,7 @@ func (s *MCPServer) HandleMessage( JSONRPC string `json:"jsonrpc"` Method mcp.MCPMethod `json:"method"` ID any `json:"id,omitempty"` + Result any `json:"result,omitempty"` } if err := json.Unmarshal(message, &baseMessage); err != nil { @@ -55,6 +56,21 @@ func (s *MCPServer) HandleMessage( return nil // Return nil for notifications } + if baseMessage.Result != nil { + // this is a response to a request sent by the server (e.g. from a ping + // sent due to WithKeepAlive option) + return nil + } + + handleErr := s.hooks.onRequestInitialization(ctx, baseMessage.ID, message) + if handleErr != nil { + return createErrorResponse( + baseMessage.ID, + mcp.INVALID_REQUEST, + handleErr.Error(), + ) + } + switch baseMessage.Method { case mcp.MethodInitialize: var request mcp.InitializeRequest @@ -63,7 +79,7 @@ func (s *MCPServer) HandleMessage( err = &requestError{ id: baseMessage.ID, code: mcp.INVALID_REQUEST, - err: &UnparseableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, + err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, } } else { s.hooks.beforeInitialize(ctx, baseMessage.ID, &request) @@ -82,7 +98,7 @@ func (s *MCPServer) HandleMessage( err = &requestError{ id: baseMessage.ID, code: mcp.INVALID_REQUEST, - err: &UnparseableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, + err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, } } else { s.hooks.beforePing(ctx, baseMessage.ID, &request) @@ -94,6 +110,31 @@ func (s *MCPServer) HandleMessage( } s.hooks.afterPing(ctx, baseMessage.ID, &request, result) return createResponse(baseMessage.ID, *result) + case mcp.MethodSetLogLevel: + var request mcp.SetLevelRequest + var result *mcp.EmptyResult + if s.capabilities.logging == nil { + err = &requestError{ + id: baseMessage.ID, + code: mcp.METHOD_NOT_FOUND, + err: fmt.Errorf("logging %w", ErrUnsupported), + } + } else if unmarshalErr := json.Unmarshal(message, &request); unmarshalErr != nil { + err = &requestError{ + id: baseMessage.ID, + code: mcp.INVALID_REQUEST, + err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, + } + } else { + s.hooks.beforeSetLevel(ctx, baseMessage.ID, &request) + result, err = s.handleSetLevel(ctx, baseMessage.ID, request) + } + if err != nil { + s.hooks.onError(ctx, baseMessage.ID, baseMessage.Method, &request, err) + return err.ToJSONRPCError() + } + s.hooks.afterSetLevel(ctx, baseMessage.ID, &request, result) + return createResponse(baseMessage.ID, *result) case mcp.MethodResourcesList: var request mcp.ListResourcesRequest var result *mcp.ListResourcesResult @@ -107,7 +148,7 @@ func (s *MCPServer) HandleMessage( err = &requestError{ id: baseMessage.ID, code: mcp.INVALID_REQUEST, - err: &UnparseableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, + err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, } } else { s.hooks.beforeListResources(ctx, baseMessage.ID, &request) @@ -132,7 +173,7 @@ func (s *MCPServer) HandleMessage( err = &requestError{ id: baseMessage.ID, code: mcp.INVALID_REQUEST, - err: &UnparseableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, + err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, } } else { s.hooks.beforeListResourceTemplates(ctx, baseMessage.ID, &request) @@ -157,7 +198,7 @@ func (s *MCPServer) HandleMessage( err = &requestError{ id: baseMessage.ID, code: mcp.INVALID_REQUEST, - err: &UnparseableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, + err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, } } else { s.hooks.beforeReadResource(ctx, baseMessage.ID, &request) @@ -182,7 +223,7 @@ func (s *MCPServer) HandleMessage( err = &requestError{ id: baseMessage.ID, code: mcp.INVALID_REQUEST, - err: &UnparseableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, + err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, } } else { s.hooks.beforeListPrompts(ctx, baseMessage.ID, &request) @@ -207,7 +248,7 @@ func (s *MCPServer) HandleMessage( err = &requestError{ id: baseMessage.ID, code: mcp.INVALID_REQUEST, - err: &UnparseableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, + err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, } } else { s.hooks.beforeGetPrompt(ctx, baseMessage.ID, &request) @@ -232,7 +273,7 @@ func (s *MCPServer) HandleMessage( err = &requestError{ id: baseMessage.ID, code: mcp.INVALID_REQUEST, - err: &UnparseableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, + err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, } } else { s.hooks.beforeListTools(ctx, baseMessage.ID, &request) @@ -257,7 +298,7 @@ func (s *MCPServer) HandleMessage( err = &requestError{ id: baseMessage.ID, code: mcp.INVALID_REQUEST, - err: &UnparseableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, + err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, } } else { s.hooks.beforeCallTool(ctx, baseMessage.ID, &request) diff --git a/vendor/github.com/mark3labs/mcp-go/server/sampling.go b/vendor/github.com/mark3labs/mcp-go/server/sampling.go new file mode 100644 index 0000000000..b633b24d07 --- /dev/null +++ b/vendor/github.com/mark3labs/mcp-go/server/sampling.go @@ -0,0 +1,37 @@ +package server + +import ( + "context" + "fmt" + + "github.com/mark3labs/mcp-go/mcp" +) + +// EnableSampling enables sampling capabilities for the server. +// This allows the server to send sampling requests to clients that support it. +func (s *MCPServer) EnableSampling() { + s.capabilitiesMu.Lock() + defer s.capabilitiesMu.Unlock() +} + +// RequestSampling sends a sampling request to the client. +// The client must have declared sampling capability during initialization. +func (s *MCPServer) RequestSampling(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) { + session := ClientSessionFromContext(ctx) + if session == nil { + return nil, fmt.Errorf("no active session") + } + + // Check if the session supports sampling requests + if samplingSession, ok := session.(SessionWithSampling); ok { + return samplingSession.RequestSampling(ctx, request) + } + + return nil, fmt.Errorf("session does not support sampling") +} + +// SessionWithSampling extends ClientSession to support sampling requests. +type SessionWithSampling interface { + ClientSession + RequestSampling(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) +} diff --git a/vendor/github.com/mark3labs/mcp-go/server/server.go b/vendor/github.com/mark3labs/mcp-go/server/server.go index ec4fcef006..46e6d9c571 100644 --- a/vendor/github.com/mark3labs/mcp-go/server/server.go +++ b/vendor/github.com/mark3labs/mcp-go/server/server.go @@ -1,11 +1,12 @@ -// Package server provides MCP (Model Control Protocol) server implementations. +// Package server provides MCP (Model Context Protocol) server implementations. package server import ( "context" + "encoding/base64" "encoding/json" - "errors" "fmt" + "slices" "sort" "sync" @@ -39,56 +40,62 @@ type PromptHandlerFunc func(ctx context.Context, request mcp.GetPromptRequest) ( // ToolHandlerFunc handles tool calls with given arguments. type ToolHandlerFunc func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) +// ToolHandlerMiddleware is a middleware function that wraps a ToolHandlerFunc. +type ToolHandlerMiddleware func(ToolHandlerFunc) ToolHandlerFunc + +// ToolFilterFunc is a function that filters tools based on context, typically using session information. +type ToolFilterFunc func(ctx context.Context, tools []mcp.Tool) []mcp.Tool + // ServerTool combines a Tool with its ToolHandlerFunc. type ServerTool struct { Tool mcp.Tool Handler ToolHandlerFunc } -// ClientSession represents an active session that can be used by MCPServer to interact with client. -type ClientSession interface { - // Initialize marks session as fully initialized and ready for notifications - Initialize() - // Initialized returns if session is ready to accept notifications - Initialized() bool - // NotificationChannel provides a channel suitable for sending notifications to client. - NotificationChannel() chan<- mcp.JSONRPCNotification - // SessionID is a unique identifier used to track user session. - SessionID() string +// ServerPrompt combines a Prompt with its handler function. +type ServerPrompt struct { + Prompt mcp.Prompt + Handler PromptHandlerFunc } -// clientSessionKey is the context key for storing current client notification channel. -type clientSessionKey struct{} +// ServerResource combines a Resource with its handler function. +type ServerResource struct { + Resource mcp.Resource + Handler ResourceHandlerFunc +} -// ClientSessionFromContext retrieves current client notification context from context. -func ClientSessionFromContext(ctx context.Context) ClientSession { - if session, ok := ctx.Value(clientSessionKey{}).(ClientSession); ok { - return session +// serverKey is the context key for storing the server instance +type serverKey struct{} + +// ServerFromContext retrieves the MCPServer instance from a context +func ServerFromContext(ctx context.Context) *MCPServer { + if srv, ok := ctx.Value(serverKey{}).(*MCPServer); ok { + return srv } return nil } -// UnparseableMessageError is attached to the RequestError when json.Unmarshal +// UnparsableMessageError is attached to the RequestError when json.Unmarshal // fails on the request. -type UnparseableMessageError struct { +type UnparsableMessageError struct { message json.RawMessage method mcp.MCPMethod err error } -func (e *UnparseableMessageError) Error() string { - return fmt.Sprintf("unparseable %s request: %s", e.method, e.err) +func (e *UnparsableMessageError) Error() string { + return fmt.Sprintf("unparsable %s request: %s", e.method, e.err) } -func (e *UnparseableMessageError) Unwrap() error { +func (e *UnparsableMessageError) Unwrap() error { return e.err } -func (e *UnparseableMessageError) GetMessage() json.RawMessage { +func (e *UnparsableMessageError) GetMessage() json.RawMessage { return e.message } -func (e *UnparseableMessageError) GetMethod() mcp.MCPMethod { +func (e *UnparsableMessageError) GetMethod() mcp.MCPMethod { return e.method } @@ -107,7 +114,7 @@ func (e *requestError) Error() string { func (e *requestError) ToJSONRPCError() mcp.JSONRPCError { return mcp.JSONRPCError{ JSONRPC: mcp.JSONRPC_VERSION, - ID: e.id, + ID: mcp.NewRequestId(e.id), Error: struct { Code int `json:"code"` Message string `json:"message"` @@ -123,126 +130,42 @@ func (e *requestError) Unwrap() error { return e.err } -var ( - ErrUnsupported = errors.New("not supported") - ErrResourceNotFound = errors.New("resource not found") - ErrPromptNotFound = errors.New("prompt not found") - ErrToolNotFound = errors.New("tool not found") -) - // NotificationHandlerFunc handles incoming notifications. type NotificationHandlerFunc func(ctx context.Context, notification mcp.JSONRPCNotification) -// MCPServer implements a Model Control Protocol server that can handle various types of requests +// MCPServer implements a Model Context Protocol server that can handle various types of requests // including resources, prompts, and tools. type MCPServer struct { - mu sync.RWMutex // Add mutex for protecting shared resources - name string - version string - instructions string - resources map[string]resourceEntry - resourceTemplates map[string]resourceTemplateEntry - prompts map[string]mcp.Prompt - promptHandlers map[string]PromptHandlerFunc - tools map[string]ServerTool - notificationHandlers map[string]NotificationHandlerFunc - capabilities serverCapabilities - sessions sync.Map - hooks *Hooks -} - -// serverKey is the context key for storing the server instance -type serverKey struct{} - -// ServerFromContext retrieves the MCPServer instance from a context -func ServerFromContext(ctx context.Context) *MCPServer { - if srv, ok := ctx.Value(serverKey{}).(*MCPServer); ok { - return srv - } - return nil -} - -// WithContext sets the current client session and returns the provided context -func (s *MCPServer) WithContext( - ctx context.Context, - session ClientSession, -) context.Context { - return context.WithValue(ctx, clientSessionKey{}, session) -} - -// RegisterSession saves session that should be notified in case if some server attributes changed. -func (s *MCPServer) RegisterSession( - ctx context.Context, - session ClientSession, -) error { - sessionID := session.SessionID() - if _, exists := s.sessions.LoadOrStore(sessionID, session); exists { - return fmt.Errorf("session %s is already registered", sessionID) - } - s.hooks.RegisterSession(ctx, session) - return nil -} - -// UnregisterSession removes from storage session that is shut down. -func (s *MCPServer) UnregisterSession( - sessionID string, -) { - s.sessions.Delete(sessionID) + // Separate mutexes for different resource types + resourcesMu sync.RWMutex + promptsMu sync.RWMutex + toolsMu sync.RWMutex + middlewareMu sync.RWMutex + notificationHandlersMu sync.RWMutex + capabilitiesMu sync.RWMutex + toolFiltersMu sync.RWMutex + + name string + version string + instructions string + resources map[string]resourceEntry + resourceTemplates map[string]resourceTemplateEntry + prompts map[string]mcp.Prompt + promptHandlers map[string]PromptHandlerFunc + tools map[string]ServerTool + toolHandlerMiddlewares []ToolHandlerMiddleware + toolFilters []ToolFilterFunc + notificationHandlers map[string]NotificationHandlerFunc + capabilities serverCapabilities + paginationLimit *int + sessions sync.Map + hooks *Hooks } -// sendNotificationToAllClients sends a notification to all the currently active clients. -func (s *MCPServer) sendNotificationToAllClients( - method string, - params map[string]any, -) { - notification := mcp.JSONRPCNotification{ - JSONRPC: mcp.JSONRPC_VERSION, - Notification: mcp.Notification{ - Method: method, - Params: mcp.NotificationParams{ - AdditionalFields: params, - }, - }, - } - - s.sessions.Range(func(k, v any) bool { - if session, ok := v.(ClientSession); ok && session.Initialized() { - select { - case session.NotificationChannel() <- notification: - default: - // TODO: log blocked channel in the future versions - } - } - return true - }) -} - -// SendNotificationToClient sends a notification to the current client -func (s *MCPServer) SendNotificationToClient( - ctx context.Context, - method string, - params map[string]any, -) error { - session := ClientSessionFromContext(ctx) - if session == nil || !session.Initialized() { - return fmt.Errorf("notification channel not initialized") - } - - notification := mcp.JSONRPCNotification{ - JSONRPC: mcp.JSONRPC_VERSION, - Notification: mcp.Notification{ - Method: method, - Params: mcp.NotificationParams{ - AdditionalFields: params, - }, - }, - } - - select { - case session.NotificationChannel() <- notification: - return nil - default: - return fmt.Errorf("notification channel full or blocked") +// WithPaginationLimit sets the pagination limit for the server. +func WithPaginationLimit(limit int) ServerOption { + return func(s *MCPServer) { + s.paginationLimit = &limit } } @@ -251,7 +174,7 @@ type serverCapabilities struct { tools *toolCapabilities resources *resourceCapabilities prompts *promptCapabilities - logging bool + logging *bool } // resourceCapabilities defines the supported resource-related features @@ -281,6 +204,47 @@ func WithResourceCapabilities(subscribe, listChanged bool) ServerOption { } } +// WithToolHandlerMiddleware allows adding a middleware for the +// tool handler call chain. +func WithToolHandlerMiddleware( + toolHandlerMiddleware ToolHandlerMiddleware, +) ServerOption { + return func(s *MCPServer) { + s.middlewareMu.Lock() + s.toolHandlerMiddlewares = append(s.toolHandlerMiddlewares, toolHandlerMiddleware) + s.middlewareMu.Unlock() + } +} + +// WithToolFilter adds a filter function that will be applied to tools before they are returned in list_tools +func WithToolFilter( + toolFilter ToolFilterFunc, +) ServerOption { + return func(s *MCPServer) { + s.toolFiltersMu.Lock() + s.toolFilters = append(s.toolFilters, toolFilter) + s.toolFiltersMu.Unlock() + } +} + +// WithRecovery adds a middleware that recovers from panics in tool handlers. +func WithRecovery() ServerOption { + return WithToolHandlerMiddleware(func(next ToolHandlerFunc) ToolHandlerFunc { + return func(ctx context.Context, request mcp.CallToolRequest) (result *mcp.CallToolResult, err error) { + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf( + "panic recovered in %s tool handler: %v", + request.Params.Name, + r, + ) + } + }() + return next(ctx, request) + } + }) +} + // WithHooks allows adding hooks that will be called before or after // either [all] requests or before / after specific request methods, or else // prior to returning an error to the client. @@ -313,7 +277,7 @@ func WithToolCapabilities(listChanged bool) ServerOption { // WithLogging enables logging capabilities for the server func WithLogging() ServerOption { return func(s *MCPServer) { - s.capabilities.logging = true + s.capabilities.logging = mcp.ToBoolPtr(true) } } @@ -342,7 +306,7 @@ func NewMCPServer( tools: nil, resources: nil, prompts: nil, - logging: false, + logging: nil, }, } @@ -353,19 +317,46 @@ func NewMCPServer( return s } +// AddResources registers multiple resources at once +func (s *MCPServer) AddResources(resources ...ServerResource) { + s.implicitlyRegisterResourceCapabilities() + + s.resourcesMu.Lock() + for _, entry := range resources { + s.resources[entry.Resource.URI] = resourceEntry{ + resource: entry.Resource, + handler: entry.Handler, + } + } + s.resourcesMu.Unlock() + + // When the list of available resources changes, servers that declared the listChanged capability SHOULD send a notification + if s.capabilities.resources.listChanged { + // Send notification to all initialized sessions + s.SendNotificationToAllClients(mcp.MethodNotificationResourcesListChanged, nil) + } +} + // AddResource registers a new resource and its handler func (s *MCPServer) AddResource( resource mcp.Resource, handler ResourceHandlerFunc, ) { - if s.capabilities.resources == nil { - s.capabilities.resources = &resourceCapabilities{} + s.AddResources(ServerResource{Resource: resource, Handler: handler}) +} + +// RemoveResource removes a resource from the server +func (s *MCPServer) RemoveResource(uri string) { + s.resourcesMu.Lock() + _, exists := s.resources[uri] + if exists { + delete(s.resources, uri) } - s.mu.Lock() - defer s.mu.Unlock() - s.resources[resource.URI] = resourceEntry{ - resource: resource, - handler: handler, + s.resourcesMu.Unlock() + + // Send notification to all initialized sessions if listChanged capability is enabled and we actually remove a resource + if exists && s.capabilities.resources != nil && s.capabilities.resources.listChanged { + s.SendNotificationToAllClients(mcp.MethodNotificationResourcesListChanged, nil) } } @@ -374,26 +365,63 @@ func (s *MCPServer) AddResourceTemplate( template mcp.ResourceTemplate, handler ResourceTemplateHandlerFunc, ) { - if s.capabilities.resources == nil { - s.capabilities.resources = &resourceCapabilities{} - } - s.mu.Lock() - defer s.mu.Unlock() + s.implicitlyRegisterResourceCapabilities() + + s.resourcesMu.Lock() s.resourceTemplates[template.URITemplate.Raw()] = resourceTemplateEntry{ template: template, handler: handler, } + s.resourcesMu.Unlock() + + // When the list of available resources changes, servers that declared the listChanged capability SHOULD send a notification + if s.capabilities.resources.listChanged { + // Send notification to all initialized sessions + s.SendNotificationToAllClients(mcp.MethodNotificationResourcesListChanged, nil) + } +} + +// AddPrompts registers multiple prompts at once +func (s *MCPServer) AddPrompts(prompts ...ServerPrompt) { + s.implicitlyRegisterPromptCapabilities() + + s.promptsMu.Lock() + for _, entry := range prompts { + s.prompts[entry.Prompt.Name] = entry.Prompt + s.promptHandlers[entry.Prompt.Name] = entry.Handler + } + s.promptsMu.Unlock() + + // When the list of available prompts changes, servers that declared the listChanged capability SHOULD send a notification. + if s.capabilities.prompts.listChanged { + // Send notification to all initialized sessions + s.SendNotificationToAllClients(mcp.MethodNotificationPromptsListChanged, nil) + } } // AddPrompt registers a new prompt handler with the given name func (s *MCPServer) AddPrompt(prompt mcp.Prompt, handler PromptHandlerFunc) { - if s.capabilities.prompts == nil { - s.capabilities.prompts = &promptCapabilities{} + s.AddPrompts(ServerPrompt{Prompt: prompt, Handler: handler}) +} + +// DeletePrompts removes prompts from the server +func (s *MCPServer) DeletePrompts(names ...string) { + s.promptsMu.Lock() + var exists bool + for _, name := range names { + if _, ok := s.prompts[name]; ok { + delete(s.prompts, name) + delete(s.promptHandlers, name) + exists = true + } + } + s.promptsMu.Unlock() + + // Send notification to all initialized sessions if listChanged capability is enabled, and we actually remove a prompt + if exists && s.capabilities.prompts != nil && s.capabilities.prompts.listChanged { + // Send notification to all initialized sessions + s.SendNotificationToAllClients(mcp.MethodNotificationPromptsListChanged, nil) } - s.mu.Lock() - defer s.mu.Unlock() - s.prompts[prompt.Name] = prompt - s.promptHandlers[prompt.Name] = handler } // AddTool registers a new tool and its handler @@ -401,39 +429,87 @@ func (s *MCPServer) AddTool(tool mcp.Tool, handler ToolHandlerFunc) { s.AddTools(ServerTool{Tool: tool, Handler: handler}) } +// Register tool capabilities due to a tool being added. Default to +// listChanged: true, but don't change the value if we've already explicitly +// registered tools.listChanged false. +func (s *MCPServer) implicitlyRegisterToolCapabilities() { + s.implicitlyRegisterCapabilities( + func() bool { return s.capabilities.tools != nil }, + func() { s.capabilities.tools = &toolCapabilities{listChanged: true} }, + ) +} + +func (s *MCPServer) implicitlyRegisterResourceCapabilities() { + s.implicitlyRegisterCapabilities( + func() bool { return s.capabilities.resources != nil }, + func() { s.capabilities.resources = &resourceCapabilities{} }, + ) +} + +func (s *MCPServer) implicitlyRegisterPromptCapabilities() { + s.implicitlyRegisterCapabilities( + func() bool { return s.capabilities.prompts != nil }, + func() { s.capabilities.prompts = &promptCapabilities{} }, + ) +} + +func (s *MCPServer) implicitlyRegisterCapabilities(check func() bool, register func()) { + s.capabilitiesMu.RLock() + if check() { + s.capabilitiesMu.RUnlock() + return + } + s.capabilitiesMu.RUnlock() + + s.capabilitiesMu.Lock() + if !check() { + register() + } + s.capabilitiesMu.Unlock() +} + // AddTools registers multiple tools at once func (s *MCPServer) AddTools(tools ...ServerTool) { - if s.capabilities.tools == nil { - s.capabilities.tools = &toolCapabilities{} - } - s.mu.Lock() + s.implicitlyRegisterToolCapabilities() + + s.toolsMu.Lock() for _, entry := range tools { s.tools[entry.Tool.Name] = entry } - s.mu.Unlock() + s.toolsMu.Unlock() - // Send notification to all initialized sessions - s.sendNotificationToAllClients("notifications/tools/list_changed", nil) + // When the list of available tools changes, servers that declared the listChanged capability SHOULD send a notification. + if s.capabilities.tools.listChanged { + // Send notification to all initialized sessions + s.SendNotificationToAllClients(mcp.MethodNotificationToolsListChanged, nil) + } } // SetTools replaces all existing tools with the provided list func (s *MCPServer) SetTools(tools ...ServerTool) { - s.mu.Lock() - s.tools = make(map[string]ServerTool) - s.mu.Unlock() + s.toolsMu.Lock() + s.tools = make(map[string]ServerTool, len(tools)) + s.toolsMu.Unlock() s.AddTools(tools...) } -// DeleteTools removes a tool from the server +// DeleteTools removes tools from the server func (s *MCPServer) DeleteTools(names ...string) { - s.mu.Lock() + s.toolsMu.Lock() + var exists bool for _, name := range names { - delete(s.tools, name) + if _, ok := s.tools[name]; ok { + delete(s.tools, name) + exists = true + } } - s.mu.Unlock() + s.toolsMu.Unlock() - // Send notification to all initialized sessions - s.sendNotificationToAllClients("notifications/tools/list_changed", nil) + // When the list of available tools changes, servers that declared the listChanged capability SHOULD send a notification. + if exists && s.capabilities.tools != nil && s.capabilities.tools.listChanged { + // Send notification to all initialized sessions + s.SendNotificationToAllClients(mcp.MethodNotificationToolsListChanged, nil) + } } // AddNotificationHandler registers a new handler for incoming notifications @@ -441,14 +517,14 @@ func (s *MCPServer) AddNotificationHandler( method string, handler NotificationHandlerFunc, ) { - s.mu.Lock() - defer s.mu.Unlock() + s.notificationHandlersMu.Lock() + defer s.notificationHandlersMu.Unlock() s.notificationHandlers[method] = handler } func (s *MCPServer) handleInitialize( ctx context.Context, - id interface{}, + _ any, request mcp.InitializeRequest, ) (*mcp.InitializeResult, *requestError) { capabilities := mcp.ServerCapabilities{} @@ -482,12 +558,12 @@ func (s *MCPServer) handleInitialize( } } - if s.capabilities.logging { + if s.capabilities.logging != nil && *s.capabilities.logging { capabilities.Logging = &struct{}{} } result := mcp.InitializeResult{ - ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION, + ProtocolVersion: s.protocolVersion(request.Params.ProtocolVersion), ServerInfo: mcp.Implementation{ Name: s.name, Version: s.version, @@ -498,70 +574,194 @@ func (s *MCPServer) handleInitialize( if session := ClientSessionFromContext(ctx); session != nil { session.Initialize() + + // Store client info if the session supports it + if sessionWithClientInfo, ok := session.(SessionWithClientInfo); ok { + sessionWithClientInfo.SetClientInfo(request.Params.ClientInfo) + } } return &result, nil } +func (s *MCPServer) protocolVersion(clientVersion string) string { + if slices.Contains(mcp.ValidProtocolVersions, clientVersion) { + return clientVersion + } + + return mcp.LATEST_PROTOCOL_VERSION +} + func (s *MCPServer) handlePing( + _ context.Context, + _ any, + _ mcp.PingRequest, +) (*mcp.EmptyResult, *requestError) { + return &mcp.EmptyResult{}, nil +} + +func (s *MCPServer) handleSetLevel( ctx context.Context, - id interface{}, - request mcp.PingRequest, + id any, + request mcp.SetLevelRequest, ) (*mcp.EmptyResult, *requestError) { + clientSession := ClientSessionFromContext(ctx) + if clientSession == nil || !clientSession.Initialized() { + return nil, &requestError{ + id: id, + code: mcp.INTERNAL_ERROR, + err: ErrSessionNotInitialized, + } + } + + sessionLogging, ok := clientSession.(SessionWithLogging) + if !ok { + return nil, &requestError{ + id: id, + code: mcp.INTERNAL_ERROR, + err: ErrSessionDoesNotSupportLogging, + } + } + + level := request.Params.Level + // Validate logging level + switch level { + case mcp.LoggingLevelDebug, mcp.LoggingLevelInfo, mcp.LoggingLevelNotice, + mcp.LoggingLevelWarning, mcp.LoggingLevelError, mcp.LoggingLevelCritical, + mcp.LoggingLevelAlert, mcp.LoggingLevelEmergency: + // Valid level + default: + return nil, &requestError{ + id: id, + code: mcp.INVALID_PARAMS, + err: fmt.Errorf("invalid logging level '%s'", level), + } + } + + sessionLogging.SetLogLevel(level) + return &mcp.EmptyResult{}, nil } +func listByPagination[T mcp.Named]( + _ context.Context, + s *MCPServer, + cursor mcp.Cursor, + allElements []T, +) ([]T, mcp.Cursor, error) { + startPos := 0 + if cursor != "" { + c, err := base64.StdEncoding.DecodeString(string(cursor)) + if err != nil { + return nil, "", err + } + cString := string(c) + startPos = sort.Search(len(allElements), func(i int) bool { + return allElements[i].GetName() > cString + }) + } + endPos := len(allElements) + if s.paginationLimit != nil { + if len(allElements) > startPos+*s.paginationLimit { + endPos = startPos + *s.paginationLimit + } + } + elementsToReturn := allElements[startPos:endPos] + // set the next cursor + nextCursor := func() mcp.Cursor { + if s.paginationLimit != nil && len(elementsToReturn) >= *s.paginationLimit { + nc := elementsToReturn[len(elementsToReturn)-1].GetName() + toString := base64.StdEncoding.EncodeToString([]byte(nc)) + return mcp.Cursor(toString) + } + return "" + }() + return elementsToReturn, nextCursor, nil +} + func (s *MCPServer) handleListResources( ctx context.Context, - id interface{}, + id any, request mcp.ListResourcesRequest, ) (*mcp.ListResourcesResult, *requestError) { - s.mu.RLock() + s.resourcesMu.RLock() resources := make([]mcp.Resource, 0, len(s.resources)) for _, entry := range s.resources { resources = append(resources, entry.resource) } - s.mu.RUnlock() + s.resourcesMu.RUnlock() - result := mcp.ListResourcesResult{ - Resources: resources, + // Sort the resources by name + sort.Slice(resources, func(i, j int) bool { + return resources[i].Name < resources[j].Name + }) + resourcesToReturn, nextCursor, err := listByPagination( + ctx, + s, + request.Params.Cursor, + resources, + ) + if err != nil { + return nil, &requestError{ + id: id, + code: mcp.INVALID_PARAMS, + err: err, + } } - if request.Params.Cursor != "" { - result.NextCursor = "" // Handle pagination if needed + result := mcp.ListResourcesResult{ + Resources: resourcesToReturn, + PaginatedResult: mcp.PaginatedResult{ + NextCursor: nextCursor, + }, } return &result, nil } func (s *MCPServer) handleListResourceTemplates( ctx context.Context, - id interface{}, + id any, request mcp.ListResourceTemplatesRequest, ) (*mcp.ListResourceTemplatesResult, *requestError) { - s.mu.RLock() + s.resourcesMu.RLock() templates := make([]mcp.ResourceTemplate, 0, len(s.resourceTemplates)) for _, entry := range s.resourceTemplates { templates = append(templates, entry.template) } - s.mu.RUnlock() - - result := mcp.ListResourceTemplatesResult{ - ResourceTemplates: templates, + s.resourcesMu.RUnlock() + sort.Slice(templates, func(i, j int) bool { + return templates[i].Name < templates[j].Name + }) + templatesToReturn, nextCursor, err := listByPagination( + ctx, + s, + request.Params.Cursor, + templates, + ) + if err != nil { + return nil, &requestError{ + id: id, + code: mcp.INVALID_PARAMS, + err: err, + } } - if request.Params.Cursor != "" { - result.NextCursor = "" // Handle pagination if needed + result := mcp.ListResourceTemplatesResult{ + ResourceTemplates: templatesToReturn, + PaginatedResult: mcp.PaginatedResult{ + NextCursor: nextCursor, + }, } return &result, nil } func (s *MCPServer) handleReadResource( ctx context.Context, - id interface{}, + id any, request mcp.ReadResourceRequest, ) (*mcp.ReadResourceResult, *requestError) { - s.mu.RLock() + s.resourcesMu.RLock() // First try direct resource handlers if entry, ok := s.resources[request.Params.URI]; ok { handler := entry.handler - s.mu.RUnlock() + s.resourcesMu.RUnlock() contents, err := handler(ctx, request) if err != nil { return nil, &requestError{ @@ -583,14 +783,14 @@ func (s *MCPServer) handleReadResource( matched = true matchedVars := template.URITemplate.Match(request.Params.URI) // Convert matched variables to a map - request.Params.Arguments = make(map[string]interface{}) + request.Params.Arguments = make(map[string]any, len(matchedVars)) for name, value := range matchedVars { request.Params.Arguments[name] = value.V } break } } - s.mu.RUnlock() + s.resourcesMu.RUnlock() if matched { contents, err := matchedHandler(ctx, request) @@ -606,8 +806,12 @@ func (s *MCPServer) handleReadResource( return nil, &requestError{ id: id, - code: mcp.INVALID_PARAMS, - err: fmt.Errorf("handler not found for resource URI '%s': %w", request.Params.URI, ErrResourceNotFound), + code: mcp.RESOURCE_NOT_FOUND, + err: fmt.Errorf( + "handler not found for resource URI '%s': %w", + request.Params.URI, + ErrResourceNotFound, + ), } } @@ -618,33 +822,50 @@ func matchesTemplate(uri string, template *mcp.URITemplate) bool { func (s *MCPServer) handleListPrompts( ctx context.Context, - id interface{}, + id any, request mcp.ListPromptsRequest, ) (*mcp.ListPromptsResult, *requestError) { - s.mu.RLock() + s.promptsMu.RLock() prompts := make([]mcp.Prompt, 0, len(s.prompts)) for _, prompt := range s.prompts { prompts = append(prompts, prompt) } - s.mu.RUnlock() + s.promptsMu.RUnlock() - result := mcp.ListPromptsResult{ - Prompts: prompts, + // sort prompts by name + sort.Slice(prompts, func(i, j int) bool { + return prompts[i].Name < prompts[j].Name + }) + promptsToReturn, nextCursor, err := listByPagination( + ctx, + s, + request.Params.Cursor, + prompts, + ) + if err != nil { + return nil, &requestError{ + id: id, + code: mcp.INVALID_PARAMS, + err: err, + } } - if request.Params.Cursor != "" { - result.NextCursor = "" // Handle pagination if needed + result := mcp.ListPromptsResult{ + Prompts: promptsToReturn, + PaginatedResult: mcp.PaginatedResult{ + NextCursor: nextCursor, + }, } return &result, nil } func (s *MCPServer) handleGetPrompt( ctx context.Context, - id interface{}, + id any, request mcp.GetPromptRequest, ) (*mcp.GetPromptResult, *requestError) { - s.mu.RLock() + s.promptsMu.RLock() handler, ok := s.promptHandlers[request.Params.Name] - s.mu.RUnlock() + s.promptsMu.RUnlock() if !ok { return nil, &requestError{ @@ -668,10 +889,11 @@ func (s *MCPServer) handleGetPrompt( func (s *MCPServer) handleListTools( ctx context.Context, - id interface{}, + id any, request mcp.ListToolsRequest, ) (*mcp.ListToolsResult, *requestError) { - s.mu.RLock() + // Get the base tools from the server + s.toolsMu.RLock() tools := make([]mcp.Tool, 0, len(s.tools)) // Get all tool names for consistent ordering @@ -687,24 +909,102 @@ func (s *MCPServer) handleListTools( for _, name := range toolNames { tools = append(tools, s.tools[name].Tool) } - s.mu.RUnlock() + s.toolsMu.RUnlock() - result := mcp.ListToolsResult{ - Tools: tools, + // Check if there are session-specific tools + session := ClientSessionFromContext(ctx) + if session != nil { + if sessionWithTools, ok := session.(SessionWithTools); ok { + if sessionTools := sessionWithTools.GetSessionTools(); sessionTools != nil { + // Override or add session-specific tools + // We need to create a map first to merge the tools properly + toolMap := make(map[string]mcp.Tool) + + // Add global tools first + for _, tool := range tools { + toolMap[tool.Name] = tool + } + + // Then override with session-specific tools + for name, serverTool := range sessionTools { + toolMap[name] = serverTool.Tool + } + + // Convert back to slice + tools = make([]mcp.Tool, 0, len(toolMap)) + for _, tool := range toolMap { + tools = append(tools, tool) + } + + // Sort again to maintain consistent ordering + sort.Slice(tools, func(i, j int) bool { + return tools[i].Name < tools[j].Name + }) + } + } } - if request.Params.Cursor != "" { - result.NextCursor = "" // Handle pagination if needed + + // Apply tool filters if any are defined + s.toolFiltersMu.RLock() + if len(s.toolFilters) > 0 { + for _, filter := range s.toolFilters { + tools = filter(ctx, tools) + } + } + s.toolFiltersMu.RUnlock() + + // Apply pagination + toolsToReturn, nextCursor, err := listByPagination( + ctx, + s, + request.Params.Cursor, + tools, + ) + if err != nil { + return nil, &requestError{ + id: id, + code: mcp.INVALID_PARAMS, + err: err, + } + } + + result := mcp.ListToolsResult{ + Tools: toolsToReturn, + PaginatedResult: mcp.PaginatedResult{ + NextCursor: nextCursor, + }, } return &result, nil } + func (s *MCPServer) handleToolCall( ctx context.Context, - id interface{}, + id any, request mcp.CallToolRequest, ) (*mcp.CallToolResult, *requestError) { - s.mu.RLock() - tool, ok := s.tools[request.Params.Name] - s.mu.RUnlock() + // First check session-specific tools + var tool ServerTool + var ok bool + + session := ClientSessionFromContext(ctx) + if session != nil { + if sessionWithTools, typeAssertOk := session.(SessionWithTools); typeAssertOk { + if sessionTools := sessionWithTools.GetSessionTools(); sessionTools != nil { + var sessionOk bool + tool, sessionOk = sessionTools[request.Params.Name] + if sessionOk { + ok = true + } + } + } + } + + // If not found in session tools, check global tools + if !ok { + s.toolsMu.RLock() + tool, ok = s.tools[request.Params.Name] + s.toolsMu.RUnlock() + } if !ok { return nil, &requestError{ @@ -714,7 +1014,18 @@ func (s *MCPServer) handleToolCall( } } - result, err := tool.Handler(ctx, request) + finalHandler := tool.Handler + + s.middlewareMu.RLock() + mw := s.toolHandlerMiddlewares + s.middlewareMu.RUnlock() + + // Apply middlewares in reverse order + for i := len(mw) - 1; i >= 0; i-- { + finalHandler = mw[i](finalHandler) + } + + result, err := finalHandler(ctx, request) if err != nil { return nil, &requestError{ id: id, @@ -730,9 +1041,9 @@ func (s *MCPServer) handleNotification( ctx context.Context, notification mcp.JSONRPCNotification, ) mcp.JSONRPCMessage { - s.mu.RLock() + s.notificationHandlersMu.RLock() handler, ok := s.notificationHandlers[notification.Method] - s.mu.RUnlock() + s.notificationHandlersMu.RUnlock() if ok { handler(ctx, notification) @@ -740,26 +1051,26 @@ func (s *MCPServer) handleNotification( return nil } -func createResponse(id interface{}, result interface{}) mcp.JSONRPCMessage { +func createResponse(id any, result any) mcp.JSONRPCMessage { return mcp.JSONRPCResponse{ JSONRPC: mcp.JSONRPC_VERSION, - ID: id, + ID: mcp.NewRequestId(id), Result: result, } } func createErrorResponse( - id interface{}, + id any, code int, message string, ) mcp.JSONRPCMessage { return mcp.JSONRPCError{ JSONRPC: mcp.JSONRPC_VERSION, - ID: id, + ID: mcp.NewRequestId(id), Error: struct { - Code int `json:"code"` - Message string `json:"message"` - Data interface{} `json:"data,omitempty"` + Code int `json:"code"` + Message string `json:"message"` + Data any `json:"data,omitempty"` }{ Code: code, Message: message, diff --git a/vendor/github.com/mark3labs/mcp-go/server/session.go b/vendor/github.com/mark3labs/mcp-go/server/session.go new file mode 100644 index 0000000000..a79da22cad --- /dev/null +++ b/vendor/github.com/mark3labs/mcp-go/server/session.go @@ -0,0 +1,380 @@ +package server + +import ( + "context" + "fmt" + + "github.com/mark3labs/mcp-go/mcp" +) + +// ClientSession represents an active session that can be used by MCPServer to interact with client. +type ClientSession interface { + // Initialize marks session as fully initialized and ready for notifications + Initialize() + // Initialized returns if session is ready to accept notifications + Initialized() bool + // NotificationChannel provides a channel suitable for sending notifications to client. + NotificationChannel() chan<- mcp.JSONRPCNotification + // SessionID is a unique identifier used to track user session. + SessionID() string +} + +// SessionWithLogging is an extension of ClientSession that can receive log message notifications and set log level +type SessionWithLogging interface { + ClientSession + // SetLogLevel sets the minimum log level + SetLogLevel(level mcp.LoggingLevel) + // GetLogLevel retrieves the minimum log level + GetLogLevel() mcp.LoggingLevel +} + +// SessionWithTools is an extension of ClientSession that can store session-specific tool data +type SessionWithTools interface { + ClientSession + // GetSessionTools returns the tools specific to this session, if any + // This method must be thread-safe for concurrent access + GetSessionTools() map[string]ServerTool + // SetSessionTools sets tools specific to this session + // This method must be thread-safe for concurrent access + SetSessionTools(tools map[string]ServerTool) +} + +// SessionWithClientInfo is an extension of ClientSession that can store client info +type SessionWithClientInfo interface { + ClientSession + // GetClientInfo returns the client information for this session + GetClientInfo() mcp.Implementation + // SetClientInfo sets the client information for this session + SetClientInfo(clientInfo mcp.Implementation) +} + +// SessionWithStreamableHTTPConfig extends ClientSession to support streamable HTTP transport configurations +type SessionWithStreamableHTTPConfig interface { + ClientSession + // UpgradeToSSEWhenReceiveNotification upgrades the client-server communication to SSE stream when the server + // sends notifications to the client + // + // The protocol specification: + // - If the server response contains any JSON-RPC notifications, it MUST either: + // - Return Content-Type: text/event-stream to initiate an SSE stream, OR + // - Return Content-Type: application/json for a single JSON object + // - The client MUST support both response types. + // + // Reference: https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#sending-messages-to-the-server + UpgradeToSSEWhenReceiveNotification() +} + +// clientSessionKey is the context key for storing current client notification channel. +type clientSessionKey struct{} + +// ClientSessionFromContext retrieves current client notification context from context. +func ClientSessionFromContext(ctx context.Context) ClientSession { + if session, ok := ctx.Value(clientSessionKey{}).(ClientSession); ok { + return session + } + return nil +} + +// WithContext sets the current client session and returns the provided context +func (s *MCPServer) WithContext( + ctx context.Context, + session ClientSession, +) context.Context { + return context.WithValue(ctx, clientSessionKey{}, session) +} + +// RegisterSession saves session that should be notified in case if some server attributes changed. +func (s *MCPServer) RegisterSession( + ctx context.Context, + session ClientSession, +) error { + sessionID := session.SessionID() + if _, exists := s.sessions.LoadOrStore(sessionID, session); exists { + return ErrSessionExists + } + s.hooks.RegisterSession(ctx, session) + return nil +} + +// UnregisterSession removes from storage session that is shut down. +func (s *MCPServer) UnregisterSession( + ctx context.Context, + sessionID string, +) { + sessionValue, ok := s.sessions.LoadAndDelete(sessionID) + if !ok { + return + } + if session, ok := sessionValue.(ClientSession); ok { + s.hooks.UnregisterSession(ctx, session) + } +} + +// SendNotificationToAllClients sends a notification to all the currently active clients. +func (s *MCPServer) SendNotificationToAllClients( + method string, + params map[string]any, +) { + notification := mcp.JSONRPCNotification{ + JSONRPC: mcp.JSONRPC_VERSION, + Notification: mcp.Notification{ + Method: method, + Params: mcp.NotificationParams{ + AdditionalFields: params, + }, + }, + } + + s.sessions.Range(func(k, v any) bool { + if session, ok := v.(ClientSession); ok && session.Initialized() { + select { + case session.NotificationChannel() <- notification: + // Successfully sent notification + default: + // Channel is blocked, if there's an error hook, use it + if s.hooks != nil && len(s.hooks.OnError) > 0 { + err := ErrNotificationChannelBlocked + // Copy hooks pointer to local variable to avoid race condition + hooks := s.hooks + go func(sessionID string, hooks *Hooks) { + ctx := context.Background() + // Use the error hook to report the blocked channel + hooks.onError(ctx, nil, "notification", map[string]any{ + "method": method, + "sessionID": sessionID, + }, fmt.Errorf("notification channel blocked for session %s: %w", sessionID, err)) + }(session.SessionID(), hooks) + } + } + } + return true + }) +} + +// SendNotificationToClient sends a notification to the current client +func (s *MCPServer) SendNotificationToClient( + ctx context.Context, + method string, + params map[string]any, +) error { + session := ClientSessionFromContext(ctx) + if session == nil || !session.Initialized() { + return ErrNotificationNotInitialized + } + + // upgrades the client-server communication to SSE stream when the server sends notifications to the client + if sessionWithStreamableHTTPConfig, ok := session.(SessionWithStreamableHTTPConfig); ok { + sessionWithStreamableHTTPConfig.UpgradeToSSEWhenReceiveNotification() + } + + notification := mcp.JSONRPCNotification{ + JSONRPC: mcp.JSONRPC_VERSION, + Notification: mcp.Notification{ + Method: method, + Params: mcp.NotificationParams{ + AdditionalFields: params, + }, + }, + } + + select { + case session.NotificationChannel() <- notification: + return nil + default: + // Channel is blocked, if there's an error hook, use it + if s.hooks != nil && len(s.hooks.OnError) > 0 { + err := ErrNotificationChannelBlocked + // Copy hooks pointer to local variable to avoid race condition + hooks := s.hooks + go func(sessionID string, hooks *Hooks) { + // Use the error hook to report the blocked channel + hooks.onError(ctx, nil, "notification", map[string]any{ + "method": method, + "sessionID": sessionID, + }, fmt.Errorf("notification channel blocked for session %s: %w", sessionID, err)) + }(session.SessionID(), hooks) + } + return ErrNotificationChannelBlocked + } +} + +// SendNotificationToSpecificClient sends a notification to a specific client by session ID +func (s *MCPServer) SendNotificationToSpecificClient( + sessionID string, + method string, + params map[string]any, +) error { + sessionValue, ok := s.sessions.Load(sessionID) + if !ok { + return ErrSessionNotFound + } + + session, ok := sessionValue.(ClientSession) + if !ok || !session.Initialized() { + return ErrSessionNotInitialized + } + + // upgrades the client-server communication to SSE stream when the server sends notifications to the client + if sessionWithStreamableHTTPConfig, ok := session.(SessionWithStreamableHTTPConfig); ok { + sessionWithStreamableHTTPConfig.UpgradeToSSEWhenReceiveNotification() + } + + notification := mcp.JSONRPCNotification{ + JSONRPC: mcp.JSONRPC_VERSION, + Notification: mcp.Notification{ + Method: method, + Params: mcp.NotificationParams{ + AdditionalFields: params, + }, + }, + } + + select { + case session.NotificationChannel() <- notification: + return nil + default: + // Channel is blocked, if there's an error hook, use it + if s.hooks != nil && len(s.hooks.OnError) > 0 { + err := ErrNotificationChannelBlocked + ctx := context.Background() + // Copy hooks pointer to local variable to avoid race condition + hooks := s.hooks + go func(sID string, hooks *Hooks) { + // Use the error hook to report the blocked channel + hooks.onError(ctx, nil, "notification", map[string]any{ + "method": method, + "sessionID": sID, + }, fmt.Errorf("notification channel blocked for session %s: %w", sID, err)) + }(sessionID, hooks) + } + return ErrNotificationChannelBlocked + } +} + +// AddSessionTool adds a tool for a specific session +func (s *MCPServer) AddSessionTool(sessionID string, tool mcp.Tool, handler ToolHandlerFunc) error { + return s.AddSessionTools(sessionID, ServerTool{Tool: tool, Handler: handler}) +} + +// AddSessionTools adds tools for a specific session +func (s *MCPServer) AddSessionTools(sessionID string, tools ...ServerTool) error { + sessionValue, ok := s.sessions.Load(sessionID) + if !ok { + return ErrSessionNotFound + } + + session, ok := sessionValue.(SessionWithTools) + if !ok { + return ErrSessionDoesNotSupportTools + } + + s.implicitlyRegisterToolCapabilities() + + // Get existing tools (this should return a thread-safe copy) + sessionTools := session.GetSessionTools() + + // Create a new map to avoid concurrent modification issues + newSessionTools := make(map[string]ServerTool, len(sessionTools)+len(tools)) + + // Copy existing tools + for k, v := range sessionTools { + newSessionTools[k] = v + } + + // Add new tools + for _, tool := range tools { + newSessionTools[tool.Tool.Name] = tool + } + + // Set the tools (this should be thread-safe) + session.SetSessionTools(newSessionTools) + + // It only makes sense to send tool notifications to initialized sessions -- + // if we're not initialized yet the client can't possibly have sent their + // initial tools/list message. + // + // For initialized sessions, honor tools.listChanged, which is specifically + // about whether notifications will be sent or not. + // see + if session.Initialized() && s.capabilities.tools != nil && s.capabilities.tools.listChanged { + // Send notification only to this session + if err := s.SendNotificationToSpecificClient(sessionID, "notifications/tools/list_changed", nil); err != nil { + // Log the error but don't fail the operation + // The tools were successfully added, but notification failed + if s.hooks != nil && len(s.hooks.OnError) > 0 { + hooks := s.hooks + go func(sID string, hooks *Hooks) { + ctx := context.Background() + hooks.onError(ctx, nil, "notification", map[string]any{ + "method": "notifications/tools/list_changed", + "sessionID": sID, + }, fmt.Errorf("failed to send notification after adding tools: %w", err)) + }(sessionID, hooks) + } + } + } + + return nil +} + +// DeleteSessionTools removes tools from a specific session +func (s *MCPServer) DeleteSessionTools(sessionID string, names ...string) error { + sessionValue, ok := s.sessions.Load(sessionID) + if !ok { + return ErrSessionNotFound + } + + session, ok := sessionValue.(SessionWithTools) + if !ok { + return ErrSessionDoesNotSupportTools + } + + // Get existing tools (this should return a thread-safe copy) + sessionTools := session.GetSessionTools() + if sessionTools == nil { + return nil + } + + // Create a new map to avoid concurrent modification issues + newSessionTools := make(map[string]ServerTool, len(sessionTools)) + + // Copy existing tools except those being deleted + for k, v := range sessionTools { + newSessionTools[k] = v + } + + // Remove specified tools + for _, name := range names { + delete(newSessionTools, name) + } + + // Set the tools (this should be thread-safe) + session.SetSessionTools(newSessionTools) + + // It only makes sense to send tool notifications to initialized sessions -- + // if we're not initialized yet the client can't possibly have sent their + // initial tools/list message. + // + // For initialized sessions, honor tools.listChanged, which is specifically + // about whether notifications will be sent or not. + // see + if session.Initialized() && s.capabilities.tools != nil && s.capabilities.tools.listChanged { + // Send notification only to this session + if err := s.SendNotificationToSpecificClient(sessionID, "notifications/tools/list_changed", nil); err != nil { + // Log the error but don't fail the operation + // The tools were successfully deleted, but notification failed + if s.hooks != nil && len(s.hooks.OnError) > 0 { + hooks := s.hooks + go func(sID string, hooks *Hooks) { + ctx := context.Background() + hooks.onError(ctx, nil, "notification", map[string]any{ + "method": "notifications/tools/list_changed", + "sessionID": sID, + }, fmt.Errorf("failed to send notification after deleting tools: %w", err)) + }(sessionID, hooks) + } + } + } + + return nil +} diff --git a/vendor/github.com/mark3labs/mcp-go/server/sse.go b/vendor/github.com/mark3labs/mcp-go/server/sse.go index 6e6a13fe78..4169957307 100644 --- a/vendor/github.com/mark3labs/mcp-go/server/sse.go +++ b/vendor/github.com/mark3labs/mcp-go/server/sse.go @@ -4,26 +4,32 @@ import ( "context" "encoding/json" "fmt" + "log" "net/http" "net/http/httptest" "net/url" + "path" "strings" "sync" "sync/atomic" + "time" "github.com/google/uuid" + "github.com/mark3labs/mcp-go/mcp" ) // sseSession represents an active SSE connection. type sseSession struct { - writer http.ResponseWriter - flusher http.Flusher done chan struct{} eventQueue chan string // Channel for queuing events sessionID string + requestID atomic.Int64 notificationChannel chan mcp.JSONRPCNotification initialized atomic.Bool + loggingLevel atomic.Value + tools sync.Map // stores session-specific tools + clientInfo atomic.Value // stores session-specific client info } // SSEContextFunc is a function that takes an existing context and the current @@ -31,6 +37,13 @@ type sseSession struct { // content. This can be used to inject context values from headers, for example. type SSEContextFunc func(ctx context.Context, r *http.Request) context.Context +// DynamicBasePathFunc allows the user to provide a function to generate the +// base path for a given request and sessionID. This is useful for cases where +// the base path is not known at the time of SSE server creation, such as when +// using a reverse proxy or when the base path is dynamically generated. The +// function should return the base path (e.g., "/mcp/tenant123"). +type DynamicBasePathFunc func(r *http.Request, sessionID string) string + func (s *sseSession) SessionID() string { return s.sessionID } @@ -40,6 +53,8 @@ func (s *sseSession) NotificationChannel() chan<- mcp.JSONRPCNotification { } func (s *sseSession) Initialize() { + // set default logging level + s.loggingLevel.Store(mcp.LoggingLevelError) s.initialized.Store(true) } @@ -47,7 +62,58 @@ func (s *sseSession) Initialized() bool { return s.initialized.Load() } -var _ ClientSession = (*sseSession)(nil) +func (s *sseSession) SetLogLevel(level mcp.LoggingLevel) { + s.loggingLevel.Store(level) +} + +func (s *sseSession) GetLogLevel() mcp.LoggingLevel { + level := s.loggingLevel.Load() + if level == nil { + return mcp.LoggingLevelError + } + return level.(mcp.LoggingLevel) +} + +func (s *sseSession) GetSessionTools() map[string]ServerTool { + tools := make(map[string]ServerTool) + s.tools.Range(func(key, value any) bool { + if tool, ok := value.(ServerTool); ok { + tools[key.(string)] = tool + } + return true + }) + return tools +} + +func (s *sseSession) SetSessionTools(tools map[string]ServerTool) { + // Clear existing tools + s.tools.Clear() + + // Set new tools + for name, tool := range tools { + s.tools.Store(name, tool) + } +} + +func (s *sseSession) GetClientInfo() mcp.Implementation { + if value := s.clientInfo.Load(); value != nil { + if clientInfo, ok := value.(mcp.Implementation); ok { + return clientInfo + } + } + return mcp.Implementation{} +} + +func (s *sseSession) SetClientInfo(clientInfo mcp.Implementation) { + s.clientInfo.Store(clientInfo) +} + +var ( + _ ClientSession = (*sseSession)(nil) + _ SessionWithTools = (*sseSession)(nil) + _ SessionWithLogging = (*sseSession)(nil) + _ SessionWithClientInfo = (*sseSession)(nil) +) // SSEServer implements a Server-Sent Events (SSE) based MCP server. // It provides real-time communication capabilities over HTTP using the SSE protocol. @@ -55,12 +121,19 @@ type SSEServer struct { server *MCPServer baseURL string basePath string - messageEndpoint string + appendQueryToMessageEndpoint bool useFullURLForMessageEndpoint bool + messageEndpoint string sseEndpoint string sessions sync.Map srv *http.Server contextFunc SSEContextFunc + dynamicBasePathFunc DynamicBasePathFunc + + keepAlive bool + keepAliveInterval time.Duration + + mu sync.RWMutex } // SSEOption defines a function type for configuring SSEServer @@ -89,14 +162,34 @@ func WithBaseURL(baseURL string) SSEOption { } } -// Add a new option for setting base path +// WithStaticBasePath adds a new option for setting a static base path +func WithStaticBasePath(basePath string) SSEOption { + return func(s *SSEServer) { + s.basePath = normalizeURLPath(basePath) + } +} + +// WithBasePath adds a new option for setting a static base path. +// +// Deprecated: Use WithStaticBasePath instead. This will be removed in a future version. +// +//go:deprecated func WithBasePath(basePath string) SSEOption { + return WithStaticBasePath(basePath) +} + +// WithDynamicBasePath accepts a function for generating the base path. This is +// useful for cases where the base path is not known at the time of SSE server +// creation, such as when using a reverse proxy or when the server is mounted +// at a dynamic path. +func WithDynamicBasePath(fn DynamicBasePathFunc) SSEOption { return func(s *SSEServer) { - // Ensure the path starts with / and doesn't end with / - if !strings.HasPrefix(basePath, "/") { - basePath = "/" + basePath + if fn != nil { + s.dynamicBasePathFunc = func(r *http.Request, sid string) string { + bp := fn(r, sid) + return normalizeURLPath(bp) + } } - s.basePath = strings.TrimSuffix(basePath, "/") } } @@ -107,6 +200,17 @@ func WithMessageEndpoint(endpoint string) SSEOption { } } +// WithAppendQueryToMessageEndpoint configures the SSE server to append the original request's +// query parameters to the message endpoint URL that is sent to clients during the SSE connection +// initialization. This is useful when you need to preserve query parameters from the initial +// SSE connection request and carry them over to subsequent message requests, maintaining +// context or authentication details across the communication channel. +func WithAppendQueryToMessageEndpoint() SSEOption { + return func(s *SSEServer) { + s.appendQueryToMessageEndpoint = true + } +} + // WithUseFullURLForMessageEndpoint controls whether the SSE server returns a complete URL (including baseURL) // or just the path portion for the message endpoint. Set to false when clients will concatenate // the baseURL themselves to avoid malformed URLs like "http://localhost/mcphttp://localhost/mcp/message". @@ -123,14 +227,29 @@ func WithSSEEndpoint(endpoint string) SSEOption { } } -// WithHTTPServer sets the HTTP server instance +// WithHTTPServer sets the HTTP server instance. +// NOTE: When providing a custom HTTP server, you must handle routing yourself +// If routing is not set up, the server will start but won't handle any MCP requests. func WithHTTPServer(srv *http.Server) SSEOption { return func(s *SSEServer) { s.srv = srv } } -// WithContextFunc sets a function that will be called to customise the context +func WithKeepAliveInterval(keepAliveInterval time.Duration) SSEOption { + return func(s *SSEServer) { + s.keepAlive = true + s.keepAliveInterval = keepAliveInterval + } +} + +func WithKeepAlive(keepAlive bool) SSEOption { + return func(s *SSEServer) { + s.keepAlive = keepAlive + } +} + +// WithSSEContextFunc sets a function that will be called to customise the context // to the server using the incoming request. func WithSSEContextFunc(fn SSEContextFunc) SSEOption { return func(s *SSEServer) { @@ -145,6 +264,8 @@ func NewSSEServer(server *MCPServer, opts ...SSEOption) *SSEServer { sseEndpoint: "/sse", messageEndpoint: "/message", useFullURLForMessageEndpoint: true, + keepAlive: false, + keepAliveInterval: 10 * time.Second, } // Apply all options @@ -157,10 +278,7 @@ func NewSSEServer(server *MCPServer, opts ...SSEOption) *SSEServer { // NewTestServer creates a test server for testing purposes func NewTestServer(server *MCPServer, opts ...SSEOption) *httptest.Server { - sseServer := NewSSEServer(server) - for _, opt := range opts { - opt(sseServer) - } + sseServer := NewSSEServer(server, opts...) testServer := httptest.NewServer(sseServer) sseServer.baseURL = testServer.URL @@ -170,19 +288,34 @@ func NewTestServer(server *MCPServer, opts ...SSEOption) *httptest.Server { // Start begins serving SSE connections on the specified address. // It sets up HTTP handlers for SSE and message endpoints. func (s *SSEServer) Start(addr string) error { - s.srv = &http.Server{ - Addr: addr, - Handler: s, + s.mu.Lock() + if s.srv == nil { + s.srv = &http.Server{ + Addr: addr, + Handler: s, + } + } else { + if s.srv.Addr == "" { + s.srv.Addr = addr + } else if s.srv.Addr != addr { + return fmt.Errorf("conflicting listen address: WithHTTPServer(%q) vs Start(%q)", s.srv.Addr, addr) + } } + srv := s.srv + s.mu.Unlock() - return s.srv.ListenAndServe() + return srv.ListenAndServe() } // Shutdown gracefully stops the SSE server, closing all active sessions // and shutting down the HTTP server. func (s *SSEServer) Shutdown(ctx context.Context) error { - if s.srv != nil { - s.sessions.Range(func(key, value interface{}) bool { + s.mu.RLock() + srv := s.srv + s.mu.RUnlock() + + if srv != nil { + s.sessions.Range(func(key, value any) bool { if session, ok := value.(*sseSession); ok { close(session.done) } @@ -190,7 +323,7 @@ func (s *SSEServer) Shutdown(ctx context.Context) error { return true }) - return s.srv.Shutdown(ctx) + return srv.Shutdown(ctx) } return nil } @@ -216,8 +349,6 @@ func (s *SSEServer) handleSSE(w http.ResponseWriter, r *http.Request) { sessionID := uuid.New().String() session := &sseSession{ - writer: w, - flusher: flusher, done: make(chan struct{}), eventQueue: make(chan string, 100), // Buffer for events sessionID: sessionID, @@ -228,10 +359,14 @@ func (s *SSEServer) handleSSE(w http.ResponseWriter, r *http.Request) { defer s.sessions.Delete(sessionID) if err := s.server.RegisterSession(r.Context(), session); err != nil { - http.Error(w, fmt.Sprintf("Session registration failed: %v", err), http.StatusInternalServerError) + http.Error( + w, + fmt.Sprintf("Session registration failed: %v", err), + http.StatusInternalServerError, + ) return } - defer s.server.UnregisterSession(sessionID) + defer s.server.UnregisterSession(r.Context(), sessionID) // Start notification handler for this session go func() { @@ -255,8 +390,44 @@ func (s *SSEServer) handleSSE(w http.ResponseWriter, r *http.Request) { } }() + // Start keep alive : ping + if s.keepAlive { + go func() { + ticker := time.NewTicker(s.keepAliveInterval) + defer ticker.Stop() + for { + select { + case <-ticker.C: + message := mcp.JSONRPCRequest{ + JSONRPC: "2.0", + ID: mcp.NewRequestId(session.requestID.Add(1)), + Request: mcp.Request{ + Method: "ping", + }, + } + messageBytes, _ := json.Marshal(message) + pingMsg := fmt.Sprintf("event: message\ndata:%s\n\n", messageBytes) + select { + case session.eventQueue <- pingMsg: + // Message sent successfully + case <-session.done: + return + } + case <-session.done: + return + case <-r.Context().Done(): + return + } + } + }() + } + // Send the initial endpoint event - fmt.Fprintf(w, "event: endpoint\ndata: %s\r\n\r\n", s.GetMessageEndpointForClient(sessionID)) + endpoint := s.GetMessageEndpointForClient(r, sessionID) + if s.appendQueryToMessageEndpoint && len(r.URL.RawQuery) > 0 { + endpoint += "&" + r.URL.RawQuery + } + fmt.Fprintf(w, "event: endpoint\ndata: %s\r\n\r\n", endpoint) flusher.Flush() // Main event loop - this runs in the HTTP handler goroutine @@ -269,22 +440,31 @@ func (s *SSEServer) handleSSE(w http.ResponseWriter, r *http.Request) { case <-r.Context().Done(): close(session.done) return + case <-session.done: + return } } } // GetMessageEndpointForClient returns the appropriate message endpoint URL with session ID -// based on the useFullURLForMessageEndpoint configuration. -func (s *SSEServer) GetMessageEndpointForClient(sessionID string) string { - messageEndpoint := s.messageEndpoint - if s.useFullURLForMessageEndpoint { - messageEndpoint = s.CompleteMessageEndpoint() +// for the given request. This is the canonical way to compute the message endpoint for a client. +// It handles both dynamic and static path modes, and honors the WithUseFullURLForMessageEndpoint flag. +func (s *SSEServer) GetMessageEndpointForClient(r *http.Request, sessionID string) string { + basePath := s.basePath + if s.dynamicBasePathFunc != nil { + basePath = s.dynamicBasePathFunc(r, sessionID) + } + + endpointPath := normalizeURLPath(basePath, s.messageEndpoint) + if s.useFullURLForMessageEndpoint && s.baseURL != "" { + endpointPath = s.baseURL + endpointPath } - return fmt.Sprintf("%s?sessionId=%s", messageEndpoint, sessionID) + + return fmt.Sprintf("%s?sessionId=%s", endpointPath, sessionID) } // handleMessage processes incoming JSON-RPC messages from clients and sends responses -// back through both the SSE connection and HTTP response. +// back through the SSE connection and 202 code to HTTP response. func (s *SSEServer) handleMessage(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { s.writeJSONRPCError(w, nil, mcp.INVALID_REQUEST, "Method not allowed") @@ -296,7 +476,6 @@ func (s *SSEServer) handleMessage(w http.ResponseWriter, r *http.Request) { s.writeJSONRPCError(w, nil, mcp.INVALID_PARAMS, "Missing sessionId") return } - sessionI, ok := s.sessions.Load(sessionID) if !ok { s.writeJSONRPCError(w, nil, mcp.INVALID_PARAMS, "Invalid session ID") @@ -317,51 +496,71 @@ func (s *SSEServer) handleMessage(w http.ResponseWriter, r *http.Request) { return } - // Process message through MCPServer - response := s.server.HandleMessage(ctx, rawMessage) - - // Only send response if there is one (not for notifications) - if response != nil { - eventData, _ := json.Marshal(response) + // Create a context that preserves all values from parent ctx but won't be canceled when the parent is canceled. + // this is required because the http ctx will be canceled when the client disconnects + detachedCtx := context.WithoutCancel(ctx) + + // quick return request, send 202 Accepted with no body, then deal the message and sent response via SSE + w.WriteHeader(http.StatusAccepted) + + // Create a new context for handling the message that will be canceled when the message handling is done + messageCtx, cancel := context.WithCancel(detachedCtx) + + go func(ctx context.Context) { + defer cancel() + // Use the context that will be canceled when session is done + // Process message through MCPServer + response := s.server.HandleMessage(ctx, rawMessage) + // Only send response if there is one (not for notifications) + if response != nil { + var message string + if eventData, err := json.Marshal(response); err != nil { + // If there is an error marshalling the response, send a generic error response + log.Printf("failed to marshal response: %v", err) + message = "event: message\ndata: {\"error\": \"internal error\",\"jsonrpc\": \"2.0\", \"id\": null}\n\n" + } else { + message = fmt.Sprintf("event: message\ndata: %s\n\n", eventData) + } - // Queue the event for sending via SSE - select { - case session.eventQueue <- fmt.Sprintf("event: message\ndata: %s\n\n", eventData): - // Event queued successfully - case <-session.done: - // Session is closed, don't try to queue - default: - // Queue is full, could log this + // Queue the event for sending via SSE + select { + case session.eventQueue <- message: + // Event queued successfully + case <-session.done: + // Session is closed, don't try to queue + default: + // Queue is full, log this situation + log.Printf("Event queue full for session %s", sessionID) + } } - - // Send HTTP response - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusAccepted) - json.NewEncoder(w).Encode(response) - } else { - // For notifications, just send 202 Accepted with no body - w.WriteHeader(http.StatusAccepted) - } + }(messageCtx) } // writeJSONRPCError writes a JSON-RPC error response with the given error details. func (s *SSEServer) writeJSONRPCError( w http.ResponseWriter, - id interface{}, + id any, code int, message string, ) { response := createErrorResponse(id, code, message) w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusBadRequest) - json.NewEncoder(w).Encode(response) + if err := json.NewEncoder(w).Encode(response); err != nil { + http.Error( + w, + fmt.Sprintf("Failed to encode response: %v", err), + http.StatusInternalServerError, + ) + return + } } // SendEventToSession sends an event to a specific SSE session identified by sessionID. // Returns an error if the session is not found or closed. func (s *SSEServer) SendEventToSession( sessionID string, - event interface{}, + event any, ) error { sessionI, ok := s.sessions.Load(sessionID) if !ok { @@ -384,6 +583,7 @@ func (s *SSEServer) SendEventToSession( return fmt.Errorf("event queue full") } } + func (s *SSEServer) GetUrlPath(input string) (string, error) { parse, err := url.Parse(input) if err != nil { @@ -392,30 +592,115 @@ func (s *SSEServer) GetUrlPath(input string) (string, error) { return parse.Path, nil } -func (s *SSEServer) CompleteSseEndpoint() string { - return s.baseURL + s.basePath + s.sseEndpoint +func (s *SSEServer) CompleteSseEndpoint() (string, error) { + if s.dynamicBasePathFunc != nil { + return "", &ErrDynamicPathConfig{Method: "CompleteSseEndpoint"} + } + + path := normalizeURLPath(s.basePath, s.sseEndpoint) + return s.baseURL + path, nil } + func (s *SSEServer) CompleteSsePath() string { - path, err := s.GetUrlPath(s.CompleteSseEndpoint()) + path, err := s.CompleteSseEndpoint() if err != nil { - return s.basePath + s.sseEndpoint + return normalizeURLPath(s.basePath, s.sseEndpoint) } - return path + urlPath, err := s.GetUrlPath(path) + if err != nil { + return normalizeURLPath(s.basePath, s.sseEndpoint) + } + return urlPath } -func (s *SSEServer) CompleteMessageEndpoint() string { - return s.baseURL + s.basePath + s.messageEndpoint +func (s *SSEServer) CompleteMessageEndpoint() (string, error) { + if s.dynamicBasePathFunc != nil { + return "", &ErrDynamicPathConfig{Method: "CompleteMessageEndpoint"} + } + path := normalizeURLPath(s.basePath, s.messageEndpoint) + return s.baseURL + path, nil } + func (s *SSEServer) CompleteMessagePath() string { - path, err := s.GetUrlPath(s.CompleteMessageEndpoint()) + path, err := s.CompleteMessageEndpoint() if err != nil { - return s.basePath + s.messageEndpoint + return normalizeURLPath(s.basePath, s.messageEndpoint) } - return path + urlPath, err := s.GetUrlPath(path) + if err != nil { + return normalizeURLPath(s.basePath, s.messageEndpoint) + } + return urlPath +} + +// SSEHandler returns an http.Handler for the SSE endpoint. +// +// This method allows you to mount the SSE handler at any arbitrary path +// using your own router (e.g. net/http, gorilla/mux, chi, etc.). It is +// intended for advanced scenarios where you want to control the routing or +// support dynamic segments. +// +// IMPORTANT: When using this handler in advanced/dynamic mounting scenarios, +// you must use the WithDynamicBasePath option to ensure the correct base path +// is communicated to clients. +// +// Example usage: +// +// // Advanced/dynamic: +// sseServer := NewSSEServer(mcpServer, +// WithDynamicBasePath(func(r *http.Request, sessionID string) string { +// tenant := r.PathValue("tenant") +// return "/mcp/" + tenant +// }), +// WithBaseURL("http://localhost:8080") +// ) +// mux.Handle("/mcp/{tenant}/sse", sseServer.SSEHandler()) +// mux.Handle("/mcp/{tenant}/message", sseServer.MessageHandler()) +// +// For non-dynamic cases, use ServeHTTP method instead. +func (s *SSEServer) SSEHandler() http.Handler { + return http.HandlerFunc(s.handleSSE) +} + +// MessageHandler returns an http.Handler for the message endpoint. +// +// This method allows you to mount the message handler at any arbitrary path +// using your own router (e.g. net/http, gorilla/mux, chi, etc.). It is +// intended for advanced scenarios where you want to control the routing or +// support dynamic segments. +// +// IMPORTANT: When using this handler in advanced/dynamic mounting scenarios, +// you must use the WithDynamicBasePath option to ensure the correct base path +// is communicated to clients. +// +// Example usage: +// +// // Advanced/dynamic: +// sseServer := NewSSEServer(mcpServer, +// WithDynamicBasePath(func(r *http.Request, sessionID string) string { +// tenant := r.PathValue("tenant") +// return "/mcp/" + tenant +// }), +// WithBaseURL("http://localhost:8080") +// ) +// mux.Handle("/mcp/{tenant}/sse", sseServer.SSEHandler()) +// mux.Handle("/mcp/{tenant}/message", sseServer.MessageHandler()) +// +// For non-dynamic cases, use ServeHTTP method instead. +func (s *SSEServer) MessageHandler() http.Handler { + return http.HandlerFunc(s.handleMessage) } // ServeHTTP implements the http.Handler interface. func (s *SSEServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if s.dynamicBasePathFunc != nil { + http.Error( + w, + (&ErrDynamicPathConfig{Method: "ServeHTTP"}).Error(), + http.StatusInternalServerError, + ) + return + } path := r.URL.Path // Use exact path matching rather than Contains ssePath := s.CompleteSsePath() @@ -431,3 +716,21 @@ func (s *SSEServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { http.NotFound(w, r) } + +// normalizeURLPath joins path elements like path.Join but ensures the +// result always starts with a leading slash and never ends with a slash +func normalizeURLPath(elem ...string) string { + joined := path.Join(elem...) + + // Ensure leading slash + if !strings.HasPrefix(joined, "/") { + joined = "/" + joined + } + + // Remove trailing slash if not just "/" + if len(joined) > 1 && strings.HasSuffix(joined, "/") { + joined = joined[:len(joined)-1] + } + + return joined +} diff --git a/vendor/github.com/mark3labs/mcp-go/server/stdio.go b/vendor/github.com/mark3labs/mcp-go/server/stdio.go index 14c1e76e9a..33ac9bb885 100644 --- a/vendor/github.com/mark3labs/mcp-go/server/stdio.go +++ b/vendor/github.com/mark3labs/mcp-go/server/stdio.go @@ -9,6 +9,7 @@ import ( "log" "os" "os/signal" + "sync" "sync/atomic" "syscall" @@ -40,7 +41,7 @@ func WithErrorLogger(logger *log.Logger) StdioOption { } } -// WithContextFunc sets a function that will be called to customise the context +// WithStdioContextFunc sets a function that will be called to customise the context // to the server. Note that the stdio server uses the same context for all requests, // so this function will only be called once per server instance. func WithStdioContextFunc(fn StdioContextFunc) StdioOption { @@ -51,8 +52,21 @@ func WithStdioContextFunc(fn StdioContextFunc) StdioOption { // stdioSession is a static client session, since stdio has only one client. type stdioSession struct { - notifications chan mcp.JSONRPCNotification - initialized atomic.Bool + notifications chan mcp.JSONRPCNotification + initialized atomic.Bool + loggingLevel atomic.Value + clientInfo atomic.Value // stores session-specific client info + writer io.Writer // for sending requests to client + requestID atomic.Int64 // for generating unique request IDs + mu sync.RWMutex // protects writer + pendingRequests map[int64]chan *samplingResponse // for tracking pending sampling requests + pendingMu sync.RWMutex // protects pendingRequests +} + +// samplingResponse represents a response to a sampling request +type samplingResponse struct { + result *mcp.CreateMessageResult + err error } func (s *stdioSession) SessionID() string { @@ -64,6 +78,8 @@ func (s *stdioSession) NotificationChannel() chan<- mcp.JSONRPCNotification { } func (s *stdioSession) Initialize() { + // set default logging level + s.loggingLevel.Store(mcp.LoggingLevelError) s.initialized.Store(true) } @@ -71,10 +87,111 @@ func (s *stdioSession) Initialized() bool { return s.initialized.Load() } -var _ ClientSession = (*stdioSession)(nil) +func (s *stdioSession) GetClientInfo() mcp.Implementation { + if value := s.clientInfo.Load(); value != nil { + if clientInfo, ok := value.(mcp.Implementation); ok { + return clientInfo + } + } + return mcp.Implementation{} +} + +func (s *stdioSession) SetClientInfo(clientInfo mcp.Implementation) { + s.clientInfo.Store(clientInfo) +} + +func (s *stdioSession) SetLogLevel(level mcp.LoggingLevel) { + s.loggingLevel.Store(level) +} + +func (s *stdioSession) GetLogLevel() mcp.LoggingLevel { + level := s.loggingLevel.Load() + if level == nil { + return mcp.LoggingLevelError + } + return level.(mcp.LoggingLevel) +} + +// RequestSampling sends a sampling request to the client and waits for the response. +func (s *stdioSession) RequestSampling(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) { + s.mu.RLock() + writer := s.writer + s.mu.RUnlock() + + if writer == nil { + return nil, fmt.Errorf("no writer available for sending requests") + } + + // Generate a unique request ID + id := s.requestID.Add(1) + + // Create a response channel for this request + responseChan := make(chan *samplingResponse, 1) + s.pendingMu.Lock() + s.pendingRequests[id] = responseChan + s.pendingMu.Unlock() + + // Cleanup function to remove the pending request + cleanup := func() { + s.pendingMu.Lock() + delete(s.pendingRequests, id) + s.pendingMu.Unlock() + } + defer cleanup() + + // Create the JSON-RPC request + jsonRPCRequest := struct { + JSONRPC string `json:"jsonrpc"` + ID int64 `json:"id"` + Method string `json:"method"` + Params mcp.CreateMessageParams `json:"params"` + }{ + JSONRPC: mcp.JSONRPC_VERSION, + ID: id, + Method: string(mcp.MethodSamplingCreateMessage), + Params: request.CreateMessageParams, + } + + // Marshal and send the request + requestBytes, err := json.Marshal(jsonRPCRequest) + if err != nil { + return nil, fmt.Errorf("failed to marshal sampling request: %w", err) + } + requestBytes = append(requestBytes, '\n') + + if _, err := writer.Write(requestBytes); err != nil { + return nil, fmt.Errorf("failed to write sampling request: %w", err) + } + + // Wait for the response or context cancellation + select { + case <-ctx.Done(): + return nil, ctx.Err() + case response := <-responseChan: + if response.err != nil { + return nil, response.err + } + return response.result, nil + } +} + +// SetWriter sets the writer for sending requests to the client. +func (s *stdioSession) SetWriter(writer io.Writer) { + s.mu.Lock() + defer s.mu.Unlock() + s.writer = writer +} + +var ( + _ ClientSession = (*stdioSession)(nil) + _ SessionWithLogging = (*stdioSession)(nil) + _ SessionWithClientInfo = (*stdioSession)(nil) + _ SessionWithSampling = (*stdioSession)(nil) +) var stdioSessionInstance = stdioSession{ - notifications: make(chan mcp.JSONRPCNotification, 100), + notifications: make(chan mcp.JSONRPCNotification, 100), + pendingRequests: make(map[int64]chan *samplingResponse), } // NewStdioServer creates a new stdio server wrapper around an MCPServer. @@ -156,29 +273,23 @@ func (s *StdioServer) processInputStream(ctx context.Context, reader *bufio.Read // returns an empty string and the context's error. EOF is returned when the input // stream is closed. func (s *StdioServer) readNextLine(ctx context.Context, reader *bufio.Reader) (string, error) { - readChan := make(chan string, 1) - errChan := make(chan error, 1) - defer func() { - close(readChan) - close(errChan) - }() + type result struct { + line string + err error + } + + resultCh := make(chan result, 1) go func() { line, err := reader.ReadString('\n') - if err != nil { - errChan <- err - return - } - readChan <- line + resultCh <- result{line: line, err: err} }() select { case <-ctx.Done(): - return "", ctx.Err() - case err := <-errChan: - return "", err - case line := <-readChan: - return line, nil + return "", nil + case res := <-resultCh: + return res.line, res.err } } @@ -194,9 +305,12 @@ func (s *StdioServer) Listen( if err := s.server.RegisterSession(ctx, &stdioSessionInstance); err != nil { return fmt.Errorf("register session: %w", err) } - defer s.server.UnregisterSession(stdioSessionInstance.SessionID()) + defer s.server.UnregisterSession(ctx, stdioSessionInstance.SessionID()) ctx = s.server.WithContext(ctx, &stdioSessionInstance) + // Set the writer for sending requests to the client + stdioSessionInstance.SetWriter(stdout) + // Add in any custom context. if s.contextFunc != nil { ctx = s.contextFunc(ctx) @@ -217,6 +331,11 @@ func (s *StdioServer) processMessage( line string, writer io.Writer, ) error { + // If line is empty, likely due to ctx cancellation + if len(line) == 0 { + return nil + } + // Parse the message as raw JSON var rawMessage json.RawMessage if err := json.Unmarshal([]byte(line), &rawMessage); err != nil { @@ -224,7 +343,29 @@ func (s *StdioServer) processMessage( return s.writeResponse(response, writer) } - // Handle the message using the wrapped server + // Check if this is a response to a sampling request + if s.handleSamplingResponse(rawMessage) { + return nil + } + + // Check if this is a tool call that might need sampling (and thus should be processed concurrently) + var baseMessage struct { + Method string `json:"method"` + } + if json.Unmarshal(rawMessage, &baseMessage) == nil && baseMessage.Method == "tools/call" { + // Process tool calls concurrently to avoid blocking on sampling requests + go func() { + response := s.server.HandleMessage(ctx, rawMessage) + if response != nil { + if err := s.writeResponse(response, writer); err != nil { + s.errLogger.Printf("Error writing tool response: %v", err) + } + } + }() + return nil + } + + // Handle other messages synchronously response := s.server.HandleMessage(ctx, rawMessage) // Only write response if there is one (not for notifications) @@ -237,6 +378,65 @@ func (s *StdioServer) processMessage( return nil } +// handleSamplingResponse checks if the message is a response to a sampling request +// and routes it to the appropriate pending request channel. +func (s *StdioServer) handleSamplingResponse(rawMessage json.RawMessage) bool { + return stdioSessionInstance.handleSamplingResponse(rawMessage) +} + +// handleSamplingResponse handles incoming sampling responses for this session +func (s *stdioSession) handleSamplingResponse(rawMessage json.RawMessage) bool { + // Try to parse as a JSON-RPC response + var response struct { + JSONRPC string `json:"jsonrpc"` + ID json.Number `json:"id"` + Result json.RawMessage `json:"result,omitempty"` + Error *struct { + Code int `json:"code"` + Message string `json:"message"` + } `json:"error,omitempty"` + } + + if err := json.Unmarshal(rawMessage, &response); err != nil { + return false + } + // Parse the ID as int64 + idInt64, err := response.ID.Int64() + if err != nil || (response.Result == nil && response.Error == nil) { + return false + } + + // Look for a pending request with this ID + s.pendingMu.RLock() + responseChan, exists := s.pendingRequests[idInt64] + s.pendingMu.RUnlock() + + if !exists { + return false + } // Parse and send the response + samplingResp := &samplingResponse{} + + if response.Error != nil { + samplingResp.err = fmt.Errorf("sampling request failed: %s", response.Error.Message) + } else { + var result mcp.CreateMessageResult + if err := json.Unmarshal(response.Result, &result); err != nil { + samplingResp.err = fmt.Errorf("failed to unmarshal sampling response: %w", err) + } else { + samplingResp.result = &result + } + } + + // Send the response (non-blocking) + select { + case responseChan <- samplingResp: + default: + // Channel is full or closed, ignore + } + + return true +} + // writeResponse marshals and writes a JSON-RPC response message followed by a newline. // Returns an error if marshaling or writing fails. func (s *StdioServer) writeResponse( @@ -261,7 +461,6 @@ func (s *StdioServer) writeResponse( // Returns an error if the server encounters any issues during operation. func ServeStdio(server *MCPServer, opts ...StdioOption) error { s := NewStdioServer(server) - s.SetErrorLogger(log.New(os.Stderr, "", log.LstdFlags)) for _, opt := range opts { opt(s) diff --git a/vendor/github.com/mark3labs/mcp-go/server/streamable_http.go b/vendor/github.com/mark3labs/mcp-go/server/streamable_http.go new file mode 100644 index 0000000000..1312c9753a --- /dev/null +++ b/vendor/github.com/mark3labs/mcp-go/server/streamable_http.go @@ -0,0 +1,655 @@ +package server + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/google/uuid" + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/util" +) + +// StreamableHTTPOption defines a function type for configuring StreamableHTTPServer +type StreamableHTTPOption func(*StreamableHTTPServer) + +// WithEndpointPath sets the endpoint path for the server. +// The default is "/mcp". +// It's only works for `Start` method. When used as a http.Handler, it has no effect. +func WithEndpointPath(endpointPath string) StreamableHTTPOption { + return func(s *StreamableHTTPServer) { + // Normalize the endpoint path to ensure it starts with a slash and doesn't end with one + normalizedPath := "/" + strings.Trim(endpointPath, "/") + s.endpointPath = normalizedPath + } +} + +// WithStateLess sets the server to stateless mode. +// If true, the server will manage no session information. Every request will be treated +// as a new session. No session id returned to the client. +// The default is false. +// +// Notice: This is a convenience method. It's identical to set WithSessionIdManager option +// to StatelessSessionIdManager. +func WithStateLess(stateLess bool) StreamableHTTPOption { + return func(s *StreamableHTTPServer) { + if stateLess { + s.sessionIdManager = &StatelessSessionIdManager{} + } + } +} + +// WithSessionIdManager sets a custom session id generator for the server. +// By default, the server will use SimpleStatefulSessionIdGenerator, which generates +// session ids with uuid, and it's insecure. +// Notice: it will override the WithStateLess option. +func WithSessionIdManager(manager SessionIdManager) StreamableHTTPOption { + return func(s *StreamableHTTPServer) { + s.sessionIdManager = manager + } +} + +// WithHeartbeatInterval sets the heartbeat interval. Positive interval means the +// server will send a heartbeat to the client through the GET connection, to keep +// the connection alive from being closed by the network infrastructure (e.g. +// gateways). If the client does not establish a GET connection, it has no +// effect. The default is not to send heartbeats. +func WithHeartbeatInterval(interval time.Duration) StreamableHTTPOption { + return func(s *StreamableHTTPServer) { + s.listenHeartbeatInterval = interval + } +} + +// WithHTTPContextFunc sets a function that will be called to customise the context +// to the server using the incoming request. +// This can be used to inject context values from headers, for example. +func WithHTTPContextFunc(fn HTTPContextFunc) StreamableHTTPOption { + return func(s *StreamableHTTPServer) { + s.contextFunc = fn + } +} + +// WithStreamableHTTPServer sets the HTTP server instance for StreamableHTTPServer. +// NOTE: When providing a custom HTTP server, you must handle routing yourself +// If routing is not set up, the server will start but won't handle any MCP requests. +func WithStreamableHTTPServer(srv *http.Server) StreamableHTTPOption { + return func(s *StreamableHTTPServer) { + s.httpServer = srv + } +} + +// WithLogger sets the logger for the server +func WithLogger(logger util.Logger) StreamableHTTPOption { + return func(s *StreamableHTTPServer) { + s.logger = logger + } +} + +// StreamableHTTPServer implements a Streamable-http based MCP server. +// It communicates with clients over HTTP protocol, supporting both direct HTTP responses, and SSE streams. +// https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#streamable-http +// +// Usage: +// +// server := NewStreamableHTTPServer(mcpServer) +// server.Start(":8080") // The final url for client is http://xxxx:8080/mcp by default +// +// or the server itself can be used as a http.Handler, which is convenient to +// integrate with existing http servers, or advanced usage: +// +// handler := NewStreamableHTTPServer(mcpServer) +// http.Handle("/streamable-http", handler) +// http.ListenAndServe(":8080", nil) +// +// Notice: +// Except for the GET handlers(listening), the POST handlers(request/notification) will +// not trigger the session registration. So the methods like `SendNotificationToSpecificClient` +// or `hooks.onRegisterSession` will not be triggered for POST messages. +// +// The current implementation does not support the following features from the specification: +// - Batching of requests/notifications/responses in arrays. +// - Stream Resumability +type StreamableHTTPServer struct { + server *MCPServer + sessionTools *sessionToolsStore + sessionRequestIDs sync.Map // sessionId --> last requestID(*atomic.Int64) + + httpServer *http.Server + mu sync.RWMutex + + endpointPath string + contextFunc HTTPContextFunc + sessionIdManager SessionIdManager + listenHeartbeatInterval time.Duration + logger util.Logger +} + +// NewStreamableHTTPServer creates a new streamable-http server instance +func NewStreamableHTTPServer(server *MCPServer, opts ...StreamableHTTPOption) *StreamableHTTPServer { + s := &StreamableHTTPServer{ + server: server, + sessionTools: newSessionToolsStore(), + endpointPath: "/mcp", + sessionIdManager: &InsecureStatefulSessionIdManager{}, + logger: util.DefaultLogger(), + } + + // Apply all options + for _, opt := range opts { + opt(s) + } + return s +} + +// ServeHTTP implements the http.Handler interface. +func (s *StreamableHTTPServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodPost: + s.handlePost(w, r) + case http.MethodGet: + s.handleGet(w, r) + case http.MethodDelete: + s.handleDelete(w, r) + default: + http.NotFound(w, r) + } +} + +// Start begins serving the http server on the specified address and path +// (endpointPath). like: +// +// s.Start(":8080") +func (s *StreamableHTTPServer) Start(addr string) error { + s.mu.Lock() + if s.httpServer == nil { + mux := http.NewServeMux() + mux.Handle(s.endpointPath, s) + s.httpServer = &http.Server{ + Addr: addr, + Handler: mux, + } + } else { + if s.httpServer.Addr == "" { + s.httpServer.Addr = addr + } else if s.httpServer.Addr != addr { + return fmt.Errorf("conflicting listen address: WithStreamableHTTPServer(%q) vs Start(%q)", s.httpServer.Addr, addr) + } + } + srv := s.httpServer + s.mu.Unlock() + + return srv.ListenAndServe() +} + +// Shutdown gracefully stops the server, closing all active sessions +// and shutting down the HTTP server. +func (s *StreamableHTTPServer) Shutdown(ctx context.Context) error { + + // shutdown the server if needed (may use as a http.Handler) + s.mu.RLock() + srv := s.httpServer + s.mu.RUnlock() + if srv != nil { + return srv.Shutdown(ctx) + } + return nil +} + +// --- internal methods --- + +const ( + headerKeySessionID = "Mcp-Session-Id" +) + +func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request) { + // post request carry request/notification message + + // Check content type + contentType := r.Header.Get("Content-Type") + if contentType != "application/json" { + http.Error(w, "Invalid content type: must be 'application/json'", http.StatusBadRequest) + return + } + + // Check the request body is valid json, meanwhile, get the request Method + rawData, err := io.ReadAll(r.Body) + if err != nil { + s.writeJSONRPCError(w, nil, mcp.PARSE_ERROR, fmt.Sprintf("read request body error: %v", err)) + return + } + var baseMessage struct { + Method mcp.MCPMethod `json:"method"` + } + if err := json.Unmarshal(rawData, &baseMessage); err != nil { + s.writeJSONRPCError(w, nil, mcp.PARSE_ERROR, "request body is not valid json") + return + } + isInitializeRequest := baseMessage.Method == mcp.MethodInitialize + + // Prepare the session for the mcp server + // The session is ephemeral. Its life is the same as the request. It's only created + // for interaction with the mcp server. + var sessionID string + if isInitializeRequest { + // generate a new one for initialize request + sessionID = s.sessionIdManager.Generate() + } else { + // Get session ID from header. + // Stateful servers need the client to carry the session ID. + sessionID = r.Header.Get(headerKeySessionID) + isTerminated, err := s.sessionIdManager.Validate(sessionID) + if err != nil { + http.Error(w, "Invalid session ID", http.StatusBadRequest) + return + } + if isTerminated { + http.Error(w, "Session terminated", http.StatusNotFound) + return + } + } + + session := newStreamableHttpSession(sessionID, s.sessionTools) + + // Set the client context before handling the message + ctx := s.server.WithContext(r.Context(), session) + if s.contextFunc != nil { + ctx = s.contextFunc(ctx, r) + } + + // handle potential notifications + mu := sync.Mutex{} + upgradedHeader := false + done := make(chan struct{}) + + go func() { + for { + select { + case nt := <-session.notificationChannel: + func() { + mu.Lock() + defer mu.Unlock() + // if the done chan is closed, as the request is terminated, just return + select { + case <-done: + return + default: + } + defer func() { + flusher, ok := w.(http.Flusher) + if ok { + flusher.Flush() + } + }() + + // if there's notifications, upgradedHeader to SSE response + if !upgradedHeader { + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Connection", "keep-alive") + w.Header().Set("Cache-Control", "no-cache") + w.WriteHeader(http.StatusAccepted) + upgradedHeader = true + } + err := writeSSEEvent(w, nt) + if err != nil { + s.logger.Errorf("Failed to write SSE event: %v", err) + return + } + }() + case <-done: + return + case <-ctx.Done(): + return + } + } + }() + + // Process message through MCPServer + response := s.server.HandleMessage(ctx, rawData) + if response == nil { + // For notifications, just send 202 Accepted with no body + w.WriteHeader(http.StatusAccepted) + return + } + + // Write response + mu.Lock() + defer mu.Unlock() + // close the done chan before unlock + defer close(done) + if ctx.Err() != nil { + return + } + // If client-server communication already upgraded to SSE stream + if session.upgradeToSSE.Load() { + if !upgradedHeader { + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Connection", "keep-alive") + w.Header().Set("Cache-Control", "no-cache") + w.WriteHeader(http.StatusAccepted) + upgradedHeader = true + } + if err := writeSSEEvent(w, response); err != nil { + s.logger.Errorf("Failed to write final SSE response event: %v", err) + } + } else { + w.Header().Set("Content-Type", "application/json") + if isInitializeRequest && sessionID != "" { + // send the session ID back to the client + w.Header().Set(headerKeySessionID, sessionID) + } + w.WriteHeader(http.StatusOK) + err := json.NewEncoder(w).Encode(response) + if err != nil { + s.logger.Errorf("Failed to write response: %v", err) + } + } +} + +func (s *StreamableHTTPServer) handleGet(w http.ResponseWriter, r *http.Request) { + // get request is for listening to notifications + // https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#listening-for-messages-from-the-server + + sessionID := r.Header.Get(headerKeySessionID) + // the specification didn't say we should validate the session id + + if sessionID == "" { + // It's a stateless server, + // but the MCP server requires a unique ID for registering, so we use a random one + sessionID = uuid.New().String() + } + + session := newStreamableHttpSession(sessionID, s.sessionTools) + if err := s.server.RegisterSession(r.Context(), session); err != nil { + http.Error(w, fmt.Sprintf("Session registration failed: %v", err), http.StatusBadRequest) + return + } + defer s.server.UnregisterSession(r.Context(), sessionID) + + // Set the client context before handling the message + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + w.WriteHeader(http.StatusOK) + + flusher, ok := w.(http.Flusher) + if !ok { + http.Error(w, "Streaming unsupported", http.StatusInternalServerError) + return + } + flusher.Flush() + + // Start notification handler for this session + done := make(chan struct{}) + defer close(done) + writeChan := make(chan any, 16) + + go func() { + for { + select { + case nt := <-session.notificationChannel: + select { + case writeChan <- &nt: + case <-done: + return + } + case <-done: + return + } + } + }() + + if s.listenHeartbeatInterval > 0 { + // heartbeat to keep the connection alive + go func() { + ticker := time.NewTicker(s.listenHeartbeatInterval) + defer ticker.Stop() + for { + select { + case <-ticker.C: + message := mcp.JSONRPCRequest{ + JSONRPC: "2.0", + ID: mcp.NewRequestId(s.nextRequestID(sessionID)), + Request: mcp.Request{ + Method: "ping", + }, + } + select { + case writeChan <- message: + case <-done: + return + } + case <-done: + return + } + } + }() + } + + // Keep the connection open until the client disconnects + // + // There's will a Available() check when handler ends, and it maybe race with Flush(), + // so we use a separate channel to send the data, inteading of flushing directly in other goroutine. + for { + select { + case data := <-writeChan: + if data == nil { + continue + } + if err := writeSSEEvent(w, data); err != nil { + s.logger.Errorf("Failed to write SSE event: %v", err) + return + } + flusher.Flush() + case <-r.Context().Done(): + return + } + } +} + +func (s *StreamableHTTPServer) handleDelete(w http.ResponseWriter, r *http.Request) { + // delete request terminate the session + sessionID := r.Header.Get(headerKeySessionID) + notAllowed, err := s.sessionIdManager.Terminate(sessionID) + if err != nil { + http.Error(w, fmt.Sprintf("Session termination failed: %v", err), http.StatusInternalServerError) + return + } + if notAllowed { + http.Error(w, "Session termination not allowed", http.StatusMethodNotAllowed) + return + } + + // remove the session relateddata from the sessionToolsStore + s.sessionTools.delete(sessionID) + + // remove current session's requstID information + s.sessionRequestIDs.Delete(sessionID) + + w.WriteHeader(http.StatusOK) +} + +func writeSSEEvent(w io.Writer, data any) error { + jsonData, err := json.Marshal(data) + if err != nil { + return fmt.Errorf("failed to marshal data: %w", err) + } + _, err = fmt.Fprintf(w, "event: message\ndata: %s\n\n", jsonData) + if err != nil { + return fmt.Errorf("failed to write SSE event: %w", err) + } + return nil +} + +// writeJSONRPCError writes a JSON-RPC error response with the given error details. +func (s *StreamableHTTPServer) writeJSONRPCError( + w http.ResponseWriter, + id any, + code int, + message string, +) { + response := createErrorResponse(id, code, message) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + err := json.NewEncoder(w).Encode(response) + if err != nil { + s.logger.Errorf("Failed to write JSONRPCError: %v", err) + } +} + +// nextRequestID gets the next incrementing requestID for the current session +func (s *StreamableHTTPServer) nextRequestID(sessionID string) int64 { + actual, _ := s.sessionRequestIDs.LoadOrStore(sessionID, new(atomic.Int64)) + counter := actual.(*atomic.Int64) + return counter.Add(1) +} + +// --- session --- + +type sessionToolsStore struct { + mu sync.RWMutex + tools map[string]map[string]ServerTool // sessionID -> toolName -> tool +} + +func newSessionToolsStore() *sessionToolsStore { + return &sessionToolsStore{ + tools: make(map[string]map[string]ServerTool), + } +} + +func (s *sessionToolsStore) get(sessionID string) map[string]ServerTool { + s.mu.RLock() + defer s.mu.RUnlock() + return s.tools[sessionID] +} + +func (s *sessionToolsStore) set(sessionID string, tools map[string]ServerTool) { + s.mu.Lock() + defer s.mu.Unlock() + s.tools[sessionID] = tools +} + +func (s *sessionToolsStore) delete(sessionID string) { + s.mu.Lock() + defer s.mu.Unlock() + delete(s.tools, sessionID) +} + +// streamableHttpSession is a session for streamable-http transport +// When in POST handlers(request/notification), it's ephemeral, and only exists in the life of the request handler. +// When in GET handlers(listening), it's a real session, and will be registered in the MCP server. +type streamableHttpSession struct { + sessionID string + notificationChannel chan mcp.JSONRPCNotification // server -> client notifications + tools *sessionToolsStore + upgradeToSSE atomic.Bool +} + +func newStreamableHttpSession(sessionID string, toolStore *sessionToolsStore) *streamableHttpSession { + return &streamableHttpSession{ + sessionID: sessionID, + notificationChannel: make(chan mcp.JSONRPCNotification, 100), + tools: toolStore, + } +} + +func (s *streamableHttpSession) SessionID() string { + return s.sessionID +} + +func (s *streamableHttpSession) NotificationChannel() chan<- mcp.JSONRPCNotification { + return s.notificationChannel +} + +func (s *streamableHttpSession) Initialize() { + // do nothing + // the session is ephemeral, no real initialized action needed +} + +func (s *streamableHttpSession) Initialized() bool { + // the session is ephemeral, no real initialized action needed + return true +} + +var _ ClientSession = (*streamableHttpSession)(nil) + +func (s *streamableHttpSession) GetSessionTools() map[string]ServerTool { + return s.tools.get(s.sessionID) +} + +func (s *streamableHttpSession) SetSessionTools(tools map[string]ServerTool) { + s.tools.set(s.sessionID, tools) +} + +var _ SessionWithTools = (*streamableHttpSession)(nil) + +func (s *streamableHttpSession) UpgradeToSSEWhenReceiveNotification() { + s.upgradeToSSE.Store(true) +} + +var _ SessionWithStreamableHTTPConfig = (*streamableHttpSession)(nil) + +// --- session id manager --- + +type SessionIdManager interface { + Generate() string + // Validate checks if a session ID is valid and not terminated. + // Returns isTerminated=true if the ID is valid but belongs to a terminated session. + // Returns err!=nil if the ID format is invalid or lookup failed. + Validate(sessionID string) (isTerminated bool, err error) + // Terminate marks a session ID as terminated. + // Returns isNotAllowed=true if the server policy prevents client termination. + // Returns err!=nil if the ID is invalid or termination failed. + Terminate(sessionID string) (isNotAllowed bool, err error) +} + +// StatelessSessionIdManager does nothing, which means it has no session management, which is stateless. +type StatelessSessionIdManager struct{} + +func (s *StatelessSessionIdManager) Generate() string { + return "" +} +func (s *StatelessSessionIdManager) Validate(sessionID string) (isTerminated bool, err error) { + // In stateless mode, ignore session IDs completely - don't validate or reject them + return false, nil +} +func (s *StatelessSessionIdManager) Terminate(sessionID string) (isNotAllowed bool, err error) { + return false, nil +} + +// InsecureStatefulSessionIdManager generate id with uuid +// It won't validate the id indeed, so it could be fake. +// For more secure session id, use a more complex generator, like a JWT. +type InsecureStatefulSessionIdManager struct{} + +const idPrefix = "mcp-session-" + +func (s *InsecureStatefulSessionIdManager) Generate() string { + return idPrefix + uuid.New().String() +} +func (s *InsecureStatefulSessionIdManager) Validate(sessionID string) (isTerminated bool, err error) { + // validate the session id is a valid uuid + if !strings.HasPrefix(sessionID, idPrefix) { + return false, fmt.Errorf("invalid session id: %s", sessionID) + } + if _, err := uuid.Parse(sessionID[len(idPrefix):]); err != nil { + return false, fmt.Errorf("invalid session id: %s", sessionID) + } + return false, nil +} +func (s *InsecureStatefulSessionIdManager) Terminate(sessionID string) (isNotAllowed bool, err error) { + return false, nil +} + +// NewTestStreamableHTTPServer creates a test server for testing purposes +func NewTestStreamableHTTPServer(server *MCPServer, opts ...StreamableHTTPOption) *httptest.Server { + sseServer := NewStreamableHTTPServer(server, opts...) + testServer := httptest.NewServer(sseServer) + return testServer +} diff --git a/vendor/github.com/mark3labs/mcp-go/util/logger.go b/vendor/github.com/mark3labs/mcp-go/util/logger.go new file mode 100644 index 0000000000..8d7555ce35 --- /dev/null +++ b/vendor/github.com/mark3labs/mcp-go/util/logger.go @@ -0,0 +1,33 @@ +package util + +import ( + "log" +) + +// Logger defines a minimal logging interface +type Logger interface { + Infof(format string, v ...any) + Errorf(format string, v ...any) +} + +// --- Standard Library Logger Wrapper --- + +// DefaultStdLogger implements Logger using the standard library's log.Logger. +func DefaultLogger() Logger { + return &stdLogger{ + logger: log.Default(), + } +} + +// stdLogger wraps the standard library's log.Logger. +type stdLogger struct { + logger *log.Logger +} + +func (l *stdLogger) Infof(format string, v ...any) { + l.logger.Printf("INFO: "+format, v...) +} + +func (l *stdLogger) Errorf(format string, v ...any) { + l.logger.Printf("ERROR: "+format, v...) +} diff --git a/vendor/github.com/spf13/cast/.editorconfig b/vendor/github.com/spf13/cast/.editorconfig new file mode 100644 index 0000000000..a85749f190 --- /dev/null +++ b/vendor/github.com/spf13/cast/.editorconfig @@ -0,0 +1,15 @@ +root = true + +[*] +charset = utf-8 +end_of_line = lf +indent_size = 4 +indent_style = space +insert_final_newline = true +trim_trailing_whitespace = true + +[*.go] +indent_style = tab + +[{*.yml,*.yaml}] +indent_size = 2 diff --git a/vendor/github.com/spf13/cast/.golangci.yaml b/vendor/github.com/spf13/cast/.golangci.yaml new file mode 100644 index 0000000000..e00fd47aa2 --- /dev/null +++ b/vendor/github.com/spf13/cast/.golangci.yaml @@ -0,0 +1,39 @@ +version: "2" + +run: + timeout: 10m + +linters: + enable: + - errcheck + - govet + - ineffassign + - misspell + - nolintlint + # - revive + - unused + + disable: + - staticcheck + + settings: + misspell: + locale: US + nolintlint: + allow-unused: false # report any unused nolint directives + require-specific: false # don't require nolint directives to be specific about which linter is being skipped + +formatters: + enable: + - gci + - gofmt + # - gofumpt + - goimports + # - golines + + settings: + gci: + sections: + - standard + - default + - localmodule diff --git a/vendor/github.com/spf13/cast/.travis.yml b/vendor/github.com/spf13/cast/.travis.yml deleted file mode 100644 index 6420d1c27f..0000000000 --- a/vendor/github.com/spf13/cast/.travis.yml +++ /dev/null @@ -1,15 +0,0 @@ -language: go -env: - - GO111MODULE=on -sudo: required -go: - - "1.11.x" - - tip -os: - - linux -matrix: - allow_failures: - - go: tip - fast_finish: true -script: - - make check diff --git a/vendor/github.com/spf13/cast/Makefile b/vendor/github.com/spf13/cast/Makefile index 7ccf8930b5..f01a5dbb6e 100644 --- a/vendor/github.com/spf13/cast/Makefile +++ b/vendor/github.com/spf13/cast/Makefile @@ -1,4 +1,4 @@ -# A Self-Documenting Makefile: http://marmelab.com/blog/2016/02/29/auto-documented-makefile.html +GOVERSION := $(shell go version | cut -d ' ' -f 3 | cut -d '.' -f 2) .PHONY: check fmt lint test test-race vet test-cover-html help .DEFAULT_GOAL := help @@ -12,11 +12,13 @@ test-race: ## Run tests with race detector go test -race ./... fmt: ## Run gofmt linter +ifeq "$(GOVERSION)" "12" @for d in `go list` ; do \ if [ "`gofmt -l -s $$GOPATH/src/$$d | tee /dev/stderr`" ]; then \ echo "^ improperly formatted go files" && echo && exit 1; \ fi \ done +endif lint: ## Run golint linter @for d in `go list` ; do \ diff --git a/vendor/github.com/spf13/cast/README.md b/vendor/github.com/spf13/cast/README.md index e6939397dd..c58eccb3fd 100644 --- a/vendor/github.com/spf13/cast/README.md +++ b/vendor/github.com/spf13/cast/README.md @@ -1,8 +1,9 @@ -cast -==== -[![GoDoc](https://godoc.org/github.com/spf13/cast?status.svg)](https://godoc.org/github.com/spf13/cast) -[![Build Status](https://api.travis-ci.org/spf13/cast.svg?branch=master)](https://travis-ci.org/spf13/cast) -[![Go Report Card](https://goreportcard.com/badge/github.com/spf13/cast)](https://goreportcard.com/report/github.com/spf13/cast) +# cast + +[![GitHub Workflow Status](https://img.shields.io/github/actions/workflow/status/spf13/cast/ci.yaml?style=flat-square)](https://github.com/spf13/cast/actions/workflows/ci.yaml) +[![go.dev reference](https://img.shields.io/badge/go.dev-reference-007d9c?logo=go&logoColor=white&style=flat-square)](https://pkg.go.dev/mod/github.com/spf13/cast) +![GitHub go.mod Go version](https://img.shields.io/github/go-mod/go-version/spf13/cast?style=flat-square&color=61CFDD) +[![OpenSSF Scorecard](https://api.securityscorecards.dev/projects/github.com/spf13/cast/badge?style=flat-square)](https://deps.dev/go/github.com%252Fspf13%252Fcast) Easy and safe casting from one type to another in Go @@ -17,7 +18,7 @@ interface into a bool, etc. Cast does this intelligently when an obvious conversion is possible. It doesn’t make any attempts to guess what you meant, for example you can only convert a string to an int when it is a string representation of an int such as “8”. Cast was developed for use in -[Hugo](http://hugo.spf13.com), a website engine which uses YAML, TOML or JSON +[Hugo](https://gohugo.io), a website engine which uses YAML, TOML or JSON for meta data. ## Why use Cast? @@ -73,3 +74,6 @@ the code for a complete set. cast.ToInt(eight) // 8 cast.ToInt(nil) // 0 +## License + +The project is licensed under the [MIT License](LICENSE). diff --git a/vendor/github.com/spf13/cast/alias.go b/vendor/github.com/spf13/cast/alias.go new file mode 100644 index 0000000000..855d60005d --- /dev/null +++ b/vendor/github.com/spf13/cast/alias.go @@ -0,0 +1,69 @@ +// Copyright © 2014 Steve Francia . +// +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. +package cast + +import ( + "reflect" + "slices" +) + +var kindNames = []string{ + reflect.String: "string", + reflect.Bool: "bool", + reflect.Int: "int", + reflect.Int8: "int8", + reflect.Int16: "int16", + reflect.Int32: "int32", + reflect.Int64: "int64", + reflect.Uint: "uint", + reflect.Uint8: "uint8", + reflect.Uint16: "uint16", + reflect.Uint32: "uint32", + reflect.Uint64: "uint64", + reflect.Float32: "float32", + reflect.Float64: "float64", +} + +var kinds = map[reflect.Kind]func(reflect.Value) any{ + reflect.String: func(v reflect.Value) any { return v.String() }, + reflect.Bool: func(v reflect.Value) any { return v.Bool() }, + reflect.Int: func(v reflect.Value) any { return int(v.Int()) }, + reflect.Int8: func(v reflect.Value) any { return int8(v.Int()) }, + reflect.Int16: func(v reflect.Value) any { return int16(v.Int()) }, + reflect.Int32: func(v reflect.Value) any { return int32(v.Int()) }, + reflect.Int64: func(v reflect.Value) any { return v.Int() }, + reflect.Uint: func(v reflect.Value) any { return uint(v.Uint()) }, + reflect.Uint8: func(v reflect.Value) any { return uint8(v.Uint()) }, + reflect.Uint16: func(v reflect.Value) any { return uint16(v.Uint()) }, + reflect.Uint32: func(v reflect.Value) any { return uint32(v.Uint()) }, + reflect.Uint64: func(v reflect.Value) any { return v.Uint() }, + reflect.Float32: func(v reflect.Value) any { return float32(v.Float()) }, + reflect.Float64: func(v reflect.Value) any { return v.Float() }, +} + +// resolveAlias attempts to resolve a named type to its underlying basic type (if possible). +// +// Pointers are expected to be indirected by this point. +func resolveAlias(i any) (any, bool) { + if i == nil { + return nil, false + } + + t := reflect.TypeOf(i) + + // Not a named type + if t.Name() == "" || slices.Contains(kindNames, t.Name()) { + return i, false + } + + resolve, ok := kinds[t.Kind()] + if !ok { // Not a supported kind + return i, false + } + + v := reflect.ValueOf(i) + + return resolve(v), true +} diff --git a/vendor/github.com/spf13/cast/basic.go b/vendor/github.com/spf13/cast/basic.go new file mode 100644 index 0000000000..fa330e207a --- /dev/null +++ b/vendor/github.com/spf13/cast/basic.go @@ -0,0 +1,131 @@ +// Copyright © 2014 Steve Francia . +// +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package cast + +import ( + "encoding/json" + "fmt" + "html/template" + "strconv" + "time" +) + +// ToBoolE casts any value to a bool type. +func ToBoolE(i any) (bool, error) { + i, _ = indirect(i) + + switch b := i.(type) { + case bool: + return b, nil + case nil: + return false, nil + case int: + return b != 0, nil + case int8: + return b != 0, nil + case int16: + return b != 0, nil + case int32: + return b != 0, nil + case int64: + return b != 0, nil + case uint: + return b != 0, nil + case uint8: + return b != 0, nil + case uint16: + return b != 0, nil + case uint32: + return b != 0, nil + case uint64: + return b != 0, nil + case float32: + return b != 0, nil + case float64: + return b != 0, nil + case time.Duration: + return b != 0, nil + case string: + return strconv.ParseBool(b) + case json.Number: + v, err := ToInt64E(b) + if err == nil { + return v != 0, nil + } + + return false, fmt.Errorf(errorMsg, i, i, false) + default: + if i, ok := resolveAlias(i); ok { + return ToBoolE(i) + } + + return false, fmt.Errorf(errorMsg, i, i, false) + } +} + +// ToStringE casts any value to a string type. +func ToStringE(i any) (string, error) { + switch s := i.(type) { + case string: + return s, nil + case bool: + return strconv.FormatBool(s), nil + case float64: + return strconv.FormatFloat(s, 'f', -1, 64), nil + case float32: + return strconv.FormatFloat(float64(s), 'f', -1, 32), nil + case int: + return strconv.Itoa(s), nil + case int8: + return strconv.FormatInt(int64(s), 10), nil + case int16: + return strconv.FormatInt(int64(s), 10), nil + case int32: + return strconv.FormatInt(int64(s), 10), nil + case int64: + return strconv.FormatInt(s, 10), nil + case uint: + return strconv.FormatUint(uint64(s), 10), nil + case uint8: + return strconv.FormatUint(uint64(s), 10), nil + case uint16: + return strconv.FormatUint(uint64(s), 10), nil + case uint32: + return strconv.FormatUint(uint64(s), 10), nil + case uint64: + return strconv.FormatUint(s, 10), nil + case json.Number: + return s.String(), nil + case []byte: + return string(s), nil + case template.HTML: + return string(s), nil + case template.URL: + return string(s), nil + case template.JS: + return string(s), nil + case template.CSS: + return string(s), nil + case template.HTMLAttr: + return string(s), nil + case nil: + return "", nil + case fmt.Stringer: + return s.String(), nil + case error: + return s.Error(), nil + default: + if i, ok := indirect(i); ok { + return ToStringE(i) + } + + if i, ok := resolveAlias(i); ok { + return ToStringE(i) + } + + return "", fmt.Errorf(errorMsg, i, i, "") + } +} diff --git a/vendor/github.com/spf13/cast/cast.go b/vendor/github.com/spf13/cast/cast.go index 9fba638d46..8d85539b35 100644 --- a/vendor/github.com/spf13/cast/cast.go +++ b/vendor/github.com/spf13/cast/cast.go @@ -8,164 +8,77 @@ package cast import "time" -// ToBool casts an interface to a bool type. -func ToBool(i interface{}) bool { - v, _ := ToBoolE(i) - return v -} - -// ToTime casts an interface to a time.Time type. -func ToTime(i interface{}) time.Time { - v, _ := ToTimeE(i) - return v -} - -// ToDuration casts an interface to a time.Duration type. -func ToDuration(i interface{}) time.Duration { - v, _ := ToDurationE(i) - return v -} - -// ToFloat64 casts an interface to a float64 type. -func ToFloat64(i interface{}) float64 { - v, _ := ToFloat64E(i) - return v -} - -// ToFloat32 casts an interface to a float32 type. -func ToFloat32(i interface{}) float32 { - v, _ := ToFloat32E(i) - return v -} - -// ToInt64 casts an interface to an int64 type. -func ToInt64(i interface{}) int64 { - v, _ := ToInt64E(i) - return v -} - -// ToInt32 casts an interface to an int32 type. -func ToInt32(i interface{}) int32 { - v, _ := ToInt32E(i) - return v -} +const errorMsg = "unable to cast %#v of type %T to %T" +const errorMsgWith = "unable to cast %#v of type %T to %T: %w" -// ToInt16 casts an interface to an int16 type. -func ToInt16(i interface{}) int16 { - v, _ := ToInt16E(i) - return v -} - -// ToInt8 casts an interface to an int8 type. -func ToInt8(i interface{}) int8 { - v, _ := ToInt8E(i) - return v -} - -// ToInt casts an interface to an int type. -func ToInt(i interface{}) int { - v, _ := ToIntE(i) - return v -} - -// ToUint casts an interface to a uint type. -func ToUint(i interface{}) uint { - v, _ := ToUintE(i) - return v -} - -// ToUint64 casts an interface to a uint64 type. -func ToUint64(i interface{}) uint64 { - v, _ := ToUint64E(i) - return v -} - -// ToUint32 casts an interface to a uint32 type. -func ToUint32(i interface{}) uint32 { - v, _ := ToUint32E(i) - return v -} - -// ToUint16 casts an interface to a uint16 type. -func ToUint16(i interface{}) uint16 { - v, _ := ToUint16E(i) - return v -} - -// ToUint8 casts an interface to a uint8 type. -func ToUint8(i interface{}) uint8 { - v, _ := ToUint8E(i) - return v -} - -// ToString casts an interface to a string type. -func ToString(i interface{}) string { - v, _ := ToStringE(i) - return v -} - -// ToStringMapString casts an interface to a map[string]string type. -func ToStringMapString(i interface{}) map[string]string { - v, _ := ToStringMapStringE(i) - return v -} - -// ToStringMapStringSlice casts an interface to a map[string][]string type. -func ToStringMapStringSlice(i interface{}) map[string][]string { - v, _ := ToStringMapStringSliceE(i) - return v -} - -// ToStringMapBool casts an interface to a map[string]bool type. -func ToStringMapBool(i interface{}) map[string]bool { - v, _ := ToStringMapBoolE(i) - return v -} - -// ToStringMapInt casts an interface to a map[string]int type. -func ToStringMapInt(i interface{}) map[string]int { - v, _ := ToStringMapIntE(i) - return v -} - -// ToStringMapInt64 casts an interface to a map[string]int64 type. -func ToStringMapInt64(i interface{}) map[string]int64 { - v, _ := ToStringMapInt64E(i) - return v -} - -// ToStringMap casts an interface to a map[string]interface{} type. -func ToStringMap(i interface{}) map[string]interface{} { - v, _ := ToStringMapE(i) - return v -} - -// ToSlice casts an interface to a []interface{} type. -func ToSlice(i interface{}) []interface{} { - v, _ := ToSliceE(i) - return v -} - -// ToBoolSlice casts an interface to a []bool type. -func ToBoolSlice(i interface{}) []bool { - v, _ := ToBoolSliceE(i) - return v -} - -// ToStringSlice casts an interface to a []string type. -func ToStringSlice(i interface{}) []string { - v, _ := ToStringSliceE(i) - return v -} - -// ToIntSlice casts an interface to a []int type. -func ToIntSlice(i interface{}) []int { - v, _ := ToIntSliceE(i) - return v -} +// Basic is a type parameter constraint for functions accepting basic types. +// +// It represents the supported basic types this package can cast to. +type Basic interface { + string | bool | Number | time.Time | time.Duration +} + +// ToE casts any value to a [Basic] type. +func ToE[T Basic](i any) (T, error) { + var t T + + var v any + var err error + + switch any(t).(type) { + case string: + v, err = ToStringE(i) + case bool: + v, err = ToBoolE(i) + case int: + v, err = toNumberE[int](i, parseInt[int]) + case int8: + v, err = toNumberE[int8](i, parseInt[int8]) + case int16: + v, err = toNumberE[int16](i, parseInt[int16]) + case int32: + v, err = toNumberE[int32](i, parseInt[int32]) + case int64: + v, err = toNumberE[int64](i, parseInt[int64]) + case uint: + v, err = toUnsignedNumberE[uint](i, parseUint[uint]) + case uint8: + v, err = toUnsignedNumberE[uint8](i, parseUint[uint8]) + case uint16: + v, err = toUnsignedNumberE[uint16](i, parseUint[uint16]) + case uint32: + v, err = toUnsignedNumberE[uint32](i, parseUint[uint32]) + case uint64: + v, err = toUnsignedNumberE[uint64](i, parseUint[uint64]) + case float32: + v, err = toNumberE[float32](i, parseFloat[float32]) + case float64: + v, err = toNumberE[float64](i, parseFloat[float64]) + case time.Time: + v, err = ToTimeE(i) + case time.Duration: + v, err = ToDurationE(i) + } + + if err != nil { + return t, err + } + + return v.(T), nil +} + +// Must is a helper that wraps a call to a cast function and panics if the error is non-nil. +func Must[T any](i any, err error) T { + if err != nil { + panic(err) + } + + return i.(T) +} + +// To casts any value to a [Basic] type. +func To[T Basic](i any) T { + v, _ := ToE[T](i) -// ToDurationSlice casts an interface to a []time.Duration type. -func ToDurationSlice(i interface{}) []time.Duration { - v, _ := ToDurationSliceE(i) return v } diff --git a/vendor/github.com/spf13/cast/caste.go b/vendor/github.com/spf13/cast/caste.go deleted file mode 100644 index a4859fb0af..0000000000 --- a/vendor/github.com/spf13/cast/caste.go +++ /dev/null @@ -1,1249 +0,0 @@ -// Copyright © 2014 Steve Francia . -// -// Use of this source code is governed by an MIT-style -// license that can be found in the LICENSE file. - -package cast - -import ( - "encoding/json" - "errors" - "fmt" - "html/template" - "reflect" - "strconv" - "strings" - "time" -) - -var errNegativeNotAllowed = errors.New("unable to cast negative value") - -// ToTimeE casts an interface to a time.Time type. -func ToTimeE(i interface{}) (tim time.Time, err error) { - i = indirect(i) - - switch v := i.(type) { - case time.Time: - return v, nil - case string: - return StringToDate(v) - case int: - return time.Unix(int64(v), 0), nil - case int64: - return time.Unix(v, 0), nil - case int32: - return time.Unix(int64(v), 0), nil - case uint: - return time.Unix(int64(v), 0), nil - case uint64: - return time.Unix(int64(v), 0), nil - case uint32: - return time.Unix(int64(v), 0), nil - default: - return time.Time{}, fmt.Errorf("unable to cast %#v of type %T to Time", i, i) - } -} - -// ToDurationE casts an interface to a time.Duration type. -func ToDurationE(i interface{}) (d time.Duration, err error) { - i = indirect(i) - - switch s := i.(type) { - case time.Duration: - return s, nil - case int, int64, int32, int16, int8, uint, uint64, uint32, uint16, uint8: - d = time.Duration(ToInt64(s)) - return - case float32, float64: - d = time.Duration(ToFloat64(s)) - return - case string: - if strings.ContainsAny(s, "nsuµmh") { - d, err = time.ParseDuration(s) - } else { - d, err = time.ParseDuration(s + "ns") - } - return - default: - err = fmt.Errorf("unable to cast %#v of type %T to Duration", i, i) - return - } -} - -// ToBoolE casts an interface to a bool type. -func ToBoolE(i interface{}) (bool, error) { - i = indirect(i) - - switch b := i.(type) { - case bool: - return b, nil - case nil: - return false, nil - case int: - if i.(int) != 0 { - return true, nil - } - return false, nil - case string: - return strconv.ParseBool(i.(string)) - default: - return false, fmt.Errorf("unable to cast %#v of type %T to bool", i, i) - } -} - -// ToFloat64E casts an interface to a float64 type. -func ToFloat64E(i interface{}) (float64, error) { - i = indirect(i) - - switch s := i.(type) { - case float64: - return s, nil - case float32: - return float64(s), nil - case int: - return float64(s), nil - case int64: - return float64(s), nil - case int32: - return float64(s), nil - case int16: - return float64(s), nil - case int8: - return float64(s), nil - case uint: - return float64(s), nil - case uint64: - return float64(s), nil - case uint32: - return float64(s), nil - case uint16: - return float64(s), nil - case uint8: - return float64(s), nil - case string: - v, err := strconv.ParseFloat(s, 64) - if err == nil { - return v, nil - } - return 0, fmt.Errorf("unable to cast %#v of type %T to float64", i, i) - case bool: - if s { - return 1, nil - } - return 0, nil - default: - return 0, fmt.Errorf("unable to cast %#v of type %T to float64", i, i) - } -} - -// ToFloat32E casts an interface to a float32 type. -func ToFloat32E(i interface{}) (float32, error) { - i = indirect(i) - - switch s := i.(type) { - case float64: - return float32(s), nil - case float32: - return s, nil - case int: - return float32(s), nil - case int64: - return float32(s), nil - case int32: - return float32(s), nil - case int16: - return float32(s), nil - case int8: - return float32(s), nil - case uint: - return float32(s), nil - case uint64: - return float32(s), nil - case uint32: - return float32(s), nil - case uint16: - return float32(s), nil - case uint8: - return float32(s), nil - case string: - v, err := strconv.ParseFloat(s, 32) - if err == nil { - return float32(v), nil - } - return 0, fmt.Errorf("unable to cast %#v of type %T to float32", i, i) - case bool: - if s { - return 1, nil - } - return 0, nil - default: - return 0, fmt.Errorf("unable to cast %#v of type %T to float32", i, i) - } -} - -// ToInt64E casts an interface to an int64 type. -func ToInt64E(i interface{}) (int64, error) { - i = indirect(i) - - switch s := i.(type) { - case int: - return int64(s), nil - case int64: - return s, nil - case int32: - return int64(s), nil - case int16: - return int64(s), nil - case int8: - return int64(s), nil - case uint: - return int64(s), nil - case uint64: - return int64(s), nil - case uint32: - return int64(s), nil - case uint16: - return int64(s), nil - case uint8: - return int64(s), nil - case float64: - return int64(s), nil - case float32: - return int64(s), nil - case string: - v, err := strconv.ParseInt(s, 0, 0) - if err == nil { - return v, nil - } - return 0, fmt.Errorf("unable to cast %#v of type %T to int64", i, i) - case bool: - if s { - return 1, nil - } - return 0, nil - case nil: - return 0, nil - default: - return 0, fmt.Errorf("unable to cast %#v of type %T to int64", i, i) - } -} - -// ToInt32E casts an interface to an int32 type. -func ToInt32E(i interface{}) (int32, error) { - i = indirect(i) - - switch s := i.(type) { - case int: - return int32(s), nil - case int64: - return int32(s), nil - case int32: - return s, nil - case int16: - return int32(s), nil - case int8: - return int32(s), nil - case uint: - return int32(s), nil - case uint64: - return int32(s), nil - case uint32: - return int32(s), nil - case uint16: - return int32(s), nil - case uint8: - return int32(s), nil - case float64: - return int32(s), nil - case float32: - return int32(s), nil - case string: - v, err := strconv.ParseInt(s, 0, 0) - if err == nil { - return int32(v), nil - } - return 0, fmt.Errorf("unable to cast %#v of type %T to int32", i, i) - case bool: - if s { - return 1, nil - } - return 0, nil - case nil: - return 0, nil - default: - return 0, fmt.Errorf("unable to cast %#v of type %T to int32", i, i) - } -} - -// ToInt16E casts an interface to an int16 type. -func ToInt16E(i interface{}) (int16, error) { - i = indirect(i) - - switch s := i.(type) { - case int: - return int16(s), nil - case int64: - return int16(s), nil - case int32: - return int16(s), nil - case int16: - return s, nil - case int8: - return int16(s), nil - case uint: - return int16(s), nil - case uint64: - return int16(s), nil - case uint32: - return int16(s), nil - case uint16: - return int16(s), nil - case uint8: - return int16(s), nil - case float64: - return int16(s), nil - case float32: - return int16(s), nil - case string: - v, err := strconv.ParseInt(s, 0, 0) - if err == nil { - return int16(v), nil - } - return 0, fmt.Errorf("unable to cast %#v of type %T to int16", i, i) - case bool: - if s { - return 1, nil - } - return 0, nil - case nil: - return 0, nil - default: - return 0, fmt.Errorf("unable to cast %#v of type %T to int16", i, i) - } -} - -// ToInt8E casts an interface to an int8 type. -func ToInt8E(i interface{}) (int8, error) { - i = indirect(i) - - switch s := i.(type) { - case int: - return int8(s), nil - case int64: - return int8(s), nil - case int32: - return int8(s), nil - case int16: - return int8(s), nil - case int8: - return s, nil - case uint: - return int8(s), nil - case uint64: - return int8(s), nil - case uint32: - return int8(s), nil - case uint16: - return int8(s), nil - case uint8: - return int8(s), nil - case float64: - return int8(s), nil - case float32: - return int8(s), nil - case string: - v, err := strconv.ParseInt(s, 0, 0) - if err == nil { - return int8(v), nil - } - return 0, fmt.Errorf("unable to cast %#v of type %T to int8", i, i) - case bool: - if s { - return 1, nil - } - return 0, nil - case nil: - return 0, nil - default: - return 0, fmt.Errorf("unable to cast %#v of type %T to int8", i, i) - } -} - -// ToIntE casts an interface to an int type. -func ToIntE(i interface{}) (int, error) { - i = indirect(i) - - switch s := i.(type) { - case int: - return s, nil - case int64: - return int(s), nil - case int32: - return int(s), nil - case int16: - return int(s), nil - case int8: - return int(s), nil - case uint: - return int(s), nil - case uint64: - return int(s), nil - case uint32: - return int(s), nil - case uint16: - return int(s), nil - case uint8: - return int(s), nil - case float64: - return int(s), nil - case float32: - return int(s), nil - case string: - v, err := strconv.ParseInt(s, 0, 0) - if err == nil { - return int(v), nil - } - return 0, fmt.Errorf("unable to cast %#v of type %T to int", i, i) - case bool: - if s { - return 1, nil - } - return 0, nil - case nil: - return 0, nil - default: - return 0, fmt.Errorf("unable to cast %#v of type %T to int", i, i) - } -} - -// ToUintE casts an interface to a uint type. -func ToUintE(i interface{}) (uint, error) { - i = indirect(i) - - switch s := i.(type) { - case string: - v, err := strconv.ParseUint(s, 0, 0) - if err == nil { - return uint(v), nil - } - return 0, fmt.Errorf("unable to cast %#v to uint: %s", i, err) - case int: - if s < 0 { - return 0, errNegativeNotAllowed - } - return uint(s), nil - case int64: - if s < 0 { - return 0, errNegativeNotAllowed - } - return uint(s), nil - case int32: - if s < 0 { - return 0, errNegativeNotAllowed - } - return uint(s), nil - case int16: - if s < 0 { - return 0, errNegativeNotAllowed - } - return uint(s), nil - case int8: - if s < 0 { - return 0, errNegativeNotAllowed - } - return uint(s), nil - case uint: - return s, nil - case uint64: - return uint(s), nil - case uint32: - return uint(s), nil - case uint16: - return uint(s), nil - case uint8: - return uint(s), nil - case float64: - if s < 0 { - return 0, errNegativeNotAllowed - } - return uint(s), nil - case float32: - if s < 0 { - return 0, errNegativeNotAllowed - } - return uint(s), nil - case bool: - if s { - return 1, nil - } - return 0, nil - case nil: - return 0, nil - default: - return 0, fmt.Errorf("unable to cast %#v of type %T to uint", i, i) - } -} - -// ToUint64E casts an interface to a uint64 type. -func ToUint64E(i interface{}) (uint64, error) { - i = indirect(i) - - switch s := i.(type) { - case string: - v, err := strconv.ParseUint(s, 0, 64) - if err == nil { - return v, nil - } - return 0, fmt.Errorf("unable to cast %#v to uint64: %s", i, err) - case int: - if s < 0 { - return 0, errNegativeNotAllowed - } - return uint64(s), nil - case int64: - if s < 0 { - return 0, errNegativeNotAllowed - } - return uint64(s), nil - case int32: - if s < 0 { - return 0, errNegativeNotAllowed - } - return uint64(s), nil - case int16: - if s < 0 { - return 0, errNegativeNotAllowed - } - return uint64(s), nil - case int8: - if s < 0 { - return 0, errNegativeNotAllowed - } - return uint64(s), nil - case uint: - return uint64(s), nil - case uint64: - return s, nil - case uint32: - return uint64(s), nil - case uint16: - return uint64(s), nil - case uint8: - return uint64(s), nil - case float32: - if s < 0 { - return 0, errNegativeNotAllowed - } - return uint64(s), nil - case float64: - if s < 0 { - return 0, errNegativeNotAllowed - } - return uint64(s), nil - case bool: - if s { - return 1, nil - } - return 0, nil - case nil: - return 0, nil - default: - return 0, fmt.Errorf("unable to cast %#v of type %T to uint64", i, i) - } -} - -// ToUint32E casts an interface to a uint32 type. -func ToUint32E(i interface{}) (uint32, error) { - i = indirect(i) - - switch s := i.(type) { - case string: - v, err := strconv.ParseUint(s, 0, 32) - if err == nil { - return uint32(v), nil - } - return 0, fmt.Errorf("unable to cast %#v to uint32: %s", i, err) - case int: - if s < 0 { - return 0, errNegativeNotAllowed - } - return uint32(s), nil - case int64: - if s < 0 { - return 0, errNegativeNotAllowed - } - return uint32(s), nil - case int32: - if s < 0 { - return 0, errNegativeNotAllowed - } - return uint32(s), nil - case int16: - if s < 0 { - return 0, errNegativeNotAllowed - } - return uint32(s), nil - case int8: - if s < 0 { - return 0, errNegativeNotAllowed - } - return uint32(s), nil - case uint: - return uint32(s), nil - case uint64: - return uint32(s), nil - case uint32: - return s, nil - case uint16: - return uint32(s), nil - case uint8: - return uint32(s), nil - case float64: - if s < 0 { - return 0, errNegativeNotAllowed - } - return uint32(s), nil - case float32: - if s < 0 { - return 0, errNegativeNotAllowed - } - return uint32(s), nil - case bool: - if s { - return 1, nil - } - return 0, nil - case nil: - return 0, nil - default: - return 0, fmt.Errorf("unable to cast %#v of type %T to uint32", i, i) - } -} - -// ToUint16E casts an interface to a uint16 type. -func ToUint16E(i interface{}) (uint16, error) { - i = indirect(i) - - switch s := i.(type) { - case string: - v, err := strconv.ParseUint(s, 0, 16) - if err == nil { - return uint16(v), nil - } - return 0, fmt.Errorf("unable to cast %#v to uint16: %s", i, err) - case int: - if s < 0 { - return 0, errNegativeNotAllowed - } - return uint16(s), nil - case int64: - if s < 0 { - return 0, errNegativeNotAllowed - } - return uint16(s), nil - case int32: - if s < 0 { - return 0, errNegativeNotAllowed - } - return uint16(s), nil - case int16: - if s < 0 { - return 0, errNegativeNotAllowed - } - return uint16(s), nil - case int8: - if s < 0 { - return 0, errNegativeNotAllowed - } - return uint16(s), nil - case uint: - return uint16(s), nil - case uint64: - return uint16(s), nil - case uint32: - return uint16(s), nil - case uint16: - return s, nil - case uint8: - return uint16(s), nil - case float64: - if s < 0 { - return 0, errNegativeNotAllowed - } - return uint16(s), nil - case float32: - if s < 0 { - return 0, errNegativeNotAllowed - } - return uint16(s), nil - case bool: - if s { - return 1, nil - } - return 0, nil - case nil: - return 0, nil - default: - return 0, fmt.Errorf("unable to cast %#v of type %T to uint16", i, i) - } -} - -// ToUint8E casts an interface to a uint type. -func ToUint8E(i interface{}) (uint8, error) { - i = indirect(i) - - switch s := i.(type) { - case string: - v, err := strconv.ParseUint(s, 0, 8) - if err == nil { - return uint8(v), nil - } - return 0, fmt.Errorf("unable to cast %#v to uint8: %s", i, err) - case int: - if s < 0 { - return 0, errNegativeNotAllowed - } - return uint8(s), nil - case int64: - if s < 0 { - return 0, errNegativeNotAllowed - } - return uint8(s), nil - case int32: - if s < 0 { - return 0, errNegativeNotAllowed - } - return uint8(s), nil - case int16: - if s < 0 { - return 0, errNegativeNotAllowed - } - return uint8(s), nil - case int8: - if s < 0 { - return 0, errNegativeNotAllowed - } - return uint8(s), nil - case uint: - return uint8(s), nil - case uint64: - return uint8(s), nil - case uint32: - return uint8(s), nil - case uint16: - return uint8(s), nil - case uint8: - return s, nil - case float64: - if s < 0 { - return 0, errNegativeNotAllowed - } - return uint8(s), nil - case float32: - if s < 0 { - return 0, errNegativeNotAllowed - } - return uint8(s), nil - case bool: - if s { - return 1, nil - } - return 0, nil - case nil: - return 0, nil - default: - return 0, fmt.Errorf("unable to cast %#v of type %T to uint8", i, i) - } -} - -// From html/template/content.go -// Copyright 2011 The Go Authors. All rights reserved. -// indirect returns the value, after dereferencing as many times -// as necessary to reach the base type (or nil). -func indirect(a interface{}) interface{} { - if a == nil { - return nil - } - if t := reflect.TypeOf(a); t.Kind() != reflect.Ptr { - // Avoid creating a reflect.Value if it's not a pointer. - return a - } - v := reflect.ValueOf(a) - for v.Kind() == reflect.Ptr && !v.IsNil() { - v = v.Elem() - } - return v.Interface() -} - -// From html/template/content.go -// Copyright 2011 The Go Authors. All rights reserved. -// indirectToStringerOrError returns the value, after dereferencing as many times -// as necessary to reach the base type (or nil) or an implementation of fmt.Stringer -// or error, -func indirectToStringerOrError(a interface{}) interface{} { - if a == nil { - return nil - } - - var errorType = reflect.TypeOf((*error)(nil)).Elem() - var fmtStringerType = reflect.TypeOf((*fmt.Stringer)(nil)).Elem() - - v := reflect.ValueOf(a) - for !v.Type().Implements(fmtStringerType) && !v.Type().Implements(errorType) && v.Kind() == reflect.Ptr && !v.IsNil() { - v = v.Elem() - } - return v.Interface() -} - -// ToStringE casts an interface to a string type. -func ToStringE(i interface{}) (string, error) { - i = indirectToStringerOrError(i) - - switch s := i.(type) { - case string: - return s, nil - case bool: - return strconv.FormatBool(s), nil - case float64: - return strconv.FormatFloat(s, 'f', -1, 64), nil - case float32: - return strconv.FormatFloat(float64(s), 'f', -1, 32), nil - case int: - return strconv.Itoa(s), nil - case int64: - return strconv.FormatInt(s, 10), nil - case int32: - return strconv.Itoa(int(s)), nil - case int16: - return strconv.FormatInt(int64(s), 10), nil - case int8: - return strconv.FormatInt(int64(s), 10), nil - case uint: - return strconv.FormatInt(int64(s), 10), nil - case uint64: - return strconv.FormatInt(int64(s), 10), nil - case uint32: - return strconv.FormatInt(int64(s), 10), nil - case uint16: - return strconv.FormatInt(int64(s), 10), nil - case uint8: - return strconv.FormatInt(int64(s), 10), nil - case []byte: - return string(s), nil - case template.HTML: - return string(s), nil - case template.URL: - return string(s), nil - case template.JS: - return string(s), nil - case template.CSS: - return string(s), nil - case template.HTMLAttr: - return string(s), nil - case nil: - return "", nil - case fmt.Stringer: - return s.String(), nil - case error: - return s.Error(), nil - default: - return "", fmt.Errorf("unable to cast %#v of type %T to string", i, i) - } -} - -// ToStringMapStringE casts an interface to a map[string]string type. -func ToStringMapStringE(i interface{}) (map[string]string, error) { - var m = map[string]string{} - - switch v := i.(type) { - case map[string]string: - return v, nil - case map[string]interface{}: - for k, val := range v { - m[ToString(k)] = ToString(val) - } - return m, nil - case map[interface{}]string: - for k, val := range v { - m[ToString(k)] = ToString(val) - } - return m, nil - case map[interface{}]interface{}: - for k, val := range v { - m[ToString(k)] = ToString(val) - } - return m, nil - case string: - err := jsonStringToObject(v, &m) - return m, err - default: - return m, fmt.Errorf("unable to cast %#v of type %T to map[string]string", i, i) - } -} - -// ToStringMapStringSliceE casts an interface to a map[string][]string type. -func ToStringMapStringSliceE(i interface{}) (map[string][]string, error) { - var m = map[string][]string{} - - switch v := i.(type) { - case map[string][]string: - return v, nil - case map[string][]interface{}: - for k, val := range v { - m[ToString(k)] = ToStringSlice(val) - } - return m, nil - case map[string]string: - for k, val := range v { - m[ToString(k)] = []string{val} - } - case map[string]interface{}: - for k, val := range v { - switch vt := val.(type) { - case []interface{}: - m[ToString(k)] = ToStringSlice(vt) - case []string: - m[ToString(k)] = vt - default: - m[ToString(k)] = []string{ToString(val)} - } - } - return m, nil - case map[interface{}][]string: - for k, val := range v { - m[ToString(k)] = ToStringSlice(val) - } - return m, nil - case map[interface{}]string: - for k, val := range v { - m[ToString(k)] = ToStringSlice(val) - } - return m, nil - case map[interface{}][]interface{}: - for k, val := range v { - m[ToString(k)] = ToStringSlice(val) - } - return m, nil - case map[interface{}]interface{}: - for k, val := range v { - key, err := ToStringE(k) - if err != nil { - return m, fmt.Errorf("unable to cast %#v of type %T to map[string][]string", i, i) - } - value, err := ToStringSliceE(val) - if err != nil { - return m, fmt.Errorf("unable to cast %#v of type %T to map[string][]string", i, i) - } - m[key] = value - } - case string: - err := jsonStringToObject(v, &m) - return m, err - default: - return m, fmt.Errorf("unable to cast %#v of type %T to map[string][]string", i, i) - } - return m, nil -} - -// ToStringMapBoolE casts an interface to a map[string]bool type. -func ToStringMapBoolE(i interface{}) (map[string]bool, error) { - var m = map[string]bool{} - - switch v := i.(type) { - case map[interface{}]interface{}: - for k, val := range v { - m[ToString(k)] = ToBool(val) - } - return m, nil - case map[string]interface{}: - for k, val := range v { - m[ToString(k)] = ToBool(val) - } - return m, nil - case map[string]bool: - return v, nil - case string: - err := jsonStringToObject(v, &m) - return m, err - default: - return m, fmt.Errorf("unable to cast %#v of type %T to map[string]bool", i, i) - } -} - -// ToStringMapE casts an interface to a map[string]interface{} type. -func ToStringMapE(i interface{}) (map[string]interface{}, error) { - var m = map[string]interface{}{} - - switch v := i.(type) { - case map[interface{}]interface{}: - for k, val := range v { - m[ToString(k)] = val - } - return m, nil - case map[string]interface{}: - return v, nil - case string: - err := jsonStringToObject(v, &m) - return m, err - default: - return m, fmt.Errorf("unable to cast %#v of type %T to map[string]interface{}", i, i) - } -} - -// ToStringMapIntE casts an interface to a map[string]int{} type. -func ToStringMapIntE(i interface{}) (map[string]int, error) { - var m = map[string]int{} - if i == nil { - return m, fmt.Errorf("unable to cast %#v of type %T to map[string]int", i, i) - } - - switch v := i.(type) { - case map[interface{}]interface{}: - for k, val := range v { - m[ToString(k)] = ToInt(val) - } - return m, nil - case map[string]interface{}: - for k, val := range v { - m[k] = ToInt(val) - } - return m, nil - case map[string]int: - return v, nil - case string: - err := jsonStringToObject(v, &m) - return m, err - } - - if reflect.TypeOf(i).Kind() != reflect.Map { - return m, fmt.Errorf("unable to cast %#v of type %T to map[string]int", i, i) - } - - mVal := reflect.ValueOf(m) - v := reflect.ValueOf(i) - for _, keyVal := range v.MapKeys() { - val, err := ToIntE(v.MapIndex(keyVal).Interface()) - if err != nil { - return m, fmt.Errorf("unable to cast %#v of type %T to map[string]int", i, i) - } - mVal.SetMapIndex(keyVal, reflect.ValueOf(val)) - } - return m, nil -} - -// ToStringMapInt64E casts an interface to a map[string]int64{} type. -func ToStringMapInt64E(i interface{}) (map[string]int64, error) { - var m = map[string]int64{} - if i == nil { - return m, fmt.Errorf("unable to cast %#v of type %T to map[string]int64", i, i) - } - - switch v := i.(type) { - case map[interface{}]interface{}: - for k, val := range v { - m[ToString(k)] = ToInt64(val) - } - return m, nil - case map[string]interface{}: - for k, val := range v { - m[k] = ToInt64(val) - } - return m, nil - case map[string]int64: - return v, nil - case string: - err := jsonStringToObject(v, &m) - return m, err - } - - if reflect.TypeOf(i).Kind() != reflect.Map { - return m, fmt.Errorf("unable to cast %#v of type %T to map[string]int64", i, i) - } - mVal := reflect.ValueOf(m) - v := reflect.ValueOf(i) - for _, keyVal := range v.MapKeys() { - val, err := ToInt64E(v.MapIndex(keyVal).Interface()) - if err != nil { - return m, fmt.Errorf("unable to cast %#v of type %T to map[string]int64", i, i) - } - mVal.SetMapIndex(keyVal, reflect.ValueOf(val)) - } - return m, nil -} - -// ToSliceE casts an interface to a []interface{} type. -func ToSliceE(i interface{}) ([]interface{}, error) { - var s []interface{} - - switch v := i.(type) { - case []interface{}: - return append(s, v...), nil - case []map[string]interface{}: - for _, u := range v { - s = append(s, u) - } - return s, nil - default: - return s, fmt.Errorf("unable to cast %#v of type %T to []interface{}", i, i) - } -} - -// ToBoolSliceE casts an interface to a []bool type. -func ToBoolSliceE(i interface{}) ([]bool, error) { - if i == nil { - return []bool{}, fmt.Errorf("unable to cast %#v of type %T to []bool", i, i) - } - - switch v := i.(type) { - case []bool: - return v, nil - } - - kind := reflect.TypeOf(i).Kind() - switch kind { - case reflect.Slice, reflect.Array: - s := reflect.ValueOf(i) - a := make([]bool, s.Len()) - for j := 0; j < s.Len(); j++ { - val, err := ToBoolE(s.Index(j).Interface()) - if err != nil { - return []bool{}, fmt.Errorf("unable to cast %#v of type %T to []bool", i, i) - } - a[j] = val - } - return a, nil - default: - return []bool{}, fmt.Errorf("unable to cast %#v of type %T to []bool", i, i) - } -} - -// ToStringSliceE casts an interface to a []string type. -func ToStringSliceE(i interface{}) ([]string, error) { - var a []string - - switch v := i.(type) { - case []interface{}: - for _, u := range v { - a = append(a, ToString(u)) - } - return a, nil - case []string: - return v, nil - case string: - return strings.Fields(v), nil - case interface{}: - str, err := ToStringE(v) - if err != nil { - return a, fmt.Errorf("unable to cast %#v of type %T to []string", i, i) - } - return []string{str}, nil - default: - return a, fmt.Errorf("unable to cast %#v of type %T to []string", i, i) - } -} - -// ToIntSliceE casts an interface to a []int type. -func ToIntSliceE(i interface{}) ([]int, error) { - if i == nil { - return []int{}, fmt.Errorf("unable to cast %#v of type %T to []int", i, i) - } - - switch v := i.(type) { - case []int: - return v, nil - } - - kind := reflect.TypeOf(i).Kind() - switch kind { - case reflect.Slice, reflect.Array: - s := reflect.ValueOf(i) - a := make([]int, s.Len()) - for j := 0; j < s.Len(); j++ { - val, err := ToIntE(s.Index(j).Interface()) - if err != nil { - return []int{}, fmt.Errorf("unable to cast %#v of type %T to []int", i, i) - } - a[j] = val - } - return a, nil - default: - return []int{}, fmt.Errorf("unable to cast %#v of type %T to []int", i, i) - } -} - -// ToDurationSliceE casts an interface to a []time.Duration type. -func ToDurationSliceE(i interface{}) ([]time.Duration, error) { - if i == nil { - return []time.Duration{}, fmt.Errorf("unable to cast %#v of type %T to []time.Duration", i, i) - } - - switch v := i.(type) { - case []time.Duration: - return v, nil - } - - kind := reflect.TypeOf(i).Kind() - switch kind { - case reflect.Slice, reflect.Array: - s := reflect.ValueOf(i) - a := make([]time.Duration, s.Len()) - for j := 0; j < s.Len(); j++ { - val, err := ToDurationE(s.Index(j).Interface()) - if err != nil { - return []time.Duration{}, fmt.Errorf("unable to cast %#v of type %T to []time.Duration", i, i) - } - a[j] = val - } - return a, nil - default: - return []time.Duration{}, fmt.Errorf("unable to cast %#v of type %T to []time.Duration", i, i) - } -} - -// StringToDate attempts to parse a string into a time.Time type using a -// predefined list of formats. If no suitable format is found, an error is -// returned. -func StringToDate(s string) (time.Time, error) { - return parseDateWith(s, []string{ - time.RFC3339, - "2006-01-02T15:04:05", // iso8601 without timezone - time.RFC1123Z, - time.RFC1123, - time.RFC822Z, - time.RFC822, - time.RFC850, - time.ANSIC, - time.UnixDate, - time.RubyDate, - "2006-01-02 15:04:05.999999999 -0700 MST", // Time.String() - "2006-01-02", - "02 Jan 2006", - "2006-01-02T15:04:05-0700", // RFC3339 without timezone hh:mm colon - "2006-01-02 15:04:05 -07:00", - "2006-01-02 15:04:05 -0700", - "2006-01-02 15:04:05Z07:00", // RFC3339 without T - "2006-01-02 15:04:05Z0700", // RFC3339 without T or timezone hh:mm colon - "2006-01-02 15:04:05", - time.Kitchen, - time.Stamp, - time.StampMilli, - time.StampMicro, - time.StampNano, - }) -} - -func parseDateWith(s string, dates []string) (d time.Time, e error) { - for _, dateType := range dates { - if d, e = time.Parse(dateType, s); e == nil { - return - } - } - return d, fmt.Errorf("unable to parse date: %s", s) -} - -// jsonStringToObject attempts to unmarshall a string as JSON into -// the object passed as pointer. -func jsonStringToObject(s string, v interface{}) error { - data := []byte(s) - return json.Unmarshal(data, v) -} diff --git a/vendor/github.com/spf13/cast/indirect.go b/vendor/github.com/spf13/cast/indirect.go new file mode 100644 index 0000000000..093345f737 --- /dev/null +++ b/vendor/github.com/spf13/cast/indirect.go @@ -0,0 +1,37 @@ +// Copyright © 2014 Steve Francia . +// +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package cast + +import ( + "reflect" +) + +// From html/template/content.go +// Copyright 2011 The Go Authors. All rights reserved. +// indirect returns the value, after dereferencing as many times +// as necessary to reach the base type (or nil). +func indirect(i any) (any, bool) { + if i == nil { + return nil, false + } + + if t := reflect.TypeOf(i); t.Kind() != reflect.Ptr { + // Avoid creating a reflect.Value if it's not a pointer. + return i, false + } + + v := reflect.ValueOf(i) + + for v.Kind() == reflect.Ptr || (v.Kind() == reflect.Interface && v.Elem().Kind() == reflect.Ptr) { + if v.IsNil() { + return nil, true + } + + v = v.Elem() + } + + return v.Interface(), true +} diff --git a/vendor/github.com/spf13/cast/internal/time.go b/vendor/github.com/spf13/cast/internal/time.go new file mode 100644 index 0000000000..906e9aece3 --- /dev/null +++ b/vendor/github.com/spf13/cast/internal/time.go @@ -0,0 +1,79 @@ +package internal + +import ( + "fmt" + "time" +) + +//go:generate stringer -type=TimeFormatType + +type TimeFormatType int + +const ( + TimeFormatNoTimezone TimeFormatType = iota + TimeFormatNamedTimezone + TimeFormatNumericTimezone + TimeFormatNumericAndNamedTimezone + TimeFormatTimeOnly +) + +type TimeFormat struct { + Format string + Typ TimeFormatType +} + +func (f TimeFormat) HasTimezone() bool { + // We don't include the formats with only named timezones, see + // https://github.com/golang/go/issues/19694#issuecomment-289103522 + return f.Typ >= TimeFormatNumericTimezone && f.Typ <= TimeFormatNumericAndNamedTimezone +} + +var TimeFormats = []TimeFormat{ + // Keep common formats at the top. + {"2006-01-02", TimeFormatNoTimezone}, + {time.RFC3339, TimeFormatNumericTimezone}, + {"2006-01-02T15:04:05", TimeFormatNoTimezone}, // iso8601 without timezone + {time.RFC1123Z, TimeFormatNumericTimezone}, + {time.RFC1123, TimeFormatNamedTimezone}, + {time.RFC822Z, TimeFormatNumericTimezone}, + {time.RFC822, TimeFormatNamedTimezone}, + {time.RFC850, TimeFormatNamedTimezone}, + {"2006-01-02 15:04:05.999999999 -0700 MST", TimeFormatNumericAndNamedTimezone}, // Time.String() + {"2006-01-02T15:04:05-0700", TimeFormatNumericTimezone}, // RFC3339 without timezone hh:mm colon + {"2006-01-02 15:04:05Z0700", TimeFormatNumericTimezone}, // RFC3339 without T or timezone hh:mm colon + {"2006-01-02 15:04:05", TimeFormatNoTimezone}, + {time.ANSIC, TimeFormatNoTimezone}, + {time.UnixDate, TimeFormatNamedTimezone}, + {time.RubyDate, TimeFormatNumericTimezone}, + {"2006-01-02 15:04:05Z07:00", TimeFormatNumericTimezone}, + {"02 Jan 2006", TimeFormatNoTimezone}, + {"2006-01-02 15:04:05 -07:00", TimeFormatNumericTimezone}, + {"2006-01-02 15:04:05 -0700", TimeFormatNumericTimezone}, + {time.Kitchen, TimeFormatTimeOnly}, + {time.Stamp, TimeFormatTimeOnly}, + {time.StampMilli, TimeFormatTimeOnly}, + {time.StampMicro, TimeFormatTimeOnly}, + {time.StampNano, TimeFormatTimeOnly}, +} + +func ParseDateWith(s string, location *time.Location, formats []TimeFormat) (d time.Time, e error) { + for _, format := range formats { + if d, e = time.Parse(format.Format, s); e == nil { + + // Some time formats have a zone name, but no offset, so it gets + // put in that zone name (not the default one passed in to us), but + // without that zone's offset. So set the location manually. + if format.Typ <= TimeFormatNamedTimezone { + if location == nil { + location = time.Local + } + year, month, day := d.Date() + hour, min, sec := d.Clock() + d = time.Date(year, month, day, hour, min, sec, d.Nanosecond(), location) + } + + return + } + } + return d, fmt.Errorf("unable to parse date: %s", s) +} diff --git a/vendor/github.com/spf13/cast/internal/timeformattype_string.go b/vendor/github.com/spf13/cast/internal/timeformattype_string.go new file mode 100644 index 0000000000..60a29a862b --- /dev/null +++ b/vendor/github.com/spf13/cast/internal/timeformattype_string.go @@ -0,0 +1,27 @@ +// Code generated by "stringer -type=TimeFormatType"; DO NOT EDIT. + +package internal + +import "strconv" + +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[TimeFormatNoTimezone-0] + _ = x[TimeFormatNamedTimezone-1] + _ = x[TimeFormatNumericTimezone-2] + _ = x[TimeFormatNumericAndNamedTimezone-3] + _ = x[TimeFormatTimeOnly-4] +} + +const _TimeFormatType_name = "TimeFormatNoTimezoneTimeFormatNamedTimezoneTimeFormatNumericTimezoneTimeFormatNumericAndNamedTimezoneTimeFormatTimeOnly" + +var _TimeFormatType_index = [...]uint8{0, 20, 43, 68, 101, 119} + +func (i TimeFormatType) String() string { + if i < 0 || i >= TimeFormatType(len(_TimeFormatType_index)-1) { + return "TimeFormatType(" + strconv.FormatInt(int64(i), 10) + ")" + } + return _TimeFormatType_name[_TimeFormatType_index[i]:_TimeFormatType_index[i+1]] +} diff --git a/vendor/github.com/spf13/cast/map.go b/vendor/github.com/spf13/cast/map.go new file mode 100644 index 0000000000..7d6beb56cc --- /dev/null +++ b/vendor/github.com/spf13/cast/map.go @@ -0,0 +1,224 @@ +// Copyright © 2014 Steve Francia . +// +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package cast + +import ( + "encoding/json" + "fmt" + "reflect" +) + +func toMapE[K comparable, V any](i any, keyFn func(any) K, valFn func(any) V) (map[K]V, error) { + m := map[K]V{} + + if i == nil { + return nil, fmt.Errorf(errorMsg, i, i, m) + } + + switch v := i.(type) { + case map[K]V: + return v, nil + + case map[K]any: + for k, val := range v { + m[k] = valFn(val) + } + + return m, nil + + case map[any]V: + for k, val := range v { + m[keyFn(k)] = val + } + + return m, nil + + case map[any]any: + for k, val := range v { + m[keyFn(k)] = valFn(val) + } + + return m, nil + + case string: + err := jsonStringToObject(v, &m) + if err != nil { + return nil, err + } + + return m, nil + + default: + return nil, fmt.Errorf(errorMsg, i, i, m) + } +} + +func toStringMapE[T any](i any, fn func(any) T) (map[string]T, error) { + return toMapE(i, ToString, fn) +} + +// ToStringMapStringE casts any value to a map[string]string type. +func ToStringMapStringE(i any) (map[string]string, error) { + return toStringMapE(i, ToString) +} + +// ToStringMapStringSliceE casts any value to a map[string][]string type. +func ToStringMapStringSliceE(i any) (map[string][]string, error) { + m := map[string][]string{} + + switch v := i.(type) { + case map[string][]string: + return v, nil + case map[string][]any: + for k, val := range v { + m[ToString(k)] = ToStringSlice(val) + } + return m, nil + case map[string]string: + for k, val := range v { + m[ToString(k)] = []string{val} + } + case map[string]any: + for k, val := range v { + switch vt := val.(type) { + case []any: + m[ToString(k)] = ToStringSlice(vt) + case []string: + m[ToString(k)] = vt + default: + m[ToString(k)] = []string{ToString(val)} + } + } + return m, nil + case map[any][]string: + for k, val := range v { + m[ToString(k)] = ToStringSlice(val) + } + return m, nil + case map[any]string: + for k, val := range v { + m[ToString(k)] = ToStringSlice(val) + } + return m, nil + case map[any][]any: + for k, val := range v { + m[ToString(k)] = ToStringSlice(val) + } + return m, nil + case map[any]any: + for k, val := range v { + key, err := ToStringE(k) + if err != nil { + return nil, fmt.Errorf(errorMsg, i, i, m) + } + value, err := ToStringSliceE(val) + if err != nil { + return nil, fmt.Errorf(errorMsg, i, i, m) + } + m[key] = value + } + case string: + err := jsonStringToObject(v, &m) + if err != nil { + return nil, err + } + + return m, nil + default: + return nil, fmt.Errorf(errorMsg, i, i, m) + } + + return m, nil +} + +// ToStringMapBoolE casts any value to a map[string]bool type. +func ToStringMapBoolE(i any) (map[string]bool, error) { + return toStringMapE(i, ToBool) +} + +// ToStringMapE casts any value to a map[string]any type. +func ToStringMapE(i any) (map[string]any, error) { + fn := func(i any) any { return i } + + return toStringMapE(i, fn) +} + +func toStringMapIntE[T int | int64](i any, fn func(any) T, fnE func(any) (T, error)) (map[string]T, error) { + m := map[string]T{} + + if i == nil { + return nil, fmt.Errorf(errorMsg, i, i, m) + } + + switch v := i.(type) { + case map[string]T: + return v, nil + + case map[string]any: + for k, val := range v { + m[k] = fn(val) + } + + return m, nil + + case map[any]T: + for k, val := range v { + m[ToString(k)] = val + } + + return m, nil + + case map[any]any: + for k, val := range v { + m[ToString(k)] = fn(val) + } + + return m, nil + + case string: + err := jsonStringToObject(v, &m) + if err != nil { + return nil, err + } + + return m, nil + } + + if reflect.TypeOf(i).Kind() != reflect.Map { + return nil, fmt.Errorf(errorMsg, i, i, m) + } + + mVal := reflect.ValueOf(m) + v := reflect.ValueOf(i) + + for _, keyVal := range v.MapKeys() { + val, err := fnE(v.MapIndex(keyVal).Interface()) + if err != nil { + return m, fmt.Errorf(errorMsg, i, i, m) + } + + mVal.SetMapIndex(keyVal, reflect.ValueOf(val)) + } + + return m, nil +} + +// ToStringMapIntE casts any value to a map[string]int type. +func ToStringMapIntE(i any) (map[string]int, error) { + return toStringMapIntE(i, ToInt, ToIntE) +} + +// ToStringMapInt64E casts any value to a map[string]int64 type. +func ToStringMapInt64E(i any) (map[string]int64, error) { + return toStringMapIntE(i, ToInt64, ToInt64E) +} + +// jsonStringToObject attempts to unmarshall a string as JSON into +// the object passed as pointer. +func jsonStringToObject(s string, v any) error { + data := []byte(s) + return json.Unmarshal(data, v) +} diff --git a/vendor/github.com/spf13/cast/number.go b/vendor/github.com/spf13/cast/number.go new file mode 100644 index 0000000000..a58dc4d1ed --- /dev/null +++ b/vendor/github.com/spf13/cast/number.go @@ -0,0 +1,549 @@ +// Copyright © 2014 Steve Francia . +// +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package cast + +import ( + "encoding/json" + "errors" + "fmt" + "regexp" + "strconv" + "strings" + "time" +) + +var errNegativeNotAllowed = errors.New("unable to cast negative value") + +type float64EProvider interface { + Float64() (float64, error) +} + +type float64Provider interface { + Float64() float64 +} + +// Number is a type parameter constraint for functions accepting number types. +// +// It represents the supported number types this package can cast to. +type Number interface { + int | int8 | int16 | int32 | int64 | uint | uint8 | uint16 | uint32 | uint64 | float32 | float64 +} + +type integer interface { + int | int8 | int16 | int32 | int64 +} + +type unsigned interface { + uint | uint8 | uint16 | uint32 | uint64 +} + +type float interface { + float32 | float64 +} + +// ToNumberE casts any value to a [Number] type. +func ToNumberE[T Number](i any) (T, error) { + var t T + + switch any(t).(type) { + case int: + return toNumberE[T](i, parseNumber[T]) + case int8: + return toNumberE[T](i, parseNumber[T]) + case int16: + return toNumberE[T](i, parseNumber[T]) + case int32: + return toNumberE[T](i, parseNumber[T]) + case int64: + return toNumberE[T](i, parseNumber[T]) + case uint: + return toUnsignedNumberE[T](i, parseNumber[T]) + case uint8: + return toUnsignedNumberE[T](i, parseNumber[T]) + case uint16: + return toUnsignedNumberE[T](i, parseNumber[T]) + case uint32: + return toUnsignedNumberE[T](i, parseNumber[T]) + case uint64: + return toUnsignedNumberE[T](i, parseNumber[T]) + case float32: + return toNumberE[T](i, parseNumber[T]) + case float64: + return toNumberE[T](i, parseNumber[T]) + default: + return 0, fmt.Errorf("unknown number type: %T", t) + } +} + +// ToNumber casts any value to a [Number] type. +func ToNumber[T Number](i any) T { + v, _ := ToNumberE[T](i) + + return v +} + +// toNumber's semantics differ from other "to" functions. +// It returns false as the second parameter if the conversion fails. +// This is to signal other callers that they should proceed with their own conversions. +func toNumber[T Number](i any) (T, bool) { + i, _ = indirect(i) + + switch s := i.(type) { + case T: + return s, true + case int: + return T(s), true + case int8: + return T(s), true + case int16: + return T(s), true + case int32: + return T(s), true + case int64: + return T(s), true + case uint: + return T(s), true + case uint8: + return T(s), true + case uint16: + return T(s), true + case uint32: + return T(s), true + case uint64: + return T(s), true + case float32: + return T(s), true + case float64: + return T(s), true + case bool: + if s { + return 1, true + } + + return 0, true + case nil: + return 0, true + case time.Weekday: + return T(s), true + case time.Month: + return T(s), true + } + + return 0, false +} + +func toNumberE[T Number](i any, parseFn func(string) (T, error)) (T, error) { + n, ok := toNumber[T](i) + if ok { + return n, nil + } + + i, _ = indirect(i) + + switch s := i.(type) { + case string: + if s == "" { + return 0, nil + } + + v, err := parseFn(s) + if err != nil { + return 0, fmt.Errorf(errorMsgWith, i, i, n, err) + } + + return v, nil + case json.Number: + if s == "" { + return 0, nil + } + + v, err := parseFn(string(s)) + if err != nil { + return 0, fmt.Errorf(errorMsgWith, i, i, n, err) + } + + return v, nil + case float64EProvider: + if _, ok := any(n).(float64); !ok { + return 0, fmt.Errorf(errorMsg, i, i, n) + } + + v, err := s.Float64() + if err != nil { + return 0, fmt.Errorf(errorMsg, i, i, n) + } + + return T(v), nil + case float64Provider: + if _, ok := any(n).(float64); !ok { + return 0, fmt.Errorf(errorMsg, i, i, n) + } + + return T(s.Float64()), nil + default: + if i, ok := resolveAlias(i); ok { + return toNumberE(i, parseFn) + } + + return 0, fmt.Errorf(errorMsg, i, i, n) + } +} + +func toUnsignedNumber[T Number](i any) (T, bool, bool) { + i, _ = indirect(i) + + switch s := i.(type) { + case T: + return s, true, true + case int: + if s < 0 { + return 0, false, false + } + + return T(s), true, true + case int8: + if s < 0 { + return 0, false, false + } + + return T(s), true, true + case int16: + if s < 0 { + return 0, false, false + } + + return T(s), true, true + case int32: + if s < 0 { + return 0, false, false + } + + return T(s), true, true + case int64: + if s < 0 { + return 0, false, false + } + + return T(s), true, true + case uint: + return T(s), true, true + case uint8: + return T(s), true, true + case uint16: + return T(s), true, true + case uint32: + return T(s), true, true + case uint64: + return T(s), true, true + case float32: + if s < 0 { + return 0, false, false + } + + return T(s), true, true + case float64: + if s < 0 { + return 0, false, false + } + + return T(s), true, true + case bool: + if s { + return 1, true, true + } + + return 0, true, true + case nil: + return 0, true, true + case time.Weekday: + if s < 0 { + return 0, false, false + } + + return T(s), true, true + case time.Month: + if s < 0 { + return 0, false, false + } + + return T(s), true, true + } + + return 0, true, false +} + +func toUnsignedNumberE[T Number](i any, parseFn func(string) (T, error)) (T, error) { + n, valid, ok := toUnsignedNumber[T](i) + if ok { + return n, nil + } + + i, _ = indirect(i) + + if !valid { + return 0, errNegativeNotAllowed + } + + switch s := i.(type) { + case string: + if s == "" { + return 0, nil + } + + v, err := parseFn(s) + if err != nil { + return 0, fmt.Errorf(errorMsgWith, i, i, n, err) + } + + return v, nil + case json.Number: + if s == "" { + return 0, nil + } + + v, err := parseFn(string(s)) + if err != nil { + return 0, fmt.Errorf(errorMsgWith, i, i, n, err) + } + + return v, nil + case float64EProvider: + if _, ok := any(n).(float64); !ok { + return 0, fmt.Errorf(errorMsg, i, i, n) + } + + v, err := s.Float64() + if err != nil { + return 0, fmt.Errorf(errorMsg, i, i, n) + } + + if v < 0 { + return 0, errNegativeNotAllowed + } + + return T(v), nil + case float64Provider: + if _, ok := any(n).(float64); !ok { + return 0, fmt.Errorf(errorMsg, i, i, n) + } + + v := s.Float64() + + if v < 0 { + return 0, errNegativeNotAllowed + } + + return T(v), nil + default: + if i, ok := resolveAlias(i); ok { + return toUnsignedNumberE(i, parseFn) + } + + return 0, fmt.Errorf(errorMsg, i, i, n) + } +} + +func parseNumber[T Number](s string) (T, error) { + var t T + + switch any(t).(type) { + case int: + v, err := parseInt[int](s) + + return T(v), err + case int8: + v, err := parseInt[int8](s) + + return T(v), err + case int16: + v, err := parseInt[int16](s) + + return T(v), err + case int32: + v, err := parseInt[int32](s) + + return T(v), err + case int64: + v, err := parseInt[int64](s) + + return T(v), err + case uint: + v, err := parseUint[uint](s) + + return T(v), err + case uint8: + v, err := parseUint[uint8](s) + + return T(v), err + case uint16: + v, err := parseUint[uint16](s) + + return T(v), err + case uint32: + v, err := parseUint[uint32](s) + + return T(v), err + case uint64: + v, err := parseUint[uint64](s) + + return T(v), err + case float32: + v, err := strconv.ParseFloat(s, 32) + + return T(v), err + case float64: + v, err := strconv.ParseFloat(s, 64) + + return T(v), err + + default: + return 0, fmt.Errorf("unknown number type: %T", t) + } +} + +func parseInt[T integer](s string) (T, error) { + v, err := strconv.ParseInt(trimDecimal(s), 0, 0) + if err != nil { + return 0, err + } + + return T(v), nil +} + +func parseUint[T unsigned](s string) (T, error) { + v, err := strconv.ParseUint(strings.TrimLeft(trimDecimal(s), "+"), 0, 0) + if err != nil { + return 0, err + } + + return T(v), nil +} + +func parseFloat[T float](s string) (T, error) { + var t T + + var v any + var err error + + switch any(t).(type) { + case float32: + n, e := strconv.ParseFloat(s, 32) + + v = float32(n) + err = e + case float64: + n, e := strconv.ParseFloat(s, 64) + + v = float64(n) + err = e + } + + return v.(T), err +} + +// ToFloat64E casts an interface to a float64 type. +func ToFloat64E(i any) (float64, error) { + return toNumberE[float64](i, parseFloat[float64]) +} + +// ToFloat32E casts an interface to a float32 type. +func ToFloat32E(i any) (float32, error) { + return toNumberE[float32](i, parseFloat[float32]) +} + +// ToInt64E casts an interface to an int64 type. +func ToInt64E(i any) (int64, error) { + return toNumberE[int64](i, parseInt[int64]) +} + +// ToInt32E casts an interface to an int32 type. +func ToInt32E(i any) (int32, error) { + return toNumberE[int32](i, parseInt[int32]) +} + +// ToInt16E casts an interface to an int16 type. +func ToInt16E(i any) (int16, error) { + return toNumberE[int16](i, parseInt[int16]) +} + +// ToInt8E casts an interface to an int8 type. +func ToInt8E(i any) (int8, error) { + return toNumberE[int8](i, parseInt[int8]) +} + +// ToIntE casts an interface to an int type. +func ToIntE(i any) (int, error) { + return toNumberE[int](i, parseInt[int]) +} + +// ToUintE casts an interface to a uint type. +func ToUintE(i any) (uint, error) { + return toUnsignedNumberE[uint](i, parseUint[uint]) +} + +// ToUint64E casts an interface to a uint64 type. +func ToUint64E(i any) (uint64, error) { + return toUnsignedNumberE[uint64](i, parseUint[uint64]) +} + +// ToUint32E casts an interface to a uint32 type. +func ToUint32E(i any) (uint32, error) { + return toUnsignedNumberE[uint32](i, parseUint[uint32]) +} + +// ToUint16E casts an interface to a uint16 type. +func ToUint16E(i any) (uint16, error) { + return toUnsignedNumberE[uint16](i, parseUint[uint16]) +} + +// ToUint8E casts an interface to a uint type. +func ToUint8E(i any) (uint8, error) { + return toUnsignedNumberE[uint8](i, parseUint[uint8]) +} + +func trimZeroDecimal(s string) string { + var foundZero bool + for i := len(s); i > 0; i-- { + switch s[i-1] { + case '.': + if foundZero { + return s[:i-1] + } + case '0': + foundZero = true + default: + return s + } + } + return s +} + +var stringNumberRe = regexp.MustCompile(`^([-+]?\d*)(\.\d*)?$`) + +// see [BenchmarkDecimal] for details about the implementation +func trimDecimal(s string) string { + if !strings.Contains(s, ".") { + return s + } + + matches := stringNumberRe.FindStringSubmatch(s) + if matches != nil { + // matches[1] is the captured integer part with sign + s = matches[1] + + // handle special cases + switch s { + case "-", "+": + s += "0" + case "": + s = "0" + } + + return s + } + + return s +} diff --git a/vendor/github.com/spf13/cast/slice.go b/vendor/github.com/spf13/cast/slice.go new file mode 100644 index 0000000000..e6a8328c60 --- /dev/null +++ b/vendor/github.com/spf13/cast/slice.go @@ -0,0 +1,106 @@ +// Copyright © 2014 Steve Francia . +// +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package cast + +import ( + "fmt" + "reflect" + "strings" +) + +// ToSliceE casts any value to a []any type. +func ToSliceE(i any) ([]any, error) { + i, _ = indirect(i) + + var s []any + + switch v := i.(type) { + case []any: + // TODO: use slices.Clone + return append(s, v...), nil + case []map[string]any: + for _, u := range v { + s = append(s, u) + } + + return s, nil + default: + return s, fmt.Errorf(errorMsg, i, i, s) + } +} + +func toSliceE[T Basic](i any) ([]T, error) { + v, ok, err := toSliceEOk[T](i) + if err != nil { + return nil, err + } + + if !ok { + return nil, fmt.Errorf(errorMsg, i, i, []T{}) + } + + return v, nil +} + +func toSliceEOk[T Basic](i any) ([]T, bool, error) { + i, _ = indirect(i) + if i == nil { + return nil, true, fmt.Errorf(errorMsg, i, i, []T{}) + } + + switch v := i.(type) { + case []T: + // TODO: clone slice + return v, true, nil + } + + kind := reflect.TypeOf(i).Kind() + switch kind { + case reflect.Slice, reflect.Array: + s := reflect.ValueOf(i) + a := make([]T, s.Len()) + + for j := 0; j < s.Len(); j++ { + val, err := ToE[T](s.Index(j).Interface()) + if err != nil { + return nil, true, fmt.Errorf(errorMsg, i, i, []T{}) + } + + a[j] = val + } + + return a, true, nil + default: + return nil, false, nil + } +} + +// ToStringSliceE casts any value to a []string type. +func ToStringSliceE(i any) ([]string, error) { + if a, ok, err := toSliceEOk[string](i); ok { + if err != nil { + return nil, err + } + + return a, nil + } + + var a []string + + switch v := i.(type) { + case string: + return strings.Fields(v), nil + case any: + str, err := ToStringE(v) + if err != nil { + return nil, fmt.Errorf(errorMsg, i, i, a) + } + + return []string{str}, nil + default: + return nil, fmt.Errorf(errorMsg, i, i, a) + } +} diff --git a/vendor/github.com/spf13/cast/time.go b/vendor/github.com/spf13/cast/time.go new file mode 100644 index 0000000000..744cd5accd --- /dev/null +++ b/vendor/github.com/spf13/cast/time.go @@ -0,0 +1,116 @@ +// Copyright © 2014 Steve Francia . +// +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package cast + +import ( + "encoding/json" + "errors" + "fmt" + "strings" + "time" + + "github.com/spf13/cast/internal" +) + +// ToTimeE any value to a [time.Time] type. +func ToTimeE(i any) (time.Time, error) { + return ToTimeInDefaultLocationE(i, time.UTC) +} + +// ToTimeInDefaultLocationE casts an empty interface to [time.Time], +// interpreting inputs without a timezone to be in the given location, +// or the local timezone if nil. +func ToTimeInDefaultLocationE(i any, location *time.Location) (tim time.Time, err error) { + i, _ = indirect(i) + + switch v := i.(type) { + case time.Time: + return v, nil + case string: + return StringToDateInDefaultLocation(v, location) + case json.Number: + // Originally this used ToInt64E, but adding string float conversion broke ToTime. + // the behavior of ToTime would have changed if we continued using it. + // For now, using json.Number's own Int64 method should be good enough to preserve backwards compatibility. + v = json.Number(trimZeroDecimal(string(v))) + s, err1 := v.Int64() + if err1 != nil { + return time.Time{}, fmt.Errorf(errorMsg, i, i, time.Time{}) + } + return time.Unix(s, 0), nil + case int: + return time.Unix(int64(v), 0), nil + case int32: + return time.Unix(int64(v), 0), nil + case int64: + return time.Unix(v, 0), nil + case uint: + return time.Unix(int64(v), 0), nil + case uint32: + return time.Unix(int64(v), 0), nil + case uint64: + return time.Unix(int64(v), 0), nil + case nil: + return time.Time{}, nil + default: + return time.Time{}, fmt.Errorf(errorMsg, i, i, time.Time{}) + } +} + +// ToDurationE casts any value to a [time.Duration] type. +func ToDurationE(i any) (time.Duration, error) { + i, _ = indirect(i) + + switch s := i.(type) { + case time.Duration: + return s, nil + case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: + v, err := ToInt64E(s) + if err != nil { + // TODO: once there is better error handling, this should be easier + return 0, errors.New(strings.ReplaceAll(err.Error(), " int64", "time.Duration")) + } + + return time.Duration(v), nil + case float32, float64, float64EProvider, float64Provider: + v, err := ToFloat64E(s) + if err != nil { + // TODO: once there is better error handling, this should be easier + return 0, errors.New(strings.ReplaceAll(err.Error(), " float64", "time.Duration")) + } + + return time.Duration(v), nil + case string: + if !strings.ContainsAny(s, "nsuµmh") { + return time.ParseDuration(s + "ns") + } + + return time.ParseDuration(s) + case nil: + return time.Duration(0), nil + default: + if i, ok := resolveAlias(i); ok { + return ToDurationE(i) + } + + return 0, fmt.Errorf(errorMsg, i, i, time.Duration(0)) + } +} + +// StringToDate attempts to parse a string into a [time.Time] type using a +// predefined list of formats. +// +// If no suitable format is found, an error is returned. +func StringToDate(s string) (time.Time, error) { + return internal.ParseDateWith(s, time.UTC, internal.TimeFormats) +} + +// StringToDateInDefaultLocation casts an empty interface to a [time.Time], +// interpreting inputs without a timezone to be in the given location, +// or the local timezone if nil. +func StringToDateInDefaultLocation(s string, location *time.Location) (time.Time, error) { + return internal.ParseDateWith(s, location, internal.TimeFormats) +} diff --git a/vendor/github.com/spf13/cast/zz_generated.go b/vendor/github.com/spf13/cast/zz_generated.go new file mode 100644 index 0000000000..ce3ec0f78f --- /dev/null +++ b/vendor/github.com/spf13/cast/zz_generated.go @@ -0,0 +1,261 @@ +// Code generated by cast generator. DO NOT EDIT. + +package cast + +import "time" + +// ToBool casts any value to a(n) bool type. +func ToBool(i any) bool { + v, _ := ToBoolE(i) + return v +} + +// ToString casts any value to a(n) string type. +func ToString(i any) string { + v, _ := ToStringE(i) + return v +} + +// ToTime casts any value to a(n) time.Time type. +func ToTime(i any) time.Time { + v, _ := ToTimeE(i) + return v +} + +// ToTimeInDefaultLocation casts any value to a(n) time.Time type. +func ToTimeInDefaultLocation(i any, location *time.Location) time.Time { + v, _ := ToTimeInDefaultLocationE(i, location) + return v +} + +// ToDuration casts any value to a(n) time.Duration type. +func ToDuration(i any) time.Duration { + v, _ := ToDurationE(i) + return v +} + +// ToInt casts any value to a(n) int type. +func ToInt(i any) int { + v, _ := ToIntE(i) + return v +} + +// ToInt8 casts any value to a(n) int8 type. +func ToInt8(i any) int8 { + v, _ := ToInt8E(i) + return v +} + +// ToInt16 casts any value to a(n) int16 type. +func ToInt16(i any) int16 { + v, _ := ToInt16E(i) + return v +} + +// ToInt32 casts any value to a(n) int32 type. +func ToInt32(i any) int32 { + v, _ := ToInt32E(i) + return v +} + +// ToInt64 casts any value to a(n) int64 type. +func ToInt64(i any) int64 { + v, _ := ToInt64E(i) + return v +} + +// ToUint casts any value to a(n) uint type. +func ToUint(i any) uint { + v, _ := ToUintE(i) + return v +} + +// ToUint8 casts any value to a(n) uint8 type. +func ToUint8(i any) uint8 { + v, _ := ToUint8E(i) + return v +} + +// ToUint16 casts any value to a(n) uint16 type. +func ToUint16(i any) uint16 { + v, _ := ToUint16E(i) + return v +} + +// ToUint32 casts any value to a(n) uint32 type. +func ToUint32(i any) uint32 { + v, _ := ToUint32E(i) + return v +} + +// ToUint64 casts any value to a(n) uint64 type. +func ToUint64(i any) uint64 { + v, _ := ToUint64E(i) + return v +} + +// ToFloat32 casts any value to a(n) float32 type. +func ToFloat32(i any) float32 { + v, _ := ToFloat32E(i) + return v +} + +// ToFloat64 casts any value to a(n) float64 type. +func ToFloat64(i any) float64 { + v, _ := ToFloat64E(i) + return v +} + +// ToStringMapString casts any value to a(n) map[string]string type. +func ToStringMapString(i any) map[string]string { + v, _ := ToStringMapStringE(i) + return v +} + +// ToStringMapStringSlice casts any value to a(n) map[string][]string type. +func ToStringMapStringSlice(i any) map[string][]string { + v, _ := ToStringMapStringSliceE(i) + return v +} + +// ToStringMapBool casts any value to a(n) map[string]bool type. +func ToStringMapBool(i any) map[string]bool { + v, _ := ToStringMapBoolE(i) + return v +} + +// ToStringMapInt casts any value to a(n) map[string]int type. +func ToStringMapInt(i any) map[string]int { + v, _ := ToStringMapIntE(i) + return v +} + +// ToStringMapInt64 casts any value to a(n) map[string]int64 type. +func ToStringMapInt64(i any) map[string]int64 { + v, _ := ToStringMapInt64E(i) + return v +} + +// ToStringMap casts any value to a(n) map[string]any type. +func ToStringMap(i any) map[string]any { + v, _ := ToStringMapE(i) + return v +} + +// ToSlice casts any value to a(n) []any type. +func ToSlice(i any) []any { + v, _ := ToSliceE(i) + return v +} + +// ToBoolSlice casts any value to a(n) []bool type. +func ToBoolSlice(i any) []bool { + v, _ := ToBoolSliceE(i) + return v +} + +// ToStringSlice casts any value to a(n) []string type. +func ToStringSlice(i any) []string { + v, _ := ToStringSliceE(i) + return v +} + +// ToIntSlice casts any value to a(n) []int type. +func ToIntSlice(i any) []int { + v, _ := ToIntSliceE(i) + return v +} + +// ToInt64Slice casts any value to a(n) []int64 type. +func ToInt64Slice(i any) []int64 { + v, _ := ToInt64SliceE(i) + return v +} + +// ToUintSlice casts any value to a(n) []uint type. +func ToUintSlice(i any) []uint { + v, _ := ToUintSliceE(i) + return v +} + +// ToFloat64Slice casts any value to a(n) []float64 type. +func ToFloat64Slice(i any) []float64 { + v, _ := ToFloat64SliceE(i) + return v +} + +// ToDurationSlice casts any value to a(n) []time.Duration type. +func ToDurationSlice(i any) []time.Duration { + v, _ := ToDurationSliceE(i) + return v +} + +// ToBoolSliceE casts any value to a(n) []bool type. +func ToBoolSliceE(i any) ([]bool, error) { + return toSliceE[bool](i) +} + +// ToDurationSliceE casts any value to a(n) []time.Duration type. +func ToDurationSliceE(i any) ([]time.Duration, error) { + return toSliceE[time.Duration](i) +} + +// ToIntSliceE casts any value to a(n) []int type. +func ToIntSliceE(i any) ([]int, error) { + return toSliceE[int](i) +} + +// ToInt8SliceE casts any value to a(n) []int8 type. +func ToInt8SliceE(i any) ([]int8, error) { + return toSliceE[int8](i) +} + +// ToInt16SliceE casts any value to a(n) []int16 type. +func ToInt16SliceE(i any) ([]int16, error) { + return toSliceE[int16](i) +} + +// ToInt32SliceE casts any value to a(n) []int32 type. +func ToInt32SliceE(i any) ([]int32, error) { + return toSliceE[int32](i) +} + +// ToInt64SliceE casts any value to a(n) []int64 type. +func ToInt64SliceE(i any) ([]int64, error) { + return toSliceE[int64](i) +} + +// ToUintSliceE casts any value to a(n) []uint type. +func ToUintSliceE(i any) ([]uint, error) { + return toSliceE[uint](i) +} + +// ToUint8SliceE casts any value to a(n) []uint8 type. +func ToUint8SliceE(i any) ([]uint8, error) { + return toSliceE[uint8](i) +} + +// ToUint16SliceE casts any value to a(n) []uint16 type. +func ToUint16SliceE(i any) ([]uint16, error) { + return toSliceE[uint16](i) +} + +// ToUint32SliceE casts any value to a(n) []uint32 type. +func ToUint32SliceE(i any) ([]uint32, error) { + return toSliceE[uint32](i) +} + +// ToUint64SliceE casts any value to a(n) []uint64 type. +func ToUint64SliceE(i any) ([]uint64, error) { + return toSliceE[uint64](i) +} + +// ToFloat32SliceE casts any value to a(n) []float32 type. +func ToFloat32SliceE(i any) ([]float32, error) { + return toSliceE[float32](i) +} + +// ToFloat64SliceE casts any value to a(n) []float64 type. +func ToFloat64SliceE(i any) ([]float64, error) { + return toSliceE[float64](i) +} diff --git a/vendor/modules.txt b/vendor/modules.txt index 56fdf01d91..5b947ffdf3 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -465,11 +465,11 @@ github.com/lufia/plan9stats github.com/mailru/easyjson/buffer github.com/mailru/easyjson/jlexer github.com/mailru/easyjson/jwriter -# github.com/mark3labs/mcp-go v0.18.0 +# github.com/mark3labs/mcp-go v0.33.0 ## explicit; go 1.23 -github.com/mark3labs/mcp-go/client github.com/mark3labs/mcp-go/mcp github.com/mark3labs/mcp-go/server +github.com/mark3labs/mcp-go/util # github.com/maruel/natural v1.1.0 ## explicit; go 1.11 github.com/maruel/natural @@ -603,9 +603,10 @@ github.com/skratchdot/open-golang/open # github.com/sosodev/duration v1.3.1 ## explicit; go 1.17 github.com/sosodev/duration -# github.com/spf13/cast v1.3.0 -## explicit +# github.com/spf13/cast v1.9.2 +## explicit; go 1.21.0 github.com/spf13/cast +github.com/spf13/cast/internal # github.com/spf13/cobra v1.1.1 ## explicit; go 1.12 github.com/spf13/cobra From 397c278ded37cf9d3a098cc5c9f5858f50324fac Mon Sep 17 00:00:00 2001 From: Nathan Rijksen Date: Tue, 15 Jul 2025 09:47:48 -0700 Subject: [PATCH 25/29] Handle cast library returning nil values if source is nil Not sure why this is suddenly breaking. The cast library was updated, but the code paths for the errors I'm seeing were not.. --- internal/config/instance.go | 18 +++++++++++++++--- pkg/projectfile/projectfile.go | 9 +++++++++ 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/internal/config/instance.go b/internal/config/instance.go index d1c1a47e24..05547ee13a 100644 --- a/internal/config/instance.go +++ b/internal/config/instance.go @@ -219,7 +219,11 @@ func (i *Instance) AllKeys() []string { // GetStringMapStringSlice retrieves a map of string slices for a given key func (i *Instance) GetStringMapStringSlice(key string) map[string][]string { - return cast.ToStringMapStringSlice(i.Get(key)) + v := cast.ToStringMapStringSlice(i.Get(key)) + if v == nil { + return map[string][]string{} + } + return v } // GetBool retrieves a boolean value for a given key @@ -229,7 +233,11 @@ func (i *Instance) GetBool(key string) bool { // GetStringSlice retrieves a slice of strings for a given key func (i *Instance) GetStringSlice(key string) []string { - return cast.ToStringSlice(i.Get(key)) + v := cast.ToStringSlice(i.Get(key)) + if v == nil { + return []string{} + } + return v } // GetTime retrieves a time instance for a given key @@ -239,7 +247,11 @@ func (i *Instance) GetTime(key string) time.Time { // GetStringMap retrieves a map of strings to values for a given key func (i *Instance) GetStringMap(key string) map[string]interface{} { - return cast.ToStringMap(i.Get(key)) + v := cast.ToStringMap(i.Get(key)) + if v == nil { + return map[string]interface{}{} + } + return v } // ConfigPath returns the path at which our configuration is stored diff --git a/pkg/projectfile/projectfile.go b/pkg/projectfile/projectfile.go index b4b105b29b..3879d816e2 100644 --- a/pkg/projectfile/projectfile.go +++ b/pkg/projectfile/projectfile.go @@ -1263,6 +1263,9 @@ func addDeprecatedProjectMappings(cfg ConfigGetter) { if err != nil && v != nil { // don't report if error due to nil input multilog.Log(logging.ErrorNoStacktrace, rollbar.Error)("Projects data in config is abnormal (type: %T)", v) } + if projects == nil { + projects = map[string][]string{} + } keys := funk.FilterString(cfg.AllKeys(), func(v string) bool { return strings.HasPrefix(v, "project_") @@ -1316,6 +1319,9 @@ func StoreProjectMapping(cfg ConfigGetter, namespace, projectPath string) { if err != nil && v != nil { // don't report if error due to nil input multilog.Log(logging.ErrorNoStacktrace, rollbar.Error)("Projects data in config is abnormal (type: %T)", v) } + if projects == nil { + projects = make(map[string][]string) + } projectPath, err = fileutils.ResolveUniquePath(projectPath) if err != nil { @@ -1370,6 +1376,9 @@ func CleanProjectMapping(cfg ConfigGetter) { if err != nil && v != nil { // don't report if error due to nil input multilog.Log(logging.ErrorNoStacktrace, rollbar.Error)("Projects data in config is abnormal (type: %T)", v) } + if projects == nil { + projects = make(map[string][]string) + } seen := make(map[string]struct{}) From 6d41be774b04835c78aa501004e9d69e1d9f9d17 Mon Sep 17 00:00:00 2001 From: Nathan Rijksen Date: Tue, 15 Jul 2025 09:47:59 -0700 Subject: [PATCH 26/29] Drop debugging code --- internal/runners/cve/cve.go | 7 ------- 1 file changed, 7 deletions(-) diff --git a/internal/runners/cve/cve.go b/internal/runners/cve/cve.go index a46da01c26..03739140cd 100644 --- a/internal/runners/cve/cve.go +++ b/internal/runners/cve/cve.go @@ -9,7 +9,6 @@ import ( "github.com/ActiveState/cli/internal/errs" "github.com/ActiveState/cli/internal/locale" - "github.com/ActiveState/cli/internal/logging" "github.com/ActiveState/cli/internal/output" "github.com/ActiveState/cli/internal/output/renderers" "github.com/ActiveState/cli/internal/primer" @@ -61,12 +60,6 @@ type cveOutput struct { } func (r *Cve) Run(params *Params) error { - defer func() { - if rc := recover(); rc != nil { - logging.Error("Recovered from panic: %v", rc) - fmt.Printf("Recovered from panic: %v\n", rc) - } - }() if !params.Namespace.IsValid() && r.proj == nil { return rationalize.ErrNoProject } From 38e51d5c0c027cd9a18a3ec730f1f1607e4e06d4 Mon Sep 17 00:00:00 2001 From: Nathan Rijksen Date: Tue, 15 Jul 2025 13:30:56 -0700 Subject: [PATCH 27/29] Fix test not constructing outputter properly --- internal/output/plain_test.go | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/internal/output/plain_test.go b/internal/output/plain_test.go index 836e8c0947..4f94c813c7 100644 --- a/internal/output/plain_test.go +++ b/internal/output/plain_test.go @@ -6,6 +6,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func nilStr(s string) *string { @@ -207,7 +208,7 @@ func TestPlain_Print(t *testing.T) { }, }, " field_header1 A \n" + - " field_header2 B \n" + + " field_hader2 B \n" + " field_header3 C \n", "", }, @@ -238,12 +239,13 @@ func TestPlain_Print(t *testing.T) { outWriter := &bytes.Buffer{} errWriter := &bytes.Buffer{} - f := &Plain{&Config{ + f, err := NewPlain(&Config{ OutWriter: outWriter, ErrWriter: errWriter, Colored: false, Interactive: false, - }} + }) + require.NoError(t, err) f.Print(tt.args.value) assert.Equal(t, tt.expectedOut, outWriter.String(), "Output did not match") @@ -274,12 +276,13 @@ func TestPlain_Notice(t *testing.T) { outWriter := &bytes.Buffer{} errWriter := &bytes.Buffer{} - f := &Plain{&Config{ + f, err := NewPlain(&Config{ OutWriter: outWriter, ErrWriter: errWriter, Colored: false, Interactive: false, - }} + }) + require.NoError(t, err) f.Notice(tt.args.value) assert.Equal(t, tt.expectedOut, outWriter.String(), "Output did not match") @@ -310,12 +313,13 @@ func TestPlain_Error(t *testing.T) { outWriter := &bytes.Buffer{} errWriter := &bytes.Buffer{} - f := &Plain{&Config{ + f, err := NewPlain(&Config{ OutWriter: outWriter, ErrWriter: errWriter, Colored: false, Interactive: false, - }} + }) + require.NoError(t, err) f.Error(tt.args.value) assert.Equal(t, tt.expectedOut, outWriter.String(), "Output did not match") From c97f5d8cfaaf5c502b4e59c35da06eee9af3d5ac Mon Sep 17 00:00:00 2001 From: Nathan Rijksen Date: Tue, 15 Jul 2025 14:31:07 -0700 Subject: [PATCH 28/29] Fix tests --- internal/output/json_test.go | 16 ++++++++++------ internal/output/plain_test.go | 2 +- internal/output/simple_test.go | 6 ++++-- 3 files changed, 15 insertions(+), 9 deletions(-) diff --git a/internal/output/json_test.go b/internal/output/json_test.go index 3c6e1deef1..2f42dbe290 100644 --- a/internal/output/json_test.go +++ b/internal/output/json_test.go @@ -7,6 +7,7 @@ import ( "github.com/ActiveState/cli/internal/locale" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestJSON_Print(t *testing.T) { @@ -65,12 +66,13 @@ func TestJSON_Print(t *testing.T) { outWriter := &bytes.Buffer{} errWriter := &bytes.Buffer{} - f := &JSON{cfg: &Config{ + f, err := NewJSON(&Config{ OutWriter: outWriter, ErrWriter: errWriter, Colored: false, Interactive: false, - }} + }) + require.NoError(t, err) f.Print(tt.args.value) assert.Equal(t, tt.expectedOut, outWriter.String(), "Output did not match") @@ -101,12 +103,13 @@ func TestJSON_Notice(t *testing.T) { outWriter := &bytes.Buffer{} errWriter := &bytes.Buffer{} - f := &JSON{cfg: &Config{ + f, err := NewJSON(&Config{ OutWriter: outWriter, ErrWriter: errWriter, Colored: false, Interactive: false, - }} + }) + require.NoError(t, err) f.Notice(tt.args.value) assert.Equal(t, tt.expectedOut, outWriter.String(), "Output did not match") @@ -155,12 +158,13 @@ func TestJSON_Error(t *testing.T) { outWriter := &bytes.Buffer{} errWriter := &bytes.Buffer{} - f := &JSON{cfg: &Config{ + f, err := NewJSON(&Config{ OutWriter: outWriter, ErrWriter: errWriter, Colored: false, Interactive: false, - }} + }) + require.NoError(t, err) f.Error(tt.args.value) assert.Equal(t, tt.expectedOut, outWriter.String(), "Output did not match") diff --git a/internal/output/plain_test.go b/internal/output/plain_test.go index 4f94c813c7..c80ac212ed 100644 --- a/internal/output/plain_test.go +++ b/internal/output/plain_test.go @@ -208,7 +208,7 @@ func TestPlain_Print(t *testing.T) { }, }, " field_header1 A \n" + - " field_hader2 B \n" + + " field_header2 B \n" + " field_header3 C \n", "", }, diff --git a/internal/output/simple_test.go b/internal/output/simple_test.go index 143d084ebf..48fa099fe2 100644 --- a/internal/output/simple_test.go +++ b/internal/output/simple_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestSimple_Notice(t *testing.T) { @@ -29,12 +30,13 @@ func TestSimple_Notice(t *testing.T) { outWriter := &bytes.Buffer{} errWriter := &bytes.Buffer{} - f := Simple{Plain{&Config{ + f, err := NewSimple(&Config{ OutWriter: outWriter, ErrWriter: errWriter, Colored: false, Interactive: false, - }}} + }) + require.NoError(t, err) f.Notice(tt.args.value) assert.Equal(t, tt.expectedOut, outWriter.String(), "Output did not match") From 8ffadb2376e804bd449ebeb28f5dbe18f890eaa4 Mon Sep 17 00:00:00 2001 From: Nathan Rijksen Date: Wed, 16 Jul 2025 10:42:55 -0700 Subject: [PATCH 29/29] Drop donotshipme package; we can in fact ship this now --- cmd/state/donotshipme/donotshipme.go | 17 ----------------- 1 file changed, 17 deletions(-) delete mode 100644 cmd/state/donotshipme/donotshipme.go diff --git a/cmd/state/donotshipme/donotshipme.go b/cmd/state/donotshipme/donotshipme.go deleted file mode 100644 index b021fc9ca5..0000000000 --- a/cmd/state/donotshipme/donotshipme.go +++ /dev/null @@ -1,17 +0,0 @@ -package donotshipme - -import ( - "github.com/ActiveState/cli/cmd/state/internal/cmdtree" - "github.com/ActiveState/cli/internal/constants" - "github.com/ActiveState/cli/internal/primer" -) - -func init() { - if constants.ChannelName == "release" { - panic("This file is for experimentation only, it should not be shipped as is. CmdTree is internal to the State command and should remain that way or be refactored.") - } -} - -func CmdTree(prime *primer.Values, args ...string) *cmdtree.CmdTree { - return cmdtree.New(prime, args...) -} \ No newline at end of file