from database import read from reddit_scraper.utils import limit_iter from typing import Iterable, Generic, TypeVar, List, Iterator T = TypeVar('T') class CustomIter(Generic[T]): def __init__(self, i: Iterable[T]) -> None: self.i = i self.l: List[T] = [] self.is_generated = False # if i doesn't have __next__ pre generate it as it is not lazy if not isinstance(self.i, Iterator): self.generate() def generate(self): if not self.is_generated: self.l.extend(self.i) self.is_generated = True def __len__(self): self.generate() return len(self.l) def __iter__(self): yield from self.l if not self.is_generated: for j in self.i: self.l.append(j) yield j self.is_generated = True @property def empty(self) -> bool: if len(self.l): return False if not self.is_generated: for j in self.i: self.l.append(j) return False return True @property def first(self) -> T: if self.empty: raise KeyError("can't get first element from empty list") return self.l[0] @property def last(self) -> T: self.generate() if not len(self.l): raise KeyError("can't get last element from empty list") return self.l[-1] def test_iter(): print("yield foo") yield "foo" print("yield bar") yield "bar" print("yield baz") yield "baz" if __name__ == "__main__": ci: Iterable = CustomIter(test_iter()) for c in ci: print(c) break print("#" * 10) for c in ci: print(c) print("#" * 10) for c in ci: print(c)