diff --git a/.env.example b/.env.example
new file mode 100644
index 0000000..2c5c29d
--- /dev/null
+++ b/.env.example
@@ -0,0 +1,12 @@
+CONN_URI=sqlite+aiosqlite:///data/db.sqlite3
+DEBUG=True
+PROJECT_URL=http://127.0.0.1
+PROJECT_LOGIN_SUCCESS_URL=http://google.com
+PROJECT_PORT=80
+JWT_PRIVATE_KEY='data/private_key'
+JWT_PUBLIC_KEY='data/public_key'
+BOT_TOKEN=xxx
+BOT_USERNAME=xxxxBot
+BOT_API_ID=111
+BOT_API_HASH=aaa
+BOT_MANAGER_IDS=[111,222]
\ No newline at end of file
diff --git a/.gitignore b/.gitignore
index 68bc17f..e08e503 100644
--- a/.gitignore
+++ b/.gitignore
@@ -157,4 +157,6 @@ cython_debug/
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
-#.idea/
+.idea/
+
+data/
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..c54138c
--- /dev/null
+++ b/README.md
@@ -0,0 +1,32 @@
+# Telegram OAuth
+
+## Configuration
+
+```dotenv
+CONN_URI=sqlite+aiosqlite:///data/db.sqlite3 # 数据库 uri
+DEBUG=True # 调试模式
+PROJECT_URL=http://127.0.0.1 # 项目可访问的地址
+PROJECT_LOGIN_SUCCESS_URL=http://google.com # 登录成功后跳转的地址
+PROJECT_PORT=80 # 项目运行的端口
+JWT_PRIVATE_KEY='data/private_key' # jwt 私钥
+JWT_PUBLIC_KEY='data/public_key' # jwt 公钥
+BOT_TOKEN=xxx # 机器人 token
+BOT_USERNAME=xxxxBot # 机器人用户名
+BOT_API_ID=111 # api id
+BOT_API_HASH=aaa # api hash
+BOT_MANAGER_IDS=[111,222] # 管理员 id
+```
+
+## OIDC Endpoints
+
+Auth URL : `/oauth2/authorize`
+
+Token URL : `/oauth2/token`
+
+Cert URL : `/oauth2/keys`
+
+## OIDC Client
+
+```sql
+INSERT INTO "client" ("grant_types", "response_types", "redirect_uris", "id", "client_id", "client_secret", "scope") VALUES ('authorization_code', 'code', 'https://127.0.0.1/access/callback', 'UUID', '123456', '123456', 'openid profile email');
+```
diff --git a/aioauth_fastapi/__init__.py b/aioauth_fastapi/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/aioauth_fastapi/__version__.py b/aioauth_fastapi/__version__.py
new file mode 100644
index 0000000..d5650b1
--- /dev/null
+++ b/aioauth_fastapi/__version__.py
@@ -0,0 +1,8 @@
+__title__ = "aioauth_fastapi"
+__description__ = "aioauth integration for FastAPI."
+__url__ = "https://github.com/aliev/aioauth-fastapi"
+__version__ = "0.1.2"
+__author__ = "Ali Aliyev"
+__author_email__ = "ali@aliev.me"
+__license__ = "The MIT License (MIT)"
+__copyright__ = "Copyright 2021 Ali Aliyev"
diff --git a/aioauth_fastapi/forms.py b/aioauth_fastapi/forms.py
new file mode 100644
index 0000000..40a818b
--- /dev/null
+++ b/aioauth_fastapi/forms.py
@@ -0,0 +1,38 @@
+"""
+.. code-block:: python
+
+ from aioauth_fastapi import forms
+
+FastAPI oauth2 forms.
+
+Used to generate an OpenAPI schema.
+
+----
+"""
+
+from dataclasses import dataclass
+from typing import Optional
+
+from aioauth.types import GrantType, TokenType
+from fastapi.params import Form
+
+
+@dataclass
+class TokenForm:
+ grant_type: Optional[GrantType] = Form(None) # type: ignore
+ client_id: Optional[str] = Form(None) # type: ignore
+ client_secret: Optional[str] = Form(None) # type: ignore
+ redirect_uri: Optional[str] = Form(None) # type: ignore
+ scope: Optional[str] = Form(None) # type: ignore
+ username: Optional[str] = Form(None) # type: ignore
+ password: Optional[str] = Form(None) # type: ignore
+ refresh_token: Optional[str] = Form(None) # type: ignore
+ code: Optional[str] = Form(None) # type: ignore
+ token: Optional[str] = Form(None) # type: ignore
+ code_verifier: Optional[str] = Form(None) # type: ignore
+
+
+@dataclass
+class TokenIntrospectForm:
+ token: Optional[str] = Form(None) # type: ignore
+ token_type_hint: Optional[TokenType] = Form(None) # type: ignore
diff --git a/aioauth_fastapi/router.py b/aioauth_fastapi/router.py
new file mode 100644
index 0000000..be3c3b0
--- /dev/null
+++ b/aioauth_fastapi/router.py
@@ -0,0 +1,113 @@
+"""
+.. code-block:: python
+
+ from aioauth_fastapi import router
+
+FastAPI routing of oauth2.
+
+Usage example
+
+.. code-block:: python
+
+ from aioauth_fastapi.router import get_oauth2_router
+ from aioauth.storage import BaseStorage
+ from aioauth.config import Settings
+ from aioauth.server import AuthorizationServer
+ from fastapi import FastAPI
+
+ app = FastAPI()
+
+ class SQLAlchemyCRUD(BaseStorage):
+ '''
+ SQLAlchemyCRUD methods must be implemented here.
+ '''
+
+ # NOTE: Redefinition of the default aioauth settings
+ # INSECURE_TRANSPORT must be enabled for local development only!
+ settings = Settings(
+ INSECURE_TRANSPORT=True,
+ )
+
+ storage = SQLAlchemyCRUD()
+ authorization_server = AuthorizationServer(storage)
+
+ # Include FastAPI router with oauth2 endpoints.
+ app.include_router(
+ get_oauth2_router(authorization_server, settings),
+ prefix="/oauth2",
+ tags=["oauth2"],
+ )
+
+----
+"""
+
+from typing import Callable, TypeVar
+
+from aioauth.config import Settings
+from aioauth.requests import TRequest
+from aioauth.server import AuthorizationServer
+from aioauth.storage import TStorage
+from fastapi import APIRouter, Request
+
+from .utils import (
+ RequestArguments,
+ default_request_factory,
+ to_fastapi_response,
+ to_oauth2_request,
+)
+
+ARequest = TypeVar("ARequest", bound=TRequest)
+
+
+def get_oauth2_router(
+ authorization_server: AuthorizationServer[ARequest, TStorage],
+ settings: Settings = Settings(),
+ request_factory: Callable[[RequestArguments], ARequest] = default_request_factory,
+) -> APIRouter:
+ """Function will create FastAPI router with the following oauth2 endpoints:
+
+ * POST /token
+ * Endpoint creates a token response by :py:meth:`aioauth.server.AuthorizationServer.create_token_response`
+ * POST `/token/introspect`
+ * Endpoint creates a token introspection by :py:meth:`aioauth.server.AuthorizationServer.create_token_introspection_response`
+ * GET `/authorize`
+ * Endpoint creates an authorization response by :py:meth:`aioauth.server.AuthorizationServer.create_authorization_response`
+
+ Returns:
+ :py:class:`fastapi.APIRouter`.
+ """
+ router = APIRouter()
+
+ @router.post("/token")
+ async def token(request: Request):
+ oauth2_request = await to_oauth2_request(
+ request=request, request_factory=request_factory, settings=settings
+ )
+ oauth2_response = await authorization_server.create_token_response(
+ oauth2_request
+ )
+ return await to_fastapi_response(oauth2_response)
+
+ @router.post("/token/introspect")
+ async def token_introspect(request: Request):
+ oauth2_request = await to_oauth2_request(
+ request=request, request_factory=request_factory, settings=settings
+ )
+ oauth2_response = (
+ await authorization_server.create_token_introspection_response(
+ oauth2_request
+ )
+ )
+ return await to_fastapi_response(oauth2_response)
+
+ @router.get("/authorize")
+ async def authorize(request: Request):
+ oauth2_request = await to_oauth2_request(
+ request=request, request_factory=request_factory, settings=settings
+ )
+ oauth2_response = await authorization_server.create_authorization_response(
+ oauth2_request
+ )
+ return await to_fastapi_response(oauth2_response)
+
+ return router
diff --git a/aioauth_fastapi/utils.py b/aioauth_fastapi/utils.py
new file mode 100644
index 0000000..788932e
--- /dev/null
+++ b/aioauth_fastapi/utils.py
@@ -0,0 +1,98 @@
+"""
+.. code-block:: python
+
+ from aioauth_fastapi import utils
+
+Core utils for integration with FastAPI
+
+----
+"""
+
+import json
+from dataclasses import dataclass
+from typing import Callable, Dict, Optional
+
+from aioauth.collections import HTTPHeaderDict
+from aioauth.config import Settings
+from aioauth.requests import Post, Query, TRequest, TUser
+from aioauth.requests import Request as OAuth2Request
+from aioauth.responses import Response as OAuth2Response
+from fastapi import Request, Response
+
+
+@dataclass
+class RequestArguments:
+ headers: HTTPHeaderDict
+ method: str
+ post_args: Dict
+ query_args: Dict
+ settings: Settings
+ url: str
+ user: Optional[TUser]
+
+
+def default_request_factory(request_args: RequestArguments) -> OAuth2Request:
+ return OAuth2Request(
+ headers=request_args.headers,
+ method=request_args.method, # type: ignore
+ post=Post(**request_args.post_args), # type: ignore
+ query=Query(**request_args.query_args), # type: ignore
+ settings=request_args.settings,
+ url=request_args.url,
+ user=request_args.user,
+ )
+
+
+async def to_oauth2_request(
+ request: Request,
+ settings: Settings = Settings(),
+ request_factory: Callable[[RequestArguments], TRequest] = default_request_factory,
+) -> TRequest:
+ """Converts :py:class:`fastapi.Request` instance to :py:class:`aioauth.requests.Request` instance"""
+ form = await request.form()
+
+ post_args = dict(form)
+ query_args = dict(request.query_params)
+ need_args = [
+ "client_id",
+ "redirect_uri",
+ "response_type",
+ "state",
+ "scope",
+ "nonce",
+ "code_challenge_method",
+ "code_challenge",
+ "response_mode",
+ ]
+ for arg in list(query_args.keys()):
+ if arg not in need_args:
+ del query_args[arg]
+ method = request.method
+ headers = HTTPHeaderDict(**request.headers)
+ url = str(request.url)
+
+ user = None
+
+ if request.user.is_authenticated:
+ user = request.user
+
+ request_args = RequestArguments(
+ headers=headers,
+ method=method,
+ post_args=post_args,
+ query_args=query_args,
+ settings=settings,
+ url=url,
+ user=user,
+ )
+ return request_factory(request_args)
+
+
+async def to_fastapi_response(oauth2_response: OAuth2Response) -> Response:
+ """Converts :py:class:`aioauth.responses.Response` instance to :py:class:`fastapi.Response` instance"""
+ response_content = oauth2_response.content
+ headers = dict(oauth2_response.headers)
+ status_code = oauth2_response.status_code
+ content = json.dumps(response_content)
+
+ return Response(content=content, headers=headers, status_code=status_code)
diff --git a/alembic.ini b/alembic.ini
new file mode 100644
index 0000000..436dba5
--- /dev/null
+++ b/alembic.ini
@@ -0,0 +1,89 @@
+# A generic, single database configuration.
+
+[alembic]
+# path to migration scripts
+script_location = alembic/
+
+# template used to generate migration files
+# file_template = %%(rev)s_%%(slug)s
+
+# sys.path path, will be prepended to sys.path if present.
+# defaults to the current working directory.
+prepend_sys_path = .
+
+# timezone to use when rendering the date
+# within the migration file as well as the filename.
+# string value is passed to dateutil.tz.gettz()
+# leave blank for localtime
+# timezone =
+
+# max length of characters to apply to the
+# "slug" field
+# truncate_slug_length = 40
+
+# set to 'true' to run the environment during
+# the 'revision' command, regardless of autogenerate
+# revision_environment = false
+
+# set to 'true' to allow .pyc and .pyo files without
+# a source .py file to be detected as revisions in the
+# versions/ directory
+# sourceless = false
+
+# version location specification; this defaults
+# to ./versions. When using multiple version
+# directories, initial revisions must be specified with --version-path
+# version_locations = %(here)s/bar %(here)s/bat ./versions
+
+# the output encoding used when revision files
+# are written from script.py.mako
+# output_encoding = utf-8
+
+sqlalchemy.url = sqlite+aiosqlite:///data/db.sqlite3
+
+
+[post_write_hooks]
+# post_write_hooks defines scripts or Python functions that are run
+# on newly generated revision scripts. See the documentation for further
+# detail and examples
+
+# format using "black" - use the console_scripts runner, against the "black" entrypoint
+# hooks = black
+# black.type = console_scripts
+# black.entrypoint = black
+# black.options = -l 79 REVISION_SCRIPT_FILENAME
+
+# Logging configuration
+[loggers]
+keys = root,sqlalchemy,alembic
+
+[handlers]
+keys = console
+
+[formatters]
+keys = generic
+
+[logger_root]
+level = WARN
+handlers = console
+qualname =
+
+[logger_sqlalchemy]
+level = WARN
+handlers =
+qualname = sqlalchemy.engine
+
+[logger_alembic]
+level = INFO
+handlers =
+qualname = alembic
+
+[handler_console]
+class = StreamHandler
+args = (sys.stderr,)
+level = NOTSET
+formatter = generic
+
+[formatter_generic]
+format = %(levelname)-5.5s [%(name)s] %(message)s
+datefmt = %H:%M:%S
diff --git a/alembic/env.py b/alembic/env.py
new file mode 100644
index 0000000..28bc9b5
--- /dev/null
+++ b/alembic/env.py
@@ -0,0 +1,90 @@
+import asyncio
+from logging.config import fileConfig
+
+from alembic import context
+from sqlalchemy import engine_from_config, pool
+from sqlalchemy.ext.asyncio import AsyncEngine
+from sqlmodel import SQLModel
+
+from src.config import settings
+from src.oauth2.models import * # noqa
+from src.users.models import * # noqa
+
+# this is the Alembic Config object, which provides
+# access to the values within the .ini file in use.
+config = context.config
+
+# Interpret the config file for Python logging.
+# This line sets up loggers basically.
+fileConfig(config.config_file_name)
+
+config.set_main_option("sqlalchemy.url", str(settings.CONN_URI))
+
+# add your model's MetaData object here
+# for 'autogenerate' support
+# from myapp import mymodel
+# target_metadata = mymodel.Base.metadata
+target_metadata = SQLModel.metadata # type: ignore
+
+
+# other values from the config, defined by the needs of env.py,
+# can be acquired:
+# my_important_option = config.get_main_option("my_important_option")
+# ... etc.
+
+
+def run_migrations_offline():
+ """Run migrations in 'offline' mode.
+
+ This configures the context with just a URL
+ and not an Engine, though an Engine is acceptable
+ here as well. By skipping the Engine creation
+ we don't even need a DBAPI to be available.
+
+ Calls to context.execute() here emit the given string to the
+ script output.
+
+ """
+ url = config.get_main_option("sqlalchemy.url")
+ context.configure(
+ url=url,
+ target_metadata=target_metadata,
+ literal_binds=True,
+ dialect_opts={"paramstyle": "named"},
+ )
+
+ with context.begin_transaction():
+ context.run_migrations()
+
+
+def do_run_migrations(connection):
+ context.configure(connection=connection, target_metadata=target_metadata)
+
+ with context.begin_transaction():
+ context.run_migrations()
+
+
+async def run_migrations_online():
+ """Run migrations in 'online' mode.
+
+ In this scenario we need to create an Engine
+ and associate a connection with the context.
+
+ """
+ connectable = AsyncEngine(
+ engine_from_config(
+ config.get_section(config.config_ini_section),
+ prefix="sqlalchemy.",
+ poolclass=pool.NullPool,
+ future=True,
+ )
+ )
+
+ async with connectable.connect() as connection:
+ await connection.run_sync(do_run_migrations)
+
+
+if context.is_offline_mode():
+ run_migrations_offline()
+else:
+ asyncio.run(run_migrations_online())
diff --git a/alembic/script.py.mako b/alembic/script.py.mako
new file mode 100644
index 0000000..e45068f
--- /dev/null
+++ b/alembic/script.py.mako
@@ -0,0 +1,25 @@
+"""${message}
+
+Revision ID: ${up_revision}
+Revises: ${down_revision | comma,n}
+Create Date: ${create_date}
+
+"""
+from alembic import op
+import sqlalchemy as sa
+import sqlmodel
+${imports if imports else ""}
+
+# revision identifiers, used by Alembic.
+revision = ${repr(up_revision)}
+down_revision = ${repr(down_revision)}
+branch_labels = ${repr(branch_labels)}
+depends_on = ${repr(depends_on)}
+
+
+def upgrade():
+ ${upgrades if upgrades else "pass"}
+
+
+def downgrade():
+ ${downgrades if downgrades else "pass"}
diff --git a/alembic/versions/07a7ace268a7_initial_migrations.py b/alembic/versions/07a7ace268a7_initial_migrations.py
new file mode 100644
index 0000000..ab7550d
--- /dev/null
+++ b/alembic/versions/07a7ace268a7_initial_migrations.py
@@ -0,0 +1,226 @@
+"""Initial migrations
+
+Revision ID: 07a7ace268a7
+Revises:
+Create Date: 2021-10-02 22:50:10.418498
+
+"""
+import sqlalchemy as sa
+import sqlmodel
+from alembic import op
+
+# revision identifiers, used by Alembic.
+revision = "07a7ace268a7"
+down_revision = None
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.create_table(
+ "users",
+ sa.Column("id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
+ sa.Column("is_superuser", sa.Boolean(), nullable=True),
+ sa.Column("is_blocked", sa.Boolean(), nullable=True),
+ sa.Column("is_active", sa.Boolean(), nullable=True),
+ sa.Column("username", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
+ sa.Column("password", sqlmodel.sql.sqltypes.AutoString(), nullable=True),
+ sa.PrimaryKeyConstraint("id"),
+ )
+ op.create_index(op.f("ix_users_id"), "users", ["id"], unique=True)
+ op.create_index(op.f("ix_users_is_active"), "users", ["is_active"], unique=False)
+ op.create_index(op.f("ix_users_is_blocked"), "users", ["is_blocked"], unique=False)
+ op.create_index(
+ op.f("ix_users_is_superuser"), "users", ["is_superuser"], unique=False
+ )
+ op.create_index(op.f("ix_users_password"), "users", ["password"], unique=False)
+ op.create_index(op.f("ix_users_username"), "users", ["username"], unique=True)
+ op.create_table(
+ "authorizationcode",
+ sa.Column("id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
+ sa.Column("code", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
+ sa.Column("client_id", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
+ sa.Column("redirect_uri", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
+ sa.Column("response_type", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
+ sa.Column("scope", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
+ sa.Column("auth_time", sa.Integer(), nullable=False),
+ sa.Column("expires_in", sa.Integer(), nullable=False),
+ sa.Column("code_challenge", sqlmodel.sql.sqltypes.AutoString(), nullable=True),
+ sa.Column(
+ "code_challenge_method", sqlmodel.sql.sqltypes.AutoString(), nullable=True
+ ),
+ sa.Column("nonce", sqlmodel.sql.sqltypes.AutoString(), nullable=True),
+ sa.Column("user_id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
+ sa.ForeignKeyConstraint(
+ ["user_id"],
+ ["users.id"],
+ ),
+ sa.PrimaryKeyConstraint("id"),
+ )
+ op.create_index(
+ op.f("ix_authorizationcode_auth_time"),
+ "authorizationcode",
+ ["auth_time"],
+ unique=False,
+ )
+ op.create_index(
+ op.f("ix_authorizationcode_client_id"),
+ "authorizationcode",
+ ["client_id"],
+ unique=False,
+ )
+ op.create_index(
+ op.f("ix_authorizationcode_code"), "authorizationcode", ["code"], unique=False
+ )
+ op.create_index(
+ op.f("ix_authorizationcode_code_challenge"),
+ "authorizationcode",
+ ["code_challenge"],
+ unique=False,
+ )
+ op.create_index(
+ op.f("ix_authorizationcode_code_challenge_method"),
+ "authorizationcode",
+ ["code_challenge_method"],
+ unique=False,
+ )
+ op.create_index(
+ op.f("ix_authorizationcode_expires_in"),
+ "authorizationcode",
+ ["expires_in"],
+ unique=False,
+ )
+ op.create_index(
+ op.f("ix_authorizationcode_id"), "authorizationcode", ["id"], unique=True
+ )
+ op.create_index(
+ op.f("ix_authorizationcode_nonce"), "authorizationcode", ["nonce"], unique=False
+ )
+ op.create_index(
+ op.f("ix_authorizationcode_redirect_uri"),
+ "authorizationcode",
+ ["redirect_uri"],
+ unique=False,
+ )
+ op.create_index(
+ op.f("ix_authorizationcode_response_type"),
+ "authorizationcode",
+ ["response_type"],
+ unique=False,
+ )
+ op.create_index(
+ op.f("ix_authorizationcode_scope"), "authorizationcode", ["scope"], unique=False
+ )
+ op.create_index(
+ op.f("ix_authorizationcode_user_id"),
+ "authorizationcode",
+ ["user_id"],
+ unique=False,
+ )
+ op.create_table(
+ "client",
+ sa.Column("grant_types", sqlmodel.sql.sqltypes.AutoString(), nullable=True),
+ sa.Column("response_types", sqlmodel.sql.sqltypes.AutoString(), nullable=True),
+ sa.Column("redirect_uris", sqlmodel.sql.sqltypes.AutoString(), nullable=True),
+ sa.Column("id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
+ sa.Column("client_id", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
+ sa.Column("client_secret", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
+ sa.Column("scope", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
+ sa.PrimaryKeyConstraint("id"),
+ )
+ op.create_index(op.f("ix_client_client_id"), "client", ["client_id"], unique=False)
+ op.create_index(
+ op.f("ix_client_client_secret"), "client", ["client_secret"], unique=False
+ )
+ op.create_index(op.f("ix_client_id"), "client", ["id"], unique=True)
+ op.create_index(op.f("ix_client_scope"), "client", ["scope"], unique=False)
+ op.create_table(
+ "token",
+ sa.Column("id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
+ sa.Column("access_token", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
+ sa.Column("refresh_token", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
+ sa.Column("scope", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
+ sa.Column("issued_at", sa.Integer(), nullable=False),
+ sa.Column("expires_in", sa.Integer(), nullable=False),
+ sa.Column("refresh_token_expires_in", sa.Integer(), nullable=False),
+ sa.Column("client_id", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
+ sa.Column("token_type", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
+ sa.Column("revoked", sa.Boolean(), nullable=False),
+ sa.Column("user_id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
+ sa.ForeignKeyConstraint(
+ ["user_id"],
+ ["users.id"],
+ ),
+ sa.PrimaryKeyConstraint("id"),
+ )
+ op.create_index(op.f("ix_token_client_id"), "token", ["client_id"], unique=False)
+ op.create_index(op.f("ix_token_expires_in"), "token", ["expires_in"], unique=False)
+ op.create_index(op.f("ix_token_id"), "token", ["id"], unique=True)
+ op.create_index(op.f("ix_token_issued_at"), "token", ["issued_at"], unique=False)
+ op.create_index(
+ op.f("ix_token_refresh_token_expires_in"),
+ "token",
+ ["refresh_token_expires_in"],
+ unique=False,
+ )
+ op.create_index(op.f("ix_token_revoked"), "token", ["revoked"], unique=False)
+ op.create_index(op.f("ix_token_scope"), "token", ["scope"], unique=False)
+ op.create_index(op.f("ix_token_token_type"), "token", ["token_type"], unique=False)
+ op.create_index(op.f("ix_token_user_id"), "token", ["user_id"], unique=False)
+ # ### end Alembic commands ###
+
+
+def downgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.drop_index(op.f("ix_token_user_id"), table_name="token")
+ op.drop_index(op.f("ix_token_token_type"), table_name="token")
+ op.drop_index(op.f("ix_token_scope"), table_name="token")
+ op.drop_index(op.f("ix_token_revoked"), table_name="token")
+ op.drop_index(op.f("ix_token_refresh_token_expires_in"), table_name="token")
+ op.drop_index(op.f("ix_token_issued_at"), table_name="token")
+ op.drop_index(op.f("ix_token_id"), table_name="token")
+ op.drop_index(op.f("ix_token_expires_in"), table_name="token")
+ op.drop_index(op.f("ix_token_client_id"), table_name="token")
+ op.drop_table("token")
+ op.drop_index(op.f("ix_client_scope"), table_name="client")
+ op.drop_index(op.f("ix_client_id"), table_name="client")
+ op.drop_index(op.f("ix_client_client_secret"), table_name="client")
+ op.drop_index(op.f("ix_client_client_id"), table_name="client")
+ op.drop_table("client")
+ op.drop_index(op.f("ix_authorizationcode_user_id"), table_name="authorizationcode")
+ op.drop_index(op.f("ix_authorizationcode_scope"), table_name="authorizationcode")
+ op.drop_index(
+ op.f("ix_authorizationcode_response_type"), table_name="authorizationcode"
+ )
+ op.drop_index(
+ op.f("ix_authorizationcode_redirect_uri"), table_name="authorizationcode"
+ )
+ op.drop_index(op.f("ix_authorizationcode_nonce"), table_name="authorizationcode")
+ op.drop_index(op.f("ix_authorizationcode_id"), table_name="authorizationcode")
+ op.drop_index(
+ op.f("ix_authorizationcode_expires_in"), table_name="authorizationcode"
+ )
+ op.drop_index(
+ op.f("ix_authorizationcode_code_challenge_method"),
+ table_name="authorizationcode",
+ )
+ op.drop_index(
+ op.f("ix_authorizationcode_code_challenge"), table_name="authorizationcode"
+ )
+ op.drop_index(op.f("ix_authorizationcode_code"), table_name="authorizationcode")
+ op.drop_index(
+ op.f("ix_authorizationcode_client_id"), table_name="authorizationcode"
+ )
+ op.drop_index(
+ op.f("ix_authorizationcode_auth_time"), table_name="authorizationcode"
+ )
+ op.drop_table("authorizationcode")
+ op.drop_index(op.f("ix_users_username"), table_name="users")
+ op.drop_index(op.f("ix_users_password"), table_name="users")
+ op.drop_index(op.f("ix_users_is_superuser"), table_name="users")
+ op.drop_index(op.f("ix_users_is_blocked"), table_name="users")
+ op.drop_index(op.f("ix_users_is_active"), table_name="users")
+ op.drop_index(op.f("ix_users_id"), table_name="users")
+ op.drop_table("users")
+ # ### end Alembic commands ###
diff --git a/alembic/versions/c76c4cbb0b3b_tgid.py b/alembic/versions/c76c4cbb0b3b_tgid.py
new file mode 100644
index 0000000..2df3e5c
--- /dev/null
+++ b/alembic/versions/c76c4cbb0b3b_tgid.py
@@ -0,0 +1,27 @@
+"""tgid
+
+Revision ID: c76c4cbb0b3b
+Revises: 07a7ace268a7
+Create Date: 2024-01-13 16:12:45.884304
+
+"""
+import sqlalchemy as sa
+from alembic import op
+
+# revision identifiers, used by Alembic.
+revision = 'c76c4cbb0b3b'
+down_revision = '07a7ace268a7'
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.add_column('users', sa.Column('tg_id', sa.BigInteger(), nullable=False))
+ # ### end Alembic commands ###
+
+
+def downgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.drop_column('users', 'tg_id')
+ # ### end Alembic commands ###
diff --git a/gen_keys.py b/gen_keys.py
new file mode 100644
index 0000000..a24edb1
--- /dev/null
+++ b/gen_keys.py
@@ -0,0 +1,26 @@
+from pathlib import Path
+
+data_path = Path("data")
+data_path.mkdir(exist_ok=True)
+private_key_path = data_path / "private_key"
+public_key_path = data_path / "public_key"
+
+
+def gen_keys():
+ from Crypto.PublicKey import RSA
+
+ key = RSA.generate(2048)
+ private_key = key.export_key().decode("utf-8")
+ public_key = key.publickey().export_key().decode("utf-8")
+
+ if private_key_path.is_file() and public_key_path.is_file():
+ print("Keys already exist")
+ return
+ with open(private_key_path, "w") as f:
+ f.write(private_key)
+ with open(public_key_path, "w") as f:
+ f.write(public_key)
+
+
+if __name__ == "__main__":
+ gen_keys()
diff --git a/html/login.jinja b/html/login.jinja
new file mode 100644
index 0000000..208dced
--- /dev/null
+++ b/html/login.jinja
@@ -0,0 +1,40 @@
+
+
+
+
+ Title
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/main.py b/main.py
new file mode 100644
index 0000000..112d8b5
--- /dev/null
+++ b/main.py
@@ -0,0 +1,50 @@
+import asyncio
+from signal import signal as signal_fn, SIGINT, SIGTERM, SIGABRT
+
+from src.app import web
+from src.bot import bot
+from src.logs import logs
+
+
+async def idle():
+ task = None
+
+ def signal_handler(_, __):
+ if web.web_server_task:
+ web.web_server_task.cancel()
+ task.cancel()
+
+ for s in (SIGINT, SIGTERM, SIGABRT):
+ signal_fn(s, signal_handler)
+
+ while True:
+ task = asyncio.create_task(asyncio.sleep(600))
+ web.bot_main_task = task
+ try:
+ await task
+ except asyncio.CancelledError:
+ break
+
+
+async def main():
+ logs.info("正在启动 Web Server")
+ await web.start()
+ logs.info("正在启动 Bot")
+ await bot.start()
+ try:
+ logs.info("正在运行")
+ await idle()
+ finally:
+ try:
+ await bot.stop()
+ except ConnectionError:
+ pass
+ if web.web_server:
+ try:
+ await web.web_server.shutdown()
+ except AttributeError:
+ pass
+
+
+if __name__ == "__main__":
+ bot.run(main())
diff --git a/pyromod/__init__.py b/pyromod/__init__.py
new file mode 100644
index 0000000..c22821d
--- /dev/null
+++ b/pyromod/__init__.py
@@ -0,0 +1,21 @@
+"""
+pyromod - A monkeypatched add-on for Pyrogram
+Copyright (C) 2020 Cezar H.
+
+This file is part of pyromod.
+
+pyromod is free software: you can redistribute it and/or modify
+it under the terms of the GNU General Public License as published by
+the Free Software Foundation, either version 3 of the License, or
+(at your option) any later version.
+
+pyromod is distributed in the hope that it will be useful,
+but WITHOUT ANY WARRANTY; without even the implied warranty of
+MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+GNU General Public License for more details.
+
+You should have received a copy of the GNU General Public License
+along with pyromod. If not, see .
+"""
+
+__version__ = "1.5"
diff --git a/pyromod/listen/__init__.py b/pyromod/listen/__init__.py
new file mode 100644
index 0000000..b916072
--- /dev/null
+++ b/pyromod/listen/__init__.py
@@ -0,0 +1,21 @@
+"""
+pyromod - A monkeypatcher add-on for Pyrogram
+Copyright (C) 2020 Cezar H.
+
+This file is part of pyromod.
+
+pyromod is free software: you can redistribute it and/or modify
+it under the terms of the GNU General Public License as published by
+the Free Software Foundation, either version 3 of the License, or
+(at your option) any later version.
+
+pyromod is distributed in the hope that it will be useful,
+but WITHOUT ANY WARRANTY; without even the implied warranty of
+MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+GNU General Public License for more details.
+
+You should have received a copy of the GNU General Public License
+along with pyromod. If not, see .
+"""
+
+from .listen import Client, MessageHandler, Chat, User
diff --git a/pyromod/listen/listen.py b/pyromod/listen/listen.py
new file mode 100644
index 0000000..a443885
--- /dev/null
+++ b/pyromod/listen/listen.py
@@ -0,0 +1,157 @@
+"""
+pyromod - A monkeypatcher add-on for Pyrogram
+Copyright (C) 2020 Cezar H.
+
+This file is part of pyromod.
+
+pyromod is free software: you can redistribute it and/or modify
+it under the terms of the GNU General Public License as published by
+the Free Software Foundation, either version 3 of the License, or
+(at your option) any later version.
+
+pyromod is distributed in the hope that it will be useful,
+but WITHOUT ANY WARRANTY; without even the implied warranty of
+MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+GNU General Public License for more details.
+
+You should have received a copy of the GNU General Public License
+along with pyromod. If not, see .
+"""
+
+import asyncio
+import functools
+
+import pyrogram
+
+from src.scheduler import add_delete_message_job
+from ..utils import patch, patchable
+from ..utils.errors import ListenerCanceled, TimeoutConversationError
+
+pyrogram.errors.ListenerCanceled = ListenerCanceled
+
+
+@patch(pyrogram.client.Client)
+class Client:
+ @patchable
+ def __init__(self, *args, **kwargs):
+ self.listening = {}
+ self.using_mod = True
+
+ self.old__init__(*args, **kwargs)
+
+ @patchable
+ async def listen(self, chat_id, filters=None, timeout=None):
+ if type(chat_id) != int:
+ chat = await self.get_chat(chat_id)
+ chat_id = chat.id
+
+ future = self.loop.create_future()
+ future.add_done_callback(functools.partial(self.clear_listener, chat_id))
+ self.listening.update({chat_id: {"future": future, "filters": filters}})
+ try:
+ return await asyncio.wait_for(future, timeout)
+ except asyncio.exceptions.TimeoutError as e:
+ raise TimeoutConversationError() from e
+
+ @patchable
+ async def ask(self, chat_id, text, filters=None, timeout=None, *args, **kwargs):
+ request = await self.send_message(chat_id, text, *args, **kwargs)
+ response = await self.listen(chat_id, filters, timeout)
+ response.request = request
+ return response
+
+ @patchable
+ def clear_listener(self, chat_id, future):
+ if future == self.listening[chat_id]["future"]:
+ self.listening.pop(chat_id, None)
+
+ @patchable
+ def cancel_listener(self, chat_id):
+ listener = self.listening.get(chat_id)
+ if not listener or listener["future"].done():
+ return
+
+ listener["future"].set_exception(ListenerCanceled())
+ self.clear_listener(chat_id, listener["future"])
+
+ @patchable
+ def cancel_all_listener(self):
+ for chat_id in self.listening:
+ self.cancel_listener(chat_id)
+
+
+@patch(pyrogram.handlers.message_handler.MessageHandler)
+class MessageHandler:
+ @patchable
+ def __init__(self, callback: callable, filters=None):
+ self.user_callback = callback
+ self.old__init__(self.resolve_listener, filters)
+
+ @patchable
+ async def resolve_listener(self, client, message, *args):
+ listener = client.listening.get(message.chat.id)
+ if listener and not listener["future"].done():
+ listener["future"].set_result(message)
+ else:
+ if listener and listener["future"].done():
+ client.clear_listener(message.chat.id, listener["future"])
+ await self.user_callback(client, message, *args)
+
+ @patchable
+ async def check(self, client, update):
+ listener = client.listening.get(update.chat.id)
+
+ if listener and not listener["future"].done():
+ return (
+ await listener["filters"](client, update)
+ if callable(listener["filters"])
+ else True
+ )
+
+ return await self.filters(client, update) if callable(self.filters) else True
+
+
+@patch(pyrogram.types.user_and_chats.chat.Chat)
+class Chat(pyrogram.types.Chat):
+ @patchable
+ def listen(self, *args, **kwargs):
+ return self._client.listen(self.id, *args, **kwargs)
+
+ @patchable
+ def ask(self, *args, **kwargs):
+ return self._client.ask(self.id, *args, **kwargs)
+
+ @patchable
+ def cancel_listener(self):
+ return self._client.cancel_listener(self.id)
+
+
+@patch(pyrogram.types.user_and_chats.user.User)
+class User(pyrogram.types.User):
+ @patchable
+ def listen(self, *args, **kwargs):
+ return self._client.listen(self.id, *args, **kwargs)
+
+ @patchable
+ def ask(self, *args, **kwargs):
+ return self._client.ask(self.id, *args, **kwargs)
+
+ @patchable
+ def cancel_listener(self):
+ return self._client.cancel_listener(self.id)
+
+
+@patch(pyrogram.types.messages_and_media.Message)
+class Message(pyrogram.types.Message):
+ @patchable
+ async def safe_delete(self, revoke: bool = True):
+ try:
+ return await self._client.delete_messages(
+ chat_id=self.chat.id, message_ids=self.id, revoke=revoke
+ )
+ except Exception as e: # noqa
+ return False
+
+ @patchable
+ async def delay_delete(self, delay: int = 60):
+ add_delete_message_job(self, delay)
diff --git a/pyromod/utils/__init__.py b/pyromod/utils/__init__.py
new file mode 100644
index 0000000..3c9f81a
--- /dev/null
+++ b/pyromod/utils/__init__.py
@@ -0,0 +1,21 @@
+"""
+pyromod - A monkeypatcher add-on for Pyrogram
+Copyright (C) 2020 Cezar H.
+
+This file is part of pyromod.
+
+pyromod is free software: you can redistribute it and/or modify
+it under the terms of the GNU General Public License as published by
+the Free Software Foundation, either version 3 of the License, or
+(at your option) any later version.
+
+pyromod is distributed in the hope that it will be useful,
+but WITHOUT ANY WARRANTY; without even the implied warranty of
+MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+GNU General Public License for more details.
+
+You should have received a copy of the GNU General Public License
+along with pyromod. If not, see .
+"""
+
+from .utils import patch, patchable
diff --git a/pyromod/utils/errors.py b/pyromod/utils/errors.py
new file mode 100644
index 0000000..d72ab10
--- /dev/null
+++ b/pyromod/utils/errors.py
@@ -0,0 +1,16 @@
+class TimeoutConversationError(Exception):
+ """
+ Occurs when the conversation times out.
+ """
+
+ def __init__(self):
+ super().__init__("Response read timed out")
+
+
+class ListenerCanceled(Exception):
+ """
+ Occurs when the listener is canceled.
+ """
+
+ def __init__(self):
+ super().__init__("Listener was canceled")
diff --git a/pyromod/utils/utils.py b/pyromod/utils/utils.py
new file mode 100644
index 0000000..2202260
--- /dev/null
+++ b/pyromod/utils/utils.py
@@ -0,0 +1,38 @@
+"""
+pyromod - A monkeypatcher add-on for Pyrogram
+Copyright (C) 2020 Cezar H.
+
+This file is part of pyromod.
+
+pyromod is free software: you can redistribute it and/or modify
+it under the terms of the GNU General Public License as published by
+the Free Software Foundation, either version 3 of the License, or
+(at your option) any later version.
+
+pyromod is distributed in the hope that it will be useful,
+but WITHOUT ANY WARRANTY; without even the implied warranty of
+MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+GNU General Public License for more details.
+
+You should have received a copy of the GNU General Public License
+along with pyromod. If not, see .
+"""
+
+
+def patch(obj):
+ def is_patchable(item):
+ return getattr(item[1], "patchable", False)
+
+ def wrapper(container):
+ for name, func in filter(is_patchable, container.__dict__.items()):
+ old = getattr(obj, name, None)
+ setattr(obj, f"old{name}", old)
+ setattr(obj, name, func)
+ return container
+
+ return wrapper
+
+
+def patchable(func):
+ func.patchable = True
+ return func
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..4324cfb
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,22 @@
+fastapi==0.109.0
+uvicorn==0.25.0
+git+https://github.com/aliev/aioauth
+sqlmodel==0.0.14
+alembic==1.13.1
+aiosqlite==0.19.0
+PyCryptodome==3.20.0
+python-jose[cryptography]==3.3.0
+python-multipart==0.0.6
+orjson==3.9.10
+jinja2==3.1.3
+pydantic~=2.5.3
+pydantic-settings==2.1.0
+SQLAlchemy~=2.0.25
+starlette~=0.35.1
+pyrogram==2.0.106
+tgcrypto==1.2.5
+pytz~=2023.3.post1
+APScheduler~=3.10.4
+coloredlogs~=15.0.1
+httpx==0.26.0
+asyncmy==0.2.9
diff --git a/src/__init__.py b/src/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/app.py b/src/app.py
new file mode 100644
index 0000000..b13e50d
--- /dev/null
+++ b/src/app.py
@@ -0,0 +1,87 @@
+import asyncio
+
+from fastapi import FastAPI
+from fastapi.responses import ORJSONResponse
+from starlette.middleware.authentication import AuthenticationMiddleware
+from starlette.middleware.cors import CORSMiddleware
+
+from .config import settings
+from .events import on_shutdown, on_startup
+from .logs import logs
+from .oauth2 import endpoints as oauth2_endpoints
+from .users import endpoints as users_endpoints
+from .users.backends import TokenAuthenticationBackend
+
+
+class Web:
+ def __init__(self):
+ self.app = FastAPI(
+ title=settings.PROJECT_NAME,
+ docs_url="/api/openapi",
+ openapi_url="/api/openapi.json",
+ default_response_class=ORJSONResponse,
+ on_startup=on_startup,
+ on_shutdown=on_shutdown,
+ )
+ self.web_server = None
+ self.web_server_task = None
+ self.bot_main_task = None
+
+ def init_web(self):
+ # Include API router
+ self.app.include_router(users_endpoints.router, prefix="/api/users", tags=["users"])
+
+ # Define aioauth-fastapi endpoints
+ self.app.include_router(
+ oauth2_endpoints.router,
+ prefix="/oauth2",
+ tags=["oauth2"],
+ )
+ self.app.add_middleware(
+ CORSMiddleware,
+ allow_origins=settings.CORS_ORIGINS,
+ allow_credentials=True,
+ allow_methods=["*"],
+ allow_headers=["*"],
+ )
+ self.app.add_middleware(AuthenticationMiddleware, backend=TokenAuthenticationBackend())
+
+ async def start(self):
+ import uvicorn
+
+ self.init_web()
+ self.web_server = uvicorn.Server(
+ config=uvicorn.Config(
+ self.app,
+ host=settings.PROJECT_HOST,
+ port=settings.PROJECT_PORT,
+ reload=settings.DEBUG,
+ )
+ )
+ server_config = self.web_server.config
+ server_config.setup_event_loop()
+ if not server_config.loaded:
+ server_config.load()
+ self.web_server.lifespan = server_config.lifespan_class(server_config)
+ try:
+ await self.web_server.startup()
+ except OSError as e:
+ if e.errno == 10048:
+ logs.error("Web Server 端口被占用:%s", e)
+ logs.error("Web Server 启动失败,正在退出")
+ raise SystemExit from None
+
+ if self.web_server.should_exit:
+ logs.error("Web Server 启动失败,正在退出")
+ raise SystemExit from None
+ logs.info("Web Server 启动成功")
+ self.web_server_task = asyncio.create_task(self.web_server.main_loop())
+
+ async def stop(self):
+ if self.web_server_task:
+ self.web_server_task.cancel()
+ if self.bot_main_task:
+ self.bot_main_task.cancel()
+
+
+web = Web()
diff --git a/src/bot.py b/src/bot.py
new file mode 100644
index 0000000..3c46665
--- /dev/null
+++ b/src/bot.py
@@ -0,0 +1,13 @@
+import pyromod.listen
+from pyrogram import Client
+
+from .config import settings
+
+bot = Client(
+ "bot",
+ bot_token=settings.BOT_TOKEN,
+ api_id=settings.BOT_API_ID,
+ api_hash=settings.BOT_API_HASH,
+ plugins={"root": "src.telegram.plugins"},
+ workdir="data",
+)
diff --git a/src/config.py b/src/config.py
new file mode 100644
index 0000000..2b50888
--- /dev/null
+++ b/src/config.py
@@ -0,0 +1,35 @@
+from typing import List
+
+from pydantic_settings import BaseSettings
+
+
+class Settings(BaseSettings):
+ PROJECT_NAME: str = "Telegram OAuth"
+ PROJECT_URL: str = "http://127.0.0.1:8081"
+ PROJECT_LOGIN_SUCCESS_URL: str = "http://google.com"
+ PROJECT_HOST: str = "127.0.0.1"
+ PROJECT_PORT: int = 8001
+ DEBUG: bool = True
+
+ CONN_URI: str
+
+ JWT_PUBLIC_KEY: str
+ JWT_PRIVATE_KEY: str
+
+ ACCESS_TOKEN_EXP: int = 900 # 15 minutes
+ REFRESH_TOKEN_EXP: int = 86400 # 1 day
+
+ CORS_ORIGINS: List[str] = ["*"]
+
+ BOT_TOKEN: str
+ BOT_USERNAME: str
+ BOT_API_ID: int
+ BOT_API_HASH: str
+ BOT_MANAGER_IDS: List[int]
+
+ class Config:
+ env_file = ".env"
+ case_sensitive = True
+
+
+settings = Settings()
diff --git a/src/events.py b/src/events.py
new file mode 100644
index 0000000..d88c603
--- /dev/null
+++ b/src/events.py
@@ -0,0 +1,27 @@
+from sqlalchemy.ext.asyncio import create_async_engine
+from sqlalchemy.orm import sessionmaker
+from sqlalchemy.pool import NullPool
+from sqlmodel.ext.asyncio.session import AsyncSession
+
+from .config import settings
+from .storage import sqlalchemy
+
+
+async def create_sqlalchemy_connection():
+ # NOTE: https://docs.sqlalchemy.org/en/14/orm/extensions/asyncio.html#using-multiple-asyncio-event-loops
+ engine = create_async_engine(settings.CONN_URI, echo=True, poolclass=NullPool)
+ async_session = sessionmaker(engine, expire_on_commit=False, class_=AsyncSession)
+ sqlalchemy.sqlalchemy_session = async_session()
+
+
+async def close_sqlalchemy_connection():
+ if sqlalchemy.sqlalchemy_session is not None:
+ await sqlalchemy.sqlalchemy_session.close()
+
+
+on_startup = [
+ create_sqlalchemy_connection,
+]
+on_shutdown = [
+ close_sqlalchemy_connection,
+]
diff --git a/src/html.py b/src/html.py
new file mode 100644
index 0000000..1c1448a
--- /dev/null
+++ b/src/html.py
@@ -0,0 +1,3 @@
+from starlette.templating import Jinja2Templates
+
+templates = Jinja2Templates(directory="html")
diff --git a/src/logs.py b/src/logs.py
new file mode 100644
index 0000000..61a3d8f
--- /dev/null
+++ b/src/logs.py
@@ -0,0 +1,21 @@
+from logging import getLogger, StreamHandler, basicConfig, INFO, CRITICAL, ERROR
+
+from coloredlogs import ColoredFormatter
+
+logs = getLogger("telegram-oauth")
+logging_format = "%(levelname)s [%(asctime)s] [%(name)s] %(message)s"
+logging_handler = StreamHandler()
+logging_handler.setFormatter(ColoredFormatter(logging_format))
+root_logger = getLogger()
+root_logger.setLevel(CRITICAL)
+root_logger.addHandler(logging_handler)
+pyro_logger = getLogger("pyrogram")
+pyro_logger.setLevel(INFO)
+sql_logger = getLogger("sqlalchemy")
+sql_logger.setLevel(CRITICAL)
+sql_engine_logger = getLogger("sqlalchemy.engine.Engine")
+sql_engine_logger.setLevel(CRITICAL)
+aioauth_logger = getLogger("aioauth")
+aioauth_logger.setLevel(INFO)
+basicConfig(level=ERROR)
+logs.setLevel(INFO)
diff --git a/src/oauth2/__init__.py b/src/oauth2/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/oauth2/endpoints.py b/src/oauth2/endpoints.py
new file mode 100644
index 0000000..4acb3ce
--- /dev/null
+++ b/src/oauth2/endpoints.py
@@ -0,0 +1,56 @@
+from aioauth.config import Settings
+from aioauth.oidc.core.grant_type import AuthorizationCodeGrantType
+from aioauth.requests import Request as OAuth2Request
+from aioauth.server import AuthorizationServer
+from fastapi import APIRouter, Depends, Request
+
+from aioauth_fastapi.utils import to_fastapi_response, to_oauth2_request
+from .storage import Storage
+from ..config import settings as local_settings
+from ..storage.sqlalchemy import SQLAlchemyStorage, get_sqlalchemy_storage
+from ..users.crypto import get_pub_key_resp
+from ..utils.oauth import to_login_request
+
+router = APIRouter()
+
+settings = Settings(
+ TOKEN_EXPIRES_IN=local_settings.ACCESS_TOKEN_EXP,
+ REFRESH_TOKEN_EXPIRES_IN=local_settings.REFRESH_TOKEN_EXP,
+ INSECURE_TRANSPORT=local_settings.DEBUG,
+)
+grant_types = {
+ "authorization_code": AuthorizationCodeGrantType,
+}
+
+
+@router.post("/token")
+async def token(
+ request: Request,
+ storage: SQLAlchemyStorage = Depends(get_sqlalchemy_storage),
+):
+ oauth2_storage = Storage(storage=storage)
+ authorization_server = AuthorizationServer(storage=oauth2_storage, grant_types=grant_types)
+ oauth2_request: OAuth2Request = await to_oauth2_request(request, settings)
+ oauth2_response = await authorization_server.create_token_response(oauth2_request)
+ return await to_fastapi_response(oauth2_response)
+
+
+@router.get("/authorize")
+async def authorize(
+ request: Request,
+ storage: SQLAlchemyStorage = Depends(get_sqlalchemy_storage),
+):
+ if not request.user.is_authenticated:
+ return await to_login_request(request)
+ oauth2_storage = Storage(storage=storage)
+ authorization_server = AuthorizationServer(storage=oauth2_storage, grant_types=grant_types)
+ oauth2_request: OAuth2Request = await to_oauth2_request(request, settings)
+ oauth2_response = await authorization_server.create_authorization_response(
+ oauth2_request
+ )
+ return await to_fastapi_response(oauth2_response)
+
+
+@router.get("/keys")
+async def keys():
+ return get_pub_key_resp()
diff --git a/src/oauth2/models.py b/src/oauth2/models.py
new file mode 100644
index 0000000..4a16aad
--- /dev/null
+++ b/src/oauth2/models.py
@@ -0,0 +1,50 @@
+from typing import TYPE_CHECKING, Optional
+
+from pydantic.types import UUID4
+from sqlmodel.main import Field, Relationship
+
+from ..storage.models import BaseTable
+
+if TYPE_CHECKING: # pragma: no cover
+ from ..users.models import User
+
+
+class Client(BaseTable, table=True): # type: ignore
+ client_id: str
+ client_secret: str
+ grant_types: str
+ response_types: str
+ redirect_uris: str
+
+ scope: str
+
+
+class AuthorizationCode(BaseTable, table=True): # type: ignore
+ code: str
+ client_id: str
+ redirect_uri: str
+ response_type: str
+ scope: str
+ auth_time: int
+ expires_in: int
+ code_challenge: Optional[str]
+ code_challenge_method: Optional[str]
+ nonce: Optional[str]
+
+ user_id: UUID4 = Field(foreign_key="users.id", nullable=False)
+ user: "User" = Relationship(back_populates="user_authorization_codes")
+
+
+class Token(BaseTable, table=True): # type: ignore
+ access_token: str
+ refresh_token: str
+ scope: str
+ issued_at: int
+ expires_in: int
+ refresh_token_expires_in: int
+ client_id: str
+ token_type: str
+ revoked: bool
+
+ user_id: UUID4 = Field(foreign_key="users.id", nullable=False)
+ user: "User" = Relationship(back_populates="user_tokens")
diff --git a/src/oauth2/storage.py b/src/oauth2/storage.py
new file mode 100644
index 0000000..d008a2a
--- /dev/null
+++ b/src/oauth2/storage.py
@@ -0,0 +1,289 @@
+from datetime import datetime, timezone
+from typing import Optional
+
+from aioauth.models import AuthorizationCode, Client, Token
+from aioauth.requests import Request
+from aioauth.storage import BaseStorage
+from aioauth.types import CodeChallengeMethod, ResponseType, TokenType
+from aioauth.utils import enforce_list
+from sqlalchemy.future import select
+from sqlalchemy.orm import selectinload
+from sqlalchemy.sql.expression import delete
+
+from .models import AuthorizationCode as AuthorizationCodeDB
+from .models import Client as ClientDB
+from .models import Token as TokenDB
+from ..config import settings
+from ..storage.sqlalchemy import SQLAlchemyStorage
+from ..users.crypto import encode_jwt, get_jwt
+from ..users.models import User
+
+
+class Storage(BaseStorage):
+ def __init__(self, storage: SQLAlchemyStorage):
+ self.storage = storage
+
+ async def get_user(self, request: Request):
+ user: Optional[User] = None
+
+ if request.query.response_type == "token":
+ # If ResponseType is token get the user from current session
+ user = request.user
+
+ if request.post.grant_type == "authorization_code":
+ # If GrantType is authorization code get user from DB by code
+ q_results = await self.storage.select(
+ select(AuthorizationCodeDB).where(
+ AuthorizationCodeDB.code == request.post.code
+ )
+ )
+
+ authorization_code: Optional[AuthorizationCodeDB]
+ authorization_code = q_results.scalars().one_or_none()
+
+ if not authorization_code:
+ return
+
+ q_results = await self.storage.select(
+ select(User).where(User.id == authorization_code.user_id)
+ )
+
+ user = q_results.scalars().one_or_none()
+
+ if request.post.grant_type == "refresh_token":
+ # Get user from token
+ q_results = await self.storage.select(
+ select(TokenDB)
+ .where(TokenDB.refresh_token == request.post.refresh_token)
+ .options(selectinload(TokenDB.user))
+ )
+
+ token: Optional[TokenDB]
+
+ token = q_results.scalars().one_or_none()
+
+ if not token:
+ return
+
+ user = token.user
+
+ return user
+
+ async def create_token(
+ self,
+ request: Request,
+ client_id: str,
+ scope: str,
+ access_token: str,
+ refresh_token: str,
+ ) -> Token:
+ """
+ Create token and store it in storage.
+ """
+ user = await self.get_user(request)
+
+ _access_token, _refresh_token = get_jwt(user)
+
+ token = Token(
+ access_token=_access_token,
+ client_id=client_id,
+ expires_in=300,
+ issued_at=int(datetime.now(tz=timezone.utc).timestamp()),
+ refresh_token=_refresh_token,
+ refresh_token_expires_in=900,
+ revoked=False,
+ scope=scope,
+ token_type="Bearer",
+ user=user,
+ )
+
+ token_record = TokenDB(
+ access_token=token.access_token,
+ refresh_token=token.refresh_token,
+ scope=token.scope,
+ issued_at=token.issued_at,
+ expires_in=token.expires_in,
+ refresh_token_expires_in=token.refresh_token_expires_in,
+ client_id=token.client_id,
+ token_type=token.token_type,
+ revoked=token.revoked,
+ user_id=user.id,
+ )
+
+ await self.storage.add(token_record)
+
+ return token
+
+ async def revoke_token(
+ self,
+ request: Request,
+ token_type: Optional[TokenType] = "refresh_token",
+ access_token: Optional[str] = None,
+ refresh_token: Optional[str] = None,
+ ) -> None:
+ """
+ Remove refresh_token from whitelist.
+ """
+ q_results = await self.storage.select(
+ select(TokenDB).where(TokenDB.refresh_token == refresh_token)
+ )
+ token_record: Optional[TokenDB]
+ token_record = q_results.scalars().one_or_none()
+
+ if token_record:
+ token_record.revoked = True
+ await self.storage.add(token_record)
+
+ async def get_token(
+ self,
+ request: Request,
+ client_id: str,
+ token_type: Optional[str] = "refresh_token",
+ access_token: Optional[str] = None,
+ refresh_token: Optional[str] = None,
+ ) -> Optional[Token]:
+ if token_type == "refresh_token":
+ q = select(TokenDB).where(TokenDB.refresh_token == refresh_token)
+ else:
+ q = select(TokenDB).where(TokenDB.access_token == access_token)
+
+ q_results = await self.storage.select(
+ q.where(TokenDB.revoked == False).options( # noqa
+ selectinload(TokenDB.user)
+ )
+ )
+
+ token_record: Optional[TokenDB]
+ token_record = q_results.scalars().one_or_none()
+
+ if token_record:
+ return Token(
+ access_token=token_record.access_token,
+ refresh_token=token_record.refresh_token,
+ scope=token_record.scope,
+ issued_at=token_record.issued_at,
+ expires_in=token_record.expires_in,
+ refresh_token_expires_in=token_record.refresh_token_expires_in,
+ client_id=client_id,
+ )
+
+ async def create_authorization_code(
+ self,
+ request: Request,
+ client_id: str,
+ scope: str,
+ response_type: ResponseType,
+ redirect_uri: str,
+ code_challenge_method: Optional[CodeChallengeMethod],
+ code_challenge: Optional[str],
+ code: str,
+ **kwargs,
+ ) -> AuthorizationCode:
+ authorization_code = AuthorizationCode(
+ auth_time=int(datetime.now(tz=timezone.utc).timestamp()),
+ client_id=client_id,
+ code=code,
+ code_challenge=code_challenge,
+ code_challenge_method=code_challenge_method,
+ expires_in=300,
+ redirect_uri=redirect_uri,
+ response_type=response_type,
+ scope=scope,
+ user=request.user,
+ )
+
+ authorization_code_record = AuthorizationCodeDB(
+ code=authorization_code.code,
+ client_id=authorization_code.client_id,
+ redirect_uri=authorization_code.redirect_uri,
+ response_type=authorization_code.response_type,
+ scope=authorization_code.scope,
+ auth_time=authorization_code.auth_time,
+ expires_in=authorization_code.expires_in,
+ code_challenge_method=authorization_code.code_challenge_method,
+ code_challenge=authorization_code.code_challenge,
+ nonce=authorization_code.nonce,
+ user_id=request.user.id,
+ )
+
+ await self.storage.add(authorization_code_record)
+
+ return authorization_code
+
+ async def get_client(
+ self, request: Request, client_id: str, client_secret: Optional[str] = None
+ ) -> Optional[Client]:
+ q_results = await self.storage.select(
+ select(ClientDB).where(ClientDB.client_id == client_id)
+ )
+
+ client_record: Optional[ClientDB]
+ client_record = q_results.scalars().one_or_none()
+
+ if not client_record:
+ return None
+
+ return Client(
+ client_id=client_record.client_id,
+ client_secret=client_record.client_secret,
+ grant_types=[client_record.grant_types],
+ response_types=[client_record.response_types],
+ redirect_uris=[client_record.redirect_uris],
+ scope=client_record.scope,
+ )
+
+ async def get_authorization_code(
+ self, request: Request, client_id: str, code: str
+ ) -> Optional[AuthorizationCode]:
+ q_results = await self.storage.select(
+ select(AuthorizationCodeDB).where(AuthorizationCodeDB.code == code)
+ )
+
+ authorization_code_record: Optional[AuthorizationCode]
+ authorization_code_record = q_results.scalars().one_or_none()
+
+ if not authorization_code_record:
+ return None
+
+ return AuthorizationCode(
+ code=authorization_code_record.code,
+ client_id=authorization_code_record.client_id,
+ redirect_uri=authorization_code_record.redirect_uri,
+ response_type=authorization_code_record.response_type,
+ scope=authorization_code_record.scope,
+ auth_time=authorization_code_record.auth_time,
+ expires_in=authorization_code_record.expires_in,
+ code_challenge=authorization_code_record.code_challenge,
+ code_challenge_method=authorization_code_record.code_challenge_method,
+ nonce=authorization_code_record.nonce,
+ )
+
+ async def delete_authorization_code(
+ self, request: Request, client_id: str, code: str
+ ) -> None:
+ await self.storage.delete(
+ delete(AuthorizationCodeDB).where(AuthorizationCodeDB.code == code)
+ )
+
+ async def get_id_token(
+ self,
+ request: Request,
+ client_id: str,
+ scope: str,
+ response_type: ResponseType,
+ redirect_uri: str,
+ **kwargs,
+ ) -> str:
+ scopes = enforce_list(scope)
+ user = await self.get_user(request)
+ user_data = {}
+
+ if "email" in scopes:
+ user_data["email"] = user.username
+ user_data["username"] = user.username
+
+ return encode_jwt(
+ expires_delta=settings.ACCESS_TOKEN_EXP,
+ sub=str(user.id),
+ additional_claims=user_data,
+ )
diff --git a/src/scheduler.py b/src/scheduler.py
new file mode 100644
index 0000000..ad08020
--- /dev/null
+++ b/src/scheduler.py
@@ -0,0 +1,33 @@
+import contextlib
+import datetime
+from typing import TYPE_CHECKING
+
+import pytz
+from apscheduler.schedulers.asyncio import AsyncIOScheduler
+
+if TYPE_CHECKING:
+ from src.telegram.enums import Message
+
+scheduler = AsyncIOScheduler(timezone="Asia/ShangHai")
+if not scheduler.running:
+ scheduler.start()
+
+
+async def delete_message(message: "Message") -> bool:
+ with contextlib.suppress(Exception):
+ await message.delete()
+ return True
+ return False
+
+
+def add_delete_message_job(message: "Message", delete_seconds: int = 60):
+ scheduler.add_job(
+ delete_message,
+ "date",
+ id=f"{message.chat.id}|{message.id}|delete_message",
+ name=f"{message.chat.id}|{message.id}|delete_message",
+ args=[message],
+ run_date=datetime.datetime.now(pytz.timezone("Asia/Shanghai"))
+ + datetime.timedelta(seconds=delete_seconds),
+ replace_existing=True,
+ )
diff --git a/src/storage/__init__.py b/src/storage/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/storage/models.py b/src/storage/models.py
new file mode 100644
index 0000000..62108e5
--- /dev/null
+++ b/src/storage/models.py
@@ -0,0 +1,14 @@
+import uuid
+
+from pydantic.types import UUID4
+from sqlmodel import Field, SQLModel
+
+
+class BaseTable(SQLModel):
+ id: UUID4 = Field(
+ primary_key=True,
+ default_factory=uuid.uuid4,
+ nullable=False,
+ index=True,
+ sa_column_kwargs={"unique": True},
+ )
diff --git a/src/storage/sqlalchemy.py b/src/storage/sqlalchemy.py
new file mode 100644
index 0000000..205405c
--- /dev/null
+++ b/src/storage/sqlalchemy.py
@@ -0,0 +1,70 @@
+from typing import Optional
+
+from sqlalchemy.engine.result import Result
+from sqlalchemy.sql.expression import Delete, Update
+from sqlalchemy.sql.selectable import Select
+from sqlmodel.ext.asyncio.session import AsyncSession
+
+
+class SQLAlchemyTransaction:
+ def __init__(self, session: AsyncSession) -> None:
+ self.session = session
+
+ async def __aenter__(self) -> "SQLAlchemyTransaction":
+ return self
+
+ async def __aexit__(self, exc_type, exc_value, traceback) -> None:
+ if exc_type is None:
+ await self.commit()
+ else:
+ await self.rollback()
+
+ await self.close()
+
+ async def rollback(self):
+ await self.session.rollback()
+
+ async def commit(self):
+ await self.session.commit()
+
+ async def close(self):
+ await self.session.close()
+
+
+class SQLAlchemyStorage:
+ def __init__(
+ self, session: AsyncSession, transaction: SQLAlchemyTransaction
+ ) -> None:
+ self.session = session
+ self.transaction = transaction
+
+ async def select(self, q: Select) -> Result:
+ async with self.transaction:
+ return await self.session.execute(q)
+
+ async def add(self, model) -> None:
+ async with self.transaction:
+ self.session.add(model)
+
+ async def delete(self, q: Delete) -> None:
+ async with self.transaction:
+ await self.session.execute(q)
+
+ async def update(self, q: Update):
+ async with self.transaction:
+ await self.session.execute(q)
+
+
+sqlalchemy_session: Optional[AsyncSession] = None
+
+
+def get_sqlalchemy_storage() -> SQLAlchemyStorage:
+ """Get SQLAlchemy storage instance.
+
+ Returns:
+ SQLAlchemyStorage: SQLAlchemy storage instance
+ """
+ sqllachemy_trancation = SQLAlchemyTransaction(session=sqlalchemy_session)
+ return SQLAlchemyStorage(
+ session=sqlalchemy_session, transaction=sqllachemy_trancation
+ )
diff --git a/src/telegram/__init__.py b/src/telegram/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/telegram/enums.py b/src/telegram/enums.py
new file mode 100644
index 0000000..2cfe36c
--- /dev/null
+++ b/src/telegram/enums.py
@@ -0,0 +1,25 @@
+from typing import Optional
+
+from pyrogram import Client as PyroClient
+from pyrogram.types import Message as PyroMessage
+
+
+class Client(PyroClient): # noqa
+ async def listen(self, chat_id, filters=None, timeout=None) -> Optional["Message"]:
+ return
+
+ async def ask(
+ self, chat_id, text, filters=None, timeout=None, *args, **kwargs
+ ) -> Optional["Message"]:
+ return
+
+ def cancel_listener(self, chat_id):
+ """Cancel the conversation with the given chat_id."""
+
+
+class Message(PyroMessage): # noqa
+ async def delay_delete(self, delete_seconds: int = 60) -> Optional[bool]:
+ return
+
+ async def safe_delete(self, revoke: bool = True) -> None:
+ return
diff --git a/src/telegram/filters.py b/src/telegram/filters.py
new file mode 100644
index 0000000..aa89b8f
--- /dev/null
+++ b/src/telegram/filters.py
@@ -0,0 +1,15 @@
+from typing import TYPE_CHECKING
+
+from pyrogram.filters import create
+
+from src.config import settings
+
+if TYPE_CHECKING:
+ from .enums import Message
+
+
+async def admin_filter(_, __, m: "Message"):
+ return bool(m.from_user and m.from_user.id in settings.BOT_MANAGER_IDS)
+
+
+admin = create(admin_filter)
diff --git a/src/telegram/message.py b/src/telegram/message.py
new file mode 100644
index 0000000..2b2952c
--- /dev/null
+++ b/src/telegram/message.py
@@ -0,0 +1,8 @@
+import re
+
+NO_ACCOUNT_MSG = """UID `%s` 还没有注册账号,请联系管理员注册账号。"""
+ACCOUNT_MSG = """UID: `%s`\n邮箱: `%s`"""
+REG_MSG = """请发送需要使用的邮箱"""
+MAIL_REGEX = re.compile(r"^[a-zA-Z0-9_-]+@[a-zA-Z0-9_-]+(\.[a-zA-Z0-9_-]+)+$")
+LOGIN_MSG = """请点击下面的按钮登录:"""
+LOGIN_BUTTON = """跳转登录"""
diff --git a/src/telegram/plugins/__init__.py b/src/telegram/plugins/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/telegram/plugins/account.py b/src/telegram/plugins/account.py
new file mode 100644
index 0000000..24e1645
--- /dev/null
+++ b/src/telegram/plugins/account.py
@@ -0,0 +1,24 @@
+from pyrogram import filters
+
+from src.bot import bot
+from src.config import settings
+from src.telegram.enums import Client, Message
+from src.telegram.message import ACCOUNT_MSG, NO_ACCOUNT_MSG
+from src.users.crud import get_user_crud
+
+
+async def account(message: Message, uid: int):
+ crud = get_user_crud()
+ user = await crud.get_by_tg_id(uid)
+ if user:
+ await message.reply(ACCOUNT_MSG % (user.tg_id, user.username), quote=True)
+ else:
+ await message.reply(NO_ACCOUNT_MSG % uid, quote=True)
+
+
+@bot.on_message(filters=filters.private & filters.command("account"))
+async def get_account(_: Client, message: Message):
+ uid = message.from_user.id
+ if uid in settings.BOT_MANAGER_IDS and len(message.command) >= 2 and message.command[1].isnumeric():
+ uid = int(message.command[1])
+ await account(message, uid)
diff --git a/src/telegram/plugins/edit.py b/src/telegram/plugins/edit.py
new file mode 100644
index 0000000..acc4e34
--- /dev/null
+++ b/src/telegram/plugins/edit.py
@@ -0,0 +1,44 @@
+from pyrogram import filters
+
+from pyromod.utils.errors import TimeoutConversationError
+from src.bot import bot
+from src.logs import logs
+from src.telegram.enums import Client, Message
+from src.telegram.filters import admin
+from src.telegram.message import REG_MSG, MAIL_REGEX
+from src.users.crud import get_user_crud
+
+
+async def reg(client: Client, from_id: int, uid: int):
+ msg_ = await client.send_message(from_id, REG_MSG)
+ try:
+ msg = await client.listen(from_id, filters=filters.text, timeout=60)
+ except TimeoutConversationError:
+ await msg_.edit("响应超时,请重试")
+ return
+ if msg.text and MAIL_REGEX.match(msg.text):
+ crud = get_user_crud()
+ try:
+ user = await crud.get_by_tg_id(uid)
+ if user:
+ await crud.update(user, username=msg.text)
+ else:
+ await crud.create(
+ username=msg.text,
+ password="1",
+ tg_id=uid,
+ )
+ except Exception as e:
+ logs.exception("注册失败", exc_info=e)
+ await msg.reply_text("注册失败")
+ await msg.reply_text("注册成功")
+ else:
+ await msg.reply_text("邮箱格式错误")
+
+
+@bot.on_message(filters=filters.private & filters.command("edit") & admin)
+async def edit_account(client: Client, message: Message):
+ uid = from_id = message.from_user.id
+ if len(message.command) >= 2 and message.command[1].isnumeric():
+ uid = int(message.command[1])
+ await reg(client, from_id, uid)
diff --git a/src/telegram/plugins/start.py b/src/telegram/plugins/start.py
new file mode 100644
index 0000000..090aea7
--- /dev/null
+++ b/src/telegram/plugins/start.py
@@ -0,0 +1,41 @@
+from httpx import URL
+from pyrogram import filters, Client
+from pyrogram.types import InlineKeyboardMarkup, InlineKeyboardButton
+
+from src.bot import bot
+from src.config import settings
+from src.telegram.enums import Message
+from src.telegram.message import NO_ACCOUNT_MSG, LOGIN_MSG, LOGIN_BUTTON
+from src.users.crud import get_user_crud
+from src.utils.telegram import encode_telegram_auth_data
+
+
+async def login(message: Message):
+ uid = message.from_user.id
+ crud = get_user_crud()
+ user = await crud.get_by_tg_id(uid)
+ if not user:
+ await message.reply(NO_ACCOUNT_MSG % uid, quote=True)
+ return
+ token = await encode_telegram_auth_data(uid)
+ url = settings.PROJECT_URL + "/api/users/auth"
+ url = URL(url).copy_add_param("jwt", token)
+ url = str(url)
+ await message.reply(
+ LOGIN_MSG,
+ quote=True,
+ reply_markup=InlineKeyboardMarkup(
+ [[InlineKeyboardButton(LOGIN_BUTTON, url=url)]]
+ ),
+ )
+
+
+@bot.on_message(filters=filters.private & filters.command("start"))
+async def start(client: Client, message: Message):
+ if message.command and len(message.command) >= 2:
+ action = message.command[1]
+ if action == "login":
+ await login(message)
+ return
+ me = await client.get_me()
+ await message.reply(f"Hello, I'm {me.first_name}. ")
diff --git a/src/users/__init__.py b/src/users/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/users/backends.py b/src/users/backends.py
new file mode 100644
index 0000000..8fee728
--- /dev/null
+++ b/src/users/backends.py
@@ -0,0 +1,32 @@
+from fastapi.security.utils import get_authorization_scheme_param
+from starlette.authentication import AuthCredentials, AuthenticationBackend
+
+from .crypto import authenticate, read_rsa_key_from_env
+from .models import User, UserAnonymous
+from ..config import settings
+
+
+class TokenAuthenticationBackend(AuthenticationBackend):
+ async def authenticate(self, request):
+ authorization: str = request.headers.get("Authorization")
+ _, bearer_token = get_authorization_scheme_param(authorization)
+
+ token: str = request.cookies.get("access_token") or bearer_token
+
+ if not token:
+ return AuthCredentials(), UserAnonymous()
+
+ key = read_rsa_key_from_env(settings.JWT_PUBLIC_KEY)
+
+ is_authenticated, decoded_token = authenticate(token=token, key=key)
+
+ if is_authenticated:
+ return AuthCredentials(), User(
+ id=decoded_token["sub"],
+ is_superuser=decoded_token["is_superuser"],
+ is_blocked=decoded_token["is_blocked"],
+ is_active=decoded_token["is_active"],
+ username=decoded_token["username"],
+ )
+
+ return AuthCredentials(), UserAnonymous()
diff --git a/src/users/crud.py b/src/users/crud.py
new file mode 100644
index 0000000..b3a61f6
--- /dev/null
+++ b/src/users/crud.py
@@ -0,0 +1,40 @@
+from typing import Optional
+
+from sqlalchemy.future import select
+from sqlalchemy.orm import selectinload
+from sqlalchemy.sql.expression import Update
+
+from .models import User
+from ..storage.sqlalchemy import SQLAlchemyStorage, get_sqlalchemy_storage
+
+
+class SQLAlchemyCRUD:
+ def __init__(self, storage: SQLAlchemyStorage):
+ self.storage = storage
+
+ async def get_by_tg_id(self, tg_id: int) -> Optional[User]:
+ q_results = await self.storage.select(
+ select(User)
+ .options(
+ # for relationship loading, eager loading should be applied.
+ selectinload(User.user_tokens)
+ )
+ .where(User.tg_id == tg_id)
+ )
+
+ return q_results.scalars().one_or_none()
+
+ async def create(self, **kwargs) -> None:
+ user = User(**kwargs)
+ await self.storage.add(user)
+
+ async def update(self, user: User, **kwargs) -> None:
+ await self.storage.update(
+ Update(User).where(User.id == user.id).values(**kwargs)
+ )
+
+
+def get_user_crud(storage: SQLAlchemyStorage = None) -> SQLAlchemyCRUD:
+ if storage is None:
+ storage = get_sqlalchemy_storage()
+ return SQLAlchemyCRUD(storage=storage)
diff --git a/src/users/crypto.py b/src/users/crypto.py
new file mode 100644
index 0000000..84b0724
--- /dev/null
+++ b/src/users/crypto.py
@@ -0,0 +1,166 @@
+import base64
+import pathlib
+import re
+import string
+import uuid
+from datetime import datetime, timedelta, timezone
+from typing import Dict, Tuple
+
+from Crypto.PublicKey import RSA
+from jose import constants, jwt
+from jose.exceptions import JWTError
+
+from ..config import settings
+
+RANDOM_STRING_CHARS = string.ascii_lowercase + string.ascii_uppercase + string.digits
+KEYS = {}
+
+
+def reformat_rsa_key(rsa_key: str) -> str:
+ """Reformat an RSA PEM key without newlines to one with correct newline characters
+
+ @param rsa_key: the PEM RSA key lacking newline characters
+ @return: the reformatted PEM RSA key with appropriate newline characters
+ """
+ # split headers from the body
+ split_rsa_key = re.split(r"(-+)", rsa_key)
+
+ # add newlines between headers and body
+ split_rsa_key.insert(4, "\n")
+ split_rsa_key.insert(6, "\n")
+
+ reformatted_rsa_key = "".join(split_rsa_key)
+
+ # reformat body
+ return RSA.importKey(reformatted_rsa_key).exportKey().decode("utf-8")
+
+
+def read_rsa_key_from_env(file_path: str) -> str:
+ if file_path in KEYS:
+ return KEYS[file_path]
+ path = pathlib.Path(file_path)
+
+ # path to rsa key file
+ if path.is_file():
+ with open(file_path, "rb") as key_file:
+ jwt_private_key = RSA.importKey(key_file.read()).exportKey()
+ k = jwt_private_key.decode("utf-8")
+ KEYS[file_path] = k
+ return k
+
+ # rsa key without newlines
+ if "\n" not in file_path:
+ k = reformat_rsa_key(file_path)
+ KEYS[file_path] = k
+ return k
+
+ return file_path
+
+
+def get_n(rsa: RSA):
+ bytes_data = rsa.n.to_bytes((rsa.n.bit_length() + 7) // 8, 'big')
+ return base64.urlsafe_b64encode(bytes_data).decode('utf-8')
+
+
+def get_pub_key_resp():
+ pub_key = RSA.importKey(read_rsa_key_from_env(settings.JWT_PUBLIC_KEY))
+ return {
+ "keys": [
+ {
+ "n": get_n(pub_key),
+ "kty": "RSA",
+ "alg": "RS256",
+ "kid": "sig",
+ "e": "AQAB",
+ "use": "sig"
+ }
+ ]
+ }
+
+
+def encode_jwt(
+ expires_delta,
+ sub,
+ secret=None,
+ additional_claims=None,
+ algorithm=constants.ALGORITHMS.RS256,
+):
+ if additional_claims is None:
+ additional_claims = {}
+ if secret is None:
+ secret = read_rsa_key_from_env(settings.JWT_PRIVATE_KEY)
+ now = datetime.now(timezone.utc)
+
+ claims = {
+ "iat": now,
+ "jti": str(uuid.uuid4()),
+ "nbf": now,
+ "sub": sub,
+ "exp": now + timedelta(seconds=expires_delta),
+ **additional_claims,
+ }
+
+ return jwt.encode(
+ claims,
+ secret,
+ algorithm,
+ )
+
+
+def decode_jwt(
+ encoded_token,
+ secret=None,
+ algorithms=None,
+):
+ if algorithms is None:
+ algorithms = constants.ALGORITHMS.RS256
+ if secret is None:
+ secret = read_rsa_key_from_env(settings.JWT_PRIVATE_KEY)
+ return jwt.decode(
+ encoded_token,
+ secret,
+ algorithms=algorithms,
+ )
+
+
+def get_jwt(user):
+ access_token = encode_jwt(
+ sub=str(user.id),
+ expires_delta=settings.ACCESS_TOKEN_EXP,
+ additional_claims={
+ "token_type": "access",
+ "is_blocked": user.is_blocked,
+ "is_superuser": user.is_superuser,
+ "username": user.username,
+ "is_active": user.is_active,
+ },
+ )
+
+ refresh_token = encode_jwt(
+ sub=str(user.id),
+ expires_delta=settings.REFRESH_TOKEN_EXP,
+ additional_claims={
+ "token_type": "refresh",
+ "is_blocked": user.is_blocked,
+ "is_superuser": user.is_superuser,
+ "username": user.username,
+ "is_active": user.is_active,
+ },
+ )
+
+ return access_token, refresh_token
+
+
+def authenticate(
+ *,
+ token: str,
+ key: str,
+) -> Tuple[bool, Dict]:
+ """Authenticate user by token"""
+ try:
+ token_header = jwt.get_unverified_header(token)
+ decoded_token = jwt.decode(token, key, algorithms=token_header.get("alg"))
+ except JWTError:
+ return False, {}
+ else:
+ return True, decoded_token
diff --git a/src/users/endpoints.py b/src/users/endpoints.py
new file mode 100644
index 0000000..c4e4a61
--- /dev/null
+++ b/src/users/endpoints.py
@@ -0,0 +1,79 @@
+from http import HTTPStatus
+
+from fastapi import APIRouter, Depends, HTTPException
+from jose import JWTError
+from starlette.requests import Request
+
+from .crud import SQLAlchemyCRUD
+from .crypto import get_jwt
+from ..config import settings
+from ..html import templates
+from ..storage.sqlalchemy import SQLAlchemyStorage, get_sqlalchemy_storage
+from ..utils.oauth import back_auth_request
+from ..utils.redirect import RedirectResponseBuilder
+from ..utils.telegram import decode_telegram_auth_data, verify_telegram_auth_data
+
+router = APIRouter()
+
+
+@router.get("/login", name="users:login:get")
+async def user_login_get(request: Request):
+ if request.user.is_authenticated:
+ if resp := await back_auth_request(request):
+ return resp
+ return RedirectResponseBuilder().build(settings.PROJECT_LOGIN_SUCCESS_URL)
+ url = request.url
+ callback_url = str(url).replace("/login", "/callback")
+ return templates.TemplateResponse(
+ "login.jinja",
+ {"request": request, "callback_url": callback_url, "username": settings.BOT_USERNAME}
+ )
+
+
+async def auth(
+ tg_id: int,
+ request: Request,
+ storage: SQLAlchemyStorage,
+):
+ if tg_id is None:
+ raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED)
+ crud = SQLAlchemyCRUD(storage=storage)
+ user = await crud.get_by_tg_id(tg_id=tg_id)
+
+ if user is None:
+ raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED)
+
+ access_token, refresh_token = get_jwt(user)
+ # NOTE: Setting expire causes an exception for requests library:
+ # https://github.com/psf/requests/issues/6004
+ if resp := await back_auth_request(request, access_token, refresh_token):
+ return resp
+ resp = RedirectResponseBuilder()
+ resp.set_cookie(
+ key="access_token", value=access_token, max_age=settings.ACCESS_TOKEN_EXP
+ )
+ resp.set_cookie(
+ key="refresh_token", value=refresh_token, max_age=settings.REFRESH_TOKEN_EXP
+ )
+ return resp.build(settings.PROJECT_LOGIN_SUCCESS_URL)
+
+
+@router.get("/callback", name="users:login")
+async def user_login(
+ request: Request,
+ storage: SQLAlchemyStorage = Depends(get_sqlalchemy_storage),
+):
+ tg_id = await verify_telegram_auth_data(request.query_params)
+ return await auth(tg_id, request, storage)
+
+
+@router.get("/auth", name="users:auth")
+async def user_auth(
+ request: Request,
+ storage: SQLAlchemyStorage = Depends(get_sqlalchemy_storage),
+):
+ try:
+ tg_id = await decode_telegram_auth_data(request.query_params)
+ except JWTError:
+ tg_id = None
+ return await auth(tg_id, request, storage)
diff --git a/src/users/models.py b/src/users/models.py
new file mode 100644
index 0000000..daf7643
--- /dev/null
+++ b/src/users/models.py
@@ -0,0 +1,37 @@
+from typing import TYPE_CHECKING, List, Optional
+
+from pydantic import BaseModel
+from sqlalchemy import Column, BigInteger
+from sqlmodel.main import Field, Relationship
+
+from ..storage.models import BaseTable
+
+if TYPE_CHECKING: # pragma: no cover
+ from ..oauth2.models import AuthorizationCode, Token
+
+
+class UserAnonymous(BaseModel):
+ @property
+ def is_authenticated(self) -> bool:
+ return False
+
+
+class User(BaseTable, table=True): # type: ignore
+ __tablename__ = "users"
+
+ is_superuser: bool = False
+ is_blocked: bool = False
+ is_active: bool = False
+
+ username: str = Field(nullable=False, sa_column_kwargs={"unique": True}, index=True)
+ password: Optional[str] = None
+ tg_id: int = Field(sa_column=Column(BigInteger(), nullable=False))
+
+ user_authorization_codes: List["AuthorizationCode"] = Relationship(
+ back_populates="user"
+ )
+ user_tokens: List["Token"] = Relationship(back_populates="user")
+
+ @property
+ def is_authenticated(self) -> bool:
+ return True
diff --git a/src/utils/__init__.py b/src/utils/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/utils/oauth.py b/src/utils/oauth.py
new file mode 100644
index 0000000..bafe9b2
--- /dev/null
+++ b/src/utils/oauth.py
@@ -0,0 +1,38 @@
+from typing import Optional
+
+from starlette.requests import Request
+from starlette.responses import RedirectResponse
+
+from src.config import settings
+from src.users.crypto import encode_jwt, decode_jwt
+from src.utils.redirect import RedirectResponseBuilder
+
+
+async def to_login_request(request: Request) -> RedirectResponse:
+ query_params = dict(request.query_params)
+ params = ""
+ for key, value in query_params.items():
+ params += f"{key}={value}&"
+ params = params[:-1]
+ jwt = encode_jwt(settings.ACCESS_TOKEN_EXP, "", additional_claims={"params": params})
+ resp = RedirectResponseBuilder()
+ resp.set_cookie("SEND", jwt, max_age=settings.ACCESS_TOKEN_EXP)
+ return resp.build("/api/users/login")
+
+
+async def back_auth_request(
+ request: Request,
+ access_token: str = None,
+ refresh_token: str = None,
+) -> Optional[RedirectResponse]:
+ cookie = request.cookies.get("SEND")
+ if cookie is None:
+ return None
+ params = decode_jwt(cookie)["params"]
+ resp = RedirectResponseBuilder()
+ if access_token:
+ resp.set_cookie("access_token", access_token, max_age=settings.ACCESS_TOKEN_EXP)
+ if refresh_token:
+ resp.set_cookie("refresh_token", refresh_token, max_age=settings.ACCESS_TOKEN_EXP)
+ resp.delete_cookie("SEND")
+ return resp.build(f"/oauth2/authorize?{params}", status_code=303)
diff --git a/src/utils/redirect.py b/src/utils/redirect.py
new file mode 100644
index 0000000..10af040
--- /dev/null
+++ b/src/utils/redirect.py
@@ -0,0 +1,78 @@
+import http.cookies
+import typing
+from datetime import datetime
+from email.utils import format_datetime
+
+from starlette.datastructures import MutableHeaders
+from starlette.responses import RedirectResponse
+
+
+class RedirectResponseBuilder:
+ def __init__(self):
+ self.raw_headers = []
+
+ def set_cookie(
+ self,
+ key: str,
+ value: str = "",
+ max_age: typing.Optional[int] = None,
+ expires: typing.Optional[typing.Union[datetime, str, int]] = None,
+ path: str = "/",
+ domain: typing.Optional[str] = None,
+ secure: bool = False,
+ httponly: bool = False,
+ samesite: typing.Optional[typing.Literal["lax", "strict", "none"]] = "lax",
+ ) -> None:
+ cookie: "http.cookies.BaseCookie[str]" = http.cookies.SimpleCookie()
+ cookie[key] = value
+ if max_age is not None:
+ cookie[key]["max-age"] = max_age
+ if expires is not None:
+ if isinstance(expires, datetime):
+ cookie[key]["expires"] = format_datetime(expires, usegmt=True)
+ else:
+ cookie[key]["expires"] = expires
+ if path is not None:
+ cookie[key]["path"] = path
+ if domain is not None:
+ cookie[key]["domain"] = domain
+ if secure:
+ cookie[key]["secure"] = True
+ if httponly:
+ cookie[key]["httponly"] = True
+ if samesite is not None:
+ assert samesite.lower() in [
+ "strict",
+ "lax",
+ "none",
+ ], "samesite must be either 'strict', 'lax' or 'none'"
+ cookie[key]["samesite"] = samesite
+ cookie_val = cookie.output(header="").strip()
+ self.raw_headers.append((b"set-cookie", cookie_val.encode("latin-1")))
+
+ def delete_cookie(
+ self,
+ key: str,
+ path: str = "/",
+ domain: typing.Optional[str] = None,
+ secure: bool = False,
+ httponly: bool = False,
+ samesite: typing.Optional[typing.Literal["lax", "strict", "none"]] = "lax",
+ ) -> None:
+ self.set_cookie(
+ key,
+ max_age=0,
+ expires=0,
+ path=path,
+ domain=domain,
+ secure=secure,
+ httponly=httponly,
+ samesite=samesite,
+ )
+
+ @property
+ def headers(self) -> MutableHeaders:
+ return MutableHeaders(raw=self.raw_headers)
+
+ def build(self, url: str, status_code: int = 307):
+ return RedirectResponse(url, headers=self.headers, status_code=status_code)
diff --git a/src/utils/telegram.py b/src/utils/telegram.py
new file mode 100644
index 0000000..f5c0ddf
--- /dev/null
+++ b/src/utils/telegram.py
@@ -0,0 +1,45 @@
+import hashlib
+import hmac
+from datetime import datetime, timezone
+from typing import Optional
+
+from starlette.datastructures import QueryParams
+
+from src.config import settings
+from src.users.crypto import encode_jwt, decode_jwt
+
+
+async def verify_telegram_auth_data(params: QueryParams) -> Optional[int]:
+ data = list(params.items())
+ hash_str = ""
+ text_list = []
+ for key, value in data:
+ if key == "hash":
+ hash_str = value
+ else:
+ text_list.append(f"{key}={value}")
+ check_string = "\n".join(sorted(text_list))
+
+ secret_key = hashlib.sha256(str.encode(settings.BOT_TOKEN)).digest()
+ hmac_hash = hmac.new(secret_key, str.encode(check_string), hashlib.sha256).hexdigest()
+
+ return int(params.get("id")) if hmac_hash == hash_str else None
+
+
+async def encode_telegram_auth_data(uid: int) -> str:
+ jwt = encode_jwt(settings.ACCESS_TOKEN_EXP, str(uid))
+ return jwt
+
+
+async def decode_telegram_auth_data(params: QueryParams) -> Optional[int]:
+ jwt = params.get("jwt")
+ if not jwt:
+ return None
+ if not jwt:
+ return None
+ data = decode_jwt(jwt)
+ now = datetime.now(timezone.utc)
+ uid, exp = data["sub"], data["exp"]
+ if exp < now.timestamp():
+ return None
+ return int(uid)