diff --git a/agent/app/src/main/java/com/example/appfunctions/agent/data/db/AppDatabase.kt b/agent/app/src/main/java/com/example/appfunctions/agent/data/db/AppDatabase.kt index f6b703a..ae383e6 100644 --- a/agent/app/src/main/java/com/example/appfunctions/agent/data/db/AppDatabase.kt +++ b/agent/app/src/main/java/com/example/appfunctions/agent/data/db/AppDatabase.kt @@ -21,7 +21,7 @@ import com.example.appfunctions.agent.data.db.dao.ChatDao import com.example.appfunctions.agent.data.db.entities.MessageEntity import com.example.appfunctions.agent.data.db.entities.ThreadEntity -@Database(entities = [ThreadEntity::class, MessageEntity::class], version = 1, exportSchema = false) +@Database(entities = [ThreadEntity::class, MessageEntity::class], version = 2, exportSchema = false) abstract class AppDatabase : RoomDatabase() { abstract fun chatDao(): ChatDao } diff --git a/agent/app/src/main/java/com/example/appfunctions/agent/data/db/entities/MessageEntity.kt b/agent/app/src/main/java/com/example/appfunctions/agent/data/db/entities/MessageEntity.kt index 7c252f4..075700b 100644 --- a/agent/app/src/main/java/com/example/appfunctions/agent/data/db/entities/MessageEntity.kt +++ b/agent/app/src/main/java/com/example/appfunctions/agent/data/db/entities/MessageEntity.kt @@ -45,6 +45,7 @@ data class MessageEntity( * only non-null if Assistant returned PendingIntent tool response. */ val pendingIntentId: String? = null, + val targetPackageName: String? = null, ) enum class MessageRole { diff --git a/agent/app/src/main/java/com/example/appfunctions/agent/di/DataModule.kt b/agent/app/src/main/java/com/example/appfunctions/agent/di/DataModule.kt index 52d4fd0..78ddcc7 100644 --- a/agent/app/src/main/java/com/example/appfunctions/agent/di/DataModule.kt +++ b/agent/app/src/main/java/com/example/appfunctions/agent/di/DataModule.kt @@ -85,6 +85,7 @@ abstract class DataModule { AppDatabase::class.java, "app_database", ) + .fallbackToDestructiveMigration() .build() } diff --git a/agent/app/src/main/java/com/example/appfunctions/agent/domain/AgentOrchestrator.kt b/agent/app/src/main/java/com/example/appfunctions/agent/domain/AgentOrchestrator.kt index cee4777..b48931d 100644 --- a/agent/app/src/main/java/com/example/appfunctions/agent/domain/AgentOrchestrator.kt +++ b/agent/app/src/main/java/com/example/appfunctions/agent/domain/AgentOrchestrator.kt @@ -95,7 +95,6 @@ class AgentOrchestrator try { val provider = thread.llmModel.providerName - val modelName = thread.llmModel.modelName val apiKey = getApiKey(provider) if (apiKey == null) { completeMessageWithError( @@ -108,39 +107,21 @@ class AgentOrchestrator } val disconnectedApps = settingsRepository.disconnectedApps.first() - val tools = - getAppFunctionsUseCase().first().values.flatten().filter { metadata -> - metadata.isEnabled && metadata.packageName !in disconnectedApps - } - var previousInteractionId = thread.latestInteractionId - val currentInput = message.textContent - var currentToolOutputs = emptyList() - var continueLoop = true - - val llmProvider = llmProviderFactory.getProvider(provider) - - while (continueLoop) { - val llmInput = prepareLlmInput(currentToolOutputs, currentInput) - - currentToolOutputs = emptyList() - val response = - llmProvider.generateResponse( - previousInteractionId = previousInteractionId, - input = llmInput, - tools = tools, - apiKey = apiKey, - modelName = modelName, - ) - when (val handleResult = handleLlmResponse(response, message, tools)) { - is HandleResult.Continue -> { - currentToolOutputs = handleResult.toolOutputs - previousInteractionId = handleResult.interactionId - } - is HandleResult.Stop -> { - continueLoop = false - } - } - } + val allTools = getAppFunctionsUseCase().first().values.flatten() + + val targetPackageName = message.targetPackageName + val queryText = message.textContent + + val tools = filterTools(allTools, disconnectedApps, targetPackageName) + + runInteractionLoop( + message = message, + thread = thread, + apiKey = apiKey, + tools = tools, + initialInput = queryText, + ) + _status.value = AgentStatus.Idle } catch (e: Exception) { Log.e("AgentOrchestrator", "Error processing message", e) @@ -153,6 +134,60 @@ class AgentOrchestrator } } + private fun filterTools( + allTools: List, + disconnectedApps: Set, + targetPackageName: String?, + ): List { + return allTools.filter { metadata -> + metadata.isEnabled && + metadata.packageName !in disconnectedApps && + (targetPackageName == null || metadata.packageName == targetPackageName) + } + } + + private suspend fun runInteractionLoop( + message: MessageEntity, + thread: ThreadEntity, + apiKey: String, + tools: List, + initialInput: String, + ) { + val provider = thread.llmModel.providerName + val modelName = thread.llmModel.modelName + val llmProvider = llmProviderFactory.getProvider(provider) + + var previousInteractionId = thread.latestInteractionId + var currentToolOutputs = emptyList() + var continueLoop = true + var currentInput = initialInput + + while (continueLoop) { + val llmInput = prepareLlmInput(currentToolOutputs, currentInput) + + currentToolOutputs = emptyList() + val response = + llmProvider.generateResponse( + previousInteractionId = previousInteractionId, + input = llmInput, + tools = tools, + apiKey = apiKey, + modelName = modelName, + ) + + when (val handleResult = handleLlmResponse(response, message, tools)) { + is HandleResult.Continue -> { + currentToolOutputs = handleResult.toolOutputs + previousInteractionId = handleResult.interactionId + } + + is HandleResult.Stop -> { + continueLoop = false + } + } + } + } + private suspend fun getApiKey(provider: LlmProviderName): String? { return when (provider) { LlmProviderName.GEMINI -> settingsRepository.geminiApiKey.first() @@ -222,6 +257,7 @@ class AgentOrchestrator } HandleResult.Continue(toolResult.toolOutputs, response.interactionId) } + is ExecuteToolCallsResult.PendingIntentAction -> { savePendingIntentUseCase( toolResult.pendingIntentId, @@ -236,6 +272,7 @@ class AgentOrchestrator ) HandleResult.Stop } + is ExecuteToolCallsResult.Error -> { HandleResult.Stop } @@ -252,6 +289,7 @@ class AgentOrchestrator HandleResult.Stop } } + is LlmResponse.Error -> { Log.e("AgentOrchestrator", "LLM Error: ${response.errorMessage}") completeMessageWithError(message.messageId, message.threadId, response.errorMessage) @@ -321,6 +359,7 @@ class AgentOrchestrator ), ) } + is ExecuteAppFunctionResult.PendingIntentAction -> { val pendingIntentId = java.util.UUID.randomUUID().toString() return ExecuteToolCallsResult.PendingIntentAction( @@ -328,6 +367,7 @@ class AgentOrchestrator executionResult.pendingIntent, ) } + is ExecuteAppFunctionResult.Error -> throw IllegalStateException( "Tool execution failed for ${toolCall.functionId}: ${executionResult.exception.message}", diff --git a/agent/app/src/main/java/com/example/appfunctions/agent/domain/chat/SendMessageUseCase.kt b/agent/app/src/main/java/com/example/appfunctions/agent/domain/chat/SendMessageUseCase.kt index 1a17d25..91c3518 100644 --- a/agent/app/src/main/java/com/example/appfunctions/agent/domain/chat/SendMessageUseCase.kt +++ b/agent/app/src/main/java/com/example/appfunctions/agent/domain/chat/SendMessageUseCase.kt @@ -42,6 +42,7 @@ class SendMessageUseCase textContent: String, processingStatus: MessageProcessingStatus, pendingIntentId: String? = null, + targetPackageName: String? = null, ) { val message = MessageEntity( @@ -52,6 +53,7 @@ class SendMessageUseCase timestamp = System.currentTimeMillis(), processingStatus = processingStatus, pendingIntentId = pendingIntentId, + targetPackageName = targetPackageName, ) chatRepository.sendMessage(message) } diff --git a/agent/app/src/main/java/com/example/appfunctions/agent/ui/screens/agentdemo/AgentDemoScreen.kt b/agent/app/src/main/java/com/example/appfunctions/agent/ui/screens/agentdemo/AgentDemoScreen.kt index 046f594..853e73d 100644 --- a/agent/app/src/main/java/com/example/appfunctions/agent/ui/screens/agentdemo/AgentDemoScreen.kt +++ b/agent/app/src/main/java/com/example/appfunctions/agent/ui/screens/agentdemo/AgentDemoScreen.kt @@ -41,6 +41,8 @@ import androidx.compose.foundation.lazy.LazyColumn import androidx.compose.foundation.lazy.items import androidx.compose.foundation.shape.CircleShape import androidx.compose.foundation.shape.RoundedCornerShape +import androidx.compose.foundation.text.InlineTextContent +import androidx.compose.foundation.text.appendInlineContent import androidx.compose.material.icons.Icons import androidx.compose.material.icons.automirrored.filled.Send import androidx.compose.material.icons.filled.Add @@ -48,6 +50,8 @@ import androidx.compose.material.icons.filled.ArrowDropDown import androidx.compose.material.icons.filled.Menu import androidx.compose.material.icons.filled.Warning import androidx.compose.material3.ButtonDefaults +import androidx.compose.material3.Card +import androidx.compose.material3.CardDefaults import androidx.compose.material3.CircularProgressIndicator import androidx.compose.material3.DrawerState import androidx.compose.material3.DrawerValue @@ -85,10 +89,31 @@ import androidx.compose.ui.input.key.onPreviewKeyEvent import androidx.compose.ui.input.key.type import androidx.compose.ui.platform.LocalConfiguration import androidx.compose.ui.platform.LocalContext +import androidx.compose.ui.platform.LocalDensity import androidx.compose.ui.platform.LocalFocusManager import androidx.compose.ui.res.stringResource +import androidx.compose.ui.text.AnnotatedString +import androidx.compose.ui.text.Placeholder +import androidx.compose.ui.text.PlaceholderVerticalAlign +import androidx.compose.ui.text.SpanStyle +import androidx.compose.ui.text.TextRange +import androidx.compose.ui.text.buildAnnotatedString +import androidx.compose.ui.text.font.FontWeight +import androidx.compose.ui.text.input.OffsetMapping +import androidx.compose.ui.text.input.TextFieldValue +import androidx.compose.ui.text.input.TransformedText +import androidx.compose.ui.text.input.VisualTransformation +import androidx.compose.ui.text.rememberTextMeasurer import androidx.compose.ui.text.style.TextOverflow +import androidx.compose.ui.text.withStyle +import androidx.compose.ui.unit.IntOffset +import androidx.compose.ui.unit.IntRect +import androidx.compose.ui.unit.IntSize +import androidx.compose.ui.unit.LayoutDirection import androidx.compose.ui.unit.dp +import androidx.compose.ui.window.Popup +import androidx.compose.ui.window.PopupPositionProvider +import androidx.compose.ui.window.PopupProperties import androidx.core.graphics.drawable.toBitmap import androidx.hilt.navigation.compose.hiltViewModel import androidx.lifecycle.compose.collectAsStateWithLifecycle @@ -99,6 +124,7 @@ import com.example.appfunctions.agent.data.db.entities.MessageProcessingStatus import com.example.appfunctions.agent.data.db.entities.MessageRole import com.example.appfunctions.agent.data.db.entities.ThreadEntity import com.example.appfunctions.agent.domain.AgentStatus +import com.example.appfunctions.agent.domain.appfunction.AppInfo import com.example.appfunctions.agent.ui.screens.debugging.LazyExposedDropdownMenu import com.mikepenz.markdown.m3.Markdown import kotlinx.coroutines.CoroutineScope @@ -116,7 +142,7 @@ fun AgentDemoScreen(viewModel: AgentDemoViewModel = hiltViewModel()) { fun AgentDemoContent( uiState: AgentUiState, onEvent: (AgentUiEvent) -> Unit, - initialSidePanelVisible: Boolean = false, + initialSidePanelVisible: Boolean = false ) { val context = LocalContext.current val packageManager = context.packageManager @@ -136,6 +162,7 @@ fun AgentDemoContent( is AgentUiState.Loading -> { AgentDemoLoadingScreen() } + is AgentUiState.Loaded -> { AgentDemoLoadedScreen( uiState = uiState, @@ -144,7 +171,7 @@ fun AgentDemoContent( drawerState = drawerState, scope = scope, packageManager = packageManager, - initialSidePanelVisible = initialSidePanelVisible, + initialSidePanelVisible = initialSidePanelVisible ) } } @@ -160,7 +187,7 @@ fun AgentDemoContent( drawerState = drawerState, drawerContent = { ModalDrawerSheet( - drawerContainerColor = MaterialTheme.colorScheme.surface, + drawerContainerColor = MaterialTheme.colorScheme.surface ) { ChatHistorySidePanel( threads = threads, @@ -168,10 +195,10 @@ fun AgentDemoContent( onEvent = { event -> onEvent(event) scope.launch { drawerState.close() } - }, + } ) } - }, + } ) { content() } @@ -194,10 +221,18 @@ fun AgentDemoLoadedScreen( drawerState: DrawerState, scope: CoroutineScope, packageManager: PackageManager, - initialSidePanelVisible: Boolean = false, + initialSidePanelVisible: Boolean = false ) { - var messageText by remember { mutableStateOf("") } + var messageText by remember { mutableStateOf(TextFieldValue("")) } var isSidePanelVisible by remember { mutableStateOf(initialSidePanelVisible) } + var selectedAppPackageName by remember { mutableStateOf(null) } + + val chipBgColor = MaterialTheme.colorScheme.primaryContainer + val chipTextColor = MaterialTheme.colorScheme.onPrimaryContainer + val visualTransformation = + remember(uiState.installedApps, chipTextColor) { + InlineAppScopingVisualTransformation(uiState.installedApps, chipTextColor) + } Scaffold( modifier = Modifier.fillMaxSize(), @@ -205,10 +240,13 @@ fun AgentDemoLoadedScreen( topBar = { Row( modifier = Modifier.padding(horizontal = 8.dp, vertical = 16.dp), - verticalAlignment = Alignment.CenterVertically, + verticalAlignment = Alignment.CenterVertically ) { ModelDropdown( - modifier = Modifier.weight(1f).padding(horizontal = 8.dp), + modifier = + Modifier + .weight(1f) + .padding(horizontal = 8.dp), currentThread = uiState.currentThread, onModelSelected = { onEvent(AgentUiEvent.OnModelSelected(it)) }, onMenuClick = { @@ -217,38 +255,39 @@ fun AgentDemoLoadedScreen( } else { scope.launch { drawerState.open() } } - }, + } ) IconButton( onClick = { onEvent(AgentUiEvent.OnCreateThread(uiState.currentThread.llmModel)) }, - modifier = Modifier.padding(horizontal = 8.dp), + modifier = Modifier.padding(horizontal = 8.dp) ) { Icon(imageVector = Icons.Default.Add, contentDescription = "Create Thread") } } - }, + } ) { paddingValues -> Row( modifier = - Modifier.fillMaxSize() - .imePadding() - .padding( - top = paddingValues.calculateTopPadding(), - ), + Modifier + .fillMaxSize() + .imePadding() + .padding( + top = paddingValues.calculateTopPadding() + ) ) { // Side Panel (only for wide screens) if (isWideScreen) { AnimatedVisibility( visible = isSidePanelVisible, enter = slideInHorizontally() + expandHorizontally(), - exit = slideOutHorizontally() + shrinkHorizontally(), + exit = slideOutHorizontally() + shrinkHorizontally() ) { ChatHistorySidePanel( threads = uiState.threads, currentThread = uiState.currentThread, - onEvent = onEvent, + onEvent = onEvent ) } } @@ -256,14 +295,20 @@ fun AgentDemoLoadedScreen( // Main Chat Area Column( modifier = - Modifier.weight(1f).fillMaxHeight().padding(start = 16.dp, end = 16.dp), - verticalArrangement = Arrangement.SpaceBetween, + Modifier + .weight(1f) + .fillMaxHeight() + .padding(start = 16.dp, end = 16.dp), + verticalArrangement = Arrangement.SpaceBetween ) { // Messages List LazyColumn( modifier = - Modifier.weight(1f).fillMaxWidth().clip(RoundedCornerShape(16.dp)), - reverseLayout = true, + Modifier + .weight(1f) + .fillMaxWidth() + .clip(RoundedCornerShape(16.dp)), + reverseLayout = true ) { // Status item at the bottom (above input) if not // idle @@ -271,73 +316,196 @@ fun AgentDemoLoadedScreen( item { StatusIndicator( status = uiState.status, - packageManager = packageManager, + packageManager = packageManager ) } } items( items = uiState.messages.reversed(), - key = { message -> message.messageId }, + key = { message -> message.messageId } ) { message -> MessageBubble( message = message, isValidAction = - message.pendingIntentId in uiState.activePendingActionIds, - onConfirmAction = { onEvent(AgentUiEvent.OnConfirmAction(it)) }, + message.pendingIntentId in uiState.activePendingActionIds, + installedApps = uiState.installedApps, + onConfirmAction = { onEvent(AgentUiEvent.OnConfirmAction(it)) } ) } } val sendMessage = { - if (messageText.isNotBlank() && uiState.status == AgentStatus.Idle) { - onEvent(AgentUiEvent.OnSendMessage(messageText)) - messageText = "" + val textStr = messageText.text + if (textStr.isNotBlank() && uiState.status == AgentStatus.Idle) { + onEvent(AgentUiEvent.OnSendMessage(textStr, selectedAppPackageName)) + messageText = TextFieldValue("") + selectedAppPackageName = null } } + val textStr = messageText.text + val lastAtIndex = textStr.lastIndexOf('@') + val showAutocomplete = lastAtIndex >= 0 && + (lastAtIndex == 0 || textStr[lastAtIndex - 1].isWhitespace()) && + selectedAppPackageName == null + val autocompleteQuery = + if (showAutocomplete) { + textStr.substring(lastAtIndex + 1) + } else { + "" + } + val filteredApps = + remember(autocompleteQuery, uiState.installedApps) { + if (autocompleteQuery.isEmpty()) { + uiState.installedApps + } else { + uiState.installedApps.filter { + it.label.contains(autocompleteQuery, ignoreCase = true) + } + } + } + + val density = LocalDensity.current + val popupPositionProvider = + remember(density) { + object : PopupPositionProvider { + override fun calculatePosition( + anchorBounds: IntRect, + windowSize: IntSize, + layoutDirection: LayoutDirection, + popupContentSize: IntSize + ): IntOffset { + val gap = with(density) { 2.dp.roundToPx() } + return IntOffset( + x = anchorBounds.left, + y = anchorBounds.top - popupContentSize.height - gap + ) + } + } + } + + val appMentionRegex = + remember(uiState.installedApps) { + if (uiState.installedApps.isNotEmpty()) { + val appLabelsPattern = + uiState.installedApps.joinToString("|") { Regex.escape(it.label) } + Regex("@($appLabelsPattern)\\b", RegexOption.IGNORE_CASE) + } else { + null + } + } + // Input area - OutlinedTextField( - value = messageText, - onValueChange = { messageText = it }, - modifier = - Modifier.fillMaxWidth().padding(vertical = 16.dp).onPreviewKeyEvent { - keyEvent -> - if ( - (keyEvent.key == Key.Enter || keyEvent.key == Key.NumPadEnter) && - keyEvent.type == KeyEventType.KeyDown - ) { - sendMessage() - true - } else { - false + Box(modifier = Modifier.fillMaxWidth()) { + OutlinedTextField( + value = messageText, + onValueChange = { newValue -> + messageText = newValue + val currentText = newValue.text + if (selectedAppPackageName != null && appMentionRegex != null) { + if (!appMentionRegex.containsMatchIn(currentText)) { + selectedAppPackageName = null + } } }, - enabled = uiState.status == AgentStatus.Idle, - shape = CircleShape, - placeholder = { Text(stringResource(R.string.agent_demo_ask_agent)) }, - colors = + modifier = + Modifier + .fillMaxWidth() + .padding(vertical = 16.dp) + .onPreviewKeyEvent { keyEvent -> + if ( + (keyEvent.key == Key.Enter || keyEvent.key == Key.NumPadEnter) && + keyEvent.type == KeyEventType.KeyDown + ) { + sendMessage() + true + } else { + false + } + }, + enabled = uiState.status == AgentStatus.Idle, + shape = CircleShape, + placeholder = { Text(stringResource(R.string.agent_demo_ask_agent)) }, + visualTransformation = visualTransformation, + colors = OutlinedTextFieldDefaults.colors( unfocusedContainerColor = MaterialTheme.colorScheme.surfaceBright, focusedContainerColor = MaterialTheme.colorScheme.surfaceBright, unfocusedBorderColor = Color.Transparent, - focusedBorderColor = Color.Transparent, + focusedBorderColor = Color.Transparent ), - trailingIcon = { - IconButton( - onClick = sendMessage, - enabled = - messageText.isNotBlank() && - uiState.status == AgentStatus.Idle, + trailingIcon = { + IconButton( + onClick = sendMessage, + enabled = + messageText.text.isNotBlank() && + uiState.status == AgentStatus.Idle + ) { + Icon( + imageVector = Icons.AutoMirrored.Filled.Send, + contentDescription = + stringResource(R.string.agent_demo_send) + ) + } + } + ) + + if (showAutocomplete && filteredApps.isNotEmpty()) { + Popup( + popupPositionProvider = popupPositionProvider, + onDismissRequest = {}, + properties = PopupProperties(focusable = false) ) { - Icon( - imageVector = Icons.AutoMirrored.Filled.Send, - contentDescription = - stringResource(R.string.agent_demo_send), - ) + Card( + modifier = Modifier.fillMaxWidth(0.9f), + elevation = CardDefaults.cardElevation(defaultElevation = 8.dp), + colors = CardDefaults.cardColors( + containerColor = MaterialTheme.colorScheme.surfaceBright + ), + shape = MaterialTheme.shapes.medium + ) { + Column(modifier = Modifier.padding(vertical = 4.dp)) { + filteredApps.take(5).forEach { app -> + DropdownMenuItem( + text = { Text(app.label) }, + onClick = { + val currentText = messageText.text + val selectionStart = messageText.selection.start + val textBeforeCursor = currentText.take( + selectionStart + ) + val textAfterCursor = currentText.drop( + selectionStart + ) + val mentionIndex = textBeforeCursor.lastIndexOf('@') + if (mentionIndex >= 0) { + val textBeforeMention = + textBeforeCursor.substring( + 0, + mentionIndex + ) + val newText = + "$textBeforeMention@${app.label} $textAfterCursor" + val newCursorPosition = + mentionIndex + app.label.length + 2 + messageText = + TextFieldValue( + text = newText, + selection = TextRange( + newCursorPosition + ) + ) + selectedAppPackageName = app.packageName + } + } + ) + } + } + } } - }, - ) + } + } } } } @@ -349,19 +517,19 @@ fun ModelDropdown( modifier: Modifier = Modifier, currentThread: ThreadEntity?, onModelSelected: (LlmModel) -> Unit, - onMenuClick: () -> Unit, + onMenuClick: () -> Unit ) { var expanded by remember { mutableStateOf(false) } ExposedDropdownMenuBox( modifier = modifier, expanded = expanded, - onExpandedChange = { expanded = !expanded }, + onExpandedChange = { expanded = !expanded } ) { Surface( modifier = Modifier.padding(bottom = 8.dp), shadowElevation = 2.dp, shape = CircleShape, - color = MaterialTheme.colorScheme.surfaceBright, + color = MaterialTheme.colorScheme.surfaceBright ) { val text = currentThread?.llmModel?.modelName @@ -374,34 +542,39 @@ fun ModelDropdown( } Row( - modifier = Modifier.fillMaxWidth().height(56.dp).padding(start = 4.dp, end = 16.dp), - verticalAlignment = Alignment.CenterVertically, + modifier = + Modifier + .fillMaxWidth() + .height(56.dp) + .padding(start = 4.dp, end = 16.dp), + verticalAlignment = Alignment.CenterVertically ) { IconButton(onClick = onMenuClick) { Icon(imageVector = Icons.Default.Menu, contentDescription = "Menu") } Row( modifier = - Modifier.weight(1f) - .fillMaxHeight() - .menuAnchor( - ExposedDropdownMenuAnchorType.PrimaryEditable, - enabled = true, - ), - verticalAlignment = Alignment.CenterVertically, + Modifier + .weight(1f) + .fillMaxHeight() + .menuAnchor( + ExposedDropdownMenuAnchorType.PrimaryEditable, + enabled = true + ), + verticalAlignment = Alignment.CenterVertically ) { Column(modifier = Modifier.weight(1f)) { Text( text = stringResource(R.string.agent_demo_title), style = MaterialTheme.typography.labelMedium, - color = MaterialTheme.colorScheme.onSurfaceVariant, + color = MaterialTheme.colorScheme.onSurfaceVariant ) Text( text = text, style = MaterialTheme.typography.bodyMedium, color = textColor, maxLines = 1, - overflow = TextOverflow.Ellipsis, + overflow = TextOverflow.Ellipsis ) } Icon(imageVector = Icons.Default.ArrowDropDown, contentDescription = null) @@ -414,21 +587,21 @@ fun ModelDropdown( onDismissRequest = { expanded = false }, modifier = Modifier.exposedDropdownSize(), containerColor = MaterialTheme.colorScheme.surfaceBright, - shape = RoundedCornerShape(28.dp), + shape = RoundedCornerShape(28.dp) ) { item { Text( "--- Gemini ---", color = MaterialTheme.colorScheme.secondary, modifier = Modifier.padding(horizontal = 24.dp, vertical = 8.dp), - style = MaterialTheme.typography.labelLarge, + style = MaterialTheme.typography.labelLarge ) } val models = listOf( LlmModel.GEMINI_3_1_PRO_PREVIEW, LlmModel.GEMINI_3_FLASH_PREVIEW, - LlmModel.GEMINI_3_1_FLASH_LITE_PREVIEW, + LlmModel.GEMINI_3_1_FLASH_LITE_PREVIEW ) items(models) { model -> DropdownMenuItem( @@ -437,7 +610,7 @@ fun ModelDropdown( onModelSelected(model) expanded = false }, - contentPadding = PaddingValues(horizontal = 24.dp, vertical = 8.dp), + contentPadding = PaddingValues(horizontal = 24.dp, vertical = 8.dp) ) } } @@ -448,7 +621,8 @@ fun ModelDropdown( fun MessageBubble( message: MessageEntity, isValidAction: Boolean, - onConfirmAction: (String) -> Unit, + installedApps: List, + onConfirmAction: (String) -> Unit ) { val alignment = if (message.role == MessageRole.USER) Alignment.End else Alignment.Start val isError = message.processingStatus == MessageProcessingStatus.FAILED @@ -466,13 +640,16 @@ fun MessageBubble( } Column( - modifier = Modifier.fillMaxWidth().padding(vertical = 4.dp, horizontal = 2.dp), - horizontalAlignment = alignment, + modifier = + Modifier + .fillMaxWidth() + .padding(vertical = 4.dp, horizontal = 2.dp), + horizontalAlignment = alignment ) { Surface( shape = MaterialTheme.shapes.large, color = backgroundColor, - shadowElevation = if (message.role == MessageRole.ASSISTANT) 1.dp else 0.dp, + shadowElevation = if (message.role == MessageRole.ASSISTANT) 1.dp else 0.dp ) { Column(modifier = Modifier.padding(12.dp)) { Row(verticalAlignment = Alignment.CenterVertically) { @@ -480,7 +657,7 @@ fun MessageBubble( Icon( imageVector = Icons.Filled.Warning, contentDescription = stringResource(R.string.debugging_error), - tint = textColor, + tint = textColor ) Spacer(modifier = Modifier.width(8.dp)) } @@ -495,10 +672,89 @@ fun MessageBubble( if (message.role != MessageRole.USER) { Markdown(content = contentText) } else { + val chipBgColor = MaterialTheme.colorScheme.primary + val chipTextColor = MaterialTheme.colorScheme.onPrimary + val formattedText = + remember(contentText, installedApps) { + formatMessageText(contentText, installedApps) + } + val textMeasurer = rememberTextMeasurer() + val typographyStyle = MaterialTheme.typography.bodyLarge + val density = LocalDensity.current + + val inlineContentMap = + remember( + contentText, + installedApps, + chipBgColor, + chipTextColor, + density + ) { + val map = mutableMapOf() + if (installedApps.isNotEmpty() && contentText.contains("@")) { + val appLabelsPattern = installedApps.joinToString( + "|" + ) { Regex.escape(it.label) } + val regex = + Regex("@($appLabelsPattern)\\b", RegexOption.IGNORE_CASE) + regex.findAll(contentText).forEachIndexed { index, match -> + val id = "chip_$index" + val appName = match.value + val measured = + textMeasurer.measure( + text = appName, + style = typographyStyle.copy( + fontWeight = FontWeight.Bold + ) + ) + val widthSp = + with( + density + ) { (measured.size.width + 8.dp.roundToPx()).toSp() } + val heightSp = + with( + density + ) { (measured.size.height + 2.dp.roundToPx()).toSp() } + + map[id] = + InlineTextContent( + Placeholder( + width = widthSp, + height = heightSp, + placeholderVerticalAlign = PlaceholderVerticalAlign.TextCenter + ) + ) { + Surface( + shape = androidx.compose.foundation.shape.RoundedCornerShape( + 6.dp + ), + color = chipBgColor + ) { + Box(contentAlignment = Alignment.Center) { + Text( + text = appName, + color = chipTextColor, + style = typographyStyle.copy( + fontWeight = FontWeight.Bold + ), + modifier = Modifier.padding( + horizontal = 4.dp, + vertical = 1.dp + ) + ) + } + } + } + } + } + map + } + Text( - text = contentText, + text = formattedText, + inlineContent = inlineContentMap, color = textColor, - style = MaterialTheme.typography.bodyLarge, + style = typographyStyle ) } } @@ -510,17 +766,17 @@ fun MessageBubble( enabled = isValidAction, shape = CircleShape, colors = - ButtonDefaults.buttonColors( - containerColor = MaterialTheme.colorScheme.primary, - contentColor = MaterialTheme.colorScheme.onPrimary, - ), + ButtonDefaults.buttonColors( + containerColor = MaterialTheme.colorScheme.primary, + contentColor = MaterialTheme.colorScheme.onPrimary + ) ) { Text( if (isValidAction) { stringResource(R.string.agent_demo_confirm_action) } else { stringResource(R.string.agent_demo_action_expired) - }, + } ) } } @@ -530,24 +786,25 @@ fun MessageBubble( } @Composable -fun StatusIndicator( - status: AgentStatus, - packageManager: PackageManager, -) { +fun StatusIndicator(status: AgentStatus, packageManager: PackageManager) { when (status) { AgentStatus.Thinking -> { Row( - modifier = Modifier.fillMaxWidth().padding(vertical = 8.dp), - verticalAlignment = Alignment.CenterVertically, + modifier = + Modifier + .fillMaxWidth() + .padding(vertical = 8.dp), + verticalAlignment = Alignment.CenterVertically ) { CircularProgressIndicator(modifier = Modifier.size(24.dp)) Spacer(modifier = Modifier.width(8.dp)) Text( stringResource(R.string.agent_demo_thinking), - style = MaterialTheme.typography.bodyMedium, + style = MaterialTheme.typography.bodyMedium ) } } + is AgentStatus.InvokingTool -> { val appName = try { @@ -564,20 +821,23 @@ fun StatusIndicator( } Surface( - modifier = Modifier.fillMaxWidth().padding(vertical = 8.dp), + modifier = + Modifier + .fillMaxWidth() + .padding(vertical = 8.dp), shape = MaterialTheme.shapes.large, color = MaterialTheme.colorScheme.surfaceBright, - shadowElevation = 2.dp, + shadowElevation = 2.dp ) { Row( modifier = Modifier.padding(12.dp), - verticalAlignment = Alignment.CenterVertically, + verticalAlignment = Alignment.CenterVertically ) { appIcon?.let { Image( bitmap = it.toBitmap().asImageBitmap(), contentDescription = null, - modifier = Modifier.size(40.dp), + modifier = Modifier.size(40.dp) ) Spacer(modifier = Modifier.width(12.dp)) } @@ -588,13 +848,14 @@ fun StatusIndicator( Spacer(modifier = Modifier.width(8.dp)) Text( stringResource(R.string.agent_demo_connecting), - style = MaterialTheme.typography.bodyMedium, + style = MaterialTheme.typography.bodyMedium ) } } } } } + AgentStatus.Idle -> { // Nothing to show } @@ -606,20 +867,26 @@ fun ChatHistorySidePanel( threads: List, currentThread: ThreadEntity?, onEvent: (AgentUiEvent) -> Unit, - modifier: Modifier = Modifier, + modifier: Modifier = Modifier ) { - Column(modifier = modifier.width(280.dp).fillMaxHeight().padding(16.dp)) { + Column( + modifier = + modifier + .width(280.dp) + .fillMaxHeight() + .padding(16.dp) + ) { Text( text = stringResource(R.string.agent_demo_chat_history), style = MaterialTheme.typography.titleLarge, color = MaterialTheme.colorScheme.onSurface, - modifier = Modifier.padding(bottom = 16.dp), + modifier = Modifier.padding(bottom = 16.dp) ) LazyColumn(modifier = Modifier.fillMaxSize()) { items( items = threads, - key = { thread -> thread.threadId }, + key = { thread -> thread.threadId } ) { thread -> val isSelected = thread.threadId == currentThread?.threadId val backgroundColor = @@ -637,23 +904,26 @@ fun ChatHistorySidePanel( Surface( modifier = - Modifier.fillMaxWidth().padding(vertical = 4.dp).clickable { + Modifier + .fillMaxWidth() + .padding(vertical = 4.dp) + .clickable { onEvent(AgentUiEvent.OnThreadSelected(thread.threadId)) }, shape = MaterialTheme.shapes.medium, color = backgroundColor, - contentColor = textColor, + contentColor = textColor ) { Column(modifier = Modifier.padding(12.dp)) { Text( text = thread.llmModel.modelName, style = MaterialTheme.typography.bodyMedium, - color = textColor, + color = textColor ) Text( text = "ID: ${thread.threadId.take(8)}", style = MaterialTheme.typography.bodySmall, - color = textColor.copy(alpha = 0.7f), + color = textColor.copy(alpha = 0.7f) ) } } @@ -661,3 +931,73 @@ fun ChatHistorySidePanel( } } } + +class InlineAppScopingVisualTransformation( + private val installedApps: List, + private val chipTextColor: Color +) : VisualTransformation { + private val regex: Regex? = + if (installedApps.isNotEmpty()) { + val appLabelsPattern = installedApps.joinToString("|") { Regex.escape(it.label) } + Regex("@($appLabelsPattern)\\b", RegexOption.IGNORE_CASE) + } else { + null + } + + override fun filter(text: AnnotatedString): TransformedText { + val rawText = text.text + val currentRegex = regex + if (currentRegex == null || !rawText.contains("@")) { + return TransformedText(text, OffsetMapping.Identity) + } + + val matches = currentRegex.findAll(rawText) + + val annotatedString = + buildAnnotatedString { + var lastIndex = 0 + matches.forEach { match -> + append(rawText.substring(lastIndex, match.range.first)) + pushStringAnnotation(tag = "mention", annotation = match.value) + withStyle( + SpanStyle( + color = chipTextColor, + fontWeight = FontWeight.Bold + ) + ) { + append(match.value) + } + pop() + lastIndex = match.range.last + 1 + } + if (lastIndex < rawText.length) { + append(rawText.substring(lastIndex)) + } + } + return TransformedText(annotatedString, OffsetMapping.Identity) + } +} + +fun formatMessageText(text: String, installedApps: List): AnnotatedString { + if (installedApps.isEmpty() || !text.contains("@")) { + return AnnotatedString(text) + } + val appLabelsPattern = installedApps.joinToString("|") { Regex.escape(it.label) } + val regex = Regex("@($appLabelsPattern)\\b", RegexOption.IGNORE_CASE) + val matches = regex.findAll(text) + + return buildAnnotatedString { + var lastIndex = 0 + matches.forEachIndexed { index, match -> + val precedingText = text.substring(lastIndex, match.range.first) + if (precedingText.isNotEmpty()) { + append(precedingText) + } + appendInlineContent(id = "chip_$index", alternateText = match.value) + lastIndex = match.range.last + 1 + } + if (lastIndex < text.length) { + append(text.substring(lastIndex)) + } + } +} diff --git a/agent/app/src/main/java/com/example/appfunctions/agent/ui/screens/agentdemo/AgentDemoViewModel.kt b/agent/app/src/main/java/com/example/appfunctions/agent/ui/screens/agentdemo/AgentDemoViewModel.kt index f20e972..0e7a34a 100644 --- a/agent/app/src/main/java/com/example/appfunctions/agent/ui/screens/agentdemo/AgentDemoViewModel.kt +++ b/agent/app/src/main/java/com/example/appfunctions/agent/ui/screens/agentdemo/AgentDemoViewModel.kt @@ -27,6 +27,9 @@ import com.example.appfunctions.agent.data.db.entities.MessageRole import com.example.appfunctions.agent.data.db.entities.ThreadEntity import com.example.appfunctions.agent.domain.AgentOrchestrator import com.example.appfunctions.agent.domain.AgentStatus +import com.example.appfunctions.agent.domain.appfunction.AppInfo +import com.example.appfunctions.agent.domain.appfunction.GetAppFunctionsUseCase +import com.example.appfunctions.agent.domain.appfunction.GetInstalledAppsUseCase import com.example.appfunctions.agent.domain.chat.GetChatHistoryUseCase import com.example.appfunctions.agent.domain.chat.ManageThreadsUseCase import com.example.appfunctions.agent.domain.chat.SendMessageUseCase @@ -34,6 +37,8 @@ import com.example.appfunctions.agent.domain.pendingintent.ConsumePendingIntentU import com.example.appfunctions.agent.domain.pendingintent.LaunchPendingIntentUseCase import com.example.appfunctions.agent.domain.pendingintent.ObserveActivePendingIntentsUseCase import dagger.hilt.android.lifecycle.HiltViewModel +import javax.inject.Inject +import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Job import kotlinx.coroutines.flow.MutableStateFlow import kotlinx.coroutines.flow.StateFlow @@ -42,158 +47,178 @@ import kotlinx.coroutines.flow.collectLatest import kotlinx.coroutines.flow.combine import kotlinx.coroutines.flow.first import kotlinx.coroutines.launch -import javax.inject.Inject +import kotlinx.coroutines.withContext @HiltViewModel class AgentDemoViewModel - @Inject - constructor( - private val savedStateHandle: SavedStateHandle, - private val getChatHistoryUseCase: GetChatHistoryUseCase, - private val manageThreadsUseCase: ManageThreadsUseCase, - private val sendMessageUseCase: SendMessageUseCase, - private val agentOrchestrator: AgentOrchestrator, - private val settingsRepository: SettingsRepository, - private val observeActivePendingIntentsUseCase: ObserveActivePendingIntentsUseCase, - private val launchPendingIntentUseCase: LaunchPendingIntentUseCase, - private val consumePendingIntentUseCase: ConsumePendingIntentUseCase, - ) : ViewModel() { - private val _uiState = MutableStateFlow(AgentUiState.Loading) - val uiState: StateFlow = _uiState.asStateFlow() +@Inject +constructor( + private val savedStateHandle: SavedStateHandle, + private val getChatHistoryUseCase: GetChatHistoryUseCase, + private val manageThreadsUseCase: ManageThreadsUseCase, + private val sendMessageUseCase: SendMessageUseCase, + private val agentOrchestrator: AgentOrchestrator, + private val settingsRepository: SettingsRepository, + private val observeActivePendingIntentsUseCase: ObserveActivePendingIntentsUseCase, + private val launchPendingIntentUseCase: LaunchPendingIntentUseCase, + private val consumePendingIntentUseCase: ConsumePendingIntentUseCase, + private val getInstalledAppsUseCase: GetInstalledAppsUseCase, + private val getAppFunctionsUseCase: GetAppFunctionsUseCase +) : ViewModel() { + private val _uiState = MutableStateFlow(AgentUiState.Loading) + val uiState: StateFlow = _uiState.asStateFlow() - private var observationJob: Job? = null + private val _installedApps = MutableStateFlow>(emptyList()) - init { - viewModelScope.launch { - observeActivePendingIntentsUseCase().collect { activePendingActionIds -> - val currentState = _uiState.value - if (currentState is AgentUiState.Loaded) { - _uiState.value = - currentState.copy(activePendingActionIds = activePendingActionIds) - } + private var observationJob: Job? = null + + init { + viewModelScope.launch { + getAppFunctionsUseCase().collect { toolsMap -> + val allTools = toolsMap.values.flatten() + val packagesWithTools = allTools.map { it.packageName }.toSet() + val filteredApps = withContext(Dispatchers.IO) { + getInstalledAppsUseCase().filter { it.packageName in packagesWithTools } } + _installedApps.value = filteredApps } + } - viewModelScope.launch { - val threads = manageThreadsUseCase.getThreads().first() - if (threads.isEmpty()) { - createAndSelectThread(LlmModel.DEFAULT) + viewModelScope.launch { + observeActivePendingIntentsUseCase().collect { activePendingActionIds -> + val currentState = _uiState.value + if (currentState is AgentUiState.Loaded) { + _uiState.value = + currentState.copy(activePendingActionIds = activePendingActionIds) } } + } - viewModelScope.launch { - combine( - manageThreadsUseCase.getThreads(), - settingsRepository.selectedProvider, - agentOrchestrator.status, - savedStateHandle.getStateFlow(MainActivity.ARG_THREAD_ID, null), - ) { - threads, - provider, - status, - targetThreadId, - -> - ThreadConfig(threads, provider, status, targetThreadId) - } - .collectLatest { (threads, provider, status, targetThreadId) -> - val currentThread = - threads.find { it.threadId == targetThreadId } ?: threads.firstOrNull() + viewModelScope.launch { + val threads = manageThreadsUseCase.getThreads().first() + if (threads.isEmpty()) { + createAndSelectThread(LlmModel.DEFAULT) + } + } + + viewModelScope.launch { + combine( + manageThreadsUseCase.getThreads(), + settingsRepository.selectedProvider, + agentOrchestrator.status, + savedStateHandle.getStateFlow(MainActivity.ARG_THREAD_ID, null), + _installedApps + ) { + threads, + provider, + status, + targetThreadId, + apps + -> + ThreadConfig(threads, provider, status, targetThreadId, apps) + } + .collectLatest { (threads, provider, status, targetThreadId, apps) -> + val currentThread = + threads.find { it.threadId == targetThreadId } ?: threads.firstOrNull() - val previousThreadId = - (_uiState.value as? AgentUiState.Loaded)?.currentThread?.threadId + val previousThreadId = + (_uiState.value as? AgentUiState.Loaded)?.currentThread?.threadId - if (currentThread == null) { - observationJob?.cancel() - observationJob = null - _uiState.value = AgentUiState.Loading - } else { - val currentLoadedState = _uiState.value as? AgentUiState.Loaded - _uiState.value = - AgentUiState.Loaded( - currentThread = currentThread, - messages = currentLoadedState?.messages ?: emptyList(), - status = status, - threads = threads, - activePendingActionIds = - currentLoadedState?.activePendingActionIds ?: emptySet(), - ) + if (currentThread == null) { + observationJob?.cancel() + observationJob = null + _uiState.value = AgentUiState.Loading + } else { + val currentLoadedState = _uiState.value as? AgentUiState.Loaded + _uiState.value = + AgentUiState.Loaded( + currentThread = currentThread, + messages = currentLoadedState?.messages ?: emptyList(), + status = status, + threads = threads, + activePendingActionIds = + currentLoadedState?.activePendingActionIds ?: emptySet(), + installedApps = apps + ) - // Start observing messages for the current thread if not already doing so - if (observationJob == null || previousThreadId != currentThread.threadId) { - observationJob?.cancel() - observationJob = - viewModelScope.launch { - launch { - getChatHistoryUseCase(currentThread.threadId).collect { - messages -> - val currentState = _uiState.value - if (currentState is AgentUiState.Loaded) { - _uiState.value = - currentState.copy(messages = messages) - } + // Start observing messages for the current thread if not already doing so + if (observationJob == null || previousThreadId != currentThread.threadId) { + observationJob?.cancel() + observationJob = + viewModelScope.launch { + launch { + getChatHistoryUseCase(currentThread.threadId).collect { + messages -> + val currentState = _uiState.value + if (currentState is AgentUiState.Loaded) { + _uiState.value = + currentState.copy(messages = messages) } } - launch { - agentOrchestrator.observeAndProcessMessages( - currentThread.threadId, - ) - } } - } + launch { + agentOrchestrator.observeAndProcessMessages( + currentThread.threadId + ) + } + } } } - } + } } + } - fun onEvent(event: AgentUiEvent) { - val currentState = _uiState.value - when (event) { - is AgentUiEvent.OnSendMessage -> { - if (currentState is AgentUiState.Loaded) { - viewModelScope.launch { - sendMessageUseCase( - threadId = currentState.currentThread.threadId, - role = MessageRole.USER, - textContent = event.text, - processingStatus = MessageProcessingStatus.PENDING_AGENT_RESPONSE, - ) - } + fun onEvent(event: AgentUiEvent) { + val currentState = _uiState.value + when (event) { + is AgentUiEvent.OnSendMessage -> { + if (currentState is AgentUiState.Loaded) { + viewModelScope.launch { + sendMessageUseCase( + threadId = currentState.currentThread.threadId, + role = MessageRole.USER, + textContent = event.text, + processingStatus = MessageProcessingStatus.PENDING_AGENT_RESPONSE, + targetPackageName = event.targetPackageName + ) } } - is AgentUiEvent.OnModelSelected -> { - if (currentState is AgentUiState.Loaded) { - viewModelScope.launch { - manageThreadsUseCase.updateThreadModel( - currentState.currentThread.threadId, - event.model, - ) - } + } + is AgentUiEvent.OnModelSelected -> { + if (currentState is AgentUiState.Loaded) { + viewModelScope.launch { + manageThreadsUseCase.updateThreadModel( + currentState.currentThread.threadId, + event.model + ) } } - is AgentUiEvent.OnCreateThread -> { - viewModelScope.launch { createAndSelectThread(event.model) } - } - is AgentUiEvent.OnThreadSelected -> { - savedStateHandle[MainActivity.ARG_THREAD_ID] = event.threadId - } - is AgentUiEvent.OnConfirmAction -> { - val pendingIntent = consumePendingIntentUseCase(event.pendingIntentId) - if (pendingIntent != null) { - launchPendingIntentUseCase(pendingIntent) - } + } + is AgentUiEvent.OnCreateThread -> { + viewModelScope.launch { createAndSelectThread(event.model) } + } + is AgentUiEvent.OnThreadSelected -> { + savedStateHandle[MainActivity.ARG_THREAD_ID] = event.threadId + } + is AgentUiEvent.OnConfirmAction -> { + val pendingIntent = consumePendingIntentUseCase(event.pendingIntentId) + if (pendingIntent != null) { + launchPendingIntentUseCase(pendingIntent) } } } + } - private suspend fun createAndSelectThread(llmModel: LlmModel) { - val threadId = manageThreadsUseCase.createThread(llmModel) - savedStateHandle[MainActivity.ARG_THREAD_ID] = threadId - } + private suspend fun createAndSelectThread(llmModel: LlmModel) { + val threadId = manageThreadsUseCase.createThread(llmModel) + savedStateHandle[MainActivity.ARG_THREAD_ID] = threadId } +} private data class ThreadConfig( val threads: List, val provider: LlmProviderName, val status: AgentStatus, val targetThreadId: String?, + val installedApps: List ) diff --git a/agent/app/src/main/java/com/example/appfunctions/agent/ui/screens/agentdemo/AgentUiState.kt b/agent/app/src/main/java/com/example/appfunctions/agent/ui/screens/agentdemo/AgentUiState.kt index 0be61ea..4798f79 100644 --- a/agent/app/src/main/java/com/example/appfunctions/agent/ui/screens/agentdemo/AgentUiState.kt +++ b/agent/app/src/main/java/com/example/appfunctions/agent/ui/screens/agentdemo/AgentUiState.kt @@ -19,6 +19,7 @@ import com.example.appfunctions.agent.data.LlmModel import com.example.appfunctions.agent.data.db.entities.MessageEntity import com.example.appfunctions.agent.data.db.entities.ThreadEntity import com.example.appfunctions.agent.domain.AgentStatus +import com.example.appfunctions.agent.domain.appfunction.AppInfo /** Represents the UI state for the Agent Demo screen. */ sealed class AgentUiState { @@ -30,12 +31,16 @@ sealed class AgentUiState { val status: AgentStatus = AgentStatus.Idle, val threads: List = emptyList(), val activePendingActionIds: Set = emptySet(), + val installedApps: List = emptyList(), ) : AgentUiState() } /** Represents UI events for the Agent Demo screen. */ sealed class AgentUiEvent { - data class OnSendMessage(val text: String) : AgentUiEvent() + data class OnSendMessage( + val text: String, + val targetPackageName: String? = null, + ) : AgentUiEvent() data class OnModelSelected(val model: LlmModel) : AgentUiEvent() diff --git a/agent/app/src/test/java/com/example/appfunctions/agent/domain/AgentOrchestratorTest.kt b/agent/app/src/test/java/com/example/appfunctions/agent/domain/AgentOrchestratorTest.kt index 724192a..592d28d 100644 --- a/agent/app/src/test/java/com/example/appfunctions/agent/domain/AgentOrchestratorTest.kt +++ b/agent/app/src/test/java/com/example/appfunctions/agent/domain/AgentOrchestratorTest.kt @@ -15,6 +15,8 @@ */ package com.example.appfunctions.agent.domain +import androidx.appfunctions.metadata.AppFunctionMetadata +import androidx.appfunctions.metadata.AppFunctionPackageMetadata import com.example.appfunctions.agent.data.LlmModel import com.example.appfunctions.agent.data.LlmProviderName import com.example.appfunctions.agent.data.SettingsRepository @@ -35,6 +37,7 @@ import com.example.appfunctions.agent.domain.chat.UpdateThreadUseCase import com.example.appfunctions.agent.domain.pendingintent.SavePendingIntentUseCase import io.mockk.coEvery import io.mockk.coVerify +import io.mockk.every import io.mockk.mockk import kotlinx.coroutines.delay import kotlinx.coroutines.flow.flow @@ -84,30 +87,10 @@ class AgentOrchestratorTest { fun `observeAndProcessMessages fails when API key is missing`() = runTest { val threadId = "thread_1" - val message = - MessageEntity( - messageId = "msg_1", - threadId = threadId, - role = MessageRole.USER, - textContent = "Hello", - timestamp = System.currentTimeMillis(), - processingStatus = MessageProcessingStatus.PENDING_AGENT_RESPONSE, - ) - val thread = - ThreadEntity( - threadId = threadId, - createdAt = System.currentTimeMillis(), - llmModel = LlmModel.GEMINI_3_FLASH_PREVIEW, - latestInteractionId = null, - ) - coEvery { observePendingMessagesUseCase(threadId) } returns - flow { - delay(10) - emit(message) - } - coEvery { manageThreadsUseCase.getThread(threadId) } returns flowOf(thread) - coEvery { settingsRepository.geminiApiKey } returns flowOf(null) - coEvery { settingsRepository.disconnectedApps } returns flowOf(emptySet()) + val message = createUserMessage(threadId, "Hello", messageId = "msg_1") + val thread = createThread(threadId) + + setupDefaultMocks(threadId, message, thread, apiKey = null) agentOrchestrator.observeAndProcessMessages(threadId) @@ -133,33 +116,11 @@ class AgentOrchestratorTest { fun `observeAndProcessMessages fails when LLM returns error`() = runTest { val threadId = "thread_1" - val message = - MessageEntity( - messageId = "msg_1", - threadId = threadId, - role = MessageRole.USER, - textContent = "Hello", - timestamp = System.currentTimeMillis(), - processingStatus = MessageProcessingStatus.PENDING_AGENT_RESPONSE, - ) - val thread = - ThreadEntity( - threadId = threadId, - createdAt = System.currentTimeMillis(), - llmModel = LlmModel.GEMINI_3_FLASH_PREVIEW, - latestInteractionId = null, - ) + val message = createUserMessage(threadId, "Hello", messageId = "msg_1") + val thread = createThread(threadId) val llmProvider = mockk() - coEvery { observePendingMessagesUseCase(threadId) } returns - flow { - delay(10) - emit(message) - } - coEvery { manageThreadsUseCase.getThread(threadId) } returns flowOf(thread) - coEvery { settingsRepository.geminiApiKey } returns flowOf("dummy_key") - coEvery { settingsRepository.disconnectedApps } returns flowOf(emptySet()) - coEvery { llmProviderFactory.getProvider(LlmProviderName.GEMINI) } returns llmProvider + setupDefaultMocks(threadId, message, thread, llmProvider = llmProvider) coEvery { getAppFunctionsUseCase() } returns flowOf(emptyMap()) val errorMsg = "LLM failed" @@ -190,38 +151,19 @@ class AgentOrchestratorTest { fun `observeAndProcessMessages succeeds when LLM returns text`() = runTest { val threadId = "thread_1" - val message = - MessageEntity( - messageId = "msg_1", - threadId = threadId, - role = MessageRole.USER, - textContent = "Hello", - timestamp = System.currentTimeMillis(), - processingStatus = MessageProcessingStatus.PENDING_AGENT_RESPONSE, - ) - val thread = - ThreadEntity( - threadId = threadId, - createdAt = System.currentTimeMillis(), - llmModel = LlmModel.GEMINI_3_FLASH_PREVIEW, - latestInteractionId = null, - ) + val message = createUserMessage(threadId, "Hello", messageId = "msg_1") + val thread = createThread(threadId) val llmProvider = mockk() - coEvery { observePendingMessagesUseCase(threadId) } returns - flow { - delay(10) - emit(message) - } - coEvery { manageThreadsUseCase.getThread(threadId) } returns flowOf(thread) - coEvery { settingsRepository.geminiApiKey } returns flowOf("dummy_key") - coEvery { settingsRepository.disconnectedApps } returns flowOf(emptySet()) - coEvery { llmProviderFactory.getProvider(LlmProviderName.GEMINI) } returns llmProvider + setupDefaultMocks(threadId, message, thread, llmProvider = llmProvider) coEvery { getAppFunctionsUseCase() } returns flowOf(emptyMap()) val responseText = "Hi there" coEvery { llmProvider.generateResponse(any(), any(), any(), any(), any()) } returns - LlmResponse.Success("interaction_123", listOf(LlmResponsePart.Text(responseText))) + LlmResponse.Success( + "interaction_123", + listOf(LlmResponsePart.Text(responseText)), + ) agentOrchestrator.observeAndProcessMessages(threadId) @@ -243,4 +185,133 @@ class AgentOrchestratorTest { ) } } + + @Test + fun `observeAndProcessMessages scopes tools when targetPackageName is set`() = + runTest { + val threadId = "thread_1" + val message = + createUserMessage( + threadId = threadId, + textContent = "run geo code address for n1c4ag", + targetPackageName = "com.google.android.appfunctiontestingagent", + ) + val thread = createThread(threadId) + val llmProvider = mockk() + + val tool1 = createMockTool("com.google.android.appfunctiontestingagent", "run_geo_code") + val tool2 = createMockTool("com.google.android.digitalwellbeing", "digital_well_being_tool") + mockAppFunctions(listOf(tool1, tool2)) + + setupDefaultMocks(threadId, message, thread, llmProvider = llmProvider) + + coEvery { + llmProvider.generateResponse(any(), any(), any(), any(), any()) + } returns LlmResponse.Success("interaction_id", listOf(LlmResponsePart.Text("Success"))) + + agentOrchestrator.observeAndProcessMessages(threadId) + + coVerify { + llmProvider.generateResponse( + previousInteractionId = null, + input = eq(LlmInput.UserMessage("run geo code address for n1c4ag")), + tools = listOf(tool1), + apiKey = "dummy_key", + modelName = any(), + ) + } + } + + @Test + fun `observeAndProcessMessages does not scope tools when targetPackageName is null`() = + runTest { + val threadId = "thread_1" + val message = createUserMessage(threadId, "run geo code address for n1c4ag") + val thread = createThread(threadId) + val llmProvider = mockk() + + val tool1 = createMockTool("com.google.android.appfunctiontestingagent", "run_geo_code") + val tool2 = createMockTool("com.google.android.digitalwellbeing", "digital_well_being_tool") + mockAppFunctions(listOf(tool1, tool2)) + + setupDefaultMocks(threadId, message, thread, llmProvider = llmProvider) + + coEvery { + llmProvider.generateResponse(any(), any(), any(), any(), any()) + } returns LlmResponse.Success("interaction_id", listOf(LlmResponsePart.Text("Success"))) + + agentOrchestrator.observeAndProcessMessages(threadId) + + coVerify { + llmProvider.generateResponse( + previousInteractionId = null, + input = eq(LlmInput.UserMessage("run geo code address for n1c4ag")), + tools = listOf(tool1, tool2), + apiKey = "dummy_key", + modelName = any(), + ) + } + } + + private fun createUserMessage( + threadId: String, + textContent: String, + messageId: String = "message_1", + targetPackageName: String? = null, + ) = MessageEntity( + messageId = messageId, + threadId = threadId, + role = MessageRole.USER, + textContent = textContent, + timestamp = System.currentTimeMillis(), + processingStatus = MessageProcessingStatus.PENDING_AGENT_RESPONSE, + targetPackageName = targetPackageName, + ) + + private fun createThread( + threadId: String, + llmModel: LlmModel = LlmModel.GEMINI_3_FLASH_PREVIEW, + latestInteractionId: String? = null, + ) = ThreadEntity( + threadId = threadId, + createdAt = System.currentTimeMillis(), + llmModel = llmModel, + latestInteractionId = latestInteractionId, + ) + + private fun createMockTool( + packageName: String, + id: String, + isEnabled: Boolean = true, + ): AppFunctionMetadata { + val tool = mockk() + every { tool.packageName } returns packageName + every { tool.id } returns id + every { tool.isEnabled } returns isEnabled + return tool + } + + private fun mockAppFunctions(tools: List) { + val packageMetadata = mockk(relaxed = true) + coEvery { getAppFunctionsUseCase() } returns flowOf(mapOf(packageMetadata to tools)) + } + + private fun setupDefaultMocks( + threadId: String, + message: MessageEntity, + thread: ThreadEntity, + apiKey: String? = "dummy_key", + disconnectedApps: Set = emptySet(), + llmProvider: LlmProvider = mockk(), + ) { + coEvery { observePendingMessagesUseCase(threadId) } returns + flow { + delay(10) + emit(message) + } + coEvery { manageThreadsUseCase.getThread(threadId) } returns flowOf(thread) + coEvery { settingsRepository.geminiApiKey } returns flowOf(apiKey) + coEvery { settingsRepository.disconnectedApps } returns flowOf(disconnectedApps) + coEvery { llmProviderFactory.getProvider(LlmProviderName.GEMINI) } returns llmProvider + } } diff --git a/agent/app/src/test/java/com/example/appfunctions/agent/ui/screens/agentdemo/AgentDemoViewModelTest.kt b/agent/app/src/test/java/com/example/appfunctions/agent/ui/screens/agentdemo/AgentDemoViewModelTest.kt index 6834f56..27a4805 100644 --- a/agent/app/src/test/java/com/example/appfunctions/agent/ui/screens/agentdemo/AgentDemoViewModelTest.kt +++ b/agent/app/src/test/java/com/example/appfunctions/agent/ui/screens/agentdemo/AgentDemoViewModelTest.kt @@ -25,6 +25,8 @@ import com.example.appfunctions.agent.data.db.entities.MessageRole import com.example.appfunctions.agent.data.db.entities.ThreadEntity import com.example.appfunctions.agent.domain.AgentOrchestrator import com.example.appfunctions.agent.domain.AgentStatus +import com.example.appfunctions.agent.domain.appfunction.GetAppFunctionsUseCase +import com.example.appfunctions.agent.domain.appfunction.GetInstalledAppsUseCase import com.example.appfunctions.agent.domain.chat.GetChatHistoryUseCase import com.example.appfunctions.agent.domain.chat.ManageThreadsUseCase import com.example.appfunctions.agent.domain.chat.SendMessageUseCase @@ -38,6 +40,7 @@ import io.mockk.mockk import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.ExperimentalCoroutinesApi import kotlinx.coroutines.flow.MutableStateFlow +import kotlinx.coroutines.flow.flowOf import kotlinx.coroutines.test.UnconfinedTestDispatcher import kotlinx.coroutines.test.resetMain import kotlinx.coroutines.test.runTest @@ -60,6 +63,8 @@ class AgentDemoViewModelTest { private lateinit var observeActivePendingIntentsUseCase: ObserveActivePendingIntentsUseCase private lateinit var launchPendingIntentUseCase: LaunchPendingIntentUseCase private lateinit var consumePendingIntentUseCase: ConsumePendingIntentUseCase + private lateinit var getInstalledAppsUseCase: GetInstalledAppsUseCase + private lateinit var getAppFunctionsUseCase: GetAppFunctionsUseCase private lateinit var viewModel: AgentDemoViewModel @@ -82,6 +87,11 @@ class AgentDemoViewModelTest { observeActivePendingIntentsUseCase = mockk() launchPendingIntentUseCase = mockk() consumePendingIntentUseCase = mockk(relaxed = true) + getInstalledAppsUseCase = mockk() + getAppFunctionsUseCase = mockk() + + every { getInstalledAppsUseCase() } returns emptyList() + every { getAppFunctionsUseCase() } returns flowOf(emptyMap()) every { manageThreadsUseCase.getThreads() } returns threadsFlow every { settingsRepository.selectedProvider } returns selectedProviderFlow @@ -124,6 +134,8 @@ class AgentDemoViewModelTest { observeActivePendingIntentsUseCase, launchPendingIntentUseCase, consumePendingIntentUseCase, + getInstalledAppsUseCase, + getAppFunctionsUseCase, ) viewModel.onEvent(AgentUiEvent.OnModelSelected(LlmModel.GEMINI_3_1_PRO_PREVIEW)) @@ -150,6 +162,8 @@ class AgentDemoViewModelTest { observeActivePendingIntentsUseCase, launchPendingIntentUseCase, consumePendingIntentUseCase, + getInstalledAppsUseCase, + getAppFunctionsUseCase, ) coVerify { manageThreadsUseCase.createThread(LlmModel.DEFAULT) } @@ -179,10 +193,15 @@ class AgentDemoViewModelTest { observeActivePendingIntentsUseCase, launchPendingIntentUseCase, consumePendingIntentUseCase, + getInstalledAppsUseCase, + getAppFunctionsUseCase, ) coVerify(exactly = 0) { manageThreadsUseCase.createThread(any()) } - assertEquals(existingThread, (viewModel.uiState.value as AgentUiState.Loaded).currentThread) + assertEquals( + existingThread, + (viewModel.uiState.value as AgentUiState.Loaded).currentThread, + ) } @Test @@ -209,6 +228,8 @@ class AgentDemoViewModelTest { observeActivePendingIntentsUseCase, launchPendingIntentUseCase, consumePendingIntentUseCase, + getInstalledAppsUseCase, + getAppFunctionsUseCase, ) coEvery { sendMessageUseCase(any(), any(), any(), any()) } returns Unit @@ -249,6 +270,8 @@ class AgentDemoViewModelTest { observeActivePendingIntentsUseCase, launchPendingIntentUseCase, consumePendingIntentUseCase, + getInstalledAppsUseCase, + getAppFunctionsUseCase, ) val message = @@ -296,6 +319,8 @@ class AgentDemoViewModelTest { observeActivePendingIntentsUseCase, launchPendingIntentUseCase, consumePendingIntentUseCase, + getInstalledAppsUseCase, + getAppFunctionsUseCase, ) assertEquals(newThread, (viewModel.uiState.value as AgentUiState.Loaded).currentThread)