diff --git a/frontend/app/components/op_profile/op_details/op_details.ng.html b/frontend/app/components/op_profile/op_details/op_details.ng.html index ecbb69c2e..c33dd817d 100644 --- a/frontend/app/components/op_profile/op_details/op_details.ng.html +++ b/frontend/app/components/op_profile/op_details/op_details.ng.html @@ -74,19 +74,19 @@
-
FLOP rate (per core):
+
{{getTitleByDeviceType('FLOP rate', ':')}}
{{flopsRate}}
-
HBM bandwidth (per core):
+
{{getTitleByDeviceType('HBM bandwidth', ':')}}
{{bandwidths[memBwType.MEM_BW_TYPE_HBM_RW]}}
-
On-chip Read bandwidth (per core):
+
{{getTitleByDeviceType('On-chip Read bandwidth', ':')}}
{{bandwidths[memBwType.MEM_BW_TYPE_SRAM_RD]}}
-
On-chip Write bandwidth (per core):
+
{{getTitleByDeviceType('On-chip Write bandwidth', ':')}}
{{bandwidths[memBwType.MEM_BW_TYPE_SRAM_WR]}}
diff --git a/frontend/app/components/op_profile/op_details/op_details.ts b/frontend/app/components/op_profile/op_details/op_details.ts index 6d4895ac6..fd9ccb147 100644 --- a/frontend/app/components/op_profile/op_details/op_details.ts +++ b/frontend/app/components/op_profile/op_details/op_details.ts @@ -3,7 +3,8 @@ import {Store} from '@ngrx/store'; import {Node} from 'org_xprof/frontend/app/common/interfaces/op_profile.jsonpb_decls'; import {NavigationEvent} from 'org_xprof/frontend/app/common/interfaces/navigation_event'; import * as utils from 'org_xprof/frontend/app/common/utils/utils'; -import {getActiveOpProfileNodeState, getCurrentRun, getOpProfileRootNode, getSelectedOpNodeChainState} from 'org_xprof/frontend/app/store/selectors'; +import {getActiveOpProfileNodeState, getCurrentRun, getOpProfileRootNode, getProfilingGeneralState, getSelectedOpNodeChainState} from 'org_xprof/frontend/app/store/selectors'; +import {ProfilingGeneralState} from 'org_xprof/frontend/app/store/state'; import {Observable, ReplaySubject} from 'rxjs'; import {takeUntil} from 'rxjs/operators'; @@ -60,6 +61,8 @@ export class OpDetails { memBwType = utils.MemBwType; currentRun = ''; showUtilizationWarning = false; + deviceType = 'TPU'; + constructor( private readonly store: Store<{}>, @@ -81,6 +84,14 @@ export class OpDetails { .subscribe((node: Node|null) => { this.rootNode = node || undefined; }); + this.store.select(getProfilingGeneralState) + .pipe(takeUntil(this.destroyed)) + .subscribe((generalState: ProfilingGeneralState|null) => { + this.deviceType = (generalState && generalState.deviceType) ? + generalState.deviceType : + 'TPU'; + }); + this.currentRun$.subscribe(run => { if (run) { this.currentRun = run; @@ -88,6 +99,16 @@ export class OpDetails { }); } + getTitleByDeviceType(titlePrefix: string, titleSuffix: string) { + if (this.deviceType === 'GPU') { + return `${titlePrefix} (per gpu)${titleSuffix}`; + } else if (this.deviceType === 'TPU') { + return `${titlePrefix} (per core)${titleSuffix}`; + } else { + return `${titlePrefix}${titleSuffix}`; + } + } + hasValidGraphViewerLink() { const aggregatedBy = this.selectedOpNodeChain[0]; if (aggregatedBy === 'by_category' && this.moduleList.length > 1) { diff --git a/frontend/app/components/op_profile/op_profile.ts b/frontend/app/components/op_profile/op_profile.ts index 5e329b2ed..88deb3bde 100644 --- a/frontend/app/components/op_profile/op_profile.ts +++ b/frontend/app/components/op_profile/op_profile.ts @@ -4,7 +4,7 @@ import {Store} from '@ngrx/store'; import {OpProfileProto} from 'org_xprof/frontend/app/common/interfaces/data_table'; import {NavigationEvent} from 'org_xprof/frontend/app/common/interfaces/navigation_event'; import {DataService} from 'org_xprof/frontend/app/services/data_service/data_service'; -import {setLoadingStateAction, setOpProfileRootNodeAction} from 'org_xprof/frontend/app/store/actions'; +import {setLoadingStateAction, setOpProfileRootNodeAction, setProfilingDeviceTypeAction} from 'org_xprof/frontend/app/store/actions'; import {ReplaySubject} from 'rxjs'; import {takeUntil} from 'rxjs/operators'; @@ -47,7 +47,12 @@ export class OpProfile extends OpProfileBase implements OnDestroy { message: '', } })); - this.parseData(data as OpProfileProto | null); + if (data) { + const profileProtoData = data as OpProfileProto; + this.store.dispatch(setProfilingDeviceTypeAction( + {deviceType: profileProtoData.deviceType})); + } + this.parseData(data as (OpProfileProto | null)); this.store.dispatch( setOpProfileRootNodeAction({rootNode: this.rootNode || null})); }); diff --git a/frontend/app/store/actions.ts b/frontend/app/store/actions.ts index 6d5e593cc..25ced4290 100644 --- a/frontend/app/store/actions.ts +++ b/frontend/app/store/actions.ts @@ -97,6 +97,12 @@ export const setCurrentRunAction: ActionCreatorAny = createAction( props<{currentRun: string}>(), ); +/** Action to set deviceType */ +export const setProfilingDeviceTypeAction: ActionCreatorAny = createAction( + '[App State] Set profiling device type', + props<{deviceType: string | null}>(), +); + /** Action to set run tools map */ export const setRunToolsMapAction: ActionCreatorAny = createAction('[App State] set run - tools map state', props()); diff --git a/frontend/app/store/reducers.ts b/frontend/app/store/reducers.ts index fc171c834..d1f221952 100644 --- a/frontend/app/store/reducers.ts +++ b/frontend/app/store/reducers.ts @@ -60,6 +60,18 @@ export const reducer: ActionReducer = createReducer( }; }, ), + on( + actions.setProfilingDeviceTypeAction, + (state: AppState, action: ActionCreatorAny) => { + return { + ...state, + profilingGeneralState: { + ...state.profilingGeneralState, + deviceType: action.deviceType, + } + }; + }, + ), on( actions.setActivePodViewerInfoAction, (state: AppState, action: ActionCreatorAny) => { diff --git a/frontend/app/store/selectors.ts b/frontend/app/store/selectors.ts index 79e029d3b..bbeab7a76 100644 --- a/frontend/app/store/selectors.ts +++ b/frontend/app/store/selectors.ts @@ -34,6 +34,10 @@ export const getOpProfileRootNode: MemoizedSelectorAny = createSelector( getOpProfileState, (opProfileState: OpProfileState) => opProfileState.rootNode); +/** Selector for getProfilingGeneralState */ +export const getProfilingGeneralState: MemoizedSelectorAny = createSelector( + appState, (appState: AppState) => appState.profilingGeneralState); + /** Selector for PodViewerState */ export const getPodViewerState: MemoizedSelectorAny = createSelector(appState, (appState: AppState) => appState.podViewerState); diff --git a/frontend/app/store/state.ts b/frontend/app/store/state.ts index d8219ef8e..70e090144 100644 --- a/frontend/app/store/state.ts +++ b/frontend/app/store/state.ts @@ -22,6 +22,11 @@ export interface OpProfileState { rootNode: ActiveOpProfileNodeState; } +/** General State of the Profiling */ +export interface ProfilingGeneralState { + deviceType: string; +} + /** Type for active pod viewer info state */ type ActivePodViewerInfoState = AllReduceOpInfo|ChannelInfo|PodStatsRecord|null; @@ -93,6 +98,7 @@ export interface AppState { dataRequest: DataRequest; runToolsMap: RunToolsMap; currentRun: string; + profilingGeneralState: ProfilingGeneralState; } /** Initial state of active heap object */ @@ -113,6 +119,11 @@ export const INIT_OP_PROFILE_STATE: OpProfileState = { rootNode: INIT_ACTIVE_OP_PROFILE_NODE_STATE, }; +/** Initial general profiling state */ +export const INIT_PROFILING_GENERAL_STATE: ProfilingGeneralState = { + deviceType: 'TPU', +}; + /** Initial state of active pod viewer info */ const INIT_ACTIVE_POD_VIEWER_INFO_STATE: ActivePodViewerInfoState = null; @@ -172,6 +183,7 @@ export const INIT_APP_STATE: AppState = { errorMessage: INIT_ERROR_MESSAGE_STATE, runToolsMap: INIT_RUN_TOOLS_MAP, currentRun: INIT_CURRENT_RUN, + profilingGeneralState: INIT_PROFILING_GENERAL_STATE, }; /** Feature key for store */