44from typing import (
55 Iterable ,
66 Tuple ,
7+ Type ,
78)
89from cytoolz import (
910 first ,
3839)
3940
4041from eth .beacon .types .states import BeaconState # noqa: F401
41- from eth .beacon .types .blocks import BaseBeaconBlock # noqa: F401
42+ from eth .beacon .types .blocks import ( # noqa: F401
43+ BaseBeaconBlock ,
44+ BeaconBlock ,
45+ )
4246from eth .beacon .validation import (
4347 validate_slot ,
4448)
4852
4953class BaseBeaconChainDB (ABC ):
5054 db = None # type: BaseAtomicDB
55+ block_class = None # type: Type[BaseBeaconBlock]
56+
57+ @abstractmethod
58+ def set_block_class (self , block_class : Type [BaseBeaconBlock ]) -> None :
59+ pass
5160
5261 #
5362 # Block API
@@ -117,24 +126,33 @@ def get(self, key: bytes) -> bytes:
117126
118127
119128class BeaconChainDB (BaseBeaconChainDB ):
120- def __init__ (self , db : BaseAtomicDB ) -> None :
129+ def __init__ (self , db : BaseAtomicDB , block_class : Type [ BaseBeaconBlock ] ) -> None :
121130 self .db = db
131+ self .block_class = block_class
132+
133+ def set_block_class (self , block_class : Type [BaseBeaconBlock ]) -> None :
134+ self .block_class = block_class
122135
123136 def persist_block (self ,
124137 block : BaseBeaconBlock ) -> Tuple [Tuple [bytes , ...], Tuple [bytes , ...]]:
125138 """
126139 Persist the given block.
127140 """
128141 with self .db .atomic_batch () as db :
129- return self ._persist_block (db , block )
142+ return self ._persist_block (db , block , self . block_class )
130143
131144 @classmethod
132145 def _persist_block (
133146 cls ,
134147 db : 'BaseDB' ,
135- block : BaseBeaconBlock ) -> Tuple [Tuple [bytes , ...], Tuple [bytes , ...]]:
148+ block : BaseBeaconBlock ,
149+ block_class : Type [BaseBeaconBlock ]) -> Tuple [Tuple [bytes , ...], Tuple [bytes , ...]]:
136150 block_chain = (block , )
137- new_canonical_blocks , old_canonical_blocks = cls ._persist_block_chain (db , block_chain )
151+ new_canonical_blocks , old_canonical_blocks = cls ._persist_block_chain (
152+ db ,
153+ block_chain ,
154+ block_class ,
155+ )
138156
139157 return new_canonical_blocks , old_canonical_blocks
140158
@@ -176,15 +194,16 @@ def get_canonical_block_by_slot(self, slot: int) -> BaseBeaconBlock:
176194 Raise BlockNotFound if there's no block with the given slot in the
177195 canonical chain.
178196 """
179- return self ._get_canonical_block_by_slot (self .db , slot )
197+ return self ._get_canonical_block_by_slot (self .db , slot , self . block_class )
180198
181199 @classmethod
182200 def _get_canonical_block_by_slot (
183201 cls ,
184202 db : BaseDB ,
185- slot : int ) -> BaseBeaconBlock :
203+ slot : int ,
204+ block_class : Type [BaseBeaconBlock ]) -> BaseBeaconBlock :
186205 canonical_block_root = cls ._get_canonical_block_root_by_slot (db , slot )
187- return cls ._get_block_by_root (db , canonical_block_root )
206+ return cls ._get_block_by_root (db , canonical_block_root , block_class )
188207
189208 def get_canonical_block_root_by_slot (self , slot : int ) -> Hash32 :
190209 """
@@ -207,21 +226,25 @@ def get_canonical_head(self) -> BaseBeaconBlock:
207226 """
208227 Return the current block at the head of the chain.
209228 """
210- return self ._get_canonical_head (self .db )
229+ return self ._get_canonical_head (self .db , self . block_class )
211230
212231 @classmethod
213- def _get_canonical_head (cls , db : BaseDB ) -> BaseBeaconBlock :
232+ def _get_canonical_head (cls ,
233+ db : BaseDB ,
234+ block_class : Type [BaseBeaconBlock ]) -> BaseBeaconBlock :
214235 try :
215236 canonical_head_root = db [SchemaV1 .make_canonical_head_root_lookup_key ()]
216237 except KeyError :
217238 raise CanonicalHeadNotFound ("No canonical head set for this chain" )
218- return cls ._get_block_by_root (db , Hash32 (canonical_head_root ))
239+ return cls ._get_block_by_root (db , Hash32 (canonical_head_root ), block_class )
219240
220241 def get_block_by_root (self , block_root : Hash32 ) -> BaseBeaconBlock :
221- return self ._get_block_by_root (self .db , block_root )
242+ return self ._get_block_by_root (self .db , block_root , self . block_class )
222243
223244 @staticmethod
224- def _get_block_by_root (db : BaseDB , block_root : Hash32 ) -> BaseBeaconBlock :
245+ def _get_block_by_root (db : BaseDB ,
246+ block_root : Hash32 ,
247+ block_class : Type [BaseBeaconBlock ]) -> BaseBeaconBlock :
225248 """
226249 Return the requested block header as specified by block root.
227250
@@ -233,7 +256,7 @@ def _get_block_by_root(db: BaseDB, block_root: Hash32) -> BaseBeaconBlock:
233256 except KeyError :
234257 raise BlockNotFound ("No block with root {0} found" .format (
235258 encode_hex (block_root )))
236- return _decode_block (block_rlp )
259+ return _decode_block (block_rlp , block_class )
237260
238261 def get_score (self , block_root : Hash32 ) -> int :
239262 return self ._get_score (self .db , block_root )
@@ -264,13 +287,14 @@ def persist_block_chain(
264287 the second containing the old canonical headers
265288 """
266289 with self .db .atomic_batch () as db :
267- return self ._persist_block_chain (db , blocks )
290+ return self ._persist_block_chain (db , blocks , self . block_class )
268291
269292 @classmethod
270293 def _persist_block_chain (
271294 cls ,
272295 db : BaseDB ,
273- blocks : Iterable [BaseBeaconBlock ]
296+ blocks : Iterable [BaseBeaconBlock ],
297+ block_class : Type [BaseBeaconBlock ]
274298 ) -> Tuple [Tuple [BaseBeaconBlock , ...], Tuple [BaseBeaconBlock , ...]]:
275299 try :
276300 first_block = first (blocks )
@@ -313,20 +337,23 @@ def _persist_block_chain(
313337 )
314338
315339 try :
316- previous_canonical_head = cls ._get_canonical_head (db ).root
340+ previous_canonical_head = cls ._get_canonical_head (db , block_class ).root
317341 head_score = cls ._get_score (db , previous_canonical_head )
318342 except CanonicalHeadNotFound :
319- return cls ._set_as_canonical_chain_head (db , block .root )
343+ return cls ._set_as_canonical_chain_head (db , block .root , block_class )
320344
321345 if score > head_score :
322- return cls ._set_as_canonical_chain_head (db , block .root )
346+ return cls ._set_as_canonical_chain_head (db , block .root , block_class )
323347 else :
324348 return tuple (), tuple ()
325349
326350 @classmethod
327351 def _set_as_canonical_chain_head (
328- cls , db : BaseDB ,
329- block_root : Hash32 ) -> Tuple [Tuple [BaseBeaconBlock , ...], Tuple [BaseBeaconBlock , ...]]:
352+ cls ,
353+ db : BaseDB ,
354+ block_root : Hash32 ,
355+ block_class : Type [BaseBeaconBlock ]
356+ ) -> Tuple [Tuple [BaseBeaconBlock , ...], Tuple [BaseBeaconBlock , ...]]:
330357 """
331358 Set the canonical chain HEAD to the block as specified by the
332359 given block root.
@@ -335,13 +362,13 @@ def _set_as_canonical_chain_head(
335362 are no longer in the canonical chain
336363 """
337364 try :
338- block = cls ._get_block_by_root (db , block_root )
365+ block = cls ._get_block_by_root (db , block_root , block_class )
339366 except BlockNotFound :
340367 raise ValueError (
341368 "Cannot use unknown block root as canonical head: {}" .format (block_root )
342369 )
343370
344- new_canonical_blocks = tuple (reversed (cls ._find_new_ancestors (db , block )))
371+ new_canonical_blocks = tuple (reversed (cls ._find_new_ancestors (db , block , block_class )))
345372 old_canonical_blocks = []
346373
347374 for block in new_canonical_blocks :
@@ -351,7 +378,7 @@ def _set_as_canonical_chain_head(
351378 # no old_canonical block, and no more possible
352379 break
353380 else :
354- old_canonical_block = cls ._get_block_by_root (db , old_canonical_root )
381+ old_canonical_block = cls ._get_block_by_root (db , old_canonical_root , block_class )
355382 old_canonical_blocks .append (old_canonical_block )
356383
357384 for block in new_canonical_blocks :
@@ -363,7 +390,11 @@ def _set_as_canonical_chain_head(
363390
364391 @classmethod
365392 @to_tuple
366- def _find_new_ancestors (cls , db : BaseDB , block : BaseBeaconBlock ) -> Iterable [BaseBeaconBlock ]:
393+ def _find_new_ancestors (
394+ cls ,
395+ db : BaseDB ,
396+ block : BaseBeaconBlock ,
397+ block_class : Type [BaseBeaconBlock ]) -> Iterable [BaseBeaconBlock ]:
367398 """
368399 Return the chain leading up from the given block until (but not including)
369400 the first ancestor it has in common with our canonical chain.
@@ -377,7 +408,7 @@ def _find_new_ancestors(cls, db: BaseDB, block: BaseBeaconBlock) -> Iterable[Bas
377408 """
378409 while True :
379410 try :
380- orig = cls ._get_canonical_block_by_slot (db , block .slot )
411+ orig = cls ._get_canonical_block_by_slot (db , block .slot , block_class )
381412 except BlockNotFound :
382413 # This just means the block is not on the canonical chain.
383414 pass
@@ -392,7 +423,7 @@ def _find_new_ancestors(cls, db: BaseDB, block: BaseBeaconBlock) -> Iterable[Bas
392423 if block .parent_root == GENESIS_PARENT_HASH :
393424 break
394425 else :
395- block = cls ._get_block_by_root (db , block .parent_root )
426+ block = cls ._get_block_by_root (db , block .parent_root , block_class )
396427
397428 @staticmethod
398429 def _add_block_slot_to_root_lookup (db : BaseDB , block : BaseBeaconBlock ) -> None :
@@ -466,9 +497,8 @@ def get(self, key: bytes) -> bytes:
466497# relatively expensive so we cache that here, but use a small cache because we *should* only
467498# be looking up recent blocks.
468499@functools .lru_cache (128 )
469- def _decode_block (block_rlp : bytes ) -> BaseBeaconBlock :
470- # TODO: forkable Block fields?
471- return rlp .decode (block_rlp , sedes = BaseBeaconBlock )
500+ def _decode_block (block_rlp : bytes , sedes : Type [BaseBeaconBlock ]) -> BaseBeaconBlock :
501+ return rlp .decode (block_rlp , sedes = sedes )
472502
473503
474504@functools .lru_cache (128 )
0 commit comments