"""IBM watsonx.ai embeddings wrapper."""
import logging
from typing import Any, cast
from ibm_watsonx_ai import APIClient # type: ignore[import-untyped]
from ibm_watsonx_ai.foundation_models.embeddings import ( # type: ignore[import-untyped]
Embeddings,
)
from ibm_watsonx_ai.gateway import Gateway # type: ignore[import-untyped]
from langchain_core.embeddings import Embeddings as LangChainEmbeddings
from langchain_core.utils.utils import secret_from_env
from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator
from typing_extensions import Self
from langchain_ibm.utils import (
async_gateway_error_handler,
extract_params,
gateway_error_handler,
resolve_watsonx_credentials,
)
logger = logging.getLogger(__name__)
[docs]
class WatsonxEmbeddings(BaseModel, LangChainEmbeddings):
"""`IBM watsonx.ai` embedding model integration.
???+ info "Setup"
To use, you should have `langchain_ibm` python package installed,
and the environment variable `WATSONX_APIKEY` set with your API key, or pass
it as a named parameter `apikey` to the constructor.
```bash
pip install -U langchain-ibm
# or using uv
uv add langchain-ibm
```
```bash
export WATSONX_APIKEY="your-api-key"
```
??? info "Instantiate"
```python
from langchain_ibm import WatsonxEmbeddings
embeddings = WatsonxEmbeddings(
model_id="ibm/granite-embedding-278m-multilingual",
url="https://us-south.ml.cloud.ibm.com",
project_id="*****",
# apikey="*****"
)
```
??? info "Embed single text"
```python
input_text = "The meaning of life is 42"
vector = embeddings.embed_query("hello")
print(vector[:3])
```
```python
[-0.0020519258, 0.0147288125, -0.0090887165]
```
??? info "Embed multiple texts"
```python
vectors = embeddings.embed_documents(["hello", "goodbye"])
# Showing only the first 3 coordinates
print(len(vectors))
print(vectors[0][:3])
```
```python
2
[-0.0020519265, 0.01472881, -0.009088721]
```
??? info "Async"
```python
await embeddings.aembed_query(input_text)
print(vector[:3])
# multiple:
# await embeddings.aembed_documents(input_texts)
```
```python
[-0.0020519258, 0.0147288125, -0.0090887165]
```
"""
model_id: str | None = None
"""Type of model to use."""
model: str | None = None
"""
Name or alias of the foundation model to use.
When using IBM's watsonx.ai Model Gateway (public preview), you can specify any
supported third-party model—OpenAI, Anthropic, NVIDIA, Cerebras, or IBM's own
Granite series—via a single, OpenAI-compatible interface. Models must be explicitly
provisioned (opt-in) through the Gateway to ensure secure, vendor-agnostic access
and easy switch-over without reconfiguration.
For more details on configuration and usage, see [IBM watsonx Model Gateway docs](https://dataplatform.cloud.ibm.com/docs/content/wsj/analyze-data/fm-model-gateway.html?context=wx&audience=wdp)
"""
project_id: str | None = None
"""ID of the Watson Studio project."""
space_id: str | None = None
"""ID of the Watson Studio space."""
url: SecretStr = Field(
alias="url",
default_factory=secret_from_env("WATSONX_URL", default=None), # type: ignore[assignment]
)
"""URL to the Watson Machine Learning or CPD instance."""
apikey: SecretStr | None = Field(
alias="apikey",
default_factory=secret_from_env("WATSONX_APIKEY", default=None),
)
"""API key to the Watson Machine Learning or CPD instance."""
token: SecretStr | None = Field(
alias="token",
default_factory=secret_from_env("WATSONX_TOKEN", default=None),
)
"""Token to the CPD instance."""
password: SecretStr | None = Field(
alias="password",
default_factory=secret_from_env("WATSONX_PASSWORD", default=None),
)
"""Password to the CPD instance."""
username: SecretStr | None = Field(
alias="username",
default_factory=secret_from_env("WATSONX_USERNAME", default=None),
)
"""Username to the CPD instance."""
instance_id: SecretStr | None = Field(
alias="instance_id",
default_factory=secret_from_env("WATSONX_INSTANCE_ID", default=None),
)
"""Instance_id of the CPD instance."""
version: SecretStr | None = None
"""Version of the CPD instance."""
params: dict | None = None
"""Model parameters to use during request generation."""
verify: str | bool | None = None
"""You can pass one of following as verify:
* the path to a CA_BUNDLE file
* the path of directory with certificates of trusted CAs
* True - default path to truststore will be taken
* False - no verification will be made
"""
watsonx_embed: Embeddings = Field(default=None) #: :meta private:
watsonx_embed_gateway: Gateway = Field(
default=None,
exclude=True,
) #: :meta private:
watsonx_client: APIClient | None = Field(default=None) #: :meta private:
model_config = ConfigDict(
extra="forbid",
arbitrary_types_allowed=True,
protected_namespaces=(),
)
@model_validator(mode="after")
def validate_environment(self) -> Self:
"""Validate that credentials and python package exists in environment."""
if self.watsonx_embed_gateway is not None:
error_msg = (
"Passing the 'watsonx_embed_gateway' parameter to the "
"WatsonxEmbeddings constructor is not supported yet.",
)
raise NotImplementedError(error_msg)
if isinstance(self.watsonx_embed, Embeddings):
self.model_id = self.watsonx_embed.model_id
self.project_id = self.watsonx_embed._client.default_project_id # noqa: SLF001
self.space_id = self.watsonx_embed._client.default_space_id # noqa: SLF001
self.params = self.watsonx_embed.params
elif isinstance(self.watsonx_client, APIClient):
if sum(map(bool, (self.model, self.model_id))) != 1:
error_msg = (
"The parameters 'model' and 'model_id' are mutually exclusive. "
"Please specify exactly one of these parameters when "
"initializing WatsonxEmbeddings.",
)
raise ValueError(error_msg)
if self.model is not None:
watsonx_embed_gateway = Gateway(
api_client=self.watsonx_client,
verify=self.verify,
)
self.watsonx_embed_gateway = watsonx_embed_gateway
else:
watsonx_embed = Embeddings(
model_id=self.model_id,
params=self.params,
api_client=self.watsonx_client,
project_id=self.project_id,
space_id=self.space_id,
verify=self.verify,
)
self.watsonx_embed = watsonx_embed
else:
if sum(map(bool, (self.model, self.model_id))) != 1:
error_msg = (
"The parameters 'model' and 'model_id' are mutually exclusive. "
"Please specify exactly one of these parameters when "
"initializing WatsonxEmbeddings.",
)
raise ValueError(error_msg)
credentials = resolve_watsonx_credentials(
url=self.url,
apikey=self.apikey,
token=self.token,
password=self.password,
username=self.username,
instance_id=self.instance_id,
version=self.version,
verify=self.verify,
)
if self.model is not None:
watsonx_embed_gateway = Gateway(
credentials=credentials,
verify=self.verify,
)
self.watsonx_embed_gateway = watsonx_embed_gateway
else:
watsonx_embed = Embeddings(
model_id=self.model_id,
params=self.params,
credentials=credentials,
project_id=self.project_id,
space_id=self.space_id,
)
self.watsonx_embed = watsonx_embed
return self
@gateway_error_handler
def _call_model_gateway(
self,
*,
model: str,
texts: list[str],
**params: Any,
) -> Any:
return self.watsonx_embed_gateway.embeddings.create(
model=model,
input=texts,
**params,
)
@async_gateway_error_handler
async def _acall_model_gateway(
self,
*,
model: str,
texts: list[str],
**params: Any,
) -> Any:
return await self.watsonx_embed_gateway.embeddings.acreate(
model=model,
input=texts,
**params,
)
[docs]
def embed_documents(self, texts: list[str], **kwargs: Any) -> list[list[float]]:
"""Embed search docs."""
params = extract_params(kwargs, self.params)
if self.watsonx_embed_gateway is not None:
call_kwargs = {**kwargs, **params}
embed_response = self._call_model_gateway(
model=self.model,
texts=texts,
**call_kwargs,
)
return [embedding["embedding"] for embedding in embed_response["data"]]
return cast(
"list[list[float]]",
self.watsonx_embed.embed_documents(
texts=texts,
**(kwargs | {"params": params}),
),
)
[docs]
async def aembed_documents(
self,
texts: list[str],
**kwargs: Any,
) -> list[list[float]]:
"""Asynchronous Embed search docs."""
params = extract_params(kwargs, self.params)
if self.watsonx_embed_gateway is not None:
call_kwargs = {**kwargs, **params}
embed_response = await self._acall_model_gateway(
model=self.model,
texts=texts,
**call_kwargs,
)
return [embedding["embedding"] for embedding in embed_response["data"]]
return cast(
"list[list[float]]",
await self.watsonx_embed.aembed_documents(
texts=texts,
**(kwargs | {"params": params}),
),
)
[docs]
def embed_query(self, text: str, **kwargs: Any) -> list[float]:
"""Embed query text."""
return self.embed_documents([text], **kwargs)[0]
[docs]
async def aembed_query(self, text: str, **kwargs: Any) -> list[float]:
"""Asynchronous Embed query text."""
embeddings = await self.aembed_documents([text], **kwargs)
return embeddings[0]