diff --git a/CLAUDE.md b/CLAUDE.md index 3a6fec8c..410f7d5f 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -125,13 +125,12 @@ models.register("my-model", custom_model_instance) ### Agent OS Abstraction -`AgentOs` provides an abstraction layer for OS-level operations: +`ComputerAgentOS` provides an abstraction layer for OS-level operations: ``` -AgentOs (Abstract Interface) - ├── AskUiControllerClient (gRPC to AskUI Agent OS - primary) +ComputerAgentOS (Abstract Interface) + ├── MultiComputerTargetAgentOS (gRPC to AskUI Agent OS - primary) ├── PlaywrightAgentOs (Web browser automation) - └── AndroidAgentOs (Android ADB) ``` ### Locator System @@ -175,7 +174,7 @@ Tools are auto-discovered and can be dynamically loaded via MCP configurations. - `src/askui/prompts/` - System prompts for different models ### Tools & OS -- `src/askui/tools/agent_os.py` - Abstract `AgentOs` interface +- `src/askui/tools/agent_os.py` - Abstract `ComputerAgentOS` interface - `src/askui/tools/askui/` - gRPC client for AskUI Agent OS - `src/askui/tools/android/` - Android-specific tools - `src/askui/tools/playwright/` - Web automation tools @@ -247,7 +246,7 @@ When writing or updating documentation in `docs/`: ## Important Patterns ### Composition over Inheritance -- `AgentToolbox` wraps `AgentOs` implementations +- `AgentToolbox` wraps `ComputerAgentOS` implementations - `ModelRouter` composes multiple model providers - `CompositeReporter` aggregates multiple reporters @@ -261,7 +260,7 @@ When writing or updating documentation in `docs/`: - Retry strategies with exponential backoff ### Adapter Pattern -- `AgentOs` abstraction bridges OS implementations (gRPC, Playwright, ADB) +- `ComputerAgentOS` abstraction bridges OS implementations (gRPC, Playwright, ADB) - `ModelFacade` adapts models to `ActModel`/`GetModel`/`LocateModel` interfaces ### Dependency Injection @@ -299,13 +298,13 @@ When writing or updating documentation in `docs/`: ### Adding Custom Tools 1. Implement `Tool` protocol in `models/shared/tools.py` 2. Register in appropriate MCP server (`api/mcp_servers/{type}.py`) -3. Use `@auto_inject_agent_os` for AgentOs dependency +3. Use `@auto_inject_agent_os` for ComputerAgentOS dependency 4. Follow Pydantic schema validation ### Adding New Agent Types 1. Inherit from `Agent` 2. Implement required abstract methods -3. Provide appropriate `AgentOs` implementation +3. Provide appropriate `ComputerAgentOS` implementation 4. Register in agent factory if needed ## Performance & Caching diff --git a/README.md b/README.md index e591d500..241a15cd 100644 --- a/README.md +++ b/README.md @@ -125,6 +125,7 @@ Ready to build your first agent? Check out our documentation: 10. **[Extracting Data](docs/10_extracting_data.md)** - Extracting structured data from screenshots and files 11. **[Callbacks](docs/11_callbacks.md)** - Inject custom logic into the control loop 12. **[Secrets](docs/12_secrets.md)** - Let agents use sensitive values without exposing them to the LLM +13. **[Multiple Target Computers](docs/13_multi_target_computers.md)** - Drive several machines from one `ComputerAgent` **Official documentation:** [docs.askui.com](https://docs.askui.com) diff --git a/docs/02_using_agents.md b/docs/02_using_agents.md index a75e531b..c2c8745c 100644 --- a/docs/02_using_agents.md +++ b/docs/02_using_agents.md @@ -13,7 +13,9 @@ with ComputerAgent() as agent: agent.act("Open the mail app and summarize all unread emails") ``` -**Default tools:** `screenshot`, `mouse_click`, `mouse_move`, `mouse_scroll`, `mouse_hold_down`, `mouse_release`, `type`, `keyboard_tap`, `keyboard_pressed`, `keyboard_release`, `get_mouse_position`, `get_system_info`, `list_displays`, `retrieve_active_display`, `set_active_display` +**Default tools:** `screenshot`, `mouse_click`, `mouse_move`, `mouse_scroll`, `mouse_hold_down`, `mouse_release`, `type`, `keyboard_tap`, `keyboard_pressed`, `keyboard_release`, `get_mouse_position`, `get_system_info`, `list_displays`, `retrieve_active_display`, `set_active_display`, `list_agent_os_target_computers`, `switch_agent_os_target_computer`, `get_current_computer_target_id` + +A single `ComputerAgent` can also drive multiple machines (local and remote) at once. See [Multiple Target Computers](13_multi_target_computers.md). ## AndroidAgent diff --git a/docs/07_tools.md b/docs/07_tools.md index 737c10f6..8f26424e 100644 --- a/docs/07_tools.md +++ b/docs/07_tools.md @@ -68,7 +68,7 @@ Work with any agent type, no special dependencies required. #### Computer Tools (`computer/`) -Require `AgentOs` and work with `ComputerAgent` for desktop automation. +Require `ComputerAgentOS` and work with `ComputerAgent` for desktop automation. **Examples:** - `ComputerSaveScreenshotTool(base_dir)` - Save screenshots to disk @@ -314,3 +314,122 @@ with ComputerAgent() as agent: tools=[GreetingTool()], ) ``` + +### Restricting a tool to one device type (computer or android) + +`GreetingTool` above subclasses `Tool` because it is pure logic and never touches a device. A tool that needs to drive a device should instead subclass one of the device-specific base classes: + +- `ComputerBaseTool` — gives the tool a typed `self.agent_os` (a `ComputerAgentOS`) and restricts it to **computer/desktop** targets. +- `AndroidBaseTool` — gives the tool a typed `self.agent_os` (an `AndroidAgentOs`) and restricts it to **Android** targets. + +Both are importable from `askui.models.shared`. + +#### How the restriction works + +Every tool carries a list of `required_tags`, and every agent OS carries a list of `tags`. When `act()` starts, the SDK binds each tool to the **first registered agent OS whose `tags` contain all of the tool's `required_tags`**. The base classes set this up for you: + +| Base class | `required_tags` | +|------------|-----------------| +| `Tool` | `[]` — binds to any agent OS (or none) | +| `ComputerBaseTool` | `["computer"]` | +| `AndroidBaseTool` | `["android"]` | + +The agent OS implementations are tagged accordingly: desktop ones report `"computer"` and Android ones report `"android"` (the coordinate-scaling facades additionally add `"scaled_agent_os"`). So a `ComputerBaseTool` can never be bound to an Android device, and vice versa. + +You can also pass extra `required_tags` to narrow further, e.g. `super().__init__(..., required_tags=["scaled_agent_os"])` to require the scaling facade specifically. + +#### Example: a computer-only tool + +```python +from askui.models.shared import ComputerBaseTool + + +class ComputerScreenSizeTool(ComputerBaseTool): + """Reports the pixel size of the active computer screen. + + Subclassing `ComputerBaseTool` tags this tool as `"computer"`, so it is + only ever bound to a computer (desktop) agent OS — never to an Android + device. `self.agent_os` is therefore a `ComputerAgentOS`. + """ + + def __init__(self) -> None: + super().__init__( + name="get_screen_size", + description="Return the width and height in pixels of the active computer screen.", + input_schema={"type": "object", "properties": {}}, + ) + + def __call__(self) -> str: + screenshot = self.agent_os.screenshot() + return f"{screenshot.width}x{screenshot.height}" +``` + +#### Where this matters + +The restriction is enforced whenever more than one agent OS is registered for a single `act()` call — most notably with [`MultiDeviceAgent`](02_using_agents.md#multideviceagent), which registers both a computer and an Android agent OS: + +```python +from askui import MultiDeviceAgent + +with MultiDeviceAgent(android_device_sn="emulator-5554") as agent: + agent.act( + "Read the screen size on the computer, then take a screenshot on the phone", + # ComputerScreenSizeTool is given only to the computer agent OS; + # an AndroidBaseTool would be given only to the Android device. + tools=[ComputerScreenSizeTool()], + ) +``` + +#### Pinning a tool to a specific machine (auto-switch) + +The tag-based restriction is by device *type* (computer vs Android), not by an individual target machine. When you drive [multiple computer targets](13_multi_target_computers.md) from one agent, every `ComputerBaseTool` shares the same computer agent OS and runs against whichever target is currently *active*. + +To bind a tool to one specific machine, have it **auto-switch** to that target inside `__call__`. `self.agent_os.temporary_select(computer_id)` activates the given target for the duration of the block and restores the previously active target on exit (even if the body raises), so the tool always acts on its machine without disturbing the rest of the run: + +```python +from askui.models.shared import ComputerBaseTool + + +class ScreenSizeOfMachineTool(ComputerBaseTool): + """Reports the screen size of one specific computer target. + + The tool is bound to a `computer_id` and auto-switches to that target for + the duration of the call, regardless of which target is currently active. + """ + + def __init__(self, computer_id: str) -> None: + super().__init__( + name="get_screen_size_of_machine", + description="Return the screen size of the machine this tool is bound to.", + input_schema={"type": "object", "properties": {}}, + ) + self._computer_id = computer_id + + def __call__(self) -> str: + with self.agent_os.temporary_select(self._computer_id): + screenshot = self.agent_os.screenshot() + return f"{screenshot.width}x{screenshot.height}" +``` + +```python +from askui import ComputerAgent +from askui.tools.askui import LocalComputerTarget, RemoteComputerTarget + +with ComputerAgent( + agent_os_target_computers=[ + LocalComputerTarget(computer_id="local-box"), + RemoteComputerTarget( + address="192.168.1.42:26000", + description="Remote box", + computer_id="remote-box", + ), + ], +) as agent: + # This tool always measures "remote-box", even though "local-box" is active. + agent.act( + "Report the screen size of the remote machine", + tools=[ScreenSizeOfMachineTool(computer_id="remote-box")], + ) +``` + +> **Note:** the `computer_id` you pass must match one registered via `agent_os_target_computers` — `temporary_select` raises if no such target exists. diff --git a/docs/13_multi_target_computers.md b/docs/13_multi_target_computers.md new file mode 100644 index 00000000..2cde6faa --- /dev/null +++ b/docs/13_multi_target_computers.md @@ -0,0 +1,104 @@ +# Multiple Target Computers + +A single `ComputerAgent` can drive **one or more machines** through the `agent_os_target_computers` argument. Each entry is an Agent OS *target computer* identified by a stable `computer_id`. This lets one agent (and one `act()` run) coordinate work across several machines — for example, research something on one computer and write up the findings on another. + +## Target types + +| Target | What it does | +|--------|--------------| +| `LocalComputerTarget` | Manages an Agent OS controller subprocess on **this** machine. At most one per agent. | +| `RemoteComputerTarget` | Points at an Agent OS controller already running on **another** machine, reachable over gRPC. No process management — the controller must already be running. | + +Both are importable from `askui.tools.askui`: + +```python +from askui.tools.askui import LocalComputerTarget, RemoteComputerTarget +``` + +## The active target + +At any moment exactly **one** target is *active* and receives all explicit calls (`click`, `type`, `keyboard`, ...). The **first** entry in `agent_os_target_computers` is the initial active target. + +```python +from askui import ComputerAgent +from askui.tools.askui import LocalComputerTarget, RemoteComputerTarget + +with ComputerAgent( + agent_os_target_computers=[ + LocalComputerTarget(computer_id="local-box"), + RemoteComputerTarget( + address="192.168.1.42:26000", + description="Remote box with a text editor open", + computer_id="remote-box", + ), + ], +) as agent: + # "local-box" is active by default (first in the list). + agent.click("Submit button") + + # Permanently switch the active target. + agent.tools.os.switch_agent_os_target_computer("remote-box") + agent.type("Typed on the remote box") + + # Temporarily switch for a block, then restore the previous target on exit. + with agent.tools.os.temporary_select("local-box"): + agent.act("Open the settings menu") + # "remote-box" is active again here. +``` + +Connections to all registered targets stay open across switches — switching only changes which connection future actions are routed to. + +## Letting `act()` orchestrate across machines + +The `act()` model is given three extra tools so it can move between machines on its own: + +- `list_agent_os_target_computers` — discover the available targets and their `computer_id`s. +- `switch_agent_os_target_computer` — make a target active. +- `get_current_computer_target_id` — check which target is active. + +Give each target a clear `description` so the model knows what each machine is for: + +```python +with ComputerAgent( + agent_os_target_computers=[ + LocalComputerTarget(computer_id="research-box"), + RemoteComputerTarget( + address="192.168.1.42:26000", + description="Writer box with a text editor open", + computer_id="writer-box", + ), + ], +) as agent: + agent.act( + "On research-box, open a browser, google 'askui', and read the top " + "results to gather key facts about what AskUI is, what it does, and " + "notable features. Then switch to writer-box and write a Markdown " + "document titled 'AskUI Findings' summarizing those facts as a " + "bulleted list in the open text editor." + ) +``` + +## Runtime helpers + +These are available on `agent.tools.os`: + +| Method | Purpose | +|--------|---------| +| `switch_agent_os_target_computer(computer_id)` | Make a target active and keep it active. Returns the now-active `ComputerTarget`. | +| `temporary_select(computer_id)` | Context manager that activates a target for a `with` block and restores the previously active one on exit (even if the block raises). | +| `get_current_computer_target_id()` | Return the `computer_id` of the active target. | +| `describe_agent_os_target_computers()` | Return a readable description of every registered target. | +| `add_agent_os_target_computer(target)` | Register an additional target at runtime (auto-connects if the agent is already connected). | +| `reset_agent_os_target_computers([...])` | Disconnect and replace the registered target list. | + +## Constraints + +- At least one target must be registered. +- At most one `LocalComputerTarget` per agent (any number of `RemoteComputerTarget`s is allowed). +- All `computer_id`s must be unique, and all remote `address`es must be unique. +- If `computer_id` is omitted, it defaults to the target's auto-generated `session_guid`. +- When `agent_os_target_computers` is provided, the top-level `display` argument is ignored — set `display` on the individual targets instead. + +## Full example + +See [`examples/multi_target_computers.py`](../examples/multi_target_computers.py) for a runnable script covering both explicit switching and model-orchestrated workflows. diff --git a/examples/computer_only_tool.py b/examples/computer_only_tool.py new file mode 100644 index 00000000..61b75029 --- /dev/null +++ b/examples/computer_only_tool.py @@ -0,0 +1,139 @@ +"""Example: a custom tool restricted to one device type (computer only). + +A tool that needs to drive a device should subclass one of the device-specific +base classes instead of `Tool`: + +- `ComputerBaseTool` - typed `self.agent_os` (a `ComputerAgentOS`); the tool is + tagged `"computer"` and can only be bound to a computer/desktop target. +- `AndroidBaseTool` - typed `self.agent_os` (an `AndroidAgentOs`); the tool is + tagged `"android"` and can only be bound to an Android target. + +When `act()` runs, the SDK binds each tool to the first registered agent OS +whose tags contain all of the tool's `required_tags`. `ComputerBaseTool` sets +`required_tags=["computer"]`, so the tool below is never handed to an Android +device. This matters most with `MultiDeviceAgent`, which registers both a +computer and an Android agent OS in the same `act()` call. + +Required environment variables (see .env): +- ASKUI_WORKSPACE_ID, ASKUI_TOKEN - for the default AskUI model stack +""" + +import logging + +from askui import ComputerAgent +from askui.models.shared import ComputerBaseTool +from askui.tools.askui import LocalComputerTarget, RemoteComputerTarget + +logging.basicConfig( + level=logging.INFO, + format="[%(levelname)s] %(asctime)s %(pathname)s:%(lineno)d | %(message)s", +) +logger = logging.getLogger(__name__) + + +class ComputerScreenSizeTool(ComputerBaseTool): + """Reports the pixel size of the active computer screen. + + Subclassing `ComputerBaseTool` tags this tool as `"computer"`, so it is only + ever bound to a computer (desktop) agent OS - never to an Android device. + `self.agent_os` is therefore a `ComputerAgentOS`. + """ + + def __init__(self) -> None: + super().__init__( + name="get_screen_size", + description=( + "Return the width and height in pixels of the active computer screen." + ), + input_schema={"type": "object", "properties": {}}, + ) + + def __call__(self) -> str: + screenshot = self.agent_os.screenshot() + return f"{screenshot.width}x{screenshot.height}" + + +class ScreenSizeOfMachineTool(ComputerBaseTool): + """Reports the screen size of one specific computer target (auto-switch). + + The tag system only restricts a tool to a device *type* (computer vs + Android), not to an individual machine. To bind a tool to one specific + target, have it auto-switch to that target inside `__call__`: + `self.agent_os.temporary_select(computer_id)` activates the given target for + the duration of the block and restores the previously active one on exit + (even if the body raises). So the tool always acts on its machine without + disturbing the rest of the run. + """ + + def __init__(self, computer_id: str) -> None: + super().__init__( + name="get_screen_size_of_machine", + description=( + "Return the screen size of the machine this tool is bound to." + ), + input_schema={"type": "object", "properties": {}}, + ) + self._computer_id = computer_id + + def __call__(self) -> str: + with self.agent_os.temporary_select(self._computer_id): + screenshot = self.agent_os.screenshot() + return f"{screenshot.width}x{screenshot.height}" + + +def computer_only_tool_with_computer_agent() -> None: + """Use the computer-scoped tool with a plain `ComputerAgent`.""" + with ComputerAgent() as agent: + agent.act( + "Report the current screen size using the get_screen_size tool", + tools=[ComputerScreenSizeTool()], + ) + + +def computer_only_tool_with_multi_device_agent() -> None: + """Show that the tool is routed only to the computer in a multi-device run. + + Requires the `android` dependency (`pip install askui[android]`) and a + connected Android device/emulator. + """ + from askui import MultiDeviceAgent + + with MultiDeviceAgent(android_device_sn="emulator-5554") as agent: + agent.act( + "Read the screen size on the computer, then take a screenshot on " + "the phone", + # ComputerScreenSizeTool is given only to the computer agent OS. + # An AndroidBaseTool subclass would be given only to the device. + tools=[ComputerScreenSizeTool()], + ) + + +def tool_pinned_to_a_specific_machine() -> None: + """Bind a tool to one specific target machine via auto-switch. + + `ScreenSizeOfMachineTool` always measures "remote-box", even though + "local-box" is the active target. The remote example expects an Agent OS + controller reachable at the configured address; adjust it to your setup. + """ + with ComputerAgent( + agent_os_target_computers=[ + LocalComputerTarget(computer_id="local-box"), + RemoteComputerTarget( + address="192.168.1.42:26000", + description="Remote box", + computer_id="remote-box", + ), + ], + ) as agent: + agent.act( + "Report the screen size of the remote machine", + tools=[ScreenSizeOfMachineTool(computer_id="remote-box")], + ) + + +if __name__ == "__main__": + computer_only_tool_with_computer_agent() + # computer_only_tool_with_multi_device_agent() + # tool_pinned_to_a_specific_machine() + + logger.info("Done!") diff --git a/examples/multi_target_computers.py b/examples/multi_target_computers.py new file mode 100644 index 00000000..d685705f --- /dev/null +++ b/examples/multi_target_computers.py @@ -0,0 +1,83 @@ +"""Example demonstrating how to drive multiple target computers with one agent. + +A single `ComputerAgent` can control one or more machines through the +`agent_os_target_computers` argument. Each entry is an Agent OS *target +computer* identified by a stable `computer_id`: + +- `LocalComputerTarget` - manages an Agent OS controller subprocess on this + machine (at most one per agent). +- `RemoteComputerTarget` - points at an Agent OS controller already running on + another machine, reachable over gRPC. + +At any moment exactly one target is *active* and receives all explicit calls +(`click`, `type`, `keyboard`, ...). The first target in the list is the initial +active one. You can change the active target at runtime in three ways: + +1. `agent.tools.os.switch_agent_os_target_computer(computer_id)` - switch and + keep the new target active. +2. `with agent.tools.os.temporary_select(computer_id): ...` - switch for the + duration of a block, then restore the previously active target on exit. +3. Let `act()` orchestrate on its own - the model has `list_agent_os_target_computers`, + `switch_agent_os_target_computer`, and `get_current_computer_target_id` tools. + +Required environment variables (see .env): +- ASKUI_WORKSPACE_ID, ASKUI_TOKEN - for the default AskUI model stack +""" + +import logging + +from askui import ComputerAgent +from askui.reporting import SimpleHtmlReporter +from askui.tools.askui import LocalComputerTarget, RemoteComputerTarget + +logging.basicConfig( + level=logging.INFO, + format="[%(levelname)s] %(asctime)s %(pathname)s:%(lineno)d | %(message)s", +) +logger = logging.getLogger(__name__) + + +def explicit_switching() -> None: + """Route explicit calls to specific machines via `switch`/`temporary_select`.""" + with ComputerAgent( + agent_os_target_computers=[ + LocalComputerTarget(computer_id="local-box"), + RemoteComputerTarget( + address="10.0.24.11:26000", + description="Remote box with a text editor open", + computer_id="remote-box", + ), + ], + reporters=[SimpleHtmlReporter()], + ) as agent: + agent.act("Take a screenshot on each machine that you are connected to") + + +def model_orchestrated() -> None: + """Let `act()` decide when to switch between machines on its own.""" + with ComputerAgent( + agent_os_target_computers=[ + LocalComputerTarget(computer_id="research-box"), + RemoteComputerTarget( + address="192.168.1.42:26000", + description="Writer box with a text editor open", + computer_id="writer-box", + ), + ], + ) as agent: + agent.act( + "On research-box, open a browser, google 'askui', and read the top " + "results to gather key facts about what AskUI is, what it does, and " + "notable features. Then switch to writer-box and write a Markdown " + "document titled 'AskUI Findings' summarizing those facts as a " + "bulleted list in the open text editor." + ) + + +if __name__ == "__main__": + # Pick the scenario to run. The remote examples expect an Agent OS + # controller reachable at the configured address; adjust it to your setup. + explicit_switching() + # model_orchestrated() + + logger.info("Done!") diff --git a/mypy.ini b/mypy.ini index cfb75eb0..7a8a99d0 100644 --- a/mypy.ini +++ b/mypy.ini @@ -16,6 +16,8 @@ plugins = pydantic.mypy,sqlalchemy.ext.mypy.plugin exclude = (?x)( ^src/askui/models/ui_tars_ep/ui_tars_api\.py$ | ^src/askui/tools/askui/askui_ui_controller_grpc/.*$ + | ^venv/.*$ + | ^\.venv/.*$ ) mypy_path = src:tests explicit_package_bases = true diff --git a/src/askui/__init__.py b/src/askui/__init__.py index 1327ef2e..8e1a9737 100644 --- a/src/askui/__init__.py +++ b/src/askui/__init__.py @@ -46,6 +46,7 @@ from .models.types.response_schemas import ResponseSchema, ResponseSchemaBase from .retry import ConfigurableRetry, Retry from .tools import ModifierKey, PcKey +from .tools.askui import LocalComputerTarget, RemoteComputerTarget from .utils.image_utils import ImageSource from .utils.source_utils import InputSource @@ -70,6 +71,8 @@ logging.getLogger(__name__).addHandler(logging.NullHandler()) __all__ = [ + "RemoteComputerTarget", + "LocalComputerTarget", "Agent", "AutomationError", "ComputerAgent", diff --git a/src/askui/agent_base.py b/src/askui/agent_base.py index 45f39174..6a1915c6 100644 --- a/src/askui/agent_base.py +++ b/src/askui/agent_base.py @@ -27,7 +27,7 @@ from askui.models.shared.truncation_strategies import TruncationStrategy from askui.prompts.act_prompts import CACHE_USE_PROMPT, create_default_prompt from askui.telemetry.otel import OtelSettings, setup_opentelemetry_tracing -from askui.tools.agent_os import AgentOs +from askui.tools.agent_os import ComputerAgentOS from askui.tools.android.agent_os import AndroidAgentOs from askui.tools.caching_tools import ( InspectCacheMetadata, @@ -58,7 +58,7 @@ def __init__( reporter: Reporter | None = None, retry: Retry | None = None, tools: list[Tool] | None = None, - agent_os: AgentOs | AndroidAgentOs | None = None, + agent_os: ComputerAgentOS | AndroidAgentOs | None = None, settings: AgentSettings | None = None, callbacks: list[ConversationCallback] | None = None, truncation_strategy: TruncationStrategy | None = None, diff --git a/src/askui/computer_agent.py b/src/askui/computer_agent.py index 4d9ec380..55fbb545 100644 --- a/src/askui/computer_agent.py +++ b/src/askui/computer_agent.py @@ -18,11 +18,13 @@ create_computer_agent_prompt, ) from askui.tools.computer import ( + ComputerGetCurrentComputerTargetIdTool, ComputerGetMousePositionTool, ComputerGetSystemInfoTool, ComputerKeyboardPressedTool, ComputerKeyboardReleaseTool, ComputerKeyboardTapTool, + ComputerListAgentOsTargetComputersTool, ComputerListDisplaysTool, ComputerMouseClickTool, ComputerMouseHoldDownTool, @@ -32,6 +34,7 @@ ComputerRetrieveActiveDisplayTool, ComputerScreenshotTool, ComputerSetActiveDisplayTool, + ComputerSwitchAgentOsTargetComputerTool, ComputerTypeTool, ) from askui.tools.exception_tool import ExceptionTool @@ -39,7 +42,7 @@ from .reporting import CompositeReporter, Reporter from .retry import Retry from .tools import AgentToolbox, ComputerAgentOsFacade, ModifierKey, PcKey -from .tools.askui import AskUiControllerClient +from .tools.askui import ComputerTarget, MultiComputerTargetAgentOS logger = logging.getLogger(__name__) @@ -51,10 +54,30 @@ class ComputerAgent(Agent): This agent can perform various UI interactions like clicking, typing, scrolling, and more. It uses computer vision models to locate UI elements and execute actions on them. + A single `ComputerAgent` can drive **one or more machines** through the + `agent_os_target_computers` argument. Each entry is an Agent OS target + computer (local subprocess or remote gRPC endpoint) identified by a stable + `computer_id`. At any moment one target is *active* and receives all + explicit calls (`click`, `type`, `keyboard`, ...). The active target can be + changed at runtime via + `agent.tools.os.switch_agent_os_target_computer(computer_id)` or scoped to a + block using `agent.tools.os.temporary_select(computer_id)`. The `act()` + model is also given list/switch/get-current tools so it can orchestrate + work across machines on its own (e.g. read something on one computer and + re-enter it on another). + Args: - display (int, optional): The display number to use for screen interactions. Defaults to `1`. + display (int, optional): The display number to use for screen interactions on the default local target. Ignored when `agent_os_target_computers` is provided. Defaults to `1`. reporters (list[Reporter] | None, optional): List of reporter instances for logging and reporting. If `None`, an empty list is used. - tools (AgentToolbox | None, optional): Custom toolbox instance. If `None`, a default one will be created with `AskUiControllerClient`. + agent_os_target_computers (list[ComputerTarget] | None, optional): + Target computers the agent can route actions to. May mix one + `LocalComputerTarget` (managing a controller subprocess on this + machine) with any number of `RemoteComputerTarget`s pointing at + controllers already running on other machines. Constraints: at + least one target, at most one local, and remote `address`es plus + all `computer_id`s must be unique. The first entry becomes the + initial active target. Defaults to a single local target bound to + `display`. settings (AgentSettings | None, optional): Provider-based model settings. If `None`, uses the default AskUI model stack. retry (Retry, optional): The retry instance to use for retrying failed actions. Defaults to `ConfigurableRetry` with exponential backoff. Currently only supported for `locate()` method. act_tools (list[Tool] | None, optional): Additional tools to make available for @@ -69,6 +92,8 @@ class ComputerAgent(Agent): to the model; on-screen secrets cannot currently be hidden. Example: + Single local machine (the default): + ```python from askui import ComputerAgent @@ -78,6 +103,36 @@ class ComputerAgent(Agent): agent.act("Open settings menu") ``` + Example: + Research on one machine and write up the findings on another. The + first target in the list is the active one; `temporary_select` + re-routes a block of explicit calls and restores the previous + active target on exit. + + ```python + from askui import ComputerAgent + from askui.tools.askui import LocalComputerTarget, RemoteComputerTarget + + with ComputerAgent( + agent_os_target_computers=[ + LocalComputerTarget(computer_id="research-box"), + RemoteComputerTarget( + address="192.168.1.42:26000", + description="Writer box with a text editor open", + computer_id="writer-box", + ), + ], + ) as agent: + agent.act( + "On research-box, open a browser, google 'askui', and read " + "the top results to gather key facts about what AskUI is, " + "what it does, and notable features. Then switch to " + "writer-box and write a Markdown document titled " + "'AskUI Findings' summarizing those facts as a bulleted " + "list in the open text editor." + ) + ``` + Example (optional tools for `act()`): Register tools from `askui.tools.store` (or your own `Tool` implementations) either on the agent so they apply to all `act()` calls, or only for one call. @@ -102,11 +157,11 @@ class ComputerAgent(Agent): @telemetry.record_call( exclude={ "reporters", - "tools", "settings", "act_tools", "callbacks", "truncation_strategy", + "agent_os_target_computers", "secrets", } ) @@ -115,7 +170,7 @@ def __init__( self, display: Annotated[int, Field(ge=1)] = 1, reporters: list[Reporter] | None = None, - tools: AgentToolbox | None = None, + agent_os_target_computers: list[ComputerTarget] | None = None, settings: AgentSettings | None = None, retry: Retry | None = None, act_tools: list[Tool] | None = None, @@ -124,10 +179,11 @@ def __init__( secrets: list[Secret] | None = None, ) -> None: reporter = CompositeReporter(reporters=reporters) - self.tools = tools or AgentToolbox( - agent_os=AskUiControllerClient( + self.tools = AgentToolbox( + agent_os=MultiComputerTargetAgentOS( display=display, reporter=reporter, + agent_os_target_computers=agent_os_target_computers, ) ) super().__init__( @@ -515,8 +571,8 @@ def cli( with ComputerAgent() as agent: # Use for Windows - agent.cli(r'start "" "C:\Program Files\VideoLAN\VLC\vlc.exe"') # Start in VLC non-blocking - agent.cli(r'"C:\Program Files\VideoLAN\VLC\vlc.exe"') # Start in VLC blocking + agent.cli(r'start "" "C:\\Program Files\\VideoLAN\\VLC\\vlc.exe"') # Start in VLC non-blocking + agent.cli(r'"C:\\Program Files\\VideoLAN\\VLC\\vlc.exe"') # Start in VLC blocking # Mac agent.cli("open -a chrome") # Open Chrome non-blocking for mac @@ -556,6 +612,9 @@ def get_default_tools() -> list[Tool]: ComputerListDisplaysTool(), ComputerRetrieveActiveDisplayTool(), ComputerSetActiveDisplayTool(), + ComputerListAgentOsTargetComputersTool(), + ComputerSwitchAgentOsTargetComputerTool(), + ComputerGetCurrentComputerTargetIdTool(), ] diff --git a/src/askui/models/shared/android_base_tool.py b/src/askui/models/shared/android_base_tool.py index 5fc1c90b..fe4942bf 100644 --- a/src/askui/models/shared/android_base_tool.py +++ b/src/askui/models/shared/android_base_tool.py @@ -2,7 +2,7 @@ from askui.models.shared.tool_tags import ToolTags from askui.models.shared.tools import ToolWithAgentOS -from askui.tools import AgentOs +from askui.tools import ComputerAgentOS from askui.tools.agent_os_type_error import AgentOsTypeError from askui.tools.android.agent_os import AndroidAgentOs @@ -41,11 +41,11 @@ def agent_os(self) -> AndroidAgentOs: return agent_os @agent_os.setter - def agent_os(self, agent_os: AgentOs | AndroidAgentOs) -> None: + def agent_os(self, agent_os: ComputerAgentOS | AndroidAgentOs) -> None: """Set the agent OS. Args: - agent_os (AgentOs | AndroidAgentOs): The agent OS instance to set. + agent_os (ComputerAgentOS | AndroidAgentOs): The agent OS instance to set. Raises: TypeError: If the agent OS is not an AndroidAgentOs instance. diff --git a/src/askui/models/shared/computer_base_tool.py b/src/askui/models/shared/computer_base_tool.py index 0b6f13be..10a45d90 100644 --- a/src/askui/models/shared/computer_base_tool.py +++ b/src/askui/models/shared/computer_base_tool.py @@ -2,17 +2,17 @@ from askui.models.shared.tool_tags import ToolTags from askui.models.shared.tools import ToolWithAgentOS -from askui.tools.agent_os import AgentOs +from askui.tools.agent_os import ComputerAgentOS from askui.tools.agent_os_type_error import AgentOsTypeError from askui.tools.android.agent_os import AndroidAgentOs class ComputerBaseTool(ToolWithAgentOS): - """Tool base class that has an AgentOs available.""" + """Tool base class that has a ComputerAgentOS available.""" def __init__( self, - agent_os: AgentOs | None = None, + agent_os: ComputerAgentOS | None = None, required_tags: list[str] | None = None, **kwargs: Any, ) -> None: @@ -23,33 +23,34 @@ def __init__( ) @property - def agent_os(self) -> AgentOs: + def agent_os(self) -> ComputerAgentOS: """Get the agent OS. Returns: - AgentOs: The agent OS instance. + ComputerAgentOS: The agent OS instance. """ agent_os = super().agent_os - if not isinstance(agent_os, AgentOs): + if not isinstance(agent_os, ComputerAgentOS): raise AgentOsTypeError( - expected_type=AgentOs, + expected_type=ComputerAgentOS, actual_type=type(agent_os), ) return agent_os @agent_os.setter - def agent_os(self, agent_os: AgentOs | AndroidAgentOs) -> None: + def agent_os(self, agent_os: ComputerAgentOS | AndroidAgentOs) -> None: """Set the agent OS facade. Args: - agent_os (AgentOs | AndroidAgentOs): The agent OS facade instance to set. + agent_os (ComputerAgentOS | AndroidAgentOs): The agent OS facade + instance to set. Raises: - TypeError: If the agent OS is not an AgentOs instance. + TypeError: If the agent OS is not a ComputerAgentOS instance. """ - if not isinstance(agent_os, AgentOs): + if not isinstance(agent_os, ComputerAgentOS): raise AgentOsTypeError( - expected_type=AgentOs, + expected_type=ComputerAgentOS, actual_type=type(agent_os), ) self._agent_os = agent_os diff --git a/src/askui/models/shared/playwright_base_tool.py b/src/askui/models/shared/playwright_base_tool.py index 1415c99a..da2772a0 100644 --- a/src/askui/models/shared/playwright_base_tool.py +++ b/src/askui/models/shared/playwright_base_tool.py @@ -2,18 +2,18 @@ from askui.models.shared.tool_tags import ToolTags from askui.models.shared.tools import ToolWithAgentOS -from askui.tools.agent_os import AgentOs +from askui.tools.agent_os import ComputerAgentOS from askui.tools.agent_os_type_error import AgentOsTypeError from askui.tools.android.agent_os import AndroidAgentOs from askui.tools.playwright.agent_os import PlaywrightAgentOs class PlaywrightBaseTool(ToolWithAgentOS): - """Tool base class that has an the Playwright AgentOs available.""" + """Tool base class that has a Playwright ComputerAgentOS available.""" def __init__( self, - agent_os: AgentOs | None = None, + agent_os: ComputerAgentOS | None = None, required_tags: list[str] | None = None, **kwargs: Any, ) -> None: @@ -39,12 +39,14 @@ def agent_os(self) -> PlaywrightAgentOs: return agent_os @agent_os.setter - def agent_os(self, agent_os: AgentOs | AndroidAgentOs | PlaywrightAgentOs) -> None: + def agent_os( + self, agent_os: ComputerAgentOS | AndroidAgentOs | PlaywrightAgentOs + ) -> None: """Set the agent OS. Args: - agent_os (AgentOs | AndroidAgentOs | PlaywrightAgentOs): The agent OS - instance to set. + agent_os (ComputerAgentOS | AndroidAgentOs | PlaywrightAgentOs): The + agent OS instance to set. Raises: TypeError: If the agent OS is not an `PlaywrightAgentOs` instance. diff --git a/src/askui/models/shared/tools.py b/src/askui/models/shared/tools.py index 72ecbd5a..1ec3ca4a 100644 --- a/src/askui/models/shared/tools.py +++ b/src/askui/models/shared/tools.py @@ -32,7 +32,7 @@ ToolUseBlockParam, ) from askui.models.shared.secrets import SecretVault -from askui.tools import AgentOs +from askui.tools import ComputerAgentOS from askui.tools.android.agent_os import AndroidAgentOs from askui.utils.image_utils import ImageSource, base64_to_image @@ -350,23 +350,23 @@ def __call__(self, *args: Any, **kwargs: Any) -> ToolCallResult: class ToolWithAgentOS(Tool): - """Tool base class that has an AgentOs available.""" + """Tool base class that has a ComputerAgentOS available.""" def __init__( self, required_tags: list[str], - agent_os: AgentOs | AndroidAgentOs | None = None, + agent_os: ComputerAgentOS | AndroidAgentOs | None = None, **kwargs: Any, ) -> None: super().__init__(**kwargs, required_tags=required_tags) - self._agent_os: AgentOs | AndroidAgentOs | None = agent_os + self._agent_os: ComputerAgentOS | AndroidAgentOs | None = agent_os @property - def agent_os(self) -> AgentOs | AndroidAgentOs: - """Get the agent OS. + def agent_os(self) -> ComputerAgentOS | AndroidAgentOs: + """Get the AgentOS. Returns: - AgentOs | AndroidAgentOs: The agent OS instance. + ComputerAgentOS | AndroidAgentOs: The AgentOS instance. """ if self._agent_os is None: msg = ( @@ -378,11 +378,11 @@ def agent_os(self) -> AgentOs | AndroidAgentOs: return self._agent_os @agent_os.setter - def agent_os(self, agent_os: AgentOs | AndroidAgentOs) -> None: + def agent_os(self, agent_os: ComputerAgentOS | AndroidAgentOs) -> None: self._agent_os = agent_os def is_agent_os_initialized(self) -> bool: - """Check if the agent OS is initialized.""" + """Check if the AgentOS is initialized.""" return self._agent_os is not None @@ -461,12 +461,12 @@ def __init__( tools: list[Tool] | None = None, mcp_client: McpClientProtocol | None = None, include: set[str] | None = None, - agent_os_list: list[AgentOs | AndroidAgentOs] | None = None, + agent_os_list: list[ComputerAgentOS | AndroidAgentOs] | None = None, secret_vault: SecretVault | None = None, ) -> None: self._mcp_client = mcp_client self._include = include - self._agent_os_list: list[AgentOs | AndroidAgentOs] = [] + self._agent_os_list: list[ComputerAgentOS | AndroidAgentOs] = [] self._tools: list[Tool] = tools or [] self._secret_vault: SecretVault = secret_vault or SecretVault() if agent_os_list: @@ -482,11 +482,11 @@ def secret_vault(self) -> SecretVault: def secret_vault(self, secret_vault: SecretVault) -> None: self._secret_vault = secret_vault - def add_agent_os(self, agent_os: AgentOs | AndroidAgentOs) -> None: - """Add an agent OS to the collection. + def add_agent_os(self, agent_os: ComputerAgentOS | AndroidAgentOs) -> None: + """Add an AgentOS to the collection. Args: - agent_os (AgentOs | AndroidAgentOs): The agent OS instance to add. + agent_os (ComputerAgentOS | AndroidAgentOs): The AgentOS instance to add. """ self._agent_os_list.append(agent_os) @@ -546,12 +546,23 @@ def reset_tools(self, tools: list[Tool] | None = None) -> None: """Reset the tools in the collection with new tools.""" self._tools = tools or [] - def get_agent_os_by_tags(self, tags: list[str]) -> AgentOs | AndroidAgentOs: - """Get an agent OS by tags.""" + def get_agent_os_by_tags( + self, required_tags: list[str] + ) -> ComputerAgentOS | AndroidAgentOs: + """ + Find the first registered AgentOS whose tags are a superset of + `required_tags`. + + Every tag in `required_tags` must appear in the AgentOS's tags; the + AgentOS may declare additional tags beyond those. + + Raises: + ValueError: when no registered AgentOS satisfies the required tags. + """ for agent_os in self._agent_os_list: - if all(tag in agent_os.tags for tag in tags): + if all(required in agent_os.tags for required in required_tags): return agent_os - msg = f"Agent OS with tags [{', '.join(tags)}] not found" + msg = f"No AgentOS satisfies required tags [{', '.join(required_tags)}]" raise ValueError(msg) def _initialize_tools(self) -> None: diff --git a/src/askui/tools/__init__.py b/src/askui/tools/__init__.py index ecd5bf24..c0f0dcc4 100644 --- a/src/askui/tools/__init__.py +++ b/src/askui/tools/__init__.py @@ -1,10 +1,11 @@ -from .agent_os import AgentOs, Coordinate, ModifierKey, PcKey +from .agent_os import AgentOs, ComputerAgentOS, Coordinate, ModifierKey, PcKey from .askui.askui_controller import RenderObjectStyle from .computer_agent_os_facade import ComputerAgentOsFacade from .toolbox import AgentToolbox __all__ = [ "AgentOs", + "ComputerAgentOS", "AgentToolbox", "ModifierKey", "PcKey", diff --git a/src/askui/tools/agent_os.py b/src/askui/tools/agent_os.py index af9cc96d..e50a523a 100644 --- a/src/askui/tools/agent_os.py +++ b/src/askui/tools/agent_os.py @@ -1,12 +1,17 @@ from abc import ABC, abstractmethod +from contextlib import AbstractContextManager from typing import TYPE_CHECKING, Literal from PIL import Image from pydantic import BaseModel, ConfigDict, Field +from typing_extensions import Self from askui.models.shared.tool_tags import ToolTags if TYPE_CHECKING: + from askui.tools.askui.agent_os_target_computer import ( + ComputerTarget, + ) from askui.tools.askui.askui_ui_controller_grpc.generated import ( Controller_V1_pb2 as controller_v1_pbs, ) @@ -213,7 +218,7 @@ def __str__(self) -> str: InputEvent = ClickEvent -class AgentOs(ABC): +class ComputerAgentOS(ABC): """ Abstract base class for Agent OS. Cannot be instantiated directly. @@ -682,6 +687,55 @@ def set_window_in_focus(self, process_id: int, window_id: int) -> None: """ raise NotImplementedError + def add_agent_os_target_computer( + self, agent_os_target_computer: "ComputerTarget" + ) -> "ComputerTarget": + """Register an additional target computer. Auto-connects if connected.""" + raise NotImplementedError + + def reset_agent_os_target_computers( + self, + agent_os_target_computers: "list[ComputerTarget] | None" = None, + ) -> None: + """Disconnect (if connected) and replace the target computer list.""" + raise NotImplementedError + + def describe_agent_os_target_computers(self) -> list[str]: + """Return the `repr()` string of every registered target computer.""" + raise NotImplementedError + + def get_current_computer_target_id(self, report: bool = True) -> str: + """Return the `computer_id` of the currently active target computer.""" + raise NotImplementedError + + def switch_agent_os_target_computer(self, computer_id: str) -> "ComputerTarget": + """Switch the active target computer by its `computer_id`.""" + raise NotImplementedError + + def temporary_select(self, computer_id: str) -> AbstractContextManager[Self]: + """ + Temporarily switch the active target computer for the duration of a `with` + block, then restore the previously-active target on exit (even if the + block raises). + + Args: + computer_id (str): Computer id of the target to activate inside the + block. + + Returns: + AbstractContextManager[Self]: Context manager that yields this + `ComputerAgentOS` with the selected target active. + + Example: + ```python + with agent_os.temporary_select('Remote-Machine') as remote_machine: + img = remote_machine.screenshot() + img.save("remote_machine.png") + # previous active target restored here + ``` + """ + raise NotImplementedError + def get_file_names(self, absolute_directory_path: str) -> list[str]: """ List file names in an absolute directory on the automation target @@ -724,3 +778,13 @@ def remove_virtual_displays(self) -> None: NotImplementedError: If the implementation does not support this operation. """ raise NotImplementedError + + +AgentOs = ComputerAgentOS +"""Deprecated alias for `ComputerAgentOS`, kept for backward compatibility. + +`AgentOs` was renamed to `ComputerAgentOS` to reflect that it is the +computer-specific Agent OS interface (mouse, keyboard, displays, ...) rather +than a universal abstraction across all device types. Prefer `ComputerAgentOS` +in new code. +""" diff --git a/src/askui/tools/android/agent_os.py b/src/askui/tools/android/agent_os.py index 3a5a8285..d7fe7e04 100644 --- a/src/askui/tools/android/agent_os.py +++ b/src/askui/tools/android/agent_os.py @@ -1,7 +1,9 @@ from abc import ABC, abstractmethod +from contextlib import AbstractContextManager from typing import List, Literal from PIL import Image +from typing_extensions import Self from askui.tools.android.uiautomator_hierarchy import UIElementCollection @@ -502,3 +504,26 @@ def get_ui_elements(self) -> UIElementCollection: Gets the UI elements. """ raise NotImplementedError + + def temporary_select(self, device_sn: str) -> AbstractContextManager[Self]: + """ + Temporarily switch the active device for the duration of a `with` block, + then restore the previously-active device on exit (even if the block + raises). + + Args: + device_sn (str): Serial number of the device to activate inside the + block. + + Returns: + AbstractContextManager[Self]: Context manager that yields this + `AndroidAgentOs` with `device_sn` active. + + Example: + ```python + with android_agent_os.temporary_select('table_phone') as table_phone: + table_phone.tap(100, 200) + # previous active device restored here + ``` + """ + raise NotImplementedError diff --git a/src/askui/tools/android/agent_os_facade.py b/src/askui/tools/android/agent_os_facade.py index e94e42be..998940f7 100644 --- a/src/askui/tools/android/agent_os_facade.py +++ b/src/askui/tools/android/agent_os_facade.py @@ -1,5 +1,6 @@ from __future__ import annotations +from contextlib import contextmanager from typing import TYPE_CHECKING from askui.models.shared.tool_tags import ToolTags @@ -7,7 +8,10 @@ from askui.tools.coordinate_scaler import CoordinateScaler if TYPE_CHECKING: + from collections.abc import Iterator + from PIL import Image + from typing_extensions import Self from askui.models.shared.coordinate_space import VlmCoordinateSpace from askui.models.shared.image_scaler import ImageScaler @@ -120,6 +124,15 @@ def set_device_by_serial_number(self, device_sn: str) -> None: self._agent_os.set_device_by_serial_number(device_sn) self._scaler.real_screen_resolution = None + @contextmanager + def temporary_select(self, device_sn: str) -> Iterator[Self]: + with self._agent_os.temporary_select(device_sn): + self._scaler.real_screen_resolution = None + try: + yield self + finally: + self._scaler.real_screen_resolution = None + def get_connected_devices_serial_numbers(self) -> list[str]: return self._agent_os.get_connected_devices_serial_numbers() diff --git a/src/askui/tools/android/ppadb_agent_os.py b/src/askui/tools/android/ppadb_agent_os.py index 9ffa7452..517ed4e1 100644 --- a/src/askui/tools/android/ppadb_agent_os.py +++ b/src/askui/tools/android/ppadb_agent_os.py @@ -2,12 +2,15 @@ import re import shlex import string +from collections.abc import Iterator +from contextlib import contextmanager from pathlib import Path from typing import List, Optional, get_args from PIL import Image from ppadb.client import Client as AdbClient from ppadb.device import Device as AndroidDevice +from typing_extensions import Self from askui.reporting import NULL_REPORTER, Reporter from askui.tools.android.agent_os import ( @@ -202,6 +205,24 @@ def set_device_by_serial_number(self, device_sn: str) -> None: msg = f"Device name {device_sn} not found" raise AndroidAgentOsError(msg) + @contextmanager + def temporary_select(self, device_sn: str) -> Iterator[Self]: + previous_sn = self._device.serial if self._device is not None else None + self._reporter.add_message( + self._REPORTER_ROLE_NAME, + f"temporary_select({device_sn!r}) [previous={previous_sn!r}]", + ) + self.set_device_by_serial_number(device_sn) + try: + yield self + finally: + if previous_sn is not None and previous_sn != device_sn: + self.set_device_by_serial_number(previous_sn) + self._reporter.add_message( + self._REPORTER_ROLE_NAME, + f"temporary_select({device_sn!r}) -> restored", + ) + def _screenshot_without_reporting(self) -> Image.Image: device: AndroidDevice = self._get_selected_device() self._check_if_display_is_selected() diff --git a/src/askui/tools/askui/__init__.py b/src/askui/tools/askui/__init__.py index 5d46a982..db94e66d 100644 --- a/src/askui/tools/askui/__init__.py +++ b/src/askui/tools/askui/__init__.py @@ -1,6 +1,19 @@ -from .askui_controller import AskUiControllerClient, AskUiControllerServer +from .agent_os_target_computer import ( + ComputerTarget, + LocalComputerTarget, + RemoteComputerTarget, +) +from .askui_controller import MultiComputerTargetAgentOS +from .computer_target_connection import ComputerTargetConnection +from .computer_target_pool import ( + ComputerTargetPool, +) __all__ = [ - "AskUiControllerClient", - "AskUiControllerServer", + "ComputerTarget", + "ComputerTargetConnection", + "ComputerTargetPool", + "MultiComputerTargetAgentOS", + "LocalComputerTarget", + "RemoteComputerTarget", ] diff --git a/src/askui/tools/askui/agent_os_target_computer.py b/src/askui/tools/askui/agent_os_target_computer.py new file mode 100644 index 00000000..03cc596f --- /dev/null +++ b/src/askui/tools/askui/agent_os_target_computer.py @@ -0,0 +1,378 @@ +import logging +import pathlib +import subprocess +import sys +import time +import uuid +from urllib.parse import urlparse + +from typing_extensions import override + +from askui.tools.askui.askui_controller_settings import AskUiControllerSettings +from askui.tools.askui.computer_target_connection import ComputerTargetConnection +from askui.tools.askui.exceptions import AskUiControllerError +from askui.tools.utils import process_exists, wait_for_port + +logger = logging.getLogger(__name__) + + +class ComputerTarget: + """ + Base class describing a computer target (a machine running the AskUI Agent + OS) that a `MultiComputerTargetAgentOS` client can connect to. + + A computer target runs the server-side counterpart of the `ComputerAgentOS` + client abstraction: it exposes a gRPC API for OS-level operations + (screenshot, mouse, keyboard, ...) and is identified by a unique session + GUID. Each computer target also tracks which display it is currently + operating against. + + Args: + address (str): gRPC address of the target computer + (e.g. ``"localhost:23000"``). + description (str): Human-readable description. + display (int, optional): Display ID selected for this target computer. + Defaults to `1`. + computer_id (str | None, optional): Stable, human-friendly identifier for + the target computer. Used by `ComputerTargetPool` lookup + helpers. Must be unique across registered target computers. Defaults + to the target computer's `session_guid`. + """ + + def __init__( + self, + address: str, + description: str, + display: int = 1, + computer_id: str | None = None, + ) -> None: + self._session_guid = "{" + str(uuid.uuid4()) + "}" + self._address = address + self._description = description + self._display = display + self._computer_id = ( + computer_id if computer_id is not None else self._session_guid + ) + self._connection: ComputerTargetConnection | None = None + + @property + def session_guid(self) -> str: + """Unique session GUID assigned to this target computer.""" + return self._session_guid + + @property + def computer_id(self) -> str: + """ + Stable identifier for this target computer. Defaults to `session_guid` + when no custom id was supplied at construction time. + """ + return self._computer_id + + @property + def address(self) -> str: + """gRPC address of the target computer.""" + return self._address + + @property + def description(self) -> str: + """Description of this target computer.""" + return self._description + + @property + def display(self) -> int: + """Display ID currently selected for this target computer.""" + return self._display + + @display.setter + def display(self, value: int) -> None: + self._display = value + + @property + def is_local(self) -> bool: + """Whether this target computer represents a locally-managed process.""" + return False + + @property + def is_connected(self) -> bool: + """Whether an open gRPC connection to this target computer exists.""" + return self._connection is not None + + @property + def connection(self) -> ComputerTargetConnection: + """ + The open gRPC connection to this target computer. + + Raises: + AskUiControllerError: If this target computer is not connected (i.e. + `connect()` has not been called). + """ + if self._connection is None: + error_msg = ( + f"Agent OS target computer {self._description!r} " + f"(computer_id={self._computer_id!r}, address={self._address}) " + "is not connected. Call `MultiComputerTargetAgentOS.connect()` " + "first." + ) + raise AskUiControllerError(error_msg) + return self._connection + + def connect(self) -> None: + """ + Open the gRPC connection to this target computer. Idempotent: returns + silently if already connected. Delegates the gRPC specifics to + `ComputerTargetConnection.open()`. + """ + if self._connection is None: + self._connection = ComputerTargetConnection.open(self) + + def disconnect(self) -> None: + """ + Close the gRPC connection to this target computer. No-op if not + connected. Delegates the gRPC teardown to + `ComputerTargetConnection.close()`. + """ + conn = self._connection + if conn is None: + return + self._connection = None + conn.close(self) + + def start(self, clean_up: bool = False) -> None: + """Start the underlying controller process. No-op for non-local targets.""" + + def stop(self, force: bool = False) -> None: + """Stop the underlying controller process. No-op for non-local targets.""" + + def __repr__(self) -> str: + return ( + f"{type(self).__name__}(" + f"computer_id={self._computer_id!r}, " + f"description={self._description!r}, " + f"display={self._display!r})" + ) + + +class LocalComputerTarget(ComputerTarget): + """ + Local computer target: manages an AskUI Remote Device Controller + subprocess on this machine. + + Args: + settings (AskUiControllerSettings | None, optional): Process-level settings + (executable path, args). Defaults to a fresh `AskUiControllerSettings`. + address (str, optional): gRPC address. Defaults to ``"localhost:23000"``. + is_service (bool, optional): When `True`, `start()` does not launch the + controller binary because it is managed externally (e.g. AskUI Core + Service on Windows). Defaults to `False`. + discover_service (bool, optional): On Windows, probe for a running + ``askuicoreservice`` and, if found, switch the address to port + ``26000`` and set `is_service` to `True`. Defaults to `True`. + description (str, optional) + display (int, optional): Display ID selected for this target computer. + Defaults to `1`. + """ + + _ASKUI_CORE_SERVICE_NAME = "AskuiCoreService" + _ASKUI_CORE_SERVICE_PORT = 26000 + + def __init__( + self, + description: str = "Local computer target", + settings: AskUiControllerSettings | None = None, + address: str = "localhost:23000", + discover_service: bool = True, + display: int = 1, + computer_id: str | None = None, + ) -> None: + super().__init__( + address=address, + description=description, + display=display, + computer_id=computer_id, + ) + self._is_service = False + self._settings = settings or AskUiControllerSettings() + self._process: subprocess.Popen[bytes] | None = None + if discover_service: + self._discover_service(address) + + @property + @override + def is_local(self) -> bool: + return True + + @property + def is_service(self) -> bool: + """Whether the controller process is managed externally (skip `start()`).""" + return self._is_service + + @staticmethod + def _is_askui_core_service_running() -> bool: + """Return `True` when the `AskuiCoreService` Windows service is RUNNING.""" + if sys.platform == "win32": + try: + result = subprocess.run( + [ + "sc", + "query", + LocalComputerTarget._ASKUI_CORE_SERVICE_NAME, + ], + capture_output=True, + text=True, + timeout=5, + check=False, + ) + if result.returncode != 0: + return False + except (OSError, subprocess.SubprocessError) as e: + error_msg = ( + "Failed to query " + f"{LocalComputerTarget._ASKUI_CORE_SERVICE_NAME} service: {e}" + ) + logger.debug(error_msg) + return False + return "RUNNING" in result.stdout.upper() + return False + + def _discover_service(self, address: str) -> None: + if LocalComputerTarget._is_askui_core_service_running(): + service_msg = ( + f"Detected running {self._ASKUI_CORE_SERVICE_NAME}; using port " + f"{self._ASKUI_CORE_SERVICE_PORT} (controller managed by service)" + ) + logger.info(service_msg) + address = LocalComputerTarget.replace_port( + address, self._ASKUI_CORE_SERVICE_PORT + ) + self._is_service = True + + @staticmethod + def replace_port(address: str, port: int) -> str: + addr = address if "://" in address else "//" + address + parsed = urlparse(addr) + host = parsed.hostname or "localhost" + return f"{host}:{port}" + + def _parse_port(self) -> int: + addr = self._address if "://" in self._address else "//" + self._address + parsed = urlparse(addr) + if parsed.port is None: + error_msg = ( + f"Could not parse port from address {self._address!r}. " + "Expected format 'host:port' (e.g. 'localhost:23000')." + ) + raise ValueError(error_msg) + return parsed.port + + def _start_process( + self, + path: pathlib.Path, + args: str | None = None, + ) -> None: + commands = [str(path)] + if args: + commands.extend(args.split()) + if not logger.isEnabledFor(logging.DEBUG): + self._process = subprocess.Popen( + commands, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL + ) + else: + self._process = subprocess.Popen(commands) + wait_for_port(self._parse_port()) + + @override + def start(self, clean_up: bool = False) -> None: + """ + Start the controller process unless this target uses a service-managed + binary. + + Args: + clean_up (bool, optional): Whether to clean up existing processes + (only on Windows) before starting. Defaults to `False`. + """ + if self._is_service: + logger.debug( + "Skipping local controller start; process is managed by service" + ) + return + if ( + sys.platform == "win32" + and clean_up + and process_exists("AskuiRemoteDeviceController.exe") + ): + self.clean_up() + logger.debug( + "Starting AskUI Remote Device Controller", + extra={"path": str(self._settings.controller_path)}, + ) + self._start_process( + self._settings.controller_path, self._settings.controller_args + ) + time.sleep(0.5) + + def clean_up(self) -> None: + subprocess.run("taskkill.exe /IM AskUI*") + time.sleep(0.1) + + @override + def stop(self, force: bool = False) -> None: + """ + Stop the controller process. + + Args: + force (bool, optional): Whether to forcefully terminate the process. + Defaults to `False`. + """ + if self._process is None: + return + + try: + if force: + self._process.kill() + if sys.platform == "win32": + self.clean_up() + else: + self._process.terminate() + except Exception: # noqa: BLE001 - We want to catch all other exceptions here + logger.exception("Error stopping local controller process") + finally: + self._process = None + + +class RemoteComputerTarget(ComputerTarget): + """ + Remote computer target: the client connects to an already-running + controller on another machine. + + No process management is performed; `start()` and `stop()` are no-ops. + + Args: + address (str): gRPC address of the remote target computer (required). + description (str): Human-readable description. + display (int, optional): Display ID selected for this target computer. + Defaults to `1`. + computer_id (str | None, optional): Stable, human-friendly identifier for + the target computer. Defaults to the target computer's + `session_guid`. + """ + + def __init__( + self, + address: str, + description: str, + display: int = 1, + computer_id: str | None = None, + ) -> None: + super().__init__( + address=address, + description=description, + display=display, + computer_id=computer_id, + ) + + +__all__ = [ + "ComputerTarget", + "LocalComputerTarget", + "RemoteComputerTarget", +] diff --git a/src/askui/tools/askui/askui_controller.py b/src/askui/tools/askui/askui_controller.py index 4e2f8c4f..be657fe7 100644 --- a/src/askui/tools/askui/askui_controller.py +++ b/src/askui/tools/askui/askui_controller.py @@ -1,1419 +1,1504 @@ -import base64 -import logging -import pathlib -import subprocess -import sys -import time -import types -import uuid -from typing import Literal, Type - -import grpc -from google.protobuf.json_format import MessageToDict -from PIL import Image -from typing_extensions import Self, override - -from askui.container import telemetry -from askui.reporting import NULL_REPORTER, Reporter -from askui.tools.agent_os import ( - AgentOs, - Coordinate, - Display, - DisplaysListResponse, - ModifierKey, - PcKey, -) -from askui.tools.askui.askui_controller_client_settings import ( - AskUiControllerClientSettings, -) -from askui.tools.askui.askui_controller_settings import AskUiControllerSettings -from askui.tools.askui.askui_ui_controller_grpc.desktop_agent_os_error import ( - DesktopAgentOsError, -) -from askui.tools.askui.askui_ui_controller_grpc.generated import ( - Controller_V1_pb2 as controller_v1_pbs, -) -from askui.tools.askui.askui_ui_controller_grpc.generated import ( - Controller_V1_pb2_grpc as controller_v1, -) -from askui.tools.askui.askui_ui_controller_grpc.generated.AgentOS_Send_Request_2501 import ( # noqa: E501 - AddRenderObjectCommand, - AskUIAgentOSSendRequestSchema, - ClearRenderObjectsCommand, - Command, - DeleteRenderObjectCommand, - GetActiveProcessCommand, - GetActiveWindowCommand, - GetFileCommand, - GetFileNamesCommand, - GetMousePositionCommand, - GetSystemInfoCommand, - Guid, - Header, - Length, - Location, - Message, - Parameter3, - RemoveVirtualDisplaysCommand, - RenderImage, - RenderObjectId, - RenderObjectStyle, - RenderText, - SetActiveProcessCommand, - SetActiveWindowCommand, - SetMousePositionCommand, - UpdateRenderObjectCommand, -) -from askui.tools.askui.askui_ui_controller_grpc.generated.AgentOS_Send_Response_2501 import ( # noqa: E501 - AskUIAgentOSSendResponseSchema, - GetActiveProcessResponse, - GetActiveProcessResponseModel, - GetActiveWindowResponse, - GetActiveWindowResponseModel, - GetFileNamesResponse, - GetFileResponse, - GetSystemInfoResponse, - GetSystemInfoResponseModel, -) -from askui.utils.annotated_image import AnnotatedImage -from askui.utils.image_utils import base64_to_image - -from ..utils import process_exists, wait_for_port -from .exceptions import ( - AskUiControllerError, - AskUiControllerInvalidCommandError, - AskUiControllerOperationTimeoutError, -) - -logger = logging.getLogger(__name__) - - -class AskUiControllerServer: - """ - Concrete implementation of `ControllerServer` for managing the AskUI Remote Device - Controller process. - Handles process discovery, startup, and shutdown for the native controller binary. - - Args: - settings (AskUiControllerSettings | None, optional): Settings for the AskUI. - """ - - def __init__(self, settings: AskUiControllerSettings | None = None) -> None: - self._process: subprocess.Popen[bytes] | None = None - self._settings = settings or AskUiControllerSettings() - - def _start_process( - self, - path: pathlib.Path, - args: str | None = None, - ) -> None: - commands = [str(path)] - if args: - commands.extend(args.split()) - if not logger.isEnabledFor(logging.DEBUG): - self._process = subprocess.Popen( - commands, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL - ) - else: - self._process = subprocess.Popen(commands) - wait_for_port(23000) - - def start(self, clean_up: bool = False) -> None: - """ - Start the controller process. - - Args: - clean_up (bool, optional): Whether to clean up existing processes - (only on Windows) before starting. Defaults to `False`. - """ - if ( - sys.platform == "win32" - and clean_up - and process_exists("AskuiRemoteDeviceController.exe") - ): - self.clean_up() - logger.debug( - "Starting AskUI Remote Device Controller", - extra={"path": str(self._settings.controller_path)}, - ) - self._start_process( - self._settings.controller_path, self._settings.controller_args - ) - time.sleep(0.5) - - def clean_up(self) -> None: - subprocess.run("taskkill.exe /IM AskUI*") - time.sleep(0.1) - - def stop(self, force: bool = False) -> None: - """ - Stop the controller process. - - Args: - force (bool, optional): Whether to forcefully terminate the process. - Defaults to `False`. - """ - if self._process is None: - return # Nothing to stop - - try: - if force: - self._process.kill() - if sys.platform == "win32": - self.clean_up() - else: - self._process.terminate() - except Exception: # noqa: BLE001 - We want to catch all other exceptions here - logger.exception("Controller error") - finally: - self._process = None - - -class AskUiControllerClient(AgentOs): - """ - Implementation of `AgentOs` that communicates with the AskUI Remote Device - Controller via gRPC. - - Args: - reporter (Reporter): Reporter used for reporting with the `"AgentOs"`. - display (int, optional): Display number to use. Defaults to `1`. - controller_server (AskUiControllerServer | None, optional): Custom controller - server. Defaults to `ControllerServer`. - """ - - @telemetry.record_call(exclude={"reporter", "controller_server"}) - def __init__( - self, - reporter: Reporter = NULL_REPORTER, - display: int = 1, - controller_server: AskUiControllerServer | None = None, - settings: AskUiControllerClientSettings | None = None, - ) -> None: - self._stub: controller_v1.ControllerAPIStub | None = None - self._channel: grpc.Channel | None = None - self._session_info: controller_v1_pbs.SessionInfo | None = None - self._pre_action_wait = 0 - self._post_action_wait = 0.05 - self._max_retries = 10 - self._display = display - self._reporter = reporter - self._controller_server = controller_server or AskUiControllerServer() - self._session_guid = "{" + str(uuid.uuid4()) + "}" - self._settings = settings or AskUiControllerClientSettings() - - @telemetry.record_call() - @override - def connect(self) -> None: - """ - Establishes a connection to the AskUI Remote Device Controller. - - This method starts the controller server, establishes a gRPC channel, - creates a session, and sets up the initial display. - """ - if self._settings.server_autostart: - self._controller_server.start() - self._channel = grpc.insecure_channel( - self._settings.server_address, - options=[ - ("grpc.max_send_message_length", 2**30), - ("grpc.max_receive_message_length", 2**30), - ("grpc.default_deadline", 300000), - ], - ) - self._stub = controller_v1.ControllerAPIStub(self._channel) - self._start_session() - self._start_execution() - self.set_display(self._display) - if self._settings.clean_virtual_displays: - logger.info( - "clean_virtual_displays is enabled. Removing all virtual displays ... " - ) - self.remove_virtual_displays() - logger.info("Virtual displays removed.") - - def _get_stub(self) -> controller_v1.ControllerAPIStub: - assert isinstance(self._stub, controller_v1.ControllerAPIStub), ( - "Stub is not initialized. Call `connect()` first." - ) - return self._stub - - def _run_recorder_action( - self, - acion_class_id: controller_v1_pbs.ActionClassID, - action_parameters: controller_v1_pbs.ActionParameters, - ) -> controller_v1_pbs.Response_RunRecordedAction: - time.sleep(self._pre_action_wait) - response: controller_v1_pbs.Response_RunRecordedAction = ( - self._get_stub().RunRecordedAction( - controller_v1_pbs.Request_RunRecordedAction( - sessionInfo=self._session_info, - actionClassID=acion_class_id, - actionParameters=action_parameters, - ) - ) - ) - - time.sleep((response.requiredMilliseconds / 1000)) - num_retries = 0 - for _ in range(self._max_retries): - poll_response: controller_v1_pbs.Response_Poll = self._get_stub().Poll( - controller_v1_pbs.Request_Poll( - sessionInfo=self._session_info, - pollEventID=controller_v1_pbs.PollEventID.PollEventID_ActionFinished, - ) - ) - if ( - poll_response.pollEventParameters.actionFinished.actionID - == response.actionID - ): - break - time.sleep(self._post_action_wait) - num_retries += 1 - if num_retries == self._max_retries - 1: - raise AskUiControllerOperationTimeoutError - return response - - @telemetry.record_call() - @override - def disconnect(self) -> None: - """ - Terminates the connection to the AskUI Remote Device Controller. - - This method stops the execution, ends the session, closes the gRPC channel, - and stops the controller server. - """ - try: - self._stop_execution() - self._stop_session() - if self._channel is not None: - self._channel.close() - self._controller_server.stop() - except Exception as e: # noqa: BLE001 - # We want to catch all other exceptions here and not re-raise them - msg = ( - "Error while disconnecting from the AskUI Remote Device Controller" - f" Error: {e}" - ) - logger.exception(msg) - - @telemetry.record_call() - def __enter__(self) -> Self: - """ - Context manager entry point that establishes the connection. - - Returns: - Self: The instance of AskUiControllerClient. - """ - self.connect() - return self - - @telemetry.record_call(exclude={"exc_value", "traceback"}) - def __exit__( - self, - exc_type: Type[BaseException] | None, - exc_value: BaseException | None, - traceback: types.TracebackType | None, - ) -> None: - """ - Context manager exit point that disconnects the client. - - Args: - exc_type: The exception type if an exception was raised. - exc_value: The exception value if an exception was raised. - traceback: The traceback if an exception was raised. - """ - self.disconnect() - - def _start_session(self) -> None: - response = self._get_stub().StartSession( - controller_v1_pbs.Request_StartSession( - sessionGUID=self._session_guid, immediateExecution=True - ) - ) - self._session_info = response.sessionInfo - - def _stop_session(self) -> None: - self._get_stub().EndSession( - controller_v1_pbs.Request_EndSession(sessionInfo=self._session_info) - ) - - def _start_execution(self) -> None: - self._get_stub().StartExecution( - controller_v1_pbs.Request_StartExecution(sessionInfo=self._session_info) - ) - - def _stop_execution(self) -> None: - self._get_stub().StopExecution( - controller_v1_pbs.Request_StopExecution(sessionInfo=self._session_info) - ) - - @telemetry.record_call() - @override - def screenshot(self, report: bool = True, unscaled: bool = False) -> Image.Image: - """ - Take a screenshot of the current screen. - - Args: - report (bool, optional): Whether to include the screenshot in reporting. - Defaults to `True`. - unscaled (bool, optional): Accepted for interface compatibility. This - client always returns the native screen resolution, so it has no - effect. Defaults to `False`. - - Returns: - Image.Image: A PIL Image object containing the screenshot. - - """ - screenResponse = self._get_stub().CaptureScreen( - controller_v1_pbs.Request_CaptureScreen( - sessionInfo=self._session_info, - captureParameters=controller_v1_pbs.CaptureParameters( - displayID=self._display - ), - ) - ) - r, g, b, _ = Image.frombytes( - "RGBA", - (screenResponse.bitmap.width, screenResponse.bitmap.height), - screenResponse.bitmap.data, - ).split() - image = Image.merge("RGB", (b, g, r)) - if report: - self._reporter.add_message("AgentOS", "screenshot()", image) - return image - - @telemetry.record_call() - @override - def mouse_move(self, x: int, y: int, duration: int = 500) -> None: - """ - Moves the mouse cursor to specified screen coordinates. - - Args: - x (int): The horizontal coordinate (in pixels) to move to. - y (int): The vertical coordinate (in pixels) to move to. - duration (int): The duration (in ms) the movement should take. - """ - self._reporter.add_message( - "AgentOS", - f"mouse_move({x}, {y}, duration={duration})", - AnnotatedImage(lambda: self.screenshot(report=False), point_list=[(x, y)]), - ) - self._run_recorder_action( - acion_class_id=controller_v1_pbs.ActionClassID_MouseMove, - action_parameters=controller_v1_pbs.ActionParameters( - mouseMove=controller_v1_pbs.ActionParameters_MouseMove( - position=controller_v1_pbs.Coordinate2(x=x, y=y), - milliseconds=duration, - ) - ), - ) - - @telemetry.record_call(exclude={"text"}) - @override - def type(self, text: str, typing_speed: int = 50) -> None: - """ - Type text at current cursor position as if entered on a keyboard. - - Args: - text (str): The text to type. - typing_speed (int, optional): The speed of typing in characters per second. - Defaults to `50`. - """ - self._reporter.add_message("AgentOS", f'type("{text}", {typing_speed})') - self._run_recorder_action( - acion_class_id=controller_v1_pbs.ActionClassID_KeyboardType_UnicodeText, - action_parameters=controller_v1_pbs.ActionParameters( - keyboardTypeUnicodeText=controller_v1_pbs.ActionParameters_KeyboardType_UnicodeText( - text=text.encode("utf-16-le"), - typingSpeed=typing_speed, - typingSpeedValue=controller_v1_pbs.TypingSpeedValue.TypingSpeedValue_CharactersPerSecond, - ) - ), - ) - - @telemetry.record_call() - @override - def click( - self, button: Literal["left", "middle", "right"] = "left", count: int = 1 - ) -> None: - """ - Click a mouse button. - - Args: - button (Literal["left", "middle", "right"], optional): The mouse button to - click. Defaults to `"left"`. - count (int, optional): Number of times to click. Defaults to `1`. - """ - self._reporter.add_message("AgentOS", f'click("{button}", {count})') - mouse_button = None - match button: - case "left": - mouse_button = controller_v1_pbs.MouseButton_Left - case "middle": - mouse_button = controller_v1_pbs.MouseButton_Middle - case "right": - mouse_button = controller_v1_pbs.MouseButton_Right - self._run_recorder_action( - acion_class_id=controller_v1_pbs.ActionClassID_MouseButton_PressAndRelease, - action_parameters=controller_v1_pbs.ActionParameters( - mouseButtonPressAndRelease=controller_v1_pbs.ActionParameters_MouseButton_PressAndRelease( - mouseButton=mouse_button, count=count - ) - ), - ) - - @telemetry.record_call() - @override - def mouse_down(self, button: Literal["left", "middle", "right"] = "left") -> None: - """ - Press and hold a mouse button. - - Args: - button (Literal["left", "middle", "right"], optional): The mouse button to - press. Defaults to `"left"`. - """ - self._reporter.add_message("AgentOS", f'mouse_down("{button}")') - mouse_button = None - match button: - case "left": - mouse_button = controller_v1_pbs.MouseButton_Left - case "middle": - mouse_button = controller_v1_pbs.MouseButton_Middle - case "right": - mouse_button = controller_v1_pbs.MouseButton_Right - self._run_recorder_action( - acion_class_id=controller_v1_pbs.ActionClassID_MouseButton_Press, - action_parameters=controller_v1_pbs.ActionParameters( - mouseButtonPress=controller_v1_pbs.ActionParameters_MouseButton_Press( - mouseButton=mouse_button - ) - ), - ) - - @telemetry.record_call() - @override - def mouse_up(self, button: Literal["left", "middle", "right"] = "left") -> None: - """ - Release a mouse button. - - Args: - button (Literal["left", "middle", "right"], optional): The mouse button to - release. Defaults to `"left"`. - """ - self._reporter.add_message("AgentOS", f'mouse_up("{button}")') - mouse_button = None - match button: - case "left": - mouse_button = controller_v1_pbs.MouseButton_Left - case "middle": - mouse_button = controller_v1_pbs.MouseButton_Middle - case "right": - mouse_button = controller_v1_pbs.MouseButton_Right - self._run_recorder_action( - acion_class_id=controller_v1_pbs.ActionClassID_MouseButton_Release, - action_parameters=controller_v1_pbs.ActionParameters( - mouseButtonRelease=controller_v1_pbs.ActionParameters_MouseButton_Release( - mouseButton=mouse_button - ) - ), - ) - - @telemetry.record_call() - @override - def mouse_scroll(self, dx: int, dy: int) -> None: - """ - Scroll the mouse wheel. - - Args: - dx (int): The horizontal scroll amount. Positive values scroll right, - negative values scroll left. - dy (int): The vertical scroll amount. Positive values scroll down, - negative values scroll up. - """ - self._reporter.add_message("AgentOS", f"mouse_scroll({dx}, {dy})") - if dx != 0: - self._run_recorder_action( - acion_class_id=controller_v1_pbs.ActionClassID_MouseWheelScroll, - action_parameters=controller_v1_pbs.ActionParameters( - mouseWheelScroll=controller_v1_pbs.ActionParameters_MouseWheelScroll( - direction=controller_v1_pbs.MouseWheelScrollDirection.MouseWheelScrollDirection_Horizontal, - deltaType=controller_v1_pbs.MouseWheelDeltaType.MouseWheelDelta_Raw, - delta=dx, - milliseconds=50, - ) - ), - ) - if dy != 0: - self._run_recorder_action( - acion_class_id=controller_v1_pbs.ActionClassID_MouseWheelScroll, - action_parameters=controller_v1_pbs.ActionParameters( - mouseWheelScroll=controller_v1_pbs.ActionParameters_MouseWheelScroll( - direction=controller_v1_pbs.MouseWheelScrollDirection.MouseWheelScrollDirection_Vertical, - deltaType=controller_v1_pbs.MouseWheelDeltaType.MouseWheelDelta_Raw, - delta=dy, - milliseconds=50, - ) - ), - ) - - @telemetry.record_call() - @override - def keyboard_pressed( - self, key: PcKey | ModifierKey, modifier_keys: list[ModifierKey] | None = None - ) -> None: - """ - Press and hold a keyboard key. - - Args: - key (PcKey | ModifierKey): The key to press. - modifier_keys (list[ModifierKey] | None, optional): List of modifier keys to - press along with the main key. Defaults to `None`. - """ - self._reporter.add_message( - "AgentOS", f'keyboard_pressed("{key}", {modifier_keys})' - ) - if modifier_keys is None: - modifier_keys = [] - self._run_recorder_action( - acion_class_id=controller_v1_pbs.ActionClassID_KeyboardKey_Press, - action_parameters=controller_v1_pbs.ActionParameters( - keyboardKeyPress=controller_v1_pbs.ActionParameters_KeyboardKey_Press( - keyName=key, modifierKeyNames=modifier_keys - ) - ), - ) - - @telemetry.record_call() - @override - def keyboard_release( - self, key: PcKey | ModifierKey, modifier_keys: list[ModifierKey] | None = None - ) -> None: - """ - Release a keyboard key. - - Args: - key (PcKey | ModifierKey): The key to release. - modifier_keys (list[ModifierKey] | None, optional): List of modifier keys to - release along with the main key. Defaults to `None`. - """ - self._reporter.add_message( - "AgentOS", f'keyboard_release("{key}", {modifier_keys})' - ) - if modifier_keys is None: - modifier_keys = [] - self._run_recorder_action( - acion_class_id=controller_v1_pbs.ActionClassID_KeyboardKey_Release, - action_parameters=controller_v1_pbs.ActionParameters( - keyboardKeyRelease=controller_v1_pbs.ActionParameters_KeyboardKey_Release( - keyName=key, modifierKeyNames=modifier_keys - ) - ), - ) - - @telemetry.record_call() - @override - def keyboard_tap( - self, - key: PcKey | ModifierKey, - modifier_keys: list[ModifierKey] | None = None, - count: int = 1, - ) -> None: - """ - Press and immediately release a keyboard key. - - Args: - key (PcKey | ModifierKey): The key to tap. - modifier_keys (list[ModifierKey] | None, optional): List of modifier keys to - press along with the main key. Defaults to `None`. - count (int, optional): The number of times to tap the key. Defaults to `1`. - """ - self._reporter.add_message( - "AgentOS", - f'keyboard_tap("{key}", {modifier_keys}, {count})', - ) - if modifier_keys is None: - modifier_keys = [] - for _ in range(count): - self._run_recorder_action( - acion_class_id=controller_v1_pbs.ActionClassID_KeyboardKey_PressAndRelease, - action_parameters=controller_v1_pbs.ActionParameters( - keyboardKeyPressAndRelease=controller_v1_pbs.ActionParameters_KeyboardKey_PressAndRelease( - keyName=key, modifierKeyNames=modifier_keys - ) - ), - ) - - @telemetry.record_call() - @override - def set_display(self, display: int = 1) -> None: - """ - Set the active display. - - Args: - display (int, optional): The display ID to set as active. - This can be either a real display ID or a virtual display ID. - Defaults to `1`. - """ - self._get_stub().SetActiveDisplay( - controller_v1_pbs.Request_SetActiveDisplay(displayID=display) - ) - self._display = display - self._reporter.add_message("AgentOS", f"set_display({display})") - - @telemetry.record_call(exclude={"command"}) - @override - def run_command(self, command: str, timeout_ms: int = 30000) -> None: - """ - Execute a shell command. - - Args: - command (str): The command to execute. - timeout_ms (int, optional): The timeout for command - execution in milliseconds. Defaults to `30000` (30 seconds). - """ - self._reporter.add_message("AgentOS", f'run_command("{command}", {timeout_ms})') - self._run_recorder_action( - acion_class_id=controller_v1_pbs.ActionClassID_RunCommand, - action_parameters=controller_v1_pbs.ActionParameters( - runcommand=controller_v1_pbs.ActionParameters_RunCommand( - command=command, timeoutInMilliseconds=timeout_ms - ) - ), - ) - - @telemetry.record_call() - @override - def retrieve_active_display(self) -> Display: - """ - Retrieve the currently active display/screen. - - Returns: - Display: The currently active display/screen. - """ - self._reporter.add_message("AgentOS", "retrieve_active_display()") - displays_list_response = self.list_displays() - for display in displays_list_response.data: - if display.id == self._display: - self._reporter.add_message( - "AgentOS", f"retrieve_active_display() -> {display}" - ) - return display - error_msg = f"Display {self._display} not found" - raise ValueError(error_msg) - - @telemetry.record_call() - @override - def list_displays( - self, - ) -> DisplaysListResponse: - """ - List all available Displays from the controller. - It includes both real and virtual displays - without describing the type of display (virtual or real). - - Returns: - DisplaysListResponse - """ - - self._reporter.add_message("AgentOS", "list_displays()") - - response: controller_v1_pbs.Response_GetDisplayInformation = ( - self._get_stub().GetDisplayInformation(controller_v1_pbs.Request_Void()) - ) - - response_dict = MessageToDict( - response, - preserving_proto_field_name=True, - ) - - displays = DisplaysListResponse.model_validate(response_dict) - - self._reporter.add_message("AgentOS", f"list_displays() ->{str(displays)}") - - return displays - - @telemetry.record_call() - def get_process_list( - self, get_extended_info: bool = False - ) -> controller_v1_pbs.Response_GetProcessList: - """ - Get a list of running processes. - - Args: - get_extended_info (bool, optional): Whether to include - extended process information. - Defaults to `False`. - - Returns: - controller_v1_pbs.Response_GetProcessList: Process list response containing: - - processes: List of ProcessInfo objects - """ - - self._reporter.add_message("AgentOS", f"get_process_list({get_extended_info})") - - response: controller_v1_pbs.Response_GetProcessList = ( - self._get_stub().GetProcessList( - controller_v1_pbs.Request_GetProcessList( - getExtendedInfo=get_extended_info - ) - ) - ) - self._reporter.add_message( - "AgentOS", f"get_process_list({get_extended_info}) -> {response}" - ) - - return response - - @telemetry.record_call() - def get_window_list( - self, process_id: int - ) -> controller_v1_pbs.Response_GetWindowList: - """ - Get a list of windows for a specific process. - - Args: - process_id (int): The ID of the process to get windows for. - - Returns: - controller_v1_pbs.Response_GetWindowList: Window list response containing: - - windows: List of WindowInfo objects with ID and name - """ - - self._reporter.add_message("AgentOS", f"get_window_list({process_id})") - - response: controller_v1_pbs.Response_GetWindowList = ( - self._get_stub().GetWindowList( - controller_v1_pbs.Request_GetWindowList(processID=process_id) - ) - ) - - self._reporter.add_message( - "AgentOS", f"get_window_list({process_id}) -> {response}" - ) - - return response - - @telemetry.record_call() - def get_automation_target_list( - self, - ) -> controller_v1_pbs.Response_GetAutomationTargetList: - """ - Get a list of available automation targets. - - Returns: - controller_v1_pbs.Response_GetAutomationTargetList: - Automation target list response: - - targets: List of AutomationTarget objects - """ - - self._reporter.add_message("AgentOS", "get_automation_target_list()") - - response: controller_v1_pbs.Response_GetAutomationTargetList = ( - self._get_stub().GetAutomationTargetList(controller_v1_pbs.Request_Void()) - ) - self._reporter.add_message( - "AgentOS", f"get_automation_target_list() -> {response}" - ) - - return response - - @telemetry.record_call() - def set_mouse_delay(self, delay_ms: int) -> None: - """ - Configure mouse action delay. - - Args: - delay_ms (int): The delay in milliseconds to set for mouse actions. - """ - - self._reporter.add_message("AgentOS", f"set_mouse_delay({delay_ms})") - - self._get_stub().SetMouseDelay( - controller_v1_pbs.Request_SetMouseDelay( - sessionInfo=self._session_info, delayInMilliseconds=delay_ms - ) - ) - - @telemetry.record_call() - def set_keyboard_delay(self, delay_ms: int) -> None: - """ - Configure keyboard action delay. - - Args: - delay_ms (int): The delay in milliseconds to set for keyboard actions. - """ - - self._reporter.add_message("AgentOS", f"set_keyboard_delay({delay_ms})") - - self._get_stub().SetKeyboardDelay( - controller_v1_pbs.Request_SetKeyboardDelay( - sessionInfo=self._session_info, delayInMilliseconds=delay_ms - ) - ) - - @telemetry.record_call() - def set_active_window(self, process_id: int, window_id: int) -> int: - """ - Set the active window for automation. - Adds the window as a virtual display and returns the display ID. - It raises an error if display length is not increased after adding the window. - - Args: - process_id (int): The ID of the process that owns the window. - window_id (int): The ID of the window to set as active. - - returns: - int: The new Display ID. - Raises: - AskUiControllerError: - If display length is not increased after adding the window. - """ - - self._reporter.add_message( - "AgentOS", f"set_active_window({process_id}, {window_id})" - ) - - display_length_before_adding_window = len(self.list_displays().data) - - self._get_stub().SetActiveWindow( - controller_v1_pbs.Request_SetActiveWindow( - processID=process_id, windowID=window_id - ) - ) - new_display_length = len(self.list_displays().data) - if new_display_length <= display_length_before_adding_window: - msg = f"Failed to set active window {window_id} for process {process_id}" - raise AskUiControllerError(msg) - self._reporter.add_message( - "AgentOS", - f"set_active_window({process_id}, {window_id}) -> {new_display_length}", - ) - return new_display_length - - @telemetry.record_call() - def set_active_automation_target(self, target_id: int) -> None: - """ - Set the active automation target. - - Args: - target_id (int): The ID of the automation target to set as active. - """ - - self._reporter.add_message( - "AgentOS", f"set_active_automation_target({target_id})" - ) - - self._get_stub().SetActiveAutomationTarget( - controller_v1_pbs.Request_SetActiveAutomationTarget(ID=target_id) - ) - - @telemetry.record_call() - def schedule_batched_action( - self, - action_class_id: controller_v1_pbs.ActionClassID, - action_parameters: controller_v1_pbs.ActionParameters, - ) -> controller_v1_pbs.Response_ScheduleBatchedAction: - """ - Schedule an action for batch execution. - - Args: - action_class_id (controller_v1_pbs.ActionClassID): The class ID - of the action to schedule. - action_parameters (controller_v1_pbs.ActionParameters): - Parameters for the action. - - Returns: - controller_v1_pbs.Response_ScheduleBatchedAction: Response containing - the scheduled action ID. - """ - - self._reporter.add_message( - "AgentOS", - f"schedule_batched_action({action_class_id}, {action_parameters})", - ) - - response: controller_v1_pbs.Response_ScheduleBatchedAction = ( - self._get_stub().ScheduleBatchedAction( - controller_v1_pbs.Request_ScheduleBatchedAction( - sessionInfo=self._session_info, - actionClassID=action_class_id, - actionParameters=action_parameters, - ) - ) - ) - - return response - - @telemetry.record_call() - def start_batch_run(self) -> None: - """ - Start executing batched actions. - """ - - self._reporter.add_message("AgentOS", "start_batch_run()") - - self._get_stub().StartBatchRun( - controller_v1_pbs.Request_StartBatchRun(sessionInfo=self._session_info) - ) - - @telemetry.record_call() - def stop_batch_run(self) -> None: - """ - Stop executing batched actions. - """ - - self._reporter.add_message("AgentOS", "stop_batch_run()") - - self._get_stub().StopBatchRun( - controller_v1_pbs.Request_StopBatchRun(sessionInfo=self._session_info) - ) - - @telemetry.record_call() - def get_action_count(self) -> controller_v1_pbs.Response_GetActionCount: - """ - Get the count of recorded or batched actions. - - Returns: - controller_v1_pbs.Response_GetActionCount: Response - containing the action count. - """ - - response: controller_v1_pbs.Response_GetActionCount = ( - self._get_stub().GetActionCount( - controller_v1_pbs.Request_GetActionCount(sessionInfo=self._session_info) - ) - ) - self._reporter.add_message("AgentOS", f"get_action_count() -> {response}") - return response - - @telemetry.record_call() - def get_action(self, action_index: int) -> controller_v1_pbs.Response_GetAction: - """ - Get a specific action by its index. - - Args: - action_index (int): The index of the action to retrieve. - - Returns: - controller_v1_pbs.Response_GetAction: Action information containing: - - actionID: The action ID - - actionClassID: The action class ID - - actionParameters: The action parameters - """ - - self._reporter.add_message("AgentOS", f"get_action({action_index})") - - response: controller_v1_pbs.Response_GetAction = self._get_stub().GetAction( - controller_v1_pbs.Request_GetAction( - sessionInfo=self._session_info, actionIndex=action_index - ) - ) - - return response - - @telemetry.record_call() - def remove_action(self, action_id: int) -> None: - """ - Remove a specific action by its ID. - - Args: - action_id (int): The ID of the action to remove. - """ - - self._reporter.add_message("AgentOS", f"remove_action({action_id})") - - self._get_stub().RemoveAction( - controller_v1_pbs.Request_RemoveAction( - sessionInfo=self._session_info, actionID=action_id - ) - ) - - @telemetry.record_call() - def remove_all_actions(self) -> None: - """ - Clear all recorded or batched actions. - """ - - self._reporter.add_message("AgentOS", "remove_all_actions()") - - self._get_stub().RemoveAllActions( - controller_v1_pbs.Request_RemoveAllActions(sessionInfo=self._session_info) - ) - - def _send_command(self, command: Command) -> AskUIAgentOSSendResponseSchema: - """ - Send a general command to the controller. - - Args: - command (Command): The command to send to the controller. - - Returns: - AskUIAgentOSSendResponseSchema: Response containing - the message from the controller. - - Raises: - AskUiControllerInvalidCommandError: If the command fails schema validation - on the server side. - """ - - header = Header(authentication=Guid(root=self._session_guid)) - message = Message(header=header, command=command) - - request = AskUIAgentOSSendRequestSchema(message=message) - - request_str = request.model_dump_json(exclude_none=True, by_alias=True) - - try: - response: controller_v1_pbs.Response_Send = self._get_stub().Send( - controller_v1_pbs.Request_Send(message=request_str) - ) - except grpc.RpcError as e: - if e.code() == grpc.StatusCode.INVALID_ARGUMENT: - details = e.details() or None - raise AskUiControllerInvalidCommandError(details) from e - raise - - return AskUIAgentOSSendResponseSchema.model_validate_json(response.message) - - @telemetry.record_call() - def get_mouse_position(self) -> Coordinate: - """ - Get the mouse cursor position - - Returns: - Coordinate: Response containing the result of the mouse position change. - """ - self._reporter.add_message("AgentOS", "get_mouse_position()") - res = self._send_command(GetMousePositionCommand()) - coordinate = Coordinate( - x=res.message.command.response.position.x.root, # type: ignore[union-attr] - y=res.message.command.response.position.y.root, # type: ignore[union-attr] - ) - self._reporter.add_message("AgentOS", f"get_mouse_position() -> {coordinate}") - return coordinate - - @telemetry.record_call() - def set_mouse_position(self, x: int, y: int) -> None: - """ - Set the mouse cursor position to specific coordinates. - - Args: - x (int): The horizontal coordinate (in pixels) to set the cursor to. - y (int): The vertical coordinate (in pixels) to set the cursor to. - """ - location = Location(x=Length(root=x), y=Length(root=y)) - command = SetMousePositionCommand(parameters=[location]) - self._reporter.add_message("AgentOS", f"set_mouse_position({x},{y})") - self._send_command(command) - - @telemetry.record_call() - def render_quad(self, style: RenderObjectStyle) -> int: - """ - Render a quad object to the display. - - Args: - style (RenderObjectStyle): The style properties for the quad. - - Returns: - int: Object ID. - """ - self._reporter.add_message("AgentOS", f"render_quad({style})") - command = AddRenderObjectCommand(parameters=["Quad", style]) - res = self._send_command(command) - return int(res.message.command.response.id.root) # type: ignore[union-attr] - - @telemetry.record_call() - def render_line(self, style: RenderObjectStyle, points: list[Coordinate]) -> int: - """ - Render a line object to the display. - - Args: - style (RenderObjectStyle): The style properties for the line. - points (list[Coordinates]): The points defining the line. - - Returns: - int: Object ID. - """ - self._reporter.add_message("AgentOS", f"render_line({style}, {points})") - command = AddRenderObjectCommand(parameters=["Line", style, points]) - res = self._send_command(command) - return int(res.message.command.response.id.root) # type: ignore[union-attr] - - @telemetry.record_call(exclude={"image_data"}) - def render_image(self, style: RenderObjectStyle, image_data: str) -> int: - """ - Render an image object to the display. - - Args: - style (RenderObjectStyle): The style properties for the image. - image_data (str): The base64-encoded image data. - - Returns: - int: Object ID. - """ - self._reporter.add_message("AgentOS", f"render_image({style}, [image_data])") - image = RenderImage(root=image_data) - command = AddRenderObjectCommand(parameters=["Image", style, image]) - res = self._send_command(command) - - return int(res.message.command.response.id.root) # type: ignore[union-attr] - - @telemetry.record_call() - def render_text(self, style: RenderObjectStyle, content: str) -> int: - """ - Render a text object to the display. - - Args: - style (RenderObjectStyle): The style properties for the text. - content (str): The text content to display. - - Returns: - int: Object ID. - """ - self._reporter.add_message("AgentOS", f"render_text({style}, {content})") - text = RenderText(root=content) - command = AddRenderObjectCommand(parameters=["Text", style, text]) - res = self._send_command(command) - return int(res.message.command.response.id.root) # type: ignore[union-attr] - - @telemetry.record_call() - def update_render_object(self, object_id: int, style: RenderObjectStyle) -> None: - """ - Update styling properties of an existing render object. - - Args: - object_id (float): The ID of the render object to update. - style (RenderObjectStyle): The new style properties. - - Returns: - int: Object ID. - """ - self._reporter.add_message( - "AgentOS", f"update_render_object({object_id}, {style})" - ) - render_object_id = RenderObjectId(root=object_id) - command = UpdateRenderObjectCommand(parameters=[render_object_id, style]) - self._send_command(command) - - @telemetry.record_call() - def delete_render_object(self, object_id: int) -> None: - """ - Delete an existing render object from the display. - - Args: - object_id (RenderObjectId): The ID of the render object to delete. - """ - self._reporter.add_message("AgentOS", f"delete_render_object({object_id})") - render_object_id = RenderObjectId(root=object_id) - command = DeleteRenderObjectCommand(parameters=[render_object_id]) - self._send_command(command) - - @telemetry.record_call() - def clear_render_objects(self) -> None: - """ - Clear all render objects from the display. - """ - self._reporter.add_message("AgentOS", "clear_render_objects()") - command = ClearRenderObjectsCommand() - self._send_command(command) - - def get_system_info(self) -> GetSystemInfoResponseModel: - """ - Get the system information. - - Returns: - SystemInfo: The system information. - """ - assert isinstance(self._stub, controller_v1.ControllerAPIStub), ( - "Stub is not initialized" - ) - self._reporter.add_message("AgentOS", "get_system_info()") - command = GetSystemInfoCommand() - res = self._send_command(command).message.command - if not isinstance(res, GetSystemInfoResponse): - message = f"unexpected response type: {res}" - raise DesktopAgentOsError(message) - self._reporter.add_message("AgentOS", f"get_system_info() -> {res.response}") - return res.response - - def get_active_process(self) -> GetActiveProcessResponseModel: - """ - Get the active process. - - Returns: - GetActiveProcessResponseModel: The active process. - """ - assert isinstance(self._stub, controller_v1.ControllerAPIStub), ( - "Stub is not initialized" - ) - self._reporter.add_message("AgentOS", "get_active_process()") - command = GetActiveProcessCommand() - res = self._send_command(command).message.command - if not isinstance(res, GetActiveProcessResponse): - message = f"unexpected response type: {res}" - raise DesktopAgentOsError(message) - self._reporter.add_message("AgentOS", f"get_active_process() -> {res.response}") - return res.response - - def set_active_process(self, process_id: int) -> None: - """ - Set the active process. - - Args: - process_id (int): The ID of the process to set as active. - """ - assert isinstance(self._stub, controller_v1.ControllerAPIStub), ( - "Stub is not initialized" - ) - self._reporter.add_message("AgentOS", f"set_active_process({process_id})") - _process_id = Parameter3(root=process_id) - command = SetActiveProcessCommand(parameters=[_process_id]) - self._send_command(command) - - def get_active_window(self) -> GetActiveWindowResponseModel: - """ - Gets the window id and name in addition to the process id - and name of the currently active window (in focus). - - - Returns: - GetActiveWindowResponseModel: The active window. - """ - assert isinstance(self._stub, controller_v1.ControllerAPIStub), ( - "Stub is not initialized" - ) - self._reporter.add_message("AgentOS", "get_active_window()") - command = GetActiveWindowCommand() - res = self._send_command(command).message.command - if not isinstance(res, GetActiveWindowResponse): - message = f"unexpected response type: {res}" - raise DesktopAgentOsError(message) - self._reporter.add_message("AgentOS", f"get_active_window() -> {res.response}") - return res.response - - def set_window_in_focus(self, process_id: int, window_id: int) -> None: - """ - Sets the window with the specified windowId of the process - with the specified processId active, - which brings it to the front and gives it focus. - - Args: - process_id (int): The ID of the process that owns the window. - window_id (int): The ID of the window to set as active. - """ - assert isinstance(self._stub, controller_v1.ControllerAPIStub), ( - "Stub is not initialized" - ) - self._reporter.add_message( - "AgentOS", f"set_window_in_focus({process_id}, {window_id})" - ) - _process_id = Parameter3(root=process_id) - _window_id = Parameter3(root=window_id) - command = SetActiveWindowCommand(parameters=[_process_id, _window_id]) - self._send_command(command) - - def get_file_names(self, absolute_directory_path: str) -> list[str]: - """ - Get the file names in the given absolute directory on the device under - automation. - - Args: - absolute_directory_path (str): The absolute directory path to list - file names from. - - Returns: - list[str]: The file names returned by the controller. - """ - assert isinstance(self._stub, controller_v1.ControllerAPIStub), ( - "Stub is not initialized" - ) - self._reporter.add_message( - "AgentOS", f"get_file_names({absolute_directory_path})" - ) - command = GetFileNamesCommand(parameters=[absolute_directory_path]) - res = self._send_command(command).message.command - if not isinstance(res, GetFileNamesResponse): - message = f"unexpected response type: {res}" - raise DesktopAgentOsError(message) - if res.error is not None: - raise DesktopAgentOsError(res.error) - if res.response is None: - message = f"{type(res).__name__} is missing both error and response" - raise DesktopAgentOsError(message) - self._reporter.add_message( - "AgentOS", f"get_file_names({absolute_directory_path}) -> {res.response}" - ) - return res.response.fileNames - - def get_file(self, path: str) -> Image.Image | str: - """ - Get the contents of a file at the given path on the device under - automation. - - The controller returns the file as a Base64-encoded string, which is - decoded and returned as `PIL.Image.Image` when the bytes can be opened - as an image (PNG, JPEG, BMP, GIF, WebP, TIFF, ...), or as `str` when - they decode cleanly as UTF-8 text. - - Args: - path (str): The file path to read on the device under automation. - - Returns: - Image.Image | str: The decoded file contents. - - Raises: - DesktopAgentOsError: If the file cannot be read or the response is invalid. - """ - assert isinstance(self._stub, controller_v1.ControllerAPIStub), ( - "Stub is not initialized" - ) - self._reporter.add_message("AgentOS", f"get_file({path})") - command = GetFileCommand(parameters=[path]) - res = self._send_command(command).message.command - if not isinstance(res, GetFileResponse): - message = f"unexpected response type: {res}" - raise DesktopAgentOsError(message) - if res.error is not None: - raise DesktopAgentOsError(res.error) - if res.response is None: - message = f"{type(res).__name__} is missing both error and response" - raise DesktopAgentOsError(message) - decoded = self._decode_file_payload(res.response.file.content) - if isinstance(decoded, Image.Image): - detail = f"image ({decoded.format}, {decoded.size[0]}x{decoded.size[1]})" - self._reporter.add_message( - "AgentOS", f"get_file({path}) -> {detail}", decoded - ) - return decoded - - detail = f"text ({len(decoded)} chars)" - self._reporter.add_message("AgentOS", f"get_file({path}) -> {detail}") - return decoded - - def remove_virtual_displays(self) -> None: - """ - Remove all virtual displays from the controller, leaving only real - displays active. - """ - assert isinstance(self._stub, controller_v1.ControllerAPIStub), ( - "Stub is not initialized" - ) - self._reporter.add_message("AgentOS", "remove_virtual_displays()") - command = RemoveVirtualDisplaysCommand() - self._send_command(command) - self._reporter.add_message("AgentOS", "remove_virtual_displays() -> done") - - @staticmethod - def _decode_file_payload(base64_data: str) -> Image.Image | str: - try: - return base64_to_image(base64_data) - except ValueError: - pass - data = base64.b64decode(base64_data, validate=True) - if b"\x00" not in data: - try: - return data.decode("utf-8") - except UnicodeDecodeError: - pass - message = "File contents are neither a supported image nor UTF-8 text" - raise DesktopAgentOsError(message) +import base64 +import time +import types +from collections.abc import Iterator +from contextlib import contextmanager +from typing import Literal, Type + +import grpc +from google.protobuf.json_format import MessageToDict +from PIL import Image +from typing_extensions import Self, override + +from askui.container import telemetry +from askui.reporting import NULL_REPORTER, Reporter +from askui.tools.agent_os import ( + ComputerAgentOS, + Coordinate, + Display, + DisplaysListResponse, + ModifierKey, + PcKey, +) +from askui.tools.askui.agent_os_target_computer import ( + ComputerTarget, + LocalComputerTarget, +) +from askui.tools.askui.askui_ui_controller_grpc.desktop_agent_os_error import ( + DesktopAgentOsError, +) +from askui.tools.askui.askui_ui_controller_grpc.generated import ( + Controller_V1_pb2 as controller_v1_pbs, +) +from askui.tools.askui.askui_ui_controller_grpc.generated import ( + Controller_V1_pb2_grpc as controller_v1, +) +from askui.tools.askui.askui_ui_controller_grpc.generated.AgentOS_Send_Request_2501 import ( # noqa: E501 + AddRenderObjectCommand, + AskUIAgentOSSendRequestSchema, + ClearRenderObjectsCommand, + Command, + DeleteRenderObjectCommand, + GetActiveProcessCommand, + GetActiveWindowCommand, + GetFileCommand, + GetFileNamesCommand, + GetMousePositionCommand, + GetSystemInfoCommand, + Guid, + Header, + Length, + Location, + Message, + Parameter3, + RemoveVirtualDisplaysCommand, + RenderImage, + RenderObjectId, + RenderObjectStyle, + RenderText, + SetActiveProcessCommand, + SetActiveWindowCommand, + SetMousePositionCommand, + UpdateRenderObjectCommand, +) +from askui.tools.askui.askui_ui_controller_grpc.generated.AgentOS_Send_Response_2501 import ( # noqa: E501 + AskUIAgentOSSendResponseSchema, + GetActiveProcessResponse, + GetActiveProcessResponseModel, + GetActiveWindowResponse, + GetActiveWindowResponseModel, + GetFileNamesResponse, + GetFileResponse, + GetSystemInfoResponse, + GetSystemInfoResponseModel, +) +from askui.tools.askui.computer_target_pool import ( + ComputerTargetPool, +) +from askui.utils.annotated_image import AnnotatedImage +from askui.utils.image_utils import base64_to_image + +from .exceptions import ( + AskUiControllerError, + AskUiControllerInvalidCommandError, + AskUiControllerOperationTimeoutError, +) + + +class MultiComputerTargetAgentOS(ComputerAgentOS): + """ + Implementation of `ComputerAgentOS` that communicates with one or more + computer targets (AskUI Remote Device Controller processes) via gRPC. + + A client is configured with a non-empty list of `agent_os_target_computers` + (at most one local, the rest remote with unique addresses). `connect()` opens + a gRPC channel and session for *every* registered target. Exactly one target + is *active* at a time; agent-os actions are routed to its connection. + `disconnect()` closes every open connection and stops only those local + processes that were started by this client (i.e. `is_local` and not + `is_service` at connect time). + + Use `add_agent_os_target_computer` to register additional targets (which + auto-connect if the client is currently connected), + `switch_agent_os_target_computer` to change the active one, + `describe_agent_os_target_computers` to inspect the registered targets, and + `reset_agent_os_target_computers` to clear or replace the list. + + Args: + reporter (Reporter): Reporter used for reporting with the `"AgentOS"`. + display (int, optional): Display number to use. Defaults to `1`. + agent_os_target_computers (list[ComputerTarget] | None, optional): + Computer targets to register. Must be non-empty if provided, contain + at most one local target, and have unique addresses across remote + targets. If `None` (default), a single `LocalComputerTarget` + with default settings is registered. + """ + + _REPORTER_SOURCE = "AgentOS" + + @telemetry.record_call(exclude={"reporter", "agent_os_target_computers"}) + def __init__( + self, + reporter: Reporter = NULL_REPORTER, + display: int = 1, + agent_os_target_computers: list[ComputerTarget] | None = None, + ) -> None: + if not agent_os_target_computers: + agent_os_target_computers = [LocalComputerTarget(display=display)] + + self._pre_action_wait = 0 + self._post_action_wait = 0.05 + self._max_retries = 10 + self._reporter = reporter + self._manager = ComputerTargetPool( + agent_os_target_computers=agent_os_target_computers + ) + + @property + def agent_os_target_computer_manager(self) -> ComputerTargetPool: + """The underlying target-computer manager.""" + return self._manager + + @property + def is_connected(self) -> bool: + """`True` when at least one target-computer connection is open.""" + return self._manager.is_connected + + def _require_active_agent_os_target_computer(self) -> ComputerTarget: + return self._manager.require_active() + + @property + def _session_info(self) -> controller_v1_pbs.SessionInfo: + return self._manager.active_connection().session_info + + @telemetry.record_call(exclude={"agent_os_target_computer"}) + @override + def add_agent_os_target_computer( + self, agent_os_target_computer: ComputerTarget + ) -> ComputerTarget: + """ + Register an already-constructed target computer. Auto-connects if the + client is currently connected. + """ + self._reporter.add_message( + self._REPORTER_SOURCE, + f"add_agent_os_target_computer({agent_os_target_computer!r})", + ) + self._manager.add(agent_os_target_computer) + return agent_os_target_computer + + @telemetry.record_call(exclude={"agent_os_target_computers"}) + @override + def reset_agent_os_target_computers( + self, + agent_os_target_computers: list[ComputerTarget] | None = None, + ) -> None: + """ + Disconnect (if connected) and replace the target computer list. + + Args: + agent_os_target_computers (list[ComputerTarget] | None, optional): + New list of target computers to register after the reset. If + `None`, the list is left empty and a subsequent `connect()` will + fail until at least one target has been registered again. Same + validation rules as the constructor (at most one local, unique + remote addresses). + """ + self._reporter.add_message( + self._REPORTER_SOURCE, + f"reset_agent_os_target_computers({agent_os_target_computers!r})", + ) + was_connected = self.is_connected + if was_connected: + self.disconnect() + self._manager.reset() + if agent_os_target_computers is not None: + for agent_os_target_computer in agent_os_target_computers: + self._manager.add(agent_os_target_computer) + if was_connected: + self.connect() + + @telemetry.record_call() + @override + def describe_agent_os_target_computers(self) -> list[str]: + """Return the `repr()` string of every registered target computer.""" + self._reporter.add_message( + self._REPORTER_SOURCE, "describe_agent_os_target_computers()" + ) + agent_os_target_computer_reprs = self._manager.describe() + self._reporter.add_message( + self._REPORTER_SOURCE, + "describe_agent_os_target_computers() -> " + f"{agent_os_target_computer_reprs!r}", + ) + return agent_os_target_computer_reprs + + @telemetry.record_call() + @override + def get_current_computer_target_id(self, report: bool = True) -> str: + """Return the `computer_id` of the currently active Agent OS target computer.""" + if report: + self._reporter.add_message( + self._REPORTER_SOURCE, "get_current_computer_target_id()" + ) + computer_id = self._require_active_agent_os_target_computer().computer_id + if report: + self._reporter.add_message( + self._REPORTER_SOURCE, + f"get_current_computer_target_id() -> {computer_id!r}", + ) + return computer_id + + @telemetry.record_call() + @override + def switch_agent_os_target_computer(self, computer_id: str) -> ComputerTarget: + """ + Switch the active target computer by its `computer_id` (the user-supplied + identifier; defaults to the target's `session_guid` when none was supplied + at construction time). + + Connections to all registered targets stay open across switches; this just + changes which connection routes future agent-os actions. If the target was + added after `connect()` and isn't connected yet, it is connected on switch. + + Args: + computer_id (str): The computer id of the target to switch to. + + Returns: + ComputerTarget: The newly active target computer. + """ + self._reporter.add_message( + self._REPORTER_SOURCE, f"switch_agent_os_target_computer({computer_id!r})" + ) + agent_os_target_computer = self._manager.switch(computer_id) + self._reporter.add_message( + self._REPORTER_SOURCE, + ( + f"switch_agent_os_target_computer({computer_id!r}) -> " + f"{agent_os_target_computer!r}" + ), + ) + return agent_os_target_computer + + @contextmanager + @override + def temporary_select(self, computer_id: str) -> Iterator[Self]: + previous = self._manager.active + self._reporter.add_message( + self._REPORTER_SOURCE, + f"temporary_select({computer_id!r}) [previous={previous!r}]", + ) + self.switch_agent_os_target_computer(computer_id) + try: + yield self + finally: + if previous is not None and previous.computer_id != computer_id: + self.switch_agent_os_target_computer(previous.computer_id) + self._reporter.add_message( + self._REPORTER_SOURCE, + f"temporary_select({computer_id!r}) -> restored", + ) + + @telemetry.record_call() + @override + def connect(self) -> None: + """ + Open a gRPC channel and session to every registered target computer via + the underlying `ComputerTargetPool`. + """ + self._manager.connect() + + def _get_stub(self) -> controller_v1.ControllerAPIStub: + return self._manager.active_connection().stub + + def _run_recorder_action( + self, + acion_class_id: controller_v1_pbs.ActionClassID, + action_parameters: controller_v1_pbs.ActionParameters, + ) -> controller_v1_pbs.Response_RunRecordedAction: + time.sleep(self._pre_action_wait) + response: controller_v1_pbs.Response_RunRecordedAction = ( + self._get_stub().RunRecordedAction( + controller_v1_pbs.Request_RunRecordedAction( + sessionInfo=self._session_info, + actionClassID=acion_class_id, + actionParameters=action_parameters, + ) + ) + ) + + time.sleep((response.requiredMilliseconds / 1000)) + num_retries = 0 + for _ in range(self._max_retries): + poll_response: controller_v1_pbs.Response_Poll = self._get_stub().Poll( + controller_v1_pbs.Request_Poll( + sessionInfo=self._session_info, + pollEventID=controller_v1_pbs.PollEventID.PollEventID_ActionFinished, + ) + ) + if ( + poll_response.pollEventParameters.actionFinished.actionID + == response.actionID + ): + break + time.sleep(self._post_action_wait) + num_retries += 1 + if num_retries == self._max_retries - 1: + agent_os_target_computer = self._require_active_agent_os_target_computer() + timeout_seconds = self._max_retries * self._post_action_wait + timeout_msg = ( + f"Action did not finish on target computer " + f"{agent_os_target_computer.description!r} " + f"(session_guid={agent_os_target_computer.session_guid}) within " + f"{timeout_seconds:.2f}s ({self._max_retries} polls of " + f"{self._post_action_wait:.2f}s). " + f"Action class id: {acion_class_id}." + ) + raise AskUiControllerOperationTimeoutError( + message=timeout_msg, timeout_seconds=timeout_seconds + ) + return response + + @telemetry.record_call() + @override + def disconnect(self) -> None: + """ + Close every open target-computer connection via the underlying + `ComputerTargetPool`. + """ + self._manager.disconnect() + + @telemetry.record_call() + def __enter__(self) -> Self: + """ + Context manager entry point that establishes the connection. + + Returns: + Self: The instance of MultiComputerTargetAgentOS. + """ + self.connect() + return self + + @telemetry.record_call(exclude={"exc_value", "traceback"}) + def __exit__( + self, + exc_type: Type[BaseException] | None, + exc_value: BaseException | None, + traceback: types.TracebackType | None, + ) -> None: + """ + Context manager exit point that disconnects the client. + + Args: + exc_type: The exception type if an exception was raised. + exc_value: The exception value if an exception was raised. + traceback: The traceback if an exception was raised. + """ + self.disconnect() + + @telemetry.record_call() + @override + def screenshot(self, report: bool = True, unscaled: bool = False) -> Image.Image: + """ + Take a screenshot of the current screen. + + Args: + report (bool, optional): Whether to include the screenshot in reporting. + Defaults to `True`. + unscaled (bool, optional): Accepted for interface compatibility. This + client always returns the native screen resolution, so it has no + effect. Defaults to `False`. + + Returns: + Image.Image: A PIL Image object containing the screenshot. + + """ + screenResponse = self._get_stub().CaptureScreen( + controller_v1_pbs.Request_CaptureScreen( + sessionInfo=self._session_info, + captureParameters=controller_v1_pbs.CaptureParameters( + displayID=self._require_active_agent_os_target_computer().display + ), + ) + ) + r, g, b, _ = Image.frombytes( + "RGBA", + (screenResponse.bitmap.width, screenResponse.bitmap.height), + screenResponse.bitmap.data, + ).split() + image = Image.merge("RGB", (b, g, r)) + if report: + self._reporter.add_message(self._REPORTER_SOURCE, "screenshot()", image) + return image + + @telemetry.record_call() + @override + def mouse_move(self, x: int, y: int, duration: int = 500) -> None: + """ + Moves the mouse cursor to specified screen coordinates. + + Args: + x (int): The horizontal coordinate (in pixels) to move to. + y (int): The vertical coordinate (in pixels) to move to. + duration (int): The duration (in ms) the movement should take. + """ + self._reporter.add_message( + self._REPORTER_SOURCE, + f"mouse_move({x}, {y}, duration={duration})", + AnnotatedImage(lambda: self.screenshot(report=False), point_list=[(x, y)]), + ) + self._run_recorder_action( + acion_class_id=controller_v1_pbs.ActionClassID_MouseMove, + action_parameters=controller_v1_pbs.ActionParameters( + mouseMove=controller_v1_pbs.ActionParameters_MouseMove( + position=controller_v1_pbs.Coordinate2(x=x, y=y), + milliseconds=duration, + ) + ), + ) + + @telemetry.record_call(exclude={"text"}) + @override + def type(self, text: str, typing_speed: int = 50) -> None: + """ + Type text at current cursor position as if entered on a keyboard. + + Args: + text (str): The text to type. + typing_speed (int, optional): The speed of typing in characters per second. + Defaults to `50`. + """ + self._reporter.add_message( + self._REPORTER_SOURCE, f'type("{text}", {typing_speed})' + ) + self._run_recorder_action( + acion_class_id=controller_v1_pbs.ActionClassID_KeyboardType_UnicodeText, + action_parameters=controller_v1_pbs.ActionParameters( + keyboardTypeUnicodeText=controller_v1_pbs.ActionParameters_KeyboardType_UnicodeText( + text=text.encode("utf-16-le"), + typingSpeed=typing_speed, + typingSpeedValue=controller_v1_pbs.TypingSpeedValue.TypingSpeedValue_CharactersPerSecond, + ) + ), + ) + + @telemetry.record_call() + @override + def click( + self, button: Literal["left", "middle", "right"] = "left", count: int = 1 + ) -> None: + """ + Click a mouse button. + + Args: + button (Literal["left", "middle", "right"], optional): The mouse button to + click. Defaults to `"left"`. + count (int, optional): Number of times to click. Defaults to `1`. + """ + self._reporter.add_message(self._REPORTER_SOURCE, f'click("{button}", {count})') + mouse_button = None + match button: + case "left": + mouse_button = controller_v1_pbs.MouseButton_Left + case "middle": + mouse_button = controller_v1_pbs.MouseButton_Middle + case "right": + mouse_button = controller_v1_pbs.MouseButton_Right + self._run_recorder_action( + acion_class_id=controller_v1_pbs.ActionClassID_MouseButton_PressAndRelease, + action_parameters=controller_v1_pbs.ActionParameters( + mouseButtonPressAndRelease=controller_v1_pbs.ActionParameters_MouseButton_PressAndRelease( + mouseButton=mouse_button, count=count + ) + ), + ) + + @telemetry.record_call() + @override + def mouse_down(self, button: Literal["left", "middle", "right"] = "left") -> None: + """ + Press and hold a mouse button. + + Args: + button (Literal["left", "middle", "right"], optional): The mouse button to + press. Defaults to `"left"`. + """ + self._reporter.add_message(self._REPORTER_SOURCE, f'mouse_down("{button}")') + mouse_button = None + match button: + case "left": + mouse_button = controller_v1_pbs.MouseButton_Left + case "middle": + mouse_button = controller_v1_pbs.MouseButton_Middle + case "right": + mouse_button = controller_v1_pbs.MouseButton_Right + self._run_recorder_action( + acion_class_id=controller_v1_pbs.ActionClassID_MouseButton_Press, + action_parameters=controller_v1_pbs.ActionParameters( + mouseButtonPress=controller_v1_pbs.ActionParameters_MouseButton_Press( + mouseButton=mouse_button + ) + ), + ) + + @telemetry.record_call() + @override + def mouse_up(self, button: Literal["left", "middle", "right"] = "left") -> None: + """ + Release a mouse button. + + Args: + button (Literal["left", "middle", "right"], optional): The mouse button to + release. Defaults to `"left"`. + """ + self._reporter.add_message(self._REPORTER_SOURCE, f'mouse_up("{button}")') + mouse_button = None + match button: + case "left": + mouse_button = controller_v1_pbs.MouseButton_Left + case "middle": + mouse_button = controller_v1_pbs.MouseButton_Middle + case "right": + mouse_button = controller_v1_pbs.MouseButton_Right + self._run_recorder_action( + acion_class_id=controller_v1_pbs.ActionClassID_MouseButton_Release, + action_parameters=controller_v1_pbs.ActionParameters( + mouseButtonRelease=controller_v1_pbs.ActionParameters_MouseButton_Release( + mouseButton=mouse_button + ) + ), + ) + + @telemetry.record_call() + @override + def mouse_scroll(self, dx: int, dy: int) -> None: + """ + Scroll the mouse wheel. + + Args: + dx (int): The horizontal scroll amount. Positive values scroll right, + negative values scroll left. + dy (int): The vertical scroll amount. Positive values scroll down, + negative values scroll up. + """ + self._reporter.add_message(self._REPORTER_SOURCE, f"mouse_scroll({dx}, {dy})") + if dx != 0: + self._run_recorder_action( + acion_class_id=controller_v1_pbs.ActionClassID_MouseWheelScroll, + action_parameters=controller_v1_pbs.ActionParameters( + mouseWheelScroll=controller_v1_pbs.ActionParameters_MouseWheelScroll( + direction=controller_v1_pbs.MouseWheelScrollDirection.MouseWheelScrollDirection_Horizontal, + deltaType=controller_v1_pbs.MouseWheelDeltaType.MouseWheelDelta_Raw, + delta=dx, + milliseconds=50, + ) + ), + ) + if dy != 0: + self._run_recorder_action( + acion_class_id=controller_v1_pbs.ActionClassID_MouseWheelScroll, + action_parameters=controller_v1_pbs.ActionParameters( + mouseWheelScroll=controller_v1_pbs.ActionParameters_MouseWheelScroll( + direction=controller_v1_pbs.MouseWheelScrollDirection.MouseWheelScrollDirection_Vertical, + deltaType=controller_v1_pbs.MouseWheelDeltaType.MouseWheelDelta_Raw, + delta=dy, + milliseconds=50, + ) + ), + ) + + @telemetry.record_call() + @override + def keyboard_pressed( + self, key: PcKey | ModifierKey, modifier_keys: list[ModifierKey] | None = None + ) -> None: + """ + Press and hold a keyboard key. + + Args: + key (PcKey | ModifierKey): The key to press. + modifier_keys (list[ModifierKey] | None, optional): List of modifier keys to + press along with the main key. Defaults to `None`. + """ + self._reporter.add_message( + self._REPORTER_SOURCE, f'keyboard_pressed("{key}", {modifier_keys})' + ) + if modifier_keys is None: + modifier_keys = [] + self._run_recorder_action( + acion_class_id=controller_v1_pbs.ActionClassID_KeyboardKey_Press, + action_parameters=controller_v1_pbs.ActionParameters( + keyboardKeyPress=controller_v1_pbs.ActionParameters_KeyboardKey_Press( + keyName=key, modifierKeyNames=modifier_keys + ) + ), + ) + + @telemetry.record_call() + @override + def keyboard_release( + self, key: PcKey | ModifierKey, modifier_keys: list[ModifierKey] | None = None + ) -> None: + """ + Release a keyboard key. + + Args: + key (PcKey | ModifierKey): The key to release. + modifier_keys (list[ModifierKey] | None, optional): List of modifier keys to + release along with the main key. Defaults to `None`. + """ + self._reporter.add_message( + self._REPORTER_SOURCE, f'keyboard_release("{key}", {modifier_keys})' + ) + if modifier_keys is None: + modifier_keys = [] + self._run_recorder_action( + acion_class_id=controller_v1_pbs.ActionClassID_KeyboardKey_Release, + action_parameters=controller_v1_pbs.ActionParameters( + keyboardKeyRelease=controller_v1_pbs.ActionParameters_KeyboardKey_Release( + keyName=key, modifierKeyNames=modifier_keys + ) + ), + ) + + @telemetry.record_call() + @override + def keyboard_tap( + self, + key: PcKey | ModifierKey, + modifier_keys: list[ModifierKey] | None = None, + count: int = 1, + ) -> None: + """ + Press and immediately release a keyboard key. + + Args: + key (PcKey | ModifierKey): The key to tap. + modifier_keys (list[ModifierKey] | None, optional): List of modifier keys to + press along with the main key. Defaults to `None`. + count (int, optional): The number of times to tap the key. Defaults to `1`. + """ + self._reporter.add_message( + self._REPORTER_SOURCE, + f'keyboard_tap("{key}", {modifier_keys}, {count})', + ) + if modifier_keys is None: + modifier_keys = [] + for _ in range(count): + self._run_recorder_action( + acion_class_id=controller_v1_pbs.ActionClassID_KeyboardKey_PressAndRelease, + action_parameters=controller_v1_pbs.ActionParameters( + keyboardKeyPressAndRelease=controller_v1_pbs.ActionParameters_KeyboardKey_PressAndRelease( + keyName=key, modifierKeyNames=modifier_keys + ) + ), + ) + + @telemetry.record_call() + @override + def set_display(self, display: int = 1) -> None: + """ + Set the active display. + + Args: + display (int, optional): The display ID to set as active. + This can be either a real display ID or a virtual display ID. + Defaults to `1`. + """ + self._get_stub().SetActiveDisplay( + controller_v1_pbs.Request_SetActiveDisplay(displayID=display) + ) + self._require_active_agent_os_target_computer().display = display + self._reporter.add_message(self._REPORTER_SOURCE, f"set_display({display})") + + @telemetry.record_call(exclude={"command"}) + @override + def run_command(self, command: str, timeout_ms: int = 30000) -> None: + """ + Execute a shell command. + + Args: + command (str): The command to execute. + timeout_ms (int, optional): The timeout for command + execution in milliseconds. Defaults to `30000` (30 seconds). + """ + self._reporter.add_message( + self._REPORTER_SOURCE, f'run_command("{command}", {timeout_ms})' + ) + self._run_recorder_action( + acion_class_id=controller_v1_pbs.ActionClassID_RunCommand, + action_parameters=controller_v1_pbs.ActionParameters( + runcommand=controller_v1_pbs.ActionParameters_RunCommand( + command=command, timeoutInMilliseconds=timeout_ms + ) + ), + ) + + @telemetry.record_call() + @override + def retrieve_active_display(self) -> Display: + """ + Retrieve the currently active display/screen. + + Returns: + Display: The currently active display/screen. + """ + self._reporter.add_message(self._REPORTER_SOURCE, "retrieve_active_display()") + agent_os_target_computer = self._require_active_agent_os_target_computer() + active_display_id = agent_os_target_computer.display + displays_list_response = self.list_displays() + for display in displays_list_response.data: + if display.id == active_display_id: + self._reporter.add_message( + self._REPORTER_SOURCE, f"retrieve_active_display() -> {display}" + ) + return display + available_ids = ( + ", ".join(str(d.id) for d in displays_list_response.data) or "none" + ) + error_msg = ( + f"Display {active_display_id} not found on target computer " + f"{agent_os_target_computer.description!r} " + f"(session_guid={agent_os_target_computer.session_guid}). " + f"Available display ids: {available_ids}. " + "Call `set_display()` with a valid id, or `list_displays()` to inspect." + ) + raise ValueError(error_msg) + + @telemetry.record_call() + @override + def list_displays( + self, + ) -> DisplaysListResponse: + """ + List all available Displays from the controller. + It includes both real and virtual displays + without describing the type of display (virtual or real). + + Returns: + DisplaysListResponse + """ + + self._reporter.add_message(self._REPORTER_SOURCE, "list_displays()") + + response: controller_v1_pbs.Response_GetDisplayInformation = ( + self._get_stub().GetDisplayInformation(controller_v1_pbs.Request_Void()) + ) + + response_dict = MessageToDict( + response, + preserving_proto_field_name=True, + ) + + displays = DisplaysListResponse.model_validate(response_dict) + + self._reporter.add_message( + self._REPORTER_SOURCE, f"list_displays() ->{str(displays)}" + ) + + return displays + + @telemetry.record_call() + def get_process_list( + self, get_extended_info: bool = False + ) -> controller_v1_pbs.Response_GetProcessList: + """ + Get a list of running processes. + + Args: + get_extended_info (bool, optional): Whether to include + extended process information. + Defaults to `False`. + + Returns: + controller_v1_pbs.Response_GetProcessList: Process list response containing: + - processes: List of ProcessInfo objects + """ + + self._reporter.add_message( + self._REPORTER_SOURCE, f"get_process_list({get_extended_info})" + ) + + response: controller_v1_pbs.Response_GetProcessList = ( + self._get_stub().GetProcessList( + controller_v1_pbs.Request_GetProcessList( + getExtendedInfo=get_extended_info + ) + ) + ) + self._reporter.add_message( + self._REPORTER_SOURCE, + f"get_process_list({get_extended_info}) -> {response}", + ) + + return response + + @telemetry.record_call() + def get_window_list( + self, process_id: int + ) -> controller_v1_pbs.Response_GetWindowList: + """ + Get a list of windows for a specific process. + + Args: + process_id (int): The ID of the process to get windows for. + + Returns: + controller_v1_pbs.Response_GetWindowList: Window list response containing: + - windows: List of WindowInfo objects with ID and name + """ + + self._reporter.add_message( + self._REPORTER_SOURCE, f"get_window_list({process_id})" + ) + + response: controller_v1_pbs.Response_GetWindowList = ( + self._get_stub().GetWindowList( + controller_v1_pbs.Request_GetWindowList(processID=process_id) + ) + ) + + self._reporter.add_message( + self._REPORTER_SOURCE, f"get_window_list({process_id}) -> {response}" + ) + + return response + + @telemetry.record_call() + def get_automation_target_list( + self, + ) -> controller_v1_pbs.Response_GetAutomationTargetList: + """ + Get a list of available automation targets. + + Returns: + controller_v1_pbs.Response_GetAutomationTargetList: + Automation target list response: + - targets: List of AutomationTarget objects + """ + + self._reporter.add_message( + self._REPORTER_SOURCE, "get_automation_target_list()" + ) + + response: controller_v1_pbs.Response_GetAutomationTargetList = ( + self._get_stub().GetAutomationTargetList(controller_v1_pbs.Request_Void()) + ) + self._reporter.add_message( + self._REPORTER_SOURCE, f"get_automation_target_list() -> {response}" + ) + + return response + + @telemetry.record_call() + def set_mouse_delay(self, delay_ms: int) -> None: + """ + Configure mouse action delay. + + Args: + delay_ms (int): The delay in milliseconds to set for mouse actions. + """ + + self._reporter.add_message( + self._REPORTER_SOURCE, f"set_mouse_delay({delay_ms})" + ) + + self._get_stub().SetMouseDelay( + controller_v1_pbs.Request_SetMouseDelay( + sessionInfo=self._session_info, delayInMilliseconds=delay_ms + ) + ) + + @telemetry.record_call() + def set_keyboard_delay(self, delay_ms: int) -> None: + """ + Configure keyboard action delay. + + Args: + delay_ms (int): The delay in milliseconds to set for keyboard actions. + """ + + self._reporter.add_message( + self._REPORTER_SOURCE, f"set_keyboard_delay({delay_ms})" + ) + + self._get_stub().SetKeyboardDelay( + controller_v1_pbs.Request_SetKeyboardDelay( + sessionInfo=self._session_info, delayInMilliseconds=delay_ms + ) + ) + + @telemetry.record_call() + def set_active_window(self, process_id: int, window_id: int) -> int: + """ + Set the active window for automation. + Adds the window as a virtual display and returns the display ID. + It raises an error if display length is not increased after adding the window. + + Args: + process_id (int): The ID of the process that owns the window. + window_id (int): The ID of the window to set as active. + + returns: + int: The new Display ID. + Raises: + AskUiControllerError: + If display length is not increased after adding the window. + """ + + self._reporter.add_message( + self._REPORTER_SOURCE, f"set_active_window({process_id}, {window_id})" + ) + + display_length_before_adding_window = len(self.list_displays().data) + + self._get_stub().SetActiveWindow( + controller_v1_pbs.Request_SetActiveWindow( + processID=process_id, windowID=window_id + ) + ) + new_display_length = len(self.list_displays().data) + if new_display_length <= display_length_before_adding_window: + msg = ( + f"Failed to add window {window_id} of process {process_id} as a " + f"virtual display: display count did not increase " + f"({display_length_before_adding_window} -> {new_display_length}). " + "Verify the process and window ids exist and are valid for the " + "active target computer." + ) + raise AskUiControllerError(msg) + self._reporter.add_message( + self._REPORTER_SOURCE, + f"set_active_window({process_id}, {window_id}) -> {new_display_length}", + ) + return new_display_length + + @telemetry.record_call() + def set_active_automation_target(self, target_id: int) -> None: + """ + Set the active automation target. + + Args: + target_id (int): The ID of the automation target to set as active. + """ + + self._reporter.add_message( + self._REPORTER_SOURCE, f"set_active_automation_target({target_id})" + ) + + self._get_stub().SetActiveAutomationTarget( + controller_v1_pbs.Request_SetActiveAutomationTarget(ID=target_id) + ) + + @telemetry.record_call() + def schedule_batched_action( + self, + action_class_id: controller_v1_pbs.ActionClassID, + action_parameters: controller_v1_pbs.ActionParameters, + ) -> controller_v1_pbs.Response_ScheduleBatchedAction: + """ + Schedule an action for batch execution. + + Args: + action_class_id (controller_v1_pbs.ActionClassID): The class ID + of the action to schedule. + action_parameters (controller_v1_pbs.ActionParameters): + Parameters for the action. + + Returns: + controller_v1_pbs.Response_ScheduleBatchedAction: Response containing + the scheduled action ID. + """ + + self._reporter.add_message( + self._REPORTER_SOURCE, + f"schedule_batched_action({action_class_id}, {action_parameters})", + ) + + response: controller_v1_pbs.Response_ScheduleBatchedAction = ( + self._get_stub().ScheduleBatchedAction( + controller_v1_pbs.Request_ScheduleBatchedAction( + sessionInfo=self._session_info, + actionClassID=action_class_id, + actionParameters=action_parameters, + ) + ) + ) + + return response + + @telemetry.record_call() + def start_batch_run(self) -> None: + """ + Start executing batched actions. + """ + + self._reporter.add_message(self._REPORTER_SOURCE, "start_batch_run()") + + self._get_stub().StartBatchRun( + controller_v1_pbs.Request_StartBatchRun(sessionInfo=self._session_info) + ) + + @telemetry.record_call() + def stop_batch_run(self) -> None: + """ + Stop executing batched actions. + """ + + self._reporter.add_message(self._REPORTER_SOURCE, "stop_batch_run()") + + self._get_stub().StopBatchRun( + controller_v1_pbs.Request_StopBatchRun(sessionInfo=self._session_info) + ) + + @telemetry.record_call() + def get_action_count(self) -> controller_v1_pbs.Response_GetActionCount: + """ + Get the count of recorded or batched actions. + + Returns: + controller_v1_pbs.Response_GetActionCount: Response + containing the action count. + """ + + response: controller_v1_pbs.Response_GetActionCount = ( + self._get_stub().GetActionCount( + controller_v1_pbs.Request_GetActionCount(sessionInfo=self._session_info) + ) + ) + self._reporter.add_message( + self._REPORTER_SOURCE, f"get_action_count() -> {response}" + ) + return response + + @telemetry.record_call() + def get_action(self, action_index: int) -> controller_v1_pbs.Response_GetAction: + """ + Get a specific action by its index. + + Args: + action_index (int): The index of the action to retrieve. + + Returns: + controller_v1_pbs.Response_GetAction: Action information containing: + - actionID: The action ID + - actionClassID: The action class ID + - actionParameters: The action parameters + """ + + self._reporter.add_message(self._REPORTER_SOURCE, f"get_action({action_index})") + + response: controller_v1_pbs.Response_GetAction = self._get_stub().GetAction( + controller_v1_pbs.Request_GetAction( + sessionInfo=self._session_info, actionIndex=action_index + ) + ) + + return response + + @telemetry.record_call() + def remove_action(self, action_id: int) -> None: + """ + Remove a specific action by its ID. + + Args: + action_id (int): The ID of the action to remove. + """ + + self._reporter.add_message(self._REPORTER_SOURCE, f"remove_action({action_id})") + + self._get_stub().RemoveAction( + controller_v1_pbs.Request_RemoveAction( + sessionInfo=self._session_info, actionID=action_id + ) + ) + + @telemetry.record_call() + def remove_all_actions(self) -> None: + """ + Clear all recorded or batched actions. + """ + + self._reporter.add_message(self._REPORTER_SOURCE, "remove_all_actions()") + + self._get_stub().RemoveAllActions( + controller_v1_pbs.Request_RemoveAllActions(sessionInfo=self._session_info) + ) + + def _send_command(self, command: Command) -> AskUIAgentOSSendResponseSchema: + """ + Send a general command to the controller. + + Args: + command (Command): The command to send to the controller. + + Returns: + AskUIAgentOSSendResponseSchema: Response containing + the message from the controller. + + Raises: + AskUiControllerInvalidCommandError: If the command fails schema validation + on the target computer side. + """ + + agent_os_target_computer = self._require_active_agent_os_target_computer() + header = Header(authentication=Guid(root=agent_os_target_computer.session_guid)) + message = Message(header=header, command=command) + + request = AskUIAgentOSSendRequestSchema(message=message) + + request_str = request.model_dump_json(exclude_none=True, by_alias=True) + + try: + response: controller_v1_pbs.Response_Send = self._get_stub().Send( + controller_v1_pbs.Request_Send(message=request_str) + ) + except grpc.RpcError as e: + if e.code() == grpc.StatusCode.INVALID_ARGUMENT: + details = e.details() or None + raise AskUiControllerInvalidCommandError(details) from e + raise + + return AskUIAgentOSSendResponseSchema.model_validate_json(response.message) + + @telemetry.record_call() + def get_mouse_position(self) -> Coordinate: + """ + Get the mouse cursor position + + Returns: + Coordinate: Response containing the result of the mouse position change. + """ + self._reporter.add_message(self._REPORTER_SOURCE, "get_mouse_position()") + res = self._send_command(GetMousePositionCommand()) + coordinate = Coordinate( + x=res.message.command.response.position.x.root, # type: ignore[union-attr] + y=res.message.command.response.position.y.root, # type: ignore[union-attr] + ) + self._reporter.add_message( + self._REPORTER_SOURCE, f"get_mouse_position() -> {coordinate}" + ) + return coordinate + + @telemetry.record_call() + def set_mouse_position(self, x: int, y: int) -> None: + """ + Set the mouse cursor position to specific coordinates. + + Args: + x (int): The horizontal coordinate (in pixels) to set the cursor to. + y (int): The vertical coordinate (in pixels) to set the cursor to. + """ + location = Location(x=Length(root=x), y=Length(root=y)) + command = SetMousePositionCommand(parameters=[location]) + self._reporter.add_message( + self._REPORTER_SOURCE, f"set_mouse_position({x},{y})" + ) + self._send_command(command) + + @telemetry.record_call() + def render_quad(self, style: RenderObjectStyle) -> int: + """ + Render a quad object to the display. + + Args: + style (RenderObjectStyle): The style properties for the quad. + + Returns: + int: Object ID. + """ + self._reporter.add_message(self._REPORTER_SOURCE, f"render_quad({style})") + command = AddRenderObjectCommand(parameters=["Quad", style]) + res = self._send_command(command) + return int(res.message.command.response.id.root) # type: ignore[union-attr] + + @telemetry.record_call() + def render_line(self, style: RenderObjectStyle, points: list[Coordinate]) -> int: + """ + Render a line object to the display. + + Args: + style (RenderObjectStyle): The style properties for the line. + points (list[Coordinates]): The points defining the line. + + Returns: + int: Object ID. + """ + self._reporter.add_message( + self._REPORTER_SOURCE, f"render_line({style}, {points})" + ) + command = AddRenderObjectCommand(parameters=["Line", style, points]) + res = self._send_command(command) + return int(res.message.command.response.id.root) # type: ignore[union-attr] + + @telemetry.record_call(exclude={"image_data"}) + def render_image(self, style: RenderObjectStyle, image_data: str) -> int: + """ + Render an image object to the display. + + Args: + style (RenderObjectStyle): The style properties for the image. + image_data (str): The base64-encoded image data. + + Returns: + int: Object ID. + """ + self._reporter.add_message( + self._REPORTER_SOURCE, f"render_image({style}, [image_data])" + ) + image = RenderImage(root=image_data) + command = AddRenderObjectCommand(parameters=["Image", style, image]) + res = self._send_command(command) + + return int(res.message.command.response.id.root) # type: ignore[union-attr] + + @telemetry.record_call() + def render_text(self, style: RenderObjectStyle, content: str) -> int: + """ + Render a text object to the display. + + Args: + style (RenderObjectStyle): The style properties for the text. + content (str): The text content to display. + + Returns: + int: Object ID. + """ + self._reporter.add_message( + self._REPORTER_SOURCE, f"render_text({style}, {content})" + ) + text = RenderText(root=content) + command = AddRenderObjectCommand(parameters=["Text", style, text]) + res = self._send_command(command) + return int(res.message.command.response.id.root) # type: ignore[union-attr] + + @telemetry.record_call() + def update_render_object(self, object_id: int, style: RenderObjectStyle) -> None: + """ + Update styling properties of an existing render object. + + Args: + object_id (float): The ID of the render object to update. + style (RenderObjectStyle): The new style properties. + + Returns: + int: Object ID. + """ + self._reporter.add_message( + self._REPORTER_SOURCE, f"update_render_object({object_id}, {style})" + ) + render_object_id = RenderObjectId(root=object_id) + command = UpdateRenderObjectCommand(parameters=[render_object_id, style]) + self._send_command(command) + + @telemetry.record_call() + def delete_render_object(self, object_id: int) -> None: + """ + Delete an existing render object from the display. + + Args: + object_id (RenderObjectId): The ID of the render object to delete. + """ + self._reporter.add_message( + self._REPORTER_SOURCE, f"delete_render_object({object_id})" + ) + render_object_id = RenderObjectId(root=object_id) + command = DeleteRenderObjectCommand(parameters=[render_object_id]) + self._send_command(command) + + @telemetry.record_call() + def clear_render_objects(self) -> None: + """ + Clear all render objects from the display. + """ + self._reporter.add_message(self._REPORTER_SOURCE, "clear_render_objects()") + command = ClearRenderObjectsCommand() + self._send_command(command) + + def get_system_info(self) -> GetSystemInfoResponseModel: + """ + Get the system information. + + Returns: + SystemInfo: The system information. + """ + self._reporter.add_message(self._REPORTER_SOURCE, "get_system_info()") + command = GetSystemInfoCommand() + res = self._send_command(command).message.command + if not isinstance(res, GetSystemInfoResponse): + message = ( + f"get_system_info: expected GetSystemInfoResponse from the " + f"controller but got {type(res).__name__}: {res!r}" + ) + raise DesktopAgentOsError(message) + self._reporter.add_message( + self._REPORTER_SOURCE, f"get_system_info() -> {res.response}" + ) + return res.response + + def get_active_process(self) -> GetActiveProcessResponseModel: + """ + Get the active process. + + Returns: + GetActiveProcessResponseModel: The active process. + """ + self._reporter.add_message(self._REPORTER_SOURCE, "get_active_process()") + command = GetActiveProcessCommand() + res = self._send_command(command).message.command + if not isinstance(res, GetActiveProcessResponse): + message = ( + f"get_active_process: expected GetActiveProcessResponse from the " + f"controller but got {type(res).__name__}: {res!r}" + ) + raise DesktopAgentOsError(message) + self._reporter.add_message( + self._REPORTER_SOURCE, f"get_active_process() -> {res.response}" + ) + return res.response + + def set_active_process(self, process_id: int) -> None: + """ + Set the active process. + + Args: + process_id (int): The ID of the process to set as active. + """ + self._reporter.add_message( + self._REPORTER_SOURCE, f"set_active_process({process_id})" + ) + _process_id = Parameter3(root=process_id) + command = SetActiveProcessCommand(parameters=[_process_id]) + self._send_command(command) + + def get_active_window(self) -> GetActiveWindowResponseModel: + """ + Gets the window id and name in addition to the process id + and name of the currently active window (in focus). + + + Returns: + GetActiveWindowResponseModel: The active window. + """ + self._reporter.add_message(self._REPORTER_SOURCE, "get_active_window()") + command = GetActiveWindowCommand() + res = self._send_command(command).message.command + if not isinstance(res, GetActiveWindowResponse): + message = ( + f"get_active_window: expected GetActiveWindowResponse from the " + f"controller but got {type(res).__name__}: {res!r}" + ) + raise DesktopAgentOsError(message) + self._reporter.add_message( + self._REPORTER_SOURCE, f"get_active_window() -> {res.response}" + ) + return res.response + + def set_window_in_focus(self, process_id: int, window_id: int) -> None: + """ + Sets the window with the specified windowId of the process + with the specified processId active, + which brings it to the front and gives it focus. + + Args: + process_id (int): The ID of the process that owns the window. + window_id (int): The ID of the window to set as active. + """ + self._reporter.add_message( + self._REPORTER_SOURCE, f"set_window_in_focus({process_id}, {window_id})" + ) + _process_id = Parameter3(root=process_id) + _window_id = Parameter3(root=window_id) + command = SetActiveWindowCommand(parameters=[_process_id, _window_id]) + self._send_command(command) + + @telemetry.record_call() + @override + def get_file_names(self, absolute_directory_path: str) -> list[str]: + """ + Get the file names in the given absolute directory on the device under + automation. + + Args: + absolute_directory_path (str): The absolute directory path to list + file names from. + + Returns: + list[str]: The file names returned by the controller. + """ + self._reporter.add_message( + self._REPORTER_SOURCE, f"get_file_names({absolute_directory_path})" + ) + command = GetFileNamesCommand(parameters=[absolute_directory_path]) + res = self._send_command(command).message.command + if not isinstance(res, GetFileNamesResponse): + message = f"unexpected response type: {res}" + raise DesktopAgentOsError(message) + if res.error is not None: + raise DesktopAgentOsError(res.error) + if res.response is None: + message = f"{type(res).__name__} is missing both error and response" + raise DesktopAgentOsError(message) + self._reporter.add_message( + self._REPORTER_SOURCE, + f"get_file_names({absolute_directory_path}) -> {res.response}", + ) + return res.response.fileNames + + @telemetry.record_call() + @override + def get_file(self, path: str) -> Image.Image | str: + """ + Get the contents of a file at the given path on the device under + automation. + + The controller returns the file as a Base64-encoded string, which is + decoded and returned as `PIL.Image.Image` when the bytes can be opened + as an image (PNG, JPEG, BMP, GIF, WebP, TIFF, ...), or as `str` when + they decode cleanly as UTF-8 text. + + Args: + path (str): The file path to read on the device under automation. + + Returns: + Image.Image | str: The decoded file contents. + + Raises: + DesktopAgentOsError: If the file cannot be read or the response is invalid. + """ + self._reporter.add_message(self._REPORTER_SOURCE, f"get_file({path})") + command = GetFileCommand(parameters=[path]) + res = self._send_command(command).message.command + if not isinstance(res, GetFileResponse): + message = f"unexpected response type: {res}" + raise DesktopAgentOsError(message) + if res.error is not None: + raise DesktopAgentOsError(res.error) + if res.response is None: + message = f"{type(res).__name__} is missing both error and response" + raise DesktopAgentOsError(message) + decoded = self._decode_file_payload(res.response.file.content) + if isinstance(decoded, Image.Image): + detail = f"image ({decoded.format}, {decoded.size[0]}x{decoded.size[1]})" + self._reporter.add_message( + self._REPORTER_SOURCE, f"get_file({path}) -> {detail}", decoded + ) + return decoded + + detail = f"text ({len(decoded)} chars)" + self._reporter.add_message( + self._REPORTER_SOURCE, f"get_file({path}) -> {detail}" + ) + return decoded + + @telemetry.record_call() + @override + def remove_virtual_displays(self) -> None: + """ + Remove all virtual displays from the controller, leaving only real + displays active. + """ + self._reporter.add_message(self._REPORTER_SOURCE, "remove_virtual_displays()") + command = RemoveVirtualDisplaysCommand() + self._send_command(command) + self._reporter.add_message( + self._REPORTER_SOURCE, "remove_virtual_displays() -> done" + ) + + @staticmethod + def _decode_file_payload(base64_data: str) -> Image.Image | str: + try: + return base64_to_image(base64_data) + except ValueError: + pass + data = base64.b64decode(base64_data, validate=True) + if b"\x00" not in data: + try: + return data.decode("utf-8") + except UnicodeDecodeError: + pass + message = "File contents are neither a supported image nor UTF-8 text" + raise DesktopAgentOsError(message) + + +AskUiControllerClient = MultiComputerTargetAgentOS diff --git a/src/askui/tools/askui/askui_controller_client_settings.py b/src/askui/tools/askui/askui_controller_client_settings.py deleted file mode 100644 index 6e53b747..00000000 --- a/src/askui/tools/askui/askui_controller_client_settings.py +++ /dev/null @@ -1,34 +0,0 @@ -from pydantic import Field -from pydantic_settings import BaseSettings, SettingsConfigDict - - -class AskUiControllerClientSettings(BaseSettings): - """ - Settings for the AskUI Remote Device Controller client. - """ - - model_config = SettingsConfigDict( - env_prefix="ASKUI_CONTROLLER_CLIENT_", - ) - - server_address: str = Field( - default="localhost:23000", - description="Address of the AskUI Remote Device Controller server.", - ) - - server_autostart: bool = Field( - default=True, - description="Whether to automatically start the AskUI Remote Device" - "Controller server. Defaults to True.", - ) - - clean_virtual_displays: bool = Field( - default=False, - description=( - "Whether to clean virtual displays after the controller is started." - "Default: False" - ), - ) - - -__all__ = ["AskUiControllerClientSettings"] diff --git a/src/askui/tools/askui/computer_target_connection.py b/src/askui/tools/askui/computer_target_connection.py new file mode 100644 index 00000000..54b8e225 --- /dev/null +++ b/src/askui/tools/askui/computer_target_connection.py @@ -0,0 +1,131 @@ +from __future__ import annotations + +import logging +from dataclasses import dataclass +from typing import TYPE_CHECKING + +import grpc + +from askui.tools.askui.askui_ui_controller_grpc.generated import ( + Controller_V1_pb2 as controller_v1_pbs, +) +from askui.tools.askui.askui_ui_controller_grpc.generated import ( + Controller_V1_pb2_grpc as controller_v1, +) +from askui.tools.askui.exceptions import AskUiControllerError + +if TYPE_CHECKING: + from askui.tools.askui.agent_os_target_computer import ComputerTarget + +logger = logging.getLogger(__name__) + + +@dataclass +class ComputerTargetConnection: + """ + The live gRPC connection to a `ComputerTarget`: the open channel, the + controller stub bound to it, and the session opened on the target computer. + + Holds only the live connection handles; the `ComputerTarget` it belongs to + is passed in when opening or closing. Encapsulates all gRPC specifics so + that `ComputerTarget` and `ComputerTargetPool` stay free of channel / stub / + session details. + + Args: + channel (grpc.Channel): The open gRPC channel. + stub (ControllerAPIStub): The controller API stub bound to `channel`. + session_info (SessionInfo): The session opened on the target computer. + """ + + channel: grpc.Channel + stub: controller_v1.ControllerAPIStub + session_info: controller_v1_pbs.SessionInfo + + @classmethod + def open(cls, target: ComputerTarget) -> ComputerTargetConnection: + """ + Open a gRPC channel and session to `target`. + + Starts the target's local controller process first (a no-op for remote + and service-managed targets), opens an insecure gRPC channel, starts a + session, starts execution, and sets the configured display. + + On failure during session setup, the channel is closed and any started + process is stopped before re-raising. + """ + target.start() + channel = grpc.insecure_channel( + target.address, + options=[ + ("grpc.max_send_message_length", 2**30), + ("grpc.max_receive_message_length", 2**30), + ("grpc.default_deadline", 300000), + ], + ) + stub = controller_v1.ControllerAPIStub(channel) + try: + session_response: controller_v1_pbs.Response_StartSession = ( + stub.StartSession( + controller_v1_pbs.Request_StartSession( + sessionGUID=target.session_guid, + immediateExecution=True, + ) + ) + ) + session_info = session_response.sessionInfo + stub.StartExecution( + controller_v1_pbs.Request_StartExecution(sessionInfo=session_info) + ) + stub.SetActiveDisplay( + controller_v1_pbs.Request_SetActiveDisplay(displayID=target.display) + ) + except Exception as e: + try: + channel.close() + finally: + target.stop() + error_msg = ( + f"Failed to connect to Agent OS target computer " + f"{target.description!r} " + f"(computer_id={target.computer_id!r}, " + f"session_guid={target.session_guid}, " + f"display={target.display}, " + f"address={target.address}): {e}" + ) + raise AskUiControllerError(error_msg) from e + return cls(channel=channel, stub=stub, session_info=session_info) + + def close(self, target: ComputerTarget) -> None: + """ + Close this connection to `target`. + + Stops execution, ends the session, closes the gRPC channel, and stops + the target's local controller process (a no-op unless this client + started one). Errors are logged but never raised, so a partial failure + still releases the rest of the connection. + """ + computer_id = target.computer_id + try: + self.stub.StopExecution( + controller_v1_pbs.Request_StopExecution(sessionInfo=self.session_info) + ) + self.stub.EndSession( + controller_v1_pbs.Request_EndSession(sessionInfo=self.session_info) + ) + except Exception: # noqa: BLE001 + logger.exception( + "Error stopping execution/session for controller %s", computer_id + ) + try: + self.channel.close() + except Exception: # noqa: BLE001 + logger.exception("Error closing channel for controller %s", computer_id) + try: + target.stop() + except Exception: # noqa: BLE001 + logger.exception( + "Error stopping client-started controller process for %s", computer_id + ) + + +__all__ = ["ComputerTargetConnection"] diff --git a/src/askui/tools/askui/computer_target_pool.py b/src/askui/tools/askui/computer_target_pool.py new file mode 100644 index 00000000..7c5ce23d --- /dev/null +++ b/src/askui/tools/askui/computer_target_pool.py @@ -0,0 +1,281 @@ +from askui.tools.askui.agent_os_target_computer import ( + ComputerTarget, +) +from askui.tools.askui.computer_target_connection import ComputerTargetConnection +from askui.tools.askui.exceptions import AskUiControllerError + + +class ComputerTargetPool: + """ + Manages a collection of `ComputerTarget` instances and tracks the currently + active one. Each target owns its own gRPC connection + (`ComputerTarget.connection`); the pool only orchestrates connecting / + disconnecting them and selecting the active one. + + Responsibilities: + - Register / unregister `ComputerTarget` instances with uniqueness + constraints (at most one local, unique computer ids / session GUIDs, + unique remote addresses). + - Drive `connect()` / `disconnect()` on registered targets (individually + or all at once). + - Track which registered target is currently active and expose its + connection needed to route agent-os actions to it. + + The first target added becomes active by default. Use `switch` to change + which target is active. `connect` opens connections to every registered + target; subsequently `add` / `switch` auto-connect any + newly-introduced target whenever the manager already holds at least one + open connection. + + Targets are addressed exclusively by their `computer_id`. + + Args: + agent_os_target_computers (list[ComputerTarget] | None, optional): + Initial targets to register. + """ + + def __init__( + self, + agent_os_target_computers: list[ComputerTarget] | None = None, + ) -> None: + # Single store. Python dicts preserve insertion order, so this also + # defines `list()` order and the first-added-is-active semantics. Each + # target owns its own connection, so no separate connection store is + # needed here. + self._by_computer_id: dict[str, ComputerTarget] = {} + self._active_computer_id: str | None = None + if agent_os_target_computers: + for target in agent_os_target_computers: + self.add(target) + + @property + def is_connected(self) -> bool: + """`True` when at least one registered target has an open connection.""" + return any(t.is_connected for t in self._by_computer_id.values()) + + def add(self, target: ComputerTarget) -> ComputerTarget: + """ + Register an Agent OS target computer. Auto-connects when the manager + already has at least one open connection. + + Args: + target (ComputerTarget): The target computer to register. + + Returns: + ComputerTarget: The registered target. + + Raises: + ValueError: If another local target is already registered, the same + session GUID or computer id is already registered, or another + remote target with the same address is already registered. + """ + self._validate_addable(target) + self._by_computer_id[target.computer_id] = target + if self._active_computer_id is None: + self._active_computer_id = target.computer_id + if self.is_connected: + self.connect_target(target) + return target + + def reset(self) -> None: + """Disconnect every open connection and remove all registered targets.""" + self.disconnect() + self._by_computer_id.clear() + self._active_computer_id = None + + def remove(self, computer_id: str) -> None: + """ + Remove a registered target by its `computer_id`. If the target was + connected, its connection is closed first. + + Args: + computer_id (str): The computer id of the target to remove. + + Raises: + KeyError: If no target with the given computer id is registered. + """ + self._require(computer_id) + self.disconnect_target(computer_id) + del self._by_computer_id[computer_id] + if self._active_computer_id == computer_id: + self._active_computer_id = next(iter(self._by_computer_id), None) + + def describe(self) -> list[str]: + """ + Return the `repr()` of every registered target, in registration order. + """ + return [repr(target) for target in self._by_computer_id.values()] + + def get(self, computer_id: str) -> ComputerTarget: + """ + Return the registered target with the given `computer_id`. + + Raises: + KeyError: If no target with the given computer id is registered. + """ + return self._require(computer_id) + + def switch(self, computer_id: str) -> ComputerTarget: + """ + Set the active target by its `computer_id`. Auto-connects the new + active target when the manager already has at least one open connection + but this target is not yet connected. + + Args: + computer_id (str): The computer id of the target to activate. + + Returns: + ComputerTarget: The newly active target. + + Raises: + KeyError: If no target with the given computer id is registered. + """ + target = self._require(computer_id) + self._active_computer_id = computer_id + if self.is_connected and not target.is_connected: + self.connect_target(target) + return target + + @property + def active(self) -> ComputerTarget | None: + """The currently active target, or `None` if no targets are registered.""" + if self._active_computer_id is None: + return None + return self._by_computer_id.get(self._active_computer_id) + + def require_active(self) -> ComputerTarget: + """ + Return the currently active target. + + Raises: + AskUiControllerError: If no target is currently active. + """ + target = self.active + if target is None: + error_msg = ( + "No active Agent OS target computer. Register one via " + "`MultiComputerTargetAgentOS.add_agent_os_target_computer()`, or " + "pass `agent_os_target_computers` to the " + "`MultiComputerTargetAgentOS` constructor." + ) + raise AskUiControllerError(error_msg) + return target + + def active_connection(self) -> ComputerTargetConnection: + """ + Return the gRPC connection for the currently active target. + + Raises: + AskUiControllerError: If no target is currently active or the active + target has no open connection (i.e. `connect()` has not been + called). + """ + return self.require_active().connection + + def connect(self) -> None: + """ + Open the connection to every registered Agent OS target via + `ComputerTarget.connect()`. Targets already connected are skipped, so + calling `connect()` twice is safe. + + Raises: + AskUiControllerError: If no targets are registered. + + On failure mid-loop, all targets connected so far are rolled back via + `disconnect()` before re-raising. + """ + if not self._by_computer_id: + error_msg = ( + "Cannot connect: no Agent OS target computers registered. Provide " + "at least one via the `MultiComputerTargetAgentOS` constructor's " + "`agent_os_target_computers` argument, or call " + "`add_agent_os_target_computer()` before `connect()`." + ) + raise AskUiControllerError(error_msg) + try: + for target in self._by_computer_id.values(): + self.connect_target(target) + except Exception: + self.disconnect() + raise + + def connect_target(self, target: ComputerTarget) -> None: + """ + Open the connection to a single registered Agent OS target. Idempotent: + returns silently if the target is already connected. Delegates to + `ComputerTarget.connect()`. + """ + target.connect() + + def disconnect(self) -> None: + """ + Close every open Agent OS target connection. Errors on one connection + are logged but do not abort the loop - a partial failure still releases + the others. + """ + for target in self._by_computer_id.values(): + target.disconnect() + + def disconnect_target(self, computer_id: str) -> None: + """ + Close a single open Agent OS target connection identified by its + `computer_id`. No-op if no such connection is open or no such target is + registered. Delegates to `ComputerTarget.disconnect()`. + """ + target = self._by_computer_id.get(computer_id) + if target is not None: + target.disconnect() + + def __len__(self) -> int: + return len(self._by_computer_id) + + def __contains__(self, computer_id: object) -> bool: + return isinstance(computer_id, str) and computer_id in self._by_computer_id + + def _validate_addable(self, target: ComputerTarget) -> None: + if target.is_local: + existing_local = next( + (t for t in self._by_computer_id.values() if t.is_local), None + ) + if existing_local is not None: + error_msg = ( + "Cannot register a second local Agent OS target computer. At " + "most one local target is supported. Existing local target: " + f"{existing_local.description!r} " + f"(computer_id={existing_local.computer_id!r}). " + "Remove it first via `remove(computer_id)`." + ) + raise ValueError(error_msg) + if target.computer_id in self._by_computer_id: + error_msg = ( + "An Agent OS target computer with " + f"computer_id={target.computer_id!r} is already registered. " + "Each target must have a unique computer_id." + ) + raise ValueError(error_msg) + if not target.is_local and any( + (not t.is_local) and t.address == target.address + for t in self._by_computer_id.values() + ): + error_msg = ( + f"A remote Agent OS target computer with address " + f"{target.address!r} is already registered. Each remote target " + "must have a unique address." + ) + raise ValueError(error_msg) + + def _require(self, computer_id: str) -> ComputerTarget: + target = self._by_computer_id.get(computer_id) + if target is not None: + return target + registered = ", ".join(repr(cid) for cid in self._by_computer_id) or "none" + error_msg = ( + f"No Agent OS target computer with computer_id={computer_id!r} is " + f"registered. Registered computer ids: {registered}. Use " + "`describe_agent_os_target_computers()` to inspect the registered " + "targets." + ) + raise KeyError(error_msg) + + +__all__ = ["ComputerTargetPool"] diff --git a/src/askui/tools/askui/exceptions.py b/src/askui/tools/askui/exceptions.py index 1398ff2b..ecfc2c16 100644 --- a/src/askui/tools/askui/exceptions.py +++ b/src/askui/tools/askui/exceptions.py @@ -1,8 +1,9 @@ class AskUiControllerError(Exception): """Base exception for AskUI controller errors. - This exception is raised when there is an error in the AskUI controller (client), - which handles the communication with the AskUI controller (server). + This exception is raised when there is an error in the AskUI controller + client, which handles the communication with the AskUI controller process + running on the target computer. Args: message (str): The error message. @@ -42,7 +43,11 @@ class AskUiControllerOperationTimeoutError(AskUiControllerError): """ def __init__( - self, message: str = "Action not yet done", timeout_seconds: float | None = None + self, + message: str = ( + "Controller action did not finish within the expected time window." + ), + timeout_seconds: float | None = None, ): super().__init__(message) self.timeout_seconds = timeout_seconds @@ -52,21 +57,23 @@ class AskUiControllerInvalidCommandError(AskUiControllerError): """Exception raised when a command sent to the controller is invalid. This exception is raised when a command fails schema validation on the - controller server side, typically due to malformed command structure or + target computer side, typically due to malformed command structure or invalid parameters. Args: - details (str | None): Optional additional error details from the server. + details (str | None): Optional additional error details from the target + computer. """ def __init__(self, details: str | None = None): error_msg = ( - "AgentOS: Command validation failed" - " This error may be resolved by updating the AskUI" - " controller to the latest version." + "AgentOS: command validation failed on the target computer. " + "This is typically caused by a malformed command or a version " + "mismatch; updating the AskUI controller to the latest version " + "may resolve it." ) if details: - error_msg += f"\n{details}" + error_msg += f"\nController details: {details}" super().__init__(error_msg) self.details = details diff --git a/src/askui/tools/computer/__init__.py b/src/askui/tools/computer/__init__.py index 0410151e..b146a31e 100644 --- a/src/askui/tools/computer/__init__.py +++ b/src/askui/tools/computer/__init__.py @@ -1,10 +1,12 @@ from .connect_tool import ComputerConnectTool from .disconnect_tool import ComputerDisconnectTool +from .get_current_computer_target_id_tool import ComputerGetCurrentComputerTargetIdTool from .get_mouse_position_tool import ComputerGetMousePositionTool from .get_system_info_tool import ComputerGetSystemInfoTool from .keyboard_pressed_tool import ComputerKeyboardPressedTool from .keyboard_release_tool import ComputerKeyboardReleaseTool from .keyboard_tap_tool import ComputerKeyboardTapTool +from .list_agent_os_target_computers_tool import ComputerListAgentOsTargetComputersTool from .list_displays_tool import ComputerListDisplaysTool from .mouse_click_tool import ComputerMouseClickTool from .mouse_hold_down_tool import ComputerMouseHoldDownTool @@ -14,12 +16,16 @@ from .retrieve_active_display_tool import ComputerRetrieveActiveDisplayTool from .screenshot_tool import ComputerScreenshotTool from .set_active_display_tool import ComputerSetActiveDisplayTool +from .switch_agent_os_target_computer_tool import ( + ComputerSwitchAgentOsTargetComputerTool, +) from .type_tool import ComputerTypeTool __all__ = [ "ComputerGetSystemInfoTool", "ComputerConnectTool", "ComputerDisconnectTool", + "ComputerGetCurrentComputerTargetIdTool", "ComputerGetMousePositionTool", "ComputerKeyboardPressedTool", "ComputerKeyboardReleaseTool", @@ -32,6 +38,8 @@ "ComputerScreenshotTool", "ComputerTypeTool", "ComputerListDisplaysTool", + "ComputerListAgentOsTargetComputersTool", "ComputerRetrieveActiveDisplayTool", "ComputerSetActiveDisplayTool", + "ComputerSwitchAgentOsTargetComputerTool", ] diff --git a/src/askui/tools/computer/connect_tool.py b/src/askui/tools/computer/connect_tool.py index 7e0e35f4..e4ece900 100644 --- a/src/askui/tools/computer/connect_tool.py +++ b/src/askui/tools/computer/connect_tool.py @@ -1,11 +1,11 @@ from askui.models.shared import ComputerBaseTool -from askui.tools.agent_os import AgentOs +from askui.tools.agent_os import ComputerAgentOS class ComputerConnectTool(ComputerBaseTool): """Computer Connect Tool""" - def __init__(self, agent_os: AgentOs | None = None) -> None: + def __init__(self, agent_os: ComputerAgentOS | None = None) -> None: super().__init__( name="connect", description=( diff --git a/src/askui/tools/computer/disconnect_tool.py b/src/askui/tools/computer/disconnect_tool.py index 6f3cea25..88a0fe86 100644 --- a/src/askui/tools/computer/disconnect_tool.py +++ b/src/askui/tools/computer/disconnect_tool.py @@ -1,11 +1,11 @@ from askui.models.shared import ComputerBaseTool -from askui.tools.agent_os import AgentOs +from askui.tools.agent_os import ComputerAgentOS class ComputerDisconnectTool(ComputerBaseTool): """Computer Disconnect Tool""" - def __init__(self, agent_os: AgentOs | None = None) -> None: + def __init__(self, agent_os: ComputerAgentOS | None = None) -> None: super().__init__( name="disconnect", description=( diff --git a/src/askui/tools/computer/get_current_computer_target_id_tool.py b/src/askui/tools/computer/get_current_computer_target_id_tool.py new file mode 100644 index 00000000..74ac248f --- /dev/null +++ b/src/askui/tools/computer/get_current_computer_target_id_tool.py @@ -0,0 +1,18 @@ +from askui.models.shared import ComputerBaseTool +from askui.tools.agent_os import ComputerAgentOS + + +class ComputerGetCurrentComputerTargetIdTool(ComputerBaseTool): + def __init__(self, agent_os: ComputerAgentOS | None = None) -> None: + super().__init__( + name="get_current_computer_target_id", + description=""" + Return the `computer_id` of the currently active Agent OS target + computer that agent-os actions are routed to. + """, + agent_os=agent_os, + ) + self.is_cacheable = True + + def __call__(self) -> str: + return self.agent_os.get_current_computer_target_id() diff --git a/src/askui/tools/computer/get_mouse_position_tool.py b/src/askui/tools/computer/get_mouse_position_tool.py index 059822a5..09729790 100644 --- a/src/askui/tools/computer/get_mouse_position_tool.py +++ b/src/askui/tools/computer/get_mouse_position_tool.py @@ -8,12 +8,20 @@ class ComputerGetMousePositionTool(ComputerBaseTool): def __init__(self, agent_os: ComputerAgentOsFacade | None = None) -> None: super().__init__( name="get_mouse_position", - description="Get the current mouse position.", + description=( + "Get the current mouse position on the currently active Agent OS " + "target computer. The result is prefixed with the active target " + "computer's id." + ), agent_os=agent_os, required_tags=[ToolTags.SCALED_AGENT_OS.value], ) self.is_cacheable = True def __call__(self) -> str: + target_id = self.agent_os.get_current_computer_target_id(report=False) cursor_position = self.agent_os.get_mouse_position() - return f"Mouse is at position ({cursor_position.x}, {cursor_position.y})." + return ( + f"[Computer '{target_id}']: Mouse is at position " + f"({cursor_position.x}, {cursor_position.y})." + ) diff --git a/src/askui/tools/computer/get_system_info_tool.py b/src/askui/tools/computer/get_system_info_tool.py index 7f68c07d..c82c0008 100644 --- a/src/askui/tools/computer/get_system_info_tool.py +++ b/src/askui/tools/computer/get_system_info_tool.py @@ -1,11 +1,12 @@ from askui.models.shared import ComputerBaseTool -from askui.tools.agent_os import AgentOs +from askui.tools.agent_os import ComputerAgentOS class ComputerGetSystemInfoTool(ComputerBaseTool): """ - Get the system information. - This tool returns the system information as a JSON object. + Get the system information of the currently active Agent OS target computer. + This tool returns the system information as a JSON object prefixed with the + active target computer's id. The JSON object contains the following fields: - platform: The operating system platform. - label: The operating system label. @@ -13,12 +14,14 @@ class ComputerGetSystemInfoTool(ComputerBaseTool): - architecture: The operating system architecture. """ - def __init__(self, agent_os: AgentOs | None = None) -> None: + def __init__(self, agent_os: ComputerAgentOS | None = None) -> None: super().__init__( name="get_system_info_tool", description=""" - Get the system information. - This tool returns the system information as a JSON object. + Get the system information of the currently active Agent OS target + computer. This tool returns the system information as a JSON object + prefixed with the active target computer's id so it is clear which + computer the info belongs to. The JSON object contains the following fields: - platform: The operating system platform. - label: The operating system label. @@ -29,4 +32,6 @@ def __init__(self, agent_os: AgentOs | None = None) -> None: ) def __call__(self) -> str: - return str(self.agent_os.get_system_info().model_dump_json()) + target_id = self.agent_os.get_current_computer_target_id(report=False) + system_info_json = self.agent_os.get_system_info().model_dump_json() + return f"[Computer '{target_id}']: {system_info_json}" diff --git a/src/askui/tools/computer/keyboard_pressed_tool.py b/src/askui/tools/computer/keyboard_pressed_tool.py index e85fad88..8f4fdb05 100644 --- a/src/askui/tools/computer/keyboard_pressed_tool.py +++ b/src/askui/tools/computer/keyboard_pressed_tool.py @@ -1,13 +1,13 @@ from typing import get_args from askui.models.shared import ComputerBaseTool -from askui.tools.agent_os import AgentOs, ModifierKey, PcKey +from askui.tools.agent_os import ComputerAgentOS, ModifierKey, PcKey class ComputerKeyboardPressedTool(ComputerBaseTool): """Computer Keyboard Pressed Tool""" - def __init__(self, agent_os: AgentOs | None = None) -> None: + def __init__(self, agent_os: ComputerAgentOS | None = None) -> None: super().__init__( name="keyboard_pressed", description="Press and hold a keyboard key.", diff --git a/src/askui/tools/computer/keyboard_release_tool.py b/src/askui/tools/computer/keyboard_release_tool.py index 13603f4b..7a7aedf9 100644 --- a/src/askui/tools/computer/keyboard_release_tool.py +++ b/src/askui/tools/computer/keyboard_release_tool.py @@ -1,13 +1,13 @@ from typing import get_args from askui.models.shared import ComputerBaseTool -from askui.tools.agent_os import AgentOs, ModifierKey, PcKey +from askui.tools.agent_os import ComputerAgentOS, ModifierKey, PcKey class ComputerKeyboardReleaseTool(ComputerBaseTool): """Computer Keyboard Release Tool""" - def __init__(self, agent_os: AgentOs | None = None) -> None: + def __init__(self, agent_os: ComputerAgentOS | None = None) -> None: super().__init__( name="keyboard_release", description="Release a keyboard key.", diff --git a/src/askui/tools/computer/keyboard_tap_tool.py b/src/askui/tools/computer/keyboard_tap_tool.py index 62f48227..64f96956 100644 --- a/src/askui/tools/computer/keyboard_tap_tool.py +++ b/src/askui/tools/computer/keyboard_tap_tool.py @@ -1,13 +1,13 @@ from typing import get_args from askui.models.shared import ComputerBaseTool -from askui.tools.agent_os import AgentOs, ModifierKey, PcKey +from askui.tools.agent_os import ComputerAgentOS, ModifierKey, PcKey class ComputerKeyboardTapTool(ComputerBaseTool): """Computer Keyboard Tap Tool""" - def __init__(self, agent_os: AgentOs | None = None) -> None: + def __init__(self, agent_os: ComputerAgentOS | None = None) -> None: super().__init__( name="keyboard_tap", description="Tap (press and release) a keyboard key.", diff --git a/src/askui/tools/computer/list_agent_os_target_computers_tool.py b/src/askui/tools/computer/list_agent_os_target_computers_tool.py new file mode 100644 index 00000000..bafe9ba9 --- /dev/null +++ b/src/askui/tools/computer/list_agent_os_target_computers_tool.py @@ -0,0 +1,19 @@ +from askui.models.shared import ComputerBaseTool +from askui.tools.agent_os import ComputerAgentOS + + +class ComputerListAgentOsTargetComputersTool(ComputerBaseTool): + def __init__(self, agent_os: ComputerAgentOS | None = None) -> None: + super().__init__( + name="list_agent_os_target_computers", + description=""" + List all the registered Agent OS target computers that the agent + can route actions to. Each target computer has a unique + `computer_id` that can be used to switch between them. + """, + agent_os=agent_os, + ) + + def __call__(self) -> str: + target_computer_reprs = self.agent_os.describe_agent_os_target_computers() + return "\n".join(target_computer_reprs) diff --git a/src/askui/tools/computer/list_displays_tool.py b/src/askui/tools/computer/list_displays_tool.py index 68f3c207..e500e262 100644 --- a/src/askui/tools/computer/list_displays_tool.py +++ b/src/askui/tools/computer/list_displays_tool.py @@ -1,19 +1,23 @@ from askui.models.shared import ComputerBaseTool -from askui.tools.agent_os import AgentOs +from askui.tools.agent_os import ComputerAgentOS class ComputerListDisplaysTool(ComputerBaseTool): - def __init__(self, agent_os: AgentOs | None = None) -> None: + def __init__(self, agent_os: ComputerAgentOS | None = None) -> None: super().__init__( name="list_displays", description=""" - List all the available displays on the computer. + List all the available displays on the currently active Agent OS + target computer. The result is prefixed with the active target + computer's id so it is clear which computer the displays belong to. """, agent_os=agent_os, ) self.is_cacheable = True def __call__(self) -> str: - return self.agent_os.list_displays().model_dump_json( + target_id = self.agent_os.get_current_computer_target_id(report=False) + displays_json = self.agent_os.list_displays().model_dump_json( exclude={"data": {"__all__": {"size"}}}, ) + return f"[Computer '{target_id}']: {displays_json}" diff --git a/src/askui/tools/computer/mouse_click_tool.py b/src/askui/tools/computer/mouse_click_tool.py index 002f7902..264e27eb 100644 --- a/src/askui/tools/computer/mouse_click_tool.py +++ b/src/askui/tools/computer/mouse_click_tool.py @@ -1,13 +1,13 @@ from typing import get_args from askui.models.shared import ComputerBaseTool -from askui.tools.agent_os import AgentOs, MouseButton +from askui.tools.agent_os import ComputerAgentOS, MouseButton class ComputerMouseClickTool(ComputerBaseTool): """Computer Mouse Click Tool""" - def __init__(self, agent_os: AgentOs | None = None) -> None: + def __init__(self, agent_os: ComputerAgentOS | None = None) -> None: super().__init__( name="mouse_click", description="Click and release the mouse button at the current position.", diff --git a/src/askui/tools/computer/mouse_hold_down_tool.py b/src/askui/tools/computer/mouse_hold_down_tool.py index 9387b117..74f68496 100644 --- a/src/askui/tools/computer/mouse_hold_down_tool.py +++ b/src/askui/tools/computer/mouse_hold_down_tool.py @@ -1,13 +1,13 @@ from typing import get_args from askui.models.shared import ComputerBaseTool -from askui.tools.agent_os import AgentOs, MouseButton +from askui.tools.agent_os import ComputerAgentOS, MouseButton class ComputerMouseHoldDownTool(ComputerBaseTool): """Computer Mouse Hold Down Tool""" - def __init__(self, agent_os: AgentOs | None = None) -> None: + def __init__(self, agent_os: ComputerAgentOS | None = None) -> None: super().__init__( name="mouse_hold_down", description="Hold down the mouse button at the current position.", diff --git a/src/askui/tools/computer/mouse_release_tool.py b/src/askui/tools/computer/mouse_release_tool.py index b8227d9c..39651f22 100644 --- a/src/askui/tools/computer/mouse_release_tool.py +++ b/src/askui/tools/computer/mouse_release_tool.py @@ -1,13 +1,13 @@ from typing import get_args from askui.models.shared import ComputerBaseTool -from askui.tools.agent_os import AgentOs, MouseButton +from askui.tools.agent_os import ComputerAgentOS, MouseButton class ComputerMouseReleaseTool(ComputerBaseTool): """Computer Mouse Release Tool""" - def __init__(self, agent_os: AgentOs | None = None) -> None: + def __init__(self, agent_os: ComputerAgentOS | None = None) -> None: super().__init__( name="mouse_release", description="Release the mouse button at the current position.", diff --git a/src/askui/tools/computer/retrieve_active_display_tool.py b/src/askui/tools/computer/retrieve_active_display_tool.py index 7eef6cfd..853785d7 100644 --- a/src/askui/tools/computer/retrieve_active_display_tool.py +++ b/src/askui/tools/computer/retrieve_active_display_tool.py @@ -1,20 +1,24 @@ from askui.models.shared import ComputerBaseTool -from askui.tools.agent_os import AgentOs +from askui.tools.agent_os import ComputerAgentOS class ComputerRetrieveActiveDisplayTool(ComputerBaseTool): - def __init__(self, agent_os: AgentOs | None = None) -> None: + def __init__(self, agent_os: ComputerAgentOS | None = None) -> None: super().__init__( name="retrieve_active_display", description=""" - Retrieve the currently active display on the computer. - The display is used to take screenshots and perform actions. + Retrieve the currently active display on the currently active Agent + OS target computer. The display is used to take screenshots and + perform actions. The result is prefixed with the active target + computer's id so it is clear which computer the display belongs to. """, agent_os=agent_os, ) self.is_cacheable = True def __call__(self) -> str: - return str( - self.agent_os.retrieve_active_display().model_dump_json(exclude={"size"}) + target_id = self.agent_os.get_current_computer_target_id(report=False) + display_json = self.agent_os.retrieve_active_display().model_dump_json( + exclude={"size"} ) + return f"[Computer '{target_id}']: {display_json}" diff --git a/src/askui/tools/computer/screenshot_tool.py b/src/askui/tools/computer/screenshot_tool.py index fcf46553..0928d389 100644 --- a/src/askui/tools/computer/screenshot_tool.py +++ b/src/askui/tools/computer/screenshot_tool.py @@ -10,12 +10,21 @@ class ComputerScreenshotTool(ComputerBaseTool): def __init__(self, agent_os: ComputerAgentOsFacade | None = None) -> None: super().__init__( name="screenshot", - description="Take a screenshot of the current screen.", + description=( + "Take a screenshot of the current screen on the currently active " + "Agent OS target computer. The accompanying message is prefixed " + "with the active target computer's id so it is clear which " + "computer the screenshot was taken on." + ), agent_os=agent_os, required_tags=[ToolTags.SCALED_AGENT_OS.value], ) self.is_cacheable = True def __call__(self) -> tuple[str, Image.Image]: + target_id = self.agent_os.get_current_computer_target_id(report=False) screenshot = self.agent_os.screenshot() - return "Screenshot was taken.", screenshot + return ( + f"[Computer '{target_id}']: Screenshot was taken.", + screenshot, + ) diff --git a/src/askui/tools/computer/set_active_display_tool.py b/src/askui/tools/computer/set_active_display_tool.py index 94719dec..bec7ba89 100644 --- a/src/askui/tools/computer/set_active_display_tool.py +++ b/src/askui/tools/computer/set_active_display_tool.py @@ -1,9 +1,9 @@ from askui.models.shared import ComputerBaseTool -from askui.tools.agent_os import AgentOs +from askui.tools.agent_os import ComputerAgentOS class ComputerSetActiveDisplayTool(ComputerBaseTool): - def __init__(self, agent_os: AgentOs | None = None) -> None: + def __init__(self, agent_os: ComputerAgentOS | None = None) -> None: super().__init__( name="set_active_display", description=""" diff --git a/src/askui/tools/computer/switch_agent_os_target_computer_tool.py b/src/askui/tools/computer/switch_agent_os_target_computer_tool.py new file mode 100644 index 00000000..ded871ef --- /dev/null +++ b/src/askui/tools/computer/switch_agent_os_target_computer_tool.py @@ -0,0 +1,28 @@ +from askui.models.shared import ComputerBaseTool +from askui.tools.agent_os import ComputerAgentOS + + +class ComputerSwitchAgentOsTargetComputerTool(ComputerBaseTool): + def __init__(self, agent_os: ComputerAgentOS | None = None) -> None: + super().__init__( + name="switch_agent_os_target_computer", + description=""" + Switch the active Agent OS target computer by its `computer_id`. + Future agent-os actions are routed to the newly selected target + computer. Use `list_agent_os_target_computers` to discover the + available computer ids. + """, + input_schema={ + "type": "object", + "properties": { + "computer_id": { + "type": "string", + }, + }, + "required": ["computer_id"], + }, + agent_os=agent_os, + ) + + def __call__(self, computer_id: str) -> str: + return repr(self.agent_os.switch_agent_os_target_computer(computer_id)) diff --git a/src/askui/tools/computer/type_tool.py b/src/askui/tools/computer/type_tool.py index ace3a612..8f232dbc 100644 --- a/src/askui/tools/computer/type_tool.py +++ b/src/askui/tools/computer/type_tool.py @@ -1,11 +1,11 @@ from askui.models.shared import ComputerBaseTool -from askui.tools.agent_os import AgentOs +from askui.tools.agent_os import ComputerAgentOS class ComputerTypeTool(ComputerBaseTool): """Computer Type Tool""" - def __init__(self, agent_os: AgentOs | None = None) -> None: + def __init__(self, agent_os: ComputerAgentOS | None = None) -> None: super().__init__( name="type", description=( diff --git a/src/askui/tools/computer_agent_os_facade.py b/src/askui/tools/computer_agent_os_facade.py index 57c7efa4..a7dce4fe 100644 --- a/src/askui/tools/computer_agent_os_facade.py +++ b/src/askui/tools/computer_agent_os_facade.py @@ -1,12 +1,15 @@ +from collections.abc import Iterator +from contextlib import contextmanager from typing import TYPE_CHECKING from PIL import Image +from typing_extensions import Self from askui.models.shared.coordinate_space import VlmCoordinateSpace from askui.models.shared.image_scaler import ImageScaler from askui.models.shared.tool_tags import ToolTags from askui.tools.agent_os import ( - AgentOs, + ComputerAgentOS, Coordinate, Display, DisplaysListResponse, @@ -19,6 +22,9 @@ from askui.tools.coordinate_scaler import CoordinateScaler if TYPE_CHECKING: + from askui.tools.askui.agent_os_target_computer import ( + ComputerTarget, + ) from askui.tools.askui.askui_ui_controller_grpc.generated import ( Controller_V1_pb2 as controller_v1_pbs, ) @@ -29,8 +35,8 @@ ) -class ComputerAgentOsFacade(AgentOs): - """Facade for `AgentOs` that adds coordinate scaling. +class ComputerAgentOsFacade(ComputerAgentOS): + """Facade for `ComputerAgentOS` that adds coordinate scaling. Screenshots are scaled using the provider's image scaler so that the AI model sees an optimally sized image. Coordinate-based inputs @@ -38,14 +44,14 @@ class ComputerAgentOsFacade(AgentOs): to the underlying agent OS. Args: - agent_os (`AgentOs`): The real agent OS to wrap. + agent_os (`ComputerAgentOS`): The real agent OS to wrap. coordinate_space (`VlmCoordinateSpace`): Coordinate grid the model uses. image_scaler (`ImageScaler`): Callable to preprocess screenshots. """ def __init__( self, - agent_os: AgentOs, + agent_os: ComputerAgentOS, coordinate_space: VlmCoordinateSpace, image_scaler: ImageScaler, ) -> None: @@ -324,6 +330,39 @@ def set_window_in_focus(self, process_id: int, window_id: int) -> None: """ self._agent_os.set_window_in_focus(process_id, window_id) + def add_agent_os_target_computer( + self, agent_os_target_computer: "ComputerTarget" + ) -> "ComputerTarget": + return self._agent_os.add_agent_os_target_computer(agent_os_target_computer) + + def reset_agent_os_target_computers( + self, + agent_os_target_computers: "list[ComputerTarget] | None" = None, + ) -> None: + self._agent_os.reset_agent_os_target_computers(agent_os_target_computers) + + def describe_agent_os_target_computers(self) -> list[str]: + return self._agent_os.describe_agent_os_target_computers() + + def get_current_computer_target_id(self, report: bool = True) -> str: + return self._agent_os.get_current_computer_target_id(report=report) + + def switch_agent_os_target_computer(self, computer_id: str) -> "ComputerTarget": + agent_os_target_computer = self._agent_os.switch_agent_os_target_computer( + computer_id + ) + self._scaler.real_screen_resolution = None + return agent_os_target_computer + + @contextmanager + def temporary_select(self, computer_id: str) -> Iterator[Self]: + with self._agent_os.temporary_select(computer_id): + self._scaler.real_screen_resolution = None + try: + yield self + finally: + self._scaler.real_screen_resolution = None + def get_file_names(self, absolute_directory_path: str) -> list[str]: """ List file names in an absolute directory on the automation target. diff --git a/src/askui/tools/playwright/agent_os.py b/src/askui/tools/playwright/agent_os.py index db199f70..31fe581f 100644 --- a/src/askui/tools/playwright/agent_os.py +++ b/src/askui/tools/playwright/agent_os.py @@ -21,7 +21,14 @@ from askui.reporting import NULL_REPORTER, Reporter from askui.utils.annotated_image import AnnotatedImage -from ..agent_os import AgentOs, Display, DisplaySize, InputEvent, ModifierKey, PcKey +from ..agent_os import ( + ComputerAgentOS, + Display, + DisplaySize, + InputEvent, + ModifierKey, + PcKey, +) def _to_unique_path(path: Path) -> Path: @@ -47,8 +54,8 @@ def _to_unique_path(path: Path) -> Path: counter += 1 -class PlaywrightAgentOs(AgentOs): - """Playwright-based implementation of `AgentOs`. +class PlaywrightAgentOs(ComputerAgentOS): + """Playwright-based implementation of `ComputerAgentOS`. This implementation uses Playwright's Python SDK to control browser automation and simulate user interactions. It provides mouse control, keyboard input, diff --git a/src/askui/tools/store/__init__.py b/src/askui/tools/store/__init__.py index 2a05056d..eba8bb31 100644 --- a/src/askui/tools/store/__init__.py +++ b/src/askui/tools/store/__init__.py @@ -3,7 +3,7 @@ Tools are organized by category: - `android`: Tools specific to Android agents (require AndroidAgentOs) - `computer`: Tools specific to Computer/Desktop agents (require ComputerAgentOsFacade) -- `universal`: Tools that work with any agent type (don't require AgentOs) +- `universal`: Tools that work with any agent type (don't require ComputerAgentOS) Example: ```python diff --git a/src/askui/tools/store/computer/__init__.py b/src/askui/tools/store/computer/__init__.py index fb7f5427..f2fc0ca2 100644 --- a/src/askui/tools/store/computer/__init__.py +++ b/src/askui/tools/store/computer/__init__.py @@ -1,6 +1,6 @@ """Computer-specific tools. -These tools require AgentOs (or ComputerAgentOsFacade) and are designed +These tools require ComputerAgentOS (or ComputerAgentOsFacade) and are designed for use with VisionAgent. """ diff --git a/src/askui/tools/store/computer/experimental/get_file.py b/src/askui/tools/store/computer/experimental/get_file.py index b7bf5c93..a7e47010 100644 --- a/src/askui/tools/store/computer/experimental/get_file.py +++ b/src/askui/tools/store/computer/experimental/get_file.py @@ -1,7 +1,7 @@ from PIL import Image from askui.models.shared import ComputerBaseTool, ToolTags -from askui.tools.agent_os import AgentOs +from askui.tools.agent_os import ComputerAgentOS class ComputerGetFileTool(ComputerBaseTool): @@ -24,7 +24,7 @@ class ComputerGetFileTool(ComputerBaseTool): ``` """ - def __init__(self, agent_os: AgentOs | None = None) -> None: + def __init__(self, agent_os: ComputerAgentOS | None = None) -> None: super().__init__( name="get_file_tool", description=( diff --git a/src/askui/tools/store/computer/experimental/get_file_names.py b/src/askui/tools/store/computer/experimental/get_file_names.py index 5002b0eb..643820fb 100644 --- a/src/askui/tools/store/computer/experimental/get_file_names.py +++ b/src/askui/tools/store/computer/experimental/get_file_names.py @@ -1,5 +1,5 @@ from askui.models.shared import ComputerBaseTool -from askui.tools.agent_os import AgentOs +from askui.tools.agent_os import ComputerAgentOS class ComputerGetFileNamesTool(ComputerBaseTool): @@ -23,7 +23,7 @@ class ComputerGetFileNamesTool(ComputerBaseTool): ``` """ - def __init__(self, agent_os: AgentOs | None = None) -> None: + def __init__(self, agent_os: ComputerAgentOS | None = None) -> None: super().__init__( name="get_file_names_tool", description=( diff --git a/src/askui/tools/store/computer/experimental/remove_virtual_displays.py b/src/askui/tools/store/computer/experimental/remove_virtual_displays.py index 1a7b2000..1a952b69 100644 --- a/src/askui/tools/store/computer/experimental/remove_virtual_displays.py +++ b/src/askui/tools/store/computer/experimental/remove_virtual_displays.py @@ -1,5 +1,5 @@ from askui.models.shared import ComputerBaseTool -from askui.tools.agent_os import AgentOs +from askui.tools.agent_os import ComputerAgentOS class ComputerRemoveVirtualDisplaysTool(ComputerBaseTool): @@ -24,7 +24,7 @@ class ComputerRemoveVirtualDisplaysTool(ComputerBaseTool): ``` """ - def __init__(self, agent_os: AgentOs | None = None) -> None: + def __init__(self, agent_os: ComputerAgentOS | None = None) -> None: super().__init__( name="remove_virtual_displays_tool", description=( diff --git a/src/askui/tools/store/computer/experimental/window_management/add_window_as_virtual_display.py b/src/askui/tools/store/computer/experimental/window_management/add_window_as_virtual_display.py index 7750b333..58e6b73e 100644 --- a/src/askui/tools/store/computer/experimental/window_management/add_window_as_virtual_display.py +++ b/src/askui/tools/store/computer/experimental/window_management/add_window_as_virtual_display.py @@ -1,5 +1,5 @@ from askui.models.shared import ComputerBaseTool -from askui.tools.agent_os import AgentOs +from askui.tools.agent_os import ComputerAgentOS class ComputerAddWindowAsVirtualDisplayTool(ComputerBaseTool): @@ -9,7 +9,7 @@ class ComputerAddWindowAsVirtualDisplayTool(ComputerBaseTool): for UI automation tasks. """ - def __init__(self, agent_os: AgentOs | None = None) -> None: + def __init__(self, agent_os: ComputerAgentOS | None = None) -> None: super().__init__( name="add_window_as_virtual_display_tool", description=""" diff --git a/src/askui/tools/store/computer/experimental/window_management/list_process.py b/src/askui/tools/store/computer/experimental/window_management/list_process.py index 775f5a8c..c3141370 100644 --- a/src/askui/tools/store/computer/experimental/window_management/list_process.py +++ b/src/askui/tools/store/computer/experimental/window_management/list_process.py @@ -1,5 +1,5 @@ from askui.models.shared import ComputerBaseTool -from askui.tools.agent_os import AgentOs +from askui.tools.agent_os import ComputerAgentOS class ComputerListProcessTool(ComputerBaseTool): @@ -9,7 +9,7 @@ class ComputerListProcessTool(ComputerBaseTool): applications and their process IDs. """ - def __init__(self, agent_os: AgentOs | None = None) -> None: + def __init__(self, agent_os: ComputerAgentOS | None = None) -> None: super().__init__( name="list_process_tool", description=""" diff --git a/src/askui/tools/store/computer/experimental/window_management/list_process_windows.py b/src/askui/tools/store/computer/experimental/window_management/list_process_windows.py index f114ec87..580445f9 100644 --- a/src/askui/tools/store/computer/experimental/window_management/list_process_windows.py +++ b/src/askui/tools/store/computer/experimental/window_management/list_process_windows.py @@ -1,5 +1,5 @@ from askui.models.shared import ComputerBaseTool -from askui.tools.agent_os import AgentOs +from askui.tools.agent_os import ComputerAgentOS class ComputerListProcessWindowsTool(ComputerBaseTool): @@ -9,7 +9,7 @@ class ComputerListProcessWindowsTool(ComputerBaseTool): list_process_tool. """ - def __init__(self, agent_os: AgentOs | None = None) -> None: + def __init__(self, agent_os: ComputerAgentOS | None = None) -> None: super().__init__( name="list_process_windows_tool", description=""" diff --git a/src/askui/tools/store/computer/experimental/window_management/set_process_in_focus.py b/src/askui/tools/store/computer/experimental/window_management/set_process_in_focus.py index 2e27550f..7e7fff84 100644 --- a/src/askui/tools/store/computer/experimental/window_management/set_process_in_focus.py +++ b/src/askui/tools/store/computer/experimental/window_management/set_process_in_focus.py @@ -1,5 +1,5 @@ from askui.models.shared import ComputerBaseTool -from askui.tools.agent_os import AgentOs +from askui.tools.agent_os import ComputerAgentOS class ComputerSetProcessInFocusTool(ComputerBaseTool): @@ -9,7 +9,7 @@ class ComputerSetProcessInFocusTool(ComputerBaseTool): operating system or the process determine which window should be focused. """ - def __init__(self, agent_os: AgentOs | None = None) -> None: + def __init__(self, agent_os: ComputerAgentOS | None = None) -> None: super().__init__( name="set_process_in_focus_tool", description=""" diff --git a/src/askui/tools/store/computer/experimental/window_management/set_window_in_focus.py b/src/askui/tools/store/computer/experimental/window_management/set_window_in_focus.py index e597a78c..41573f5a 100644 --- a/src/askui/tools/store/computer/experimental/window_management/set_window_in_focus.py +++ b/src/askui/tools/store/computer/experimental/window_management/set_window_in_focus.py @@ -1,5 +1,5 @@ from askui.models.shared import ComputerBaseTool -from askui.tools.agent_os import AgentOs +from askui.tools.agent_os import ComputerAgentOS class ComputerSetWindowInFocusTool(ComputerBaseTool): @@ -9,7 +9,7 @@ class ComputerSetWindowInFocusTool(ComputerBaseTool): before performing automation tasks. """ - def __init__(self, agent_os: AgentOs | None = None) -> None: + def __init__(self, agent_os: ComputerAgentOS | None = None) -> None: super().__init__( name="set_window_in_focus_tool", description=""" diff --git a/src/askui/tools/toolbox.py b/src/askui/tools/toolbox.py index 3f954fe4..a7362810 100644 --- a/src/askui/tools/toolbox.py +++ b/src/askui/tools/toolbox.py @@ -3,7 +3,7 @@ import httpx import pyperclip -from askui.tools.agent_os import AgentOs +from askui.tools.agent_os import ComputerAgentOS class AgentToolbox: @@ -13,16 +13,18 @@ class AgentToolbox: Provides access to OS-level actions, clipboard, web browser, HTTP client etc. Args: - agent_os (AgentOs): The OS interface implementation to use for agent actions. + agent_os (ComputerAgentOS): The OS interface implementation to use for + agent actions. Attributes: webbrowser: Python's built-in `webbrowser` module for opening URLs. clipboard: `pyperclip` module for clipboard access. - agent_os (AgentOs): The OS interface for mouse, keyboard, and screen actions. + agent_os (ComputerAgentOS): The OS interface for mouse, keyboard, and + screen actions. httpx: HTTPX client for HTTP requests. """ - def __init__(self, agent_os: AgentOs): + def __init__(self, agent_os: ComputerAgentOS): self.webbrowser = webbrowser self.clipboard = pyperclip self.os = agent_os diff --git a/tests/conftest.py b/tests/conftest.py index 5eb112db..f2a792af 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,8 +5,7 @@ from PIL import Image from pytest_mock import MockerFixture -from askui.tools.agent_os import AgentOs, Display, DisplaySize -from askui.tools.toolbox import AgentToolbox +from askui.tools.agent_os import ComputerAgentOS, Display, DisplaySize @pytest.fixture @@ -84,22 +83,28 @@ def path_fixtures_github_com__icon(path_fixtures_images: pathlib.Path) -> pathli @pytest.fixture -def agent_os_mock(mocker: MockerFixture) -> AgentOs: +def agent_os_mock(mocker: MockerFixture) -> ComputerAgentOS: """Fixture providing a mock agent os.""" - mock = mocker.MagicMock(spec=AgentOs) + mock = mocker.MagicMock(spec=ComputerAgentOS) mock.retrieve_active_display.return_value = Display( id=1, name="Display 1", size=DisplaySize(width=100, height=100), ) mock.screenshot.return_value = Image.new("RGB", (100, 100), color="white") - return cast("AgentOs", mock) + return cast("ComputerAgentOS", mock) @pytest.fixture -def agent_toolbox_mock(agent_os_mock: AgentOs) -> AgentToolbox: - """Fixture providing a mock agent toolbox.""" - return AgentToolbox(agent_os=agent_os_mock) +def agent_os_mock_patch( + mocker: MockerFixture, agent_os_mock: ComputerAgentOS +) -> ComputerAgentOS: + """Patches `MultiComputerTargetAgentOS` so `ComputerAgent` uses `agent_os_mock`.""" + mocker.patch( + "askui.computer_agent.MultiComputerTargetAgentOS", + return_value=agent_os_mock, + ) + return agent_os_mock @pytest.fixture(autouse=True) diff --git a/tests/e2e/agent/conftest.py b/tests/e2e/agent/conftest.py index 19bdbaa6..b502949e 100644 --- a/tests/e2e/agent/conftest.py +++ b/tests/e2e/agent/conftest.py @@ -27,7 +27,7 @@ from askui.models.shared.settings import LocateSettings from askui.models.types.geometry import PointList from askui.reporting import Reporter, SimpleHtmlReporter -from askui.tools.toolbox import AgentToolbox +from askui.tools.agent_os import ComputerAgentOS from askui.utils.image_utils import ImageSource @@ -98,7 +98,7 @@ def combo_locate_model(path_fixtures: pathlib.Path) -> LocateModel: @pytest.fixture def agent_with_pta_model( pta_locate_model: LocateModel, - agent_toolbox_mock: AgentToolbox, + agent_os_mock_patch: ComputerAgentOS, # noqa: ARG001 simple_html_reporter: Reporter, ) -> Generator[ComputerAgent, None, None]: with ComputerAgent( @@ -106,7 +106,6 @@ def agent_with_pta_model( detection_provider=_LocateModelDetectionProvider(pta_locate_model) ), reporters=[simple_html_reporter], - tools=agent_toolbox_mock, ) as agent: yield agent @@ -114,7 +113,7 @@ def agent_with_pta_model( @pytest.fixture def agent_with_ocr_model( ocr_locate_model: LocateModel, - agent_toolbox_mock: AgentToolbox, + agent_os_mock_patch: ComputerAgentOS, # noqa: ARG001 simple_html_reporter: Reporter, ) -> Generator[ComputerAgent, None, None]: with ComputerAgent( @@ -122,7 +121,6 @@ def agent_with_ocr_model( detection_provider=_LocateModelDetectionProvider(ocr_locate_model) ), reporters=[simple_html_reporter], - tools=agent_toolbox_mock, ) as agent: yield agent @@ -130,7 +128,7 @@ def agent_with_ocr_model( @pytest.fixture def agent_with_ai_element_model( ai_element_locate_model: LocateModel, - agent_toolbox_mock: AgentToolbox, + agent_os_mock_patch: ComputerAgentOS, # noqa: ARG001 simple_html_reporter: Reporter, ) -> Generator[ComputerAgent, None, None]: with ComputerAgent( @@ -138,7 +136,6 @@ def agent_with_ai_element_model( detection_provider=_LocateModelDetectionProvider(ai_element_locate_model) ), reporters=[simple_html_reporter], - tools=agent_toolbox_mock, ) as agent: yield agent @@ -146,7 +143,7 @@ def agent_with_ai_element_model( @pytest.fixture def agent_with_combo_model( combo_locate_model: LocateModel, - agent_toolbox_mock: AgentToolbox, + agent_os_mock_patch: ComputerAgentOS, # noqa: ARG001 simple_html_reporter: Reporter, ) -> Generator[ComputerAgent, None, None]: with ComputerAgent( @@ -154,19 +151,17 @@ def agent_with_combo_model( detection_provider=_LocateModelDetectionProvider(combo_locate_model) ), reporters=[simple_html_reporter], - tools=agent_toolbox_mock, ) as agent: yield agent @pytest.fixture def vision_agent( - agent_toolbox_mock: AgentToolbox, + agent_os_mock_patch: ComputerAgentOS, # noqa: ARG001 simple_html_reporter: Reporter, ) -> Generator[ComputerAgent, None, None]: """Fixture providing a ComputerAgent instance.""" with ComputerAgent( reporters=[simple_html_reporter], - tools=agent_toolbox_mock, ) as agent: yield agent diff --git a/tests/e2e/agent/test_get.py b/tests/e2e/agent/test_get.py index bae0d4e8..b34aa67e 100644 --- a/tests/e2e/agent/test_get.py +++ b/tests/e2e/agent/test_get.py @@ -19,7 +19,7 @@ from askui.models.shared.settings import GetSettings from askui.models.types.response_schemas import ResponseSchema from askui.reporting import Reporter -from askui.tools.toolbox import AgentToolbox +from askui.tools.agent_os import ComputerAgentOS from askui.utils.source_utils import Source @@ -97,7 +97,7 @@ class BrowserContextResponse(ResponseSchemaBase): ) def test_get( vision_agent: ComputerAgent, - agent_toolbox_mock: AgentToolbox, + agent_os_mock_patch: ComputerAgentOS, # noqa: ARG001 simple_html_reporter: Reporter, github_login_screenshot: PILImage.Image, get_model: GetModel | None, @@ -112,7 +112,6 @@ def test_get( settings=AgentSettings( image_qa_provider=_GetModelImageQAProvider(get_model) ), - tools=agent_toolbox_mock, reporters=[simple_html_reporter], ) as agent: url = agent.get( @@ -142,14 +141,13 @@ def test_get( ], ) def test_get_with_pdf_with_gemini_model( - agent_toolbox_mock: AgentToolbox, + agent_os_mock_patch: ComputerAgentOS, # noqa: ARG001 simple_html_reporter: Reporter, get_model: GetModel, path_fixtures_dummy_pdf: pathlib.Path, ) -> None: with ComputerAgent( settings=AgentSettings(image_qa_provider=_GetModelImageQAProvider(get_model)), - tools=agent_toolbox_mock, reporters=[simple_html_reporter], ) as agent: response = agent.get( @@ -180,7 +178,7 @@ def test_get_with_pdf_with_gemini_model( ], ) def test_get_with_pdf_too_large( - agent_toolbox_mock: AgentToolbox, + agent_os_mock_patch: ComputerAgentOS, # noqa: ARG001 simple_html_reporter: Reporter, get_model: GetModel, path_fixtures_dummy_pdf: pathlib.Path, @@ -189,7 +187,6 @@ def test_get_with_pdf_too_large( mocker.patch("askui.models.askui.get_model.MAX_FILE_SIZE_BYTES", 1) with ComputerAgent( settings=AgentSettings(image_qa_provider=_GetModelImageQAProvider(get_model)), - tools=agent_toolbox_mock, reporters=[simple_html_reporter], ) as agent: with pytest.raises(ValueError, match="PDF file size exceeds the limit"): @@ -232,14 +229,13 @@ def test_get_with_pdf_too_large_with_default_model( ], ) def test_get_with_xlsx_with_gemini_model( - agent_toolbox_mock: AgentToolbox, + agent_os_mock_patch: ComputerAgentOS, # noqa: ARG001 simple_html_reporter: Reporter, get_model: GetModel, path_fixtures_dummy_excel: pathlib.Path, ) -> None: with ComputerAgent( settings=AgentSettings(image_qa_provider=_GetModelImageQAProvider(get_model)), - tools=agent_toolbox_mock, reporters=[simple_html_reporter], ) as agent: response = agent.get( @@ -279,14 +275,13 @@ class SalaryResponse(ResponseSchemaBase): ], ) def test_get_with_xlsx_with_gemini_model_with_response_schema( - agent_toolbox_mock: AgentToolbox, + agent_os_mock_patch: ComputerAgentOS, # noqa: ARG001 simple_html_reporter: Reporter, get_model: GetModel, path_fixtures_dummy_excel: pathlib.Path, ) -> None: with ComputerAgent( settings=AgentSettings(image_qa_provider=_GetModelImageQAProvider(get_model)), - tools=agent_toolbox_mock, reporters=[simple_html_reporter], ) as agent: response = agent.get( @@ -325,7 +320,7 @@ def test_get_with_docs_with_default_model( def test_get_with_fallback_model( - agent_toolbox_mock: AgentToolbox, + agent_os_mock_patch: ComputerAgentOS, # noqa: ARG001 simple_html_reporter: Reporter, github_login_screenshot: PILImage.Image, ) -> None: @@ -338,7 +333,6 @@ def test_get_with_fallback_model( image_qa_provider=_GetModelImageQAProvider(askui_get_model) ), reporters=[simple_html_reporter], - tools=agent_toolbox_mock, ) as agent: url = agent.get( "What is the current url shown in the url bar?", @@ -393,7 +387,7 @@ def test_get_with_response_schema_with_default_value( ) def test_get_with_response_schema( vision_agent: ComputerAgent, - agent_toolbox_mock: AgentToolbox, + agent_os_mock_patch: ComputerAgentOS, # noqa: ARG001 simple_html_reporter: Reporter, github_login_screenshot: PILImage.Image, get_model: GetModel | None, @@ -409,7 +403,6 @@ def test_get_with_response_schema( settings=AgentSettings( image_qa_provider=_GetModelImageQAProvider(get_model) ), - tools=agent_toolbox_mock, reporters=[simple_html_reporter], ) as agent: response = agent.get( @@ -434,14 +427,13 @@ def test_get_with_response_schema( ], ) def test_get_with_nested_and_inherited_response_schema( - agent_toolbox_mock: AgentToolbox, + agent_os_mock_patch: ComputerAgentOS, # noqa: ARG001 simple_html_reporter: Reporter, github_login_screenshot: PILImage.Image, get_model: GetModel, ) -> None: with ComputerAgent( settings=AgentSettings(image_qa_provider=_GetModelImageQAProvider(get_model)), - tools=agent_toolbox_mock, reporters=[simple_html_reporter], ) as agent: response = agent.get( @@ -473,14 +465,13 @@ class LinkedListNode(ResponseSchemaBase): ], ) def test_get_with_recursive_response_schema( - agent_toolbox_mock: AgentToolbox, + agent_os_mock_patch: ComputerAgentOS, # noqa: ARG001 simple_html_reporter: Reporter, github_login_screenshot: PILImage.Image, get_model: GetModel, ) -> None: with ComputerAgent( settings=AgentSettings(image_qa_provider=_GetModelImageQAProvider(get_model)), - tools=agent_toolbox_mock, reporters=[simple_html_reporter], ) as agent: with pytest.raises( @@ -507,14 +498,13 @@ def test_get_with_recursive_response_schema( ], ) def test_get_with_string_schema( - agent_toolbox_mock: AgentToolbox, + agent_os_mock_patch: ComputerAgentOS, # noqa: ARG001 simple_html_reporter: Reporter, github_login_screenshot: PILImage.Image, get_model: GetModel, ) -> None: with ComputerAgent( settings=AgentSettings(image_qa_provider=_GetModelImageQAProvider(get_model)), - tools=agent_toolbox_mock, reporters=[simple_html_reporter], ) as agent: response = agent.get( @@ -545,14 +535,13 @@ def test_get_with_string_schema( ], ) def test_get_with_boolean_schema( - agent_toolbox_mock: AgentToolbox, + agent_os_mock_patch: ComputerAgentOS, # noqa: ARG001 simple_html_reporter: Reporter, github_login_screenshot: PILImage.Image, get_model: GetModel, ) -> None: with ComputerAgent( settings=AgentSettings(image_qa_provider=_GetModelImageQAProvider(get_model)), - tools=agent_toolbox_mock, reporters=[simple_html_reporter], ) as agent: response = agent.get( @@ -577,14 +566,13 @@ def test_get_with_boolean_schema( ], ) def test_get_with_integer_schema( - agent_toolbox_mock: AgentToolbox, + agent_os_mock_patch: ComputerAgentOS, # noqa: ARG001 simple_html_reporter: Reporter, github_login_screenshot: PILImage.Image, get_model: GetModel, ) -> None: with ComputerAgent( settings=AgentSettings(image_qa_provider=_GetModelImageQAProvider(get_model)), - tools=agent_toolbox_mock, reporters=[simple_html_reporter], ) as agent: response = agent.get( @@ -609,14 +597,13 @@ def test_get_with_integer_schema( ], ) def test_get_with_float_schema( - agent_toolbox_mock: AgentToolbox, + agent_os_mock_patch: ComputerAgentOS, # noqa: ARG001 simple_html_reporter: Reporter, github_login_screenshot: PILImage.Image, get_model: GetModel, ) -> None: with ComputerAgent( settings=AgentSettings(image_qa_provider=_GetModelImageQAProvider(get_model)), - tools=agent_toolbox_mock, reporters=[simple_html_reporter], ) as agent: response = agent.get( @@ -641,14 +628,13 @@ def test_get_with_float_schema( ], ) def test_get_returns_str_when_no_schema_specified( - agent_toolbox_mock: AgentToolbox, + agent_os_mock_patch: ComputerAgentOS, # noqa: ARG001 simple_html_reporter: Reporter, github_login_screenshot: PILImage.Image, get_model: GetModel, ) -> None: with ComputerAgent( settings=AgentSettings(image_qa_provider=_GetModelImageQAProvider(get_model)), - tools=agent_toolbox_mock, reporters=[simple_html_reporter], ) as agent: response = agent.get( @@ -675,14 +661,13 @@ class Basis(ResponseSchemaBase): ], ) def test_get_with_basis_schema( - agent_toolbox_mock: AgentToolbox, + agent_os_mock_patch: ComputerAgentOS, # noqa: ARG001 simple_html_reporter: Reporter, github_login_screenshot: PILImage.Image, get_model: GetModel, ) -> None: with ComputerAgent( settings=AgentSettings(image_qa_provider=_GetModelImageQAProvider(get_model)), - tools=agent_toolbox_mock, reporters=[simple_html_reporter], ) as agent: response = agent.get( @@ -715,14 +700,13 @@ class BasisWithNestedRootModel(ResponseSchemaBase): ], ) def test_get_with_nested_root_model( - agent_toolbox_mock: AgentToolbox, + agent_os_mock_patch: ComputerAgentOS, # noqa: ARG001 simple_html_reporter: Reporter, github_login_screenshot: PILImage.Image, get_model: GetModel, ) -> None: with ComputerAgent( settings=AgentSettings(image_qa_provider=_GetModelImageQAProvider(get_model)), - tools=agent_toolbox_mock, reporters=[simple_html_reporter], ) as agent: response = agent.get( @@ -774,7 +758,7 @@ class PageDom(ResponseSchemaBase): ], ) def test_get_with_deeply_nested_response_schema_with_model_that_does_not_support_recursion( - agent_toolbox_mock: AgentToolbox, + agent_os_mock_patch: ComputerAgentOS, # noqa: ARG001 simple_html_reporter: Reporter, github_login_screenshot: PILImage.Image, get_model: GetModel, @@ -786,7 +770,6 @@ def test_get_with_deeply_nested_response_schema_with_model_that_does_not_support """ with ComputerAgent( settings=AgentSettings(image_qa_provider=_GetModelImageQAProvider(get_model)), - tools=agent_toolbox_mock, reporters=[simple_html_reporter], ) as agent: response = agent.get( diff --git a/tests/e2e/test_telemetry.py b/tests/e2e/test_telemetry.py index 25b9202a..70539277 100644 --- a/tests/e2e/test_telemetry.py +++ b/tests/e2e/test_telemetry.py @@ -5,13 +5,13 @@ from askui import locators as loc from askui.container import telemetry from askui.telemetry.processors import Segment, SegmentSettings -from askui.tools.toolbox import AgentToolbox +from askui.tools.agent_os import ComputerAgentOS @pytest.mark.timeout(60) def test_telemetry_with_nonexistent_domain_should_not_block( github_login_screenshot: Image.Image, - agent_toolbox_mock: AgentToolbox, + agent_os_mock_patch: ComputerAgentOS, # noqa: ARG001 ) -> None: telemetry.set_processors( [ @@ -23,6 +23,6 @@ def test_telemetry_with_nonexistent_domain_should_not_block( ) ] ) - with ComputerAgent(tools=agent_toolbox_mock) as agent: + with ComputerAgent() as agent: agent.locate(loc.Text(), screenshot=github_login_screenshot) assert True diff --git a/tests/e2e/tools/askui/test_askui_controller.py b/tests/e2e/tools/askui/test_askui_controller.py index bca9e591..b0cff359 100644 --- a/tests/e2e/tools/askui/test_askui_controller.py +++ b/tests/e2e/tools/askui/test_askui_controller.py @@ -7,33 +7,33 @@ from askui.reporting import CompositeReporter from askui.tools.agent_os import Coordinate +from askui.tools.askui import LocalComputerTarget from askui.tools.askui.askui_controller import ( - AskUiControllerClient, - AskUiControllerServer, + MultiComputerTargetAgentOS, RenderObjectStyle, ) from askui.tools.askui.askui_controller_settings import AskUiControllerSettings @pytest.fixture -def controller_server() -> AskUiControllerServer: - return AskUiControllerServer( +def agent_os_target_computer() -> LocalComputerTarget: + return LocalComputerTarget( settings=AskUiControllerSettings(controller_args="--showOverlay true") ) @pytest.fixture def controller_client( - controller_server: AskUiControllerServer, -) -> AskUiControllerClient: - return AskUiControllerClient( + agent_os_target_computer: LocalComputerTarget, +) -> MultiComputerTargetAgentOS: + return MultiComputerTargetAgentOS( reporter=CompositeReporter(), display=1, - controller_server=controller_server, + agent_os_target_computers=[agent_os_target_computer], ) -def test_actions(controller_client: AskUiControllerClient) -> None: +def test_actions(controller_client: MultiComputerTargetAgentOS) -> None: with controller_client: controller_client.screenshot() controller_client.mouse_move(0, 0) @@ -42,14 +42,15 @@ def test_actions(controller_client: AskUiControllerClient) -> None: @pytest.mark.parametrize("button", ["left", "right", "middle"]) def test_click_all_buttons( - controller_client: AskUiControllerClient, button: Literal["left", "middle", "right"] + controller_client: MultiComputerTargetAgentOS, + button: Literal["left", "middle", "right"], ) -> None: """Test clicking each mouse button""" with controller_client: controller_client.click(button=button) -def test_mouse_multiple_clicks(controller_client: AskUiControllerClient) -> None: +def test_mouse_multiple_clicks(controller_client: MultiComputerTargetAgentOS) -> None: """Test click count parameter""" with controller_client: controller_client.click(count=3) @@ -57,7 +58,8 @@ def test_mouse_multiple_clicks(controller_client: AskUiControllerClient) -> None @pytest.mark.parametrize("button", ["left", "right", "middle"]) def test_mouse_press_hold_release( - controller_client: AskUiControllerClient, button: Literal["left", "middle", "right"] + controller_client: MultiComputerTargetAgentOS, + button: Literal["left", "middle", "right"], ) -> None: """Test mouse_down() and mouse_up() operations""" with controller_client: @@ -67,14 +69,14 @@ def test_mouse_press_hold_release( @pytest.mark.parametrize("x,y", [(0, 0), (100, 100), (500, 300)]) def test_mouse_move_coordinates( - controller_client: AskUiControllerClient, x: int, y: int + controller_client: MultiComputerTargetAgentOS, x: int, y: int ) -> None: """Test mouse movement to various coordinates""" with controller_client: controller_client.mouse_move(x, y) -def test_mouse_scroll_directions(controller_client: AskUiControllerClient) -> None: +def test_mouse_scroll_directions(controller_client: MultiComputerTargetAgentOS) -> None: """Test horizontal and vertical scrolling""" with controller_client: controller_client.mouse_scroll(0, 5) # Vertical scroll @@ -82,54 +84,58 @@ def test_mouse_scroll_directions(controller_client: AskUiControllerClient) -> No controller_client.mouse_scroll(3, -2) # Combined scroll -def test_type_text_basic(controller_client: AskUiControllerClient) -> None: +def test_type_text_basic(controller_client: MultiComputerTargetAgentOS) -> None: """Test typing simple text""" with controller_client: controller_client.type("Hello World") -def test_type_text_with_speed(controller_client: AskUiControllerClient) -> None: +def test_type_text_with_speed(controller_client: MultiComputerTargetAgentOS) -> None: """Test typing with custom speed""" with controller_client: controller_client.type("Fast typing", typing_speed=100) controller_client.type("Slow typing", typing_speed=10) -def test_keyboard_tap_with_modifiers(controller_client: AskUiControllerClient) -> None: +def test_keyboard_tap_with_modifiers( + controller_client: MultiComputerTargetAgentOS, +) -> None: """Test key combination like Ctrl+C""" with controller_client: controller_client.keyboard_tap("c", modifier_keys=["command"]) controller_client.keyboard_tap("v", modifier_keys=["command"]) -def test_keyboard_tap_multiple(controller_client: AskUiControllerClient) -> None: +def test_keyboard_tap_multiple(controller_client: MultiComputerTargetAgentOS) -> None: """Test multiple key taps""" with controller_client: controller_client.keyboard_tap("escape", count=3) -def test_keyboard_press_hold_release(controller_client: AskUiControllerClient) -> None: +def test_keyboard_press_hold_release( + controller_client: MultiComputerTargetAgentOS, +) -> None: """Test keyboard_pressed() and keyboard_release()""" with controller_client: controller_client.keyboard_pressed("escape") controller_client.keyboard_release("escape") -def test_screenshot_basic(controller_client: AskUiControllerClient) -> None: +def test_screenshot_basic(controller_client: MultiComputerTargetAgentOS) -> None: """Test taking screenshots with different report settings""" with controller_client: image_with_report = controller_client.screenshot() assert isinstance(image_with_report, Image.Image) -def test_get_display_information(controller_client: AskUiControllerClient) -> None: +def test_get_display_information(controller_client: MultiComputerTargetAgentOS) -> None: """Test retrieving display information""" with controller_client: display_info = controller_client.list_displays() assert display_info is not None -def test_get_process_list(controller_client: AskUiControllerClient) -> None: +def test_get_process_list(controller_client: MultiComputerTargetAgentOS) -> None: """Test retrieving running processes""" with controller_client: processes = controller_client.get_process_list() @@ -139,38 +145,40 @@ def test_get_process_list(controller_client: AskUiControllerClient) -> None: assert processes_extended is not None -def test_get_automation_target_list(controller_client: AskUiControllerClient) -> None: +def test_get_automation_target_list( + controller_client: MultiComputerTargetAgentOS, +) -> None: """Test retrieving automation targets""" with controller_client: targets = controller_client.get_automation_target_list() assert targets is not None -def test_set_display(controller_client: AskUiControllerClient) -> None: +def test_set_display(controller_client: MultiComputerTargetAgentOS) -> None: """Test changing active display""" with controller_client: controller_client.set_display(1) -def test_set_mouse_delay(controller_client: AskUiControllerClient) -> None: +def test_set_mouse_delay(controller_client: MultiComputerTargetAgentOS) -> None: """Test configuring mouse action delay""" with controller_client: controller_client.set_mouse_delay(100) -def test_set_keyboard_delay(controller_client: AskUiControllerClient) -> None: +def test_set_keyboard_delay(controller_client: MultiComputerTargetAgentOS) -> None: """Test configuring keyboard action delay""" with controller_client: controller_client.set_keyboard_delay(50) -def test_run_command(controller_client: AskUiControllerClient) -> None: +def test_run_command(controller_client: MultiComputerTargetAgentOS) -> None: """Test executing shell commands""" with controller_client: controller_client.run_command("echo test", 0) -def test_get_action_count(controller_client: AskUiControllerClient) -> None: +def test_get_action_count(controller_client: MultiComputerTargetAgentOS) -> None: """Test getting count of batched actions""" with controller_client: count = controller_client.get_action_count() @@ -179,7 +187,7 @@ def test_get_action_count(controller_client: AskUiControllerClient) -> None: def test_operations_before_connect() -> None: """Test calling methods before connect() raises appropriate errors""" - client = AskUiControllerClient(reporter=CompositeReporter(), display=1) + client = MultiComputerTargetAgentOS(reporter=CompositeReporter(), display=1) with pytest.raises( AssertionError, match="Stub is not initialized. Call `connect()` first." @@ -187,19 +195,19 @@ def test_operations_before_connect() -> None: client.screenshot() -def test_invalid_coordinates(controller_client: AskUiControllerClient) -> None: +def test_invalid_coordinates(controller_client: MultiComputerTargetAgentOS) -> None: """Test mouse operations with potentially problematic coordinates""" with controller_client: controller_client.mouse_move(-1, -1) controller_client.mouse_move(9999, 9999) -def test_set_mouse_position(controller_client: AskUiControllerClient) -> None: +def test_set_mouse_position(controller_client: MultiComputerTargetAgentOS) -> None: with controller_client: controller_client.set_mouse_position(100, 100) -def test_get_mouse_position(controller_client: AskUiControllerClient) -> None: +def test_get_mouse_position(controller_client: MultiComputerTargetAgentOS) -> None: """Test getting current mouse coordinates""" with controller_client: position = controller_client.get_mouse_position() @@ -208,7 +216,7 @@ def test_get_mouse_position(controller_client: AskUiControllerClient) -> None: assert hasattr(position, "y") -def test_render_quad(controller_client: AskUiControllerClient) -> None: +def test_render_quad(controller_client: MultiComputerTargetAgentOS) -> None: """Test adding a quad render object to the display""" with controller_client: style = RenderObjectStyle( @@ -225,7 +233,7 @@ def test_render_quad(controller_client: AskUiControllerClient) -> None: assert response is not None -def test_render_line(controller_client: AskUiControllerClient) -> None: +def test_render_line(controller_client: MultiComputerTargetAgentOS) -> None: """Test rendering a line object to the display""" with controller_client: style = RenderObjectStyle( @@ -240,7 +248,7 @@ def test_render_line(controller_client: AskUiControllerClient) -> None: def test_render_image( - controller_client: AskUiControllerClient, + controller_client: MultiComputerTargetAgentOS, askui_logo_bmp: Image.Image, ) -> None: """Test rendering an image object to the display""" @@ -262,7 +270,7 @@ def test_render_image( assert response is not None -def test_render_text(controller_client: AskUiControllerClient) -> None: +def test_render_text(controller_client: MultiComputerTargetAgentOS) -> None: """Test rendering a text object to the display""" with controller_client: style = RenderObjectStyle( @@ -279,7 +287,7 @@ def test_render_text(controller_client: AskUiControllerClient) -> None: assert response is not None -def test_update_render_object(controller_client: AskUiControllerClient) -> None: +def test_update_render_object(controller_client: MultiComputerTargetAgentOS) -> None: """Test updating an existing render object""" with controller_client: style = RenderObjectStyle( @@ -306,7 +314,7 @@ def test_update_render_object(controller_client: AskUiControllerClient) -> None: controller_client.update_render_object(object_id, update_style) -def test_update_text_object(controller_client: AskUiControllerClient) -> None: +def test_update_text_object(controller_client: MultiComputerTargetAgentOS) -> None: """Test updating an existing render object""" with controller_client: style = RenderObjectStyle( @@ -334,7 +342,7 @@ def test_update_text_object(controller_client: AskUiControllerClient) -> None: controller_client.update_render_object(object_id, update_style) -def test_delete_render_object(controller_client: AskUiControllerClient) -> None: +def test_delete_render_object(controller_client: MultiComputerTargetAgentOS) -> None: """Test deleting an existing render object""" with controller_client: style = RenderObjectStyle( @@ -350,7 +358,7 @@ def test_delete_render_object(controller_client: AskUiControllerClient) -> None: controller_client.delete_render_object(quad_id) -def test_clear_render_objects(controller_client: AskUiControllerClient) -> None: +def test_clear_render_objects(controller_client: MultiComputerTargetAgentOS) -> None: """Test clearing all render objects""" with controller_client: style1 = RenderObjectStyle( @@ -374,7 +382,7 @@ def test_clear_render_objects(controller_client: AskUiControllerClient) -> None: controller_client.clear_render_objects() -def test_get_system_info(controller_client: AskUiControllerClient) -> None: +def test_get_system_info(controller_client: MultiComputerTargetAgentOS) -> None: """Test getting system information""" with controller_client: system_info = controller_client.get_system_info() @@ -385,7 +393,7 @@ def test_get_system_info(controller_client: AskUiControllerClient) -> None: assert system_info.architecture is not None -def test_get_active_process(controller_client: AskUiControllerClient) -> None: +def test_get_active_process(controller_client: MultiComputerTargetAgentOS) -> None: with controller_client: active_process = controller_client.get_active_process() @@ -395,7 +403,7 @@ def test_get_active_process(controller_client: AskUiControllerClient) -> None: assert active_process.process.id is not None -def test_set_active_process(controller_client: AskUiControllerClient) -> None: +def test_set_active_process(controller_client: MultiComputerTargetAgentOS) -> None: """Test setting the active process""" with controller_client: controller_client.set_active_process(1062) @@ -404,7 +412,7 @@ def test_set_active_process(controller_client: AskUiControllerClient) -> None: assert active_process.process is not None -def test_get_active_window(controller_client: AskUiControllerClient) -> None: +def test_get_active_window(controller_client: MultiComputerTargetAgentOS) -> None: """Test getting the active window""" with controller_client: active_window = controller_client.get_active_window() diff --git a/tests/integration/agent/test_retry.py b/tests/integration/agent/test_retry.py index 8f08d51a..bd1d453e 100644 --- a/tests/integration/agent/test_retry.py +++ b/tests/integration/agent/test_retry.py @@ -10,7 +10,7 @@ from askui.models.exceptions import ElementNotFoundError, ModelNotFoundError from askui.models.shared.settings import LocateSettings from askui.models.types.geometry import PointList -from askui.tools.toolbox import AgentToolbox +from askui.tools.agent_os import ComputerAgentOS from askui.utils.image_utils import ImageSource @@ -58,21 +58,21 @@ def always_failing_provider() -> FailingDetectionProvider: @pytest.fixture def agent_with_retry( - failing_provider: FailingDetectionProvider, agent_toolbox_mock: AgentToolbox + failing_provider: FailingDetectionProvider, + agent_os_mock_patch: ComputerAgentOS, # noqa: ARG001 ) -> ComputerAgent: return ComputerAgent( settings=AgentSettings(detection_provider=failing_provider), - tools=agent_toolbox_mock, ) @pytest.fixture def agent_with_retry_on_multiple_exceptions( - failing_provider: FailingDetectionProvider, agent_toolbox_mock: AgentToolbox + failing_provider: FailingDetectionProvider, + agent_os_mock_patch: ComputerAgentOS, # noqa: ARG001 ) -> ComputerAgent: return ComputerAgent( settings=AgentSettings(detection_provider=failing_provider), - tools=agent_toolbox_mock, retry=ConfigurableRetry( on_exception_types=( ElementNotFoundError, @@ -88,11 +88,11 @@ def agent_with_retry_on_multiple_exceptions( @pytest.fixture def agent_always_fail( - always_failing_provider: FailingDetectionProvider, agent_toolbox_mock: AgentToolbox + always_failing_provider: FailingDetectionProvider, + agent_os_mock_patch: ComputerAgentOS, # noqa: ARG001 ) -> ComputerAgent: return ComputerAgent( settings=AgentSettings(detection_provider=always_failing_provider), - tools=agent_toolbox_mock, retry=ConfigurableRetry( on_exception_types=(ElementNotFoundError,), strategy="Fixed", diff --git a/tests/integration/test_custom_models.py b/tests/integration/test_custom_models.py index 996f610a..0bb8a266 100644 --- a/tests/integration/test_custom_models.py +++ b/tests/integration/test_custom_models.py @@ -26,7 +26,7 @@ from askui.models.shared.prompts import SystemPrompt from askui.models.shared.settings import GetSettings, LocateSettings from askui.models.shared.tools import ToolCollection -from askui.tools.toolbox import AgentToolbox +from askui.tools.agent_os import ComputerAgentOS from askui.utils.image_utils import ImageSource from askui.utils.source_utils import Source @@ -148,12 +148,11 @@ def detection_provider(self) -> SimpleDetectionProvider: def test_inject_and_use_custom_vlm_provider( self, vlm_provider: SimpleVlmProvider, - agent_toolbox_mock: AgentToolbox, + agent_os_mock_patch: ComputerAgentOS, # noqa: ARG002 ) -> None: """Test injecting and using a custom VLM provider.""" with ComputerAgent( settings=AgentSettings(vlm_provider=vlm_provider), - tools=agent_toolbox_mock, ) as agent: agent.act("test goal") @@ -175,12 +174,11 @@ def test_inject_and_use_custom_vlm_provider( def test_inject_and_use_custom_image_qa_provider( self, image_qa_provider: SimpleImageQAProvider, - agent_toolbox_mock: AgentToolbox, + agent_os_mock_patch: ComputerAgentOS, # noqa: ARG002 ) -> None: """Test injecting and using a custom image Q&A provider.""" with ComputerAgent( settings=AgentSettings(image_qa_provider=image_qa_provider), - tools=agent_toolbox_mock, ) as agent: result = agent.get("test query") @@ -190,13 +188,12 @@ def test_inject_and_use_custom_image_qa_provider( def test_inject_and_use_custom_image_qa_provider_with_pdf( self, image_qa_provider: SimpleImageQAProvider, - agent_toolbox_mock: AgentToolbox, + agent_os_mock_patch: ComputerAgentOS, # noqa: ARG002 path_fixtures_dummy_pdf: pathlib.Path, ) -> None: """Test injecting and using a custom image Q&A provider with a PDF.""" with ComputerAgent( settings=AgentSettings(image_qa_provider=image_qa_provider), - tools=agent_toolbox_mock, ) as agent: result = agent.get("test query", source=path_fixtures_dummy_pdf) @@ -206,12 +203,11 @@ def test_inject_and_use_custom_image_qa_provider_with_pdf( def test_inject_and_use_custom_detection_provider( self, detection_provider: SimpleDetectionProvider, - agent_toolbox_mock: AgentToolbox, + agent_os_mock_patch: ComputerAgentOS, # noqa: ARG002 ) -> None: """Test injecting and using a custom detection provider.""" with ComputerAgent( settings=AgentSettings(detection_provider=detection_provider), - tools=agent_toolbox_mock, ) as agent: agent.click("test element") @@ -222,7 +218,7 @@ def test_inject_all_custom_providers( vlm_provider: SimpleVlmProvider, image_qa_provider: SimpleImageQAProvider, detection_provider: SimpleDetectionProvider, - agent_toolbox_mock: AgentToolbox, + agent_os_mock_patch: ComputerAgentOS, # noqa: ARG002 ) -> None: """Test injecting all custom providers at once.""" with ComputerAgent( @@ -231,7 +227,6 @@ def test_inject_all_custom_providers( image_qa_provider=image_qa_provider, detection_provider=detection_provider, ), - tools=agent_toolbox_mock, ) as agent: agent.act("test goal") result = agent.get("test query") @@ -258,7 +253,7 @@ def test_inject_all_custom_providers( def test_use_response_schema_with_custom_image_qa_provider( self, image_qa_provider: SimpleImageQAProvider, - agent_toolbox_mock: AgentToolbox, + agent_os_mock_patch: ComputerAgentOS, # noqa: ARG002 ) -> None: """Test using a response schema with a custom image Q&A provider.""" response = SimpleResponseSchema(value="test value") @@ -266,7 +261,6 @@ def test_use_response_schema_with_custom_image_qa_provider( with ComputerAgent( settings=AgentSettings(image_qa_provider=image_qa_provider), - tools=agent_toolbox_mock, ) as agent: result = agent.get("test query", response_schema=SimpleResponseSchema) @@ -276,8 +270,8 @@ def test_use_response_schema_with_custom_image_qa_provider( def test_defaults_to_built_in_providers_when_not_provided( self, - agent_toolbox_mock: AgentToolbox, + agent_os_mock_patch: ComputerAgentOS, # noqa: ARG002 ) -> None: """Test agent uses built-in defaults when custom ones not provided.""" - with ComputerAgent(tools=agent_toolbox_mock) as agent: + with ComputerAgent() as agent: assert agent is not None diff --git a/tests/unit/tools/askui/test_agent_os_target_computer.py b/tests/unit/tools/askui/test_agent_os_target_computer.py new file mode 100644 index 00000000..86bffd7c --- /dev/null +++ b/tests/unit/tools/askui/test_agent_os_target_computer.py @@ -0,0 +1,131 @@ +from typing import Callable + +import pytest + +from askui.tools.askui.agent_os_target_computer import ( + ComputerTarget, + LocalComputerTarget, + RemoteComputerTarget, +) + + +class TestReplacePort: + def test_replaces_port_on_bare_authority(self) -> None: + assert ( + LocalComputerTarget.replace_port("example.com:1234", 23000) + == "example.com:23000" + ) + + def test_replaces_port_on_url_with_scheme(self) -> None: + assert ( + LocalComputerTarget.replace_port("http://example.com:1234", 23000) + == "example.com:23000" + ) + + def test_falls_back_to_localhost_when_host_missing(self) -> None: + # A bare ":1234" has no hostname, so the helper falls back to "localhost". + assert LocalComputerTarget.replace_port(":1234", 23000) == "localhost:23000" + + +class TestAgentOsTargetComputer: + def test_session_guid_unique_per_instance(self) -> None: + a = RemoteComputerTarget(address="1.2.3.4:23000", description="a") + b = RemoteComputerTarget(address="5.6.7.8:23000", description="b") + assert a.session_guid != b.session_guid + + def test_computer_id_defaults_to_session_guid(self) -> None: + s = RemoteComputerTarget(address="1.2.3.4:23000", description="a") + assert s.computer_id == s.session_guid + + def test_explicit_computer_id_is_preserved(self) -> None: + s = RemoteComputerTarget( + address="1.2.3.4:23000", description="a", computer_id="laptop" + ) + assert s.computer_id == "laptop" + assert s.session_guid != "laptop" + + def test_display_defaults_to_one_and_is_settable(self) -> None: + s = RemoteComputerTarget(address="1.2.3.4:23000", description="a") + assert s.display == 1 + s.display = 3 + assert s.display == 3 + + def test_explicit_display_is_preserved(self) -> None: + s = RemoteComputerTarget(address="1.2.3.4:23000", description="a", display=2) + assert s.display == 2 + + def test_repr_contains_identity_fields(self) -> None: + s = RemoteComputerTarget( + address="1.2.3.4:23000", + description="my rig", + display=2, + computer_id="rig", + ) + r = repr(s) + assert "RemoteComputerTarget" in r + assert "computer_id='rig'" in r + assert "description='my rig'" in r + assert "display=2" in r + + def test_base_class_is_not_local(self) -> None: + s = RemoteComputerTarget(address="1.2.3.4:23000", description="a") + assert s.is_local is False + + def test_start_and_stop_are_no_ops_on_remote(self) -> None: + s = RemoteComputerTarget(address="1.2.3.4:23000", description="a") + s.start() + s.stop() + + +class TestLocalAgentOsTargetComputer: + def test_is_local(self) -> None: + s = LocalComputerTarget(discover_service=False) + assert s.is_local is True + + def test_default_description(self) -> None: + s = LocalComputerTarget(discover_service=False) + assert s.description == "Local computer target" + + def test_default_address(self) -> None: + s = LocalComputerTarget(discover_service=False) + assert s.address == "localhost:23000" + + def test_is_service_default_false(self) -> None: + s = LocalComputerTarget(discover_service=False) + assert s.is_service is False + + def test_explicit_computer_id(self) -> None: + s = LocalComputerTarget(discover_service=False, computer_id="my-laptop") + assert s.computer_id == "my-laptop" + + def test_parse_port_rejects_bad_address(self) -> None: + s = LocalComputerTarget(discover_service=False, address="no-port-here") + with pytest.raises(ValueError, match="Could not parse port"): + s._parse_port() # noqa: SLF001 - intentional unit test against helper + + def test_parse_port_extracts_port(self) -> None: + s = LocalComputerTarget(discover_service=False, address="localhost:24567") + assert s._parse_port() == 24567 # noqa: SLF001 + + +class TestSubclassesPassThroughDisplayAndId: + @pytest.mark.parametrize( + "factory", + [ + lambda: LocalComputerTarget( + discover_service=False, display=4, computer_id="local" + ), + lambda: RemoteComputerTarget( + address="1.2.3.4:23000", + description="r", + display=4, + computer_id="remote", + ), + ], + ) + def test_display_and_computer_id_round_trip( + self, factory: Callable[[], ComputerTarget] + ) -> None: + s: ComputerTarget = factory() + assert s.display == 4 + assert s.computer_id in {"local", "remote"} diff --git a/tests/unit/tools/askui/test_askui_controller_client.py b/tests/unit/tools/askui/test_askui_controller_client.py new file mode 100644 index 00000000..4c007f5a --- /dev/null +++ b/tests/unit/tools/askui/test_askui_controller_client.py @@ -0,0 +1,216 @@ +""" +Unit tests for `MultiComputerTargetAgentOS`'s multi-target registration / routing +logic. These tests intentionally avoid exercising the gRPC code path (which +needs a real controller binary). They cover the in-memory bookkeeping done by +the client and its `ComputerTargetPool`. +""" + +import pytest + +from askui.tools.askui.agent_os_target_computer import ( + LocalComputerTarget, + RemoteComputerTarget, +) +from askui.tools.askui.askui_controller import MultiComputerTargetAgentOS +from askui.tools.askui.computer_target_pool import ( + ComputerTargetPool, +) +from askui.tools.askui.exceptions import AskUiControllerError + + +def _make_local( + description: str = "local", computer_id: str | None = None, display: int = 1 +) -> LocalComputerTarget: + return LocalComputerTarget( + description=description, + discover_service=False, + computer_id=computer_id, + display=display, + ) + + +def _make_remote( + address: str = "1.2.3.4:23000", + description: str = "remote", + computer_id: str | None = None, + display: int = 1, +) -> RemoteComputerTarget: + return RemoteComputerTarget( + address=address, + description=description, + computer_id=computer_id, + display=display, + ) + + +class TestConstruction: + def test_default_registers_single_local_target(self) -> None: + client = MultiComputerTargetAgentOS() + manager = client.agent_os_target_computer_manager + assert len(manager) == 1 + assert isinstance(manager.active, LocalComputerTarget) + + def test_default_propagates_display_to_default_local_target(self) -> None: + client = MultiComputerTargetAgentOS(display=3) + active = client.agent_os_target_computer_manager.active + assert active is not None + assert active.display == 3 + + def test_accepts_explicit_targets(self) -> None: + a = _make_local(computer_id="local") + b = _make_remote(computer_id="remote") + client = MultiComputerTargetAgentOS(agent_os_target_computers=[a, b]) + assert client.agent_os_target_computer_manager.describe() == [ + repr(a), + repr(b), + ] + assert client.agent_os_target_computer_manager.active is a + + def test_explicit_targets_keep_their_own_display(self) -> None: + """Constructor's display arg only seeds the auto-created default target.""" + a = _make_local(computer_id="local", display=2) + b = _make_remote(computer_id="remote", display=3) + client = MultiComputerTargetAgentOS(display=5, agent_os_target_computers=[a, b]) + assert client.agent_os_target_computer_manager.get("local").display == 2 + assert client.agent_os_target_computer_manager.get("remote").display == 3 + + def test_is_connected_false_before_connect(self) -> None: + client = MultiComputerTargetAgentOS(agent_os_target_computers=[_make_remote()]) + assert client.is_connected is False + + +class TestActiveTarget: + def test_get_current_returns_first_registered_id(self) -> None: + a = _make_local(computer_id="a") + b = _make_remote(computer_id="b") + client = MultiComputerTargetAgentOS(agent_os_target_computers=[a, b]) + assert client.get_current_computer_target_id() == "a" + + def test_get_current_with_empty_manager_raises(self) -> None: + client = MultiComputerTargetAgentOS(agent_os_target_computers=[_make_remote()]) + client.agent_os_target_computer_manager.reset() + with pytest.raises( + AskUiControllerError, match="No active Agent OS target computer" + ): + client.get_current_computer_target_id(report=False) + + +class TestSwitchAgentOsTargetComputer: + def test_switch_changes_active_when_disconnected(self) -> None: + a = _make_local(computer_id="a") + b = _make_remote(computer_id="b") + client = MultiComputerTargetAgentOS(agent_os_target_computers=[a, b]) + client.switch_agent_os_target_computer("b") + assert client.agent_os_target_computer_manager.active is b + + def test_switch_unknown_computer_id_raises_keyerror(self) -> None: + client = MultiComputerTargetAgentOS( + agent_os_target_computers=[_make_local(computer_id="a")] + ) + with pytest.raises(KeyError, match="missing"): + client.switch_agent_os_target_computer("missing") + + def test_switch_returns_the_new_active_target(self) -> None: + a = _make_local(computer_id="a") + b = _make_remote(computer_id="b") + client = MultiComputerTargetAgentOS(agent_os_target_computers=[a, b]) + result = client.switch_agent_os_target_computer("b") + assert result is b + + def test_per_target_display_preserved_across_switch(self) -> None: + a = _make_local(computer_id="a", display=1) + b = _make_remote(computer_id="b", display=4) + client = MultiComputerTargetAgentOS(agent_os_target_computers=[a, b]) + client.switch_agent_os_target_computer("b") + active_b = client.agent_os_target_computer_manager.active + assert active_b is not None + assert active_b.display == 4 + client.switch_agent_os_target_computer("a") + active_a = client.agent_os_target_computer_manager.active + assert active_a is not None + assert active_a.display == 1 + + +class TestDescribeAndReset: + def test_describe_returns_registered_target_summaries(self) -> None: + a = _make_local(computer_id="a") + b = _make_remote(computer_id="b") + client = MultiComputerTargetAgentOS(agent_os_target_computers=[a, b]) + assert client.describe_agent_os_target_computers() == [repr(a), repr(b)] + + def test_reset_with_no_args_leaves_manager_empty(self) -> None: + client = MultiComputerTargetAgentOS( + agent_os_target_computers=[_make_remote(computer_id="r")] + ) + client.reset_agent_os_target_computers() + assert client.describe_agent_os_target_computers() == [] + + def test_reset_with_new_list_replaces_registrations(self) -> None: + client = MultiComputerTargetAgentOS( + agent_os_target_computers=[_make_remote(computer_id="old")] + ) + new_agent_os_target_computer = _make_remote( + address="9.9.9.9:23000", computer_id="new" + ) + client.reset_agent_os_target_computers([new_agent_os_target_computer]) + assert client.describe_agent_os_target_computers() == [ + repr(new_agent_os_target_computer) + ] + assert ( + client.agent_os_target_computer_manager.active + is new_agent_os_target_computer + ) + + +class TestAddAgentOsTargetComputerWhileDisconnected: + def test_add_already_constructed_target(self) -> None: + client = MultiComputerTargetAgentOS( + agent_os_target_computers=[_make_local(computer_id="l")] + ) + extra = _make_remote(address="2.2.2.2:23000", computer_id="r") + result = client.add_agent_os_target_computer(extra) + assert result is extra + assert repr(extra) in client.describe_agent_os_target_computers() + + +class TestTemporarySelect: + def test_temporary_select_restores_previous_active(self) -> None: + a = _make_local(computer_id="a") + b = _make_remote(computer_id="b") + client = MultiComputerTargetAgentOS(agent_os_target_computers=[a, b]) + manager = client.agent_os_target_computer_manager + before = manager.active + assert before is a + with client.temporary_select("b"): + inside = manager.active + assert inside is b + after = manager.active + assert after is a + + def test_temporary_select_restores_previous_even_on_exception(self) -> None: + a = _make_local(computer_id="a") + b = _make_remote(computer_id="b") + client = MultiComputerTargetAgentOS(agent_os_target_computers=[a, b]) + error_message = "boom" + with ( + pytest.raises(RuntimeError, match=error_message), + client.temporary_select("b"), + ): + assert client.agent_os_target_computer_manager.active is b + raise RuntimeError(error_message) + assert client.agent_os_target_computer_manager.active is a + + def test_temporary_select_same_id_is_a_noop_around_yield(self) -> None: + a = _make_local(computer_id="a") + client = MultiComputerTargetAgentOS(agent_os_target_computers=[a]) + with client.temporary_select("a"): + assert client.agent_os_target_computer_manager.active is a + assert client.agent_os_target_computer_manager.active is a + + +class TestUsesAgentOsTargetComputerManager: + def test_underlying_manager_is_an_agent_os_target_computer_manager(self) -> None: + client = MultiComputerTargetAgentOS( + agent_os_target_computers=[_make_local(computer_id="l")] + ) + assert isinstance(client.agent_os_target_computer_manager, ComputerTargetPool) diff --git a/tests/unit/tools/askui/test_askui_controller_client_settings.py b/tests/unit/tools/askui/test_askui_controller_client_settings.py deleted file mode 100644 index 3a086453..00000000 --- a/tests/unit/tools/askui/test_askui_controller_client_settings.py +++ /dev/null @@ -1,73 +0,0 @@ -from unittest.mock import patch - -import pytest -from pydantic import ValidationError - -from askui.tools.askui.askui_controller_client_settings import ( - AskUiControllerClientSettings, -) - - -class TestAskUiControllerClientSettings: - """Test suite for AskUiControllerClientSettings.""" - - def test_defaults(self) -> None: - """Defaults are applied when no environment variables are set.""" - with patch.dict("os.environ", {}, clear=True): - settings = AskUiControllerClientSettings() - assert settings.server_address == "localhost:23000" - assert settings.server_autostart is True - - def test_server_address_from_env(self) -> None: - """ - `ASKUI_CONTROLLER_CLIENT_SERVER_ADDRESS` overrides default for `server_address`. - """ - with patch.dict( - "os.environ", - {"ASKUI_CONTROLLER_CLIENT_SERVER_ADDRESS": "127.0.0.1:24000"}, - clear=True, - ): - settings = AskUiControllerClientSettings() - assert settings.server_address == "127.0.0.1:24000" - - def test_server_autostart_from_env_false(self) -> None: - """`ASKUI_CONTROLLER_CLIENT_SERVER_AUTOSTART` parses boolean from env.""" - with patch.dict( - "os.environ", - {"ASKUI_CONTROLLER_CLIENT_SERVER_AUTOSTART": "False"}, - clear=True, - ): - settings = AskUiControllerClientSettings() - assert settings.server_autostart is False - - def test_server_autostart_from_env_true(self) -> None: - """Boolean true value is parsed correctly from environment variable.""" - with patch.dict( - "os.environ", - {"ASKUI_CONTROLLER_CLIENT_SERVER_AUTOSTART": "true"}, - clear=True, - ): - settings = AskUiControllerClientSettings() - assert settings.server_autostart is True - - def test_server_address_from_constructor(self) -> None: - """`server_address` is set correctly from constructor.""" - settings = AskUiControllerClientSettings(server_address="127.0.0.1:24000") - assert settings.server_address == "127.0.0.1:24000" - - def test_server_autostart_from_constructor(self) -> None: - """`server_autostart` is set correctly from constructor.""" - settings = AskUiControllerClientSettings(server_autostart=False) - assert settings.server_autostart is False - - def test_autostart_from_env_with_invalid_value(self) -> None: - """ - Test that ValidationError is raised when environment variable is invalid. - """ - with patch.dict( - "os.environ", - {"ASKUI_CONTROLLER_CLIENT_SERVER_AUTOSTART": "invalid"}, - clear=True, - ): - with pytest.raises(ValidationError): - AskUiControllerClientSettings() diff --git a/tests/unit/tools/askui/test_computer_target_pool.py b/tests/unit/tools/askui/test_computer_target_pool.py new file mode 100644 index 00000000..65f33c25 --- /dev/null +++ b/tests/unit/tools/askui/test_computer_target_pool.py @@ -0,0 +1,203 @@ +from collections.abc import Callable + +import pytest + +from askui.tools.askui.agent_os_target_computer import ( + ComputerTarget, + LocalComputerTarget, + RemoteComputerTarget, +) +from askui.tools.askui.computer_target_pool import ( + ComputerTargetPool, +) + + +def _make_remote( + address: str = "1.2.3.4:23000", + description: str = "remote", + computer_id: str | None = None, +) -> RemoteComputerTarget: + return RemoteComputerTarget( + address=address, description=description, computer_id=computer_id + ) + + +def _make_local(computer_id: str | None = None) -> LocalComputerTarget: + return LocalComputerTarget(discover_service=False, computer_id=computer_id) + + +@pytest.fixture(params=["local", "remote"]) +def make_target( + request: pytest.FixtureRequest, +) -> Callable[..., ComputerTarget]: + """Build a single target of the parametrized kind so a test runs once per kind. + + Use for tests that register exactly one target and where the local/remote + distinction is irrelevant to the behavior under test. + """ + + def _make( + computer_id: str | None = None, + address: str = "1.2.3.4:23000", + ) -> ComputerTarget: + if request.param == "local": + return _make_local(computer_id=computer_id) + return _make_remote(address=address, computer_id=computer_id) + + return _make + + +class TestConstruction: + def test_empty_constructor_yields_empty_manager(self) -> None: + m = ComputerTargetPool() + assert m.describe() == [] + assert m.active is None + assert len(m) == 0 + + def test_constructor_registers_initial_targets_in_order(self) -> None: + a = _make_remote(address="1.1.1.1:23000", computer_id="a") + b = _make_remote(address="2.2.2.2:23000", computer_id="b") + m = ComputerTargetPool(agent_os_target_computers=[a, b]) + assert m.describe() == [repr(a), repr(b)] + # First registered becomes active. + assert m.active is a + + def test_first_added_becomes_active( + self, make_target: Callable[..., ComputerTarget] + ) -> None: + m = ComputerTargetPool() + a = make_target(computer_id="a") + m.add(a) + assert m.active is a + + +class TestAddConstraints: + def test_rejects_second_local_target(self) -> None: + m = ComputerTargetPool() + m.add(_make_local(computer_id="first")) + with pytest.raises(ValueError, match="second local Agent OS target computer"): + m.add(_make_local(computer_id="second")) + + def test_rejects_duplicate_computer_id(self) -> None: + m = ComputerTargetPool() + m.add(_make_remote(address="1.1.1.1:23000", computer_id="rig")) + with pytest.raises(ValueError, match="computer_id='rig'"): + m.add(_make_remote(address="2.2.2.2:23000", computer_id="rig")) + + def test_rejects_duplicate_remote_address(self) -> None: + m = ComputerTargetPool() + m.add(_make_remote(address="1.1.1.1:23000", computer_id="a")) + with pytest.raises( + ValueError, + match="remote Agent OS target computer with address '1.1.1.1:23000'", + ): + m.add(_make_remote(address="1.1.1.1:23000", computer_id="b")) + + def test_allows_local_plus_remote_with_same_address(self) -> None: + m = ComputerTargetPool() + m.add(_make_local(computer_id="local")) + # Local target's default address is 'localhost:23000' but the local/remote + # address-uniqueness rule only applies between remote targets. + m.add( + _make_remote( + address="localhost:23000", description="remote", computer_id="remote" + ) + ) + assert len(m) == 2 + + +class TestGetAndSwitch: + def test_get_returns_target_by_computer_id( + self, make_target: Callable[..., ComputerTarget] + ) -> None: + m = ComputerTargetPool() + a = make_target(address="1.1.1.1:23000", computer_id="a") + m.add(a) + assert m.get("a") is a + + def test_get_raises_keyerror_with_registered_ids( + self, make_target: Callable[..., ComputerTarget] + ) -> None: + m = ComputerTargetPool() + m.add(make_target(address="1.1.1.1:23000", computer_id="a")) + with pytest.raises(KeyError) as exc_info: + m.get("missing") + message = str(exc_info.value) + assert "missing" in message + assert "'a'" in message # registered id surfaced + + def test_switch_changes_active(self) -> None: + m = ComputerTargetPool() + a = _make_remote(address="1.1.1.1:23000", computer_id="a") + b = _make_remote(address="2.2.2.2:23000", computer_id="b") + m.add(a) + m.add(b) + assert m.active is a + m.switch("b") + assert m.active is b + + def test_switch_unknown_id_raises_keyerror( + self, make_target: Callable[..., ComputerTarget] + ) -> None: + m = ComputerTargetPool() + m.add(make_target(computer_id="a")) + with pytest.raises(KeyError, match="missing"): + m.switch("missing") + + +class TestRemove: + def test_remove_drops_target(self) -> None: + m = ComputerTargetPool() + a = _make_remote(address="1.1.1.1:23000", computer_id="a") + b = _make_remote(address="2.2.2.2:23000", computer_id="b") + m.add(a) + m.add(b) + m.remove("a") + assert m.describe() == [repr(b)] + + def test_remove_active_falls_back_to_first_remaining(self) -> None: + m = ComputerTargetPool() + a = _make_remote(address="1.1.1.1:23000", computer_id="a") + b = _make_remote(address="2.2.2.2:23000", computer_id="b") + m.add(a) + m.add(b) + assert m.active is a + m.remove("a") + assert m.active is b + + def test_remove_last_clears_active( + self, make_target: Callable[..., ComputerTarget] + ) -> None: + m = ComputerTargetPool() + m.add(make_target(computer_id="a")) + m.remove("a") + assert m.active is None + assert len(m) == 0 + + def test_remove_inactive_keeps_active_unchanged(self) -> None: + m = ComputerTargetPool() + a = _make_remote(address="1.1.1.1:23000", computer_id="a") + b = _make_remote(address="2.2.2.2:23000", computer_id="b") + m.add(a) + m.add(b) + m.remove("b") + assert m.active is a + + def test_remove_unknown_raises_keyerror( + self, make_target: Callable[..., ComputerTarget] + ) -> None: + m = ComputerTargetPool() + m.add(make_target(computer_id="a")) + with pytest.raises(KeyError): + m.remove("missing") + + +class TestReset: + def test_reset_clears_all(self) -> None: + m = ComputerTargetPool() + m.add(_make_remote(computer_id="a")) + m.add(_make_remote(address="2.2.2.2:23000", computer_id="b")) + m.reset() + assert m.describe() == [] + assert m.active is None + assert len(m) == 0 diff --git a/tests/unit/tools/computer/__init__.py b/tests/unit/tools/computer/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/tools/computer/test_agent_os_target_computer_tools.py b/tests/unit/tools/computer/test_agent_os_target_computer_tools.py new file mode 100644 index 00000000..c3c6132a --- /dev/null +++ b/tests/unit/tools/computer/test_agent_os_target_computer_tools.py @@ -0,0 +1,80 @@ +from unittest.mock import MagicMock + +import pytest + +from askui.tools.agent_os import ComputerAgentOS +from askui.tools.askui.agent_os_target_computer import RemoteComputerTarget +from askui.tools.computer import ( + ComputerGetCurrentComputerTargetIdTool, + ComputerListAgentOsTargetComputersTool, + ComputerSwitchAgentOsTargetComputerTool, +) + + +@pytest.fixture +def fake_agent_os() -> MagicMock: + """A MagicMock that passes `isinstance(x, ComputerAgentOS)` checks.""" + return MagicMock(spec=ComputerAgentOS) + + +class TestComputerListAgentOsTargetComputersTool: + def test_tool_name(self, fake_agent_os: MagicMock) -> None: + tool = ComputerListAgentOsTargetComputersTool(agent_os=fake_agent_os) + assert tool.base_name == "list_agent_os_target_computers" + + def test_returns_newline_joined_reprs(self, fake_agent_os: MagicMock) -> None: + a = RemoteComputerTarget( + address="1.1.1.1:23000", description="a", computer_id="a" + ) + b = RemoteComputerTarget( + address="2.2.2.2:23000", description="b", computer_id="b" + ) + fake_agent_os.describe_agent_os_target_computers.return_value = [ + repr(a), + repr(b), + ] + tool = ComputerListAgentOsTargetComputersTool(agent_os=fake_agent_os) + out = tool() + assert out == f"{a!r}\n{b!r}" + + def test_empty_list_yields_empty_string(self, fake_agent_os: MagicMock) -> None: + fake_agent_os.describe_agent_os_target_computers.return_value = [] + tool = ComputerListAgentOsTargetComputersTool(agent_os=fake_agent_os) + assert tool() == "" + + +class TestComputerSwitchAgentOsTargetComputerTool: + def test_tool_name(self, fake_agent_os: MagicMock) -> None: + tool = ComputerSwitchAgentOsTargetComputerTool(agent_os=fake_agent_os) + assert tool.base_name == "switch_agent_os_target_computer" + + def test_input_schema_requires_computer_id(self, fake_agent_os: MagicMock) -> None: + tool = ComputerSwitchAgentOsTargetComputerTool(agent_os=fake_agent_os) + schema = tool.input_schema + assert "computer_id" in schema["properties"] + assert schema["required"] == ["computer_id"] + + def test_call_delegates_to_switch_agent_os_target_computer( + self, fake_agent_os: MagicMock + ) -> None: + switched = RemoteComputerTarget( + address="1.1.1.1:23000", description="new", computer_id="new" + ) + fake_agent_os.switch_agent_os_target_computer.return_value = switched + tool = ComputerSwitchAgentOsTargetComputerTool(agent_os=fake_agent_os) + out = tool(computer_id="new") + fake_agent_os.switch_agent_os_target_computer.assert_called_once_with("new") + assert out == repr(switched) + + +class TestComputerGetCurrentComputerTargetIdTool: + def test_tool_name(self, fake_agent_os: MagicMock) -> None: + tool = ComputerGetCurrentComputerTargetIdTool(agent_os=fake_agent_os) + assert tool.base_name == "get_current_computer_target_id" + + def test_call_returns_current_computer_id(self, fake_agent_os: MagicMock) -> None: + fake_agent_os.get_current_computer_target_id.return_value = "a" + tool = ComputerGetCurrentComputerTargetIdTool(agent_os=fake_agent_os) + out = tool() + fake_agent_os.get_current_computer_target_id.assert_called_once_with() + assert out == "a"