|
43 | 43 | from pydantic_ai.models import Model |
44 | 44 | from pydantic_ai.models.bedrock import BedrockConverseModel |
45 | 45 | from pydantic_ai.models.google import GoogleModel |
46 | | - from pydantic_ai.models.openai import OpenAIResponsesModel |
| 46 | + from pydantic_ai.models.openai import OpenAIChatModel, OpenAIResponsesModel |
47 | 47 | from pydantic_ai.providers import Provider |
48 | 48 | from pydantic_ai.providers.anthropic import ( |
49 | 49 | AnthropicProvider as PydanticAnthropic, |
@@ -504,8 +504,8 @@ class CustomProvider(OpenAIClientMixin, PydanticProvider["Provider[Any]"]): |
504 | 504 |
|
505 | 505 | Note: |
506 | 506 | We need to use the specific provider and model classes, because Pydantic AI has tuned them to send & return messages correctly. |
507 | | - We can also use `Agent("provider:model_name")` to avoid finding the provider and model classes ourselves. However, this does not let |
508 | | - us create custom providers and models. They rely on env vars to be set. |
| 507 | + We can also use `Agent("provider:model_name")` to avoid finding the provider ourselves. However, this does not let |
| 508 | + us create custom providers. They rely on env vars to be set. |
509 | 509 | """ |
510 | 510 |
|
511 | 511 | def __init__( |
@@ -597,6 +597,7 @@ def _create_custom_provider( |
597 | 597 | """ |
598 | 598 |
|
599 | 599 | provider_name = provider_class.__name__ |
| 600 | + LOGGER.debug(f"Creating custom provider: {provider_name}") |
600 | 601 |
|
601 | 602 | if provider_name == "CerebrasProvider": |
602 | 603 | from pydantic_ai.providers.cerebras import CerebrasProvider |
@@ -643,133 +644,55 @@ def _create_custom_provider( |
643 | 644 | client = self.get_openai_client(config) |
644 | 645 | return PydanticOpenAI(openai_client=client) |
645 | 646 |
|
646 | | - def create_model(self, max_tokens: int) -> Model: |
647 | | - """Create a model based on provider compatibility. |
648 | | -
|
649 | | - - OpenAIResponsesCompatibleProvider -> OpenAIResponsesModel |
650 | | - - OpenAIChatCompatibleProvider -> OpenAIChatModel |
651 | | - - Other providers -> infer_model to get appropriate model class |
652 | | - """ |
653 | | - # Prefer using the Responses API if available for perf and features |
654 | | - # TODO: Does not work at the moment, so we just use the OpenAIChatModel instead |
655 | | - # if self._supports_responses_api(): |
656 | | - # from pydantic_ai.models.openai import ( |
657 | | - # OpenAIResponsesModel, |
658 | | - # OpenAIResponsesModelSettings, |
659 | | - # ) |
660 | | - |
661 | | - # LOGGER.debug( |
662 | | - # f"Using OpenAIResponsesModel for {self._provider_name}" |
663 | | - # ) |
664 | | - # return OpenAIResponsesModel( |
665 | | - # model_name=self.model, |
666 | | - # provider=self.provider, |
667 | | - # settings=OpenAIResponsesModelSettings(max_tokens=max_tokens), |
668 | | - # ) |
669 | | - |
670 | | - if self._is_openai_compatible(): |
671 | | - from pydantic_ai.models.openai import ( |
672 | | - OpenAIChatModel, |
673 | | - OpenAIChatModelSettings, |
674 | | - ) |
| 647 | + def create_model(self, max_tokens: int) -> OpenAIChatModel: |
| 648 | + """Default to OpenAIChatModel""" |
675 | 649 |
|
676 | | - LOGGER.debug(f"Using OpenAIChatModel for {self._provider_name}") |
677 | | - return OpenAIChatModel( |
678 | | - model_name=self.model, |
679 | | - provider=self.provider, |
680 | | - settings=OpenAIChatModelSettings(max_tokens=max_tokens), |
681 | | - ) |
682 | | - |
683 | | - return self._create_custom_model(max_tokens) |
| 650 | + from pydantic_ai.models.openai import ( |
| 651 | + OpenAIChatModel, |
| 652 | + OpenAIChatModelSettings, |
| 653 | + ) |
684 | 654 |
|
685 | | - def _create_custom_model(self, max_tokens: int) -> Model: |
686 | | - """Create a custom model based on the provider class. These providers are not OpenAI-compatible.""" |
| 655 | + return OpenAIChatModel( |
| 656 | + model_name=self.model, |
| 657 | + provider=self.provider, |
| 658 | + settings=OpenAIChatModelSettings(max_tokens=max_tokens), |
| 659 | + ) |
687 | 660 |
|
| 661 | + def create_agent( |
| 662 | + self, |
| 663 | + max_tokens: int, |
| 664 | + tools: list[ToolDefinition], |
| 665 | + system_prompt: str, |
| 666 | + ) -> Agent[None, DeferredToolRequests | str]: |
| 667 | + """Create a Pydantic AI agent""" |
| 668 | + from pydantic_ai import Agent, UserError |
688 | 669 | from pydantic_ai.models import infer_model |
689 | 670 | from pydantic_ai.settings import ModelSettings |
690 | 671 |
|
691 | | - model_string = f"{self._provider_name}:{self.model}" |
692 | 672 | try: |
693 | | - # Don't return inferred model at first, because we want to use our own provider |
694 | | - inferred = infer_model(model_string) |
695 | | - model_class = type(inferred) |
696 | | - LOGGER.debug(f"Inferred model class: {model_class.__name__}") |
697 | | - except Exception as e: |
698 | | - from pydantic_ai.models.openai import ( |
699 | | - OpenAIChatModel, |
700 | | - OpenAIChatModelSettings, |
| 673 | + model = infer_model( |
| 674 | + f"{self._provider_name}:{self.model}", |
| 675 | + provider_factory=lambda _: self.provider, |
701 | 676 | ) |
702 | | - |
| 677 | + except UserError: |
703 | 678 | LOGGER.warning( |
704 | | - f"Could not infer model for {model_string}: {e}. Using OpenAIChatModel." |
705 | | - ) |
706 | | - return OpenAIChatModel( |
707 | | - model_name=self.model, |
708 | | - provider=self.provider, |
709 | | - settings=OpenAIChatModelSettings(max_tokens=max_tokens), |
| 679 | + f"Model {self.model} not found. Falling back to OpenAIChatModel." |
710 | 680 | ) |
711 | | - |
712 | | - # Import on-demand as top-level imports will require the package to be installed |
713 | | - model_name = model_class.__name__ |
714 | | - |
715 | | - if model_name == "CerebrasModel": |
716 | | - from pydantic_ai.models.cerebras import ( |
717 | | - CerebrasModel, |
718 | | - CerebrasModelSettings, |
719 | | - ) |
720 | | - |
721 | | - return CerebrasModel( |
722 | | - model_name=self.model, |
723 | | - provider=self.provider, |
724 | | - settings=CerebrasModelSettings(max_tokens=max_tokens), |
725 | | - ) |
726 | | - if model_name == "CohereModel": |
727 | | - from pydantic_ai.models.cohere import ( |
728 | | - CohereModel, |
729 | | - CohereModelSettings, |
730 | | - ) |
731 | | - |
732 | | - return CohereModel( |
733 | | - model_name=self.model, |
734 | | - provider=self.provider, |
735 | | - settings=CohereModelSettings(max_tokens=max_tokens), |
736 | | - ) |
737 | | - if model_name == "GroqModel": |
738 | | - from pydantic_ai.models.groq import GroqModel, GroqModelSettings |
739 | | - |
740 | | - return GroqModel( |
741 | | - model_name=self.model, |
742 | | - provider=self.provider, |
743 | | - settings=GroqModelSettings(max_tokens=max_tokens), |
744 | | - ) |
745 | | - if model_name == "HuggingFaceModel": |
746 | | - from pydantic_ai.models.huggingface import ( |
747 | | - HuggingFaceModel, |
748 | | - HuggingFaceModelSettings, |
749 | | - ) |
750 | | - |
751 | | - return HuggingFaceModel( |
752 | | - model_name=self.model, |
753 | | - provider=self.provider, |
754 | | - settings=HuggingFaceModelSettings(max_tokens=max_tokens), |
755 | | - ) |
756 | | - if model_name == "MistralModel": |
757 | | - from pydantic_ai.models.mistral import ( |
758 | | - MistralModel, |
759 | | - MistralModelSettings, |
760 | | - ) |
761 | | - |
762 | | - return MistralModel( |
763 | | - model_name=self.model, |
764 | | - provider=self.provider, |
765 | | - settings=MistralModelSettings(max_tokens=max_tokens), |
| 681 | + model = self.create_model(max_tokens) |
| 682 | + except Exception as e: |
| 683 | + LOGGER.error( |
| 684 | + f"Error creating model: {e}. Falling back to OpenAIChatModel." |
766 | 685 | ) |
| 686 | + model = self.create_model(max_tokens) |
767 | 687 |
|
768 | | - # Unknown model class - return the inferred model directly |
769 | | - LOGGER.warning( |
770 | | - f"Unknown model: {self._provider_name}:{self.model}. Create a GitHub issue to request support." |
| 688 | + toolset, output_type = self._get_toolsets_and_output_type(tools) |
| 689 | + return Agent( |
| 690 | + model, |
| 691 | + model_settings=ModelSettings(max_tokens=max_tokens), |
| 692 | + toolsets=[toolset] if tools else None, |
| 693 | + instructions=system_prompt, |
| 694 | + output_type=output_type, |
771 | 695 | ) |
772 | | - return model_class(settings=ModelSettings(max_tokens=max_tokens)) |
773 | 696 |
|
774 | 697 |
|
775 | 698 | class AnthropicProvider(PydanticProvider["PydanticAnthropic"]): |
|
0 commit comments