Skip to content

Commit b3faf30

Browse files
committed
beginning of refactor for Closes #83.
1 parent 1f2a933 commit b3faf30

File tree

14 files changed

+414
-442
lines changed

14 files changed

+414
-442
lines changed

ell-studio/src/components/DependencyGraphPane.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ import {DependencyGraph} from './depgraph/DependencyGraph';
66
// When changing pages we need to rerender this component (or create a new graph)
77
const DependencyGraphPane = ({ lmp, uses }) => {
88
const lmps = [lmp, ...uses];
9-
9+
console.log(uses)
1010
return (
1111
<DependencyGraph lmps={lmps} />
1212
);

ell-studio/src/components/depgraph/graphUtils.js

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -209,10 +209,10 @@ export function getInitialGraph(lmps, traces) {
209209
const deadNodes = lmps
210210
.filter((x) => !!x)
211211
.flatMap((lmp) =>
212-
(lmp.uses || []).filter(use => !lmpIds.has(use)).map(use => ({
213-
id: `${use}`,
212+
(lmp.uses || []).filter(use => !lmpIds.has(use.lmp_id)).map(use => ({
213+
id: `${use.lmp_id}`,
214214
type: "lmp",
215-
data: { label: `Unknown LMP (${use})`, lmp: { lmp_id: use, name: `Out of Date LMP (${use})`, version_number: -2 } },
215+
data: { label: `Outdated LMP ${use.name}`, lmp: { lmp_id: use.lmp_id, name: `Outdated LMP (${use.name})`, version_number: use.version_number } },
216216
position: { x: 0, y: 0 },
217217
style: { opacity: 0.5 }, // Make dead nodes visually distinct
218218
}))
@@ -228,9 +228,9 @@ export function getInitialGraph(lmps, traces) {
228228
return (
229229
lmp?.uses?.map((use) => {
230230
return {
231-
id: `uses-${lmp.lmp_id}-${use}`,
231+
id: `uses-${lmp.lmp_id}-${use.lmp_id}`,
232232
target: `${lmp.lmp_id}`,
233-
source: `${use}`,
233+
source: `${use.lmp_id}`,
234234
animated: false,
235235
type: "default",
236236
};

ell-studio/src/components/source/LMPSourceView.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ const LMPSourceView = ({ lmp, showDependenciesInitial = false, selectedInvocatio
4646
const trimmedDependencies = dependencies.trim();
4747
const dependencyLines = trimmedDependencies ? trimmedDependencies.split('\n').length : 0;
4848
const sourceLines = source.split('\n').length;
49-
const dependentLMPs = uses.length;
49+
5050

5151
const boundedVariableHooks = useMemo(() => {
5252
const mutableBVWrapper = ({ children, key, content }) => (

ell-studio/src/hooks/useBackend.js

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -89,21 +89,6 @@ export const useInvocation = (id) => {
8989
}
9090

9191

92-
export const useMultipleLMPs = (usesIds) => {
93-
const multipleLMPs = useQueries({
94-
queries: (usesIds || []).map(use => ({
95-
queryKey: ['lmp', use],
96-
queryFn: async () => {
97-
const useResponse = await axios.get(`${API_BASE_URL}/api/lmp/${use}`);
98-
return useResponse.data;
99-
},
100-
enabled: !!use,
101-
})),
102-
});
103-
const isLoading = multipleLMPs.some(query => query.isLoading);
104-
const data = multipleLMPs.map(query => query.data);
105-
return { isLoading, data };
106-
};
10792

10893
export const useLatestLMPs = (page = 0, pageSize = 100) => {
10994
return useQuery({

ell-studio/src/pages/LMP.js

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ import {
88
import {
99
useLMPs,
1010
useInvocationsFromLMP,
11-
useMultipleLMPs,
1211
useInvocation,
1312
} from "../hooks/useBackend";
1413
import InvocationsTable from "../components/invocations/InvocationsTable";
@@ -55,7 +54,8 @@ function LMP() {
5554
pageSize,
5655
true // dangerous hierarchical query that will not scale to unique invocations
5756
);
58-
const { data: uses } = useMultipleLMPs(lmp?.uses);
57+
const uses = lmp?.uses;
58+
5959

6060
const [activeTab, setActiveTab] = useState("runs");
6161
const [selectedTrace, setSelectedTrace] = useState(null);

src/ell/decorators/track.py

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import logging
22
import threading
3-
from ell.types import SerializedLStr, utc_now
3+
from ell.types import SerializedLStr, utc_now, SerializedLMP, Invocation, InvocationTrace
44
import ell.util.closure
55
from ell.configurator import config
66
from ell.lstr import lstr
@@ -94,9 +94,8 @@ def wrapper(*fn_args, **fn_kwargs) -> str:
9494
state_cache_key = compute_state_cache_key(ipstr, func_to_track.__ell_closure__)
9595

9696
cache_store = func_to_track.__wrapper__.__ell_use_cache__
97-
cached_invocations = cache_store.get_invocations(lmp_filters=dict(lmp_id=func_to_track.__ell_hash__), filters=dict(
98-
state_cache_key=state_cache_key
99-
))
97+
cached_invocations = cache_store.get_cached_invocations(func_to_track.__ell_hash__, state_cache_key)
98+
10099

101100
if len(cached_invocations) > 0:
102101
# TODO THis is bad?
@@ -154,7 +153,7 @@ def wrapper(*fn_args, **fn_kwargs) -> str:
154153
return wrapper
155154

156155
def _serialize_lmp(func, name, fn_closure, is_lmp, lm_kwargs):
157-
lmps = config._store.get_lmps(name=name)
156+
lmps = config._store.get_versions_by_fqn(fqn=name)
158157
version = 0
159158
already_in_store = any(lmp['lmp_id'] == func.__ell_hash__ for lmp in lmps)
160159

@@ -170,24 +169,25 @@ def _serialize_lmp(func, name, fn_closure, is_lmp, lm_kwargs):
170169
else:
171170
commit = None
172171

173-
config._store.write_lmp(
172+
serialized_lmp = SerializedLMP(
174173
lmp_id=func.__ell_hash__,
175174
name=name,
176175
created_at=utc_now(),
177176
source=fn_closure[0],
178177
dependencies=fn_closure[1],
179178
commit_message=commit,
180-
global_vars=get_immutable_vars(func.__ell_closure__[2]),
181-
free_vars=get_immutable_vars(func.__ell_closure__[3]),
182-
is_lmp=is_lmp,
179+
initial_global_vars=get_immutable_vars(fn_closure[2]),
180+
initial_free_vars=get_immutable_vars(fn_closure[3]),
181+
is_lm=is_lmp,
183182
lm_kwargs=lm_kwargs if lm_kwargs else None,
184183
version_number=version,
185-
uses=func.__ell_uses__,
186184
)
187185

186+
config._store.write_lmp(serialized_lmp, func.__ell_uses__)
187+
188188
def _write_invocation(func, invocation_id, latency_ms, prompt_tokens, completion_tokens,
189189
state_cache_key, invocation_kwargs, cleaned_invocation_params, consumes, result, parent_invocation_id):
190-
config._store.write_invocation(
190+
invocation = Invocation(
191191
id=invocation_id,
192192
lmp_id=func.__ell_hash__,
193193
created_at=utc_now(),
@@ -198,12 +198,27 @@ def _write_invocation(func, invocation_id, latency_ms, prompt_tokens, completion
198198
completion_tokens=completion_tokens,
199199
state_cache_key=state_cache_key,
200200
invocation_kwargs=invocation_kwargs,
201-
**cleaned_invocation_params,
202-
consumes=consumes,
203-
result=result,
204-
parent_invocation_id=parent_invocation_id
201+
args=cleaned_invocation_params.get('args', []),
202+
kwargs=cleaned_invocation_params.get('kwargs', {}),
203+
used_by_id=parent_invocation_id
205204
)
206205

206+
results = []
207+
if isinstance(result, lstr):
208+
results = [result]
209+
elif isinstance(result, list):
210+
results = result
211+
else:
212+
raise TypeError("Result must be either lstr or List[lstr]")
213+
214+
serialized_results = [
215+
SerializedLStr(
216+
content=str(res),
217+
logits=res.logits
218+
) for res in results
219+
]
220+
221+
config._store.write_invocation(invocation, serialized_results, consumes)
207222

208223
def compute_state_cache_key(ipstr, fn_closure):
209224
_global_free_vars_str = f"{json.dumps(get_immutable_vars(fn_closure[2]), sort_keys=True, default=repr)}"
@@ -270,6 +285,4 @@ def process_lstr(obj):
270285
# TODO: This is a hack fix it.
271286
# XXX: Unify this with above so that we don't have to do this.
272287
# XXX: I really think there is some standard var explorer we can leverage from from ipython or someshit.
273-
return json.loads(jstr), jstr, consumes
274-
275-
288+
return json.loads(jstr), jstr, consumes

src/ell/store.py

Lines changed: 52 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from datetime import datetime
44
from typing import Any, Optional, Dict, List, Set, Union
55
from ell.lstr import lstr
6-
from ell.types import InvocableLM
6+
from ell.types import InvocableLM, SerializedLMP, Invocation, SerializedLStr
77

88

99
class Store(ABC):
@@ -12,111 +12,106 @@ class Store(ABC):
1212
"""
1313

1414
@abstractmethod
15-
def write_lmp(self, lmp_id: str, name: str, source: str, dependencies: List[str], is_lmp: bool, lm_kwargs: str,
16-
version_number: int,
17-
uses: Dict[str, Any],
18-
commit_message: Optional[str] = None,
19-
created_at: Optional[datetime]=None) -> Optional[Any]:
15+
def write_lmp(self, serialized_lmp: SerializedLMP, uses: Dict[str, Any]) -> Optional[Any]:
2016
"""
2117
Write an LMP (Language Model Package) to the storage.
2218
23-
:param lmp_id: Unique identifier for the LMP.
24-
:param name: Name of the LMP.
25-
:param source: Source code or reference for the LMP.
26-
:param dependencies: List of dependencies for the LMP.
27-
:param is_lmp: Boolean indicating if it is an LMP.
28-
:param lm_kwargs: Additional keyword arguments for the LMP.
19+
:param serialized_lmp: SerializedLMP object containing all LMP details.
2920
:param uses: Dictionary of LMPs used by this LMP.
30-
:param created_at: Optional timestamp of when the LMP was created.
3121
:return: Optional return value.
3222
"""
3323
pass
3424

3525
@abstractmethod
36-
def write_invocation(self, id: str, lmp_id: str, args: str, kwargs: str, result: Union[lstr, List[lstr]], invocation_kwargs: Dict[str, Any],
37-
created_at: Optional[datetime], consumes: Set[str], prompt_tokens: Optional[int] = None,
38-
completion_tokens: Optional[int] = None, latency_ms: Optional[float] = None,
39-
state_cache_key: Optional[str] = None,
40-
cost_estimate: Optional[float] = None) -> Optional[Any]:
26+
def write_invocation(self, invocation: Invocation, results: List[SerializedLStr], consumes: Set[str]) -> Optional[Any]:
4127
"""
4228
Write an invocation of an LMP to the storage.
4329
44-
:param id: Unique identifier for the invocation.
45-
:param lmp_id: Unique identifier for the LMP.
46-
:param args: Arguments used in the invocation.
47-
:param kwargs: Keyword arguments used in the invocation.
48-
:param result: Result of the invocation.
49-
:param invocation_kwargs: Additional keyword arguments for the invocation.
50-
:param created_at: Optional timestamp of when the invocation was created.
30+
:param invocation: Invocation object containing all invocation details.
31+
:param results: List of SerializedLStr objects representing the results.
5132
:param consumes: Set of invocation IDs consumed by this invocation.
52-
:param prompt_tokens: Optional number of prompt tokens used.
53-
:param completion_tokens: Optional number of completion tokens used.
54-
:param latency_ms: Optional latency in milliseconds.
55-
:param cost_estimate: Optional estimated cost of the invocation.
5633
:return: Optional return value.
5734
"""
5835
pass
5936

6037
@abstractmethod
61-
def get_lmps(self, **filters: Optional[Dict[str, Any]]) -> List[Dict[str, Any]]:
38+
def get_cached_invocations(self, lmp_id :str, state_cache_key :str) -> List[Invocation]:
6239
"""
63-
Retrieve LMPs from the storage.
64-
65-
:param filters: Optional dictionary of filters to apply.
66-
:return: List of LMPs.
40+
Get cached invocations for a given LMP and state cache key.
6741
"""
6842
pass
6943

7044
@abstractmethod
71-
def get_invocations(self, lmp_id: str, filters: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]:
45+
def get_versions_by_fqn(self, fqn :str) -> List[SerializedLMP]:
7246
"""
73-
Retrieve invocations of an LMP from the storage.
74-
75-
:param lmp_id: Unique identifier for the LMP.
76-
:param filters: Optional dictionary of filters to apply.
77-
:return: List of invocations.
47+
Get all versions of an LMP by its fully qualified name.
7848
"""
7949
pass
8050

8151
# @abstractmethod
82-
# def search_lmps(self, query: str) -> List[Dict[str, Any]]:
52+
# def get_lmps(self, skip: int = 0, limit: int = 10, subquery=None, **filters: Optional[Dict[str, Any]]) -> List[Dict[str, Any]]:
8353
# """
84-
# Search for LMPs in the storage.
54+
# Retrieve LMPs from the storage.
55+
56+
# :param skip: Number of records to skip.
57+
# :param limit: Maximum number of records to return.
58+
# :param subquery: Optional subquery for filtering.
59+
# :param filters: Optional dictionary of filters to apply.
60+
# :return: List of LMPs.
61+
# """
62+
# pass
8563

86-
# :param query: Search query string.
87-
# :return: List of LMPs matching the query.
64+
# @abstractmethod
65+
# def get_invocations(self, lmp_filters: Dict[str, Any], skip: int = 0, limit: int = 10, filters: Optional[Dict[str, Any]] = None, hierarchical: bool = False) -> List[Dict[str, Any]]:
66+
# """
67+
# Retrieve invocations of an LMP from the storage.
68+
69+
# :param lmp_filters: Filters to apply on the LMP level.
70+
# :param skip: Number of records to skip.
71+
# :param limit: Maximum number of records to return.
72+
# :param filters: Optional dictionary of filters to apply on the invocation level.
73+
# :param hierarchical: Whether to include hierarchical information.
74+
# :return: List of invocations.
8875
# """
8976
# pass
9077

9178
# @abstractmethod
92-
# def search_invocations(self, query: str) -> List[Dict[str, Any]]:
79+
# def get_latest_lmps(self, skip: int = 0, limit: int = 10) -> List[Dict[str, Any]]:
9380
# """
94-
# Search for invocations in the storage.
81+
# Retrieve the latest versions of all LMPs from the storage.
9582

96-
# :param query: Search query string.
97-
# :return: List of invocations matching the query.
83+
# :param skip: Number of records to skip.
84+
# :param limit: Maximum number of records to return.
85+
# :return: List of the latest LMPs.
9886
# """
9987
# pass
10088

89+
# @abstractmethod
90+
# def get_traces(self) -> List[Dict[str, Any]]:
91+
# """
92+
# Retrieve all traces from the storage.
10193

102-
@abstractmethod
103-
def get_latest_lmps(self) -> List[Dict[str, Any]]:
104-
"""
105-
Retrieve the latest versions of all LMPs from the storage.
94+
# :return: List of traces.
95+
# """
96+
# pass
10697

107-
:return: List of the latest LMPs.
108-
"""
109-
pass
98+
# @abstractmethod
99+
# def get_all_traces_leading_to(self, invocation_id: str) -> List[Dict[str, Any]]:
100+
# """
101+
# Retrieve all traces leading to a specific invocation.
110102

103+
# :param invocation_id: ID of the invocation to trace.
104+
# :return: List of traces leading to the specified invocation.
105+
# """
106+
# pass
111107

112108
@contextmanager
113109
def freeze(self, *lmps: InvocableLM):
114110
"""
115111
A context manager for caching operations using a particular store.
116112
117113
Args:
118-
key (Optional[str]): The cache key. If None, a default key will be generated.
119-
condition (Optional[Callable[..., bool]]): A function that determines whether to cache or not.
114+
*lmps: InvocableLM objects to freeze.
120115
121116
Yields:
122117
None

0 commit comments

Comments
 (0)