Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions docs/examples/melp/lazy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import asyncio
from mellea.stdlib.base import (
SimpleContext,
Context,
CBlock,
ModelOutputThunk,
SimpleComponent,
)
from mellea.backends import Backend
from mellea.backends.ollama import OllamaModelBackend

backend = OllamaModelBackend("granite4:latest")


async def fib(backend: Backend, ctx: Context, x: CBlock, y: CBlock) -> ModelOutputThunk:
sc = SimpleComponent(
instruction="What is x+y? Respond with the number only.", x=x, y=y
)
mot, _ = await backend.generate_from_context(action=sc, ctx=SimpleContext())
return mot


async def main(backend: Backend, ctx: Context):
fibs = []
for i in range(100):
if i == 0 or i == 1:
fibs.append(CBlock(f"{i + 1}"))
else:
fibs.append(await fib(backend, ctx, fibs[i - 1], fibs[i - 2]))

for x in fibs:
match x:
case ModelOutputThunk():
print(await x.avalue())
case CBlock():
print(x.value)


asyncio.run(main(backend, SimpleContext()))
44 changes: 44 additions & 0 deletions docs/examples/melp/lazy_fib.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import asyncio
from mellea.stdlib.base import (
SimpleContext,
Context,
CBlock,
ModelOutputThunk,
SimpleComponent,
)
from mellea.stdlib.requirement import Requirement
from mellea.backends import Backend
from mellea.backends.ollama import OllamaModelBackend
from typing import Tuple

backend = OllamaModelBackend("granite4:latest")


async def fib(backend: Backend, ctx: Context, x: CBlock, y: CBlock) -> ModelOutputThunk:
sc = SimpleComponent(
instruction="What is x+y? Respond with the number only.", x=x, y=y
)
mot, _ = await backend.generate_from_context(action=sc, ctx=SimpleContext())
return mot


async def fib_main(backend: Backend, ctx: Context):
fibs = []
for i in range(20):
if i == 0 or i == 1:
fibs.append(CBlock(f"{i}"))
else:
mot = await fib(backend, ctx, fibs[i - 1], fibs[i - 2])
fibs.append(mot)

print(await fibs[-1].avalue())
# for x in fibs:
# match x:
# case ModelOutputThunk():
# n = await x.avalue()
# print(n)
# case CBlock():
# print(x.value)


asyncio.run(fib_main(backend, SimpleContext()))
66 changes: 66 additions & 0 deletions docs/examples/melp/lazy_fib_sample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import asyncio
from mellea.stdlib.base import (
SimpleContext,
Context,
CBlock,
ModelOutputThunk,
SimpleComponent,
)
from mellea.stdlib.requirement import Requirement
from mellea.backends import Backend
from mellea.backends.ollama import OllamaModelBackend
from typing import Tuple

backend = OllamaModelBackend("granite4:latest")


async def _fib_sample(
backend: Backend, ctx: Context, x: CBlock, y: CBlock
) -> ModelOutputThunk | None:
sc = SimpleComponent(
instruction="What is x+y? Respond with the number only.", x=x, y=y
)
answer_mot, _ = await backend.generate_from_context(action=sc, ctx=SimpleContext())

# This is a fundamental thing: it means computation must occur.
# We need to be able to read this off at c.g. construction time.
value = await answer_mot.avalue()

try:
int(value)
return answer_mot
except:
return None


async def fib_sampling_version(
backend: Backend, ctx: Context, x: CBlock, y: CBlock
) -> ModelOutputThunk | None:
for i in range(5):
sample = await _fib_sample(backend, ctx, x, y)
if sample is not None:
return sample
else:
continue
return None


async def fib_sampling_version_main(backend: Backend, ctx: Context):
fibs = []
for i in range(20):
if i == 0 or i == 1:
fibs.append(CBlock(f"{i}"))
else:
mot = await fib_sampling_version(backend, ctx, fibs[i - 1], fibs[i - 2])
fibs.append(mot)

for x_i, x in enumerate(fibs):
match x:
case ModelOutputThunk():
n = await x.avalue()
print(n)
case CBlock():
print(x.value)


asyncio.run(fib_sampling_version_main(backend, SimpleContext()))
38 changes: 38 additions & 0 deletions docs/examples/melp/simple_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import asyncio
from mellea.stdlib.base import Context, CBlock, SimpleContext, ModelOutputThunk
from mellea.backends import Backend
from mellea.backends.ollama import OllamaModelBackend


async def main(backend: Backend, ctx: Context):
"""
In this example, we show how executing multiple MOTs in parallel should work.
"""
m_states = "Missouri", "Minnesota", "Montana", "Massachusetts"

poem_thunks = []
for state_name in m_states:
mot, ctx = await backend.generate_from_context(
CBlock(f"Write a poem about {state_name}"), ctx
)
poem_thunks.append(mot)

# Notice that what we have now is a list of ModelOutputThunks, none of which are computed.
for poem_thunk in poem_thunks:
assert type(poem_thunk) == ModelOutputThunk
print(f"Computed: {poem_thunk.is_computed()}")

# Let's run all of these in parallel.
await asyncio.gather(*[c.avalue() for c in poem_thunks])

# Print out the final results, which are now computed.
for poem_thunk in poem_thunks:
print(f"Computed: {poem_thunk.is_computed()}")

# And let's print out the final results.
for poem_thunk in poem_thunks:
print(poem_thunk.value)


backend = OllamaModelBackend(model_id="granite4:latest")
asyncio.run(main(backend, SimpleContext()))
44 changes: 44 additions & 0 deletions docs/examples/melp/states.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from mellea.stdlib.base import SimpleContext, Context, CBlock, SimpleComponent
from mellea.backends import Backend
from mellea.backends.ollama import OllamaModelBackend
import asyncio


async def main(backend: Backend, ctx: Context):
a_states = "Alaska,Arizona,Arkansas".split(",")
m_states = "Missouri", "Minnesota", "Montana", "Massachusetts"

a_state_pops = dict()
for state in a_states:
a_state_pops[state], _ = await backend.generate_from_context(
CBlock(f"What is the population of {state}? Respond with an integer only."),
SimpleContext(),
)
a_total_pop = SimpleComponent(
instruction=CBlock(
"What is the total population of these states? Respond with an integer only."
),
**a_state_pops,
)
a_state_total, _ = await backend.generate_from_context(a_total_pop, SimpleContext())

m_state_pops = dict()
for state in m_states:
m_state_pops[state], _ = await backend.generate_from_context(
CBlock(f"What is the population of {state}? Respond with an integer only."),
SimpleContext(),
)
m_total_pop = SimpleComponent(
instruction=CBlock(
"What is the total population of these states? Respond with an integer only."
),
**m_state_pops,
)
m_state_total, _ = await backend.generate_from_context(m_total_pop, SimpleContext())

print(await a_state_total.avalue())
print(await m_state_total.avalue())


backend = OllamaModelBackend(model_id="granite4:latest")
asyncio.run(main(backend, SimpleContext()))
9 changes: 9 additions & 0 deletions docs/rewrite/session_deepdive/1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import mellea.stdlib.functional as mfuncs
from mellea.stdlib.base import SimpleContext
from mellea.backends.ollama import OllamaModelBackend

response, next_context = mfuncs.chat("What is 1+1?",
context=SimpleContext(),
backend=OllamaModelBackend("granite4:latest"))

print(response.content)
9 changes: 9 additions & 0 deletions docs/rewrite/session_deepdive/2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import mellea.stdlib.functional as mfuncs
from mellea.stdlib.base import SimpleContext, CBlock
from mellea.backends.ollama import OllamaModelBackend

response, next_context = mfuncs.act(CBlock("What is 1+1?"),
context=SimpleContext(),
backend=OllamaModelBackend("granite4:latest"))

print(response.value)
15 changes: 15 additions & 0 deletions docs/rewrite/session_deepdive/3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import mellea.stdlib.functional as mfuncs
from mellea.stdlib.base import SimpleContext, CBlock, Context
from mellea.backends.ollama import OllamaModelBackend
from mellea.backends import Backend
import asyncio


async def main(backend: Backend, ctx: Context):
response, next_context = await mfuncs.aact(CBlock("What is 1+1?"),
context=ctx,
backend=backend)

print(response.value)

asyncio.run(main(OllamaModelBackend("granite4:latest"), SimpleContext()))
19 changes: 19 additions & 0 deletions docs/rewrite/session_deepdive/4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import mellea.stdlib.functional as mfuncs
from mellea.stdlib.base import SimpleContext, CBlock, Context
from mellea.backends.ollama import OllamaModelBackend
from mellea.backends import Backend
import asyncio


async def main(backend: Backend, ctx: Context):
response, next_context = await backend.generate_from_context(
CBlock("What is 1+1?"),
ctx=ctx # TODO we should rationalize ctx and context acress mfuncs and base/backend.
)

print(f"Currently computed: {response.is_computed()}")
print(await response.avalue())
print(f"Currently computed: {response.is_computed()}")


asyncio.run(main(OllamaModelBackend("granite4:latest"), SimpleContext()))
25 changes: 25 additions & 0 deletions docs/rewrite/session_deepdive/5.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import mellea.stdlib.functional as mfuncs
from mellea.stdlib.base import SimpleContext, CBlock, Context, SimpleComponent
from mellea.backends.ollama import OllamaModelBackend
from mellea.backends import Backend
import asyncio


async def main(backend: Backend, ctx: Context):
x, _ = await backend.generate_from_context(CBlock("What is 1+1?"), ctx=ctx)

y, _ = await backend.generate_from_context(CBlock("What is 2+2?"), ctx=ctx)

response, _ = await backend.generate_from_context(
SimpleComponent(instruction="What is x+y?", x=x, y=y),
ctx=ctx # TODO we should rationalize ctx and context acress mfuncs and base/backend.
)

print(f"x currently computed: {x.is_computed()}")
print(f"y currently computed: {y.is_computed()}")
print(f"response currently computed: {response.is_computed()}")
print(await response.avalue())
print(f"response currently computed: {response.is_computed()}")


asyncio.run(main(OllamaModelBackend("granite4:latest"), SimpleContext()))
43 changes: 43 additions & 0 deletions mellea/backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,15 @@
from __future__ import annotations

import abc
import asyncio
import itertools
from typing import TypeVar

import pydantic

from mellea.backends.model_ids import ModelIdentifier
from mellea.backends.types import ModelOption
from mellea.helpers.fancy_logger import FancyLogger
from mellea.stdlib.base import CBlock, Component, Context, GenerateLog, ModelOutputThunk

BaseModelSubclass = TypeVar(
Expand Down Expand Up @@ -76,3 +79,43 @@ async def generate_from_raw(
model_options: Any model options to upsert into the defaults for this call.
tool_calls: Always set to false unless supported by backend.
"""

async def do_generate_walk(
self, action: CBlock | Component | ModelOutputThunk
) -> None:
"""Does the generation walk."""
_to_compute = list(generate_walk(action))
coroutines = [x.avalue() for x in _to_compute]
# The following log message might get noisy. Feel free to remove if so.
if len(_to_compute) > 0:
FancyLogger.get_logger().info(
f"generate_from_chat_context awaited on {len(_to_compute)} uncomputed mots."
)
await asyncio.gather(*coroutines)

async def do_generate_walks(
self, actions: list[CBlock | Component | ModelOutputThunk]
) -> None:
"""Does the generation walk."""
_to_compute = []
for action in actions:
_to_compute.extend(list(generate_walk(action)))
coroutines = [x.avalue() for x in _to_compute]
# The following log message might get noisy. Feel free to remove if so.
if len(_to_compute) > 0:
FancyLogger.get_logger().info(
f"generate_from_chat_context awaited on {len(_to_compute)} uncomputed mots."
)
await asyncio.gather(*coroutines)


def generate_walk(c: CBlock | Component | ModelOutputThunk) -> list[ModelOutputThunk]:
"""Returns the generation walk ordering for a Span."""
match c:
case ModelOutputThunk() if not c.is_computed():
return [c]
case CBlock():
return []
case Component():
parts_walk = [generate_walk(p) for p in c.parts()]
return list(itertools.chain.from_iterable(parts_walk)) # aka flatten
Loading
Loading