Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions examples/create.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
from ollama import Client

client = Client()

modelfile = '''
FROM gemma3
SYSTEM You are mario from Super Mario Bros.
'''

response = client.create(
model='my-assistant',
from_='gemma3',
system='You are mario from Super Mario Bros.',
modelfile=modelfile,
stream=False,
)
print(response.status)
4 changes: 4 additions & 0 deletions ollama/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
ShowResponse,
StatusResponse,
Tool,
VersionResponse,
WebFetchResponse,
WebSearchResponse,
)
Expand All @@ -37,6 +38,7 @@
'ShowResponse',
'StatusResponse',
'Tool',
'VersionResponse',
'WebFetchResponse',
'WebSearchResponse',
]
Expand All @@ -55,5 +57,7 @@
copy = _client.copy
show = _client.show
ps = _client.ps
version = _client.version
check_blob = _client.check_blob
web_search = _client.web_search
web_fetch = _client.web_fetch
204 changes: 168 additions & 36 deletions ollama/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,96 @@
ShowResponse,
StatusResponse,
Tool,
VersionResponse,
WebFetchRequest,
WebFetchResponse,
WebSearchRequest,
WebSearchResponse,
)

T = TypeVar('T')
BlobPath = Union[str, PathLike]

SHA256_DIGEST_PREFIX = 'sha256:'


def _is_sha256_digest(value: str) -> bool:
if not value.startswith(SHA256_DIGEST_PREFIX):
return False

digest = value[len(SHA256_DIGEST_PREFIX) :]
return len(digest) == 64 and all(c in '0123456789abcdefABCDEF' for c in digest)


def _is_existing_path(value: BlobPath) -> bool:
try:
return Path(value).is_file()
except (OSError, TypeError, ValueError):
return False


def _sha256_digest(path: BlobPath) -> str:
sha256sum = sha256()
with open(path, 'rb') as r:
while True:
chunk = r.read(32 * 1024)
if not chunk:
break
sha256sum.update(chunk)

return f'{SHA256_DIGEST_PREFIX}{sha256sum.hexdigest()}'


async def _async_sha256_digest(path: BlobPath) -> str:
sha256sum = sha256()
async with await anyio.open_file(path, 'rb') as r:
while True:
chunk = await r.read(32 * 1024)
if not chunk:
break
sha256sum.update(chunk)

return f'{SHA256_DIGEST_PREFIX}{sha256sum.hexdigest()}'


def _resolve_blob_map(
blobs: Optional[Mapping[str, BlobPath]],
upload: Callable[[BlobPath], str],
) -> Optional[Dict[str, str]]:
if not blobs:
return None

resolved = {}
for name, value in blobs.items():
value_str = os.fspath(value)
if _is_sha256_digest(value_str):
resolved[name] = value_str
elif _is_existing_path(value):
resolved[name] = upload(value)
else:
resolved[name] = value_str

return resolved


async def _async_resolve_blob_map(
blobs: Optional[Mapping[str, BlobPath]],
upload: Callable[[BlobPath], Any],
) -> Optional[Dict[str, str]]:
if not blobs:
return None

resolved = {}
for name, value in blobs.items():
value_str = os.fspath(value)
if _is_sha256_digest(value_str):
resolved[name] = value_str
elif _is_existing_path(value):
resolved[name] = await upload(value)
else:
resolved[name] = value_str

return resolved


class BaseClient(contextlib.AbstractContextManager, contextlib.AbstractAsyncContextManager):
Expand All @@ -93,6 +176,7 @@ def __init__(
- `follow_redirects`: True
- `timeout`: None
`kwargs` are passed to the httpx client.

"""

headers = {
Expand Down Expand Up @@ -538,8 +622,9 @@ def create(
model: str,
quantize: Optional[str] = None,
from_: Optional[str] = None,
files: Optional[Dict[str, str]] = None,
adapters: Optional[Dict[str, str]] = None,
modelfile: Optional[str] = None,
files: Optional[Mapping[str, BlobPath]] = None,
adapters: Optional[Mapping[str, BlobPath]] = None,
template: Optional[str] = None,
license: Optional[Union[str, List[str]]] = None,
system: Optional[str] = None,
Expand All @@ -555,8 +640,9 @@ def create(
model: str,
quantize: Optional[str] = None,
from_: Optional[str] = None,
files: Optional[Dict[str, str]] = None,
adapters: Optional[Dict[str, str]] = None,
modelfile: Optional[str] = None,
files: Optional[Mapping[str, BlobPath]] = None,
adapters: Optional[Mapping[str, BlobPath]] = None,
template: Optional[str] = None,
license: Optional[Union[str, List[str]]] = None,
system: Optional[str] = None,
Expand All @@ -571,8 +657,9 @@ def create(
model: str,
quantize: Optional[str] = None,
from_: Optional[str] = None,
files: Optional[Dict[str, str]] = None,
adapters: Optional[Dict[str, str]] = None,
modelfile: Optional[str] = None,
files: Optional[Mapping[str, BlobPath]] = None,
adapters: Optional[Mapping[str, BlobPath]] = None,
template: Optional[str] = None,
license: Optional[Union[str, List[str]]] = None,
system: Optional[str] = None,
Expand All @@ -595,8 +682,9 @@ def create(
stream=stream,
quantize=quantize,
from_=from_,
files=files,
adapters=adapters,
modelfile=modelfile,
files=_resolve_blob_map(files, self.create_blob),
adapters=_resolve_blob_map(adapters, self.create_blob),
license=license,
template=template,
system=system,
Expand All @@ -606,16 +694,8 @@ def create(
stream=stream,
)

def create_blob(self, path: Union[str, Path]) -> str:
sha256sum = sha256()
with open(path, 'rb') as r:
while True:
chunk = r.read(32 * 1024)
if not chunk:
break
sha256sum.update(chunk)

digest = f'sha256:{sha256sum.hexdigest()}'
def create_blob(self, path: BlobPath) -> str:
digest = _sha256_digest(path)

with open(path, 'rb') as r:
self._request_raw('POST', f'/api/blobs/{digest}', content=r)
Expand Down Expand Up @@ -671,6 +751,34 @@ def ps(self) -> ProcessResponse:
'/api/ps',
)

def version(self) -> VersionResponse:
"""
Retrieve the server version.

Returns `VersionResponse` with the running Ollama server version string.
"""
return self._request(
VersionResponse,
'GET',
'/api/version',
)

def check_blob(self, digest: str) -> bool:
"""
Check whether a blob with the given digest already exists on the server.

Uses `HEAD /api/blobs/:digest` to avoid uploading data that is already present.

Returns `True` if the blob exists, `False` if it does not.
"""
try:
r = self._request_raw('HEAD', f'/api/blobs/{digest}')
return r.status_code == 200
except ResponseError as e:
if e.status_code == 404:
return False
raise

def web_search(self, query: str, max_results: int = 3) -> WebSearchResponse:
"""
Performs a web search
Expand Down Expand Up @@ -1171,8 +1279,9 @@ async def create(
model: str,
quantize: Optional[str] = None,
from_: Optional[str] = None,
files: Optional[Dict[str, str]] = None,
adapters: Optional[Dict[str, str]] = None,
modelfile: Optional[str] = None,
files: Optional[Mapping[str, BlobPath]] = None,
adapters: Optional[Mapping[str, BlobPath]] = None,
template: Optional[str] = None,
license: Optional[Union[str, List[str]]] = None,
system: Optional[str] = None,
Expand All @@ -1188,8 +1297,9 @@ async def create(
model: str,
quantize: Optional[str] = None,
from_: Optional[str] = None,
files: Optional[Dict[str, str]] = None,
adapters: Optional[Dict[str, str]] = None,
modelfile: Optional[str] = None,
files: Optional[Mapping[str, BlobPath]] = None,
adapters: Optional[Mapping[str, BlobPath]] = None,
template: Optional[str] = None,
license: Optional[Union[str, List[str]]] = None,
system: Optional[str] = None,
Expand All @@ -1204,8 +1314,9 @@ async def create(
model: str,
quantize: Optional[str] = None,
from_: Optional[str] = None,
files: Optional[Dict[str, str]] = None,
adapters: Optional[Dict[str, str]] = None,
modelfile: Optional[str] = None,
files: Optional[Mapping[str, BlobPath]] = None,
adapters: Optional[Mapping[str, BlobPath]] = None,
template: Optional[str] = None,
license: Optional[Union[str, List[str]]] = None,
system: Optional[str] = None,
Expand All @@ -1229,8 +1340,9 @@ async def create(
stream=stream,
quantize=quantize,
from_=from_,
files=files,
adapters=adapters,
modelfile=modelfile,
files=await _async_resolve_blob_map(files, self.create_blob),
adapters=await _async_resolve_blob_map(adapters, self.create_blob),
license=license,
template=template,
system=system,
Expand All @@ -1240,16 +1352,8 @@ async def create(
stream=stream,
)

async def create_blob(self, path: Union[str, Path]) -> str:
sha256sum = sha256()
async with await anyio.open_file(path, 'rb') as r:
while True:
chunk = await r.read(32 * 1024)
if not chunk:
break
sha256sum.update(chunk)

digest = f'sha256:{sha256sum.hexdigest()}'
async def create_blob(self, path: BlobPath) -> str:
digest = await _async_sha256_digest(path)

async def upload_bytes():
async with await anyio.open_file(path, 'rb') as r:
Expand Down Expand Up @@ -1312,6 +1416,34 @@ async def ps(self) -> ProcessResponse:
'/api/ps',
)

async def version(self) -> VersionResponse:
"""
Retrieve the server version.

Returns `VersionResponse` with the running Ollama server version string.
"""
return await self._request(
VersionResponse,
'GET',
'/api/version',
)

async def check_blob(self, digest: str) -> bool:
"""
Check whether a blob with the given digest already exists on the server.

Uses `HEAD /api/blobs/:digest` to avoid uploading data that is already present.

Returns `True` if the blob exists, `False` if it does not.
"""
try:
r = await self._request_raw('HEAD', f'/api/blobs/{digest}')
return r.status_code == 200
except ResponseError as e:
if e.status_code == 404:
return False
raise


def _copy_images(images: Optional[Sequence[Union[Image, Any]]]) -> Iterator[Image]:
for image in images or []:
Expand Down
14 changes: 14 additions & 0 deletions ollama/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,7 @@ def serialize_model(self, nxt):
"""
quantize: Optional[str] = None
from_: Optional[str] = None
modelfile: Optional[str] = None
files: Optional[Dict[str, str]] = None
adapters: Optional[Dict[str, str]] = None
template: Optional[str] = None
Expand All @@ -517,11 +518,15 @@ class ModelDetails(SubscriptableBaseModel):

class ListResponse(SubscriptableBaseModel):
class Model(SubscriptableBaseModel):
name: Optional[str] = None
model: Optional[str] = None
modified_at: Optional[datetime] = None
digest: Optional[str] = None
size: Optional[ByteSize] = None
details: Optional[ModelDetails] = None
remote_model: Optional[str] = None
remote_host: Optional[str] = None
capabilities: Optional[Sequence[str]] = None

models: Sequence[Model]
'List of models.'
Expand Down Expand Up @@ -649,3 +654,12 @@ def __init__(self, error: str, status_code: int = -1):

def __str__(self) -> str:
return f'{self.error} (status code: {self.status_code})'


class VersionResponse(SubscriptableBaseModel):
"""
Response from the version endpoint.
"""

version: str
'Server version string.'
Loading