Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: scoring method #53

Open
wants to merge 26 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
58721d5
feat: scoring method
Chkhikvadze Feb 3, 2024
8554e4a
Merge branch 'main' into subnet-to-subnet
Chkhikvadze Feb 3, 2024
99d9a23
fix organic scoring
Chkhikvadze Feb 3, 2024
e2b4f86
fix global dentire
Chkhikvadze Feb 3, 2024
ccdfbe7
fix: scoring route
Chkhikvadze Feb 4, 2024
3562ce6
fix: max token
Chkhikvadze Feb 5, 2024
8db7b5b
fix: model for scoring
Chkhikvadze Feb 5, 2024
dc4ea89
Merge branch 'main' into subnet-to-subnet
Chkhikvadze Feb 8, 2024
da43459
fix: model on scoring
Chkhikvadze Feb 12, 2024
0833cfd
fix: timeout of is live
Chkhikvadze Feb 12, 2024
936494a
fix: scoring and update weight on api
Chkhikvadze Feb 12, 2024
0db7270
fix: scoring and update weight on api
Chkhikvadze Feb 12, 2024
291cf3f
fix: unused line
Chkhikvadze Feb 12, 2024
bf89518
fix: update avalailbe uids 10 minutes
Chkhikvadze Feb 12, 2024
36f11a0
fix: merge text and organic api
Chkhikvadze Feb 12, 2024
c28b3fd
fix: weight response
Chkhikvadze Feb 12, 2024
b56cbf6
fix: scoring for text validator
Chkhikvadze Feb 12, 2024
b80ef37
fix: state json git ignore
Chkhikvadze Feb 13, 2024
51f0981
fix: set weight periodically
Chkhikvadze Feb 13, 2024
9745b60
fix: set weight in perfor synthetic data
Chkhikvadze Feb 13, 2024
c01c0e6
fix: score response default value
Chkhikvadze Feb 13, 2024
4b8ea1a
fix: weight setter
Chkhikvadze Feb 13, 2024
a306247
fix: weights update time
Chkhikvadze Feb 13, 2024
b4a91b9
fix: synthetic data run every 10 minutes
Chkhikvadze Feb 13, 2024
4442de0
refactor: periodicall update
Chkhikvadze Feb 13, 2024
504d693
fix: define avalaible uids
Chkhikvadze Feb 14, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ state.json
wandb/
.vscode
.envrc
.idea/
.idea/
1 change: 0 additions & 1 deletion state.json

This file was deleted.

1 change: 1 addition & 0 deletions template/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
'5FKstHjZkh4v3qAMSBa1oJcHCLjxYZ8SNTSz1opTv4hR7gVB',
'5Dd8gaRNdhm1YP7G1hcB1N842ecAUQmbLjCRLqH5ycaTGrWv',
'5HbLYXUBy1snPR8nfioQ7GoA9x76EELzEq9j7F32vWUQHm1x',
'5H66kJAzBCv2DC9poHATLQqyt3ag8FLSbHf6rMqTiRcS52rc',
] + os.environ.get('CORTEXT_MINER_ADDITIONAL_WHITELIST_VALIDATOR_KEYS', '').split(',')
WHITELISTED_KEYS = testnet_key + test_key + valid_validators
BLACKLISTED_KEYS = ["5G1NjW9YhXLadMWajvTkfcJy6up3yH2q1YzMXDTi6ijanChe"]
Expand Down
30 changes: 18 additions & 12 deletions validators/text_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,20 @@ def __init__(self, dendrite, config, subtensor, wallet: bt.wallet):
"scores": {},
"timestamps": {},
}

async def organic(self, metagraph, query: dict[str, list[dict[str, str]]]) -> AsyncIterator[tuple[int, str]]:
for uid, messages in query.items():
syn = StreamPrompting(messages=messages, model=self.model, seed=self.seed, max_tokens=self.max_tokens, temperature=self.temperature, provider=self.provider, top_p=self.top_p, top_k=self.top_k)
bt.logging.info(
f"Sending {syn.model} {self.query_type} request to uid: {uid}, "
f"timeout {self.timeout}: {syn.messages[0]['content']}"
)


async def organic(self, metagraph, available_uids, messages: dict[str, list[dict[str, str]]]) -> AsyncIterator[tuple[int, str]]:
uid_to_question = {}
if len(messages) <= len(available_uids):
random_uids = random.sample(list(available_uids.keys()), len(messages))
else:
random_uids = [random.choice(list(available_uids.keys())) for _ in range(len(messages))]
for message_dict, uid in zip(messages, random_uids): # Iterate over each dictionary in the list and random_uids
(key, message_list), = message_dict.items()
prompt = message_list[-1]['content']
uid_to_question[uid] = prompt
message = message_list
syn = StreamPrompting(messages=message_list, model=self.model, seed=self.seed, max_tokens=self.max_tokens, temperature=self.temperature, provider=self.provider, top_p=self.top_p, top_k=self.top_k)
bt.logging.info(f"Sending {syn.model} {self.query_type} request to uid: {uid}, timeout {self.timeout}: {message[0]['content']}")
self.wandb_data["prompts"][uid] = messages
responses = await self.dendrite(
metagraph.axons[uid],
Expand All @@ -55,8 +60,8 @@ async def organic(self, metagraph, query: dict[str, list[dict[str, str]]]) -> As
continue

bt.logging.trace(resp)
yield uid, resp

yield uid, key, resp
async def handle_response(self, uid: str, responses) -> tuple[str, str]:
full_response = ""
for resp in responses:
Expand Down Expand Up @@ -120,13 +125,14 @@ async def score_responses(
query_responses: list[tuple[int, str]], # [(uid, response)]
uid_to_question: dict[int, str], # uid -> prompt
metagraph: bt.metagraph,
is_score_all=False
) -> tuple[torch.Tensor, dict[int, float], dict]:
scores = torch.zeros(len(metagraph.hotkeys))
uid_scores_dict = {}
response_tasks = []

# Decide to score all UIDs this round based on a chance
will_score_all = self.should_i_score()
will_score_all = True if is_score_all else self.should_i_score()

for uid, response in query_responses:
self.wandb_data["responses"][uid] = response
Expand Down
48 changes: 32 additions & 16 deletions validators/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import template
from template import utils
import sys

import json
from weight_setter import WeightSetter, TestWeightSetter

text_vali = None
Expand Down Expand Up @@ -107,9 +107,12 @@ def initialize_components(config: bt.config):


def initialize_validators(vali_config, test=False):
global text_vali, image_vali, embed_vali
global text_vali, text_vali_organic, image_vali, embed_vali

text_vali = (TextValidator if not test else TestTextValidator)(**vali_config)
text_vali_organic = (TextValidator if not test else TestTextValidator)(**vali_config)
text_vali_organic.model = 'gpt-3.5-turbo-16k'
text_vali_organic.max_tokens = 8096
image_vali = ImageValidator(**vali_config)
embed_vali = EmbeddingsValidator(**vali_config)
bt.logging.info("initialized_validators")
Expand All @@ -124,27 +127,42 @@ async def process_text_validator(request: web.Request):
access_key = request.headers.get("access-key")
if access_key != EXPECTED_ACCESS_KEY:
return web.Response(status=401, text="Invalid access key")

if len(validator_app.weight_setter.available_uids) == 0:
return web.Response(status=404, text="No available UIDs")

try:
messages_dict = {int(k): [{'role': 'user', 'content': v}] for k, v in (await request.json()).items()}
except ValueError:
return web.Response(status=400, text="Bad request format")
body = await request.json()
messages = body['messages']

response = web.StreamResponse()
await response.prepare(request)

uid_to_response = dict.fromkeys(messages_dict, "")
key_to_response = {}
uid_to_response = {}
try:
async for uid, content in text_vali.organic(validator_app.weight_setter.metagraph, messages_dict):
uid_to_response[uid] += content
await response.write(content.encode())
async for uid, key, content in text_vali_organic.organic(metagraph=validator_app.weight_setter.metagraph,
available_uids=validator_app.weight_setter.available_uids,
messages=messages):
uid_to_response[uid] = uid_to_response.get(uid, '') + content
key_to_response[key] = key_to_response.get(key, '') + content
# await response.write(content.encode())
prompts = {}
for uid, message_dict in zip(uid_to_response.keys(), messages):
(key, message_list), = message_dict.items()
prompt = message_list[-1]['content']
prompts[uid] = prompt # Update prompts correctly for each uid


validator_app.weight_setter.register_text_validator_organic_query(
uid_to_response, {k: v[0]['content'] for k, v in messages_dict.items()}
text_vali=text_vali_organic,
uid_to_response=uid_to_response,
messages_dict=prompts
)
await response.write(json.dumps(key_to_response).encode())

except Exception as e:
bt.logging.error(f'Encountered in {process_text_validator.__name__}:\n{traceback.format_exc()}')
bt.logging.error(f'Encountered in {process_text_validator.__name__}:\n{traceback.format_exc()}, ERROR: {e}')
await response.write(b'<<internal error>>')

return response


Expand All @@ -153,11 +171,9 @@ def __init__(self, *a, **kw):
super().__init__(*a, **kw)
self.weight_setter: WeightSetter | None = None


validator_app = ValidatorApplication()
validator_app.add_routes([web.post('/text-validator/', process_text_validator)])


def main(run_aio_app=True, test=False) -> None:
config = get_config()
wallet, subtensor, dendrite, my_uid = initialize_components(config)
Expand All @@ -177,7 +193,7 @@ def main(run_aio_app=True, test=False) -> None:

if run_aio_app:
try:
web.run_app(validator_app, port=config.http_port, loop=loop)
web.run_app(validator_app, port=config.http_port, loop=loop, shutdown_timeout=120)
except KeyboardInterrupt:
bt.logging.info("Keyboard interrupt detected. Exiting validator.")
finally:
Expand Down
110 changes: 87 additions & 23 deletions validators/weight_setter.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@
import wandb
import os
import shutil

import time
from template.protocol import IsAlive
from text_validator import TextValidator
from image_validator import ImageValidator
from embeddings_validator import EmbeddingsValidator

iterations_per_set_weights = 5
scoring_organic_timeout = 60
scoring_organic_timeout = 120


async def wait_for_coro_with_limit(coro, timeout: int) -> Tuple[bool, object]:
Expand Down Expand Up @@ -47,8 +47,56 @@ def __init__(self, loop: asyncio.AbstractEventLoop, dendrite, subtensor, config,
self.organic_scoring_tasks = set()

self.thread_executor = concurrent.futures.ThreadPoolExecutor(thread_name_prefix='asyncio')

self.steps_passed = 0
self.available_uids = {}
self.loop.create_task(self.update_available_uids_periodically())
self.loop.create_task(self.consume_organic_scoring())

self.loop.create_task(self.perform_synthetic_scoring_and_update_weights())
self.loop.create_task(self.update_weights_periodically())

async def update_weights_periodically(self):
while True:
try:
if len(self.available_uids) == 0 or torch.all(self.total_scores == 0):
await asyncio.sleep(10)
continue

await self.update_weights(self.steps_passed)
except Exception as e:
# Log the exception or handle it as needed
bt.logging.error(f"An error occurred in update_weights_periodically: {e}")
# Optionally, decide whether to continue or break the loop based on the exception
finally:
# Ensure the sleep is in the finally block if you want the loop to always wait,
# even if an error occurs.
await asyncio.sleep(1800) # Sleep for 30 minutes


async def update_available_uids_periodically(self):
while True:
start_time = time.time()
try:
# It's assumed run_sync_in_async is a method that correctly handles running synchronous code in async.
# If not, ensure it's properly implemented to avoid blocking the event loop.
self.metagraph = await self.run_sync_in_async(lambda: self.subtensor.metagraph(self.config.netuid))

# Directly await the asynchronous method without intermediate assignment to self.available_uids,
# unless it's used elsewhere.
available_uids = await self.get_available_uids()
uid_list = self.shuffled(list(available_uids.keys())) # Ensure shuffled is properly defined to work with async.

bt.logging.info(f"update_available_uids_periodically Number of available UIDs for periodic update: {len(uid_list)}, UIDs: {uid_list}")
except Exception as e:
bt.logging.error(f"update_available_uids_periodically Failed to update available UIDs: {e}")
# Consider whether to continue or break the loop upon certain errors.

end_time = time.time()
execution_time = end_time - start_time
bt.logging.info(f"update_available_uids_periodically Execution time for getting available UIDs amount is: {execution_time} seconds")

await asyncio.sleep(600) # 600 seconds = 10 minutes

async def run_sync_in_async(self, fn):
return await self.loop.run_in_executor(self.thread_executor, fn)
Expand Down Expand Up @@ -80,25 +128,19 @@ async def consume_organic_scoring(self):

async def perform_synthetic_scoring_and_update_weights(self):
while True:
for steps_passed in itertools.count():
self.metagraph = await self.run_sync_in_async(lambda: self.subtensor.metagraph(self.config.netuid))

available_uids = await self.get_available_uids()
selected_validator = self.select_validator(steps_passed)
scores, _ = await self.process_modality(selected_validator, available_uids)
self.total_scores += scores

steps_since_last_update = steps_passed % iterations_per_set_weights

if steps_since_last_update == iterations_per_set_weights - 1:
await self.update_weights(steps_passed)
else:
bt.logging.info(
f"Updating weights in {iterations_per_set_weights - steps_since_last_update - 1} iterations."
)

if len(self.available_uids) == 0:
await asyncio.sleep(10)
continue

available_uids = self.available_uids
selected_validator = self.select_validator(self.steps_passed)
scores, _ = await self.process_modality(selected_validator, available_uids)
self.total_scores += scores

self.steps_passed+=1
await asyncio.sleep(600)


def select_validator(self, steps_passed):
return self.text_vali if steps_passed % 5 in (0, 1, 2, 3) else self.image_vali

Expand All @@ -115,7 +157,7 @@ async def get_available_uids(self):
async def check_uid(self, axon, uid):
"""Asynchronously check if a UID is available."""
try:
response = await self.dendrite(axon, IsAlive(), deserialize=False, timeout=4)
response = await self.dendrite(axon, IsAlive(), deserialize=False, timeout=15)
if response.is_success:
bt.logging.trace(f"UID {uid} is active")
return axon # Return the axon info instead of the UID
Expand Down Expand Up @@ -181,21 +223,43 @@ async def set_weights(self, scores):
)
bt.logging.success("Successfully set weights.")

def handle_task_result_organic_query(self, task):
try:
success, data = task.result()
if success:
scores, uid_scores_dict, wandb_data = data
if self.config.wandb_on:
wandb.log(wandb_data)
bt.logging.success("wandb_log successful")
self.total_scores += scores
bt.logging.success(f"Task completed successfully. Scores updated.")
else:
bt.logging.error("Task failed. No scores updated.")
except Exception as e:
# Handle exceptions raised during task execution
bt.logging.error(f"handle_task_result_organic_query An error occurred during task execution: {e}")

def register_text_validator_organic_query(
self,
text_vali,
uid_to_response: dict[int, str], # [(uid, response)]
messages_dict: dict[int, str],
):
self.organic_scoring_tasks.add(asyncio.create_task(
self.steps_passed += 1

task = asyncio.create_task(
wait_for_coro_with_limit(
self.text_vali.score_responses(
text_vali.score_responses(
query_responses=list(uid_to_response.items()),
uid_to_question=messages_dict,
metagraph=self.metagraph,
is_score_all=True
),
scoring_organic_timeout
)
))
)
task.add_done_callback(self.handle_task_result_organic_query) # Attach the callback
self.organic_scoring_tasks.add(task)


class TestWeightSetter(WeightSetter):
Expand Down
Loading