From 00d9b231e1cdef10c3fea7e6cc0b485bf4208031 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20Lanzend=C3=B6rfer?= <leviathan@libresilicon.com> Date: Thu, 4 Jul 2024 23:20:17 +0100 Subject: [PATCH] Inference test --- src/py/tty4.py | 47 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) create mode 100644 src/py/tty4.py diff --git a/src/py/tty4.py b/src/py/tty4.py new file mode 100644 index 0000000..4f2e9a4 --- /dev/null +++ b/src/py/tty4.py @@ -0,0 +1,47 @@ +import json + +from fac_tools import run_command +from fac_tools import get_fac_wrapper +from fac_tools import load_weights_and_biases +from fac_tools import predict_next_token +from fac_tools import reset_states + +#from transformers import AutoTokenizer + +#tokenizer = AutoTokenizer.from_pretrained("gpt2") +#help(tokenizer) + +#prompt = "Taxation" +#input_ids = tokenizer(prompt, return_tensors="pt").input_ids +#tokens = input_ids[0].numpy() +tokens = [27017, 341, 318, 12402] # Taxation is theft + +server = get_fac_wrapper("telnet") + +print("Reading in JSON") +with open("test_files/weights_and_biases_trained.json", "r") as f: + weights_and_biases = json.loads(f.read()) + f.close() + +load_weights_and_biases(server, weights_and_biases) + +#print("Input string: ", prompt) +#print("Input Tokens: ", tokens) + +#retarr=[] +#ret = 0 + +reset_states(server) + +print("Input Token: ", hex(tokens[0])) +tok = predict_next_token(server,tokens[0]) +print("Output Token: ", hex(tok)) +print("Expected Token: ", hex(tokens[1])) + +#print("Output string: ", tokenizer.decode(retarr)) + +run_command(server,"DONE") + +run_command(server,"TERMINATE") + +server.close() -- GitLab