Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: reward scoring improvements #166

Merged
merged 19 commits into from
Aug 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
3c5f21b
fix(responses): getting entire synapse and surfacing processing times…
grantdfoster Jun 27, 2024
0c7a392
feat: adds processing time to response logging
grantdfoster Jul 16, 2024
4f57017
fix: remove process times from get_rewards until data validation is i…
grantdfoster Jul 16, 2024
565de71
fix: adds latency to validator responses
grantdfoster Jul 16, 2024
7a23dcf
fix: better sorting on the validator response for now
grantdfoster Jul 16, 2024
66d946b
fix: removes double import
grantdfoster Jul 16, 2024
ccc5764
fix: guards against deployed validators not correctly pinging uids
grantdfoster Jul 16, 2024
e4c167b
fix: helper for sending TAO
grantdfoster Jul 17, 2024
e507649
feat: responses classification
hide-on-bush-x Jul 26, 2024
8c9bec1
feat: wip twitter endpoints set
hide-on-bush-x Jul 31, 2024
6214720
feat: logic fix
hide-on-bush-x Jul 31, 2024
9582690
chore: updates requirements.txt
grantdfoster Jul 31, 2024
7e1faf6
fix: adjust timeout on vali
grantdfoster Jul 31, 2024
8c60da4
Merge branch 'main' into feature/reward-scoring-improvements
grantdfoster Jul 31, 2024
4fb6884
feat: adding support for discord profile and channel/messages, web sc…
hide-on-bush-x Aug 1, 2024
bd9c291
Merge branch 'feature/reward-scoring-improvements' of https://github.…
hide-on-bush-x Aug 1, 2024
923fa76
feat: cleanning logs
hide-on-bush-x Aug 2, 2024
0cbe8bc
Update masa/validator/forwarder.py
hide-on-bush-x Aug 5, 2024
8ee4c3d
Merge branch 'main' into feature/reward-scoring-improvements
hide-on-bush-x Aug 6, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ fund-validator-wallet:
fund-miner-wallet:
btcli wallet faucet --wallet.name miner --subtensor.$(SUBTENSOR_ENVIRONMENT)

## Send TAO
send:
btcli w transfer --subtensor.$(SUBTENSOR_ENVIRONMENT)

## Subnet creation
create-subnet:
btcli subnet create --wallet.name owner --subtensor.$(SUBTENSOR_ENVIRONMENT)
Expand Down
7 changes: 3 additions & 4 deletions masa/miner/web/scraper.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,8 @@ def scrape_web(self, query: WebScraperQuery) -> WebScraperObject:

def format_scraped_data(self, data: requests.Response) -> WebScraperObject:
bt.logging.info(f"Formatting scraped data: {data}")
scraped_data = json.loads(
data.json()["data"]
) # Convert stringified json to dict
formatted_scraped_data = WebScraperObject(**scraped_data)
json_data = data.json()["data"]

formatted_scraped_data = WebScraperObject(**json_data)

return formatted_scraped_data
2 changes: 1 addition & 1 deletion masa/types/web.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@ class Section(TypedDict, total=False):

class WebScraperObject(TypedDict):
sections: Optional[List[Section]]
pages: List[str]
pages: Optional[List[str]]
19 changes: 11 additions & 8 deletions masa/utils/uids.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def check_uid_availability(
"""
# Filter non serving axons.
if not metagraph.axons[uid].is_serving:
bt.logging.info(f"UID: {uid} is not serving")
return False

# Filter out non validator permit.
Expand Down Expand Up @@ -116,26 +117,28 @@ async def get_random_uids(self, k: int, exclude: List[int] = None) -> torch.Long
"""
dendrite = bt.dendrite(wallet=self.wallet)

print("get random uids")

try:
# Generic sanitation
avail_uids = get_available_uids(
self.metagraph, self.config.neuron.vpermit_tao_limit
)
candidate_uids = remove_excluded_uids(avail_uids, exclude)
# healthy_uids = remove_excluded_uids(avail_uids, exclude)

# healthy_uids, _ = await ping_uids(dendrite, self.metagraph, candidate_uids)

# guard against deployed validators not finding any healthy ids via ping...
# if (len(healthy_uids) == 0):
# healthy_uids = candidate_uids

healthy_uids, _ = await ping_uids(dendrite, self.metagraph, candidate_uids)
# filtered_uids = filter_duplicated_axon_ips_for_uids(
# healthy_uids, self.metagraph
# )

k = min(k, len(healthy_uids))
# k = min(k, len(healthy_uids))
# Random sampling
random_sample = random.sample(healthy_uids, k)
print(f"Random sample: {random_sample}")
# random_sample = random.sample(healthy_uids, k)

uids = torch.tensor(random_sample)
uids = torch.tensor(avail_uids)
return uids
except Exception as e:
bt.logging.error(f"Failed to get random miner uids: {e}")
Expand Down
7 changes: 5 additions & 2 deletions masa/validator/discord/all_guilds/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
from masa.api.request import Request, RequestType
from masa.validator.forwarder import Forwarder
from masa.validator.discord.all_guilds.parser import all_guilds_parser
from masa.validator.discord.all_guilds.reward import get_rewards
from masa.miner.masa_protocol_request import REQUEST_TIMEOUT_IN_SECONDS
from masa.miner.discord.all_guilds import DiscordAllGuildsRequest


class DiscordAllGuildsForwarder(Forwarder):
Expand All @@ -32,11 +32,14 @@ def __init__(self, validator):

async def forward_query(self):
try:

def source_method(query):
return DiscordAllGuildsRequest().get_discord_all_guilds()
return await self.forward(
request=Request(type=RequestType.DISCORD_ALL_GUILDS.value),
get_rewards=get_rewards,
parser_method=all_guilds_parser,
timeout=REQUEST_TIMEOUT_IN_SECONDS,
source_method=source_method
)

except Exception as e:
Expand Down
4 changes: 2 additions & 2 deletions masa/validator/discord/channel_messages/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from masa.api.request import Request, RequestType
from masa.validator.forwarder import Forwarder
from masa.validator.discord.channel_messages.parser import channel_messages_parser
from masa.validator.discord.channel_messages.reward import get_rewards
from masa.miner.discord.channel_messages import DiscordChannelMessagesRequest


class DiscordChannelMessagesForwarder(Forwarder):
Expand All @@ -35,8 +35,8 @@ async def forward_query(self, query):
request=Request(
query=query, type=RequestType.DISCORD_CHANNEL_MESSAGES.value
),
get_rewards=get_rewards,
parser_method=channel_messages_parser,
source_method=DiscordChannelMessagesRequest().get_discord_channel_messages
)

except Exception as e:
Expand Down
4 changes: 2 additions & 2 deletions masa/validator/discord/guild_channels/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from masa.api.request import Request, RequestType
from masa.validator.forwarder import Forwarder
from masa.validator.discord.guild_channels.parser import guild_channels_parser
from masa.validator.discord.guild_channels.reward import get_rewards
from masa.miner.discord.guild_channels import DiscordGuildChannelsRequest


class DiscordGuildChannelsForwarder(Forwarder):
Expand All @@ -35,8 +35,8 @@ async def forward_query(self, query):
request=Request(
query=query, type=RequestType.DISCORD_GUILD_CHANNELS.value
),
get_rewards=get_rewards,
parser_method=guild_channels_parser,
source_method=DiscordGuildChannelsRequest().get_discord_guild_channels
)

except Exception as e:
Expand Down
4 changes: 2 additions & 2 deletions masa/validator/discord/profile/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from masa.api.request import Request, RequestType
from masa.types.discord import DiscordProfileObject
from masa.validator.forwarder import Forwarder
from masa.validator.discord.profile.reward import get_rewards
from masa.miner.discord.profile import DiscordProfileRequest


class DiscordProfileForwarder(Forwarder):
Expand All @@ -33,8 +33,8 @@ async def forward_query(self, query):
try:
return await self.forward(
request=Request(query=query, type=RequestType.DISCORD_PROFILE.value),
get_rewards=get_rewards,
parser_object=DiscordProfileObject,
source_method=DiscordProfileRequest().get_profile
)

except Exception as e:
Expand Down
4 changes: 2 additions & 2 deletions masa/validator/discord/user_guilds/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from masa.api.request import Request, RequestType
from masa.validator.forwarder import Forwarder
from masa.validator.discord.user_guilds.parser import user_guilds_parser
from masa.validator.discord.user_guilds.reward import get_rewards
from masa.miner.discord.user_guilds import DiscordUserGuildsRequest


class DiscordUserGuildsForwarder(Forwarder):
Expand All @@ -33,8 +33,8 @@ async def forward_query(self):
try:
return await self.forward(
request=Request(type=RequestType.DISCORD_USER_GUILDS.value),
get_rewards=get_rewards,
parser_method=user_guilds_parser,
source_method=DiscordUserGuildsRequest().get_discord_user_guilds
)

except Exception as e:
Expand Down
165 changes: 137 additions & 28 deletions masa/validator/forwarder.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,32 +19,35 @@

from masa.utils.uids import get_random_uids
import bittensor as bt
import torch
from collections import defaultdict
import math
from sklearn.cluster import KMeans
# this forwarder needs to able to handle multiple requests, driven off of an API request


# this forwarder needs to able to handle multiple requests, driven off of an API request
class Forwarder:
def __init__(self, validator):
self.validator = validator
self.minimum_accepted_score = 0.8

async def forward(
self, request, get_rewards, parser_object=None, parser_method=None, timeout=5
):
# TODO: This should live inside each endpoint to enable us to filter miners by different parameters in the future
# like blacklisting miners only on a specific endpoint like profiles or followers
miner_uids = await get_random_uids(
self.validator, k=self.validator.config.neuron.sample_size
)
async def forward(self, request, parser_object=None, parser_method=None, timeout=5, source_method=None):
miner_uids = await get_random_uids(self.validator, k=self.validator.config.neuron.sample_size)
bt.logging.info("Calling UIDS -----------------------------------------")
bt.logging.info(miner_uids)

if miner_uids is None:
return []

responses = await self.validator.dendrite(
synapses = await self.validator.dendrite(
axons=[self.validator.metagraph.axons[uid] for uid in miner_uids],
synapse=request,
deserialize=True,
timeout=timeout,
deserialize=False,
timeout=timeout
)

responses = [synapse.response for synapse in synapses]

# Filter and parse valid responses
valid_responses, valid_miner_uids = self.sanitize_responses_and_uids(
responses, miner_uids=miner_uids
Expand All @@ -58,33 +61,106 @@ async def forward(
elif parser_method:
parsed_responses = parser_method(valid_responses)

# Score responses
rewards = get_rewards(
self.validator, query=request.query, responses=parsed_responses
)
process_times = [synapse.dendrite.process_time for synapse,
uid in zip(synapses, miner_uids) if uid in valid_miner_uids]

# Update the scores based on the rewards
source_of_truth = await self.get_source_of_truth(
responses=parsed_responses, miner_uids=miner_uids, source_method=source_method, query=request.query)

# Score responses
rewards = self.get_rewards(responses=parsed_responses, source_of_truth=source_of_truth
)
# Update the scores based on the rewards
if len(valid_miner_uids) > 0:
self.validator.update_scores(rewards, valid_miner_uids)
if self.validator.should_set_weights():
try:
self.validator.set_weights()
except Exception as e:
bt.logging.error(f"Failed to set weights: {e}")

# Add corresponding uid to each response
response_with_uids = [
{"response": response, "uid": int(uid.item()), "score": score.item()}
for response, uid, score in zip(parsed_responses, valid_miner_uids, rewards)
responses_with_metadata = [
{"response": response, "uid": int(
uid.item()), "score": score.item(), "latency": latency}
for response, latency, uid, score in zip(parsed_responses, process_times, valid_miner_uids, rewards)
]

responses_with_metadata.sort(key=lambda x: (-x["score"], x["latency"]))
return responses_with_metadata

def get_rewards(
self,
responses: dict,
source_of_truth: dict
) -> torch.FloatTensor:

combined_responses = responses.copy()
combined_responses.append(source_of_truth)

embeddings = self.validator.model.encode(
[str(response) for response in combined_responses])

num_clusters = min(len(combined_responses), 2)
grantdfoster marked this conversation as resolved.
Show resolved Hide resolved
clustering_model = KMeans(n_clusters=num_clusters)
clustering_model.fit(embeddings)
cluster_labels = clustering_model.labels_

source_of_truth_label = cluster_labels[-1] if len(cluster_labels) > 0 else None
bt.logging.info("Source of truth -----------------------------------------")
bt.logging.info(source_of_truth)
bt.logging.info(f"Source of truth label: {source_of_truth_label}")
bt.logging.info(f"labels: {cluster_labels}")
rewards_list = [
1 if cluster_labels[i] == source_of_truth_label else self.calculate_reward(
response, source_of_truth)
for i, response in enumerate(responses)
]

response_with_uids.sort(key=lambda x: x["score"], reverse=True)

print("FINAL RESPONSES ------------------------------------------------")
print(response_with_uids)

return response_with_uids

bt.logging.info("REWARDS LIST ----------------------------------------------")
bt.logging.info(rewards_list)

return torch.FloatTensor(rewards_list).to(
self.validator.device
)

def score_dicts_difference(self, initialScore, dict1, dict2):
score = initialScore

if not isinstance(dict1, dict) and not isinstance(dict2, dict):
if dict1 != dict2:
return max(score - 0.1, 0)
grantdfoster marked this conversation as resolved.
Show resolved Hide resolved
else:
return max(score, 0)

for key in dict1.keys():
if key not in dict2 or dict2[key] is None:
score -= 0.1
elif isinstance(dict1[key], dict) and isinstance(dict2[key], dict):
score = self.score_dicts_difference(score, dict1[key], dict2[key])
elif isinstance(dict1[key], list) and isinstance(dict2[key], list):
if len(dict1[key]) != len(dict2[key]):
length_difference = abs(len(dict1[key]) - len(dict2[key]))
score -= 0.1 * (1 + length_difference)
else:
for item1, item2 in zip(dict1[key], dict2[key]):
score = self.score_dicts_difference(score, item1, item2)
elif str(dict1[key]) != str(dict2[key]):
score -= 0.1

return max(score, 0)

def calculate_reward(self, response: dict, source_of_truth: dict) -> float:

# Return a reward of 0.0 if the response is None
if response is None:
return 0.0

bt.logging.info(f"Getting username from {response}")
response = {'response': response}

score = self.score_dicts_difference(1, source_of_truth, response)
return max(score, 0) # Ensure the score doesn't go below 0

def sanitize_responses_and_uids(self, responses, miner_uids):
valid_responses = [response for response in responses if response is not None]
Expand All @@ -94,3 +170,36 @@ def sanitize_responses_and_uids(self, responses, miner_uids):
if response is not None
]
return valid_responses, valid_miner_uids

async def get_source_of_truth(self, responses, miner_uids, source_method, query):
responses_str = [str(response) for response in responses]
weighted_responses = defaultdict(float)
most_common_response = None
count_high_score_uids = sum(
1 for uid in miner_uids if self.validator.scores[uid] >= self.minimum_accepted_score)
bt.logging.info(
f"Number of UIDs with score greater than the minimum accepted: {count_high_score_uids}")

if (count_high_score_uids > 10):
for response, uid in zip(responses_str, miner_uids):
score = self.validator.scores[uid]
exponential_weight = math.exp(score)

weighted_responses[response] += exponential_weight

most_common_response = max(weighted_responses, key=weighted_responses.get)
else:
if source_method:
most_common_response = source_method(query)

if isinstance(most_common_response, str):
try:
most_common_response = eval(most_common_response)
except Exception as e:
bt.logging.error(
f"Failed to transform most_common_response to dict: {e}")
most_common_response = {}

most_common_response = {'response': most_common_response}

return most_common_response
Loading
Loading