Skip to content

Commit ed34f1b

Browse files
committed
use simpler infer_model
1 parent 5dc8016 commit ed34f1b

File tree

1 file changed

+40
-117
lines changed

1 file changed

+40
-117
lines changed

marimo/_server/ai/providers.py

Lines changed: 40 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
from pydantic_ai.models import Model
4444
from pydantic_ai.models.bedrock import BedrockConverseModel
4545
from pydantic_ai.models.google import GoogleModel
46-
from pydantic_ai.models.openai import OpenAIResponsesModel
46+
from pydantic_ai.models.openai import OpenAIChatModel, OpenAIResponsesModel
4747
from pydantic_ai.providers import Provider
4848
from pydantic_ai.providers.anthropic import (
4949
AnthropicProvider as PydanticAnthropic,
@@ -504,8 +504,8 @@ class CustomProvider(OpenAIClientMixin, PydanticProvider["Provider[Any]"]):
504504
505505
Note:
506506
We need to use the specific provider and model classes, because Pydantic AI has tuned them to send & return messages correctly.
507-
We can also use `Agent("provider:model_name")` to avoid finding the provider and model classes ourselves. However, this does not let
508-
us create custom providers and models. They rely on env vars to be set.
507+
We can also use `Agent("provider:model_name")` to avoid finding the provider ourselves. However, this does not let
508+
us create custom providers. They rely on env vars to be set.
509509
"""
510510

511511
def __init__(
@@ -597,6 +597,7 @@ def _create_custom_provider(
597597
"""
598598

599599
provider_name = provider_class.__name__
600+
LOGGER.debug(f"Creating custom provider: {provider_name}")
600601

601602
if provider_name == "CerebrasProvider":
602603
from pydantic_ai.providers.cerebras import CerebrasProvider
@@ -643,133 +644,55 @@ def _create_custom_provider(
643644
client = self.get_openai_client(config)
644645
return PydanticOpenAI(openai_client=client)
645646

646-
def create_model(self, max_tokens: int) -> Model:
647-
"""Create a model based on provider compatibility.
648-
649-
- OpenAIResponsesCompatibleProvider -> OpenAIResponsesModel
650-
- OpenAIChatCompatibleProvider -> OpenAIChatModel
651-
- Other providers -> infer_model to get appropriate model class
652-
"""
653-
# Prefer using the Responses API if available for perf and features
654-
# TODO: Does not work at the moment, so we just use the OpenAIChatModel instead
655-
# if self._supports_responses_api():
656-
# from pydantic_ai.models.openai import (
657-
# OpenAIResponsesModel,
658-
# OpenAIResponsesModelSettings,
659-
# )
660-
661-
# LOGGER.debug(
662-
# f"Using OpenAIResponsesModel for {self._provider_name}"
663-
# )
664-
# return OpenAIResponsesModel(
665-
# model_name=self.model,
666-
# provider=self.provider,
667-
# settings=OpenAIResponsesModelSettings(max_tokens=max_tokens),
668-
# )
669-
670-
if self._is_openai_compatible():
671-
from pydantic_ai.models.openai import (
672-
OpenAIChatModel,
673-
OpenAIChatModelSettings,
674-
)
647+
def create_model(self, max_tokens: int) -> OpenAIChatModel:
648+
"""Default to OpenAIChatModel"""
675649

676-
LOGGER.debug(f"Using OpenAIChatModel for {self._provider_name}")
677-
return OpenAIChatModel(
678-
model_name=self.model,
679-
provider=self.provider,
680-
settings=OpenAIChatModelSettings(max_tokens=max_tokens),
681-
)
682-
683-
return self._create_custom_model(max_tokens)
650+
from pydantic_ai.models.openai import (
651+
OpenAIChatModel,
652+
OpenAIChatModelSettings,
653+
)
684654

685-
def _create_custom_model(self, max_tokens: int) -> Model:
686-
"""Create a custom model based on the provider class. These providers are not OpenAI-compatible."""
655+
return OpenAIChatModel(
656+
model_name=self.model,
657+
provider=self.provider,
658+
settings=OpenAIChatModelSettings(max_tokens=max_tokens),
659+
)
687660

661+
def create_agent(
662+
self,
663+
max_tokens: int,
664+
tools: list[ToolDefinition],
665+
system_prompt: str,
666+
) -> Agent[None, DeferredToolRequests | str]:
667+
"""Create a Pydantic AI agent"""
668+
from pydantic_ai import Agent, UserError
688669
from pydantic_ai.models import infer_model
689670
from pydantic_ai.settings import ModelSettings
690671

691-
model_string = f"{self._provider_name}:{self.model}"
692672
try:
693-
# Don't return inferred model at first, because we want to use our own provider
694-
inferred = infer_model(model_string)
695-
model_class = type(inferred)
696-
LOGGER.debug(f"Inferred model class: {model_class.__name__}")
697-
except Exception as e:
698-
from pydantic_ai.models.openai import (
699-
OpenAIChatModel,
700-
OpenAIChatModelSettings,
673+
model = infer_model(
674+
f"{self._provider_name}:{self.model}",
675+
provider_factory=lambda _: self.provider,
701676
)
702-
677+
except UserError:
703678
LOGGER.warning(
704-
f"Could not infer model for {model_string}: {e}. Using OpenAIChatModel."
705-
)
706-
return OpenAIChatModel(
707-
model_name=self.model,
708-
provider=self.provider,
709-
settings=OpenAIChatModelSettings(max_tokens=max_tokens),
679+
f"Model {self.model} not found. Falling back to OpenAIChatModel."
710680
)
711-
712-
# Import on-demand as top-level imports will require the package to be installed
713-
model_name = model_class.__name__
714-
715-
if model_name == "CerebrasModel":
716-
from pydantic_ai.models.cerebras import (
717-
CerebrasModel,
718-
CerebrasModelSettings,
719-
)
720-
721-
return CerebrasModel(
722-
model_name=self.model,
723-
provider=self.provider,
724-
settings=CerebrasModelSettings(max_tokens=max_tokens),
725-
)
726-
if model_name == "CohereModel":
727-
from pydantic_ai.models.cohere import (
728-
CohereModel,
729-
CohereModelSettings,
730-
)
731-
732-
return CohereModel(
733-
model_name=self.model,
734-
provider=self.provider,
735-
settings=CohereModelSettings(max_tokens=max_tokens),
736-
)
737-
if model_name == "GroqModel":
738-
from pydantic_ai.models.groq import GroqModel, GroqModelSettings
739-
740-
return GroqModel(
741-
model_name=self.model,
742-
provider=self.provider,
743-
settings=GroqModelSettings(max_tokens=max_tokens),
744-
)
745-
if model_name == "HuggingFaceModel":
746-
from pydantic_ai.models.huggingface import (
747-
HuggingFaceModel,
748-
HuggingFaceModelSettings,
749-
)
750-
751-
return HuggingFaceModel(
752-
model_name=self.model,
753-
provider=self.provider,
754-
settings=HuggingFaceModelSettings(max_tokens=max_tokens),
755-
)
756-
if model_name == "MistralModel":
757-
from pydantic_ai.models.mistral import (
758-
MistralModel,
759-
MistralModelSettings,
760-
)
761-
762-
return MistralModel(
763-
model_name=self.model,
764-
provider=self.provider,
765-
settings=MistralModelSettings(max_tokens=max_tokens),
681+
model = self.create_model(max_tokens)
682+
except Exception as e:
683+
LOGGER.error(
684+
f"Error creating model: {e}. Falling back to OpenAIChatModel."
766685
)
686+
model = self.create_model(max_tokens)
767687

768-
# Unknown model class - return the inferred model directly
769-
LOGGER.warning(
770-
f"Unknown model: {self._provider_name}:{self.model}. Create a GitHub issue to request support."
688+
toolset, output_type = self._get_toolsets_and_output_type(tools)
689+
return Agent(
690+
model,
691+
model_settings=ModelSettings(max_tokens=max_tokens),
692+
toolsets=[toolset] if tools else None,
693+
instructions=system_prompt,
694+
output_type=output_type,
771695
)
772-
return model_class(settings=ModelSettings(max_tokens=max_tokens))
773696

774697

775698
class AnthropicProvider(PydanticProvider["PydanticAnthropic"]):

0 commit comments

Comments
 (0)