Skip to content

Commit

Permalink
attributes working
Browse files Browse the repository at this point in the history
  • Loading branch information
tjdevries committed Mar 1, 2024
1 parent 2398774 commit 96f34ae
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 14 deletions.
41 changes: 40 additions & 1 deletion derive/attributes.ml
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,47 @@ type field_attributes = {
| `never ];
}

let of_constructor_attributes (_attributes : attributes) =
(* deserialize_ignored_any *)
let of_record_attributes attributes =
let deny_unknown_fields = ref false in
let serde_attr =
List.find_opt
(function { attr_name = { txt = "serde"; _ }; _ } -> true | _ -> false)
attributes
in
Option.iter
(fun attr ->
match attr.attr_payload with
| PStr
[
{
pstr_desc =
Pstr_eval ({ pexp_desc = Pexp_record (_fields, _); _ }, _);
_;
};
] ->
(* if _fields <> [] then failwith (Ppxli"Unknown attribute payload"; *)
List.iter
(function
| ( { txt = Lident "deny_unknown_fields"; _ },
{
pexp_desc = Pexp_construct ({ txt = Lident "true"; _ }, _);
_;
} ) ->
deny_unknown_fields := true
| ( { txt = Lident "deny_unknown_fields"; _ },
{
pexp_desc = Pexp_construct ({ txt = Lident "false"; _ }, _);
_;
} ) ->
deny_unknown_fields := false
| { txt = Lident txt; _ }, _ ->
failwith
(Format.sprintf "[ppx_serde] Unknown attribute %S" txt)
| _ -> ())
_fields
| _ -> ())
serde_attr;
{
rename = "";
mode = `normal;
Expand Down
41 changes: 29 additions & 12 deletions derive/de.ml
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,10 @@ module Record_deserializer = struct
)
]}
*)
let deserialize_with_unordered_fields ~ctxt labels final_expr =
let deserialize_with_unordered_fields ~ctxt type_attributes labels final_expr
=
let open Attributes in
let _ = type_attributes.deny_unknown_fields in
let loc = loc ~ctxt in
let labels = List.rev labels in
let labels = List.map Attributes.of_field_attributes labels in
Expand Down Expand Up @@ -241,6 +244,15 @@ module Record_deserializer = struct
*)
let field_visitor next =
let visit_string =
let invalid_tag_case =
let rhs =
match type_attributes.deny_unknown_fields with
| true -> [%expr Error `invalid_tag]
| false -> [%expr Ok `invalid_tag]
in
Ast.case ~lhs:(Ast.ppat_any ~loc) ~guard:None ~rhs
in

let cases =
List.map
(fun (field, attr) ->
Expand All @@ -251,10 +263,7 @@ module Record_deserializer = struct
in
Ast.case ~lhs ~rhs ~guard:None)
labels
@ [
Ast.case ~lhs:(Ast.ppat_any ~loc) ~guard:None
~rhs:[%expr Ok `invalid_tag];
]
@ [ invalid_tag_case ]
in
let body = Ast.pexp_match ~loc [%expr str] cases in
[%expr fun _ctx str -> [%e body]]
Expand Down Expand Up @@ -365,7 +374,8 @@ module Record_deserializer = struct
@@ record_expr
end

let gen_deserialize_variant_impl ~ctxt ptype_name cstr_declarations =
let gen_deserialize_variant_impl ~ctxt ptype_name type_attributes
cstr_declarations =
let loc = loc ~ctxt in
let type_name = Ast.estring ~loc ptype_name.txt in
let constructor_names =
Expand All @@ -377,7 +387,6 @@ let gen_deserialize_variant_impl ~ctxt ptype_name cstr_declarations =
in

let deser_by_constructor _type_name idx cstr =
let _ = Attributes.of_constructor_attributes cstr.pcd_attributes in
let _idx = Ast.eint ~loc idx in
let name = Longident.parse cstr.pcd_name.txt |> var ~ctxt in
match cstr.pcd_args with
Expand Down Expand Up @@ -452,7 +461,8 @@ let gen_deserialize_variant_impl ~ctxt ptype_name cstr_declarations =
| Pcstr_record labels ->
let field_count = Ast.eint ~loc (List.length labels) in
let body =
Record_deserializer.deserialize_with_unordered_fields ~ctxt labels
Record_deserializer.deserialize_with_unordered_fields ~ctxt
type_attributes labels
@@ fun record ->
let cstr = Ast.pexp_construct ~loc name (Some record) in
[%expr Ok [%e cstr]]
Expand Down Expand Up @@ -510,13 +520,14 @@ let gen_deserialize_variant_impl ~ctxt ptype_name cstr_declarations =
See [Record_deserializer] above for more info.
*)
let gen_deserialize_record_impl ~ctxt ptype_name label_declarations =
let gen_deserialize_record_impl ~ctxt ptype_name type_attributes
label_declarations =
let loc = loc ~ctxt in
let type_name = Ast.estring ~loc ptype_name.txt in
let field_count = Ast.eint ~loc (List.length label_declarations) in

let body =
Record_deserializer.deserialize_with_unordered_fields ~ctxt
Record_deserializer.deserialize_with_unordered_fields ~ctxt type_attributes
label_declarations
@@ fun record -> [%expr Ok [%e record]]
in
Expand All @@ -531,12 +542,18 @@ let gen_deserialize_impl ~ctxt type_decl =

let typename = type_decl.ptype_name.txt in

let type_attributes =
Attributes.of_record_attributes type_decl.ptype_attributes
in

let body =
match type_decl with
| { ptype_kind = Ptype_record label_declarations; ptype_name; _ } ->
gen_deserialize_record_impl ~ctxt ptype_name label_declarations
gen_deserialize_record_impl ~ctxt ptype_name type_attributes
label_declarations
| { ptype_kind = Ptype_variant cstrs_declaration; ptype_name; _ } ->
gen_deserialize_variant_impl ~ctxt ptype_name cstrs_declaration
gen_deserialize_variant_impl ~ctxt ptype_name type_attributes
cstrs_declaration
| { ptype_kind; ptype_name; _ } ->
let err =
match ptype_kind with
Expand Down
19 changes: 18 additions & 1 deletion serde_json/serde_json_test.ml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ type with_type_field = { type_ : string [@serde { rename = "type" }] }
[@@deriving serialize, deserialize]

type with_unknown_keys = { known : string } [@@deriving serialize, deserialize]
(* type with_unknown_keys = { known : string } [@@deriving serialize, deserialize] *)

type deny_unknown_fields = { known : string }
[@@deriving serialize, deserialize] [@@serde { deny_unknown_fields = true }]

let pp_variant fmt A = Format.fprintf fmt "A"
let pp_variant_with_arg fmt (B i) = Format.fprintf fmt "(B %d)" i
Expand Down Expand Up @@ -536,3 +538,18 @@ let _serde_json_parse_with_unknown_keys =
"parsed with unknown keys"
(error "%a" Serde.pp_err err);
assert false

let _serde_json_parse_deny_unknown_fields =
let str = {| {"known":"yoyo", "unknown": true} |} in
let parsed = Serde_json.of_string deserialize_deny_unknown_fields str in
match parsed with
| Ok parsed ->
assert (parsed.known = "yoyo");
Format.printf "serde_json.ser/de test %S %s\r\n%!"
"parsed with deny unknown keys"
(error "Pased but should not have");
assert false
| Error err ->
Format.printf "serde_json.ser/de test %S %s\r\n%!"
"parsed with deny unknown keys"
(keyword "OK: (found %a)" Serde.pp_err err)

0 comments on commit 96f34ae

Please sign in to comment.