Well Well Well
425 points
Challenge
Hmm, I think there is a way to make me the Best Hedgehog ever! Help me do it, and I’ll reward you generously! ~ Jaga, the Cybersecurity Hedgehog You seem to have found yourself at the bottom of an old well, with no obvious way to get out. The only thing you can find on the ground is a pile of bones next to a dusty old personal assistant device. You don’t see its owner anywhere, so surely they managed to escape.
You try scraping the device for clues, but almost everything is encrypted. The only usable artefact left behind came from the model’s internals, a cache of some sort. You hope this is enough to get somewhere.
Solution
We run a greedy brute force search across each token, at each position ID, to find the token which produces past_key_values that are the closest to the keys at that position in the KV cache we’ve been given.
We use as large a batch size as possible when searching over tokens to reduce the time taken.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
PATH_CACHE = "./kv_cache.pt"
DEVICE = torch.device("cuda")
DTYPE = torch.float32
@torch.no_grad()
def solve():
kv_cache = torch.load(PATH_CACHE, map_location=DEVICE)
K_rot = kv_cache["K_rot"].to(DTYPE)
T = int(kv_cache["T"])
ckpt = kv_cache["model"]
rev = kv_cache["revision"]
model = AutoModelForCausalLM.from_pretrained(ckpt, revision=rev).to(DEVICE).eval()
tok = AutoTokenizer.from_pretrained(ckpt, revision=rev, use_fast=True)
special_ids = set(tok.all_special_ids or [])
cand_ids = torch.tensor(
[i for i in range(tok.vocab_size) if i not in special_ids],
dtype=torch.long
)
B = 6144
result = []
for t in range(T):
target = K_rot[:,t,:].unsqueeze(0)
best_err = 1e20
best_id = None
for start in range(0, len(cand_ids), B):
batch_ids = cand_ids[start:start+B]
input_ids = batch_ids.view(-1, 1).to(DEVICE)
position_ids = torch.full_like(input_ids, t)
outputs = model(
input_ids=input_ids,
use_cache=True,
position_ids=position_ids,
return_dict=True
)
k0 = outputs.past_key_values[0][0]
k_batch = k0.squeeze(2).contiguous().to(DTYPE)
diffs = k_batch - target
errs = (diffs * diffs).sum(dim=(1,2))
min_err, min_idx = errs.min(dim=0)
if float(min_err) < best_err:
best_err = float(min_err)
best_id = int(batch_ids[min_idx])
result.append(best_id)
print(tok.decode(result, clean_up_tokenization_spaces=False))
print()
print(tok.decode(result, clean_up_tokenization_spaces=False))
solve()
